Pytorch中scatter

您所在的位置:网站首页 scatter的用法和短语 Pytorch中scatter

Pytorch中scatter

#Pytorch中scatter| 来源: 网络整理| 查看: 265

先看一个例子:

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 1, 1, 1]]), 2) tensor([[2., 0., 0., 0., 0.], [0., 2., 2., 2., 2.], [0., 0., 0., 0., 0.]])

首先是定义了一个3行5列的数组,_scatter中第一个参数0.表示沿着第0轴, 后面第二个参数是坐标,第三个是对应坐标的值,整个意思就是给torch.zeros(3, 5)对应元素赋值,怎么理解呢:

我们看torch.tensor([[0, 1, 1, 1, 1]]), 这个是一个1行5列的数组,【0, 0】的值是0, 【0,1】的值是1,等等,这个地方,数组的值就是torch.zeros(3, 5)的第几行,因为这个地方是沿着第0轴,对应的是行, 那么torch.tensor([[0, 1, 1, 1, 1]]) 对应torch.zeros(3, 5)的坐标就是 【0, 0 】, 【1, 1】, 【1,3】, 【1, 4】,torch.tensor([[0, 1, 1, 1, 1]])这个数组的列就是对应torch.zeros(3, 5)的列,值代表的是他的行,因为他是沿着第0轴,所以如果就是行,

所以因为沿着第0轴,所以torch.tensor([[0, 1, 1, 1, 1]])的列和 torch.zeros(3, 5)列必须一样,都是5, 如果这个地方torch.tensor([[0, 1, 1, 1, 1]]) 列是4,会报错,同时因为torch.tensor([[0, 1, 1, 1, 1]])的值表示行,所以他的值不能大于等于3,不然会报错,因为 torch.zeros(3, 5)索引越界

再来看个例子:

torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), 2) tensor([[2., 0., 0., 0., 0.],         [0., 2., 0., 0., 0.],         [0., 0., 0., 2., 0.]])

因为沿着第1轴,所以值表示第1轴的索引,所以对应索引是【0, 0】, 【1,1】,【2,3】对应的元素是2,

在看下面:

torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), torch.from_numpy(np.array([[0], [1], [3]], np.float32))) tensor([[0., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 0., 3., 0.]])

第三参数也可以是一个数组,也就是对应的值可以赋值不一样的,但是对应位置是第二个数组控制的。

所以,我们可以拓展到多维度的,

torch.zeros(2, 3, 5).scatter_(0, torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]]), 2) tensor([[[2., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.]]])

沿着第0轴,所以第二个参数的值就是第0轴的坐标,比如第二个参数的第一个值,表示[0,0,0], 第二个是【1, 0, 1】等等

torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]])的shape是[1, 3,5]这个数组的第0轴,表示的是我们可以多赋值,就是也可以他的shape也可以是[2, 3,5], ,[3, 3,5]等,可以理解为是为了方便赋值更多,所以这个数组的第0个维度不做对应索引元素的值,就像上面数组维度为2一样,

所以,给定任意的scatter_函数,都是可以这样理解的。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3