深入解析cloneofsimo/lora项目中的LoRA训练脚本
2025-07-07 04:03:41作者:董灵辛Dennis
本文将对cloneofsimo/lora项目中的train_lora_w_ti.py
训练脚本进行深入解析,帮助读者理解如何使用LoRA(Low-Rank Adaptation)技术结合文本反转(Textual Inversion)来微调Stable Diffusion模型。
脚本概述
train_lora_w_ti.py
是一个用于训练LoRA权重并结合文本反转技术的Python脚本,主要功能包括:
- 加载预训练的Stable Diffusion模型
- 注入可训练的LoRA层到UNet模型中
- 实现文本反转训练以学习新的概念
- 支持多种训练配置选项
核心功能解析
1. 数据集处理
脚本中定义了DreamBoothTiDataset
类来处理训练数据,主要特点包括:
- 支持实例图像和类别图像的加载
- 提供多种图像预处理选项:裁剪、颜色抖动、水平翻转等
- 使用模板生成多样化的提示文本
- 支持随机属性混合增强训练数据
class DreamBoothTiDataset(Dataset):
def __init__(self, instance_data_root, learnable_property, placeholder_token,
stochastic_attribute, tokenizer, class_data_root=None, ...):
# 初始化代码...
2. LoRA注入与训练
脚本使用lora_diffusion
模块提供的功能来注入和训练LoRA层:
inject_trainable_lora()
: 将LoRA层注入到UNet模型中extract_lora_ups_down()
: 提取训练好的LoRA权重save_lora_weight()
: 保存LoRA权重到文件
# 注入LoRA层
unet = inject_trainable_lora(unet, r=args.lora_rank)
# 训练完成后保存LoRA权重
save_lora_weight(unet, os.path.join(args.output_dir, "lora_weight.pt"))
3. 文本反转训练
脚本支持同时训练文本反转嵌入:
- 使用特殊的占位符标记(placeholder token)来表示新概念
- 可以控制文本编码器的训练方式
- 支持保存学习到的嵌入向量
# 保存文本反转嵌入
learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
训练配置选项
脚本提供了丰富的训练参数配置:
基本参数
pretrained_model_name_or_path
: 预训练模型路径instance_data_dir
: 实例图像目录placeholder_token
: 占位符标记learnable_property
: 学习属性(style/object)
训练参数
lora_rank
: LoRA的秩大小train_text_encoder
: 是否训练文本编码器learning_rate
: 学习率设置unfreeze_lora_step
: 解冻LoRA层的训练步数
图像处理参数
resolution
: 图像分辨率center_crop
: 是否中心裁剪color_jitter
: 是否使用颜色抖动h_flip
: 是否水平翻转
训练流程
-
初始化阶段:
- 加载预训练模型和分词器
- 准备数据集和数据加载器
- 注入LoRA层到UNet模型
- 设置优化器和学习率调度器
-
训练循环:
- 前向传播计算损失
- 反向传播更新参数
- 定期保存检查点
- 支持梯度累积和混合精度训练
-
保存结果:
- 保存训练好的LoRA权重
- 保存文本反转嵌入
- 支持多种输出格式(pt/safe/both)
技术亮点
-
LoRA与文本反转的结合:
- 同时优化模型权重和新概念嵌入
- 提供更灵活的概念定制能力
-
渐进式训练策略:
- 可以控制LoRA层解冻的时机
- 支持分阶段优化不同组件
-
高效训练技术:
- 支持梯度检查点节省内存
- 提供8-bit Adam优化器选项
- 支持xFormers内存高效注意力
使用建议
- 对于新概念学习,建议同时启用LoRA和文本反转
- 根据显存大小调整批次大小和梯度累积步数
- 对于风格学习,使用
imagenet_style_templates_small
模板 - 监控学习到的嵌入向量变化以评估训练效果
通过这个脚本,用户可以高效地定制Stable Diffusion模型,使其学习新的视觉概念或风格,同时保持模型原有的大部分能力。