Stable Diffusion Quick Kit 动手实践

您所在的位置:网站首页 sts安装配置 Stable Diffusion Quick Kit 动手实践

Stable Diffusion Quick Kit 动手实践

#Stable Diffusion Quick Kit 动手实践| 来源: 网络整理| 查看: 265

什么是 LoRA

很多小伙伴接触 LoRA 是 Stable Diffusion 的 LoRA 模型,用于人物和风格训练,但本质上 LoRA 并不专为 Stable Diffusion 服务。LoRA 英文全称 Low-Rank Adaptation of Large Language Models,是微软开源的解决大模型微调而开发的一项技术。

大模型参数规模巨大,比如 GPT-3 有 1750 亿参数,对这类大模型进行微调成本太高,LoRA 的做法是,冻结预训练好的模型权重参数,然后在每个 Transformer 块里进行低秩矩阵运算,注入新的训练的层参数。

这种方式与普通 fine tuning 需要对模型的权重参数重新计算梯度不同,相当于在原有 layer 上新增加的 network 层参数,所以大大减少了需要训练的计算量,并且保存的训练后的模型文件只是 network 超参值,相比于原模型文件大小小很多,方便进行分享和转换。

LoRA 本来是给大语言模型准备的,但把它用在 cross-attention layers 也能影响用文字生成图片的效果,在 Stable Diffusion 模型支持 LoRA 后,效果出乎意料,一时火遍全网。

在 Stable Diffusion 中使用 LoRA,是一种使用少量图像来训练模型的方法,由于冻结原有基础模型的权重层并重新计算,LoRA 训练速度很快,通常 8-10 张图片在 T4 单显卡机器上只需要 20 分钟即可训练完毕,且产生模型文件只有几 MB 到一两百 MB,相对于原几个 GB 的模型文件显著降低存储成本,提升效率。

LoRA 和 Dreambooth 的区别

LoRA 与 Dreambooth 都是目前业界主流的 Stable Diffusion 模型 fine tuning 的方法,二者面向的业务场景和实现方式各不相同,这里简单对比如下:

LoRA  Dreamboth 类似 hypernetwork 的单独网络层参数训练模型大小适中,8~200MB 推理加载时需要 LoRA 模型和基础模型融合推理时可以多个不同的 LoRA 模型+权重叠加使用本地训练时需要显存适中,>=7GB 推荐训练人物。 根据 instance token/class token 重新训练 unet/Clip 等子模型模型文件很大,2-5GB 独立的完整模型加载可以进行多次 fine tuning,训练不同的 concept,从而融合多个造型或者物件款式本地训练时需要高显存,>=12GB 推荐训练人脸及物件。 在 SageMaker 上进行 LoRA fine tuning

Stable Diffusion 的 LoRA 如此火爆,自然吸引众多业界商机和关注,很多行业(e.g: 游戏,社交)将 LoRA 模型训练做为 VIP/付费用户的高级体验功能,允许其训练模型生成自己的专属人物,性格画像,二次元虚拟人物,数字模特等。因此将 LoRA 的训练和推理在业务系统/AIOps/ML 中台上的工程化,是实施落地的关键。

本文将详细讲解 LoRA 在 Amazon SageMaker training job 的 fine tuning,以及 Amazon SageMaker inference 推理的开发和部署,以及在 Stable Diffusion Quick Kit 上简单快捷的集成和使用的具体内容,以帮助客户快速上手并将该功能集成到整个后台端到端业务流程中。

SageMaker LoRA 整体流程

我们使用 SageMaker BYOC Training Job 进行 LoRA 模型的 fine tuning,传入待训练的基础模型和图像数据集,做为训练 input,训练完成后输出模型保存在 S3 路径推理部署时,同样通过 BYOC inference 打包推理镜像,传入基础模型 uri 及训练后的 LoRA 模型 S3 位置,合并二者并加载,进行模型的推理生图。

SageMaker BYOC(Bring your own Container)方式训练及推理具体方法这里不再赘述,感兴趣的小伙伴可以在附录中查阅 Amazon 官方文档。

整个流程 pipeline 如下图所示:

LoRA on SageMaker Training Job

