【CleanRL】强化学习SAC进阶——离散动作实现与数学原理
本文介绍了离散动作空间下SAC(Soft Actor-Critic)算法的实现要点。与连续版本不同,离散SAC需要让策略输出所有动作的概率分布,并通过加权求和直接计算期望值。
文章目录
SAC离散和连续动作的核心区别就在于,生成动作的时候,没法在借用重参数技巧,基于高斯分布的均值和方差采样动作了。而是要让策略输出一个在所有离散动作上的 概率分布,并且损失函数的计算也不能依赖于某个“采样”的动作,而必须考虑整个概率分布。
Actor
任务是给定状态(游戏中的图像帧),输出一个在所有离散动作上的得分,再使用softmax转换成概率。
class Actor(nn.Module):
def __init__(self, envs):
super().__init__()
# Atari 环境输入是图像,所以使用 NatureCNN 来提取特征
self.network = NatureCNN(envs.single_observation_space.shape[0], 512)
# 输出层的维度是动作空间的数量 (env.single_action_space.n)
self.actor_head = nn.Linear(512, envs.single_action_space.n)
def forward(self, x):
# x / 255.0 是图像数据的标准化
x = self.network(x / 255.0)
# 输出的是每个动作的 "logits" (未经 softmax 的原始分数)
logits = self.actor_head(x)
return logits
# 注意:这里的方法名和连续版本不同,直接返回概率和对数概率
def get_action(self, x):
logits = self(x)
# --- 核心区别 ---
# 1. 用 softmax 将 logits 转换为概率
probs = F.softmax(logits, dim=-1)
# 2. 计算 log 概率
log_probs = F.log_softmax(logits, dim=-1)
# 返回整个分布,而不是一个采样的动作
return probs, log_probs
Actor网络更新
用当前的状态和动作来计算期望,利用的是向量对应相乘再求和。
L π ( ϕ ) = E s ∼ D [ ∑ a ∼ A π ϕ ( a ∣ s ) ( α log π ϕ ( a ∣ s ) − min j = 1 , 2 Q θ j ( s , a ) ) ] L_{\pi}(\phi)=\mathbb{E}_{s\sim\mathcal{D}}\left[\sum_{a\sim\mathcal{A}}\pi_{\phi}(a|s)\left(\alpha\log\pi_{\phi}(a|s)-\min_{j=1,2}Q_{\theta_j}(s,a)\right)\right] Lπ(ϕ)=Es∼D[a∼A∑πϕ(a∣s)(αlogπϕ(a∣s)−j=1,2minQθj(s,a))]
# if global_step % args.policy_frequency == 0:
# 同样,获取当前状态下所有动作的概率和对数概率
action_probs, log_action_probs = actor.get_action(data.observations)
# 用主 Q 网络计算所有动作的 Q 值
with torch.no_grad():
qf1_pi = qf1(data.observations)
qf2_pi = qf2(data.observations)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
# --- 核心区别:计算期望 ---
# 对应公式 Σ [π * (α*logπ - Q)]
actor_loss = (action_probs * ((alpha * log_action_probs) - min_qf_pi)).sum(dim=1).mean()
# 更新 Actor
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()
Critic
评价评价,评价什么?给定一个状态,一次性输出所有离散动作对应的Q值
函数 Q θ ( s , a ) Q_θ(s,a) Qθ(s,a) 的输入是状态 s 和离散动作 a。但为了效率,我们让网络只输入 s,输出一个向量,向量的第 i 个元素就是 Q θ ( s , a i ) Q_θ(s,a_i) Qθ(s,ai)。
class SoftQNetwork(nn.Module):
def __init__(self, envs):
super().__init__()
self.network = NatureCNN(envs.single_observation_space.shape[0], 512)
# --- 核心区别 ---
# 输出头的维度是动作空间的数量
# 这样,对于一个状态,它能并行计算出所有动作的 Q 值
self.q_head = nn.Linear(512, envs.single_action_space.n)
def forward(self, x):
x = self.network(x / 255.0)
# 输出一个向量,[Q(s, a_1), Q(s, a_2), ...]
q_values = self.q_head(x)
return q_values
Critic网络的更新
和前面的Bellman方程形式相同,但是目标值y的计算方式不再是采样,而是要真的计算期望,对下一状态 s ′ s' s′的所有可能动作 a ′ a' a′进行求和
y = r + γ ( E a ′ ∼ π ( ⋅ ∣ s ′ ) [ min j = 1 , 2 Q θ ˉ j ( s ′ , a ′ ) − α log π ϕ ( a ′ ∣ s ′ ) ] ) y=r+\gamma\left(\mathbb{E}_{a'\sim\pi(\cdot|s')}\left[\min_{j=1,2}Q_{\bar{\theta}_j}(s',a')-\alpha\log\pi_{\phi}(a'|s')\right]\right) y=r+γ(Ea′∼π(⋅∣s′)[j=1,2minQθˉj(s′,a′)−αlogπϕ(a′∣s′)])
期望展开为加权:
y = r + γ ∑ a ′ ∼ A π ϕ ( a ′ ∣ s ′ ) ( min j = 1 , 2 Q θ ˉ j ( s ′ , a ′ ) − α log π ϕ ( a ′ ∣ s ′ ) ) y=r+\gamma\sum_{a'\sim\mathcal{A}}\pi_{\phi}(a'|s')\left(\min_{j=1,2}Q_{\bar{\theta}_j}(s',a')-\alpha\log\pi_{\phi}(a'|s')\right) y=r+γa′∼A∑πϕ(a′∣s′)(j=1,2minQθˉj(s′,a′)−αlogπϕ(a′∣s′))
# with torch.no_grad():
# --- 对应公式中的 π(a'|s') 和 log π(a'|s') ---
# actor.get_action 返回的是所有下一状态的动作的概率和对数概率
next_state_action_probs, next_state_log_probs = actor.get_action(data.next_observations)
# --- 对应公式中的 Q_target(s', a') ---
# 目标网络一次性输出所有动作的 Q 值
qf1_next_target = qf1_target(data.next_observations)
qf2_next_target = qf2_target(data.next_observations)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
# --- 核心区别:计算期望 ---
# 对应公式中的 Σ [π * (Q_target - α*logπ)]
# next_state_action_probs 是一个 (batch_size, action_dim) 的张量
# min_qf_next_target 也是 (batch_size, action_dim)
# 逐元素相乘再按动作维度求和,就完成了期望的计算
next_q_value = (next_state_action_probs * (min_qf_next_target - alpha * next_state_log_probs)).sum(dim=1)
# 加上奖励和折扣因子
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * next_q_value
# --- Q 损失计算 ---
# 这里的 Q 值是所有动作的 Q 值,那个q_values的向量
qf1_a_values = qf1(data.observations)
qf2_a_values = qf2(data.observations)
# --- 核心区别:用 gather 选取实际执行动作的 Q 值 ---
# data.actions 是实际执行的动作的索引
# .gather(1, data.actions) 会根据索引,从 qf1_a_values 中挑出对应的 Q 值
qf1_a_values = qf1_a_values.gather(1, data.actions).squeeze()
qf2_a_values = qf2_a_values.gather(1, data.actions).squeeze()
# MSE 损失的计算和之前一样
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss
总结
-
连续 SAC 通过聪明的重参数化技巧,把对期望的优化问题转换成了一个基于单次采样的优化问题。
-
离散 SAC 因为没有这个技巧,只能老老实实地通过对整个动作分布进行加权求和来直接计算期望。
SAC依然具备对离散动作的原生支持,相对于DDPG那样需要Gumbel-Softmax有着更好的适应性。
更多推荐
所有评论(0)