我们从Python开源项目中,提取了以下8个代码示例,用于说明如何使用torch.round()。
def visual(self, input_ts, target_ts, mask_ts, output_ts=None): """ input_ts: [(num_wordsx2+2) x batch_size x (len_word+2)] target_ts: [(num_wordsx2+2) x batch_size x (len_word)] mask_ts: [(num_wordsx2+2) x batch_size x (len_word)] output_ts: [(num_wordsx2+2) x batch_size x (len_word)] """ output_ts = torch.round(output_ts * mask_ts) if output_ts is not None else None input_strings = [self._readable(input_ts[:, 0, i]) for i in range(input_ts.size(2))] target_strings = [self._readable(target_ts[:, 0, i]) for i in range(target_ts.size(2))] mask_strings = [self._readable(mask_ts[:, 0, 0])] output_strings = [self._readable(output_ts[:, 0, i]) for i in range(output_ts.size(2))] if output_ts is not None else None input_strings = 'Input:\n' + '\n'.join(input_strings) target_strings = 'Target:\n' + '\n'.join(target_strings) mask_strings = 'Mask:\n' + '\n'.join(mask_strings) output_strings = 'Output:\n' + '\n'.join(output_strings) if output_ts is not None else None # strings = [input_strings, target_strings, mask_strings, output_strings] # self.logger.warning(input_strings) # self.logger.warning(target_strings) # self.logger.warning(mask_strings) # self.logger.warning(output_strings) print(input_strings) print(target_strings) print(mask_strings) print(output_strings) if output_ts is not None else None
def test_round(self): self._testMath(torch.round, round)
def visual(self, input_ts, target_ts, mask_ts, output_ts=None): """ input_ts: [(num_wordsx(repeats+1)+3) x batch_size x (len_word+2)] target_ts: [(num_wordsx(repeats+1)+3) x batch_size x (len_word+1)] mask_ts: [(num_wordsx(repeats+1)+3) x batch_size x (len_word+1)] output_ts: [(num_wordsx(repeats+1)+3) x batch_size x (len_word+1)] """ input_ts = self._unnormalize_repeats(input_ts) output_ts = torch.round(output_ts * mask_ts) if output_ts is not None else None input_strings = [self._readable(input_ts[:, 0, i]) for i in range(input_ts.size(2))] target_strings = [self._readable(target_ts[:, 0, i]) for i in range(target_ts.size(2))] mask_strings = [self._readable(mask_ts[:, 0, 0])] output_strings = [self._readable(output_ts[:, 0, i]) for i in range(output_ts.size(2))] if output_ts is not None else None input_strings = 'Input:\n' + '\n'.join(input_strings) target_strings = 'Target:\n' + '\n'.join(target_strings) mask_strings = 'Mask:\n' + '\n'.join(mask_strings) output_strings = 'Output:\n' + '\n'.join(output_strings) if output_ts is not None else None # strings = [input_strings, target_strings, mask_strings, output_strings] # self.logger.warning(input_strings) # self.logger.warning(target_strings) # self.logger.warning(mask_strings) # self.logger.warning(output_strings) print(input_strings) print(target_strings) print(mask_strings) print(output_strings) if output_ts is not None else None
def round(x): y = get_op(lambda x: torch.round(x))(x) return y