Facebook DLRM项目核心模型解析:PyTorch实现详解
2025-07-09 03:25:32作者:羿妍玫Ivan
概述
Facebook Research团队提出的DLRM(Deep Learning Recommendation Model)是一种用于个性化推荐系统的深度学习模型。该模型创新性地结合了密集特征和稀疏特征的处理方式,通过多层感知机(MLP)和嵌入层(Embedding)的交互,实现了高效的推荐预测。
模型架构
DLRM模型的核心架构可以分为以下几个部分:
- 底部MLP:处理密集特征(dense features)
- 嵌入层:处理稀疏特征(sparse features)
- 顶部MLP:处理特征交互后的结果
- 特征交互层:实现不同特征间的交互操作
模型的数据流向可以表示为:
[密集特征] -> 底部MLP -> 特征交互 <- [稀疏特征通过嵌入层]
顶部MLP -> 输出
核心组件实现
1. MLP构建
create_mlp
方法负责构建多层感知机:
def create_mlp(self, ln, sigmoid_layer):
layers = nn.ModuleList()
for i in range(0, ln.size - 1):
n = ln[i] # 输入维度
m = ln[i+1] # 输出维度
# 构建全连接层
LL = nn.Linear(int(n), int(m), bias=True)
# Xavier初始化
std_dev = np.sqrt(2 / (m + n))
W = np.random.normal(0, std_dev, size=(m, n)).astype(np.float32)
bt = np.random.normal(0, np.sqrt(1/m), size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
layers.append(LL)
# 添加激活函数
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
return torch.nn.Sequential(*layers)
该方法特点:
- 使用Xavier初始化保证网络训练稳定性
- 支持自定义sigmoid层位置
- 默认使用ReLU激活函数
2. 嵌入层构建
create_emb
方法构建嵌入层,支持多种优化技术:
def create_emb(self, m, ln, weighted_pooling=None):
emb_l = nn.ModuleList()
v_W_l = []
for i in range(0, ln.size):
n = ln[i] # 嵌入表大小
# QR嵌入(降低存储)
if self.qr_flag and n > self.qr_threshold:
EE = QREmbeddingBag(n, m, self.qr_collisions,
operation=self.qr_operation)
# MD嵌入(混合维度)
elif self.md_flag and n > self.md_threshold:
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
# 常规嵌入
else:
EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)
W = np.random.uniform(-np.sqrt(1/n), np.sqrt(1/n), (n, m))
EE.weight.data = torch.tensor(W, requires_grad=True)
emb_l.append(EE)
v_W_l.append(torch.ones(n, dtype=torch.float32) if weighted_pooling else None)
return emb_l, v_W_l
嵌入层优化技术:
- QR嵌入:使用商余技巧减少嵌入表大小
- MD嵌入:混合维度技术,不同嵌入表使用不同维度
- 加权池化:支持学习或固定权重的池化方式
前向传播逻辑
模型的前向传播分为几个关键步骤:
- 处理密集特征:
x = self.apply_mlp(dense_x, self.bot_l)
- 处理稀疏特征:
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
- 特征交互:
if self.arch_interaction_op == "dot":
# 点积交互
z = torch.cat([x] + [torch.matmul(ly[i], x.t()) for i in range(len(ly))], dim=1)
elif self.arch_interaction_op == "cat":
# 拼接交互
z = torch.cat([x] + ly, dim=1)
else:
# 求和交互(默认)
z = torch.cat([x] + [torch.sum(ly, dim=0, keepdim=True)], dim=1)
- 顶部MLP处理:
p = self.apply_mlp(z, self.top_l)
训练相关功能
损失函数
支持多种损失函数配置:
- MSE (均方误差)
- BCE (二元交叉熵)
- WBCE (加权二元交叉熵)
if self.loss_function == "mse":
self.loss_fn = torch.nn.MSELoss(reduction="mean")
elif self.loss_function == "bce":
self.loss_fn = torch.nn.BCELoss(reduction="mean")
elif self.loss_function == "wbce":
self.loss_ws = torch.tensor(np.fromstring(args.loss_weights, dtype=float, sep="-"))
self.loss_fn = torch.nn.BCELoss(reduction="none")
学习率调度
实现自定义学习率调度策略:
class LRPolicyScheduler(_LRScheduler):
def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
# 包含预热和衰减阶段
self.num_warmup_steps = num_warmup_steps
self.decay_start_step = decay_start_step
self.decay_end_step = decay_start_step + num_decay_steps
性能优化技术
- 分布式训练支持:
if ext_dist.my_size > 1:
self.n_local_emb, self.n_emb_per_rank = ext_dist.get_split_lengths(n_emb)
self.local_emb_slice = ext_dist.get_my_slice(n_emb)
- 嵌入量化:
def quantize_embedding(self, bits):
if bits == 4:
self.emb_l_q[k] = ops.quantized.embedding_bag_4bit_prepack(self.emb_l[k].weight)
elif bits == 8:
self.emb_l_q[k] = ops.quantized.embedding_bag_byte_prepack(self.emb_l[k].weight)
使用建议
-
超参数选择:
- 对于大型嵌入表(>200维),考虑使用QR或MD技术
- 交互操作根据任务需求选择dot/cat/sum
-
训练技巧:
- 使用学习率预热避免早期训练不稳定
- 对于类别不平衡数据,使用WBCE损失
-
部署优化:
- 考虑使用8bit或4bit量化减小模型大小
- 分布式训练可以显著加速大型嵌入表的训练
DLRM模型通过创新的架构设计,在推荐系统任务中实现了高精度和高效能的平衡,其PyTorch实现提供了丰富的配置选项,可以灵活适应各种推荐场景的需求。