深入理解BERT模型实现:基于nlp-tutorial项目的代码解析
2025-07-06 02:15:34作者:昌雅子Ethen
前言
BERT(Bidirectional Encoder Representations from Transformers)是自然语言处理领域里程碑式的模型,它通过预训练和微调的方式在各种NLP任务上取得了突破性的成果。本文将通过分析nlp-tutorial项目中的BERT实现代码,带您深入理解BERT的核心机制和实现细节。
BERT模型概述
BERT的核心思想是通过Transformer编码器结构,利用大规模无监督语料进行预训练,学习通用的语言表示。其两个主要预训练任务是:
- 掩码语言模型(Masked Language Model, MLM):随机遮盖输入序列中的部分token,让模型预测被遮盖的token
- 下一句预测(Next Sentence Prediction, NSP):判断两个句子是否是连续的上下文关系
代码结构解析
1. 数据预处理
make_batch()
函数负责生成训练数据,包含以下关键步骤:
def make_batch():
batch = []
positive = negative = 0
while positive != batch_size/2 or negative != batch_size/2:
# 随机选择两个句子
tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
# 构建输入序列和分段ID
input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
segment_ids = [0] * (1 + len(tokens_a) + [1] * (len(tokens_b) + 1)
# 掩码语言模型处理
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
shuffle(cand_maked_pos)
masked_tokens, masked_pos = [], []
for pos in cand_maked_pos[:n_pred]:
masked_pos.append(pos)
masked_tokens.append(input_ids[pos])
if random() < 0.8: # 80%概率替换为[MASK]
input_ids[pos] = word_dict['[MASK]']
elif random() < 0.5: # 10%概率随机替换
index = randint(0, vocab_size - 1)
input_ids[pos] = word_dict[number_dict[index]]
# 填充处理
n_pad = maxlen - len(input_ids)
input_ids.extend([0] * n_pad)
segment_ids.extend([0] * n_pad)
# 判断是否为连续句子
if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
positive += 1
elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
negative += 1
return batch
2. 模型架构
BERT模型主要由以下几部分组成:
2.1 嵌入层(Embedding)
class Embedding(nn.Module):
def __init__(self):
super(Embedding, self).__init__()
self.tok_embed = nn.Embedding(vocab_size, d_model) # token嵌入
self.pos_embed = nn.Embedding(maxlen, d_model) # 位置嵌入
self.seg_embed = nn.Embedding(n_segments, d_model) # 分段嵌入
self.norm = nn.LayerNorm(d_model)
def forward(self, x, seg):
seq_len = x.size(1)
pos = torch.arange(seq_len, dtype=torch.long)
pos = pos.unsqueeze(0).expand_as(x)
embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
return self.norm(embedding)
2.2 自注意力机制(Scaled Dot-Product Attention)
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
scores.masked_fill_(attn_mask, -1e9)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return context, attn
2.3 多头注意力(Multi-Head Attention)
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
def forward(self, Q, K, V, attn_mask):
residual, batch_size = Q, Q.size(0)
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
output = nn.Linear(n_heads * d_v, d_model)(context)
return nn.LayerNorm(d_model)(output + residual), attn
2.4 前馈网络(Position-wise Feed Forward Network)
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.fc2(gelu(self.fc1(x)))
2.5 编码器层(Encoder Layer)
class EncoderLayer(nn.Module):
def __init__(self):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, enc_inputs, enc_self_attn_mask):
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
enc_outputs = self.pos_ffn(enc_outputs)
return enc_outputs, attn
2.6 完整的BERT模型
class BERT(nn.Module):
def __init__(self):
super(BERT, self).__init__()
self.embedding = Embedding()
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
self.fc = nn.Linear(d_model, d_model)
self.activ1 = nn.Tanh()
self.linear = nn.Linear(d_model, d_model)
self.activ2 = gelu
self.norm = nn.LayerNorm(d_model)
self.classifier = nn.Linear(d_model, 2)
# 解码器与嵌入层共享权重
embed_weight = self.embedding.tok_embed.weight
n_vocab, n_dim = embed_weight.size()
self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
self.decoder.weight = embed_weight
self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
def forward(self, input_ids, segment_ids, masked_pos):
output = self.embedding(input_ids, segment_ids)
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
for layer in self.layers:
output, enc_self_attn = layer(output, enc_self_attn_mask)
# 下一句预测任务
h_pooled = self.activ1(self.fc(output[:, 0]))
logits_clsf = self.classifier(h_pooled)
# 掩码语言模型任务
masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
h_masked = torch.gather(output, 1, masked_pos)
h_masked = self.norm(self.activ2(self.linear(h_masked)))
logits_lm = self.decoder(h_masked) + self.decoder_bias
return logits_lm, logits_clsf
训练过程
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))
for epoch in range(100):
optimizer.zero_grad()
logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # MLM损失
loss_lm = (loss_lm.float()).mean()
loss_clsf = criterion(logits_clsf, isNext) # NSP损失
loss = loss_lm + loss_clsf
if (epoch + 1) % 10 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
loss.backward()
optimizer.step()
关键实现细节解析
-
掩码语言模型实现:
- 随机选择15%的token进行掩码处理
- 其中80%替换为[MASK],10%随机替换,10%保持不变
- 这种策略使模型不仅学习预测被掩码的token,还能处理原始输入
-
下一句预测任务:
- 通过[CLS]标记的最终隐藏状态进行分类
- 正样本为连续句子,负样本为随机组合的句子
-
注意力掩码:
- 使用
get_attn_pad_mask
函数处理填充部分 - 确保模型不会关注填充位置的信息
- 使用
-
权重共享:
- 解码器与嵌入层共享权重,这是BERT的标准做法
- 可以减少参数量并提高训练稳定性
总结
通过分析nlp-tutorial项目中的BERT实现,我们可以深入理解:
- BERT的双任务预训练机制
- Transformer编码器的具体实现
- 如何处理输入数据和构建模型
- 训练过程中的关键细节
这个实现虽然简化,但包含了BERT的核心思想,是学习BERT模型原理的优秀示例。理解这些代码可以帮助我们更好地应用BERT到实际NLP任务中,也为实现更复杂的Transformer模型打下坚实基础。