【Pytorch】repeat()和expand()区别

您所在的位置:网站首页 repeatitve和repeated区别 【Pytorch】repeat()和expand()区别

【Pytorch】repeat()和expand()区别

2024-02-19 04:51| 来源: 网络整理| 查看: 265

torch.Tensor是包含一种数据类型元素的多维矩阵。

A  torch.Tensor is a multi-dimensional matrix containing elements of a single data type.

torch.Tensor有两个实例方法可以用来扩展某维的数据的尺寸,分别是repeat()和expand():

expand() expand(*sizes) -> Tensor *sizes(torch.Size or int) - the desired expanded  size Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

返回当前张量在某维扩展更大后的张量。

例子:

import torch x = torch.tensor([1, 2, 3]) y = x.expand(2, 3) print(y) 结果为: tensor([[1, 2, 3], [1, 2, 3]]) print(x) 结果为: tensor([1, 2, 3])

就拿上面的例子来说 x的形状为3,他被扩展成2行3列,该扩展形状即为最终形状。

那么此时x会自动在高位添加1这个空维度,这时x会变为1*3的形状,随后使用复制的方式,将形状变为2*3。

>> x = torch.randn(2, 1, 1, 4) >> x.expand(-1, 2, 3, -1) torch.Size([2, 2, 3, 4])

正如上面的这个例子,不难看出,都是沿着1所在的维度进行复制。 

repeat() repeat(*sizes) -> Tensor *size(torch.Size or int) - The  number of times to repeat this tensor along each dimension. Repeats this tensor along the specified dimensions.

沿着特定的维度重复这个张量。

例子:

import torch >> x = torch.tensor([1, 2, 3]) >> x.repeat(3, 2) tensor([[1, 2, 3, 1, 2, 3], [1, 2, 3, 1, 2, 3], [1, 2, 3, 1, 2, 3]])

上面的x形状为3,由于repeat后面的维度是2个维度,因此x也需要变成2个维度,即为1*3。

接下来repeat即从x最右边的3这个维度开始,x的3所在的维度被重复了2次,此时x变成[1,2,3,1,2,3];

然后看x的1这个维度,被重复了3次,变成[[1,2,3,1,2,3],[1,2,3,1,2,3][1,2,3,1,2,3]]。

>> x2 = torch.randn(2, 3, 4) >> x2.repeat(2, 1, 3).shape torch.Tensor([4, 3, 12])

假设x2的矩阵为[[[a1,a2,a3,a4],[a1,a2,a3,a4],[a1,a2,a3,a4]],[[a1,a2,a3,a4],[a1,a2,a3,a4],[a1,a2,a3,a4]]]

上面x2的形状为2*3*4,由于repeat的维度为3,所以x2的维度中不需要补1,可以直接进行重复操作。

首先从x2的最后边的4这个维度开始,被重复了3次,变为

[[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]]];

然后看x2的3这个维度,被重复了1次,其矩阵没有变化,依旧为

[[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]]];

最后看x2的2这个维度,被重复了2次,这时变为

[[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]]];

转自:https://zhuanlan.zhihu.com/p/58109107



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3