我们从Python开源项目中,提取了以下9个代码示例,用于说明如何使用torch.unbind()。
def transformImage(opt,image,pMtrx): refMtrx = util.toTorch(opt.refMtrx) refMtrx = refMtrx.repeat(opt.batchSize,1,1) transMtrx = refMtrx.matmul(pMtrx) # warp the canonical coordinates X,Y = np.meshgrid(np.linspace(-1,1,opt.W),np.linspace(-1,1,opt.H)) X,Y = X.flatten(),Y.flatten() XYhom = np.stack([X,Y,np.ones_like(X)],axis=1).T XYhom = np.tile(XYhom,[opt.batchSize,1,1]).astype(np.float32) XYhom = util.toTorch(XYhom) XYwarpHom = transMtrx.matmul(XYhom) XwarpHom,YwarpHom,ZwarpHom = torch.unbind(XYwarpHom,dim=1) Xwarp = (XwarpHom/(ZwarpHom+1e-8)).view(opt.batchSize,opt.H,opt.W) Ywarp = (YwarpHom/(ZwarpHom+1e-8)).view(opt.batchSize,opt.H,opt.W) grid = torch.stack([Xwarp,Ywarp],dim=-1) # sampling with bilinear interpolation imageWarp = torch.nn.functional.grid_sample(image,grid,mode="bilinear") return imageWarp
def test_unbind(self): x = torch.rand(2, 3, 4, 5) for dim in range(4): res = torch.unbind(x, dim) self.assertEqual(x.size(dim), len(res)) for i in range(dim): self.assertEqual(x.select(dim, i), res[i])
def parallel(layer): """ Creates a parallel operation (i.e., map/distributed operation). """ def func(module, x): """ The actual wrapped operation. """ return torch.stack( tuple(Layer.resolve(layer)(module, X) for X in torch.unbind(x, 0)), 0 ) func.pure = True return func ### EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF.EOF
def parallel(layer): """ Creates a parallel operation (i.e., map/distributed operation). """ def func(module, x): """ The actual wrapped operation. """ return torch.stack( tuple(Layer.resolve(layer)(module, X) for X in torch.unbind(x, 0)), 0 ) func.pure = True return func ###############################################################################
def vec2mtrx(opt,p): O = util.toTorch(np.zeros([opt.batchSize],dtype=np.float32)) I = util.toTorch(np.ones([opt.batchSize],dtype=np.float32)) if opt.warpType=="translation": tx,ty = torch.unbind(p,dim=1) pMtrx = torch.stack([torch.stack([I,O,tx],dim=-1), torch.stack([O,I,ty],dim=-1), torch.stack([O,O,I],dim=-1)],dim=1) if opt.warpType=="similarity": pc,ps,tx,ty = torch.unbind(p,dim=1) pMtrx = torch.stack([torch.stack([I+pc,-ps,tx],dim=-1), torch.stack([ps,I+pc,ty],dim=-1), torch.stack([O,O,I],dim=-1)],dim=1) if opt.warpType=="affine": p1,p2,p3,p4,p5,p6,p7,p8 = torch.unbind(p,dim=1) pMtrx = torch.stack([torch.stack([I+p1,p2,p3],dim=-1), torch.stack([p4,I+p5,p6],dim=-1), torch.stack([O,O,I],dim=-1)],dim=1) if opt.warpType=="homography": p1,p2,p3,p4,p5,p6,p7,p8 = torch.unbind(p,dim=1) pMtrx = torch.stack([torch.stack([I+p1,p2,p3],dim=-1), torch.stack([p4,I+p5,p6],dim=-1), torch.stack([p7,p8,I],dim=-1)],dim=1) return pMtrx # convert warp matrix to parameters
def mtrx2vec(opt,pMtrx): [row0,row1,row2] = torch.unbind(pMtrx,dim=1) [e00,e01,e02] = torch.unbind(row0,dim=1) [e10,e11,e12] = torch.unbind(row1,dim=1) [e20,e21,e22] = torch.unbind(row2,dim=1) if opt.warpType=="translation": p = torch.stack([e02,e12],dim=1) if opt.warpType=="similarity": p = torch.stack([e00-1,e10,e02,e12],dim=1) if opt.warpType=="affine": p = torch.stack([e00-1,e01,e02,e10,e11-1,e12],dim=1) if opt.warpType=="homography": p = torch.stack([e00-1,e01,e02,e10,e11-1,e12,e20,e21],dim=1) return p # warp the image