强化学习入门:DQN 玩 Atari 游戏

1. DQN 核心概念

深度 Q 网络(Deep Q-Network, DQN)是强化学习中的里程碑算法,通过神经网络近似 Q 函数: $$Q(s,a) \approx Q(s,a;\theta)$$ 其中:

  • $s$ 为状态(游戏画面)
  • $a$ 为动作(如左右移动)
  • $\theta$ 为神经网络参数
2. 关键技术创新

(1) 经验回放(Experience Replay)

  • 存储转移样本 $(s_t, a_t, r_t, s_{t+1})$ 到记忆库
  • 训练时随机采样小批量样本,打破数据相关性

(2) 目标网络(Target Network)

  • 使用独立网络 $\theta^-$ 计算目标值: $$y_t = r_t + \gamma \max_{a'} Q(s_{t+1}, a'; \theta^-)$$
  • 主网络 $\theta$ 每 $C$ 步同步到目标网络
3. 损失函数

均方误差损失: $$L(\theta) = \mathbb{E}_{(s,a,r,s') \sim D} \left[ \left( y - Q(s,a;\theta) \right)^2 \right]$$ 其中 $D$ 为经验回放库,$\gamma$ 为折扣因子(通常取 0.99)

4. 网络架构(Atari 版)
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        conv_out = self._conv_out_size(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _conv_out_size(self, shape):
        return self.conv(torch.zeros(1, *shape)).view(1, -1).size(1)
    
    def forward(self, x):
        conv_out = self.conv(x).view(x.size(0), -1)
        return self.fc(conv_out)

5. 训练流程
# 伪代码框架
for episode in range(MAX_EPISODES):
    state = env.reset()
    while not done:
        # ε-贪婪策略选择动作
        action = agent.select_action(state, epsilon) 
        
        # 执行动作
        next_state, reward, done, _ = env.step(action)
        
        # 存储经验
        agent.store_transition(state, action, reward, next_state, done)
        
        # 更新网络
        if len(replay_buffer) > BATCH_SIZE:
            agent.update()
            
        state = next_state
    
    # 衰减探索率
    epsilon = max(MIN_EPSILON, epsilon * EPS_DECAY)

6. 实战技巧
  • 输入预处理:将原始 210×160 RGB 帧转换为 84×84 灰度图,堆叠 4 帧作为状态
  • 奖励裁剪:将奖励值限制为 $[-1, 1]$ 提高稳定性
  • 帧跳过:每 4 帧执行一次动作,中间帧重复动作
  • 目标网络更新:$C$ 通常取 10000 步
7. 性能优化
  • Double DQN:解耦动作选择与价值评估
  • Dueling DQN:分离状态价值和动作优势
  • 优先级经验回放:根据 TD 误差加权采样

经典结果:DQN 在 49 款 Atari 游戏中,75% 超越人类专业玩家水平,部分游戏(如 Breakout)达到人类水平的 10 倍以上。

通过上述框架,可训练智能体在 Pong、Breakout 等经典游戏中实现超人表现。实际实现需结合 OpenAI Gym 环境库和 PyTorch/TensorFlow 深度学习框架。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