首页
/ 基于Bi-LSTM与注意力机制的文本分类模型实现解析

基于Bi-LSTM与注意力机制的文本分类模型实现解析

2025-07-06 02:10:34作者:戚魁泉Nursing

模型概述

本文将详细解析一个使用双向LSTM(Bi-LSTM)结合注意力机制(Attention)实现的文本分类模型。该模型属于自然语言处理(NLP)领域中的经典架构,能够有效捕捉文本中的长距离依赖关系,并通过注意力机制突出关键词语对分类决策的影响。

模型架构详解

1. 嵌入层(Embedding Layer)

模型首先通过嵌入层将离散的词语索引转换为连续的向量表示:

self.embedding = nn.Embedding(vocab_size, embedding_dim)
  • vocab_size: 词汇表大小
  • embedding_dim: 词向量维度(本例中设为2)

2. 双向LSTM层

双向LSTM能够同时考虑词语的前向和后向上下文信息:

self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)
  • n_hidden: 隐藏层单元数(本例中设为5)
  • bidirectional=True: 启用双向LSTM

3. 注意力机制

注意力机制是该模型的核心创新点,它能够自动学习不同词语对分类结果的重要性权重:

def attention_net(self, lstm_output, final_state):
    hidden = final_state.view(-1, n_hidden * 2, 1)
    attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
    soft_attn_weights = F.softmax(attn_weights, 1)
    context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
    return context, soft_attn_weights.data.numpy()

注意力计算过程:

  1. 将LSTM的最终隐藏状态作为注意力查询(query)
  2. 计算注意力权重(attention weights)
  3. 应用softmax归一化
  4. 计算上下文向量(context vector)

4. 输出层

最后通过一个全连接层将上下文向量映射到类别空间:

self.out = nn.Linear(n_hidden * 2, num_classes)

训练过程分析

数据准备

示例使用了6个简单的英文句子作为训练数据,分为正面(1)和负面(0)两类:

sentences = ["i love you", "he loves me", "she likes baseball", 
             "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]

训练配置

  • 损失函数:交叉熵损失(CrossEntropyLoss)
  • 优化器:Adam,学习率0.001
  • 训练轮次:5000次

训练监控

每1000轮输出一次损失值:

if (epoch + 1) % 1000 == 0:
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

模型测试与可视化

测试示例

使用"sorry hate you"作为测试输入:

test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = torch.LongTensor(tests)

注意力权重可视化

模型使用matplotlib绘制了注意力权重的热力图:

fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], 
                   fontdict={'fontsize': 14}, rotation=90)
ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], 
                   fontdict={'fontsize': 14})
plt.show()

模型特点与优势

  1. 双向上下文捕捉:Bi-LSTM能够同时考虑词语的前后文信息
  2. 注意力机制:自动学习不同词语的重要性,提高模型可解释性
  3. 端到端训练:整个模型可以联合训练,无需复杂的特征工程
  4. 可视化能力:注意力权重可视化帮助理解模型决策过程

实际应用建议

  1. 对于真实场景,建议使用更大的词向量维度(如300维)
  2. 可以尝试不同的注意力机制变体,如多头注意力
  3. 考虑加入预训练词向量(如Word2Vec、GloVe)提升性能
  4. 对于长文本,可以结合分层注意力机制

这个实现展示了如何将Bi-LSTM与注意力机制结合用于文本分类任务,代码简洁但功能完整,非常适合作为学习注意力机制的入门示例。