Pytorch的两个拼接函数,torch.cat() 和 torch.stack()

您所在的位置:网站首页 python中len和length区别 Pytorch的两个拼接函数,torch.cat() 和 torch.stack()

Pytorch的两个拼接函数,torch.cat() 和 torch.stack()

2023-03-21 15:12| 来源: 网络整理| 查看: 265

Pytorch中常用的两个拼接函数,torch.cat() 和 torch.stack()

1. torch.cat()

一般torch.cat()是为了把多个tensor进行拼接而存在的。实际使用中,和torch.stack()使用场景不同。

torch.cat()和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。

1.【函数目的】:

在给定维度上对输入的张量序列seq 进行连接操作。

outputs = torch.cat(inputs, dim=?) → Tensor

inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列

dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。

2.【note】:输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor。维度不可以超过输入数据的任一个张量的维度。import torch # x1 x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int) x1.shape # torch.Size([2, 3]) # x2 x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int) x2.shape # torch.Size([2, 3]) inputs = [x1, x2] print(inputs) #'打印查看' #[tensor([[11, 21, 31], # [21, 31, 41]], dtype=torch.int32), # tensor([[12, 22, 32], # [22, 32, 42]], dtype=torch.int32)]

#测试不同的dim拼接结果

In [1]: torch.cat(inputs, dim=0).shape Out[1]: torch.Size([4, 3]) In [2]: torch.cat(inputs, dim=1).shape Out[2]: torch.Size([2, 6]) In [3]: torch.cat(inputs, dim=2).shape IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)2. torch.stack()1.【函数目的】:

使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。

释义:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

,假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。

outputs = torch.stack(inputs, dim=?) → Tensor

inputs : 待连接的张量序列。注:python的序列数据只有list和tuple。dim : 新的维度, 必须在0到len(outputs)之间。注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。2.【note】:函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等

----[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

2. dim是选择生成的维度,必须满足0



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3