DeepRL项目中的神经网络头设计解析
2025-07-09 08:08:04作者:仰钰奇
概述
在深度强化学习框架DeepRL中,network_heads.py
文件定义了多种用于强化学习任务的神经网络头部结构。这些头部结构通常与网络主体(body)结合使用,构成完整的深度强化学习模型。本文将深入解析这些网络头的设计原理和实现细节。
基础网络头
VanillaNet (普通网络头)
VanillaNet是最基础的网络头结构,适用于标准的深度Q网络(DQN):
- 接收来自网络主体的特征向量
- 通过一个全连接层输出Q值
- 结构简单,计算效率高
class VanillaNet(nn.Module, BaseNet):
def __init__(self, output_dim, body):
super(VanillaNet, self).__init__()
self.fc_head = layer_init(nn.Linear(body.feature_dim, output_dim))
self.body = body
高级网络头结构
DuelingNet (决斗网络头)
DuelingNet实现了著名的Dueling DQN架构:
- 将Q值分解为状态价值函数V和优势函数A
- 通过两个独立的全连接层分别计算V和A
- 最后将两者组合得到最终的Q值
class DuelingNet(nn.Module, BaseNet):
def __init__(self, action_dim, body):
self.fc_value = layer_init(nn.Linear(body.feature_dim, 1)) # 状态价值函数
self.fc_advantage = layer_init(nn.Linear(body.feature_dim, action_dim)) # 优势函数
CategoricalNet (分类网络头)
CategoricalNet实现了Categorical DQN(也称为Distributional DQN):
- 输出每个动作的价值分布而非单一Q值
- 使用softmax将输出转换为概率分布
- 适用于处理价值分布的不确定性
class CategoricalNet(nn.Module, BaseNet):
def __init__(self, action_dim, num_atoms, body):
self.fc_categorical = layer_init(nn.Linear(body.feature_dim, action_dim * num_atoms))
RainbowNet (彩虹网络头)
RainbowNet整合了多种DQN改进技术:
- 结合了Dueling架构和Categorical DQN
- 可选是否使用噪声线性层(NoisyNet)
- 实现了reset_noise方法用于噪声重置
class RainbowNet(nn.Module, BaseNet):
def reset_noise(self):
if self.noisy_linear:
self.fc_value.reset_noise()
self.fc_advantage.reset_noise()
策略梯度网络头
DeterministicActorCriticNet (确定性策略梯度网络头)
适用于确定性策略梯度算法(如DDPG):
- 包含actor和critic两个分支
- actor输出确定性动作(使用tanh激活)
- critic评估状态-动作对的Q值
class DeterministicActorCriticNet(nn.Module, BaseNet):
def actor(self, phi):
return torch.tanh(self.fc_action(self.actor_body(phi)))
GaussianActorCriticNet (高斯策略网络头)
适用于输出连续动作空间的随机策略:
- 输出动作的高斯分布参数(均值和标准差)
- 使用softplus确保标准差为正
- 支持动作采样和概率计算
class GaussianActorCriticNet(nn.Module, BaseNet):
def forward(self, obs, action=None):
dist = torch.distributions.Normal(mean, F.softplus(self.std))
CategoricalActorCriticNet (分类策略网络头)
适用于离散动作空间的策略梯度算法:
- 输出动作的分类分布
- 使用softmax计算动作概率
- 支持动作采样和熵计算
class CategoricalActorCriticNet(nn.Module, BaseNet):
def forward(self, obs, action=None):
dist = torch.distributions.Categorical(logits=logits)
特殊网络头
OptionCriticNet (选项网络头)
实现选项(option)框架的分层强化学习:
- 同时学习选项策略、终止条件和选项价值函数
- 使用sigmoid激活计算选项终止概率
- 输出包含多个选项的相关信息
class OptionCriticNet(nn.Module, BaseNet):
def forward(self, x):
return {'q': q, 'beta': beta, 'log_pi': log_pi, 'pi': pi}
TD3Net (双延迟DDPG网络头)
实现TD3(Twin Delayed DDPG)算法:
- 包含两个独立的critic网络以减少过高估计偏差
- 使用目标网络和延迟更新策略网络
- 实现动作值函数的多重估计
class TD3Net(nn.Module, BaseNet):
def q(self, obs, a):
q_1 = self.fc_critic_1(self.critic_body_1(x))
q_2 = self.fc_critic_2(self.critic_body_2(x))
设计特点总结
- 模块化设计:每个网络头都是独立的模块,可与不同的网络主体组合使用
- 统一接口:大多数网络头都继承自BaseNet,保持一致的接口规范
- 设备管理:自动将网络参数转移到配置指定的设备(CPU/GPU)
- 参数初始化:使用layer_init进行规范的参数初始化
- 功能完备:覆盖了从值函数到策略梯度的多种强化学习算法需求
这些网络头结构为DeepRL框架提供了强大的灵活性,使得实现各种深度强化学习算法变得简单而高效。