Pytorch 钩子函数hook的使用

您所在的位置:网站首页 pytorch节省显存 Pytorch 钩子函数hook的使用

Pytorch 钩子函数hook的使用

#Pytorch 钩子函数hook的使用| 来源: 网络整理| 查看: 265

1. hook函数

为了节省显存(内存),PyTorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数在使用后应及时删除(remove),以避免每次都运行钩子增加运行负载。

这里总结一下并给出实际用法和注意点。列举了常见的4种hook函数:

1、Tensor.register_hook() # 用来导出指定张量的梯度,或修改这个梯度值 2、torch.nn.Module.register_forward_hook() 3、torch.nn.Module.register_backward_hook() 4、torch.nn.Module.register_forward_pre_hook() 2. hook函数说明 2.1 Tensor.register_hook()

用来导出指定张量的梯度,或修改这个梯度值。 在这里插入图片描述 注意:

上述代码是有效的,但如果写成 grad = grad * 2就失效了,因为此时没有对grad进行本地操作,新的grad 值没有传递给指定的梯度。保险起见,最好在def语句中写明re


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3