LoRA 发展迅速,开源的 fine tuning 框架众多,百花齐放,使用的时候要注意根据业务场景需求进行选择,这里列举部分 LoRA 模型训练业界主要的 github repository,供大家参考:

https://github.com/crosstyan/sd-LoRA/:第一个让 Stable Diffusion 支持 LoRA 的开源框架,但最近没看到更新。 https://github.com/cloneofsimo/LoRA/pulse/monthly:和 SD 的 LoRA 不兼容,训练方式有所不一样,相当于把 embedding 和 LoRA 合到一起调整网络了,训练出来的 LoRA model WebUI 不一定能加载成功。 https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_LoRA.py:Stable Diffusion 官方的 LoRA training。

我们这里选择 Kohya_ss 的 sd-scripts 开源代码,该 repository 是 Stable Diffusion WebUI 的 LoRA 插件的开发者,使用该开源 repository 可以保持与前端 UI 的参数兼容性,并且该插件支持 safetensor、checkpoint 格式的模型文件,自带了这些格式到 Stable Diffusion 模型格式的转换,方便 diffuser pipeline 的加载,最新的 kohya_ss 上还实现了单独的 GUI,方便进行 LoRA 训练的开发调试。

以下详细讲解 LoRA 在 Amazon SageMaker training job 的 fine tuning 开发,以及 Amazon SageMaker inference 推理部署的实现,以及在 Stable Diffusion Quick Kit 上简单快捷的使用的具体内容。

1. 准备阶段

1. 1 准备训练数据集

与早期的 Hypernetwork 网络训练类似,LoRA 也是通过图像训练权重参数层。

传入 LoRA 图像训练数据集有两种方式——

可以使用 HuggingFace 的 dataset 数据集格式,通过传递 HuggineFace 的 dataset 路径 url(e.g: lambdalabs/pokemon-blip-captions),此方式下可以方便地寻找开放的训练数据集,或者按照 HuggingFace 规定的数据格式上传图像及 metadata 元数据文件,即可使用 HuggingFace 的 Load_dataset 标准 API 进行数据 download 和加载。

HuggingFace image Dataset 图像训练数据集格式此处不再赘述,感兴趣的小伙伴可以参考其官方文档说明详细了解:https://huggingface.co/docs/datasets/image_dataset#imagefolder。

另一种方式是客户已经通过图像工具自行准备好了待训练的 images 及 prompt 文本(e.g: 卡通风格的一组二次元 IP 图片),这时可以使用 Kohya_ss 自定义的训练数据集格式,构造 toml 元数据配置文件,指明训练图片放置的位置及层次结构,训练脚本会自动识别该配置文件并获取对应目录下的图像文件。

toml 格式元数据配置文件如下示例所示:

[general] shuffle_caption = true caption_extension = '.txt' keep_tokens = 1 [[datasets]] resolution = [768, 768] batch_size = 2 [[datasets.subsets]] image_dir = '/opt/ml/input/data/images/' # metadata_file = '/opt/ml/input/data/images/metadata.jsonl'

如上示例文件中,[general]为训练数据集整体配置,指定了整体配置的设置,比如每张图像对应的 prompt 文本文件(caption_extension)格式后缀。[[datasets]] 是 general 下的二级配置,指定训练数据集的 revision 像素及训练 batch size 等。[[datasets.subsets]]是具体的每一类训练数据集的详细配置,比如图像所在目录 image_dir,该类图像的元数据 metadaga 文件(如果已经有每张图片名.txt 的 prompt 文本配置文件,则该配置项注释掉)。

详细 toml 格式配置项可以参见 kohya_ss 插件的说明文档:https://github.com/bmaltais/kohya_ss/blob/master/train_network_README.md。

在本次示例中,我们将使用 https://d374aanje223q0.cloudfront.net/pokemon-blip-captions-dataset.tar.gz 作为训练数据集(数据来源于开源链接:https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions),我们将数据集解压后与 dataset.toml 文件一起上传到 S3 中。

# 下载数据集 $ wget https://d374aanje223q0.cloudfront.net/pokemon-blip-captions-dataset.tar.gz $ tar -xzvf pokemon-blip-captions-dataset.tar.gz $ cd pokemon-blip-captions-dataset # 创建配置文件 $ cat > dataset.toml


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3