查表法 · 量化 Softmax

您所在的位置:网站首页 sum嵌套offset累加求和 查表法 · 量化 Softmax

查表法 · 量化 Softmax

#查表法 · 量化 Softmax| 来源: 网络整理| 查看: 265

关于查表法基础性介绍请回看

基本信息

表达公式:y = exp(x) / sum(exp(x))

函数曲线:没有固定曲线

数学推演消除 max

计算 softmax 的第一步通常都是做如下这样一个等价变化,来保证求和时不会发生数据溢出,

y = exp(x) / sum(exp(x)) = exp(x - offset) / sum(exp(x - offset)),通常 offset = max(x)

随后将问题拆解为如何得到 exp(x - max(x))。带入量化的表达式 x = sx * X,得,

exp(sx * X - max(sx * X)) = exp(sx(X - max(X)))

对于 X(量化 tensor)而言,

max(X) P,其中 X 用 input_quant_bit 位宽存放,T 与 acc 保持一致用 acc_quant_bit 位宽存放,P 用 acc_quant_bit + output_quant_bit 位宽存放。

功能代码

代码地址

class Softmax(torch.nn.Module): def __init__(self, dim_len: int, input_bit: int, input_amax: float, input_unsign: bool, output_bit: int, output_amax: float, output_unsign: bool = True, acc_bit: int = 16, narrow: bool = False, dim: int = None) -> None: super().__init__() assert (input_bit minus_max -> DQ -> float_func -> (exp_float) input_quant = self.input_qconfig.range.to(torch.int32) input_quant_minus_max = input_quant - self.input_qconfig.quant_max input_float_minus_max = self.input_qconfig.dequantize(input_quant_minus_max) exp_float = torch.exp(input_float_minus_max) acc_quant_max = quant_max(bit=acc_bit, unsign=False) # denominator denominator_scale = 1 / (acc_quant_max // dim_len) # denominator allowed min quant scale self.denominator_element_qconfig = QuantConfig(bit=acc_bit, narrow=False, unsign=False, scale=denominator_scale) # (exp_float) -> Q -> (denominator_element_quant) denominator_element_quant = self.denominator_element_qconfig.quantize(exp_float) # numerator numerator_bit = acc_bit + output_bit numerator_scale = denominator_scale * self.output_qconfig.scale self.numerator_qconfig = QuantConfig(bit=numerator_bit, narrow=False, unsign=False, scale=numerator_scale) # (exp_float) -> Q -> (numerator_quant) numerator_quant = self.numerator_qconfig.quantize(exp_float) # adjust sequence of output_quant for easier retrieve if input_unsign: self._denominator_element_table = denominator_element_quant self._numerator_table = numerator_quant else: index = self.input_qconfig.quant_max if narrow else self.input_qconfig.quant_max + 1 self._denominator_element_table = torch.cat( (denominator_element_quant[index:], denominator_element_quant[:index])) self._numerator_table = torch.cat((numerator_quant[index:], numerator_quant[:index])) def forward(self, x: torch.Tensor): denominator_element = self._denominator_element_table[x.to(torch.int64)] denominator = torch.sum(denominator_element, dim=self._dim) numerator = self._numerator_table[x.to(torch.int64)] y = numerator / denominator y = torch.clamp(y, self.output_qconfig.quant_min, self.output_qconfig.quant_max) y = y.to(self.output_qconfig.dtype) return y测试代码

代码地址

测试代码和测试设计中提到的会有所不同。因为 Softmax 不是简单查表就能实现的,过程中存在累加和除法,所以存在无法避免的误差。在测试代码中,将量化输出的最大绝对值误差(max absolute error)限定在 1 以内(包括 1),也就是等价浮点输出误差在 output_quant_scale 以内,对应代码块 L19。

def _check_symmetric_quant_table_softmax(dim_len_range: Tuple[int], input_bit_range: Tuple[int], input_amax_range: Tuple[float], input_unsign_range: Tuple[bool], output_bit_range: Tuple[int], output_amax_range: Tuple[float], output_unsign_range: Tuple[bool] = (False,)): for dim_len in tqdm(range(dim_len_range[0], dim_len_range[1]), desc=f'Testing {Softmax}'): for input_bit in input_bit_range: for input_amax in input_amax_range: for input_unsign in input_unsign_range: for output_bit in output_bit_range: for output_amax in output_amax_range: for output_unsign in output_unsign_range: for narrow in (True, False): max_absolute_error = __check_symmetric_quant_table_softmax( dim_len, input_bit, input_amax, input_unsign, output_bit, output_amax, output_unsign, narrow) if max_absolute_error > 1: print(f'dim_len = {dim_len}, input_bit = {input_bit}, input_amax = {input_amax}, input_unsign = {input_unsign}, '\ f'output_bit = {output_bit}, output_amax = {output_amax}, output_unsign = {output_unsign}, '\ f'narrow = {narrow} max_absolute_error is {max_absolute_error}!') return False return True

备注:测试代码中未添加 acc_quant_bit 的遍历测试。经手动调节 acc_quant_bit 大小发现,acc_quant_bit 越大,误差越小。acc_quant_bit 较小时,误差很大,测试无法通过。acc_quant_bit 足够大时,误差可以控制在允许范围内,测试能通过。此时 acc_quant_bit 再增大会发现,绝对值误差为 1 的数量会逐渐减少。有兴趣的朋友也可以试一下。究其原因,是 acc_quant_bit 位宽增大,保留了更多原来浮点数末尾的小数,保留的越多累加后的误差也就越小。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3