Pytorch BN(BatchNormal)计算过程与源码分析和train与eval的区别 |
您所在的位置:网站首页 › train的复数 › Pytorch BN(BatchNormal)计算过程与源码分析和train与eval的区别 |
文章目录
1. Pytorch的net.train 和 net.eval2. net.train2.1 BN (Batch Normalization)一、什么是BN?二、BN核心公式三、以全连接网络的BN为例(图例过程)四、PyTorch 源码解读之 BN1.**BatchNorm 原理**2. BatchNorm 的 PyTorch 实现2.1 _NormBase 类**2.1.1 初始化**2.1.2 模拟 BN forward2.1.3 running_mean、running_var 的更新
3. 再回到train和eval3.1 调试验证
4. 对于BatchNorm2d
1. Pytorch的net.train 和 net.eval
神经网络模块存在两种模式: train模式( **net.train() ** ) 和eval模式( net.eval() ) 2. net.train 一般的神经网络中,这两种模式是一样的,只有当模型中存在dropout和batchnorm的时候才有区别。说到这里,先回顾以下神经网络中的batchnorm(BN) 2.1 BN (Batch Normalization) 一、什么是BN? Batch Normalization是2015年一篇论文中提出的数据归一化方法,往往用在深度神经网络中激活层之前。其作用可以加快模型训练时的收敛速度,使得模型训练过程更加稳定,避免梯度爆炸或者梯度消失。并且起到一定的正则化作用,几乎代替了Dropout。 神经网络训练开始前,都要对数据做一个归一化处理,归一化有很多好处,原因是网络学习的过程的本质就是学习数据分布,一旦训练数据和测试数据的分布不同,那么网络的泛化能力就会大大降低,另外一方面,每一批次的数据分布如果不相同的话,那么网络就要在每次迭代的时候都去适应不同的分布,这样会大大降低网络的训练速度,这也就是为什么要对数据做一个归一化预处理的原因。另外对图片进行归一化处理还可以处理光照,对比度等影响。 二、BN核心公式 假设输入的数据为**[ [ 1, 2, 3] , [4 ,5 ,6] ]** 示例: 以Pytorch为例: class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)如下代码 class example(nn.Module): def __init__(self): super(example, self).__init__() self.fc1 = nn.Linear(3, 3) self.bn = nn.BatchNorm1d(num_features=3) def forward(self, x): print(x) #输入 x = self.fc1(x) x = self.bn(x) return x if __name__ == '__main__': datas = torch.tensor([[1,2,3], [4,5,6]], dtype=torch.float) datas = datas.cuda() net = example().cuda() # summary(net.cuda(),(3,)) out = net(datas) print(out)调试: (1) 查看全连接层的权值weight和偏置bias如和输入如下: (2)全连接层forward 其计算过程如下: (i) 输入一组[ 1, 2, 3 ] 为例, 第一个神经元计算输出 ( 1*0.0193 + 2*0.3252+3*(-0.3773) ) + (- 0.2935) = - 0.7556 第二个神经元计算输出 ( 1*0.3813 + 2*0.2321+3*0.5265) + 0.5100 = 2.9349 第三个神经元计算输出 ( 1*(- 0.3829) + 2*(- 0.1440)+3*0.1517) + (-0.0860)= - 0.3017 对于这组输入的最终结果为 [ - 0.7556, 2.9349 , -0.3017] (ii)对于第二组输入[ 4, 5 , 6 ] 计算同上, 最后输出 [ -0.8537, 6.3545, -1.4271] **(3)* bn层 ** 对于全连接层fc1输出了 tensor([[-0.7556, 2.9349, -0.3017], [-0.8537, 6.3545, -1.4271]], device=‘cuda:0’, grad_fn=) 计算过程如下: 注意:此时BN层输入通道数为3, 即 对于BN第一个神经元的输入 为 上一层输出的 第一维的集合 即 [-0.7556, -0.8537] 且 weight和bias分别对应 γ,β 根据上面的BN核心公式: (i)以第一个BN层神经元为例, 计算输入的平均值 E[x] = ( -0.7556 + (- 0.8537))/2 = - 0.80465 计算输入的方差(有偏估计) Var[x] = 0.0024 (ii)根据公式继续计算, 其中eps=1e-5, 得到:X=[ 0.9979, -0.9979], 即第一组输出和第二组输入的第一维度值分别为 0.9979, -0.9979,同理可以计算其他数。 具体计算如下: >>> data tensor([[-0.7556, 2.9349, -0.3017], [-0.8537, 6.3545, -1.4271]]) >>> mean_var = data.mean(0) >>> mean_var tensor([-0.8046, 4.6447, -0.8644]) >>> var_var = data.var(0, unbiased=False) >>> var_var tensor([2.4059e-03, 2.9234e+00, 3.1663e-01]) >>> out = (data-mean_var)/torch.sqrt(var_var+1e-5) >>> out tensor([[ 0.9979, -1.0000, 1.0000], [-0.9979, 1.0000, -1.0000]]) 四、PyTorch 源码解读之 BN 以下主要摘自,便于自己学习:PyTorch 源码解读之 BN 1.BatchNorm 原理PyTorch 中与 BN 相关的几个类放在 torch.nn.modules.batchnorm 中,包含以下几个类: _NormBase:nn.Module 的子类,定义了 BN 中的一系列属性与初始化、读数据的方法;_BatchNorm:_NormBase 的子类,定义了 forward 方法;BatchNorm1d & BatchNorm2d & BatchNorm3d:_BatchNorm的子类,定义了不同的_check_input_dim方法。 2.1 _NormBase 类 2.1.1 初始化_NormBase类定义了 BN 相关的一些属性,如下表所示: attributemeaningnum_features输入的 channel 数track_running_stats默认为 True,是否统计 running_mean,running_varrunning_mean训练时统计输入的 mean,之后用于 inferencerunning_var训练时统计输入的 var,之后用于 inferencemomentum默认 0.1,更新 running_mean,running_var 时的动量num_batches_trackedPyTorch 0.4 后新加入,当 momentum 设置为 None 时,使用 num_batches_tracked 计算每一轮更新的动量affine默认为 True,训练 weight 和 bias;否则不更新它们的值weight公式中的 \gamma,初始化为全 1 tensorbias公式中的 \beta,初始化为全 0 tensor这里贴一下 PyTorch 的源码: class _NormBase(Module): """Common base of _InstanceNorm and _BatchNorm""" # 读checkpoint时会用version来区分是 PyTorch 0.4.1 之前还是之后的版本 _version = 2 __constants__ = ['track_running_stats', 'momentum', 'eps', 'num_features', 'affine'] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_NormBase, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: # 如果打开 affine,就使用缩放因子和平移因子 self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) # 训练时是否需要统计 mean 和 variance if self.track_running_stats: # buffer 不会在self.parameters()中出现 self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: init.ones_(self.weight) init.zeros_(self.bias) def _check_input_dim(self, input): # 具体在 BN1d, BN2d, BN3d 中实现,验证输入合法性 raise NotImplementedError def extra_repr(self): return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 'track_running_stats={track_running_stats}'.format(**self.__dict__) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if (version is None or version |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |