[PyTorch] 拼接多个tensor:torch.cat((A,B),axis) |
您所在的位置:网站首页 › python将两个数组合并为一个数组 › [PyTorch] 拼接多个tensor:torch.cat((A,B),axis) |
注:参考博客Pytorch中的torch.cat()函数。本人在其基础上增加了更为详细的解释。 torch.cat((A,B),axis)是对A, B两个tensor进行拼接。 参数axis指定拼接的方式。axis=0为按行拼接;axis=1为按列拼接。 拼接的时候把待拼接的tensor视作整体。(注意在示例中理解这句话) import torch # 初始化三个 tensor A=torch.ones(2,3) #2x3的张量(矩阵) # tensor([[ 1., 1., 1.], # [ 1., 1., 1.]]) B=2*torch.ones(4,3) #4x3的张量(矩阵) # tensor([[ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.]]) D=2*torch.ones(2,4) # 2x4的张量(矩阵) # tensor([[ 2., 2., 2., 2.], # [ 2., 2., 2., 2.], # 按维数0(行)拼接 A 和 B C=torch.cat((A,B),0) # tensor([[ 1., 1., 1.], # [ 1., 1., 1.], # [ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.]]) print(C.shape) # torch.Size([6, 3]) # 按维数1(列)拼接 A 和 D C=torch.cat((A,D),1) # tensor([[ 1., 1., 1., 2., 2., 2., 2.], # [ 1., 1., 1., 2., 2., 2., 2.]]) print(C.shape) # torch.Size([2, 7])另外,torch.cat((A,B),axis)还能把list中的tensor拼接起来。 import torch x = torch.Tensor([1, 2, 3]) x = x.unsqueeze(1) x2 = torch.cat( [ x*2 for i in range (1,4) ], 1 ) # tensor([[2., 2., 2.], # [4., 4., 4.], # [6., 6., 6.]])x = torch.Tensor([1, 2, 3])生成的x的shape为torch.Size([3]),我们需要用x = x.unsqueeze(1)为x增加第二个维度,使其变为二维的tensor:torch.Size([3, 1])。 关于升维函数 x.unsqueeze(axis) 和降维函数 x.unsqueeze(axis) 的详细说明,可以去我的另一篇博客增加维度或者减少维度 ——a.squeeze(axis) 和 a.unsqueeze(axis) torch.cat( [ x*2 for i in range (1,4) ], 1 )先生成了一个包含 3 个tensor的list,然后对list中的元素按列拼接(axis=1)。 关于range这个基本的函数,见本人的另一篇博客创建一个整数列表—— range() 故最后的结果为: # tensor([[2., 2., 2.], # [4., 4., 4.], # [6., 6., 6.]]) |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |