DeepMind PolyGen项目:基于Transformer的3D网格生成模型训练指南
2025-07-06 02:56:17作者:彭桢灵Jeremy
项目概述
DeepMind PolyGen是一个创新的3D网格生成框架,它采用了两阶段Transformer架构来生成高质量的3D网格模型。该项目通过分离顶点生成和面生成两个阶段,实现了对3D几何结构的有效建模。
环境准备
在开始训练前,需要配置以下环境依赖:
!pip install tensorflow==1.15 dm-sonnet==1.36 tensor2tensor==1.14
import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import matplotlib.pyplot as plt
注意:项目基于TensorFlow 1.x版本构建,使用兼容性API确保代码运行。
数据集准备
PolyGen使用合成数据集进行演示,包含四种基本几何体:
- 立方体(cube)
- 圆柱体(cylinder)
- 圆锥体(cone)
- 二十面体(icosphere)
数据处理流程
ex_list = []
for k, mesh in enumerate(['cube', 'cylinder', 'cone', 'icosphere']):
mesh_dict = data_utils.load_process_mesh(
os.path.join('meshes', '{}.obj'.format(mesh)))
mesh_dict['class_label'] = k
ex_list.append(mesh_dict)
关键处理步骤:
- 加载原始.obj文件
- 对网格进行归一化和中心化处理
- 对顶点位置进行量化
- 将面数据展平为序列,用特殊标记(=1)分隔不同面
- 为每个网格分配类别标签
模型架构
PolyGen采用双模型架构:
1. 顶点模型(Vertex Model)
负责生成3D网格的顶点坐标序列:
vertex_model = modules.VertexModel(
decoder_config={
'hidden_size': 128,
'fc_size': 512,
'num_layers': 3,
'dropout_rate': 0.
},
class_conditional=True,
num_classes=4,
max_num_input_verts=250,
quantization_bits=8,
)
关键特性:
- 基于Transformer解码器架构
- 支持类别条件生成
- 处理最大250个顶点
- 使用8位量化
2. 面模型(Face Model)
基于生成的顶点生成面序列:
face_model = modules.FaceModel(
encoder_config={
'hidden_size': 128,
'fc_size': 512,
'num_layers': 3,
'dropout_rate': 0.
},
decoder_config={
'hidden_size': 128,
'fc_size': 512,
'num_layers': 3,
'dropout_rate': 0.
},
class_conditional=False,
max_seq_length=500,
quantization_bits=8,
decoder_cross_attention=True,
use_discrete_vertex_embeddings=True,
)
关键特性:
- 编码器-解码器结构
- 解码器使用交叉注意力机制
- 处理最大500长度的面序列
- 使用离散顶点嵌入
训练流程
训练过程采用联合优化策略:
# 优化设置
learning_rate = 5e-4
training_steps = 500
check_step = 5
# 创建优化器
optimizer = tf.train.AdamOptimizer(learning_rate)
vertex_model_optim_op = optimizer.minimize(vertex_model_loss)
face_model_optim_op = optimizer.minimize(face_model_loss)
训练循环中定期检查模型生成的样本:
v_samples_np, f_samples_np, b_np = sess.run(
(vertex_samples, face_samples, vertex_model_batch))
mesh_list = []
for n in range(4):
mesh_list.append(
{
'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],
'faces': data_utils.unflatten_faces(
f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])
}
)
data_utils.plot_meshes(mesh_list, ax_lims=0.5)
技术要点解析
- 顶点序列化:顶点按Z->Y->X坐标顺序展平,形成序列数据
- 面序列表示:面数据被展平为索引序列,用特殊标记分隔不同面
- 量化处理:顶点坐标进行8位量化,减少模型学习难度
- 条件生成:顶点模型支持基于类别的条件生成
- 联合训练:两个模型可以独立训练,也可以联合优化
应用前景
PolyGen框架在以下领域有潜在应用价值:
- 3D内容自动生成
- 计算机辅助设计(CAD)
- 游戏资产创建
- 虚拟现实内容制作
通过本教程,开发者可以理解PolyGen的核心思想,并基于此框架进行3D生成任务的扩展开发。