nn.Dropout应该放在哪里?有什么用?

您所在的位置:网站首页 蜡笔干什么用的 nn.Dropout应该放在哪里?有什么用?

nn.Dropout应该放在哪里?有什么用?

2023-08-30 22:39| 来源: 网络整理| 查看: 265

nn.Dropout是常用的防止过拟合的技术,同时有提生模型鲁棒性等诸多优点。 在pytorch中,Dropout共有两个参数:

p – probability of an element to be zeroed. Default: 0.5 p为元素被置0的概率,即被‘丢’掉的概率inplace – If set to True, will do this operation in-place. Default: False 是否进行就地操作,默认为否

下面记录几个例子方便理解:

import torch import torch.utils.data as Data import numpy as np from torch import nn test1=torch.tensor(np.ones((10))) print(test1) #tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=torch.float64) shape=test1.size()[0] l1=nn.Linear(shape,shape) test1=l1(test1) '''tensor([ 1.2093, 0.4491, -1.3389, -0.1974, -0.3614, -0.4381, 0.0361, 0.5686, -0.3484, 0.0110], grad_fn=)''' d=nn.Dropout(0.3) test1=d(test1) ''' tensor([ 1.7276, 0.6416, -0.0000, -0.2820, -0.0000, -0.6258, 0.0000, 0.0000, -0.4977, 0.0157], grad_fn=) note: 实测,0.3只是一个大概的概率,对单样本来说并不完全精确,但在大样本下无限接近于0.3 ''' l2=nn.Linear(shape,shape) test1=l2(test1) ''' tensor([ 0.2900, 0.2310, 0.0016, -0.0806, 0.5717, -0.3444, -0.4782, 0.0626, 0.5609, 0.5971], grad_fn=) ''' print(l2.weight) print(l2.weight.size()) ''' tensor([[-0.1172, 0.1366, -0.1985, -0.2270, 0.2187, 0.0618, 0.0044, 0.0445, -0.2501, -0.1953], [ 0.2652, -0.1562, -0.0418, -0.1766, 0.1065, -0.0898, 0.0416, 0.2394, -0.0220, -0.1494], [ 0.0389, 0.1469, -0.0870, 0.1732, 0.0536, -0.1480, -0.2159, -0.2245, 0.1741, 0.1966], [-0.0640, 0.1407, -0.2929, -0.2074, -0.0870, 0.0739, 0.1317, 0.1397, -0.1414, -0.1190], [ 0.2213, -0.0548, -0.2381, -0.3103, -0.0315, 0.0421, 0.2332, -0.0250, -0.0922, -0.2802], [-0.2725, 0.1465, -0.0210, 0.2928, 0.2386, 0.1554, -0.0050, -0.2504, -0.3003, 0.1500], [-0.1510, -0.2754, 0.1225, -0.0894, -0.0776, 0.1202, 0.2340, -0.0242, 0.2875, 0.1247], [ 0.2576, 0.2990, -0.2164, 0.1260, -0.2340, 0.2892, -0.1079, 0.0174, 0.2302, -0.0158], [ 0.2372, 0.1228, 0.0886, 0.2946, 0.2048, -0.3042, -0.2103, -0.1585, 0.2601, -0.3085], [ 0.1054, 0.2851, 0.2042, -0.2554, 0.0978, -0.2108, 0.2923, 0.1692, -0.1682, -0.0952]], requires_grad=True) torch.Size([10, 10]) '''

记录一个小知识点, 不构建nn.model时, 单行使用完 l2=nn.Linear(shape,shape) test1=l2(test1) l2.weight会自动更新



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3