[PyTorch] 拼接多个tensor:torch.cat((A,B),axis)

您所在的位置:网站首页 python将两个数组合并为一个数组 [PyTorch] 拼接多个tensor:torch.cat((A,B),axis)

[PyTorch] 拼接多个tensor:torch.cat((A,B),axis)

2024-07-10 11:26| 来源: 网络整理| 查看: 265

注:参考博客Pytorch中的torch.cat()函数。本人在其基础上增加了更为详细的解释。

torch.cat((A,B),axis)是对A, B两个tensor进行拼接。

参数axis指定拼接的方式。axis=0为按行拼接;axis=1为按列拼接。

拼接的时候把待拼接的tensor视作整体。(注意在示例中理解这句话)

import torch # 初始化三个 tensor A=torch.ones(2,3) #2x3的张量(矩阵) # tensor([[ 1., 1., 1.], # [ 1., 1., 1.]]) B=2*torch.ones(4,3) #4x3的张量(矩阵) # tensor([[ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.]]) D=2*torch.ones(2,4) # 2x4的张量(矩阵) # tensor([[ 2., 2., 2., 2.], # [ 2., 2., 2., 2.], # 按维数0(行)拼接 A 和 B C=torch.cat((A,B),0) # tensor([[ 1., 1., 1.], # [ 1., 1., 1.], # [ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.], # [ 2., 2., 2.]]) print(C.shape) # torch.Size([6, 3]) # 按维数1(列)拼接 A 和 D C=torch.cat((A,D),1) # tensor([[ 1., 1., 1., 2., 2., 2., 2.], # [ 1., 1., 1., 2., 2., 2., 2.]]) print(C.shape) # torch.Size([2, 7])

另外,torch.cat((A,B),axis)还能把list中的tensor拼接起来。

import torch x = torch.Tensor([1, 2, 3]) x = x.unsqueeze(1) x2 = torch.cat( [ x*2 for i in range (1,4) ], 1 ) # tensor([[2., 2., 2.], # [4., 4., 4.], # [6., 6., 6.]])

x = torch.Tensor([1, 2, 3])生成的x的shape为torch.Size([3]),我们需要用x = x.unsqueeze(1)为x增加第二个维度,使其变为二维的tensor:torch.Size([3, 1])。

关于升维函数 x.unsqueeze(axis) 和降维函数 x.unsqueeze(axis) 的详细说明,可以去我的另一篇博客增加维度或者减少维度 ——a.squeeze(axis) 和 a.unsqueeze(axis)

torch.cat( [ x*2 for i in range (1,4) ], 1 )先生成了一个包含 3 个tensor的list,然后对list中的元素按列拼接(axis=1)。 关于range这个基本的函数,见本人的另一篇博客创建一个整数列表—— range() 故最后的结果为:

# tensor([[2., 2., 2.], # [4., 4., 4.], # [6., 6., 6.]])


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3