首页
/ PyTorch Text库中的文本分类模型训练教程

PyTorch Text库中的文本分类模型训练教程

2025-07-09 06:19:33作者:侯霆垣

概述

本文将详细介绍如何使用PyTorch Text库中的工具和组件来训练一个文本分类模型。文本分类是自然语言处理中的基础任务,广泛应用于情感分析、新闻分类、垃圾邮件检测等领域。我们将通过一个完整的训练流程,展示如何构建数据管道、定义模型结构、训练和评估模型。

环境准备

在开始之前,确保你已经安装了以下Python库:

  • PyTorch
  • torchtext

数据预处理

1. 数据集加载

PyTorch Text提供了多种内置数据集,如AG_NEWS、DBpedia等。我们可以通过DATASETS字典轻松加载这些数据集:

train_iter = DATASETS[args.dataset](root=data_dir, split="train")

2. 分词处理

文本分类的第一步是将原始文本转换为模型可以处理的token序列。PyTorch Text提供了两种分词方式:

  1. 基础英语分词器:使用basic_english分词器,适合简单的英文文本处理
  2. SentencePiece分词器:更先进的分词方式,支持多种语言
if use_sp_tokenizer:
    sp_model = load_sp_model(sp_model_path)
    tokenizer = SentencePieceTokenizer(sp_model)
else:
    tokenizer = get_tokenizer("basic_english")

3. 词汇表构建

使用build_vocab_from_iterator函数从训练数据中构建词汇表:

vocab = build_vocab_from_iterator(yield_tokens(train_iter, ngrams), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

yield_tokens函数负责生成n-gram tokens,这对于捕捉局部词序信息很有帮助。

数据管道

1. 文本处理管道

定义text_pipeline函数,将原始文本转换为词汇表索引:

def text_pipeline(x):
    return vocab(list(ngrams_iterator(tokenizer(x), ngrams)))

2. 标签处理管道

定义label_pipeline函数,将标签转换为模型需要的格式:

def label_pipeline(x):
    return int(x) - 1  # 通常将标签转换为从0开始的索引

3. 批处理函数

collate_batch函数负责将多个样本组合成一个批次:

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    # 合并处理后的数据
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

模型定义

文本分类模型通常包含以下组件:

  1. 嵌入层(Embedding):将离散的token索引转换为连续的向量表示
  2. 线性层(Linear):将嵌入向量映射到类别空间
from model import TextClassificationModel

虽然示例中引用了外部模型定义,但典型的文本分类模型结构如下:

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

EmbeddingBag特别适合处理变长序列,因为它可以高效地计算嵌入向量的平均值。

训练流程

1. 初始化模型和优化器

model = TextClassificationModel(len(vocab), embed_dim, num_class).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss().to(device)

2. 训练循环

def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    
    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(text, offsets)
        loss = criterion(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度裁剪
        optimizer.step()
        # 计算准确率
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        # 定期打印训练信息
        if idx % log_interval == 0 and idx > 0:
            print(f"| epoch {epoch:3d} | {idx:5d}/{len(dataloader):5d} batches "
                  f"| accuracy {total_acc/total_count:8.3f}")
            total_acc, total_count = 0, 0

3. 评估函数

def evaluate(dataloader, model):
    model.eval()
    total_acc, total_count = 0, 0
    
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

完整训练过程

for epoch in range(1, num_epochs + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    accu_val = evaluate(valid_dataloader, model)
    scheduler.step()
    print("-" * 59)
    print(f"| end of epoch {epoch:3d} | time: {time.time()-epoch_start_time:5.2f}s | "
          f"valid accuracy {accu_val:8.3f} ")
    print("-" * 59)

模型保存与应用

训练完成后,我们可以保存模型和词汇表供后续使用:

if args.save_model_path:
    torch.save(model.to("cpu"), args.save_model_path)
    
if args.dictionary is not None:
    torch.save(vocab, args.dictionary)

总结

本文详细介绍了使用PyTorch Text库进行文本分类的完整流程,包括:

  1. 数据加载与预处理
  2. 分词与词汇表构建
  3. 数据管道设计
  4. 模型定义与训练
  5. 评估与模型保存

通过这个示例,你可以快速上手PyTorch Text库,并将其应用于自己的文本分类任务中。根据具体需求,你可以调整模型结构、优化器参数或数据处理方式以获得更好的性能。