首页
/ NLP教程:基于Skip-gram和Softmax的Word2Vec实现详解

NLP教程:基于Skip-gram和Softmax的Word2Vec实现详解

2025-07-06 01:54:24作者:管翌锬

前言

Word2Vec是自然语言处理领域中最著名的词嵌入技术之一,它能够将词语映射到低维向量空间,同时保留词语之间的语义关系。本文将深入解析一个使用PyTorch实现Skip-gram模型并结合Softmax分类器的Word2Vec实现。

核心概念解析

Skip-gram模型

Skip-gram是Word2Vec的两种主要架构之一(另一种是CBOW)。它的核心思想是通过中心词预测上下文词,这与人类理解语言的方式相似——通过一个词来联想它可能出现的上下文环境。

Softmax分类器

在传统的Word2Vec实现中,Softmax函数用于计算给定中心词时上下文词的概率分布。虽然计算复杂度较高,但对于小型词汇表或教学目的来说,它是最直观的实现方式。

代码实现详解

1. 数据准备

首先定义了一些简单的句子作为训练数据:

sentences = ["apple banana fruit", "banana orange fruit", 
             "orange banana fruit", "dog cat animal", 
             "cat monkey animal", "monkey dog animal"]

这些句子被处理成单词序列,并建立词汇表字典:

word_sequence = " ".join(sentences).split()
word_list = list(set(word_sequence))
word_dict = {w: i for i, w in enumerate(word_list)}

2. Skip-gram样本生成

通过滑动窗口的方式生成训练样本:

skip_grams = []
for i in range(1, len(word_sequence) - 1):
    target = word_dict[word_sequence[i]]  # 中心词
    context = [word_dict[word_sequence[i - 1]], word_dict[word_sequence[i + 1]]]  # 上下文词
    for w in context:
        skip_grams.append([target, w])

3. 模型架构

定义了一个简单的神经网络模型:

class Word2Vec(nn.Module):
    def __init__(self):
        super(Word2Vec, self).__init__()
        self.W = nn.Linear(voc_size, embedding_size, bias=False)
        self.WT = nn.Linear(embedding_size, voc_size, bias=False)
        
    def forward(self, X):
        hidden_layer = self.W(X)
        output_layer = self.WT(hidden_layer)
        return output_layer

这个模型包含两个线性层:

  • 第一个线性层将one-hot编码的输入词映射到低维嵌入空间
  • 第二个线性层将嵌入向量映射回词汇表大小的空间

4. 训练过程

训练过程采用随机小批量梯度下降:

model = Word2Vec()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5000):
    input_batch, target_batch = random_batch()
    # 前向传播、计算损失、反向传播等步骤
    ...

5. 结果可视化

训练完成后,将学习到的词向量可视化:

for i, label in enumerate(word_list):
    W, WT = model.parameters()
    x, y = W[0][i].item(), W[1][i].item()
    plt.scatter(x, y)
    plt.annotate(label, xy=(x, y), xytext=(5, 2), 
                 textcoords='offset points', ha='right', va='bottom')
plt.show()

关键点解析

  1. 嵌入维度选择:代码中设置为2维(embedding_size=2),这是为了可视化方便,实际应用中通常使用50-300维。

  2. 负采样与Softmax:原始Word2Vec论文使用负采样来提高效率,但本实现使用完整的Softmax,更适合教学目的。

  3. 权重共享:在更高效的实现中,输入和输出矩阵通常共享权重,但本实现保持了两个独立的矩阵。

  4. 批处理策略:使用小批量训练(batch_size=2),每次随机选择样本进行训练。

实际应用建议

  1. 对于大规模语料,建议使用负采样或层次Softmax来替代完整Softmax
  2. 增加嵌入维度可以提高模型表达能力
  3. 可以使用更大的上下文窗口来捕获更广泛的语义关系
  4. 添加子采样技术可以降低高频词的影响

总结

本教程展示了一个简洁但完整的Word2Vec实现,涵盖了从数据准备、模型构建到训练和可视化的全过程。虽然实现简单,但它清晰地展示了Skip-gram模型的核心思想和工作原理,是理解词嵌入技术的优秀起点。