首页
/ 深入理解BERT模型实现:基于nlp-tutorial项目的代码解析

深入理解BERT模型实现:基于nlp-tutorial项目的代码解析

2025-07-06 02:15:34作者:昌雅子Ethen

前言

BERT(Bidirectional Encoder Representations from Transformers)是自然语言处理领域里程碑式的模型,它通过预训练和微调的方式在各种NLP任务上取得了突破性的成果。本文将通过分析nlp-tutorial项目中的BERT实现代码,带您深入理解BERT的核心机制和实现细节。

BERT模型概述

BERT的核心思想是通过Transformer编码器结构,利用大规模无监督语料进行预训练,学习通用的语言表示。其两个主要预训练任务是:

  1. 掩码语言模型(Masked Language Model, MLM):随机遮盖输入序列中的部分token,让模型预测被遮盖的token
  2. 下一句预测(Next Sentence Prediction, NSP):判断两个句子是否是连续的上下文关系

代码结构解析

1. 数据预处理

make_batch()函数负责生成训练数据,包含以下关键步骤:

def make_batch():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        # 随机选择两个句子
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        
        # 构建输入序列和分段ID
        input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + [1] * (len(tokens_b) + 1)
        
        # 掩码语言模型处理
        n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15))))
        cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%概率替换为[MASK]
                input_ids[pos] = word_dict['[MASK]']
            elif random() < 0.5:  # 10%概率随机替换
                index = randint(0, vocab_size - 1)
                input_ids[pos] = word_dict[number_dict[index]]
        
        # 填充处理
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
        
        # 判断是否为连续句子
        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
    return batch

2. 模型架构

BERT模型主要由以下几部分组成:

2.1 嵌入层(Embedding)

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token嵌入
        self.pos_embed = nn.Embedding(maxlen, d_model)  # 位置嵌入
        self.seg_embed = nn.Embedding(n_segments, d_model)  # 分段嵌入
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

2.2 自注意力机制(Scaled Dot-Product Attention)

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        scores.masked_fill_(attn_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

2.3 多头注意力(Multi-Head Attention)

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    
    def forward(self, Q, K, V, attn_mask):
        residual, batch_size = Q, Q.size(0)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
        
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual), attn

2.4 前馈网络(Position-wise Feed Forward Network)

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(gelu(self.fc1(x)))

2.5 编码器层(Encoder Layer)

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn

2.6 完整的BERT模型

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ1 = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # 解码器与嵌入层共享权重
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
    
    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        
        # 下一句预测任务
        h_pooled = self.activ1(self.fc(output[:, 0]))
        logits_clsf = self.classifier(h_pooled)
        
        # 掩码语言模型任务
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
        h_masked = torch.gather(output, 1, masked_pos)
        h_masked = self.norm(self.activ2(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias
        
        return logits_lm, logits_clsf

训练过程

model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

for epoch in range(100):
    optimizer.zero_grad()
    logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # MLM损失
    loss_lm = (loss_lm.float()).mean()
    loss_clsf = criterion(logits_clsf, isNext) # NSP损失
    loss = loss_lm + loss_clsf
    if (epoch + 1) % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

关键实现细节解析

  1. 掩码语言模型实现

    • 随机选择15%的token进行掩码处理
    • 其中80%替换为[MASK],10%随机替换,10%保持不变
    • 这种策略使模型不仅学习预测被掩码的token,还能处理原始输入
  2. 下一句预测任务

    • 通过[CLS]标记的最终隐藏状态进行分类
    • 正样本为连续句子,负样本为随机组合的句子
  3. 注意力掩码

    • 使用get_attn_pad_mask函数处理填充部分
    • 确保模型不会关注填充位置的信息
  4. 权重共享

    • 解码器与嵌入层共享权重,这是BERT的标准做法
    • 可以减少参数量并提高训练稳定性

总结

通过分析nlp-tutorial项目中的BERT实现,我们可以深入理解:

  1. BERT的双任务预训练机制
  2. Transformer编码器的具体实现
  3. 如何处理输入数据和构建模型
  4. 训练过程中的关键细节

这个实现虽然简化,但包含了BERT的核心思想,是学习BERT模型原理的优秀示例。理解这些代码可以帮助我们更好地应用BERT到实际NLP任务中,也为实现更复杂的Transformer模型打下坚实基础。