Pytorch

您所在的位置:网站首页 pytorch进阶 Pytorch

Pytorch

2023-02-21 23:21| 来源: 网络整理| 查看: 265

LightningCLI 是lighting团队开发的一个拓展模块,用于高效、进阶支持yaml,需要在安装pl后,再安装一下jsonargparse[signatures] : pip install jsonargparse[signatures]

与pytorch- Lightning配合使用非常方便,可惜网上资料甚少,谷歌也找不到多少,官方文档也不够详尽,比如如何配合使用自定义 logger和callback等都没提及。我来做一个比较实用的补充:

需要将原本的train入口部分改为如下形式以支持CLI、自定义logger和callbacks

logger = WandbLogger(project="xx_Baseline", name = "xxx", log_model=True) def cli_main(): checkpoint_callback = ModelCheckpoint( monitor='xx', #dirpath='xxx', filename='xxx-epoch{epoch:02d}-psnr{psnr:.3f}-ssim{ssim:.3f}', auto_insert_metric_name=False, every_n_epochs=1, save_top_k=3, mode = "max", save_last=True ) trainer_defaults = {'gpus':[0,1],'callbacks':[checkpoint_callback],'logger':logger} cli = LightningCLI( model_class=Mymodel, trainer_defaults=trainer_defaults, save_config_overwrite=True ) if __name__ == '__main__': #your code cli_main()

我们可以看下CLI自带的help信息:

训练时,使用如下命令即可:

python PL_train_CIL.py fit --config ./configs/base_config.yaml

LightningCIL 好用的特性挺多,详细可以参考一下官方文档:

Lightning CLI and config files



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3