基于Bi-LSTM与注意力机制的文本分类模型实现解析
2025-07-06 02:10:34作者:戚魁泉Nursing
模型概述
本文将详细解析一个使用双向LSTM(Bi-LSTM)结合注意力机制(Attention)实现的文本分类模型。该模型属于自然语言处理(NLP)领域中的经典架构,能够有效捕捉文本中的长距离依赖关系,并通过注意力机制突出关键词语对分类决策的影响。
模型架构详解
1. 嵌入层(Embedding Layer)
模型首先通过嵌入层将离散的词语索引转换为连续的向量表示:
self.embedding = nn.Embedding(vocab_size, embedding_dim)
vocab_size
: 词汇表大小embedding_dim
: 词向量维度(本例中设为2)
2. 双向LSTM层
双向LSTM能够同时考虑词语的前向和后向上下文信息:
self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)
n_hidden
: 隐藏层单元数(本例中设为5)bidirectional=True
: 启用双向LSTM
3. 注意力机制
注意力机制是该模型的核心创新点,它能够自动学习不同词语对分类结果的重要性权重:
def attention_net(self, lstm_output, final_state):
hidden = final_state.view(-1, n_hidden * 2, 1)
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
soft_attn_weights = F.softmax(attn_weights, 1)
context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
return context, soft_attn_weights.data.numpy()
注意力计算过程:
- 将LSTM的最终隐藏状态作为注意力查询(query)
- 计算注意力权重(attention weights)
- 应用softmax归一化
- 计算上下文向量(context vector)
4. 输出层
最后通过一个全连接层将上下文向量映射到类别空间:
self.out = nn.Linear(n_hidden * 2, num_classes)
训练过程分析
数据准备
示例使用了6个简单的英文句子作为训练数据,分为正面(1)和负面(0)两类:
sentences = ["i love you", "he loves me", "she likes baseball",
"i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]
训练配置
- 损失函数:交叉熵损失(CrossEntropyLoss)
- 优化器:Adam,学习率0.001
- 训练轮次:5000次
训练监控
每1000轮输出一次损失值:
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
模型测试与可视化
测试示例
使用"sorry hate you"作为测试输入:
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = torch.LongTensor(tests)
注意力权重可视化
模型使用matplotlib绘制了注意力权重的热力图:
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'],
fontdict={'fontsize': 14}, rotation=90)
ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'],
fontdict={'fontsize': 14})
plt.show()
模型特点与优势
- 双向上下文捕捉:Bi-LSTM能够同时考虑词语的前后文信息
- 注意力机制:自动学习不同词语的重要性,提高模型可解释性
- 端到端训练:整个模型可以联合训练,无需复杂的特征工程
- 可视化能力:注意力权重可视化帮助理解模型决策过程
实际应用建议
- 对于真实场景,建议使用更大的词向量维度(如300维)
- 可以尝试不同的注意力机制变体,如多头注意力
- 考虑加入预训练词向量(如Word2Vec、GloVe)提升性能
- 对于长文本,可以结合分层注意力机制
这个实现展示了如何将Bi-LSTM与注意力机制结合用于文本分类任务,代码简洁但功能完整,非常适合作为学习注意力机制的入门示例。