首页
/ DeepRL项目中的神经网络头设计解析

DeepRL项目中的神经网络头设计解析

2025-07-09 08:08:04作者:仰钰奇

概述

在深度强化学习框架DeepRL中,network_heads.py文件定义了多种用于强化学习任务的神经网络头部结构。这些头部结构通常与网络主体(body)结合使用,构成完整的深度强化学习模型。本文将深入解析这些网络头的设计原理和实现细节。

基础网络头

VanillaNet (普通网络头)

VanillaNet是最基础的网络头结构,适用于标准的深度Q网络(DQN):

  • 接收来自网络主体的特征向量
  • 通过一个全连接层输出Q值
  • 结构简单,计算效率高
class VanillaNet(nn.Module, BaseNet):
    def __init__(self, output_dim, body):
        super(VanillaNet, self).__init__()
        self.fc_head = layer_init(nn.Linear(body.feature_dim, output_dim))
        self.body = body

高级网络头结构

DuelingNet (决斗网络头)

DuelingNet实现了著名的Dueling DQN架构:

  • 将Q值分解为状态价值函数V和优势函数A
  • 通过两个独立的全连接层分别计算V和A
  • 最后将两者组合得到最终的Q值
class DuelingNet(nn.Module, BaseNet):
    def __init__(self, action_dim, body):
        self.fc_value = layer_init(nn.Linear(body.feature_dim, 1))  # 状态价值函数
        self.fc_advantage = layer_init(nn.Linear(body.feature_dim, action_dim))  # 优势函数

CategoricalNet (分类网络头)

CategoricalNet实现了Categorical DQN(也称为Distributional DQN):

  • 输出每个动作的价值分布而非单一Q值
  • 使用softmax将输出转换为概率分布
  • 适用于处理价值分布的不确定性
class CategoricalNet(nn.Module, BaseNet):
    def __init__(self, action_dim, num_atoms, body):
        self.fc_categorical = layer_init(nn.Linear(body.feature_dim, action_dim * num_atoms))

RainbowNet (彩虹网络头)

RainbowNet整合了多种DQN改进技术:

  • 结合了Dueling架构和Categorical DQN
  • 可选是否使用噪声线性层(NoisyNet)
  • 实现了reset_noise方法用于噪声重置
class RainbowNet(nn.Module, BaseNet):
    def reset_noise(self):
        if self.noisy_linear:
            self.fc_value.reset_noise()
            self.fc_advantage.reset_noise()

策略梯度网络头

DeterministicActorCriticNet (确定性策略梯度网络头)

适用于确定性策略梯度算法(如DDPG):

  • 包含actor和critic两个分支
  • actor输出确定性动作(使用tanh激活)
  • critic评估状态-动作对的Q值
class DeterministicActorCriticNet(nn.Module, BaseNet):
    def actor(self, phi):
        return torch.tanh(self.fc_action(self.actor_body(phi)))

GaussianActorCriticNet (高斯策略网络头)

适用于输出连续动作空间的随机策略:

  • 输出动作的高斯分布参数(均值和标准差)
  • 使用softplus确保标准差为正
  • 支持动作采样和概率计算
class GaussianActorCriticNet(nn.Module, BaseNet):
    def forward(self, obs, action=None):
        dist = torch.distributions.Normal(mean, F.softplus(self.std))

CategoricalActorCriticNet (分类策略网络头)

适用于离散动作空间的策略梯度算法:

  • 输出动作的分类分布
  • 使用softmax计算动作概率
  • 支持动作采样和熵计算
class CategoricalActorCriticNet(nn.Module, BaseNet):
    def forward(self, obs, action=None):
        dist = torch.distributions.Categorical(logits=logits)

特殊网络头

OptionCriticNet (选项网络头)

实现选项(option)框架的分层强化学习:

  • 同时学习选项策略、终止条件和选项价值函数
  • 使用sigmoid激活计算选项终止概率
  • 输出包含多个选项的相关信息
class OptionCriticNet(nn.Module, BaseNet):
    def forward(self, x):
        return {'q': q, 'beta': beta, 'log_pi': log_pi, 'pi': pi}

TD3Net (双延迟DDPG网络头)

实现TD3(Twin Delayed DDPG)算法:

  • 包含两个独立的critic网络以减少过高估计偏差
  • 使用目标网络和延迟更新策略网络
  • 实现动作值函数的多重估计
class TD3Net(nn.Module, BaseNet):
    def q(self, obs, a):
        q_1 = self.fc_critic_1(self.critic_body_1(x))
        q_2 = self.fc_critic_2(self.critic_body_2(x))

设计特点总结

  1. 模块化设计:每个网络头都是独立的模块,可与不同的网络主体组合使用
  2. 统一接口:大多数网络头都继承自BaseNet,保持一致的接口规范
  3. 设备管理:自动将网络参数转移到配置指定的设备(CPU/GPU)
  4. 参数初始化:使用layer_init进行规范的参数初始化
  5. 功能完备:覆盖了从值函数到策略梯度的多种强化学习算法需求

这些网络头结构为DeepRL框架提供了强大的灵活性,使得实现各种深度强化学习算法变得简单而高效。