介绍

深度Q网络(Deep Q-Network,DQN)是深度强化学习(Deep Reinforcement Learning)中的一个重要方法,它将Q学习(Q-learning)深度神经网络相结合,实现了在高维状态空间下的强化学习任务,代表性成果是 DeepMind 在 2015 年提出的利用 DQN 在多个 Atari 游戏中实现了接近或超过人类水平的表现。


一、背景与基本概念

1. 强化学习框架

强化学习问题通常用马尔可夫决策过程(Markov Decision Process, MDP)建模:

  • 状态空间: S \mathcal{S} S
  • 动作空间: A \mathcal{A} A
  • 状态转移概率: P ( s ′ ∣ s , a ) P(s'|s,a) P(ss,a)
  • 奖励函数: R ( s , a ) R(s,a) R(s,a)
  • 折扣因子: γ ∈ [ 0 , 1 ) \gamma \in [0,1) γ[0,1)

目标是学习一个策略 π ( a ∣ s ) \pi(a|s) π(as),使得累积期望奖励最大化:

E [ ∑ t = 0 ∞ γ t r t ] \mathbb{E} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \right] E[t=0γtrt]


二、Q学习回顾

Q学习是一种无模型的强化学习方法,学习动作价值函数(Q函数):

Q π ( s , a ) = E [ ∑ t = 0 ∞ γ t r t ∣ s 0 = s , a 0 = a , π ] Q^\pi(s,a) = \mathbb{E} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \mid s_0 = s, a_0 = a, \pi \right] Qπ(s,a)=E[t=0γtrts0=s,a0=a,π]

最优Q函数满足 Bellman最优方程

Q ∗ ( s , a ) = E s ′ [ r + γ max ⁡ a ′ Q ∗ ( s ′ , a ′ ) ∣ s , a ] Q^*(s,a) = \mathbb{E}_{s'} \left[ r + \gamma \max_{a'} Q^*(s', a') \mid s, a \right] Q(s,a)=Es[r+γamaxQ(s,a)s,a]

