Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果 |
您所在的位置:网站首页 › excel的矩阵乘法 › Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果 |
Pytorch中张量矩阵乘法函数使用说明
1 torch.mm() 函数1.1 torch.mm() 函数定义及参数1.2 torch.bmm() 官方示例
2 torch.bmm() 函数2.1 torch.bmm() 函数定义及参数2.2 torch.bmm() 官方示例
3 torch.matmul() 函数3.1 torch.matmul() 函数定义及参数3.2 torch.matmul() 规则约定3.3 torch.matmul() 官方示例3.4 高维数据实例解释
参考博文及感谢
1 torch.mm() 函数
全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维; 1.1 torch.mm() 函数定义及参数torch.bmm(input, mat2, , out=None) → Tensor input (Tensor) – – 第一个要相乘的矩阵 ** mat2* (Tensor) – – 第二个要相乘的矩阵 不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。 1.2 torch.bmm() 官方示例 mat1 = torch.randn(2, 3) mat2 = torch.randn(3, 3) torch.mm(mat1, mat2) tensor([[ 0.4851, 0.5037, -0.3633], [-0.0760, -3.6705, 2.4784]]) 2 torch.bmm() 函数全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维; 2.1 torch.bmm() 函数定义及参数torch.bmm(input, mat2, , out=None) → Tensor input (Tensor) – – 第一批要相乘的矩阵 ** mat2* (Tensor) – – 第二批要相乘的矩阵 不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。 2.2 torch.bmm() 官方示例 input = torch.randn(10, 3, 4) mat2 = torch.randn(10, 4, 5) res = torch.bmm(input, mat2) res.size() torch.Size([10, 3, 5]) 3 torch.matmul() 函数可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数。 3.1 torch.matmul() 函数定义及参数torch.matmul(input, mat2, , out=None) → Tensor input (Tensor) – – 第一个要相乘的张量 ** mat2* (Tensor) – – 第二个要相乘的张量 支持广播到通用形状、类型推广以及整数、浮点和复杂输入。 3.2 torch.matmul() 规则约定(1)若两个都是1D(向量)的,则返回两个向量的点积; (2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D; (3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系; (4)若input是2D,other是1D,则返回两者的点积结果; (5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply) (a)若input是1D,other是大于2D的,则类似于规则(3);(b)若other是1D,input是大于2D的,则类似于规则(4);(c)若input和other都是3D的,则与torch.bmm()函数功能一样;(d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。 3.3 torch.matmul() 官方示例 # vector x vector tensor1 = torch.randn(3) tensor2 = torch.randn(3) torch.matmul(tensor1, tensor2).size() torch.Size([]) # matrix x vector tensor1 = torch.randn(3, 4) tensor2 = torch.randn(4) torch.matmul(tensor1, tensor2).size() torch.Size([3]) # batched matrix x broadcasted vector tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3]) # batched matrix x batched matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) # batched matrix x broadcasted matrix tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4, 5) torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) 3.4 高维数据实例解释直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算。 代码如下: import torch import numpy as np np.random.seed(2022) a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4)) a = torch.tensor(a) b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3)) b = torch.tensor(b) c = torch.matmul(a, b) # or # c = a @ b print(a) print("=============================================") print(b) print("=============================================") print(c.size()) print("=============================================") print(c)运行结果为: tensor([[[[1, 0, 1, 0], [1, 1, 0, 1], [0, 0, 0, 0]], [[1, 1, 1, 1], [1, 1, 0, 0], [0, 1, 0, 1]]], [[[0, 0, 0, 1], [0, 0, 0, 1], [0, 1, 0, 0]], [[1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]]]], dtype=torch.int32) ============================================= tensor([[[[0, 1, 0], [1, 1, 0], [0, 0, 0], [1, 1, 0]]], [[[0, 1, 0], [1, 1, 1], [1, 1, 1], [1, 0, 1]]]], dtype=torch.int32) ============================================= torch.Size([2, 2, 3, 3]) ============================================= tensor([[[[0, 1, 0], [2, 3, 0], [0, 0, 0]], [[2, 3, 0], [1, 2, 0], [2, 2, 0]]], [[[1, 0, 1], [1, 0, 1], [1, 1, 1]], [[3, 3, 3], [3, 3, 3], [0, 0, 0]]]], dtype=torch.int32) 参考博文及感谢部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ 参考博文1 官方文档查询地址 https://pytorch.org/docs/stable/index.html 参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别 https://blog.csdn.net/irober/article/details/113686080 |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |