SAC离散和连续动作的核心区别就在于,生成动作的时候,没法在借用重参数技巧,基于高斯分布的均值和方差采样动作了。而是要让策略输出一个在所有离散动作上的 概率分布,并且损失函数的计算也不能依赖于某个“采样”的动作,而必须考虑整个概率分布。

Actor

任务是给定状态(游戏中的图像帧),输出一个在所有离散动作上的得分,再使用softmax转换成概率。

class Actor(nn.Module):
    def __init__(self, envs):
        super().__init__()
        # Atari 环境输入是图像,所以使用 NatureCNN 来提取特征
        self.network = NatureCNN(envs.single_observation_space.shape[0], 512)
        # 输出层的维度是动作空间的数量 (env.single_action_space.n)
        self.actor_head = nn.Linear(512, envs.single_action_space.n)

    def forward(self, x):
        # x / 255.0 是图像数据的标准化
        x = self.network(x / 255.0)
        # 输出的是每个动作的 "logits" (未经 softmax 的原始分数)
        logits = self.actor_head(x)
        return logits

    # 注意:这里的方法名和连续版本不同,直接返回概率和对数概率
    def get_action(self, x):
        logits = self(x)
        # --- 核心区别 ---
        # 1. 用 softmax 将 logits 转换为概率
        probs = F.softmax(logits, dim=-1)
        # 2. 计算 log 概率
        log_probs = F.log_softmax(logits, dim=-1)
        # 返回整个分布,而不是一个采样的动作
        return probs, log_probs

Actor网络更新

用当前的状态和动作来计算期望,利用的是向量对应相乘再求和。
L π ( ϕ ) = E s ∼ D [ ∑ a ∼ A π ϕ ( a ∣ s ) ( α log ⁡ π ϕ ( a ∣ s ) − min ⁡ j = 1 , 2 Q θ j ( s , a ) ) ] L_{\pi}(\phi)=\mathbb{E}_{s\sim\mathcal{D}}\left[\sum_{a\sim\mathcal{A}}\pi_{\phi}(a|s)\left(\alpha\log\pi_{\phi}(a|s)-\min_{j=1,2}Q_{\theta_j}(s,a)\right)\right] Lπ(ϕ)=EsD[aAπϕ(as)(αlogπϕ(as)j=1,2minQθj(s,a))]

# if global_step % args.policy_frequency == 0:
    # 同样,获取当前状态下所有动作的概率和对数概率
    action_probs, log_action_probs = actor.get_action(data.observations)

    # 用主 Q 网络计算所有动作的 Q 值
    with torch.no_grad():
        qf1_pi = qf1(data.observations)
        qf2_pi = qf2(data.observations)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

    # --- 核心区别:计算期望 ---
    # 对应公式 Σ [π * (α*logπ - Q)]
    actor_loss = (action_probs * ((alpha * log_action_probs) - min_qf_pi)).sum(dim=1).mean()

    # 更新 Actor
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

Critic

评价评价,评价什么?给定一个状态,一次性输出所有离散动作对应的Q值

函数 Q θ ( s , a ) Q_θ(s,a) Qθ(s,a) 的输入是状态 s 和离散动作 a。但为了效率,我们让网络只输入 s,输出一个向量,向量的第 i 个元素就是 Q θ ( s , a i ) Q_θ(s,a_i) Qθ(s,ai)

class SoftQNetwork(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.network = NatureCNN(envs.single_observation_space.shape[0], 512)
        # --- 核心区别 ---
        # 输出头的维度是动作空间的数量
        # 这样,对于一个状态,它能并行计算出所有动作的 Q 值
        self.q_head = nn.Linear(512, envs.single_action_space.n)

    def forward(self, x):
        x = self.network(x / 255.0)
        # 输出一个向量,[Q(s, a_1), Q(s, a_2), ...]
        q_values = self.q_head(x)
        return q_values

Critic网络的更新

和前面的Bellman方程形式相同,但是目标值y的计算方式不再是采样,而是要真的计算期望,对下一状态 s ′ s' s的所有可能动作 a ′ a' a进行求和
y = r + γ ( E a ′ ∼ π ( ⋅ ∣ s ′ ) [ min ⁡ j = 1 , 2 Q θ ˉ j ( s ′ , a ′ ) − α log ⁡ π ϕ ( a ′ ∣ s ′ ) ] ) y=r+\gamma\left(\mathbb{E}_{a'\sim\pi(\cdot|s')}\left[\min_{j=1,2}Q_{\bar{\theta}_j}(s',a')-\alpha\log\pi_{\phi}(a'|s')\right]\right) y=r+γ(Eaπ(s)[j=1,2minQθˉj(s,a)αlogπϕ(as)])
期望展开为加权:
y = r + γ ∑ a ′ ∼ A π ϕ ( a ′ ∣ s ′ ) ( min ⁡ j = 1 , 2 Q θ ˉ j ( s ′ , a ′ ) − α log ⁡ π ϕ ( a ′ ∣ s ′ ) ) y=r+\gamma\sum_{a'\sim\mathcal{A}}\pi_{\phi}(a'|s')\left(\min_{j=1,2}Q_{\bar{\theta}_j}(s',a')-\alpha\log\pi_{\phi}(a'|s')\right) y=r+γaAπϕ(as)(j=1,2minQθˉj(s,a)αlogπϕ(as))

# with torch.no_grad():
    # --- 对应公式中的 π(a'|s') 和 log π(a'|s') ---
    # actor.get_action 返回的是所有下一状态的动作的概率和对数概率
    next_state_action_probs, next_state_log_probs = actor.get_action(data.next_observations)

    # --- 对应公式中的 Q_target(s', a') ---
    # 目标网络一次性输出所有动作的 Q 值
    qf1_next_target = qf1_target(data.next_observations)
    qf2_next_target = qf2_target(data.next_observations)
    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)

    # --- 核心区别:计算期望 ---
    # 对应公式中的 Σ [π * (Q_target - α*logπ)]
    # next_state_action_probs 是一个 (batch_size, action_dim) 的张量
    # min_qf_next_target 也是 (batch_size, action_dim)
    # 逐元素相乘再按动作维度求和,就完成了期望的计算
    next_q_value = (next_state_action_probs * (min_qf_next_target - alpha * next_state_log_probs)).sum(dim=1)

    # 加上奖励和折扣因子
    next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * next_q_value

# --- Q 损失计算 ---
# 这里的 Q 值是所有动作的 Q 值,那个q_values的向量
qf1_a_values = qf1(data.observations)
qf2_a_values = qf2(data.observations)

# --- 核心区别:用 gather 选取实际执行动作的 Q 值 ---
# data.actions 是实际执行的动作的索引
# .gather(1, data.actions) 会根据索引,从 qf1_a_values 中挑出对应的 Q 值
qf1_a_values = qf1_a_values.gather(1, data.actions).squeeze()
qf2_a_values = qf2_a_values.gather(1, data.actions).squeeze()

# MSE 损失的计算和之前一样
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss

总结

  • 连续 SAC 通过聪明的重参数化技巧,把对期望的优化问题转换成了一个基于单次采样的优化问题。

  • 离散 SAC 因为没有这个技巧,只能老老实实地通过对整个动作分布进行加权求和来直接计算期望
    SAC依然具备对离散动作的原生支持,相对于DDPG那样需要Gumbel-Softmax有着更好的适应性。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