深入解析ikostrikov/pytorch-a2c-ppo-acktr-gail中的策略模型设计
本文将对ikostrikov/pytorch-a2c-ppo-acktr-gail项目中的model.py文件进行深入解析,重点讲解其强化学习策略模型的设计架构和实现细节。
模型架构概述
该文件实现了强化学习中常用的策略网络架构,主要包含以下几个核心组件:
- 基础网络模块:包括CNNBase和MLPBase,分别处理图像输入和向量输入
- 策略分布模块:处理不同类型的动作空间(离散、连续、多二元)
- 循环网络支持:通过GRU实现时序记忆功能
核心类解析
1. Policy类:策略网络主框架
Policy类是策略网络的核心框架,它根据输入类型自动选择适当的基础网络,并构建相应的动作分布:
class Policy(nn.Module):
def __init__(self, obs_shape, action_space, base=None, base_kwargs=None):
super(Policy, self).__init__()
# 根据观测形状选择基础网络
if len(obs_shape) == 3:
base = CNNBase # 图像输入使用CNN
elif len(obs_shape) == 1:
base = MLPBase # 向量输入使用MLP
# 根据动作空间类型选择分布
if action_space.__class__.__name__ == "Discrete":
self.dist = Categorical(...) # 离散动作
elif action_space.__class__.__name__ == "Box":
self.dist = DiagGaussian(...) # 连续动作
elif action_space.__class__.__name__ == "MultiBinary":
self.dist = Bernoulli(...) # 多二元动作
Policy类提供了三个关键方法:
act()
:根据当前状态选择动作get_value()
:获取状态价值估计evaluate_actions()
:评估给定动作的概率和熵
2. NNBase类:神经网络基础类
NNBase是所有基础网络的父类,主要实现了循环神经网络(GRU)的相关功能:
class NNBase(nn.Module):
def __init__(self, recurrent, recurrent_input_size, hidden_size):
super(NNBase, self).__init__()
if recurrent:
self.gru = nn.GRU(recurrent_input_size, hidden_size)
# GRU参数初始化
nn.init.orthogonal_(param) # 权重正交初始化
nn.init.constant_(param, 0) # 偏置初始化为0
_forward_gru()
方法实现了带掩码的GRU前向传播,能够正确处理序列中的填充部分。
3. CNNBase类:卷积神经网络实现
CNNBase处理图像输入,采用经典的CNN架构:
self.main = nn.Sequential(
nn.Conv2d(num_inputs, 32, 8, stride=4), nn.ReLU(), # 第一层卷积
nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(), # 第二层卷积
nn.Conv2d(64, 32, 3, stride=1), nn.ReLU(), # 第三层卷积
Flatten(),
nn.Linear(32 * 7 * 7, hidden_size), nn.ReLU() # 全连接层
)
输入图像会先被归一化到[0,1]范围(除以255.0),然后经过三层卷积和一层全连接。
4. MLPBase类:多层感知机实现
MLPBase处理向量输入,采用两层全连接网络:
self.actor = nn.Sequential(
nn.Linear(num_inputs, hidden_size), nn.Tanh(),
nn.Linear(hidden_size, hidden_size), nn.Tanh())
self.critic = nn.Sequential(
nn.Linear(num_inputs, hidden_size), nn.Tanh(),
nn.Linear(hidden_size, hidden_size), nn.Tanh())
注意这里使用了Tanh激活函数,相比ReLU能产生更平滑的梯度。
关键技术点
-
参数初始化:使用正交初始化(orthogonal)配合适当的增益(gain)值
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))
-
循环网络处理:GRU的隐藏状态会与掩码(masks)相乘,正确处理序列中的填充部分
-
多类型动作空间支持:
- 离散动作:Categorical分布
- 连续动作:DiagGaussian对角高斯分布
- 多二元动作:Bernoulli分布
-
确定性vs随机性策略:
if deterministic: action = dist.mode() # 确定性策略取最大概率动作 else: action = dist.sample() # 随机策略按概率采样
设计亮点
-
模块化设计:将基础网络、策略分布等组件分离,便于扩展和维护
-
通用接口:Policy类提供了统一的接口(act, get_value等),便于不同算法复用
-
高效实现:GRU前向传播中优化了带掩码的序列处理,提升了计算效率
-
灵活的初始化系统:通过init_函数统一管理参数初始化方式
实际应用建议
-
对于图像输入任务,优先使用CNNBase,其卷积结构能有效提取空间特征
-
对于需要记忆的任务,启用recurrent选项,GRU能帮助模型记住历史信息
-
连续控制任务中,DiagGaussian分布通常能取得较好效果
-
训练初期可以使用随机策略(deterministic=False)进行探索,后期可转为确定性策略
通过这种设计,该策略模型能够灵活应对各种强化学习任务,同时保持代码的清晰和高效。理解这些设计思想对于实现自己的强化学习模型有很大帮助。