pytorch之torch中的几种乘法 #点乘torch.mm() #矩阵乘torch.mul(),torch.matmul() #高维Tensor相乘维度要求

您所在的位置:网站首页 n1盒子网口松 pytorch之torch中的几种乘法 #点乘torch.mm() #矩阵乘torch.mul(),torch.matmul() #高维Tensor相乘维度要求

pytorch之torch中的几种乘法 #点乘torch.mm() #矩阵乘torch.mul(),torch.matmul() #高维Tensor相乘维度要求

2023-08-11 03:54| 来源: 网络整理| 查看: 265

文章目录 1. 点乘——`torch.mul(a, b)`2. 矩阵乘2.1. 二维矩阵乘——`torch.mm(a, b)`2.2. 高维矩阵乘——`torch.matmul(a, b)` 3. 高维的Tensor相乘维度要求3.1. 对于维数相同的张量3.2. 对于维数不一样的张量

1. 点乘——torch.mul(a, b)

点乘是对应位置元素相乘 点乘都是broadcast的,可以用torch.mul(a, b)实现,也可以直接用*实现。

python中的广播机制(broadcasting) broadcasting可以这样理解:如果你有一个(m,n)的矩阵,让它加减乘除一个(1,n)的矩阵,它会被复制m次,成为一个(m,n)的矩阵,然后再逐元素地进行加减乘除操作。同样地对(m,1)的矩阵成立 在这里插入图片描述 图源:https://www.jianshu.com/p/fadd169cd396

当a, b维度满足广播机制时,会自动填充到相同维度相点乘。 例如:a的维度为(2,3),b的维度为(1,3); 或者:a的维度为(2,3),b的维度为(2,1)。当a, b维度不满足广播机制时,要求a和b的维度必须相等。 a的维度为(1,2),b的维度为(2,3)就会报错:The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1 报错的意思是b中维度为3的位置必须和a中维度为2的位置相匹配,因为a中有个维度1,要想满足广播机制就必须是(1,2)和(2,2),否则就需要满足维度必须相等(2,3)和(2,3) import torch a = torch.ones(3,4) print(a) b = torch.Tensor([1,2,3]).reshape((3,1)) print(b) print(torch.mul(a, b))

在这里插入图片描述

2. 矩阵乘

矩阵相乘有torch.mm(a, b)和torch.matmul(a, b)两个函数。

前一个是针对二维矩阵,后一个是高维。当torch.mm(a, b)用于大于二维时将报错。

2.1. 二维矩阵乘——torch.mm(a, b) import torch a = torch.ones(3,4) print(a) b = torch.ones(4,2) print(b) print(torch.mm(a, b))

在这里插入图片描述 当torch.mm(a, b)用于大于二维时将报错:

2.2. 高维矩阵乘——torch.matmul(a, b)

torch.matmul(a, b)可以用于二维:

import torch a = torch.ones(3,4) print(a) b = torch.ones(4,2) print(b) print(torch.matmul(a, b))

torch.matmul(a, b)可以用于高维:

import torch a = torch.ones(3,1,2) print(a) b = torch.ones(3,2,2) print(b) print(torch.matmul(a, b)) 3. 高维的Tensor相乘维度要求

两个Tensor维度要求:

"2维以上"的尺寸必须完全对应相等;"2维"具有实际意义的单位,只要满足矩阵相乘的尺寸规律即可。 3.1. 对于维数相同的张量

A.shape =(b,m,n);B.shape = (b,n,k) numpy.matmul(A,B) 结果shape为(b,m,k)

要求第一维度相同,后两个维度能满足矩阵相乘条件。

import torch a = torch.ones(3,1,2) print(a) b = torch.ones(3,2,2) print(b) print(torch.matmul(a, b))

在这里插入图片描述

3.2. 对于维数不一样的张量

比如 A.shape =(m,n); B.shape = (b,n,k); C.shape=(k,l)

numpy.matmul(A,B) 结果shape为(b,m,k)

numpy.matmul(B,C) 结果shape为(b,n,l)

2D张量要和3D张量的后两个维度满足矩阵相乘条件。

import torch a = torch.ones(1,2) print(a) b = torch.ones(2,2,3) print(b) c = torch.ones(3,1) print(b) print(torch.matmul(a, b)) print(torch.matmul(b, c))

在这里插入图片描述



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3