Pytorch加载模型不完全匹配 & 只加载部分参数权重 load (pth文件) & 从网络加载权重(URL)

您所在的位置:网站首页 amdx49550参数 Pytorch加载模型不完全匹配 & 只加载部分参数权重 load (pth文件) & 从网络加载权重(URL)

Pytorch加载模型不完全匹配 & 只加载部分参数权重 load (pth文件) & 从网络加载权重(URL)

2023-12-17 11:34| 来源: 网络整理| 查看: 265

加载模型不完全匹配 model.load_state_dict(torch.load(weight_path), strict=False)

当权重中的key和网络中匹配就加载,不匹配就跳过

如果strict是True,那必须完全匹配,不然就报错

默认是True

但是注意,如果是像英文模型迁移到中文,改了class num的话,例如由26改为3600,这时模型不匹配用它是解决不了的,因为此时模型的key名字是对应的上的,只是权重的size不同 看

只加载部分参数权重

如果发生上述情况的话,那就需要把加载到的模型的中,不匹配的那几项删掉,然后加载其他项

x = torch.load(self.weight) del x['char_recognizer.classifier.bias'] del x['char_recognizer.classifier.weight'] self.load_state_dict(x, strict=False)

或者

# Use when some parts of pretrained model are not needed # pretrained_dict = checkpoint['state_dict'] # model_dict = model.state_dict() # # 1. filter out unnecessary keys # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # # 2. overwrite entries in the existing state dict # model_dict.update(pretrained_dict) # # 3. load the new state dict # model.load_state_dict(model_dict)

或者

load pretrained model 然后通过args.finetune_ignore指定忽略的参数

code is from DAB-DETR

def clean_state_dict(state_dict): new_state_dict = OrderedDict() for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] # remove `module.` new_state_dict[k] = v return new_state_dict #code is from DAB-DETR if not args.resume and args.pretrain_model_path: checkpoint = torch.load(args.pretrain_model_path, map_location='cpu')['model'] from collections import OrderedDict _ignorekeywordlist = args.finetune_ignore if args.finetune_ignore else [] ignorelist = [] def check_keep(keyname, ignorekeywordlist): for keyword in ignorekeywordlist: if keyword == keyname: ignorelist.append(keyname) return False return True # logger.info("Ignore keys: {}".format(json.dumps(ignorelist, indent=2))) _tmp_st = OrderedDict({k:v for k, v in clean_state_dict(checkpoint).items() if check_keep(k, _ignorekeywordlist)}) _load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False) 从网络加载权重 #code is from DETR checkpoint = torch.hub.load_state_dict_from_url( url, map_location='cpu', check_hash=True)



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3