基于x-transformers的Transformer语言模型训练教程
2025-07-08 01:34:44作者:虞亚竹Luna
本文将以x-transformers库中的train_enwik8.py为例,详细介绍如何使用Transformer架构训练一个字符级的语言模型。我们将从数据准备、模型构建到训练过程进行全面解析,帮助读者理解Transformer模型在文本生成任务中的应用。
环境与数据准备
首先需要准备训练数据,这里使用的是经典的enwik8数据集,它包含了网络百科的前1亿字节数据。代码中通过gzip模块读取压缩文件,并将其转换为numpy数组:
with gzip.open('./data/enwik8.gz') as file:
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
train_x, valid_x = np.split(data, [int(90e6)])
数据被分割为9000万字节的训练集和500万字节的验证集,随后转换为PyTorch张量。这种字符级处理方式使得模型可以学习生成任意文本,而不受限于特定词汇表。
模型架构设计
x-transformers提供了灵活的Transformer构建方式,这里创建了一个类似GPT的解码器模型:
model = TransformerWrapper(
num_tokens = 256, # 字符级,256个可能值
max_seq_len = SEQ_LEN,
attn_layers = Decoder(
dim = 512, # 模型维度
depth = 6, # 6层Transformer
heads = 8, # 8头注意力
rotary_pos_emb = True # 使用旋转位置编码
)
)
model = AutoregressiveWrapper(model)
关键组件说明:
TransformerWrapper
:处理token嵌入和位置编码Decoder
:标准的Transformer解码器层AutoregressiveWrapper
:添加自回归训练功能rotary_pos_emb
:采用旋转位置编码,相比传统位置编码有更好的外推能力
数据加载与批处理
自定义TextSamplerDataset
类实现了滑动窗口采样,每次随机选取一段连续文本:
class TextSamplerDataset(Dataset):
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq
这种设计确保了:
- 每次训练都能看到不同的文本片段
- 序列长度固定为1024,适合GPU内存
- 数据利用率最大化
训练流程详解
训练过程采用了多项优化技术:
- 梯度累积:每4个批次更新一次参数,模拟更大的batch size
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
- 梯度裁剪:防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
- 周期性验证:每100批次在验证集上评估
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
loss = model(next(val_loader))
- 文本生成测试:每500批次生成样例文本,直观观察模型进步
sample = model.generate(
prompts = inp,
seq_len = GENERATE_LENGTH,
cache_kv = True # 使用KV缓存加速生成
)
关键训练参数
- 学习率:1e-4(使用Adam优化器)
- 批量大小:4(实际等效批量16,因梯度累积)
- 序列长度:1024
- 训练步数:100,000
模型生成能力
生成阶段展示了Transformer的核心能力 - 给定前缀续写文本。代码中随机选取验证集片段作为提示(prompt),然后让模型生成后续文本:
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
sample = model.generate(prompts=inp, seq_len=GENERATE_LENGTH)
这种自回归生成方式与人类写作过程类似,逐个字符预测,将前一步输出作为下一步输入。
总结
通过这个示例,我们学习了:
- 如何使用x-transformers构建Transformer语言模型
- 字符级语言模型的训练方法
- 训练过程中的各种优化技巧
- Transformer模型的文本生成能力
该框架可以轻松扩展到其他文本数据集或调整模型架构,是学习现代NLP技术的优秀起点。读者可以尝试调整模型深度、注意力头数等超参数,或更换更大的数据集来进一步提升模型性能。