Q学习通过以下更新规则进行学习:

Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a) \leftarrow Q(s,a) + \alpha \left[ r + \gamma \max_{a'} Q(s',a') - Q(s,a) \right] Q(s,a)Q(s,a)+α[r+γamaxQ(s,a)Q(s,a)]


三、深度Q网络(DQN)

在高维状态空间(如图像)下,用表格表示 Q 函数不再可行,DQN 使用一个神经网络 Q ( s , a ; θ ) Q(s,a;\theta) Q(s,a;θ) 来逼近 Q 函数。

核心思想

  • 使用神经网络表示 Q 函数
  • 使用经验回放(experience replay)
  • 使用固定目标网络(target network)

四、DQN算法公式推导

1. 损失函数(目标函数)

使用神经网络参数 θ \theta θ 逼近 Q 值。目标是最小化 Q 值与目标 Q 值的差距。

定义当前网络: Q ( s , a ; θ ) Q(s,a;\theta) Q(s,a;θ)

定义目标网络(固定参数若干步后更新一次): Q ( s , a ; θ − ) Q(s,a;\theta^-) Q(s,a;θ)

构造目标:

y = r + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ − ) y = r + \gamma \max_{a'} Q(s',a'; \theta^-) y=r+γamaxQ(s,a;θ)

定义损失函数:

L ( θ ) = E ( s , a , r , s ′ ) ∼ D [ ( y − Q ( s , a ; θ ) ) 2 ] L(\theta) = \mathbb{E}_{(s,a,r,s') \sim D} \left[ \left( y - Q(s,a;\theta) \right)^2 \right] L(θ)=E(s,a,r,s)D[(yQ(s,a;θ))2]

也就是:

L ( θ ) = E ( s , a , r , s ′ ) [ ( r + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ − ) − Q ( s , a ; θ ) ) 2 ] L(\theta) = \mathbb{E}_{(s,a,r,s')} \left[ \left( r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s,a;\theta) \right)^2 \right] L(θ)=E(s,a,r,s)[(r+γamaxQ(s,a;θ)Q(s,a;θ))2]

2. 梯度更新

对损失函数求梯度(用于反向传播):

∇ θ L ( θ ) = E ( s , a , r , s ′ ) [ ( r + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ − ) − Q ( s , a ; θ ) ) ∇ θ Q ( s , a ; θ ) ] \nabla_\theta L(\theta) = \mathbb{E}_{(s,a,r,s')} \left[ \left( r + \gamma \max_{a'} Q(s',a';\theta^-) - Q(s,a;\theta) \right) \nabla_\theta Q(s,a;\theta) \right] θL(θ)=E(s,a,r,s)[(r+γamaxQ(s,a;θ)Q(s,a;θ))θQ(s,a;θ)]


五、DQN核心技术细节

1. 经验回放(Experience Replay)

  • 用一个缓存 D D D 存储经验元组 ( s , a , r , s ′ ) (s, a, r, s') (s,a,r,s)
  • 每次训练从 D D D 中随机采样 mini-batch
  • 破除数据相关性,提高样本利用率

2. 固定目标网络(Target Network)

  • 用另一个网络 Q ( s , a ; θ − ) Q(s,a;\theta^-) Q(s,a;θ) 生成目标 Q 值
  • 每隔 C C C 步更新一次目标网络: θ − ← θ \theta^- \leftarrow \theta θθ

3. ε-贪心策略(Exploration)

  • 以概率 ϵ \epsilon ϵ 随机选择动作(探索)
  • 以概率 1 − ϵ 1-\epsilon 1ϵ 选择 arg ⁡ max ⁡ a Q ( s , a ; θ ) \arg\max_a Q(s,a;\theta) argmaxaQ(s,a;θ)(利用)
  • 通常采用 ε 衰减策略

六、DQN算法流程总结

初始化 Q 网络参数 θ,目标网络参数 θ⁻ ← θ
初始化经验回放池 D

for 每一轮 episode:
    初始化状态 s
    for 每一步:
        ε-贪心选择动作 a
        执行动作 a,观察 r 和 s'
        存储 (s,a,r,s') 到 D 中
        从 D 中采样一小批样本 (s,a,r,s')
        计算目标 y = r + γ max_{a'} Q(s',a';θ⁻)
        计算损失 L(θ) = (y - Q(s,a;θ))²
        更新 θ 以最小化损失
        每 C 步:更新 θ⁻ ← θ

七、DQN的改进版本(可扩展了解)

  1. Double DQN:解决 Q 值过估计问题
  2. Dueling DQN:分离状态价值和动作优势函数
  3. Prioritized Experience Replay:根据 TD 误差优先采样
  4. Rainbow DQN:集成多种改进(Double, Dueling, PER, Noisy Nets 等)

八、小结

项目 说明
模型 用深度网络逼近 Q 函数
核心 经验回放 + 固定目标网络
损失函数 L ( θ ) = ( r + γ max ⁡ Q ′ − Q ) 2 L(\theta) = (r + \gamma \max Q' - Q)^2 L(θ)=(r+γmaxQQ)2
策略 ε-贪心探索策略
训练方式 小批量梯度下降更新 Q 网络参数

代码实现示例

下面是一个最小可运行版本的 DQN 实现代码,使用 PyTorch,在 OpenAI Gym 环境(如 CartPole-v1)中运行。这个版本涵盖了:

  • Q 网络
  • 目标网络
  • 经验回放
  • ε-贪心策略
  • 网络参数更新

✅ 环境依赖

先确保你安装了以下库:

pip install gym torch numpy

✅ 完整代码(适用于 CartPole-v1)

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque

# 超参数
GAMMA = 0.99
LR = 1e-3
BATCH_SIZE = 64
MEMORY_SIZE = 10000
TARGET_UPDATE_FREQ = 100
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 500

# Q 网络定义
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, x):
        return self.net(x)

# 经验回放
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s_, done):
        self.buffer.append((s, a, r, s_, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s_, d = zip(*batch)
        return (
            torch.tensor(s, dtype=torch.float),
            torch.tensor(a, dtype=torch.long),
            torch.tensor(r, dtype=torch.float),
            torch.tensor(s_, dtype=torch.float),
            torch.tensor(d, dtype=torch.float)
        )

    def __len__(self):
        return len(self.buffer)

# 主训练逻辑
def train_dqn(env_name='CartPole-v1', episodes=500):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    policy_net = DQN(state_dim, action_dim)
    target_net = DQN(state_dim, action_dim)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=LR)
    replay_buffer = ReplayBuffer(MEMORY_SIZE)

    steps_done = 0

    def select_action(state):
        nonlocal steps_done
        epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
                  np.exp(-1. * steps_done / EPSILON_DECAY)
        steps_done += 1
        if random.random() < epsilon:
            return random.randrange(action_dim)
        else:
            with torch.no_grad():
                return policy_net(torch.tensor(state, dtype=torch.float).unsqueeze(0)).argmax().item()

    for episode in range(episodes):
        state = env.reset()
        episode_reward = 0
        done = False

        while not done:
            action = select_action(state)
            next_state, reward, done, _ = env.step(action)
            replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward

            if len(replay_buffer) > BATCH_SIZE:
                states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)

                q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
                next_q_values = target_net(next_states).max(1)[0].detach()
                targets = rewards + GAMMA * next_q_values * (1 - dones)

                loss = nn.MSELoss()(q_values, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if steps_done % TARGET_UPDATE_FREQ == 0:
                target_net.load_state_dict(policy_net.state_dict())

        print(f"Episode {episode}, Reward: {episode_reward}")

    env.close()

# 运行训练
if __name__ == "__main__":
    train_dqn()

✅ 输出示例

Episode 0, Reward: 23.0
Episode 1, Reward: 19.0
Episode 2, Reward: 28.0
...
Episode 200, Reward: 200.0

当 DQN 学习稳定时,CartPole-v1 达到 200 分说明它已经学会了控制策略。


✅ 下一步建议

如你想进一步提升性能或学习更复杂环境:

  • ✅ 添加 Double DQN 或 Dueling DQN
  • ✅ 支持图像输入(比如 Atari)
  • ✅ 使用 PyTorch Lightning 等框架简化训练逻辑
  • ✅ 可视化 Q 值或策略演化过程
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