Weight Normalization(WN) 权重归一化

您所在的位置:网站首页 wn英文缩写 Weight Normalization(WN) 权重归一化

Weight Normalization(WN) 权重归一化

2023-12-21 04:22| 来源: 网络整理| 查看: 265

      BN/LN/IN/GN都是在数据的层面上做的归一化,而Weight Normalization(WN)是对网络权值W做的归一化。WN的做法是将权值向量w在其欧氏范数和其方向上解耦成了参数向量 v 和参数标量 g 后使用SGD分别优化这两个参数。

      WN也是和样本量无关的,所以可以应用在batchsize较小以及RNN等动态网络中;另外BN使用的基于mini-batch的归一化统计量代替全局统计量,相当于在梯度计算中引入了噪声。而WN则没有这个问题,所以在生成模型,强化学习等噪声敏感的环境中WN的效果也要优于BN。

      WN没有额外参数,这样更节约显存。同时WN的计算效率也要优于要计算归一化统计量的BN。

      但是,WN不具备BN把每一层的输出Y固定在一个变化范围的作用。因此采用WN的时候要特别注意参数初始值的选择

可以认为v是本来的权重

v除以v的模,可以得到它的单位方向向量,再乘以g,g是可学习的

 本来的权重是v的,现在又新增了一个g,得到的新的w是保留了v的方向,然后又新增了一个可学习的幅度

torch.nn.utils.weight_norm(module, name='weight', dim=0) import torch from torch import nn layer = nn.Linear(20, 40) m = nn.utils.weight_norm(layer, name='weight') print(m) print(m.weight_g.size()) print(m.weight_v.size())

手动实现 import torch from torch import nn input = torch.randn(8, 3, 20) linear = nn.Linear(20, 40, bias=False) wn_layer = nn.utils.weight_norm(linear, name='weight') wn_output = wn_layer(input) weight_direction = linear.weight / torch.norm(linear.weight, p=2, dim=1, keepdim=True) #二范数 weight_magnitude = wn_layer.weight_g output = input @ (weight_direction.permute(1,0).contiguous() * weight_magnitude.permute(1,0).contiguous()) assert torch.allclose(wn_output, output)



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3