pytorch中stack()和cat()的理解和区别图解

您所在的位置:网站首页 carter和cat区别 pytorch中stack()和cat()的理解和区别图解

pytorch中stack()和cat()的理解和区别图解

2023-08-22 05:37| 来源: 网络整理| 查看: 265

torch.cat() 和 torch.stack() 函数的作用都是将多个维度参数相同的张量连接成一个张量,不同之处在与 stock()相比于cat()多了一维。这里两个函数都有 dim 这个参数,但是指的意思却不一样。使用下图来解释,在这里将两个张量理解成树这种形式,希望可以帮助理解。

这里竖线代表一个维度,竖线上所有节点代表同一维度的所有元素,在下面所有图中,同颜色的元素都是按照从上往下按顺序排列的。

dim在cat()函数中表示索所要连接的维度,也就是连接 所要连接的多个张量 的这个维度上面的参数。

但是在stack()中,dim表示多出来的维度,这个维度被用来连接之后维度的参数。原来的维度则变成子节点了,例如dim=1,那么 原来张量的第一维度 就变成了 连接之后的张量 的第二维度

假设这里一个torch.randn(2, 3, 4)生成的两个张量,如下图 在这里插入图片描述

红和蓝分别表示两个不同的张量,后面所有的图中左边的是使用stack()函数,右边是使用cat()函数,黄色的表示stack()函数生成的多的一维。 那么当 dim = 0时,如下图在这里插入图片描述

dim = 1, 如下图

在这里插入图片描述

dim = 2,如下图 在这里插入图片描述

对于stack()函数生成的结果会多一个维度,所有在这个例子中会有3这个索引值所代表的第四维度,dim = 3是成立的,但是对于cat()函数则没有这个

在这里插入图片描述



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3