pytorch之SmoothL1Loss原理与用法

您所在的位置:网站首页 l1loss求导 pytorch之SmoothL1Loss原理与用法

pytorch之SmoothL1Loss原理与用法

2024-07-09 07:03| 来源: 网络整理| 查看: 265

文章目录 官方说明:解读图像用法

官方说明:

在这里插入图片描述

解读

我们直接看那个loss计算公式 l n l_n ln​,可以发现,是一个分段函数,我们将绝对值差视为一个变量 z z z,那么这个变量是大于0的,即分段函数只在大于等于0处有定义,有图像。我们再来看看分段点,就是beta。

有意思的是,在分段函数和这个分段点有关,在第一个公式(左边分段函数)中,函数值小于等于 0.5 z 0.5z 0.5z,因为除了beta。右边分段函数中,大于等于 0.5 z 0.5z 0.5z。所以是连续的,所以叫做Smooth。

而且beta固定下来的时候,当 z z z很大时,损失是线性函数,也就是说损失不会像MSE那样平方倍的爆炸。

总结就是:前半段随着 z z z的增长,损失增长得非常缓慢,后面快了一点点,但是也仍然是线性的。

图像 plt.figure(figsize=(20,8),dpi=80) beta=[0.5,1,2,3] for i in range(len(beta)): x1=np.linspace(0,beta[i],21) y1=0.5*x1*x1/beta[i] x2=np.linspace(beta[i],6,21) y2=x2-0.5*beta[i] plt.plot(np.hstack([x1,x2]),np.hstack([y1,y2]),label="beta:{}".format(beta[i])) plt.xlabel("the absolute element-wise error") plt.ylabel("the real loss") plt.legend()

在这里插入图片描述

用法 import torch import torch.nn as nn a=[1,2,3] b=[3,1,9] loss_fn=nn.SmoothL1Loss() loss_fn(torch.tensor(a,dtype=torch.float32),torch.tensor(b,dtype=torch.float32))

在这里插入图片描述



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3