梯度累加是什么意思 |
您所在的位置:网站首页 › 1129391646_16771526175211njpg › 梯度累加是什么意思 |
每次看到“梯度是累加的,所以需要清除梯度”这句话都感觉云里雾里,貌似懂了实际没懂,原来竟是这么简单的意思! 1、直接上代码: import torch x = torch.Tensor([1, 2, 3]) x.requires_grad_() print(x) y = x**2 # 连续调用backward时,需要retain_graph=True确保计算图暂时不被释放 y.sum().backward(retain_graph=True) print(x.grad) y.sum().backward() print(x.grad) # 如果梯度不归零的话,梯度是累加的运行结果是: tensor([1., 2., 3.], requires_grad=True) tensor([2., 4., 6.]) tensor([ 4., 8., 12.])第一次调用backward反向传播,结果是(2 4 6),中间没有梯度清零,第二次调用backward反向传播,又有了一波结果(2 4 6),加在之前的结果上就得了(4 8 12) 2、接下来,我们在两次调用之间加一个梯度清零操作看看: import torch x = torch.Tensor([1, 2, 3]) x.requires_grad_() print(x) y = x**2 y.sum().backward(retain_graph=True) # 连续调用backward时,需要retain_graph=True确保计算图暂时不被释放 print(x.grad) x.grad.zero_() y.sum().backward() print(x.grad) # 如果梯度不归零的话,梯度是累加的运行结果是: tensor([1., 2., 3.], requires_grad=True) tensor([2., 4., 6.]) tensor([2., 4., 6.]) |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |