nn.Linear()函数详解及代码使用

您所在的位置:网站首页 linear是什么意思 nn.Linear()函数详解及代码使用

nn.Linear()函数详解及代码使用

2023-09-17 23:37| 来源: 网络整理| 查看: 265

这是官方给出的文档,需要注意的是,虽然在神经网络中,我们一般输入都是二维的tensor矩阵(batch,input_size),但其实输入的维度是不做限制的。如果是三维的输入,会将前两维的数据先乘一起,然后在做计算,实际上还是单层神经网络的计算。个人理解,这个函数就是改变最后一维,也就是数据的特征维度,通过调整output_size的尺寸来扩张或者是收缩特征。

import torch.nn as nn import torch import numpy as np X_2dim=np.array([[1,2,3,4],[2,3,45,6]]) #二维数组(2,4) X_3dim=np.array([[[1,2,3,4],[2,3,4,6],[3,4,5,5]],[[1,1,5,6],[0,0,6,5],[3,3,5,7]]]) # 三维数组(2,3,4) #转成tensor的形式,因为Linear要求输入是float类型,因此还需要转成float32 X2_tensor=torch.from_numpy(X_2dim.astype(np.float32)) X3_tensor=torch.from_numpy(X_3dim.astype(np.float32)) #用来改变最后数组最后一维的维度 #用来缩小或者扩展特征维度 emdeding=nn.Linear(4,3) Y2=emdeding(X2_tensor) Y3=emdeding(X3_tensor) #输出 print(Y2) print(Y3) #Y2 tensor([[ 0.6468, 0.6430, 0.4253], [-2.9180, -3.3393, 6.3075]], grad_fn=) #Y3 tensor([[[0.6468, 0.6430, 0.4253], [1.0562, 0.8781, 0.6216], [0.7615, 0.3500, 0.7439]], [[1.1430, 0.6462, 0.8132], [0.7745, 0.4598, 0.9190], [1.4516, 0.5589, 0.8545]]], grad_fn=)



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3