Guided

您所在的位置:网站首页 SAMPLE的解读,错误的是 Guided

Guided

2024-07-11 01:53| 来源: 网络整理| 查看: 265

本文所分析的DDPM论文以及代码的相关信息如下:

Diffusion Models Beat GANs on Image Synthesis https://github.com/openai/guided-diffusion 一、算法流程

论文核心方法: 在扩散模型上作出两个方面的改进提升,即改进模型架构(Architecture Improvements)和使用分类器引导(Classifier Guidance),从而在FID上获得优于GAN的表现。

  1.通过设计一系列的消融实验,比较各部分改进对FID的影响,最终确定使用架构如下:

  2.使用分类器引导的DDPM和DDIM的算法流程分别如下:

二、 代码分析 2.1 分类引导采样流程

采样步骤包括:

加载UNet模型 (预测噪声$\theta$) 和扩散模型 ($\mu_\theta(x_t),\sigma_\theta(x_t)$); 加载预训练好的分类噪声图像的分类器 $p_\phi(y|x_t)$; 进行DDPM或DDIM采样过程,并加入引导梯度; 将样本转化为图片并保存。

相关核心函数调用见下图(以DDPM采样为例): classifier_sample.py代码如下:

1#使用一个噪声图像分类器来引导采样过程,从而生成更逼真的图像 2#非核心代码已省略.... 3def main(): 4 logger.log("creating model and diffusion...") 5 model, diffusion = create_model_and_diffusion(#初始化UNet模型和扩散模型 6 **args_to_dict(args, model_and_diffusion_defaults().keys())) 7 model.load_state_dict( 8 dist_util.load_state_dict(args.model_path, map_location="cpu")) 9 model.to(dist_util.dev()) 10 if args.use_fp16: 11 model.convert_to_fp16()#使用浮点进行原始模型的训练推理,float16加快速度 12 model.eval()#.train()模式主要用于激活某些特定于训练的层如Dropout和BatchNorm) 13 # 而.eval()模式则确保这些层在评估或测试时不激活 14 logger.log("loading classifier...") 15 classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys())) 16 #.... 与加载model类似 17 logger.log("sampling...") 18 def model_fn(x, t, y=None):#预测噪声 19 assert y is not None 20 return model(x, t, y if args.class_cond else None) 21 all_images = [] 22 all_labels = [] 23 while len(all_images) * args.batch_size


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3