首页
/ DiffSynth-Studio中的Kolors模型训练与微调指南

DiffSynth-Studio中的Kolors模型训练与微调指南

2025-07-06 08:06:19作者:侯霆垣

概述

Kolors是一个基于ChatGLM和Stable Diffusion XL的中文扩散模型,专注于中文文本到图像生成任务。本文将详细介绍如何在DiffSynth-Studio环境中使用Kolors模型进行训练和微调。

模型准备

模型下载与结构

Kolors模型由多个组件构成,需要下载以下文件:

models
├── kolors
│   └── Kolors
│       ├── text_encoder
│       ├── unet
│       └── vae
└── sdxl-vae-fp16-fix

推荐使用以下Python代码下载模型:

from diffsynth import download_models
download_models(["Kolors", "SDXL-vae-fp16-fix"])

模型组件说明

  1. Text Encoder:基于ChatGLM的文本编码器,负责将中文文本转换为潜在空间表示
  2. UNet:扩散模型的核心组件,负责图像生成过程
  3. VAE:变分自编码器,用于图像潜在空间与像素空间之间的转换
  4. SDXL-vae-fp16-fix:修复了FP16精度问题的VAE模型

训练环境准备

依赖安装

执行以下命令安装训练所需依赖:

pip install peft lightning pandas torchvision

硬件要求

训练Kolors模型需要较高配置的GPU:

  • 推荐显存:≥22GB
  • 支持FP16混合精度训练

数据集准备

数据集结构

训练数据应按以下结构组织:

data/类别名称/
└── train
    ├── 00.jpg
    ├── 01.jpg
    ├── ...
    └── metadata.csv

元数据文件

metadata.csv文件应包含两列:

  • file_name:图片文件名
  • text:对应的文本描述

示例内容:

file_name,text
00.jpg,一只小狗
01.jpg,一只小狗
...

LoRA微调训练

训练脚本使用

Kolors提供了train_kolors_lora.py训练脚本,推荐使用以下参数配置:

CUDA_VISIBLE_DEVICES="0" python train_kolors_lora.py \
  --pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
  --pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
  --pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
  --dataset_path data/dog \
  --output_path ./models \
  --max_epochs 10 \
  --center_crop \
  --use_gradient_checkpointing \
  --precision "16-mixed"

关键参数说明

  1. 模型路径参数

    • pretrained_unet_path:预训练UNet模型路径
    • pretrained_text_encoder_path:预训练文本编码器路径
    • pretrained_fp16_vae_path:FP16修复版VAE路径
  2. 训练参数

    • max_epochs:训练轮数
    • batch_size:批次大小(根据显存调整)
    • learning_rate:学习率(默认5e-5)
    • precision:训练精度(推荐"16-mixed")
  3. LoRA参数

    • lora_rank:LoRA矩阵的秩(默认4)
    • lora_alpha:LoRA更新权重(默认4.0)
  4. 优化参数

    • use_gradient_checkpointing:启用梯度检查点以节省显存
    • accumulate_grad_batches:梯度累积步数

模型推理

加载LoRA模型

训练完成后,可以使用以下代码加载并使用LoRA模型:

from diffsynth import ModelManager, KolorsImagePipeline
from peft import LoraConfig, inject_adapter_in_model
import torch

def load_lora(model, lora_rank, lora_alpha, lora_path):
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        init_lora_weights="gaussian",
        target_modules=["to_q", "to_k", "to_v", "to_out"],
    )
    model = inject_adapter_in_model(lora_config, model)
    state_dict = torch.load(lora_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    return model

# 初始化管道
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
                             file_path_list=[
                                 "models/kolors/Kolors/text_encoder",
                                 "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
                                 "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
                             ])
pipe = KolorsImagePipeline.from_model_manager(model_manager)

# 加载LoRA权重
pipe.unet = load_lora(
    pipe.unet,
    lora_rank=4, lora_alpha=4.0,
    lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
)

# 生成图像
torch.manual_seed(0)
image = pipe(
    prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉",
    negative_prompt="",
    cfg_scale=4,
    num_inference_steps=50, height=1024, width=1024,
)
image.save("image_with_lora.jpg")

生成效果对比

使用相同提示词"一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉":

原始模型 LoRA微调后
生成通用风格图像 生成符合训练数据特征的图像

训练建议

  1. 数据质量:确保训练图片质量高、多样性强
  2. 文本描述:描述应准确反映图片内容
  3. 训练轮数:通常5-10个epoch足够
  4. 学习率:从默认值开始,根据loss变化调整
  5. 监控训练:定期检查生成的样本图像

通过本指南,您应该能够成功地在DiffSynth-Studio环境中使用Kolors模型进行训练和微调。LoRA技术可以高效地实现模型个性化,而无需完整训练整个大模型。