Pytorch BN(BatchNormal)计算过程与源码分析和train与eval的区别

您所在的位置:网站首页 train的复数 Pytorch BN(BatchNormal)计算过程与源码分析和train与eval的区别

Pytorch BN(BatchNormal)计算过程与源码分析和train与eval的区别

2024-06-08 23:05| 来源: 网络整理| 查看: 265

文章目录 1. Pytorch的net.train 和 net.eval2. net.train2.1 BN (Batch Normalization)一、什么是BN?二、BN核心公式三、以全连接网络的BN为例(图例过程)四、PyTorch 源码解读之 BN1.**BatchNorm 原理**2. BatchNorm 的 PyTorch 实现2.1 _NormBase 类**2.1.1 初始化**2.1.2 模拟 BN forward2.1.3 running_mean、running_var 的更新 3. 再回到train和eval3.1 调试验证 4. 对于BatchNorm2d

1. Pytorch的net.train 和 net.eval

​ 神经网络模块存在两种模式: train模式( **net.train() ** ) 和eval模式( net.eval() )

2. net.train

​ 一般的神经网络中,这两种模式是一样的,只有当模型中存在dropout和batchnorm的时候才有区别。说到这里,先回顾以下神经网络中的batchnorm(BN)

2.1 BN (Batch Normalization) 一、什么是BN?

​ Batch Normalization是2015年一篇论文中提出的数据归一化方法,往往用在深度神经网络中激活层之前。其作用可以加快模型训练时的收敛速度,使得模型训练过程更加稳定,避免梯度爆炸或者梯度消失。并且起到一定的正则化作用,几乎代替了Dropout。

​ 神经网络训练开始前,都要对数据做一个归一化处理,归一化有很多好处,原因是网络学习的过程的本质就是学习数据分布,一旦训练数据和测试数据的分布不同,那么网络的泛化能力就会大大降低,另外一方面,每一批次的数据分布如果不相同的话,那么网络就要在每次迭代的时候都去适应不同的分布,这样会大大降低网络的训练速度,这也就是为什么要对数据做一个归一化预处理的原因。另外对图片进行归一化处理还可以处理光照,对比度等影响。

二、BN核心公式

在这里插入图片描述

一般来说,BN层的输出将作为下一层激活层的输入。BN层的输入一组数据 X={ x1 , x2, x3, x4,… , xm }, 计算 平均值 uB然后计算方差再对输入X的每一个数据进行标准化输出y通过γ与β的线性变换得到新的值 (γ,β 正是需要训练的参数) 三、以全连接网络的BN为例(图例过程)

在这里插入图片描述

​ 假设输入的数据为**[ [ 1, 2, 3] , [4 ,5 ,6] ]**

在这里插入图片描述

