我们从Python开源项目中,提取了以下15个代码示例,用于说明如何使用torch.addbmm()。
def forward(self, add_matrix, batch1, batch2): self.save_for_backward(batch1, batch2) output = self._get_output(add_matrix) return torch.addbmm(output, self.alpha, add_matrix, self.beta, batch1, batch2)
def test_addbmm(self): # num_batches = 10 # M, N, O = 12, 8, 5 num_batches = 2 M, N, O = 2, 3, 4 b1 = torch.randn(num_batches, M, N) b2 = torch.randn(num_batches, N, O) res = torch.bmm(b1, b2) res2 = torch.Tensor().resize_as_(res[0]).zero_() res2.addbmm_(b1,b2) self.assertEqual(res2, res.sum(0)[0]) res2.addbmm_(1,b1,b2) self.assertEqual(res2, res.sum(0)[0]*2) res2.addbmm_(1.,.5,b1,b2) self.assertEqual(res2, res.sum(0)[0]*2.5) res3 = torch.addbmm(1,res2,0,b1,b2) self.assertEqual(res3, res2) res4 = torch.addbmm(1,res2,.5,b1,b2) self.assertEqual(res4, res.sum(0)[0]*3) res5 = torch.addbmm(0,res2,1,b1,b2) self.assertEqual(res5, res.sum(0)[0]) res6 = torch.addbmm(.1,res2,.5,b1,b2) self.assertEqual(res6, res2 * .1 + res.sum(0) * .5)
def forward(ctx, add_matrix, batch1, batch2, alpha=1, beta=1, inplace=False): ctx.alpha = alpha ctx.beta = beta ctx.save_for_backward(batch1, batch2) output = _get_output(ctx, add_matrix, inplace=inplace) return torch.addbmm(alpha, add_matrix, beta, batch1, batch2, out=output)
def test_functional_blas(self): def compare(fn, *args): unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg for arg in args) self.assertEqual(fn(*args).data, fn(*unpacked_args)) def test_blas_add(fn, x, y, z): # Checks all signatures compare(fn, x, y, z) compare(fn, 0.5, x, y, z) compare(fn, 0.5, x, 0.25, y, z) def test_blas(fn, x, y): compare(fn, x, y) test_blas(torch.mm, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas_add(torch.addmm, Variable(torch.randn(2, 4)), Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)), Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)), Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas(torch.mv, Variable(torch.randn(2, 10)), Variable(torch.randn(10))) test_blas_add(torch.addmv, Variable(torch.randn(2)), Variable(torch.randn(2, 10)), Variable(torch.randn(10))) test_blas(torch.ger, Variable(torch.randn(5)), Variable(torch.randn(6))) test_blas_add(torch.addr, Variable(torch.randn(5, 6)), Variable(torch.randn(5)), Variable(torch.randn(6)))
def test_addbmm(self): # num_batches = 10 # M, N, O = 12, 8, 5 num_batches = 2 M, N, O = 2, 3, 4 b1 = torch.randn(num_batches, M, N) b2 = torch.randn(num_batches, N, O) res = torch.bmm(b1, b2) res2 = torch.Tensor().resize_as_(res[0]).zero_() res2.addbmm_(b1, b2) self.assertEqual(res2, res.sum(0)[0]) res2.addbmm_(1, b1, b2) self.assertEqual(res2, res.sum(0)[0] * 2) res2.addbmm_(1., .5, b1, b2) self.assertEqual(res2, res.sum(0)[0] * 2.5) res3 = torch.addbmm(1, res2, 0, b1, b2) self.assertEqual(res3, res2) res4 = torch.addbmm(1, res2, .5, b1, b2) self.assertEqual(res4, res.sum(0)[0] * 3) res5 = torch.addbmm(0, res2, 1, b1, b2) self.assertEqual(res5, res.sum(0)[0]) res6 = torch.addbmm(.1, res2, .5, b1, b2) self.assertEqual(res6, res2 * .1 + res.sum(0) * .5)
def test_addbmm(self): # num_batches = 10 # M, N, O = 12, 8, 5 num_batches = 2 M, N, O = 2, 3, 4 b1 = torch.randn(num_batches, M, N) b2 = torch.randn(num_batches, N, O) res = torch.bmm(b1, b2) res2 = torch.Tensor().resize_as_(res[0]).zero_() res2.addbmm_(b1, b2) self.assertEqual(res2, res.sum(0, False)) res2.addbmm_(1, b1, b2) self.assertEqual(res2, res.sum(0, False) * 2) res2.addbmm_(1., .5, b1, b2) self.assertEqual(res2, res.sum(0, False) * 2.5) res3 = torch.addbmm(1, res2, 0, b1, b2) self.assertEqual(res3, res2) res4 = torch.addbmm(1, res2, .5, b1, b2) self.assertEqual(res4, res.sum(0, False) * 3) res5 = torch.addbmm(0, res2, 1, b1, b2) self.assertEqual(res5, res.sum(0, False)) res6 = torch.addbmm(.1, res2, .5, b1, b2) self.assertEqual(res6, res2 * .1 + res.sum(0) * .5)
def forward(ctx, add_matrix, batch1, batch2, alpha=1, beta=1, inplace=False): ctx.alpha = alpha ctx.beta = beta ctx.add_matrix_size = add_matrix.size() ctx.save_for_backward(batch1, batch2) output = _get_output(ctx, add_matrix, inplace=inplace) return torch.addbmm(alpha, add_matrix, beta, batch1, batch2, out=output)
def test_addbmm(self): # num_batches = 10 # M, N, O = 12, 8, 5 num_batches = 2 M, N, O = 2, 3, 4 b1 = torch.randn(num_batches, M, N) b2 = torch.randn(num_batches, N, O) res = torch.bmm(b1, b2) res2 = torch.Tensor().resize_as_(res[0]).zero_() res2.addbmm_(b1, b2) self.assertEqual(res2, res.sum(0, False)) res2.addbmm_(1, b1, b2) self.assertEqual(res2, res.sum(0, False) * 2) res2.addbmm_(1., .5, b1, b2) self.assertEqual(res2, res.sum(0, False) * 2.5) res3 = torch.addbmm(1, res2, 0, b1, b2) self.assertEqual(res3, res2) res4 = torch.addbmm(1, res2, .5, b1, b2) self.assertEqual(res4, res.sum(0, False) * 3) res5 = torch.addbmm(0, res2, 1, b1, b2) self.assertEqual(res5, res.sum(0, False)) res6 = torch.addbmm(.1, res2, .5, b1, b2) self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5))
def _test_broadcast_fused_matmul(self, cast): fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] for fn in fns: batch_dim = random.randint(1, 8) n_dim = random.randint(1, 8) m_dim = random.randint(1, 8) p_dim = random.randint(1, 8) def dims_full_for_fn(): if fn == "baddbmm": return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) elif fn == "addbmm": return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) elif fn == "addmm": return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) elif fn == "addmv": return ([n_dim], [n_dim, m_dim], [m_dim]) elif fn == "addr": return ([n_dim, m_dim], [n_dim], [m_dim]) else: raise AssertionError("unknown function") (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) t0_small = cast(torch.randn(*t0_dims_small).float()) t1 = cast(torch.randn(*t1_dims).float()) t2 = cast(torch.randn(*t2_dims).float()) t0_full = cast(t0_small.expand(*t0_dims_full)) fntorch = getattr(torch, fn) r0 = fntorch(t0_small, t1, t2) r1 = fntorch(t0_full, t1, t2) self.assertEqual(r0, r1)
def test_functional_blas(self): def compare(fn, *args): unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg for arg in args) unpacked_result = fn(*unpacked_args) packed_result = fn(*args).data # if non-Variable torch function returns a scalar, compare to scalar if not torch.is_tensor(unpacked_result): assert packed_result.dim() == 1 assert packed_result.nelement() == 1 packed_result = packed_result[0] self.assertEqual(packed_result, unpacked_result) def test_blas_add(fn, x, y, z): # Checks all signatures compare(fn, x, y, z) compare(fn, 0.5, x, y, z) compare(fn, 0.5, x, 0.25, y, z) def test_blas(fn, x, y): compare(fn, x, y) test_blas(torch.mm, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas_add(torch.addmm, Variable(torch.randn(2, 4)), Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)), Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)), Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4))) test_blas(torch.mv, Variable(torch.randn(2, 10)), Variable(torch.randn(10))) test_blas_add(torch.addmv, Variable(torch.randn(2)), Variable(torch.randn(2, 10)), Variable(torch.randn(10))) test_blas(torch.ger, Variable(torch.randn(5)), Variable(torch.randn(6))) test_blas_add(torch.addr, Variable(torch.randn(5, 6)), Variable(torch.randn(5)), Variable(torch.randn(6))) test_blas(torch.matmul, Variable(torch.randn(6)), Variable(torch.randn(6))) test_blas(torch.matmul, Variable(torch.randn(10, 4)), Variable(torch.randn(4))) test_blas(torch.matmul, Variable(torch.randn(5)), Variable(torch.randn(5, 6))) test_blas(torch.matmul, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4))) test_blas(torch.matmul, Variable(torch.randn(5, 2, 10)), Variable(torch.randn(5, 10, 4))) test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(3, 5, 10, 4))) test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(10))) test_blas(torch.matmul, Variable(torch.randn(10)), Variable(torch.randn(3, 5, 10, 4)))