Microsoft UniLM中的s2s-ft:序列到序列微调技术详解
2025-07-05 08:13:43作者:董宙帆
项目概述
s2s-ft是Microsoft UniLM项目中的一个重要组件,它是一个基于PyTorch的工具包,专门用于对预训练的Transformer模型进行序列到序列(Sequence-to-Sequence)语言生成任务的微调。该项目支持多种预训练模型,包括不同大小和配置的UniLM模型,能够广泛应用于文本摘要、机器翻译等自然语言生成任务。
环境配置
推荐环境
建议使用Docker容器来运行代码,确保环境一致性:
docker run -it --rm --runtime=nvidia --ipc=host --privileged pytorch/pytorch:1.2-cuda10.0-cudnn7-devel bash
依赖安装
在容器内需要安装以下Python包:
pip install --user methodtools py-rouge pyrouge nltk
python -c "import nltk; nltk.download('punkt')"
git clone apex仓库 && cd apex && git reset --hard de6378f5dae8fcf2879a4be8ecea8bbcb9e59d5 && python setup.py install --cuda_ext --cpp_ext
最后将项目安装为可编辑包:
cd ${code_dir} ; pip install --editable .
预训练模型选择
项目支持多种预训练模型,根据任务需求可选择:
-
基础模型(推荐使用无大小写敏感版本):
- unilm1.2-base-uncased:12层,768隐藏单元,12头,1.1亿参数
- unilm2-base-uncased:12层,768隐藏单元,12头,1.1亿参数
-
大小写敏感模型:
- unilm1-base-cased:12层,768隐藏单元,12头,1.1亿参数
- unilm1-large-cased:24层,1024隐藏单元,16头,3.4亿参数
- unilm2-large-uncased:24层,1024隐藏单元,16头,3.4亿参数
- unilm2-large-cased:24层,1024隐藏单元,16头,3.4亿参数
-
轻量级模型(推理速度更快):
- minilm-l12-h384-uncased:12层,384隐藏单元,12头,3300万参数
输入数据格式
项目支持两种数据格式:
1. 文本格式
每行包含一个JSON对象,其中:
"src"
字段:源序列文本"tgt"
字段:目标序列文本(解码时可忽略)
示例:
{"src": "Messages posted on social media...", "tgt": "Threats to kill pupils..."}
2. 分词格式
使用与BERT相同的WordPiece分词器预处理后的数据:
"src"
字段:源序列分词列表"tgt"
字段:目标序列分词列表(解码时可忽略)
示例:
{"src": ["messages", "posted", ...], "tgt": ["threats", "to", ...]}
系统会自动检测输入格式,如果JSON行包含list
则按分词格式处理,包含string
则自动进行分词。
实战示例:XSum数据集微调
1. 使用unilm1.2-base-uncased模型
微调训练
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \
--train_file ${TRAIN_FILE} --output_dir ${OUTPUT_DIR} \
--model_type unilm --model_name_or_path unilm1.2-base-uncased \
--do_lower_case --fp16 --fp16_opt_level O2 \
--max_source_seq_length 464 --max_target_seq_length 48 \
--per_gpu_train_batch_size 16 --gradient_accumulation_steps 1 \
--learning_rate 7e-5 --num_warmup_steps 500 --num_training_steps 32000
关键参数说明:
- 总batch size = GPU数量 × 单GPU batch size × 梯度累积步数
--do_lower_case
:无大小写敏感模型必须启用--fp16
:启用混合精度训练加速
解码生成
python decode_seq2seq.py \
--fp16 --model_type unilm --tokenizer_name unilm1.2-base-uncased \
--input_file ${INPUT_JSON} --split $SPLIT --do_lower_case \
--model_path ${MODEL_PATH} --max_seq_length 512 --max_tgt_length 48 \
--batch_size 32 --beam_size 5 --length_penalty 0
生成结果保存在${MODEL_PATH}.${SPLIT}
文件中。
评估
使用专用评估脚本计算ROUGE分数:
python evaluations/eval_for_xsum.py \
--pred ${MODEL_PATH}.${SPLIT} --gold ${GOLD_PATH} --split ${SPLIT}
2. 使用minilm-l12-h384-uncased轻量模型
轻量模型训练参数略有不同:
python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \
--train_file ${TRAIN_FILE} --output_dir ${OUTPUT_DIR} \
--model_type minilm --model_name_or_path minilm-l12-h384-uncased \
--do_lower_case --fp16 --fp16_opt_level O2 \
--max_source_seq_length 464 --max_target_seq_length 48 \
--per_gpu_train_batch_size 16 --gradient_accumulation_steps 1 \
--learning_rate 1e-4 --num_warmup_steps 500 --num_training_steps 108000
CNN/Daily Mail数据集实战
1. 使用unilm1-base-cased模型
微调训练
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \
--train_file $TRAIN_FILE --cached_train_features_file $CACHED_FEATURE_FILE \
--output_dir $OUTPUT_DIR --model_type unilm --model_name_or_path unilm1-base-cased \
--fp16 --fp16_opt_level O2 --max_source_seq_length 608 --max_target_seq_length 160 \
--per_gpu_train_batch_size 8 --gradient_accumulation_steps 2 \
--learning_rate 7e-5 --num_warmup_steps 1000 --num_training_steps 45000
评估
使用专用评估脚本:
python evaluations/eval_for_cnndm.py \
--pred ${MODEL_PATH}.${SPLIT} --gold ${GOLD_PATH} --split ${SPLIT} --trunc_len 160
2. 使用unilm2-base-uncased模型
新增目标片段词丢弃概率参数:
python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \
--train_file $TRAIN_FILE --output_dir $OUTPUT_DIR \
--model_type unilm --model_name_or_path unilm2-base-uncased --do_lower_case \
--fp16 --fp16_opt_level O2 --max_source_seq_length 720 --max_target_seq_length 48 \
--per_gpu_train_batch_size 8 --gradient_accumulation_steps 2 \
--learning_rate 7e-5 --num_warmup_steps 1000 --num_training_steps 48000 \
--target_mask_prob 0.4
关键新增参数:
--target_mask_prob
:目标片段词丢弃概率,XSum推荐0.4-0.5,CNN/DM推荐0.7-0.8
技术原理
s2s-ft的核心思想是通过微调预训练的UniLM模型,使其适应特定的序列到序列生成任务。UniLM(Unified Language Model)是一种统一的语言模型,通过不同的注意力掩码机制,可以同时支持自编码、自回归和序列到序列三种任务。
在微调过程中,模型会:
- 对输入序列进行编码
- 使用特定的注意力掩码机制生成目标序列
- 通过最大似然估计优化生成质量
项目支持混合精度训练、多GPU分布式训练等优化技术,大幅提高了训练效率。
应用建议
-
模型选择:
- 对质量要求高:选择large版本模型
- 对推理速度要求高:选择minilm轻量模型
- 英文任务:优先考虑uncased版本
-
参数调优:
- 学习率:base模型5e-5~1e-4,large模型1e-5~3e-5
- batch size:根据显存调整,保持总batch size稳定
- 序列长度:根据任务特点调整max_source_seq_length和max_target_seq_length
-
训练技巧:
- 使用fp16混合精度训练加速
- 适当使用梯度累积模拟更大batch size
- 根据任务特点调整target_mask_prob参数
s2s-ft项目为序列到序列任务提供了强大的基础框架,通过合理的配置和微调,可以在各种文本生成任务上取得优异的效果。