首页
/ 基于x-transformers的Transformer语言模型训练教程

基于x-transformers的Transformer语言模型训练教程

2025-07-08 01:34:44作者:虞亚竹Luna

本文将以x-transformers库中的train_enwik8.py为例,详细介绍如何使用Transformer架构训练一个字符级的语言模型。我们将从数据准备、模型构建到训练过程进行全面解析,帮助读者理解Transformer模型在文本生成任务中的应用。

环境与数据准备

首先需要准备训练数据,这里使用的是经典的enwik8数据集,它包含了网络百科的前1亿字节数据。代码中通过gzip模块读取压缩文件,并将其转换为numpy数组:

with gzip.open('./data/enwik8.gz') as file:
    data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
    train_x, valid_x = np.split(data, [int(90e6)])

数据被分割为9000万字节的训练集和500万字节的验证集,随后转换为PyTorch张量。这种字符级处理方式使得模型可以学习生成任意文本,而不受限于特定词汇表。

模型架构设计

x-transformers提供了灵活的Transformer构建方式,这里创建了一个类似GPT的解码器模型:

model = TransformerWrapper(
    num_tokens = 256,  # 字符级,256个可能值
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(
        dim = 512,      # 模型维度
        depth = 6,      # 6层Transformer
        heads = 8,      # 8头注意力
        rotary_pos_emb = True  # 使用旋转位置编码
    )
)
model = AutoregressiveWrapper(model)

关键组件说明:

  1. TransformerWrapper:处理token嵌入和位置编码
  2. Decoder:标准的Transformer解码器层
  3. AutoregressiveWrapper:添加自回归训练功能
  4. rotary_pos_emb:采用旋转位置编码,相比传统位置编码有更好的外推能力

数据加载与批处理

自定义TextSamplerDataset类实现了滑动窗口采样,每次随机选取一段连续文本:

class TextSamplerDataset(Dataset):
    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq

这种设计确保了:

  • 每次训练都能看到不同的文本片段
  • 序列长度固定为1024,适合GPU内存
  • 数据利用率最大化

训练流程详解

训练过程采用了多项优化技术:

  1. 梯度累积:每4个批次更新一次参数,模拟更大的batch size
for __ in range(GRADIENT_ACCUMULATE_EVERY):
    loss = model(next(train_loader))
    (loss / GRADIENT_ACCUMULATE_EVERY).backward()
  1. 梯度裁剪:防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
  1. 周期性验证:每100批次在验证集上评估
if i % VALIDATE_EVERY == 0:
    model.eval()
    with torch.no_grad():
        loss = model(next(val_loader))
  1. 文本生成测试:每500批次生成样例文本,直观观察模型进步
sample = model.generate(
    prompts = inp,
    seq_len = GENERATE_LENGTH,
    cache_kv = True  # 使用KV缓存加速生成
)

关键训练参数

  • 学习率:1e-4(使用Adam优化器)
  • 批量大小:4(实际等效批量16,因梯度累积)
  • 序列长度:1024
  • 训练步数:100,000

模型生成能力

生成阶段展示了Transformer的核心能力 - 给定前缀续写文本。代码中随机选取验证集片段作为提示(prompt),然后让模型生成后续文本:

inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
sample = model.generate(prompts=inp, seq_len=GENERATE_LENGTH)

这种自回归生成方式与人类写作过程类似,逐个字符预测,将前一步输出作为下一步输入。

总结

通过这个示例,我们学习了:

  1. 如何使用x-transformers构建Transformer语言模型
  2. 字符级语言模型的训练方法
  3. 训练过程中的各种优化技巧
  4. Transformer模型的文本生成能力

该框架可以轻松扩展到其他文本数据集或调整模型架构,是学习现代NLP技术的优秀起点。读者可以尝试调整模型深度、注意力头数等超参数,或更换更大的数据集来进一步提升模型性能。