首页
/ DeepMind PolyGen项目:基于Transformer的3D网格生成模型训练指南

DeepMind PolyGen项目:基于Transformer的3D网格生成模型训练指南

2025-07-06 02:56:17作者:彭桢灵Jeremy

项目概述

DeepMind PolyGen是一个创新的3D网格生成框架,它采用了两阶段Transformer架构来生成高质量的3D网格模型。该项目通过分离顶点生成和面生成两个阶段,实现了对3D几何结构的有效建模。

环境准备

在开始训练前,需要配置以下环境依赖:

!pip install tensorflow==1.15 dm-sonnet==1.36 tensor2tensor==1.14

import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import matplotlib.pyplot as plt

注意:项目基于TensorFlow 1.x版本构建,使用兼容性API确保代码运行。

数据集准备

PolyGen使用合成数据集进行演示,包含四种基本几何体:

  1. 立方体(cube)
  2. 圆柱体(cylinder)
  3. 圆锥体(cone)
  4. 二十面体(icosphere)

数据处理流程

ex_list = []
for k, mesh in enumerate(['cube', 'cylinder', 'cone', 'icosphere']):
  mesh_dict = data_utils.load_process_mesh(
      os.path.join('meshes', '{}.obj'.format(mesh)))
  mesh_dict['class_label'] = k
  ex_list.append(mesh_dict)

关键处理步骤:

  • 加载原始.obj文件
  • 对网格进行归一化和中心化处理
  • 对顶点位置进行量化
  • 将面数据展平为序列,用特殊标记(=1)分隔不同面
  • 为每个网格分配类别标签

模型架构

PolyGen采用双模型架构:

1. 顶点模型(Vertex Model)

负责生成3D网格的顶点坐标序列:

vertex_model = modules.VertexModel(
    decoder_config={
        'hidden_size': 128,
        'fc_size': 512, 
        'num_layers': 3,
        'dropout_rate': 0.
    },
    class_conditional=True,
    num_classes=4,
    max_num_input_verts=250,
    quantization_bits=8,
)

关键特性:

  • 基于Transformer解码器架构
  • 支持类别条件生成
  • 处理最大250个顶点
  • 使用8位量化

2. 面模型(Face Model)

基于生成的顶点生成面序列:

face_model = modules.FaceModel(
    encoder_config={
        'hidden_size': 128,
        'fc_size': 512, 
        'num_layers': 3,
        'dropout_rate': 0.
    },
    decoder_config={
        'hidden_size': 128,
        'fc_size': 512, 
        'num_layers': 3,
        'dropout_rate': 0.
    },
    class_conditional=False,
    max_seq_length=500,
    quantization_bits=8,
    decoder_cross_attention=True,
    use_discrete_vertex_embeddings=True,
)

关键特性:

  • 编码器-解码器结构
  • 解码器使用交叉注意力机制
  • 处理最大500长度的面序列
  • 使用离散顶点嵌入

训练流程

训练过程采用联合优化策略:

# 优化设置
learning_rate = 5e-4
training_steps = 500
check_step = 5

# 创建优化器
optimizer = tf.train.AdamOptimizer(learning_rate)
vertex_model_optim_op = optimizer.minimize(vertex_model_loss)
face_model_optim_op = optimizer.minimize(face_model_loss)

训练循环中定期检查模型生成的样本:

v_samples_np, f_samples_np, b_np = sess.run(
    (vertex_samples, face_samples, vertex_model_batch))
mesh_list = []
for n in range(4):
    mesh_list.append(
        {
            'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],
            'faces': data_utils.unflatten_faces(
                f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])
        }
    )
data_utils.plot_meshes(mesh_list, ax_lims=0.5)

技术要点解析

  1. 顶点序列化:顶点按Z->Y->X坐标顺序展平,形成序列数据
  2. 面序列表示:面数据被展平为索引序列,用特殊标记分隔不同面
  3. 量化处理:顶点坐标进行8位量化,减少模型学习难度
  4. 条件生成:顶点模型支持基于类别的条件生成
  5. 联合训练:两个模型可以独立训练,也可以联合优化

应用前景

PolyGen框架在以下领域有潜在应用价值:

  • 3D内容自动生成
  • 计算机辅助设计(CAD)
  • 游戏资产创建
  • 虚拟现实内容制作

通过本教程,开发者可以理解PolyGen的核心思想,并基于此框架进行3D生成任务的扩展开发。