首页
/ 深入解析ikostrikov/pytorch-a2c-ppo-acktr-gail中的策略模型设计

深入解析ikostrikov/pytorch-a2c-ppo-acktr-gail中的策略模型设计

2025-07-09 04:17:48作者:晏闻田Solitary

本文将对ikostrikov/pytorch-a2c-ppo-acktr-gail项目中的model.py文件进行深入解析,重点讲解其强化学习策略模型的设计架构和实现细节。

模型架构概述

该文件实现了强化学习中常用的策略网络架构,主要包含以下几个核心组件:

  1. 基础网络模块:包括CNNBase和MLPBase,分别处理图像输入和向量输入
  2. 策略分布模块:处理不同类型的动作空间(离散、连续、多二元)
  3. 循环网络支持:通过GRU实现时序记忆功能

核心类解析

1. Policy类:策略网络主框架

Policy类是策略网络的核心框架,它根据输入类型自动选择适当的基础网络,并构建相应的动作分布:

class Policy(nn.Module):
    def __init__(self, obs_shape, action_space, base=None, base_kwargs=None):
        super(Policy, self).__init__()
        # 根据观测形状选择基础网络
        if len(obs_shape) == 3:
            base = CNNBase  # 图像输入使用CNN
        elif len(obs_shape) == 1:
            base = MLPBase  # 向量输入使用MLP
            
        # 根据动作空间类型选择分布
        if action_space.__class__.__name__ == "Discrete":
            self.dist = Categorical(...)  # 离散动作
        elif action_space.__class__.__name__ == "Box":
            self.dist = DiagGaussian(...)  # 连续动作
        elif action_space.__class__.__name__ == "MultiBinary":
            self.dist = Bernoulli(...)  # 多二元动作

Policy类提供了三个关键方法:

  • act():根据当前状态选择动作
  • get_value():获取状态价值估计
  • evaluate_actions():评估给定动作的概率和熵

2. NNBase类:神经网络基础类

NNBase是所有基础网络的父类,主要实现了循环神经网络(GRU)的相关功能:

class NNBase(nn.Module):
    def __init__(self, recurrent, recurrent_input_size, hidden_size):
        super(NNBase, self).__init__()
        if recurrent:
            self.gru = nn.GRU(recurrent_input_size, hidden_size)
            # GRU参数初始化
            nn.init.orthogonal_(param)  # 权重正交初始化
            nn.init.constant_(param, 0)  # 偏置初始化为0

_forward_gru()方法实现了带掩码的GRU前向传播,能够正确处理序列中的填充部分。

3. CNNBase类:卷积神经网络实现

CNNBase处理图像输入,采用经典的CNN架构:

self.main = nn.Sequential(
    nn.Conv2d(num_inputs, 32, 8, stride=4), nn.ReLU(),  # 第一层卷积
    nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),         # 第二层卷积
    nn.Conv2d(64, 32, 3, stride=1), nn.ReLU(),         # 第三层卷积
    Flatten(),
    nn.Linear(32 * 7 * 7, hidden_size), nn.ReLU()      # 全连接层
)

输入图像会先被归一化到[0,1]范围(除以255.0),然后经过三层卷积和一层全连接。

4. MLPBase类:多层感知机实现

MLPBase处理向量输入,采用两层全连接网络:

self.actor = nn.Sequential(
    nn.Linear(num_inputs, hidden_size), nn.Tanh(),
    nn.Linear(hidden_size, hidden_size), nn.Tanh())
    
self.critic = nn.Sequential(
    nn.Linear(num_inputs, hidden_size), nn.Tanh(),
    nn.Linear(hidden_size, hidden_size), nn.Tanh())

注意这里使用了Tanh激活函数,相比ReLU能产生更平滑的梯度。

关键技术点

  1. 参数初始化:使用正交初始化(orthogonal)配合适当的增益(gain)值

    init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))
    
  2. 循环网络处理:GRU的隐藏状态会与掩码(masks)相乘,正确处理序列中的填充部分

  3. 多类型动作空间支持

    • 离散动作:Categorical分布
    • 连续动作:DiagGaussian对角高斯分布
    • 多二元动作:Bernoulli分布
  4. 确定性vs随机性策略

    if deterministic:
        action = dist.mode()  # 确定性策略取最大概率动作
    else:
        action = dist.sample()  # 随机策略按概率采样
    

设计亮点

  1. 模块化设计:将基础网络、策略分布等组件分离,便于扩展和维护

  2. 通用接口:Policy类提供了统一的接口(act, get_value等),便于不同算法复用

  3. 高效实现:GRU前向传播中优化了带掩码的序列处理,提升了计算效率

  4. 灵活的初始化系统:通过init_函数统一管理参数初始化方式

实际应用建议

  1. 对于图像输入任务,优先使用CNNBase,其卷积结构能有效提取空间特征

  2. 对于需要记忆的任务,启用recurrent选项,GRU能帮助模型记住历史信息

  3. 连续控制任务中,DiagGaussian分布通常能取得较好效果

  4. 训练初期可以使用随机策略(deterministic=False)进行探索,后期可转为确定性策略

通过这种设计,该策略模型能够灵活应对各种强化学习任务,同时保持代码的清晰和高效。理解这些设计思想对于实现自己的强化学习模型有很大帮助。