贝曼方程

贝曼方程是强化学习中最核心的理论,贯穿了从经典的动态规划到现代深度强化学习(如 DQN、PPO)的发展脉络。其核心思想在于递归分解:将无限远未来的复杂评估,拆解为“当前即时奖励”与“下一状态折现价值”之和。

1. 核心前置概念

在推导贝曼方程之前,必须明确两个支撑模型:

1.1 折现回报 (Discounted Return)

智能体的目标是最大化未来的累积奖励,称为 回报 G t G_t Gt。引入折扣因子 γ ∈ [ 0 , 1 ] \gamma \in [0, 1] γ[0,1] 以保证数学收敛并体现未来的不确定性。

G t = R t + 1 + γ R t + 2 + γ 2 R t + 3 + ⋯ = ∑ k = 0 ∞ γ k R t + k + 1 G_t = R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \dots = \sum_{k=0}^{\infty} \gamma^k R_{t+k+1} Gt=Rt+1+γRt+2+γ2Rt+3+=k=0γkRt+k+1

递归性质:
通过提取公因式 γ \gamma γ,可以得到贝曼方程的基础递归式:

G t = R t + 1 + γ G t + 1 G_t = R_{t+1} + \gamma G_{t+1} Gt=Rt+1+γGt+1

即: 当前总回报 = 即时奖励 + γ × \gamma \times γ× 下一时刻总回报。

1.2 马尔可夫决策过程 (MDP)

