首页
/ 深入解析cloneofsimo/lora项目中的LoRA训练脚本

深入解析cloneofsimo/lora项目中的LoRA训练脚本

2025-07-07 04:03:41作者:董灵辛Dennis

本文将对cloneofsimo/lora项目中的train_lora_w_ti.py训练脚本进行深入解析,帮助读者理解如何使用LoRA(Low-Rank Adaptation)技术结合文本反转(Textual Inversion)来微调Stable Diffusion模型。

脚本概述

train_lora_w_ti.py是一个用于训练LoRA权重并结合文本反转技术的Python脚本,主要功能包括:

  1. 加载预训练的Stable Diffusion模型
  2. 注入可训练的LoRA层到UNet模型中
  3. 实现文本反转训练以学习新的概念
  4. 支持多种训练配置选项

核心功能解析

1. 数据集处理

脚本中定义了DreamBoothTiDataset类来处理训练数据,主要特点包括:

  • 支持实例图像和类别图像的加载
  • 提供多种图像预处理选项:裁剪、颜色抖动、水平翻转等
  • 使用模板生成多样化的提示文本
  • 支持随机属性混合增强训练数据
class DreamBoothTiDataset(Dataset):
    def __init__(self, instance_data_root, learnable_property, placeholder_token, 
                 stochastic_attribute, tokenizer, class_data_root=None, ...):
        # 初始化代码...

2. LoRA注入与训练

脚本使用lora_diffusion模块提供的功能来注入和训练LoRA层:

  • inject_trainable_lora(): 将LoRA层注入到UNet模型中
  • extract_lora_ups_down(): 提取训练好的LoRA权重
  • save_lora_weight(): 保存LoRA权重到文件
# 注入LoRA层
unet = inject_trainable_lora(unet, r=args.lora_rank)

# 训练完成后保存LoRA权重
save_lora_weight(unet, os.path.join(args.output_dir, "lora_weight.pt"))

3. 文本反转训练

脚本支持同时训练文本反转嵌入:

  • 使用特殊的占位符标记(placeholder token)来表示新概念
  • 可以控制文本编码器的训练方式
  • 支持保存学习到的嵌入向量
# 保存文本反转嵌入
learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)

训练配置选项

脚本提供了丰富的训练参数配置:

基本参数

  • pretrained_model_name_or_path: 预训练模型路径
  • instance_data_dir: 实例图像目录
  • placeholder_token: 占位符标记
  • learnable_property: 学习属性(style/object)

训练参数

  • lora_rank: LoRA的秩大小
  • train_text_encoder: 是否训练文本编码器
  • learning_rate: 学习率设置
  • unfreeze_lora_step: 解冻LoRA层的训练步数

图像处理参数

  • resolution: 图像分辨率
  • center_crop: 是否中心裁剪
  • color_jitter: 是否使用颜色抖动
  • h_flip: 是否水平翻转

训练流程

  1. 初始化阶段:

    • 加载预训练模型和分词器
    • 准备数据集和数据加载器
    • 注入LoRA层到UNet模型
    • 设置优化器和学习率调度器
  2. 训练循环:

    • 前向传播计算损失
    • 反向传播更新参数
    • 定期保存检查点
    • 支持梯度累积和混合精度训练
  3. 保存结果:

    • 保存训练好的LoRA权重
    • 保存文本反转嵌入
    • 支持多种输出格式(pt/safe/both)

技术亮点

  1. LoRA与文本反转的结合:

    • 同时优化模型权重和新概念嵌入
    • 提供更灵活的概念定制能力
  2. 渐进式训练策略:

    • 可以控制LoRA层解冻的时机
    • 支持分阶段优化不同组件
  3. 高效训练技术:

    • 支持梯度检查点节省内存
    • 提供8-bit Adam优化器选项
    • 支持xFormers内存高效注意力

使用建议

  1. 对于新概念学习,建议同时启用LoRA和文本反转
  2. 根据显存大小调整批次大小和梯度累积步数
  3. 对于风格学习,使用imagenet_style_templates_small模板
  4. 监控学习到的嵌入向量变化以评估训练效果

通过这个脚本,用户可以高效地定制Stable Diffusion模型,使其学习新的视觉概念或风格,同时保持模型原有的大部分能力。