首页
/ Twitter推荐算法中的Twhin模型解析与实现

Twitter推荐算法中的Twhin模型解析与实现

2025-07-06 06:24:50作者:仰钰奇

概述

本文将深入解析Twitter推荐算法项目中的Twhin模型实现,这是一个基于图嵌入的推荐系统模型。Twhin模型通过捕捉用户与内容之间的复杂关系来提升推荐质量,是Twitter推荐系统的重要组成部分。

模型架构

Twhin模型的核心由以下几个组件构成:

  1. 大型嵌入层(LargeEmbeddings):处理高维稀疏特征
  2. 关系转换参数:学习不同关系类型间的转换
  3. 负采样机制:通过对比学习提升模型性能

嵌入层初始化

self.large_embeddings = LargeEmbeddings(model_config.embeddings)

模型使用LargeEmbeddings来处理大规模稀疏特征,这是处理推荐系统中海量用户和内容ID的标准做法。嵌入维度由配置中的embedding_dim参数决定。

关系转换参数

self.all_trans_embs = torch.nn.parameter.Parameter(
  torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))

为每种关系类型学习一个独立的转换向量,这使得模型能够捕捉不同类型关系(如关注、点赞等)之间的差异。

前向传播逻辑

模型的前向传播过程可以分为几个关键步骤:

  1. 嵌入查找:通过large_embeddings获取节点嵌入
  2. 关系转换:应用关系特定的转换向量
  3. 相似度计算:计算正样本和负样本的相似度

嵌入处理

# 2B x TD -> 2B x T x D -> 2B x D
x = outs.values().reshape(2 * self.batch_size, -1, self.embedding_dim)
x = torch.sum(x, 1)

这里将查找得到的嵌入进行重塑和求和,最终得到每个节点的聚合表示。

负采样机制

模型支持两种负采样方式:

  • 批内负采样(in_batch_negatives)
  • 全局负采样(global_negatives)
if self.in_batch_negatives:
  # 构造负样本的相似度矩阵
  negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
  negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))

批内负采样通过随机排列当前批次中的样本来构造负样本对,这是一种高效且有效的对比学习策略。

优化器应用

Twhin模型采用了特殊的优化器应用方式,为每个嵌入表单独配置优化器:

def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
  for table in model_config.embeddings.tables:
    optimizer_class = get_optimizer_class(table.optimizer)
    apply_optimizer_in_backward(
      optimizer_class=optimizer_class,
      params=params,
      optimizer_kwargs=optimizer_kwargs,
    )

这种方法允许对不同部分的参数使用不同的优化策略,在处理大规模稀疏特征时特别有用。

损失计算

TwhinModelAndLoss类封装了模型和损失函数,主要特点包括:

  1. 正负样本平衡:通过neg_weight调整正负样本的权重
  2. 多任务支持:可以同时处理多种类型的预测任务
neg_weight = float(num_positives) / num_negatives
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
weights = torch.cat([batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)])

这种设计使得模型能够有效地处理推荐系统中常见的类别不平衡问题。

实际应用建议

  1. 嵌入维度选择:根据数据规模选择合适的嵌入维度,太大容易过拟合,太小则表达能力不足
  2. 负采样比例:调整in_batch_negatives参数平衡训练效率和模型性能
  3. 关系类型设计:仔细设计关系类型,确保它们能够捕捉到重要的用户行为模式

Twhin模型通过结合图嵌入和关系转换,能够有效地捕捉Twitter平台上的复杂用户-内容交互模式,为推荐系统提供强大的基础能力。