首页
/ 基于Bi-LSTM与注意力机制的文本分类模型实现教程

基于Bi-LSTM与注意力机制的文本分类模型实现教程

2025-07-06 02:09:29作者:钟日瑜

模型概述

本教程将详细介绍如何使用双向长短期记忆网络(Bi-LSTM)结合注意力机制(Attention)实现一个简单的文本分类模型。该模型能够自动学习文本中的重要特征,并对文本情感倾向进行分类判断。

模型架构

1. 嵌入层(Embedding Layer)

模型首先通过嵌入层将离散的单词索引转换为连续的向量表示。这种表示能够捕捉单词之间的语义关系。

self.embedding = nn.Embedding(vocab_size, embedding_dim)

2. Bi-LSTM层

双向LSTM能够同时考虑文本的前向和后向信息,更好地理解上下文关系。

self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)

3. 注意力机制(Attention Mechanism)

注意力机制的核心思想是让模型能够自动关注输入序列中对当前任务最重要的部分。

def attention_net(self, lstm_output, final_state):
    hidden = final_state.view(-1, n_hidden * 2, 1)
    attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
    soft_attn_weights = F.softmax(attn_weights, 1)
    context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
    return context, soft_attn_weights.data.numpy()

4. 输出层

最后通过一个全连接层将注意力机制的输出映射到分类结果。

self.out = nn.Linear(n_hidden * 2, num_classes)

实现步骤详解

1. 数据准备

首先定义了一个简单的数据集,包含6个句子和对应的情感标签(1表示正面,0表示负面):

sentences = ["i love you", "he loves me", "she likes baseball", 
             "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]

然后构建词汇表,将单词映射为索引:

word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict)

2. 模型训练

使用交叉熵损失函数和Adam优化器进行模型训练:

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

训练过程中每1000个epoch打印一次损失值:

for epoch in range(5000):
    optimizer.zero_grad()
    output, attention = model(inputs)
    loss = criterion(output, targets)
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

3. 模型测试

训练完成后,可以对新的文本进行情感分类预测:

test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = torch.LongTensor(tests)
predict, _ = model(test_batch)

4. 注意力可视化

模型还提供了注意力权重的可视化功能,可以直观地看到模型在分类时关注了哪些词语:

fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], 
                  fontdict={'fontsize': 14}, rotation=90)
ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 
                        'batch_4', 'batch_5', 'batch_6'], 
                  fontdict={'fontsize': 14})
plt.show()

关键参数说明

  • embedding_dim: 词向量的维度
  • n_hidden: LSTM隐藏层的维度
  • num_classes: 分类类别数(本示例中为2类)
  • vocab_size: 词汇表大小

模型优势

  1. 双向上下文理解:Bi-LSTM能够同时考虑前后文信息
  2. 注意力机制:自动学习文本中的重要部分,提高分类准确性
  3. 可解释性:通过注意力权重可视化,可以理解模型的决策过程

实际应用建议

  1. 对于真实场景,建议使用更大的数据集和预训练词向量
  2. 可以尝试不同的注意力机制变体,如多头注意力
  3. 考虑加入Dropout层防止过拟合
  4. 对于长文本,可以尝试分层注意力机制

通过本教程,读者可以掌握Bi-LSTM与注意力机制结合的基本原理和实现方法,为进一步研究更复杂的NLP模型打下基础。