DiffSynth-Studio中的Kolors模型训练与微调指南
2025-07-06 08:06:19作者:侯霆垣
概述
Kolors是一个基于ChatGLM和Stable Diffusion XL的中文扩散模型,专注于中文文本到图像生成任务。本文将详细介绍如何在DiffSynth-Studio环境中使用Kolors模型进行训练和微调。
模型准备
模型下载与结构
Kolors模型由多个组件构成,需要下载以下文件:
models
├── kolors
│ └── Kolors
│ ├── text_encoder
│ ├── unet
│ └── vae
└── sdxl-vae-fp16-fix
推荐使用以下Python代码下载模型:
from diffsynth import download_models
download_models(["Kolors", "SDXL-vae-fp16-fix"])
模型组件说明
- Text Encoder:基于ChatGLM的文本编码器,负责将中文文本转换为潜在空间表示
- UNet:扩散模型的核心组件,负责图像生成过程
- VAE:变分自编码器,用于图像潜在空间与像素空间之间的转换
- SDXL-vae-fp16-fix:修复了FP16精度问题的VAE模型
训练环境准备
依赖安装
执行以下命令安装训练所需依赖:
pip install peft lightning pandas torchvision
硬件要求
训练Kolors模型需要较高配置的GPU:
- 推荐显存:≥22GB
- 支持FP16混合精度训练
数据集准备
数据集结构
训练数据应按以下结构组织:
data/类别名称/
└── train
├── 00.jpg
├── 01.jpg
├── ...
└── metadata.csv
元数据文件
metadata.csv
文件应包含两列:
file_name
:图片文件名text
:对应的文本描述
示例内容:
file_name,text
00.jpg,一只小狗
01.jpg,一只小狗
...
LoRA微调训练
训练脚本使用
Kolors提供了train_kolors_lora.py
训练脚本,推荐使用以下参数配置:
CUDA_VISIBLE_DEVICES="0" python train_kolors_lora.py \
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
--pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 10 \
--center_crop \
--use_gradient_checkpointing \
--precision "16-mixed"
关键参数说明
-
模型路径参数:
pretrained_unet_path
:预训练UNet模型路径pretrained_text_encoder_path
:预训练文本编码器路径pretrained_fp16_vae_path
:FP16修复版VAE路径
-
训练参数:
max_epochs
:训练轮数batch_size
:批次大小(根据显存调整)learning_rate
:学习率(默认5e-5)precision
:训练精度(推荐"16-mixed")
-
LoRA参数:
lora_rank
:LoRA矩阵的秩(默认4)lora_alpha
:LoRA更新权重(默认4.0)
-
优化参数:
use_gradient_checkpointing
:启用梯度检查点以节省显存accumulate_grad_batches
:梯度累积步数
模型推理
加载LoRA模型
训练完成后,可以使用以下代码加载并使用LoRA模型:
from diffsynth import ModelManager, KolorsImagePipeline
from peft import LoraConfig, inject_adapter_in_model
import torch
def load_lora(model, lora_rank, lora_alpha, lora_path):
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights="gaussian",
target_modules=["to_q", "to_k", "to_v", "to_out"],
)
model = inject_adapter_in_model(lora_config, model)
state_dict = torch.load(lora_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model
# 初始化管道
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
file_path_list=[
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
])
pipe = KolorsImagePipeline.from_model_manager(model_manager)
# 加载LoRA权重
pipe.unet = load_lora(
pipe.unet,
lora_rank=4, lora_alpha=4.0,
lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
)
# 生成图像
torch.manual_seed(0)
image = pipe(
prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉",
negative_prompt="",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)
image.save("image_with_lora.jpg")
生成效果对比
使用相同提示词"一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉":
原始模型 | LoRA微调后 |
---|---|
生成通用风格图像 | 生成符合训练数据特征的图像 |
训练建议
- 数据质量:确保训练图片质量高、多样性强
- 文本描述:描述应准确反映图片内容
- 训练轮数:通常5-10个epoch足够
- 学习率:从默认值开始,根据loss变化调整
- 监控训练:定期检查生成的样本图像
通过本指南,您应该能够成功地在DiffSynth-Studio环境中使用Kolors模型进行训练和微调。LoRA技术可以高效地实现模型个性化,而无需完整训练整个大模型。