PyTorch Text库中的文本分类模型训练教程
2025-07-09 06:19:33作者:侯霆垣
概述
本文将详细介绍如何使用PyTorch Text库中的工具和组件来训练一个文本分类模型。文本分类是自然语言处理中的基础任务,广泛应用于情感分析、新闻分类、垃圾邮件检测等领域。我们将通过一个完整的训练流程,展示如何构建数据管道、定义模型结构、训练和评估模型。
环境准备
在开始之前,确保你已经安装了以下Python库:
- PyTorch
- torchtext
数据预处理
1. 数据集加载
PyTorch Text提供了多种内置数据集,如AG_NEWS、DBpedia等。我们可以通过DATASETS
字典轻松加载这些数据集:
train_iter = DATASETS[args.dataset](root=data_dir, split="train")
2. 分词处理
文本分类的第一步是将原始文本转换为模型可以处理的token序列。PyTorch Text提供了两种分词方式:
- 基础英语分词器:使用
basic_english
分词器,适合简单的英文文本处理 - SentencePiece分词器:更先进的分词方式,支持多种语言
if use_sp_tokenizer:
sp_model = load_sp_model(sp_model_path)
tokenizer = SentencePieceTokenizer(sp_model)
else:
tokenizer = get_tokenizer("basic_english")
3. 词汇表构建
使用build_vocab_from_iterator
函数从训练数据中构建词汇表:
vocab = build_vocab_from_iterator(yield_tokens(train_iter, ngrams), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
yield_tokens
函数负责生成n-gram tokens,这对于捕捉局部词序信息很有帮助。
数据管道
1. 文本处理管道
定义text_pipeline
函数,将原始文本转换为词汇表索引:
def text_pipeline(x):
return vocab(list(ngrams_iterator(tokenizer(x), ngrams)))
2. 标签处理管道
定义label_pipeline
函数,将标签转换为模型需要的格式:
def label_pipeline(x):
return int(x) - 1 # 通常将标签转换为从0开始的索引
3. 批处理函数
collate_batch
函数负责将多个样本组合成一个批次:
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
# 合并处理后的数据
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list)
return label_list.to(device), text_list.to(device), offsets.to(device)
模型定义
文本分类模型通常包含以下组件:
- 嵌入层(Embedding):将离散的token索引转换为连续的向量表示
- 线性层(Linear):将嵌入向量映射到类别空间
from model import TextClassificationModel
虽然示例中引用了外部模型定义,但典型的文本分类模型结构如下:
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
EmbeddingBag
特别适合处理变长序列,因为它可以高效地计算嵌入向量的平均值。
训练流程
1. 初始化模型和优化器
model = TextClassificationModel(len(vocab), embed_dim, num_class).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss().to(device)
2. 训练循环
def train(dataloader, model, optimizer, criterion, epoch):
model.train()
total_acc, total_count = 0, 0
for idx, (label, text, offsets) in enumerate(dataloader):
optimizer.zero_grad()
predited_label = model(text, offsets)
loss = criterion(predited_label, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪
optimizer.step()
# 计算准确率
total_acc += (predited_label.argmax(1) == label).sum().item()
total_count += label.size(0)
# 定期打印训练信息
if idx % log_interval == 0 and idx > 0:
print(f"| epoch {epoch:3d} | {idx:5d}/{len(dataloader):5d} batches "
f"| accuracy {total_acc/total_count:8.3f}")
total_acc, total_count = 0, 0
3. 评估函数
def evaluate(dataloader, model):
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (label, text, offsets) in enumerate(dataloader):
predited_label = model(text, offsets)
total_acc += (predited_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return total_acc / total_count
完整训练过程
for epoch in range(1, num_epochs + 1):
epoch_start_time = time.time()
train(train_dataloader, model, optimizer, criterion, epoch)
accu_val = evaluate(valid_dataloader, model)
scheduler.step()
print("-" * 59)
print(f"| end of epoch {epoch:3d} | time: {time.time()-epoch_start_time:5.2f}s | "
f"valid accuracy {accu_val:8.3f} ")
print("-" * 59)
模型保存与应用
训练完成后,我们可以保存模型和词汇表供后续使用:
if args.save_model_path:
torch.save(model.to("cpu"), args.save_model_path)
if args.dictionary is not None:
torch.save(vocab, args.dictionary)
总结
本文详细介绍了使用PyTorch Text库进行文本分类的完整流程,包括:
- 数据加载与预处理
- 分词与词汇表构建
- 数据管道设计
- 模型定义与训练
- 评估与模型保存
通过这个示例,你可以快速上手PyTorch Text库,并将其应用于自己的文本分类任务中。根据具体需求,你可以调整模型结构、优化器参数或数据处理方式以获得更好的性能。