MDP 是描述环境的标准模型,其核心是马尔可夫性质:下一状态只取决于当前状态和动作。

  • 五元组: ( S , A , P , R , γ ) (S, A, P, R, \gamma) (S,A,P,R,γ)
  • P ( s ′ ∣ s , a ) P(s'|s,a) P(ss,a) 状态转移概率。
  • R ( s , a ) R(s,a) R(s,a) 奖励函数。
  1. S S S (State) - 状态它是智能体眼中的“世界快照”。 它包含了环境在某一时刻的所有关键信息,比如你在地图上的坐标、赛车的速度或棋盘上棋子的布局。它告诉智能体:“你现在处于什么样的处境中。
  2. A A A (Action) - 动作它是智能体拥有的“决策清单”。 它定义了智能体在特定状态下合法的所有行为,比如格斗游戏里的出拳、防御,或者走迷宫时的上下左右。它决定了智能体:“面对当前处境,你有哪些选择。
  3. P ( s ′ ∣ s , a ) P(s'|s, a) P(ss,a) (Transition) - 状态转移概率它是环境变化的“因果规律”。 它描述了当你执行某个动作后,世界发生变化的随机性,例如在结冰路面开车,你踩了刹车(动作),但车子由于惯性滑向了不同位置(新状态)。它代表了:“你做了某个动作后,有多大概率会变成什么样子。
  4. R ( s , a ) R(s, a) R(s,a) (Reward) - 奖励函数它是衡量行为好坏的“反馈刻度”。 它是环境在你做出动作后立刻给你的即时奖励或惩罚,比如吃到金币得 10 分,掉进陷阱扣 100 分。它指引智能体:“这一步走完,你是尝到了甜头还是吃到了苦头。
  5. γ \gamma γ (Discount Factor) - 折扣因子它是智能体对待未来的“耐心程度”。 它的数值在 0 到 1 之间,数值越大说明越看重长远利益,数值越小则说明越贪图眼前的即时快感。它决定了:“未来的回报在现在的你看来,到底还值多少钱。”

2. 价值函数:V 与 Q

为了量化策略 π \pi π 的好坏,我们定义了两种价值函数:

函数类型 定义 数学表达
状态价值 V π ( s ) V^\pi(s) Vπ(s) 在状态 s s s 下,遵循策略 π \pi π 的期望回报 V π ( s ) = E π [ G t ∣ S t = s ] V^\pi(s) = \mathbb{E}_\pi [ G_t \mid S_t = s ] Vπ(s)=Eπ[GtSt=s]
动作价值 Q π ( s , a ) Q^\pi(s,a) Qπ(s,a) 在状态 s s s 执行动作 a a a 后,遵循策略 π \pi π 的期望回报 Q π ( s , a ) = E π [ G t ∣ S t = s , A t = a ] Q^\pi(s,a) = \mathbb{E}_\pi [ G_t \mid S_t = s, A_t = a ] Qπ(s,a)=Eπ[GtSt=s,At=a]

注意:
1.这里的 π \pi π指的的是策略而非圆周率, 通常写成 π ( a ∣ s ) \pi(a|s) π(as), 在状态 s s s 的情况下,采取动作 a a a 的概率。
2.策略 π \pi π 与 动作 a a a 的本质区别动作 a a a 是“零件”: 它是动作空间 A A A 里的一个具体元素,比如“左转”、“跳跃”。它本身没有思想,只是一个动作指令。策略 π \pi π 是“蓝图”: 它是一个概率分布函数。它的任务是观察当前状态 s s s,然后计算出每个动作 a a a 被选中的概率。
3.动作价值 Q Q Q 负责将即时收益 r r r 与未来潜力 γ V n e x t \gamma V_{next} γVnext 挂钩,它告诉智能体执行某一步具体的动作能换回多少“现钱”和多大的“前程”;而状态价值 V V V 则通过策略 π \pi π 对这些动作进行加权平均,从而对当前的整体处境给出一个综合评分。
4. V V V(处境分) 是由 π \pi π(出牌概率) 和 Q Q Q(单张牌的分数) 共同决定的。公式表达就是: V = ∑ ( π × Q ) V = \sum (\pi \times Q) V=(π×Q)即:处境的总分 = (动作1的概率 × \times × 动作1的分数) + (动作2的概率 × \times × 动作2的分数) + …

3. 贝曼期望方程 (Bellman Expectation Equation)

通过将 V V V Q Q Q 互相嵌套,我们得到了描述当前状态与后续状态之间关系的递归方程。

3.1 相互转化关系

  1. 状态价值由动作价值加权决定:

V π ( s ) = ∑ a ∈ A π ( a ∣ s ) Q π ( s , a ) V^\pi(s) = \sum_{a \in A} \pi(a|s) Q^\pi(s,a) Vπ(s)=aAπ(as)Qπ(s,a)

  1. 动作价值由即时奖励和后续状态价值决定:

Q π ( s , a ) = R ( s , a ) + γ ∑ s ′ ∈ S P ( s ′ ∣ s , a ) V π ( s ′ ) Q^\pi(s,a) = R(s,a) + \gamma \sum_{s' \in S} P(s'|s,a) V^\pi(s') Qπ(s,a)=R(s,a)+γsSP(ss,a)Vπ(s)

3.2 完整递归形式

将上述两式合并,得到针对 V V V 的完整贝曼期望方程:

V π ( s ) = ∑ a ∈ A π ( a ∣ s ) [ R ( s , a ) + γ ∑ s ′ ∈ S P ( s ′ ∣ s , a ) V π ( s ′ ) ] V^\pi(s) = \sum_{a \in A} \pi(a|s) \left[ R(s,a) + \gamma \sum_{s' \in S} P(s'|s,a) V^\pi(s') \right] Vπ(s)=aAπ(as)[R(s,a)+γsSP(ss,a)Vπ(s)]


4. 贝曼最优方程 (Bellman Optimality Equation)

强化学习的终极目标是找到最优策略 π ∗ \pi^* π。在最优策略下,智能体会选择 Q Q Q 值最大的动作,而非按概率分布选择。

  • 最优状态价值:

V ∗ ( s ) = max ⁡ a ∈ A [ R ( s , a ) + γ ∑ s ′ ∈ S P ( s ′ ∣ s , a ) V ∗ ( s ′ ) ] V^*(s) = \max_{a \in A} \left[ R(s,a) + \gamma \sum_{s' \in S} P(s'|s,a) V^*(s') \right] V(s)=aAmax[R(s,a)+γsSP(ss,a)V(s)]

  • 最优动作价值:

Q ∗ ( s , a ) = R ( s , a ) + γ ∑ s ′ ∈ S P ( s ′ ∣ s , a ) max ⁡ a ′ Q ∗ ( s ′ , a ′ ) Q^*(s,a) = R(s,a) + \gamma \sum_{s' \in S} P(s'|s,a) \max_{a'} Q^*(s',a') Q(s,a)=R(s,a)+γsSP(ss,a)amaxQ(s,a)

核心逻辑: 只要解出 Q ∗ ( s , a ) Q^*(s,a) Q(s,a),最优策略即为: π ∗ ( s ) = arg ⁡ max ⁡ a Q ∗ ( s , a ) \pi^*(s) = \arg\max_a Q^*(s,a) π(s)=argmaxaQ(s,a)


5. 算法实现:价值迭代 (Python)

在环境已知的情况下,我们可以利用贝曼最优方程通过迭代逼近真实价值。

# bellman.py
class BellmanSolver:
    def __init__(self, states, get_actions_func, get_transitions_func, gamma=0.9):
        self.states = states
        self.get_actions = get_actions_func
        self.get_transitions = get_transitions_func
        self.gamma = gamma
        self.V = {s: 0.0 for s in states}

    def compute_q(self, state, action):
        """核心逻辑:下半场 Q = r + gamma * V_next"""
        q_value = 0
        for prob, next_s, reward in self.get_transitions(state, action):
            q_value += prob * (reward + self.gamma * self.V.get(next_s, 0.0))
        return q_value

    def step(self, policy=None):
        """核心逻辑:上半场 V = sum(pi * Q) 或 V = max(Q)"""
        new_V = {}
        delta = 0.0

        for s in self.states:
            actions = self.get_actions(s)
            if not actions:
                new_V[s] = 0.0
                continue

            qs = {a: self.compute_q(s, a) for a in actions}

            if policy:
                new_V[s] = sum(policy(s, a) * qs[a] for a in actions)
            else:
                new_V[s] = max(qs.values())

            delta = max(delta, abs(self.V[s] - new_V[s]))

        self.V = new_V
        return delta

    def solve(self, theta=1e-6):
        while self.step() > theta:
            pass
        return self.V


if __name__ == '__main__':
    my_states = ['安全区', '寻宝区', '悬崖']


    def my_actions(state):
        if state == '安全区':
            return ['前往寻宝']
        elif state == '寻宝区':
            return ['挖宝', '退回安全区']
        return []


    def my_transitions(state, action):
        if state == '安全区' and action == '前往寻宝':
            return [(1.0, '寻宝区', -1)]
        elif state == '寻宝区' and action == '挖宝':
            return [(0.8, '寻宝区', 50), (0.2, '悬崖', -100)]
        elif state == '寻宝区' and action == '退回安全区':
            return [(1.0, '安全区', 0)]
        return [(1.0, state, 0)]


    solver = BellmanSolver(my_states, my_actions, my_transitions, gamma=0.9)
    optimal_values = solver.solve()

    for s, v in optimal_values.items():
        print(f"状态 [{s}] 价值: {v:.2f}")

    q_dig = solver.compute_q('寻宝区', '挖宝')
    q_back = solver.compute_q('寻宝区', '退回安全区')
    print(f"动作评估 - 挖宝: {q_dig:.2f}, 退回: {q_back:.2f}")

经典的两个问题

  1. 认知荒漠:因“奖励真空”导致的逻辑坍缩认知荒漠的本质是即时奖励 r r r 的全线缺失,使得折现因子 γ \gamma γ 失去了赖以生存的“放大对象”。即便公式 Q = r + γ V n e x t Q = r + \gamma V_{next} Q=r+γVnext 在理论上严丝合缝,但若环境设计仅在终点设奖,漫长的中间路径便会沦为真空地带;此时由于捕捉不到任何初始的“奖励火种”,贝曼方程在每一层递归中都只能传递零值,导致智能体如同拿着精密放大镜在空无一物的白纸上寻找线索,最终因感知不到任何反馈向量而在原地盲目打转。
  2. 土拨鼠陷阱:因“ max ⁡ \max max 偏见”导致的伪风险规避土拨鼠陷阱本质上是 max ⁡ ( Q ) \max(Q) max(Q) 逻辑引发的永久性偏见,它将算法锁死在已知的局部最优解中。贝曼最优方程表现出的“风险规避”并非由于算法胆小,而是因为它极度现实且缺乏主动探索的脉冲:一旦智能体在初次尝试动作 B 时偶然遭遇负分,该动作的评估值会立即跌入谷底,而贪婪的 max ⁡ \max max 操作从此会永久性地过滤掉这个“已知低分”选项。这种为了维持当前数值最大化而拒绝再次尝试的行为,使智能体陷入了不断重复平庸经验的死循环,彻底丧失了发现全局最优解的机会。
  3. 一句话点破:荒漠是因为“没得放( r = 0 r=0 r=0)”,陷阱是因为“不敢放( max ⁡ \max max 剔除)”。这正是为什么在自动驾驶或围棋的生产级实现中,必须在贝曼方程之外额外增加“奖励塑造”来填补荒漠,以及“ ϵ \epsilon ϵ-贪婪采样”来打破陷阱。

6. Q-learning:贝尔曼方程的第一个落地算法

理解了贝尔曼最优方程的理论框架后,我们自然会问:如何在实际中求解这个方程? 如果环境模型已知,我们可以用上一节的价值迭代来求解;但如果模型未知(绝大多数现实问题),就必须通过与环境的交互来学习。Q‑learning 正是这样一个开创性的无模型算法,它直接利用贝尔曼最优方程的思想,在交互中逐步逼近最优动作价值函数 Q ∗ Q^* Q

从贝尔曼最优方程到增量更新

回忆最优动作价值函数满足:
Q ∗ ( s , a ) = E s ′ ∼ P [ r + γ max ⁡ a ′ Q ∗ ( s ′ , a ′ ) ] Q^*(s,a) = \mathbb{E}_{s' \sim P}\left[ r + \gamma \max_{a'} Q^*(s',a') \right] Q(s,a)=EsP[r+γamaxQ(s,a)]

这个等式表明:在状态 s s s 执行动作 a a a 后,期望得到的回报等于即时奖励 r r r 加上折扣后的未来最优价值。然而,期望依赖于状态转移概率 P P P,而我们在真实环境中只有一次实际的交互结果。Q‑learning 的做法非常朴素——用一次采样来近似期望

假设智能体在状态 s s s 执行动作 a a a 后,观测到奖励 r r r 和下一状态 s ′ s' s,那么可以构造一个 时序差分目标(TD target):

y t = r + γ max ⁡ a ′ Q ( s ′ , a ′ ) y_t = r + \gamma \max_{a'} Q(s', a') yt=r+γamaxQ(s,a)

其中 Q ( s ′ , a ′ ) Q(s',a') Q(s,a) 是当前对下一状态动作价值的估计。这个目标就是对贝尔曼最优方程右边的一次采样。当前估计 Q ( s , a ) Q(s,a) Q(s,a) 与目标 y t y_t yt 之间的差距称为 TD 误差

δ t = y t − Q ( s , a ) \delta_t = y_t - Q(s,a) δt=ytQ(s,a)

TD 误差衡量了当前估计与“采样后的贝尔曼最优目标”之间的偏差。如果 δ t > 0 \delta_t > 0 δt>0,说明实际结果比当前估计更好,应该提高 Q ( s , a ) Q(s,a) Q(s,a);反之则应降低。Q‑learning 采用增量更新,将当前 Q Q Q 值向目标方向调整一小步:

Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a) \leftarrow Q(s,a) + \alpha \left[ r + \gamma \max_{a'} Q(s',a') - Q(s,a) \right] Q(s,a)Q(s,a)+α[r+γamaxQ(s,a)Q(s,a)]

这里 α ∈ ( 0 , 1 ] \alpha \in (0,1] α(0,1] 是学习率,控制着每次更新的步长。这个公式可以直观理解为:用实际观察到的奖励和下一状态的最佳估计来修正当前状态的 Q Q Q,每一步都在拉近 Q ( s , a ) Q(s,a) Q(s,a) 与贝尔曼最优方程右侧的距离。

为了更直观,考虑一个简单例子:在TD学习中,从19:00刷题状态出发,经历刷题获得即时奖励5分,到达估计价值50分的22:00睡觉状态,用折扣因子0.8计算采样目标5 + 0.8×50 = 45分,以此修正原估值。

离线策略特性:数据可以“回收利用”

仔细观察更新公式,你会发现目标值 max ⁡ a ′ Q ( s ′ , a ′ ) \max_{a'} Q(s',a') maxaQ(s,a) 与当前行为策略如何选择动作 a a a 完全无关。无论智能体是用贪婪策略、随机策略还是人类演示来生成动作 a a a,只要得到了 ( s , a , r , s ′ ) (s,a,r,s') (s,a,r,s) 这个经验,就可以用来更新 Q ( s , a ) Q(s,a) Q(s,a)。这种性质称为 离线策略(off-policy)
离线策略带来两大好处:

  • 数据复用:智能体可以从任意策略产生的历史数据中学习,甚至可以学习专家演示。这意味着我们可以将过去的经验存储起来,反复使用,大大提高样本效率。
  • 经验回放:可以将过去的经验存储在一个缓冲区中,训练时随机采样更新,这不仅能复用数据,还能打破连续样本之间的相关性,使训练更加稳定。这个思想后来成为深度强化学习的基石,如 DQN 中的经验回放。
    与离线策略相对的是 在线策略(on-policy) 算法,如 Sarsa。在 Sarsa 中,更新公式为 Q ( s , a ) ← Q ( s , a ) + α [ r + γ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a) \leftarrow Q(s,a) + \alpha [r + \gamma Q(s',a') - Q(s,a)] Q(s,a)Q(s,a)+α[r+γQ(s,a)Q(s,a)],其中 a ′ a' a 必须是当前策略实际选择的动作。这意味着 Sarsa 只能用当前策略产生的数据进行更新,一旦策略改变,旧数据就失效了。相比之下,Q‑learning 的离线特性使其更加灵活高效。

算法流程与探索机制

下面是 Q‑learning 的经典伪代码(表格型版本),它清晰地展示了如何通过反复交互来更新 Q Q Q 值:

初始化 Q(s,a) 表格,通常设为零(或小随机数)
循环每个回合(episode):
    初始化状态 s
    循环每一步:
        根据当前 Q 值采用 ε-greedy 策略选择动作 a
        执行动作 a,观测奖励 r 和下一状态 s'
        更新 Q 值:
            Q(s,a) ← Q(s,a) + α [ r + γ max_{a'} Q(s',a') - Q(s,a) ]
        s ← s'
        如果 s 是终止状态,跳出循环

其中 ε-greedy 策略 是平衡“利用”与“探索”的简单而有效的方法:以概率 ϵ \epsilon ϵ 随机选择动作(探索),以概率 1 − ϵ 1-\epsilon 1ϵ 选择当前 Q Q Q 值最大的动作(利用)。随着学习的进行, ϵ \epsilon ϵ 通常会逐渐衰减(例如从 1.0 衰减到 0.01),让智能体从探索为主过渡到利用为主。

探索机制至关重要:如果没有探索,智能体可能永远无法发现更好的动作,这正是“土拨鼠陷阱”的根源。ε-greedy 通过强制尝试非最优动作,打破了 max ⁡ \max max 偏见。此外,还可以采用更高级的探索策略,如 Boltzmann 探索(根据 Q Q Q 值的 softmax 分布选择动作),或加入探索噪声等。

收敛性与局限性

收敛性保证:在表格型表示、所有状态-动作对无限次访问、学习率满足 Robbins‑Monro 条件( ∑ t = 0 ∞ α t = ∞ \sum_{t=0}^{\infty} \alpha_t = \infty t=0αt= ∑ t = 0 ∞ α t 2 < ∞ \sum_{t=0}^{\infty} \alpha_t^2 < \infty t=0αt2<)时,Q‑learning 以概率 1 收敛到最优 Q ∗ Q^* Q。直观解释:第一个条件保证学习率不会衰减得太快,使得算法有足够机会修正估计;第二个条件保证最终步长足够小,避免震荡。常见的满足条件的衰减策略如 α t = 1 / t \alpha_t = 1/t αt=1/t

局限性

  • 状态空间爆炸:当状态空间或动作空间巨大时(例如围棋、Atari 游戏画面),表格无法存储,必须引入函数近似(如神经网络)。但函数近似会带来新的挑战,如收敛性不再保证、可能发散等。
  • 探索效率:在高维空间中,随机探索可能极慢,需要更智能的探索机制(如基于不确定性的探索、内在动机等)。
  • 收敛速度:即使理论上收敛,实际中可能需要海量样本。对于复杂任务,直接使用表格 Q-learning 几乎不可行。
  • 过估计问题:由于 max ⁡ \max max 操作,Q-learning 往往会高估 Q Q Q 值,尤其是在使用函数近似时。这催生了 Double Q-learning 等改进算法。

Q‑learning 如何回应“两个经典问题”

有趣的是,Q‑learning 的设计恰好体现了对前文“认知荒漠”和“土拨鼠陷阱”的朴素解法:

  • 认知荒漠:Q‑learning 通过与环境的持续交互,不断获得即时奖励 r r r,从而让贝尔曼方程的递归有了“火种”。即使环境奖励稀疏,只要偶尔有非零奖励,就能通过 自举(bootstrap)——用当前估计的 Q ( s ′ , a ′ ) Q(s',a') Q(s,a) 来更新 Q ( s , a ) Q(s,a) Q(s,a)——逐渐将终局奖励反向传播到早期状态。例如,在迷宫中只有终点有 +1 奖励,智能体第一次到达终点时,会将奖励回传给前一步的状态,然后前一步的状态再回传给更早的状态,最终整个路径上的状态都能获得非零价值估计。这正是“奖励塑造”的一种天然形式,无需人工设计中间奖励。

  • 土拨鼠陷阱:Q‑learning 采用的 ε-greedy 策略强制智能体以一定概率尝试非最优动作,这正是打破 max ⁡ \max max 偏见的直接手段。此外,离线策略特性允许智能体从过去的探索经验中反复学习,即使某次偶然失败导致某个动作被低估,未来仍有机会通过再次探索修正它。相比之下,纯粹贪心的价值迭代(如 V ∗ ( s ) = max ⁡ V^*(s)=\max V(s)=max)一旦陷入局部最优就无法自拔,而 Q-learning 的动态探索机制提供了逃逸的可能。
    因此,Q‑learning 不仅是一个实用的算法,更是对贝尔曼方程深刻理解后的自然产物。它开启了无模型强化学习的广阔天地,并直接催生了后来的深度 Q 网络(DQN)等一系列重要进展。

DQN的出现

DQN 的出现,主要是因为经典的 Q-learning 在稍微复杂一点的问题面前就转不动了。Q-learning 的核心是维护一张表,把所有状态和动作对应的 Q 值记下来。这个方法在小迷宫这类状态有限的问题上没问题,但一旦状态多起来,比如让计算机直接看游戏画面,画面里每个像素的颜色和位置组合起来就是天文数字,别说存一张表,就算把整个硬盘填满也放不下。那怎么办?换个思路:不存具体的值,而是用一个函数去拟合。就像给你一堆点,你画一条曲线去穿过它们。神经网络就是一种特别灵活的函数,给它输入状态,它能输出每个动作对应的 Q 值。这就是 DQN 的基本想法——用神经网络代替 Q 表。网络的参数就是那些神经元之间的连接权重,训练的过程就是不断调整这些权重,让网络的输出越来越接近真实的 Q 值。但这个想法一开始直接做是行不通的,因为用神经网络套 Q-learning 会遇到两个很棘手的问题。

两个麻烦:数据太像,目标在跑

第一个问题是数据长得太像了。玩一个游戏,前后几步的画面肯定是连续变化的,相关性非常高。如果用这些连续的数据去训练神经网络,网络会很快“记住”最近这几步的规律,但稍微遇到一点不一样的情况就懵了。就像学生考试前只背了同一类题,题目一变形就不会做。这种学习很不稳定,甚至会出现前面刚学会的后面又忘了的情况,模型一直震荡,很难收敛。
第二个问题是目标值一直在变。Q-learning 的更新公式里有一个目标值,它是即时奖励加上下一个状态的最大 Q 值乘以一个折扣。问题在于,下一个状态的 Q 值也是同一个网络估计出来的。这意味着网络一边在调整自己,想去接近一个目标,但这个目标本身又随着网络的调整而变动。这就好比跑步的时候,终点线也在跟着你跑,你永远追不上。这会导致训练过程非常不稳定,甚至发散。

两个巧妙的解决办法

为了解决这两个问题,DQN 引入了两个关键机制。
第一个叫经验回放。智能体在和环境交互的过程中,会把每一次的经历——当前状态、做的动作、拿到的奖励、下一个状态——都存进一个叫作“经验池”的地方。训练的时候,不直接用刚发生的那一步数据,而是从这个池子里随机抽出一小批来学。这个随机抽取的动作,打破了前后数据之间的相关性。抽出来的经历可能来自十分钟前,也可能来自一小时前,它们之间没什么关联,这样网络就能学到更通用的规律。而且这些经历可以反复拿出来用,不像在线学习那样用过一次就扔,数据的利用率也提高了。

对比维度 Q-learning 的离线策略特性 DQN 的经验回放机制
本质 算法本身的一种数学性质,指更新目标值与行为策略无关 一种具体的数据处理方法,用于存储和复用历史经验
目的 允许从任意策略产生的数据中学习(理论上的可能性) 打破连续样本的相关性,提高样本效率,稳定神经网络训练
操作方式 通常在线更新:每交互一步,用当前数据更新Q表后丢弃(理论上可复用,但传统表格实现不存储) 将每一步经验存入缓冲区,训练时从中随机采样小批量进行更新
是否存储数据 一般不存储历史数据,数据用过即弃 必须存储大量历史经验,依靠经验池反复采样
必要性 不是必须的,只是算法具有的特性,不影响基础版本的运行 对DQN来说是必须的,否则神经网络训练会因数据相关而发散
数据利用效率 每条数据通常只用一次,效率较低 同一条经验可被多次采样学习,大幅提高利用率
解决的核心问题 使算法能够利用异策略数据(如专家演示) 解决连续样本相关性导致的训练不稳定和过拟合问题
第二个叫目标网络。这个机制也很巧妙,就是再准备一个一模一样的网络,专门用来计算目标值。平时主网络正常更新,目标网络先不动。计算目标值的时候,用的是目标网络的输出,而不是主网络的。这样一来,在一段时间内,目标值是相对固定的,主网络可以稳稳地向这个固定目标靠近。等主网络更新了若干步之后,再把主网络的参数复制给目标网络,让它也更新一下。然后重复这个过程。这就解决了“终点线乱跑”的问题,训练稳定多了。

整个流程串起来看

一开始,两个网络参数相同,经验池是空的。智能体开始玩游戏,看到一帧画面,主网络根据当前的状态给出每个动作的分数。然后用 ε-greedy 的方法选动作,大部分时候选分数最高的,小部分时候随机选一个——这是为了探索那些没试过的可能性。选了动作之后,游戏会返回下一个画面和一个得分,这一整条经历就被存进经验池。

接下来就是从经验池里随机抽一批历史经历出来,对每一条经历,用目标网络算出一个目标分数,公式就是那条经历里的得分加上下一状态的最大 Q 值乘以折扣。然后让主网络去猜这条经历里的动作值多少钱,用猜的结果和目标分数之间的差距去调整主网络的参数。每过固定的步数,把主网络的参数同步给目标网络。这个过程不断重复,主网络的估计就会越来越准。

所以 DQN 本质上还是 Q-learning 的那套逻辑,用实际经历的样本来修正当前的估计。只不过它用神经网络代替了表格,用经验回放和目标网络解决了神经网络训练中出现的两个致命问题。这样一来,Q-learning 就能处理以前完全没辙的复杂任务,比如直接从高维度的图像输入里学会玩游戏,这也是当年 DQN 引起轰动的原因。

DQN VS Q-Learning代码

import matplotlib

matplotlib.use('TkAgg')  # 强制使用 TkAgg 后端,避免 PyCharm 显示错误

import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random

env = gym.make('FrozenLake-v1', map_name="4x4", is_slippery=False)
state_dim = env.observation_space.n
action_dim = env.action_space.n

gamma = 0.99
epsilon = 0.1
episodes = 2000


def moving_average(data, window=100):
    if len(data) < window:
        return data
    return np.convolve(data, np.ones(window) / window, mode='valid')


def state_to_onehot(s, dim=16):
    onehot = np.zeros(dim)
    onehot[s] = 1.0
    return onehot


print("Running Q-learning...")
Q = np.zeros((state_dim, action_dim))
rewards_ql = []

for ep in range(episodes):
    state, _ = env.reset()
    done = False
    total_reward = 0
    while not done:
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            action = np.argmax(Q[state, :])

        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        total_reward += reward

        best_next = np.max(Q[next_state, :])
        Q[state, action] += 0.1 * (reward + gamma * best_next - Q[state, action])

        state = next_state
    rewards_ql.append(total_reward)

print("Running DQN...")


class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.out = nn.Linear(64, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.out(x)


lr = 0.001
batch_size = 32
memory_size = 10000
target_update_freq = 100
epsilon_min = 0.01
epsilon_decay = 0.995

policy_net = DQN(state_dim, action_dim)
target_net = DQN(state_dim, action_dim)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=lr)

memory = deque(maxlen=memory_size)

rewards_dqn = []
epsilon_current = epsilon

for ep in range(episodes):
    state, _ = env.reset()
    state_onehot = state_to_onehot(state)
    done = False
    total_reward = 0

    while not done:
        if np.random.rand() < epsilon_current:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_vals = policy_net(torch.FloatTensor(state_onehot).unsqueeze(0))
                action = q_vals.argmax().item()

        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        next_state_onehot = state_to_onehot(next_state)
        total_reward += reward

        memory.append((state_onehot, action, reward, next_state_onehot, done))
        state_onehot = next_state_onehot

        if len(memory) >= batch_size:
            batch = random.sample(memory, batch_size)
            states = torch.FloatTensor(np.array([b[0] for b in batch]))
            actions = torch.LongTensor(np.array([b[1] for b in batch])).unsqueeze(1)
            rewards = torch.FloatTensor(np.array([b[2] for b in batch]))
            next_states = torch.FloatTensor(np.array([b[3] for b in batch]))
            dones = torch.FloatTensor(np.array([b[4] for b in batch]))

            current_q = policy_net(states).gather(1, actions).squeeze()

            with torch.no_grad():
                next_q = target_net(next_states).max(1)[0]
                target_q = rewards + gamma * next_q * (1 - dones)

            loss = nn.MSELoss()(current_q, target_q)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if epsilon_current > epsilon_min:
        epsilon_current *= epsilon_decay

    if ep % target_update_freq == 0:
        target_net.load_state_dict(policy_net.state_dict())

    rewards_dqn.append(total_reward)

plt.figure(figsize=(12, 5))
plt.plot(moving_average(rewards_ql, 100), label='Q-learning', linewidth=2)
plt.plot(moving_average(rewards_dqn, 100), label='DQN', linewidth=2)
plt.xlabel('Episode')
plt.ylabel('Average Reward (100-episode smooth)')
plt.title('Q-learning vs DQN on FrozenLake (deterministic)')
plt.legend()
plt.grid(alpha=0.3)

try:
    plt.show()
except Exception:
    plt.savefig('comparison.png')
    print("Figure saved as comparison.png due to display error.")

print("Done!")

深入理解REINFORCE算法:从原理到实现

REINFORCE算法是强化学习领域中策略梯度方法的奠基之作,由Williams于1992年提出。它巧妙地将策略参数化,利用蒙特卡洛采样估计梯度,为后续复杂的策略梯度算法奠定了理论基础。本文将系统介绍REINFORCE的数学原理、算法流程、实用技巧及改进方向,并提供一个完整的PyTorch实现,帮助读者建立对这类算法直观而深刻的理解。

1. 为什么需要REINFORCE?

在强化学习中,智能体通过与环境交互来学习最优策略,目标是最大化期望累积折扣奖励:
J ( π ) = E τ ∼ π [ ∑ t = 0 T γ t r t ] J(\pi) = \mathbb{E}_{\tau\sim\pi}\left[\sum_{t=0}^T \gamma^t r_t\right] J(π)=Eτπ[t=0Tγtrt]

传统的值函数方法(如Q-learning)走的是"间接"路线:先估计状态或动作的价值,再据此选择动作。而策略梯度方法则另辟蹊径,直接对策略 π θ ( a ∣ s ) \pi_\theta(a|s) πθ(as) 进行参数化,通过梯度上升优化 J ( θ ) J(\theta) J(θ)。这种方式具有独特的优势:

  • 天然适应连续动作空间和高维问题
  • 能够学习随机策略,内置探索机制
  • 更好的收敛性保证,避免值函数方法中的"死循环"

REINFORCE作为最简单的策略梯度算法,其核心思想既直观又优雅:用当前策略采样一条完整轨迹,然后根据每个时间步后的累积奖励来调整该步动作的概率——动作带来高回报就"鼓励",否则就"抑制"。这个朴素的想法可以通过严格的数学推导得到。

2. 策略梯度定理的直观理解

策略梯度定理给出了目标函数 J ( θ ) J(\theta) J(θ) 对参数 θ \theta θ 的梯度表达式:

∇ θ J ( θ ) = E τ ∼ π θ [ ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ s t )   G t ] \nabla_\theta J(\theta) = \mathbb{E}_{\tau\sim\pi_\theta}\left[ \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t) \, G_t \right] θJ(θ)=Eτπθ[t=0Tθlogπθ(atst)Gt]

其中 G t = ∑ k = t T γ k − t r k G_t = \sum_{k=t}^{T} \gamma^{k-t} r_k Gt=k=tTγktrk 是从时刻 t t t 开始的累积折扣奖励,称为回报

这个公式的精妙之处在于将梯度分解为两个部分的乘积:

  • ∇ θ log ⁡ π θ ( a t ∣ s t ) \nabla_\theta \log \pi_\theta(a_t|s_t) θlogπθ(atst)得分函数,指示如何调整参数以增加该动作的概率
  • G t G_t Gt权重因子,决定了调整的幅度和方向

直观理解:如果某动作带来的回报高于平均水平,我们希望增加它的概率;反之则降低。这正是"试错学习"思想的数学实现。

3. REINFORCE算法流程详解

基于上述理论,REINFORCE采用蒙特卡洛方法估计梯度,具体步骤如下:

Step 1: 采样轨迹
使用当前策略 π θ \pi_\theta πθ 与环境交互,生成一条完整轨迹(直到终止状态或最大步长),记录每一步的状态、动作和奖励 ( s t , a t , r t ) (s_t, a_t, r_t) (st,at,rt)

Step 2: 计算回报
对轨迹中的每个时间步 t t t,计算回报 G t = ∑ k = t T − 1 γ k − t r k G_t = \sum_{k=t}^{T-1} \gamma^{k-t} r_k Gt=k=tT1γktrk

Step 3: 估计梯度
∇ θ J ( θ ) ≈ ∑ t = 0 T − 1 ∇ θ log ⁡ π θ ( a t ∣ s t )   G t \nabla_\theta J(\theta) \approx \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t|s_t) \, G_t θJ(θ)t=0T1θlogπθ(atst)Gt

Step 4: 更新策略
θ ← θ + α ∇ θ J ( θ ) \theta \leftarrow \theta + \alpha \nabla_\theta J(\theta) θθ+αθJ(θ)

在实际实现中,为了利用深度学习框架的自动求导功能,我们通常构造一个等价的损失函数:
L ( θ ) = − ∑ t log ⁡ π θ ( a t ∣ s t ) G t L(\theta) = -\sum_t \log \pi_\theta(a_t|s_t) G_t L(θ)=tlogπθ(atst)Gt
对其最小化等价于对原目标进行梯度上升,因为 ∇ θ L = − ∇ θ J \nabla_\theta L = -\nabla_\theta J θL=θJ

4. 降低方差的利器:基线方法

REINFORCE最突出的问题是梯度估计方差过大,导致学习过程剧烈波动。原因在于单条轨迹的回报可能变化很大,就像用一次考试分数评判一个学生的整体水平一样不可靠。

引入基线是降低方差的有效手段:
∇ θ J ( θ ) = E [ ∑ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ( G t − b ( s t ) ) ] \nabla_\theta J(\theta) = \mathbb{E}\left[ \sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) \big( G_t - b(s_t) \big) \right] θJ(θ)=E[tθlogπθ(atst)(Gtb(st))]

只要基线 b ( s t ) b(s_t) b(st) 不依赖于当前动作 a t a_t at,这个期望与原梯度相等。证明的关键在于:
E [ ∇ θ log ⁡ π θ ( a t ∣ s t ) b ( s t ) ] = ∑ a ∇ θ π θ ( a ∣ s ) b ( s ) = b ( s ) ∇ θ ∑ a π θ ( a ∣ s ) = 0 \mathbb{E}[\nabla_\theta \log \pi_\theta(a_t|s_t) b(s_t)] = \sum_a \nabla_\theta \pi_\theta(a|s) b(s) = b(s) \nabla_\theta \sum_a \pi_\theta(a|s) = 0 E[θlogπθ(atst)b(st)]=aθπθ(as)b(s)=b(s)θaπθ(as)=0

最优基线通常是状态价值函数 V π ( s t ) V^\pi(s_t) Vπ(st),它表示从该状态出发的平均回报。此时 G t − V ( s t ) G_t - V(s_t) GtV(st) 称为优势函数的估计,直观反映了当前动作相对于平均表现的好坏。加入基线后的新解释:如果动作的结果比预期好,就加强它;否则就削弱它。

5. 算法优缺点剖析

优点:

  • 理论简洁:数学推导清晰,易于理解和实现
  • 无偏估计:理论上保证收敛到局部最优
  • 适用性广:支持离散和连续动作空间,只需策略可微
  • 随机策略:内置探索机制,无需额外处理

缺点:

  • 方差大:学习过程波动剧烈,收敛缓慢
  • 样本效率低:每条轨迹只能使用一次,且必须on-policy采样
  • 超参数敏感:需要仔细调整学习率、基线结构等

6. 从REINFORCE到现代策略梯度算法

REINFORCE的局限性催生了一系列改进工作:

  • Actor-Critic架构:同时学习策略(Actor)和价值函数(Critic),利用自举法(bootstrapping)进一步降低方差,代表作有A2C/A3C
  • 自然策略梯度:通过KL散度约束控制更新步长,提高稳定性,如TRPO和PPO
  • Off-policy扩展:通过重要性采样复用历史数据,但需注意分布偏移修正

理解REINFORCE就像是掌握了策略梯度方法的"第一性原理",为学习这些先进算法打下坚实基础。

7. PyTorch实战:带基线的REINFORCE

下面提供一个完整的带基线的REINFORCE实现,以CartPole环境为例。代码中包含了详细的注释,帮助读者理解每一步的用意。

import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np
from collections import deque

class PolicyNet(nn.Module):
    """策略网络:输出动作的概率分布"""
    def __init__(self, state_dim, action_dim, hidden=128):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, action_dim)
        )
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        return self.softmax(self.fc(x))

class ValueNet(nn.Module):
    """价值网络:估计状态价值,用作基线"""
    def __init__(self, state_dim, hidden=128):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
    
    def forward(self, x):
        return self.fc(x)

def reinforce_with_baseline(env_name='CartPole-v1', num_episodes=1000, gamma=0.99, 
                            lr_policy=1e-2, lr_value=1e-2):
    # 初始化环境和网络
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    policy = PolicyNet(state_dim, action_dim)
    value = ValueNet(state_dim)
    policy_opt = optim.Adam(policy.parameters(), lr=lr_policy)
    value_opt = optim.Adam(value.parameters(), lr=lr_value)
    
    # 记录训练过程
    reward_history = deque(maxlen=100)
    
    for episode in range(num_episodes):
        # 收集轨迹数据
        log_probs = []
        rewards = []
        states = []
        
        state = env.reset()
        done = False
        
        while not done:
            states.append(state)
            state_t = torch.FloatTensor(state).unsqueeze(0)
            
            # 采样动作
            probs = policy(state_t)
            m = torch.distributions.Categorical(probs)
            action = m.sample()
            log_prob = m.log_prob(action)
            
            next_state, reward, done, _ = env.step(action.item())
            
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        # 计算折扣回报
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns)
        
        # 计算基线值并估计优势
        states_t = torch.FloatTensor(np.array(states))
        values = value(states_t).squeeze()
        advantages = returns - values.detach()  # 关键:detach阻止梯度流向价值网络
        
        # 更新策略网络:最大化优势加权对数概率
        policy_loss = 0
        for log_prob, adv in zip(log_probs, advantages):
            policy_loss -= log_prob * adv
        policy_opt.zero_grad()
        policy_loss.backward()
        policy_opt.step()
        
        # 更新价值网络:拟合真实回报
        value_loss = nn.MSELoss()(values, returns)
        value_opt.zero_grad()
        value_loss.backward()
        value_opt.step()
        
        # 记录和输出
        reward_history.append(sum(rewards))
        if episode % 100 == 0:
            avg_reward = np.mean(reward_history)
            print(f"Episode {episode}, Avg Reward (last 100): {avg_reward:.2f}")
    
    env.close()
    return policy, value

# 运行示例
if __name__ == "__main__":
    policy, value = reinforce_with_baseline(num_episodes=1000)

代码要点解析:

  1. 网络设计:策略网络输出softmax概率,价值网络输出标量价值估计
  2. 轨迹收集:使用Categorical分布采样动作,保留对数概率
  3. 回报计算:从后向前递推计算折扣回报
  4. 优势估计:使用 G t − V ( s t ) G_t - V(s_t) GtV(st) 作为优势,其中 V ( s t ) V(s_t) V(st) 需要detach防止影响策略梯度
  5. 双重更新:策略网络用优势加权更新,价值网络用MSE拟合回报
  6. 经验追踪:维护最近100个回合的奖励以监控训练效果

8. 实践技巧与调参建议

基于经验,这里提供几个实用的调参建议:

  • 回报归一化:将returns在整个batch内做标准化,可以稳定训练
  • 学习率调整:策略网络通常需要比价值网络稍大的学习率
  • 熵正则化:在损失函数中加入策略熵的负项,鼓励探索
  • 多线程采样:同时运行多个环境收集轨迹,降低样本相关性
  • 广义优势估计(GAE):用GAE替代单步优势,平衡方差和偏差

9. 总结与展望

REINFORCE作为策略梯度方法的开山之作,其价值不仅在于算法的实用性,更在于它所奠定的思想基础。从REINFORCE出发,我们可以清晰地看到策略梯度方法的发展脉络:从简单蒙特卡洛采样,到引入基线降低方差,再到Actor-Critic架构,最后演进到PPO、SAC等现代主流算法。

理解REINFORCE就像掌握了强化学习的一把钥匙,它揭示了"如何直接优化策略"这一核心问题的本质。希望本文能帮助读者建立对这一经典算法的深刻认识,为深入学习更复杂的强化学习方法打下坚实基础。

树搜索强化学习(阿尔法go的核心之一,我用象棋做了个demo)

没问题,为您将 UCT 公式的核心设计原理做一份结构化的梳理。这份整理不仅适合作为您的核心学习笔记,未来如果您需要将这项技术提炼并在 PPT 中向团队或领导做汇报,也可以直接作为核心素材使用。

MCTS

MCTS概述

蒙特卡洛树搜索(MCTS)MCTS 的英文全拼是 Monte Carlo Tree Search(蒙特卡洛树搜索),它不需要死板地穷举所有未来的可能性,而是通过在脑海中进行成千上万次快速的模拟试错来预判局势。它巧妙地利用数学公式在“选择已知的高胜率走法(利用)”和“发掘未知的高潜力走法(探索)”之间寻找平衡,通过不断循环选择、扩展、模拟和回溯这四个动作,将零散的试错经验沉淀到一棵庞大的决策树中,最终在面对极其复杂的局面(如围棋残局或复杂的数学推理)时,为您指出那条被反复验证过胜算最高的最优路径。

MCTS 核心引擎:UCT 公式设计原理拆解

核心痛点(多臂老虎机问题):在算力/时间有限的情况下,算法必须在**“选择已知高胜率的走法(利用)”“尝试未知可能的高收益走法(探索)”**之间找到完美的数学平衡。

完整公式:

U C T = W i N i + c ln ⁡ N N i UCT = \frac{W_i}{N_i} + c \sqrt{\frac{\ln N}{N_i}} UCT=NiWi+cNilnN

该公式巧妙地分为左右两部分,代表了两种截然不同的决策心理:

一、 左半部分:经验主义者的“贪婪” (Exploitation)
  • 公式提取 W i N i \frac{W_i}{N_i} NiWi
  • 核心含义:当前节点的平均胜率(历史胜利次数 W i W_i Wi ÷ 历史访问次数 N i N_i Ni)。
  • 作用机制:基于历史经验,优先选择过去赢面最大的走法。
  • 局限性:如果只依赖这一项,算法会变得极其短视。一旦某条潜力巨大的分支在第一次尝试时偶然失败,算法就会永远放弃它,陷入“局部最优”。
二、 右半部分:理想主义者的“好奇心” (Exploration)
  • 公式提取 c ln ⁡ N N i c \sqrt{\frac{\ln N}{N_i}} cNilnN
  • 核心含义:当前节点的探索价值(基于统计学置信区间求出的“误差上限”)。
  • 作用机制(变量拆解)
  • 分子 ln ⁡ N \ln N lnN(父节点总访问量的自然对数)—— 作用:防止遗忘
    随着全局推演次数 N N N 的增加,分子的值会缓慢持续增长。即使某个节点被冷落( N i N_i Ni 不变),随着时间推移,它的探索加分也会膨胀到足以逼迫算法再次去尝试它。
  • 分母 N i N_i Ni(当前子节点的访问量)—— 作用:见好就收
    一旦算法去访问某个好奇的节点,分母 N i N_i Ni 增大,整体的探索加分会迅速衰减,防止在已经摸透的分支上无限浪费算力。
  • 常数 c c c:调节探索欲望的超参数(通常取 2 \sqrt{2} 2 )。如果 c c c 设得很小,机器就会变得极其保守,只要发现一步胜率还凑合的棋,就会死死咬住不放,陷入固步自封的陷阱。反过来,如果 c c c 设得太大,机器又会变成一个不顾后果的赌徒,无视过去积累的高胜率经验,执意去测试所有冷门走法,导致算力被严重浪费。
    补充说明:
    UCT 公式中的这两个符号严谨推导自统计学中计算“误差上限”的定律(霍夫丁不等式):开根号 … \sqrt{\dots} )代表了评估不确定性的数学规律,因为我们对某步棋尝试的次数越少,算出的平均胜率与它客观真实胜率之间的潜在误差就越大,而这种误差在统计学上天然是沿着样本量平方根的轨迹反向缩小的;位于分子的自然对数 ln ⁡ N \ln N lnN)则是为了给机器的好奇心套上一条“极其克制的缰绳” ,随着全局总推演次数 N N N 呈百倍千倍地暴涨,对数函数极度缓慢的增长特性(如 ln ⁡ 10 ≈ 2.3 \ln 10\approx 2.3 ln102.3 ln ⁡ 1000 ≈ 6.9 \ln 1000\approx 6.9 ln10006.9)保证了那些长期被冷落的走法,其“探索加分”虽然在缓慢回升以防被永远遗忘,但绝不会失控膨胀,从而确保机器始终能将绝大部分算力稳稳集中在已知的高胜率主线路上。

象棋代码

import tkinter as tk
import random
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tkinter import messagebox
from collections import deque
import threading
import time

ROWS = 10
COLS = 9
SQUARE_SIZE = 65
MARGIN = 40
BOARD_BG = "#D49A5A"
LINE_COLOR = "#3E2723"
TEXT_RED = "#D32F2F"
TEXT_BLACK = "#1A1A1A"
FONT_UI = ("微软雅黑", 11)

PIECES = {
    'R': {'G': '帥', 'A': '仕', 'E': '相', 'R': '車', 'N': '馬', 'C': '炮', 'P': '兵'},
    'B': {'G': '將', 'A': '士', 'E': '象', 'R': '車', 'N': '馬', 'C': '炮', 'P': '卒'}
}

INPUT_DIM = ROWS * COLS * 9
HIDDEN_DIM = 512
LEARNING_RATE = 0.001
MCTS_SIMULATIONS = 400
C_PUCT = 1.5
BATCH_SIZE = 256
REPLAY_BUFFER_SIZE = 50000
TRAIN_PER_GAME = 20
MODEL_SAVE_PATH = "chess_net_final.pth"
MAX_MOVES_PER_GAME = 300
TEMPERATURE_THRESHOLD = 30

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def action_to_idx(from_r, from_c, to_r, to_c):
    return from_r * COLS * ROWS * COLS + from_c * ROWS * COLS + to_r * COLS + to_c

def idx_to_action(idx):
    to_c = idx % COLS
    idx //= COLS
    to_r = idx % ROWS
    idx //= ROWS
    from_c = idx % COLS
    from_r = idx // COLS
    return (from_r, from_c), (to_r, to_c)

def encode_board(board, turn):
    encoding = np.zeros((ROWS, COLS, 9), dtype=np.float32)
    for r in range(ROWS):
        for c in range(COLS):
            piece = board[r][c]
            if piece:
                color = 1.0 if piece[0] == 'R' else -1.0
                ptype = piece[1]
                type_index = {'G':0,'A':1,'E':2,'R':3,'N':4,'C':5,'P':6}[ptype]
                encoding[r,c,0] = color
                encoding[r,c,1+type_index] = 1.0
    encoding[:,:,8] = 1.0 if turn == 'R' else 0.0
    return encoding.flatten()

class AlphaZeroNet(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM, action_dim=ROWS*COLS*ROWS*COLS):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.policy_head = nn.Linear(hidden_dim, action_dim)
        self.value_head = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        policy = self.policy_head(x)
        value = torch.tanh(self.value_head(x))
        return policy, value.squeeze(-1)

class ChessBoard:
    def __init__(self):
        self.board = self._create_initial_board()
        self.current_turn = 'R'
        self.history = []
        self.move_count = 0
        self.position_count = {}
        self.history_states = []

    def _create_initial_board(self):
        board = [['' for _ in range(COLS)] for _ in range(ROWS)]
        board[0][4] = 'RG'
        board[0][3] = 'RA'; board[0][5] = 'RA'
        board[0][2] = 'RE'; board[0][6] = 'RE'
        board[0][0] = 'RR'; board[0][8] = 'RR'
        board[0][1] = 'RN'; board[0][7] = 'RN'
        board[2][1] = 'RC'; board[2][7] = 'RC'
        for col in [0,2,4,6,8]: board[3][col] = 'RP'
        board[9][4] = 'BG'
        board[9][3] = 'BA'; board[9][5] = 'BA'
        board[9][2] = 'BE'; board[9][6] = 'BE'
        board[9][0] = 'BR'; board[9][8] = 'BR'
        board[9][1] = 'BN'; board[9][7] = 'BN'
        board[7][1] = 'BC'; board[7][7] = 'BC'
        for col in [0,2,4,6,8]: board[6][col] = 'BP'
        return board

    def copy(self):
        new_board = ChessBoard()
        new_board.board = [row[:] for row in self.board]
        new_board.current_turn = self.current_turn
        new_board.move_count = self.move_count
        new_board.position_count = self.position_count.copy()
        new_board.history_states = self.history_states.copy()
        return new_board

    def board_to_string(self):
        return ''.join(''.join(row if row else ' ' for row in line) for line in self.board) + self.current_turn

    def _can_attack(self, piece, from_row, from_col, to_row, to_col):
        piece_type = piece[1]
        dr = abs(from_row - to_row)
        dc = abs(from_col - to_col)
        if piece_type == 'G':
            if dr + dc != 1: return False
            if piece[0] == 'R' and not (0 <= to_row <= 2 and 3 <= to_col <= 5): return False
            if piece[0] == 'B' and not (7 <= to_row <= 9 and 3 <= to_col <= 5): return False
            return True
        elif piece_type == 'A':
            if dr != 1 or dc != 1: return False
            if piece[0] == 'R' and not (0 <= to_row <= 2 and 3 <= to_col <= 5): return False
            if piece[0] == 'B' and not (7 <= to_row <= 9 and 3 <= to_col <= 5): return False
            return True
        elif piece_type == 'E':
            if dr != 2 or dc != 2: return False
            if piece[0] == 'R' and to_row >= 5: return False
            if piece[0] == 'B' and to_row <= 4: return False
            return True
        elif piece_type == 'R':
            return from_row == to_row or from_col == to_col
        elif piece_type == 'N':
            return (dr == 2 and dc == 1) or (dr == 1 and dc == 2)
        elif piece_type == 'C':
            return from_row == to_row or from_col == to_col
        elif piece_type == 'P':
            if piece[0] == 'R':
                if from_row < 5:
                    return to_row == from_row + 1 and to_col == from_col
                else:
                    return (to_row == from_row + 1 and to_col == from_col) or (to_row == from_row and abs(to_col - from_col) == 1)
            else:
                if from_row > 4:
                    return to_row == from_row - 1 and to_col == from_col
                else:
                    return (to_row == from_row - 1 and to_col == from_col) or (to_row == from_row and abs(to_col - from_col) == 1)
        return False

    def _is_path_clear(self, from_row, from_col, to_row, to_col):
        if from_row == to_row:
            step = 1 if to_col > from_col else -1
            for col in range(from_col + step, to_col, step):
                if self.board[from_row][col] != '': return False
        elif from_col == to_col:
            step = 1 if to_row > from_row else -1
            for row in range(from_row + step, to_row, step):
                if self.board[row][from_col] != '': return False
        else:
            return False
        return True

    def _get_hobble_pos(self, from_row, from_col, to_row, to_col):
        dr = to_row - from_row
        dc = to_col - from_col
        if abs(dr) == 2:
            return (from_row + (1 if dr > 0 else -1), from_col)
        else:
            return (from_row, from_col + (1 if dc > 0 else -1))

    def _are_generals_facing(self):
        red_g = black_g = None
        for r in range(ROWS):
            for c in range(COLS):
                if self.board[r][c] == 'RG': red_g = (r,c)
                elif self.board[r][c] == 'BG': black_g = (r,c)
        if not red_g or not black_g: return False
        if red_g[1] != black_g[1]: return False
        r1,c1 = red_g; r2,c2 = black_g
        for r in range(min(r1,r2)+1, max(r1,r2)):
            if self.board[r][c1] != '': return False
        return True

    def is_valid_move(self, from_row, from_col, to_row, to_col):
        if not (0 <= to_row < ROWS and 0 <= to_col < COLS): return False
        piece = self.board[from_row][from_col]
        if piece == '': return False
        color = piece[0]
        target = self.board[to_row][to_col]
        if target and target[0] == color: return False
        piece_type = piece[1]
        if piece_type == 'R':
            if from_row != to_row and from_col != to_col: return False
            if not self._is_path_clear(from_row, from_col, to_row, to_col): return False
        elif piece_type == 'N':
            dr = abs(from_row - to_row)
            dc = abs(from_col - to_col)
            if not ((dr == 2 and dc == 1) or (dr == 1 and dc == 2)): return False
            hr, hc = self._get_hobble_pos(from_row, from_col, to_row, to_col)
            if self.board[hr][hc] != '': return False
        elif piece_type == 'C':
            if from_row != to_row and from_col != to_col: return False
            if target != '':
                cnt = 0
                if from_row == to_row:
                    step = 1 if to_col > from_col else -1
                    for col in range(from_col + step, to_col, step):
                        if self.board[from_row][col] != '': cnt += 1
                else:
                    step = 1 if to_row > from_row else -1
                    for row in range(from_row + step, to_row, step):
                        if self.board[row][from_col] != '': cnt += 1
                if cnt != 1: return False
            else:
                if not self._is_path_clear(from_row, from_col, to_row, to_col): return False
        elif piece_type == 'G':
            if color == 'R' and not (0 <= to_row <= 2 and 3 <= to_col <= 5): return False
            if color == 'B' and not (7 <= to_row <= 9 and 3 <= to_col <= 5): return False
            if abs(from_row - to_row) + abs(from_col - to_col) != 1: return False
        elif piece_type == 'A':
            if color == 'R' and not (0 <= to_row <= 2 and 3 <= to_col <= 5): return False
            if color == 'B' and not (7 <= to_row <= 9 and 3 <= to_col <= 5): return False
            if abs(from_row - to_row) != 1 or abs(from_col - to_col) != 1: return False
        elif piece_type == 'E':
            if color == 'R' and to_row >= 5: return False
            if color == 'B' and to_row <= 4: return False
            if abs(from_row - to_row) != 2 or abs(from_col - to_col) != 2: return False
            if self.board[(from_row+to_row)//2][(from_col+to_col)//2] != '': return False
        elif piece_type == 'P':
            if color == 'R':
                if from_row < 5:
                    if to_row != from_row + 1 or to_col != from_col: return False
                else:
                    forward = (to_row == from_row + 1 and to_col == from_col)
                    horizontal = (to_row == from_row and abs(to_col - from_col) == 1)
                    if not (forward or horizontal): return False
            else:
                if from_row > 4:
                    if to_row != from_row - 1 or to_col != from_col: return False
                else:
                    forward = (to_row == from_row - 1 and to_col == from_col)
                    horizontal = (to_row == from_row and abs(to_col - from_col) == 1)
                    if not (forward or horizontal): return False
        else:
            return False
        captured = self.board[to_row][to_col]
        self.board[to_row][to_col] = self.board[from_row][from_col]
        self.board[from_row][from_col] = ''
        facing = self._are_generals_facing()
        in_check = False
        if not facing:
            in_check = self._is_general_in_check(color)
        self.board[from_row][from_col] = self.board[to_row][to_col]
        self.board[to_row][to_col] = captured
        if facing or in_check: return False
        return True

    def _is_general_in_check(self, color):
        gr = gc = -1
        for r in range(ROWS):
            for c in range(COLS):
                if self.board[r][c] == color + 'G':
                    gr,gc = r,c; break
            if gr != -1: break
        if gr == -1: return False
        for r in range(ROWS):
            for c in range(COLS):
                p = self.board[r][c]
                if p and p[0] != color:
                    if self._can_attack(p, r, c, gr, gc):
                        pt = p[1]
                        if pt in ('R','C'):
                            if pt == 'C' and self.board[gr][gc] != '':
                                cnt = 0
                                if r == gr:
                                    step = 1 if gc > c else -1
                                    for col in range(c+step, gc, step):
                                        if self.board[r][col] != '': cnt += 1
                                else:
                                    step = 1 if gr > r else -1
                                    for row in range(r+step, gr, step):
                                        if self.board[row][c] != '': cnt += 1
                                if cnt != 1: continue
                            elif pt == 'R':
                                if not self._is_path_clear(r,c,gr,gc): continue
                        elif pt == 'N':
                            hr,hc = self._get_hobble_pos(r,c,gr,gc)
                            if self.board[hr][hc] != '': continue
                        elif pt == 'E':
                            if abs(r-gr)==2 and abs(c-gc)==2:
                                if self.board[(r+gr)//2][(c+gc)//2] != '': continue
                        return True
        return False

    def has_legal_moves(self, color):
        for r in range(ROWS):
            for c in range(COLS):
                if self.board[r][c] and self.board[r][c][0] == color:
                    if self.get_valid_moves(r,c): return True
        return False

    def make_move(self, from_pos, to_pos):
        fr,fc = from_pos
        tr,tc = to_pos
        if self.is_valid_move(fr,fc,tr,tc):
            captured = self.board[tr][tc]
            self.history.append((from_pos, to_pos, captured))
            self.board[tr][tc] = self.board[fr][fc]
            self.board[fr][fc] = ''
            self.current_turn = 'B' if self.current_turn == 'R' else 'R'
            self.move_count += 1
            pos_str = self.board_to_string()
            self.position_count[pos_str] = self.position_count.get(pos_str,0)+1
            self.history_states.append(pos_str)
            return True
        return False

    def undo_move(self):
        if not self.history: return False
        from_pos, to_pos, captured = self.history.pop()
        fr,fc = from_pos; tr,tc = to_pos
        self.board[fr][fc] = self.board[tr][tc]
        self.board[tr][tc] = captured
        self.current_turn = 'B' if self.current_turn == 'R' else 'R'
        self.move_count -= 1
        pos_str = self.board_to_string()
        self.position_count[pos_str] = self.position_count.get(pos_str,0)-1
        self.history_states.pop()
        return True

    def get_valid_moves(self, row, col):
        return [(r,c) for r in range(ROWS) for c in range(COLS) if self.is_valid_move(row,col,r,c)]

    def get_valid_moves_for_color(self, color):
        return [((r,c),(tr,tc)) for r in range(ROWS) for c in range(COLS)
                if self.board[r][c] and self.board[r][c][0]==color for tr,tc in self.get_valid_moves(r,c)]

    def get_valid_moves_idx(self, color):
        moves = []
        for (fr,fc),(tr,tc) in self.get_valid_moves_for_color(color):
            moves.append(action_to_idx(fr,fc,tr,tc))
        return moves

    def is_game_over(self):
        red_g = any(self.board[r][c]=='RG' for r in range(ROWS) for c in range(COLS))
        black_g = any(self.board[r][c]=='BG' for r in range(ROWS) for c in range(COLS))
        if not red_g or not black_g: return True
        if not self.has_legal_moves(self.current_turn): return True
        if self._are_generals_facing(): return True
        if self.move_count >= MAX_MOVES_PER_GAME: return True
        if self.position_count.get(self.board_to_string(),0) >= 3: return True
        return False

    def get_winner(self):
        red_g = any(self.board[r][c]=='RG' for r in range(ROWS) for c in range(COLS))
        black_g = any(self.board[r][c]=='BG' for r in range(ROWS) for c in range(COLS))
        if not red_g: return 'B'
        if not black_g: return 'R'
        if self._are_generals_facing(): return 'R' if self.current_turn=='B' else 'B'
        if not self.has_legal_moves(self.current_turn): return 'B' if self.current_turn=='R' else 'R'
        if self.move_count >= MAX_MOVES_PER_GAME: return None
        if self.position_count.get(self.board_to_string(),0) >= 3: return None
        return None

    def get_chinese_move(self, from_pos, to_pos, color):
        fr,fc = from_pos; tr,tc = to_pos
        piece = self.board[fr][fc]
        if not piece: return "未知"
        ptype = piece[1]
        pname = PIECES[color][ptype]
        red_cols = ["一","二","三","四","五","六","七","八","九"]
        blk_cols = ["1","2","3","4","5","6","7","8","9"]
        if color == 'R':
            fcol = red_cols[fc]
            tcol = red_cols[tc]
        else:
            fcol = blk_cols[8-fc]
            tcol = blk_cols[8-tc]
        same_file = []
        for r in range(ROWS):
            if self.board[r][fc] == piece: same_file.append(r)
        prefix = pname + fcol
        if len(same_file) >= 2 and ptype in ('R','N','C','P'):
            if color == 'R':
                if fr == max(same_file): prefix = "前"+pname
                elif fr == min(same_file): prefix = "后"+pname
                else: prefix = "中"+pname
            else:
                if fr == min(same_file): prefix = "前"+pname
                elif fr == max(same_file): prefix = "后"+pname
                else: prefix = "中"+pname
        if fr == tr:
            return f"{prefix}{tcol}"
        if color == 'R':
            action = "进" if tr > fr else "退"
            step_str = red_cols[abs(tr-fr)-1]
        else:
            action = "进" if tr < fr else "退"
            step_str = str(abs(tr-fr))
        if ptype in ('N','A','E'):
            return f"{prefix}{action}{tcol}"
        else:
            return f"{prefix}{action}{step_str}"

class MCTSNode:
    def __init__(self, state, parent=None, move=None, prior=0.0):
        self.state = state
        self.parent = parent
        self.move = move
        self.children = []
        self.visits = 0
        self.value_sum = 0.0
        self.prior = prior

    def value(self):
        return self.value_sum / self.visits if self.visits > 0 else 0.0

    def puct_score(self, c_puct=C_PUCT):
        u = c_puct * self.prior * math.sqrt(self.parent.visits) / (1 + self.visits)
        return self.value() + u

class MCTSPlayer:
    def __init__(self, color, net, simulations=MCTS_SIMULATIONS, temperature=1.0):
        self.color = color
        self.net = net
        self.simulations = simulations
        self.temperature = temperature

    def get_move_probs(self, board):
        root = MCTSNode(board.copy())
        for _ in range(self.simulations):
            node = root
            path = [node]
            while node.children:
                node = max(node.children, key=lambda c: c.puct_score())
                path.append(node)
            if not node.state.is_game_over():
                moves_idx = node.state.get_valid_moves_idx(node.state.current_turn)
                if moves_idx:
                    state_enc = torch.from_numpy(encode_board(node.state.board, node.state.current_turn)).float().unsqueeze(0).to(device)
                    with torch.no_grad():
                        logits, val = self.net(state_enc)
                        logits = logits.cpu().numpy().flatten()
                    mask = np.zeros_like(logits)
                    mask[moves_idx] = 1
                    logits = logits * mask - 1e8 * (1 - mask)
                    probs = np.exp(logits - logits.max())
                    probs = probs / probs.sum()
                    for idx in moves_idx:
                        child_state = node.state.copy()
                        fr,fc,tr,tc = idx_to_action(idx)[0][0], idx_to_action(idx)[0][1], idx_to_action(idx)[1][0], idx_to_action(idx)[1][1]
                        child_state.make_move((fr,fc),(tr,tc))
                        node.children.append(MCTSNode(child_state, parent=node, move=idx, prior=probs[idx]))
                    node = random.choice(node.children)
                    path.append(node)
                    leaf_val = -val.item()
                else:
                    leaf_val = -1.0
            else:
                w = node.state.get_winner()
                leaf_val = 0.0 if w is None else (1.0 if w == self.color else -1.0)
            for n in reversed(path):
                n.visits += 1
                n.value_sum += leaf_val
        visits = np.array([c.visits for c in root.children])
        if self.temperature == 0:
            best_idx = np.argmax(visits)
            probs = np.zeros(len(visits))
            probs[best_idx] = 1.0
        else:
            visits = visits ** (1.0 / self.temperature)
            probs = visits / visits.sum()
        move_probs = {}
        for c, p in zip(root.children, probs):
            move_probs[c.move] = p
        return move_probs

    def get_move(self, board):
        move_probs = self.get_move_probs(board)
        if not move_probs: return None
        moves = list(move_probs.keys())
        probs = list(move_probs.values())
        chosen = np.random.choice(moves, p=probs)
        return idx_to_action(chosen)

class ChessGame:
    def __init__(self, root):
        self.root = root
        self.root.title("中国象棋 - 智能推演与训练中心")
        self.root.configure(bg="#F0F0F0")
        self.root.geometry("1100x780")
        self.net = AlphaZeroNet().to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=LEARNING_RATE)
        try:
            self.net.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
        except:
            pass
        self.board = ChessBoard()
        self.mcts_red = MCTSPlayer('R', self.net, temperature=1.0)
        self.mcts_black = MCTSPlayer('B', self.net, temperature=1.0)
        self.flipped = False
        self.setup_mode = False
        self.setup_piece_selected = tk.StringVar(value="")
        main_frame = tk.Frame(self.root, bg="#F0F0F0")
        main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        self.canvas_w = (COLS-1)*SQUARE_SIZE + 2*MARGIN
        self.canvas_h = (ROWS-1)*SQUARE_SIZE + 2*MARGIN
        self.canvas = tk.Canvas(main_frame, width=self.canvas_w, height=self.canvas_h, bg=BOARD_BG,
                                highlightthickness=3, highlightbackground="#5A3A22")
        self.canvas.pack(side=tk.LEFT, padx=10)
        self.canvas.bind("<Button-1>", self.on_click)
        right_panel = tk.Frame(main_frame, bg="#F0F0F0", width=340)
        right_panel.pack(side=tk.RIGHT, fill=tk.Y, padx=10)
        right_panel.pack_propagate(False)
        ctl_frame = tk.LabelFrame(right_panel, text="基础控制", bg="#F0F0F0", font=FONT_UI)
        ctl_frame.pack(fill=tk.X, pady=5)
        tk.Button(ctl_frame, text="翻转视角", command=self.toggle_flip, width=12).grid(row=0, column=0, padx=8, pady=5)
        tk.Button(ctl_frame, text="重置对局", command=self.reset_game, width=12).grid(row=0, column=1, padx=8, pady=5)
        mode_frame = tk.LabelFrame(right_panel, text="引擎模式", bg="#F0F0F0", font=FONT_UI)
        mode_frame.pack(fill=tk.X, pady=5)
        self.btn_train = tk.Button(mode_frame, text="启动闭环进化训练", command=self.toggle_training, bg="#4CAF50",
                                   fg="white", font=("微软雅黑",10,"bold"))
        self.btn_train.pack(fill=tk.X, padx=8, pady=4)
        tk.Button(mode_frame, text="观战: AI 左右互搏", command=self.start_ai_vs_ai, bg="#FF9800", fg="white",
                  font=("微软雅黑",10,"bold")).pack(fill=tk.X, padx=8, pady=4)
        h_frame = tk.Frame(mode_frame, bg="#F0F0F0")
        h_frame.pack(fill=tk.X, pady=4)
        self.human_color = tk.StringVar(value='R')
        tk.Radiobutton(h_frame, text="执红", variable=self.human_color, value='R', bg="#F0F0F0").pack(side=tk.LEFT, padx=5)
        tk.Radiobutton(h_frame, text="执黑", variable=self.human_color, value='B', bg="#F0F0F0").pack(side=tk.LEFT, padx=5)
        self.collect_var = tk.BooleanVar(value=False)
        tk.Checkbutton(h_frame, text="记录对局", variable=self.collect_var, bg="#F0F0F0").pack(side=tk.LEFT, padx=5)
        tk.Button(h_frame, text="人机对战", command=self.start_human_vs_ai, bg="#2196F3", fg="white",
                  font=("微软雅黑",10,"bold")).pack(side=tk.RIGHT, padx=8)
        self.setup_frame = tk.LabelFrame(right_panel, text="沙盘推演 (残局摆放)", bg="#F0F0F0", font=FONT_UI)
        self.setup_frame.pack(fill=tk.X, pady=5)
        self.btn_setup = tk.Button(self.setup_frame, text="进入摆放模式", command=self.toggle_setup, bg="#9C27B0", fg="white")
        self.btn_setup.pack(fill=tk.X, padx=8, pady=5)
        self.palette_frame = tk.Frame(self.setup_frame, bg="#F0F0F0")
        pcs = [('RG','帥'),('RR','紅車'),('RN','紅馬'),('RC','紅炮'),('RP','兵'),
               ('BG','將'),('BR','黑車'),('BN','黑馬'),('BC','黑炮'),('BP','卒'),('','橡皮擦')]
        for i,(val,txt) in enumerate(pcs):
            tk.Radiobutton(self.palette_frame, text=txt, variable=self.setup_piece_selected, value=val,
                           bg="#F0F0F0").grid(row=i//3, column=i%3, sticky=tk.W)
        self.info_label = tk.Label(right_panel, text="就绪 | 等待指令", bg="#F0F0F0", font=("微软雅黑",10), fg="blue")
        self.info_label.pack(pady=5)
        log_frame = tk.LabelFrame(right_panel, text="实时推演日志 (支持复制)", bg="#F0F0F0", font=FONT_UI)
        log_frame.pack(fill=tk.BOTH, expand=True, pady=5)
        self.move_log = tk.Text(log_frame, font=("微软雅黑",10), state=tk.DISABLED, width=30, bg="#FFFFFF")
        self.move_log.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(5,0), pady=5)
        scroll = tk.Scrollbar(log_frame, command=self.move_log.yview)
        scroll.pack(side=tk.RIGHT, fill=tk.Y, pady=5, padx=(0,5))
        self.move_log.config(yscrollcommand=scroll.set)
        self.selected = None
        self.game_mode = None
        self.ai_thinking = False
        self.training_active = False
        self.collect_human_data = False
        self.experience_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
        self.current_game_data = []
        self.games_played = 0
        self.total_loss = 0.0
        self.draw_board()
        self.root.protocol("WM_DELETE_WINDOW", self.on_close)

    def append_log(self, text):
        self.move_log.config(state=tk.NORMAL)
        self.move_log.insert(tk.END, text+"\n")
        self.move_log.see(tk.END)
        self.move_log.config(state=tk.DISABLED)

    def toggle_flip(self):
        self.flipped = not self.flipped
        self.selected = None
        self.draw_board()

    def toggle_setup(self):
        if not self.setup_mode:
            self.setup_mode = True
            self.game_mode = None
            if self.training_active: self.toggle_training()
            self.btn_setup.config(text="完成摆放", bg="#E91E63")
            self.palette_frame.pack(fill=tk.X, padx=5, pady=5)
            self.selected = None
            self.board.board = [['' for _ in range(COLS)] for _ in range(ROWS)]
            self.append_log("--- 进入沙盘模式 ---")
            self.draw_board()
        else:
            red_g = sum(1 for r in range(ROWS) for c in range(COLS) if self.board.board[r][c]=='RG')
            black_g = sum(1 for r in range(ROWS) for c in range(COLS) if self.board.board[r][c]=='BG')
            if red_g != 1 or black_g != 1:
                messagebox.showerror("错误","沙盘上必须且只能有一个将和一个帅!")
                return
            self.setup_mode = False
            self.btn_setup.config(text="进入摆放模式", bg="#9C27B0")
            self.palette_frame.pack_forget()
            self.board.history.clear()
            self.board.move_count = 0
            self.board.position_count.clear()
            self.board.history_states.clear()
            self.current_game_data.clear()
            self.append_log("--- 沙盘摆放完成 ---")
            self.draw_board()

    def logical_to_screen(self, r, c):
        return (9-r, 8-c) if self.flipped else (r,c)

    def screen_to_logical(self, sr, sc):
        return (9-sr, 8-sc) if self.flipped else (sr,sc)

    def draw_board(self):
        self.canvas.delete("all")
        grid_w = (COLS-1)*SQUARE_SIZE
        grid_h = (ROWS-1)*SQUARE_SIZE
        self.canvas.create_rectangle(MARGIN-5, MARGIN-5, MARGIN+grid_w+5, MARGIN+grid_h+5, outline=LINE_COLOR, width=4)
        for i in range(ROWS):
            y = MARGIN + i*SQUARE_SIZE
            self.canvas.create_line(MARGIN, y, MARGIN+grid_w, y, fill=LINE_COLOR, width=2)
        for i in range(COLS):
            x = MARGIN + i*SQUARE_SIZE
            if i==0 or i==COLS-1:
                self.canvas.create_line(x, MARGIN, x, MARGIN+grid_h, fill=LINE_COLOR, width=2)
            else:
                self.canvas.create_line(x, MARGIN, x, MARGIN+4*SQUARE_SIZE, fill=LINE_COLOR, width=2)
                self.canvas.create_line(x, MARGIN+5*SQUARE_SIZE, x, MARGIN+grid_h, fill=LINE_COLOR, width=2)
        self.canvas.create_text(MARGIN+grid_w//2, MARGIN+4.5*SQUARE_SIZE, text="楚 河         汉 界",
                                font=("楷体",28,"bold"), fill=LINE_COLOR)
        for base_r in [0,7]:
            top = MARGIN + base_r*SQUARE_SIZE
            bot = MARGIN + (base_r+2)*SQUARE_SIZE
            left = MARGIN + 3*SQUARE_SIZE
            right = MARGIN + 5*SQUARE_SIZE
            self.canvas.create_line(left, top, right, bot, fill=LINE_COLOR, width=2)
            self.canvas.create_line(right, top, left, bot, fill=LINE_COLOR, width=2)
        def draw_mark(sr,sc):
            x = MARGIN + sc*SQUARE_SIZE
            y = MARGIN + sr*SQUARE_SIZE
            d, length = 5, 12
            if sc>0:
                self.canvas.create_line(x-d, y-d, x-d-length, y-d, fill=LINE_COLOR, width=2)
                self.canvas.create_line(x-d, y-d, x-d, y-d-length, fill=LINE_COLOR, width=2)
                self.canvas.create_line(x-d, y+d, x-d-length, y+d, fill=LINE_COLOR, width=2)
                self.canvas.create_line(x-d, y+d, x-d, y+d+length, fill=LINE_COLOR, width=2)
            if sc<8:
                self.canvas.create_line(x+d, y-d, x+d+length, y-d, fill=LINE_COLOR, width=2)
                self.canvas.create_line(x+d, y-d, x+d, y-d-length, fill=LINE_COLOR, width=2)
                self.canvas.create_line(x+d, y+d, x+d+length, y+d, fill=LINE_COLOR, width=2)
                self.canvas.create_line(x+d, y+d, x+d, y+d+length, fill=LINE_COLOR, width=2)
        marks = [(2,1),(2,7),(7,1),(7,7),(3,0),(3,2),(3,4),(3,6),(3,8),(6,0),(6,2),(6,4),(6,6),(6,8)]
        for mr,mc in marks:
            sr,sc = self.logical_to_screen(mr,mc)
            draw_mark(sr,sc)
        for r in range(ROWS):
            for c in range(COLS):
                piece = self.board.board[r][c]
                if piece:
                    sr,sc = self.logical_to_screen(r,c)
                    x = MARGIN + sc*SQUARE_SIZE
                    y = MARGIN + sr*SQUARE_SIZE
                    color = TEXT_RED if piece[0]=='R' else TEXT_BLACK
                    txt = PIECES[piece[0]][piece[1]]
                    self.canvas.create_oval(x-28, y-28, x+29, y+29, fill="#A56D3D", outline="")
                    self.canvas.create_oval(x-27, y-27, x+27, y+27, fill="#F5D098", outline=LINE_COLOR, width=1)
                    self.canvas.create_oval(x-21, y-21, x+21, y+21, outline=color, width=2)
                    self.canvas.create_text(x, y, text=txt, fill=color, font=("楷体",24,"bold"))
        if self.selected:
            sr,sc = self.logical_to_screen(self.selected[0],self.selected[1])
            x = MARGIN + sc*SQUARE_SIZE
            y = MARGIN + sr*SQUARE_SIZE
            self.canvas.create_rectangle(x-32, y-32, x+32, y+32, outline="#4CAF50", width=4, dash=(4,4))
        self.update_status()

    def update_status(self):
        if self.game_mode:
            t = "当前回合: " + ("红方" if self.board.current_turn=='R' else "黑方")
            if self.ai_thinking: t += " [AI 深度推演中...]"
            elif self.training_active: t += " [后台进化中]"
            self.info_label.config(text=t)

    def on_click(self, event):
        sc = round((event.x - MARGIN) / SQUARE_SIZE)
        sr = round((event.y - MARGIN) / SQUARE_SIZE)
        if not (0 <= sr < ROWS and 0 <= sc < COLS): return
        r,c = self.screen_to_logical(sr,sc)
        if self.setup_mode:
            self.board.board[r][c] = self.setup_piece_selected.get()
            self.draw_board()
            return
        if self.game_mode != 'human_vs_ai' or self.ai_thinking: return
        pcolor = self.human_color.get()
        if self.board.current_turn != pcolor: return
        if self.selected is None:
            if self.board.board[r][c] and self.board.board[r][c][0]==pcolor:
                self.selected = (r,c)
                self.draw_board()
        else:
            sel_r,sel_c = self.selected
            if self.board.is_valid_move(sel_r,sel_c,r,c):
                desc = self.board.get_chinese_move((sel_r,sel_c),(r,c),pcolor)
                self.board.make_move((sel_r,sel_c),(r,c))
                self.selected = None
                self.draw_board()
                self.append_log(f"{'红' if pcolor=='R' else '黑'}: {desc}")
                if self.check_game_over(): return
                self.ai_thinking = True
                self.update_status()
                threading.Thread(target=self._compute_ai_move_async, daemon=True).start()
            else:
                self.selected = None
                self.draw_board()

    def start_ai_vs_ai(self):
        if self.setup_mode: return
        if self.training_active: self.toggle_training()
        self.reset_game()
        self.game_mode = 'ai_vs_ai'
        self.ai_thinking = True
        self.append_log("--- AI 左右互搏开始 ---")
        threading.Thread(target=self._compute_ai_move_async, daemon=True).start()

    def start_human_vs_ai(self):
        if self.setup_mode: return
        if self.training_active: self.toggle_training()
        self.reset_game()
        self.game_mode = 'human_vs_ai'
        self.collect_human_data = self.collect_var.get()
        self.selected = None
        self.append_log("--- 人机对战开始 ---")
        self.draw_board()
        if self.human_color.get() == 'B':
            self.ai_thinking = True
            threading.Thread(target=self._compute_ai_move_async, daemon=True).start()

    def toggle_training(self):
        if self.setup_mode: return
        if not self.training_active:
            self.training_active = True
            self.btn_train.config(text="停止闭环进化训练", bg="#F44336")
            self.game_mode = 'ai_vs_ai'
            self.reset_game(keep_mode=True)
            self.games_played = 0
            self.total_loss = 0.0
            self.experience_buffer.clear()
            self.ai_thinking = True
            self.append_log("--- 闭环训练启动 ---")
            threading.Thread(target=self._compute_ai_move_async, daemon=True).start()
        else:
            self.training_active = False
            self.btn_train.config(text="启动闭环进化训练", bg="#4CAF50")
            self.game_mode = None
            self.info_label.config(text="训练已手动终止")

    def _compute_ai_move_async(self):
        if self.board.is_game_over():
            self.root.after(0, self.check_game_over)
            return
        temp = 1.0 if self.board.move_count < TEMPERATURE_THRESHOLD else 0.1
        player = self.mcts_red if self.board.current_turn == 'R' else self.mcts_black
        player.temperature = temp
        move_probs = player.get_move_probs(self.board)
        if self.training_active or (self.game_mode == 'human_vs_ai' and self.collect_human_data):
            state_enc = encode_board(self.board.board, self.board.current_turn)
            self.current_game_data.append((state_enc, move_probs))
        move = player.get_move(self.board)
        self.root.after(0, self._apply_ai_move, move, self.board.current_turn)

    def _apply_ai_move(self, move, color):
        if self.game_mode is None:
            self.ai_thinking = False
            return
        if move:
            desc = self.board.get_chinese_move(move[0], move[1], color)
            self.board.make_move(move[0], move[1])
            self.draw_board()
            self.append_log(f"{'红' if color=='R' else '黑'} (AI): {desc}")
            if self.check_game_over(): return
            if self.game_mode == 'ai_vs_ai' or (self.game_mode == 'human_vs_ai' and self.board.current_turn != self.human_color.get()):
                threading.Thread(target=self._compute_ai_move_async, daemon=True).start()
            else:
                self.ai_thinking = False
                self.update_status()
        else:
            self.ai_thinking = False
            self.check_game_over()

    def check_game_over(self):
        over = False
        winner = None
        reason = ""
        if self.board._are_generals_facing():
            over = True
            winner = 'R' if self.board.current_turn=='B' else 'B'
            reason = "将帅照面,主动送吃!"
        elif not self.board.has_legal_moves(self.board.current_turn):
            over = True
            winner = 'B' if self.board.current_turn=='R' else 'R'
            if self.board._is_general_in_check(self.board.current_turn): reason = "绝杀无解!"
            else: reason = "无子可动,困毙!"
        elif self.board.is_game_over():
            over = True
            winner = self.board.get_winner()
            reason = "回合上限或重复局面"
        if over:
            self.ai_thinking = False
            msg = "和棋!" if winner is None else f"{'红方' if winner=='R' else '黑方'} 获胜!({reason})"
            self.append_log(f"*** 对局结束: {msg} ***")
            if self.training_active or (self.game_mode == 'human_vs_ai' and self.collect_human_data):
                game_data = list(self.current_game_data)
                threading.Thread(target=self._async_train, args=(game_data, winner), daemon=True).start()
            else:
                messagebox.showinfo("对局结束", msg)
                self.game_mode = None
                self.update_status()
            return True
        return False

    def _async_train(self, game_data, winner):
        for state_enc, move_probs in game_data:
            target_value = 1.0 if self.board.current_turn == winner else -1.0
            self.experience_buffer.append((state_enc, move_probs, target_value))
        total_loss = 0.0
        count = 0
        for _ in range(TRAIN_PER_GAME):
            if len(self.experience_buffer) < BATCH_SIZE: break
            batch = random.sample(self.experience_buffer, BATCH_SIZE)
            states = torch.from_numpy(np.array([b[0] for b in batch])).float().to(device)
            target_values = torch.from_numpy(np.array([b[2] for b in batch])).float().to(device)
            target_policies = []
            for b in batch:
                probs = np.zeros(ROWS*COLS*ROWS*COLS)
                for move_idx, p in b[1].items():
                    probs[move_idx] = p
                target_policies.append(probs)
            target_policies = torch.from_numpy(np.array(target_policies)).float().to(device)
            self.optimizer.zero_grad()
            logits, values = self.net(states)
            policy_loss = -torch.mean(torch.sum(target_policies * F.log_softmax(logits, dim=1), dim=1))
            value_loss = F.mse_loss(values, target_values)
            loss = policy_loss + value_loss
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            count += 1
        self.root.after(0, self._post_train_update, count, total_loss)

    def _post_train_update(self, count, total_loss):
        self.games_played += 1
        if count > 0:
            avg_loss = total_loss / count
            self.total_loss = 0.9 * self.total_loss + 0.1 * avg_loss
            self.append_log(f"--- 训练完成 | 均损: {avg_loss:.4f} ---")
        if self.training_active:
            self.reset_game(keep_mode=True)
            self.ai_thinking = True
            threading.Thread(target=self._compute_ai_move_async, daemon=True).start()

    def reset_game(self, keep_mode=False):
        if self.training_active and not keep_mode: self.toggle_training()
        self.board = ChessBoard()
        self.selected = None
        self.current_game_data.clear()
        if not keep_mode:
            self.game_mode = None
            self.move_log.config(state=tk.NORMAL)
            self.move_log.delete(1.0, tk.END)
            self.move_log.config(state=tk.DISABLED)
        self.ai_thinking = False
        self.draw_board()

    def on_close(self):
        torch.save(self.net.state_dict(), MODEL_SAVE_PATH)
        self.game_mode = None
        self.training_active = False
        self.root.destroy()

if __name__ == "__main__":
    root = tk.Tk()
    game = ChessGame(root)
    root.mainloop()
Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