Pytorch中scatter |
您所在的位置:网站首页 › scatter的用法和短语 › Pytorch中scatter |
先看一个例子: torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 1, 1, 1]]), 2) tensor([[2., 0., 0., 0., 0.], [0., 2., 2., 2., 2.], [0., 0., 0., 0., 0.]])首先是定义了一个3行5列的数组,_scatter中第一个参数0.表示沿着第0轴, 后面第二个参数是坐标,第三个是对应坐标的值,整个意思就是给torch.zeros(3, 5)对应元素赋值,怎么理解呢: 我们看torch.tensor([[0, 1, 1, 1, 1]]), 这个是一个1行5列的数组,【0, 0】的值是0, 【0,1】的值是1,等等,这个地方,数组的值就是torch.zeros(3, 5)的第几行,因为这个地方是沿着第0轴,对应的是行, 那么torch.tensor([[0, 1, 1, 1, 1]]) 对应torch.zeros(3, 5)的坐标就是 【0, 0 】, 【1, 1】, 【1,3】, 【1, 4】,torch.tensor([[0, 1, 1, 1, 1]])这个数组的列就是对应torch.zeros(3, 5)的列,值代表的是他的行,因为他是沿着第0轴,所以如果就是行, 所以因为沿着第0轴,所以torch.tensor([[0, 1, 1, 1, 1]])的列和 torch.zeros(3, 5)列必须一样,都是5, 如果这个地方torch.tensor([[0, 1, 1, 1, 1]]) 列是4,会报错,同时因为torch.tensor([[0, 1, 1, 1, 1]])的值表示行,所以他的值不能大于等于3,不然会报错,因为 torch.zeros(3, 5)索引越界 再来看个例子: torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), 2) tensor([[2., 0., 0., 0., 0.], [0., 2., 0., 0., 0.], [0., 0., 0., 2., 0.]])因为沿着第1轴,所以值表示第1轴的索引,所以对应索引是【0, 0】, 【1,1】,【2,3】对应的元素是2, 在看下面: torch.zeros(3, 5).scatter_(1, torch.tensor([[0], [1], [3]]), torch.from_numpy(np.array([[0], [1], [3]], np.float32))) tensor([[0., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 0., 3., 0.]])第三参数也可以是一个数组,也就是对应的值可以赋值不一样的,但是对应位置是第二个数组控制的。 所以,我们可以拓展到多维度的, torch.zeros(2, 3, 5).scatter_(0, torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]]), 2) tensor([[[2., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], [[0., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.]]])沿着第0轴,所以第二个参数的值就是第0轴的坐标,比如第二个参数的第一个值,表示[0,0,0], 第二个是【1, 0, 1】等等 torch.tensor([[[0, 1, 1, 1, 1], [1, 1, 1, 1,1], [1,1,1,1,1]]])的shape是[1, 3,5]这个数组的第0轴,表示的是我们可以多赋值,就是也可以他的shape也可以是[2, 3,5], ,[3, 3,5]等,可以理解为是为了方便赋值更多,所以这个数组的第0个维度不做对应索引元素的值,就像上面数组维度为2一样, 所以,给定任意的scatter_函数,都是可以这样理解的。 |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |