Pytorch中torch.nn.Softmax的dim参数含义

您所在的位置:网站首页 p-dim是什么意思 Pytorch中torch.nn.Softmax的dim参数含义

Pytorch中torch.nn.Softmax的dim参数含义

2024-07-10 14:39| 来源: 网络整理| 查看: 265

涉及到多维tensor时,对softmax的参数dim总是很迷,下面用一个例子说明

import torch.nn as nn m = nn.Softmax(dim=0) n = nn.Softmax(dim=1) k = nn.Softmax(dim=2) input = torch.randn(2, 2, 3) print(input) print(m(input)) print(n(input)) print(k(input))

输出: input

tensor([[[ 0.5450, -0.6264, 1.0446], [ 0.6324, 1.9069, 0.7158]], [[ 1.0092, 0.2421, -0.8928], [ 0.0344, 0.9723, 0.4328]]])

dim=0

tensor([[[0.3860, 0.2956, 0.8741], [0.6452, 0.7180, 0.5703]], [[0.6140, 0.7044, 0.1259], [0.3548, 0.2820, 0.4297]]])

dim=0时,在第0维上sum=1,即: [0][0][0]+[1][0][0]=0.3860+0.6140=1 [0][0][1]+[1][0][1]=0.2956+0.7044=1 … …

dim=1

tensor([[[0.4782, 0.0736, 0.5815], [0.5218, 0.9264, 0.4185]], [[0.7261, 0.3251, 0.2099], [0.2739, 0.6749, 0.7901]]])

dim=1时,在第1维上sum=1,即: [0][0][0]+[0][1][0]=0.4782+0.5218=1 [0][0][1]+[0][1][1]=0.0736+0.9264=1 … …

dim=2

tensor([[[0.3381, 0.1048, 0.5572], [0.1766, 0.6315, 0.1919]], [[0.6197, 0.2878, 0.0925], [0.1983, 0.5065, 0.2953]]])

dim=2时,在第2维上sum=1,即: [0][0][0]+[0][0][1]+[0][0][2]=0.3381+0.1048+0.5572=1.0001(四舍五入问题) [0][1][0]+[0][1][1]+[0][1][2]=0.1766+0.6315+0.1919=1 … …

用图表示223的张量如下: 在这里插入图片描述



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3