pytorch中的矩阵乘法:函数mul,mm,mv以及 @运算 和 *运算 |
您所在的位置:网站首页 › 两个矩阵怎么乘 › pytorch中的矩阵乘法:函数mul,mm,mv以及 @运算 和 *运算 |
pytorch中矩阵运算种类
关于@运算,*运算,torch.mul(), torch.mm(), torch.mv(), tensor.t() @ 和 *代表矩阵的两种相乘方式:@表示常规的数学上定义的矩阵相乘;*表示两个矩阵对应位置处的两个元素相乘。 x.dot(y): 向量乘积,x,y均为一维向量。 *和torch.mul()等同:表示相同shape矩阵点乘,即对应位置相乘,得到矩阵有相同的shape。 @和torch.mm(a, b)等同:正常矩阵相乘,要求a的列数与b的行数相同。 torch.mv(X, w0):是矩阵和向量相乘.第一个参数是矩阵,第二个参数只能是一维向量,等价于X乘以w0的转置 Y.t():矩阵Y的转置。 实例展示: a = torch.tensor([ [1, 0, 1], [0, 1, 0], [1, 0, 1], ]) b = torch.tensor([ [3, 1, 3], [1, 0, 1], [3, 1, 3], ]) c = torch.tensor([ [1, 1, 1], [0, 1, 1], [0, 0, 1], ]) w1 = torch.tensor([1, 2, 1]) w2 = torch.tensor([3, 4, 3])1、 *和torch.mul()等同,矩阵点乘 res1 = a*b res11 = torch.mul(a, b) print("a*b={}\n".format(res1)) print("torch.mul(a, b)={}\n".format(res11))Output: a*b=tensor([[3, 0, 3], [0, 0, 0], [3, 0, 3]]) torch.mul(a, b)=tensor([[3, 0, 3], [0, 0, 0], [3, 0, 3]]) 2、 @和torch.mm(a, b)等同,矩阵乘法 res2 = a@b res22 = torch.mm(a, b) print("a@b={}\n".format(res2)) print("torch.mm(a,b)={}\n".format(res22))Output: a@b=tensor([[6, 2, 6], [1, 0, 1], [6, 2, 6]]) torch.mm(a,b)=tensor([[6, 2, 6], [1, 0, 1], [6, 2, 6]]) 3、 dot()向量乘法 向量运算,参数不能是多维矩阵,否则报错:RuntimeError: 1D tensors expected, got 2D, 2D tensors at. res3 = w1.dot(w2) print("w1.dot(w2)={}\n".format(res3))Output: w1.dot(w2)=14 4、 c.t()矩阵转置 res4 = c.t() print("c.t()={}\n".format(res4))Output: c.t()=tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) 5、torch.mv(a, w1)矩阵乘向量 res5 = torch.mv(a, w1) res55 = torch.mv(a, w1.t()) print("torch.mv(a, w1)={}\n".format(res5)) print("torch.mv(a, w1.t())={}\n".format(res55))Output: 向量w1是行向量和列向量都可以,结果一样。 torch.mv(a, w1)=tensor([2, 2, 2]) torch.mv(a, w1.t())=tensor([2, 2, 2]) 总结:如果命题拿不准就多测试两遍,相互验证。 |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |