Pytorch模型测试时显存一直上升导致爆显存

您所在的位置:网站首页 如何清显存 Pytorch模型测试时显存一直上升导致爆显存

Pytorch模型测试时显存一直上升导致爆显存

2023-08-10 12:34| 来源: 网络整理| 查看: 265

问题描述

首先说明: 由于我的测试集很大, 因此需要对测试集进行分批次推理.

在写代码的时候发现进行训练的时候大概显存只占用了2GB左右, 而且训练过程中显存占用量也基本上是不变的. 而在测试的时候, 发现显存在每个batch数据推理后逐渐增加, 直至最后导致爆显存, 程序fail.

这里放一下我测试的代码:

y, y_ = torch.Tensor(), torch.Tensor() for batch in tqdm(loader): x, batch_y = batch[0], batch[1] batch_y_ = model(x) y = torch.cat([y, batch_y], dim=0) y_ = torch.cat([y_, batch_y_], dim=0) 解决方法

遇到问题后我就进行单步调试, 然后观察显存的变化. 发现在模型推理这一步, 每一轮次显存都会增加.

batch_y_ = model(x)

这里令人费解的是, 模型推理实际上在训练和测试中都是存在的, 为什么训练的时候就不会出现这个问题呢.

最后发现其实是在训练的时候有这样一步与测试不同:

self.optimizer.zero_grad()

⭐️ 在训练时, 每一个batch后都会将模型的梯度进行一次清零. 而测试的时候我则没有加这一步, 这样的话每次模型再做推理的时候都会产生新的梯度, 并累积到显存当中.

清楚了问题, 那么解决方法也就随之而来, 在测试的时候让模型不要记录梯度就好, 因为其实也用不到:

with torch.no_grad(): test() 总结 训练的时候每一个batch结束除了梯度反向传播, 还要提前清理梯度梯度如果不清理的话, 会在显存中累积下来


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3