​ 对于输 [ 1,2,3 ] 第一个神经元输出: (1*w1 + 2*w2 + 3/*w3) + b​ 同理可得其他输出

示例:

以Pytorch为例:

class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

如下代码

class example(nn.Module): def __init__(self): super(example, self).__init__() self.fc1 = nn.Linear(3, 3) self.bn = nn.BatchNorm1d(num_features=3) def forward(self, x): print(x) #输入 x = self.fc1(x) x = self.bn(x) return x if __name__ == '__main__': datas = torch.tensor([[1,2,3], [4,5,6]], dtype=torch.float) datas = datas.cuda() net = example().cuda() # summary(net.cuda(),(3,)) out = net(datas) print(out)

调试:

(1) 查看全连接层的权值weight和偏置bias如和输入如下:

在这里插入图片描述

(2)全连接层forward

在这里插入图片描述

​ 其计算过程如下:

(i) 输入一组[ 1, 2, 3 ] 为例, 第一个神经元计算输出 ( 1*0.0193 + 2*0.3252+3*(-0.3773) ) + (- 0.2935) = - 0.7556

​ 第二个神经元计算输出 ( 1*0.3813 + 2*0.2321+3*0.5265) + 0.5100 = 2.9349

​ 第三个神经元计算输出 ( 1*(- 0.3829) + 2*(- 0.1440)+3*0.1517) + (-0.0860)= - 0.3017

​ 对于这组输入的最终结果为 [ - 0.7556, 2.9349 , -0.3017]

(ii)对于第二组输入[ 4, 5 , 6 ] 计算同上, 最后输出 [ -0.8537, 6.3545, -1.4271]

**(3)* bn层 **

​ 对于全连接层fc1输出了 tensor([[-0.7556, 2.9349, -0.3017], [-0.8537, 6.3545, -1.4271]], device=‘cuda:0’, grad_fn=) 在这里插入图片描述

​ 计算过程如下:

​ 注意:此时BN层输入通道数为3, 即 对于BN第一个神经元的输入 为 上一层输出的 第一维的集合 即 [-0.7556, -0.8537]

​ 且 weight和bias分别对应 γ,β

​ 根据上面的BN核心公式:

​ (i)以第一个BN层神经元为例, 计算输入的平均值 E[x] = ( -0.7556 + (- 0.8537))/2 = - 0.80465

​ 计算输入的方差(有偏估计) Var[x] = 0.0024

​ (ii)根据公式继续计算, 其中eps=1e-5, 得到:X=[ 0.9979, -0.9979], 即第一组输出和第二组输入的第一维度值分别为 0.9979, -0.9979,同理可以计算其他数。

具体计算如下:

>>> data tensor([[-0.7556, 2.9349, -0.3017], [-0.8537, 6.3545, -1.4271]]) >>> mean_var = data.mean(0) >>> mean_var tensor([-0.8046, 4.6447, -0.8644]) >>> var_var = data.var(0, unbiased=False) >>> var_var tensor([2.4059e-03, 2.9234e+00, 3.1663e-01]) >>> out = (data-mean_var)/torch.sqrt(var_var+1e-5) >>> out tensor([[ 0.9979, -1.0000, 1.0000], [-0.9979, 1.0000, -1.0000]]) 四、PyTorch 源码解读之 BN

​ 以下主要摘自,便于自己学习:PyTorch 源码解读之 BN

1.BatchNorm 原理

在这里插入图片描述

​	BatchNorm 最早在全连接网络中被提出,对每个神经元的输入做归一化。扩展到 CNN 中,就是对每个卷积核的输入做归一化,或者说在 channel 之外的所有维度做归一化。

2. BatchNorm 的 PyTorch 实现

PyTorch 中与 BN 相关的几个类放在 torch.nn.modules.batchnorm 中,包含以下几个类:

_NormBase:nn.Module 的子类,定义了 BN 中的一系列属性与初始化、读数据的方法;_BatchNorm:_NormBase 的子类,定义了 forward 方法;BatchNorm1d & BatchNorm2d & BatchNorm3d:_BatchNorm的子类,定义了不同的_check_input_dim方法。 2.1 _NormBase 类 2.1.1 初始化

_NormBase类定义了 BN 相关的一些属性,如下表所示:

attributemeaningnum_features输入的 channel 数track_running_stats默认为 True,是否统计 running_mean,running_varrunning_mean训练时统计输入的 mean,之后用于 inferencerunning_var训练时统计输入的 var,之后用于 inferencemomentum默认 0.1,更新 running_mean,running_var 时的动量num_batches_trackedPyTorch 0.4 后新加入,当 momentum 设置为 None 时,使用 num_batches_tracked 计算每一轮更新的动量affine默认为 True,训练 weight 和 bias;否则不更新它们的值weight公式中的 \gamma,初始化为全 1 tensorbias公式中的 \beta,初始化为全 0 tensor

这里贴一下 PyTorch 的源码:

class _NormBase(Module): """Common base of _InstanceNorm and _BatchNorm""" # 读checkpoint时会用version来区分是 PyTorch 0.4.1 之前还是之后的版本 _version = 2 __constants__ = ['track_running_stats', 'momentum', 'eps', 'num_features', 'affine'] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_NormBase, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats if self.affine: # 如果打开 affine,就使用缩放因子和平移因子 self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) # 训练时是否需要统计 mean 和 variance if self.track_running_stats: # buffer 不会在self.parameters()中出现 self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: init.ones_(self.weight) init.zeros_(self.bias) def _check_input_dim(self, input): # 具体在 BN1d, BN2d, BN3d 中实现,验证输入合法性 raise NotImplementedError def extra_repr(self): return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 'track_running_stats={track_running_stats}'.format(**self.__dict__) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if (version is None or version


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3