强化学习¶
强化学习通过试错来训练智能体,使其通过最大化累积奖励来做出序贯决策。本文件涵盖马尔可夫决策过程、价值函数、贝尔曼方程、Q学习、策略梯度、演员-评论家方法、近端策略优化以及基于人类反馈的强化学习——这是游戏智能体和语言模型对齐背后的框架。
-
监督学习需要带标签的数据。无监督学习在无标签数据中发现模式。强化学习与前两者都不同:智能体通过与环境的交互、采取行动并获得奖励来学习。这里没有正确的标签;智能体必须通过试错来发现良好的行为。
-
想象教一只狗新技能。你不会向它展示一个正确行为的数据集。相反,它会尝试各种事情,你对好的动作给予奖励,随着时间的推移,它会明白你想要什么。强化学习将这个形式化。
-
强化学习的设定包括五个核心组件。智能体是学习者和决策者。环境是智能体外部的、与其交互的一切。在每个时间步,智能体观察到状态 \(s_t\),选择一个动作 \(a_t\),获得一个奖励 \(r_t\),并转移到新状态 \(s_{t+1}\)。智能体的目标是最大化随时间收集的总奖励。
-
策略 \(\pi\) 是智能体的策略:从状态到动作的映射。确定性策略给每个状态一个动作:\(a = \pi(s)\)。随机策略给出动作上的概率分布:\(\pi(a \mid s)\)。强化学习的目标是找到最优策略,即最大化期望累积奖励的策略。
-
强化学习的数学框架是马尔可夫决策过程,由元组 \((S, A, P, R, \gamma)\) 定义:状态集合 \(S\),动作集合 \(A\),转移概率 \(P(s' \mid s, a)\),奖励函数 \(R(s, a)\),以及折扣因子 \(\gamma\)。
-
马尔可夫性质(来自第05章)表示未来只依赖于当前状态,而不依赖于到达该状态的历史:\(P(s_{t+1} \mid s_t, a_t, s_{t-1}, \ldots) = P(s_{t+1} \mid s_t, a_t)\)。这意味着状态包含了决策所需的所有信息。
-
折扣因子 \(\gamma \in [0, 1)\) 决定了智能体在多大程度上关心未来奖励而非当前奖励。从时间 \(t\) 开始的折扣回报为:
-
当 \(\gamma = 0\) 时,智能体完全短视,只关心下一个奖励。当 \(\gamma\) 接近1时,智能体更具远见。折扣因子也保证了求和收敛(如果奖励有界),这对数学上的良好定义很重要。
-
价值函数估计处于某个状态(或在某个状态采取某个动作)的好坏程度。状态价值函数 \(V^\pi(s)\) 是从状态 \(s\) 出发并遵循策略 \(\pi\) 的期望回报:
- 动作价值函数 \(Q^\pi(s, a)\) 是从状态 \(s\) 出发、采取动作 \(a\),然后遵循策略 \(\pi\) 的期望回报:
-
两者关系:\(V^\pi(s) = \sum_a \pi(a \mid s) \, Q^\pi(s, a)\)。状态价值是动作价值按策略加权的平均。
-
贝尔曼方程表达了递归关系:一个状态的价值等于即时奖励加上下一状态的折扣价值。对于状态价值函数:
- 对于最优价值函数 \(V^{*}(s)\),智能体总是选择最佳动作:
- 类似地,\(Q^{*}\) 的贝尔曼最优方程:
-
一旦得到 \(Q^{*}\),最优策略就很简单:总是选择 Q 值最高的动作:\(\pi^{*}(s) = \arg\max_a Q^{*}(s, a)\)。
-
动态规划方法在已知转移概率和奖励(即完整模型)时求解 MDP。策略评估通过迭代应用贝尔曼方程直到收敛,为给定策略计算 \(V^\pi\)。策略改进利用价值函数,通过贪婪地选择动作来构建更好的策略:\(\pi'(s) = \arg\max_a \sum_{s'} P(s' \mid s, a)[R(s,a) + \gamma V^\pi(s')]\)。
-
策略迭代交替进行评估和改进,直到策略不再变化。它保证收敛到最优策略。
-
价值迭代将两个步骤合二为一:反复应用贝尔曼最优方程直到 \(V^{*}\) 收敛,然后提取策略。
-
动态规划需要知道 \(P(s' \mid s, a)\),这在现实中通常不可行。在大多数实际问题中,智能体不知道环境的动态特性;它只能与环境交互。这时无模型方法就派上用场了。
-
时序差分学习无需知道模型,直接从经验中学习。关键思想是自举:不需要等到回合结束才计算实际回报 \(G_t\),而是用当前价值函数来估计它:
- 括号中的项是TD误差:TD目标(\(r_t + \gamma V(s_{t+1})\))与当前估计 \(V(s_t)\) 之间的差值。如果 TD 误差为正,说明状态比预期好,因此提高其价值;如果为负,则降低。
-
TD学习在每一步之后都进行更新(而不是在整个回合之后),这使其比蒙特卡洛方法高效得多。它还可以在持续(非回合制)环境中工作。
-
SARSA(状态-动作-奖励-状态-动作)是将TD学习应用于Q值的方法。智能体在状态 \(s\) 下采取动作 \(a\),观察到奖励 \(r\) 和下一状态 \(s'\),然后根据其策略选择下一个动作 \(a'\):
-
SARSA 是同策略的:它使用智能体实际采取的动作进行更新,其中包括探索。这使得 SARSA 更加保守;它学习到的策略会考虑到自身的探索噪声。
-
Q学习是最著名的强化学习算法。它与 SARSA 类似,但不同之处在于它不使用智能体实际采取的动作,而是使用最佳可能的动作:
-
Q学习是异策略的:无论当前遵循什么策略,它都能学到最优的 Q 值。智能体可以随机探索,同时仍能学到最优的动作价值。这使得 Q学习更加激进,通常收敛更快,但可能高估价值。
-
探索与利用是根本的困境:智能体应该利用已知信息(选择估计价值最高的动作),还是探索未知的动作(可能结果更好)?
-
最简单的策略是 epsilon-greedy:以概率 \(\epsilon\) 随机选择动作(探索);以概率 \(1 - \epsilon\) 选择贪婪动作(利用)。常见的调度是从较高的 \(\epsilon\) 开始(大量探索),随后随时间衰减。
-
表格方法(在表格中为每个状态-动作对存储一个值)适用于小的、离散的状态空间。对于大的或连续的状态空间,需要函数近似。深度Q网络使用神经网络来近似 \(Q(s, a; \theta)\),其中 \(\theta\) 是网络权重。
-
DQN 引入了两个关键的稳定技术。经验回放:不使用连续过渡(高度相关)进行学习,而是将过渡存储在回放缓冲区中,并随机采样小批量进行训练。这打破了相关性,并高效地复用数据。
-
目标网络:使用一个单独的、缓慢更新的网络副本来计算 TD 目标。如果没有这个,每次更新网络时目标都会移动,造成“追逐自己的尾巴”的不稳定性。目标网络定期更新(每 \(N\) 步硬更新)或连续更新(软更新:\(\theta^{-} \leftarrow \tau\theta + (1-\tau)\theta^{-}\))。
-
DQN 的损失就是预测 Q 值与 TD 目标之间的均方误差:
-
到目前为止,所有方法都学习价值函数并从中推导出策略。策略梯度方法采用不同的思路:它们直接参数化策略 \(\pi(a \mid s; \theta)\),并通过期望回报上的梯度上升来优化它。
-
策略梯度定理给出了期望回报对策略参数的梯度:
-
这意味着:增加导致高回报的动作的概率,减少导致低回报的动作的概率。对数概率梯度给出了改变策略的方向,\(G_t\) 缩放改变的幅度。
-
REINFORCE 是最简单的策略梯度算法。运行一个回合,计算每个步骤的回报 \(G_t\),然后更新:
- REINFORCE 的方差很高,因为 \(G_t\) 是期望回报的有噪声的、单样本估计。一个常见的修复方法是减去一个基线(通常是平均回报或学习到的价值函数),以在不引入偏差的情况下降低方差:
- 演员-评论家方法使用两个网络。演员是策略 \(\pi(a \mid s; \theta)\)。评论家是一个价值函数 \(V(s; \phi)\),作为基线。优势 \(A_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) 取代了 \(G_t - b\):
- 评论家通过最小化 TD 误差来更新,类似于基于价值的方法。演员使用策略梯度更新,评论家的优势估计降低了方差。这是两全其美的方案。
-
近端策略优化是实际中最广泛使用的策略梯度算法。它解决了一个关键问题:如果策略更新太大,性能可能灾难性地崩溃。
-
PPO 使用一个裁剪的替代目标函数。令 \(r_t(\theta) = \frac{\pi(a_t | s_t; \theta)}{\pi(a_t | s_t; \theta_{\text{old}})}\) 为新旧策略之间的概率比。损失函数为:
-
裁剪(通常 \(\epsilon = 0.2\))防止概率比偏离1太远,从而保持更新小而稳定。如果优势为正(动作好),比率被限制在 \(1 + \epsilon\);如果优势为负(动作差),比率被限制在 \(1 - \epsilon\)。这比早期的信赖域方法(TRPO)更简单、更稳定。
-
PPO 正是用于通过基于人类反馈的强化学习训练 ChatGPT 风格模型的方法。在 RLHF 中,首先在人类偏好数据上训练一个奖励模型(两个输出中人类更喜欢哪一个?),然后 PPO 优化语言模型的策略以最大化这个学习到的奖励。
-
直接偏好优化通过完全消除奖励模型来简化 RLHF。DPO 不训练奖励模型然后运行强化学习,而是推导出一个闭式损失,直接从偏好数据中优化策略:
-
这里 \(y_w\) 是被偏好的(获胜的)响应,\(y_l\) 是不被偏好的(失败的)响应。DPO 增加了偏好输出的相对概率,并且实现上比基于 PPO 的 RLHF 简单得多。
-
强化学习算法中的两个重要区别。同策略 vs 异策略:同策略方法(SARSA、PPO)从当前策略生成的数据中学习;异策略方法(Q学习、DQN)可以从任何策略生成的数据中学习。异策略方法的样本效率更高(可以复用旧数据),但可能不太稳定。
-
基于模型 vs 无模型:无模型方法(到目前为止讨论的所有方法)直接从经验中学习价值或策略。基于模型的方法学习环境的模型(\(P(s' \mid s, a)\) 和 \(R(s, a)\)),并用它进行规划(在不实际采取动作的情况下想象未来的轨迹)。基于模型的方法样本效率更高,但增加了学习准确模型的复杂性。
-
强化学习领域概览:
| 方法 | 类型 | 核心思想 | 优势 |
|---|---|---|---|
| 价值迭代 | DP, 基于模型 | 贝尔曼最优性 | 精确解(小MDP) |
| SARSA | TD, 同策略 | 同策略学习Q值 | 保守、安全 |
| Q学习 | TD, 异策略 | 学习Q*,贪婪目标 | 简单、有效 |
| DQN | 深度, 异策略 | 神经网络Q + 回放 + 目标网络 | 可扩展至高维状态 |
| REINFORCE | 策略梯度 | 对数概率*回报的梯度 | 简单的策略优化 |
| 演员-评论家 | PG + 价值 | 演员 + 评论家,低方差 | 实用、灵活 |
| PPO | PG, 裁剪 | 类似信赖域的稳定性 | 行业标准 |
| DPO | 直接偏好 | 跳过奖励模型 | 更简单的RLHF |
编码任务(使用 CoLab 或 notebook)¶
-
为一个简单的网格世界实现价值迭代。计算最优价值函数并提取最优策略。用热力图和箭头图将两者可视化。
import jax.numpy as jnp import matplotlib.pyplot as plt # 4x4网格世界:目标在(3,3),每步奖励-1,目标处为0 grid_size = 4 gamma = 0.99 goal = (3, 3) # 动作:上、下、左、右 actions = [(-1, 0), (1, 0), (0, -1), (0, 1)] action_names = ['up', 'down', 'left', 'right'] action_arrows = ['\u2191', '\u2193', '\u2190', '\u2192'] def step(s, a): """确定性转移。""" ns = (max(0, min(grid_size-1, s[0]+a[0])), max(0, min(grid_size-1, s[1]+a[1]))) return ns # 价值迭代 V = jnp.zeros((grid_size, grid_size)) for iteration in range(100): V_new = jnp.array(V) for i in range(grid_size): for j in range(grid_size): if (i, j) == goal: continue values = [] for a in actions: ns = step((i, j), a) values.append(-1 + gamma * float(V[ns[0], ns[1]])) V_new = V_new.at[i, j].set(max(values)) if jnp.max(jnp.abs(V_new - V)) < 1e-6: print(f"在 {iteration+1} 次迭代后收敛") break V = V_new # 提取策略 policy = [['' for _ in range(grid_size)] for _ in range(grid_size)] for i in range(grid_size): for j in range(grid_size): if (i, j) == goal: policy[i][j] = 'G' continue best_a = max(range(4), key=lambda a: -1 + gamma * float(V[step((i,j), actions[a])[0], step((i,j), actions[a])[1]])) policy[i][j] = action_arrows[best_a] fig, axes = plt.subplots(1, 2, figsize=(10, 4)) im = axes[0].imshow(V, cmap='YlOrRd_r') axes[0].set_title("最优价值函数") for i in range(grid_size): for j in range(grid_size): axes[0].text(j, i, f"{V[i,j]:.1f}", ha='center', va='center', fontsize=10) plt.colorbar(im, ax=axes[0]) axes[1].imshow(jnp.ones((grid_size, grid_size)), cmap='Greys', vmin=0, vmax=2) axes[1].set_title("最优策略") for i in range(grid_size): for j in range(grid_size): axes[1].text(j, i, policy[i][j], ha='center', va='center', fontsize=18) plt.tight_layout(); plt.show() -
在一个简单的网格世界上实现表格型Q学习。训练智能体,绘制学习曲线,并显示学到的Q值。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt grid_size = 5 goal = (4, 4) actions = [(-1,0), (1,0), (0,-1), (0,1)] # Q表 Q = {} for i in range(grid_size): for j in range(grid_size): Q[(i,j)] = [0.0] * 4 alpha = 0.1 gamma = 0.95 epsilon = 1.0 epsilon_decay = 0.995 min_epsilon = 0.01 def step(s, a_idx): a = actions[a_idx] ns = (max(0, min(grid_size-1, s[0]+a[0])), max(0, min(grid_size-1, s[1]+a[1]))) r = 0.0 if ns == goal else -1.0 done = ns == goal return ns, r, done key = jax.random.PRNGKey(42) rewards_per_episode = [] for ep in range(500): s = (0, 0) total_reward = 0 for _ in range(100): key, subkey = jax.random.split(key) if float(jax.random.uniform(subkey)) < epsilon: key, subkey = jax.random.split(key) a = int(jax.random.randint(subkey, (), 0, 4)) else: a = max(range(4), key=lambda i: Q[s][i]) ns, r, done = step(s, a) total_reward += r # Q学习更新 Q[s][a] += alpha * (r + gamma * max(Q[ns]) - Q[s][a]) s = ns if done: break rewards_per_episode.append(total_reward) epsilon = max(min_epsilon, epsilon * epsilon_decay) plt.figure(figsize=(8, 4)) # 平滑曲线 window = 20 smoothed = [sum(rewards_per_episode[max(0,i-window):i+1])/min(i+1, window) for i in range(len(rewards_per_episode))] plt.plot(smoothed, color='#3498db', linewidth=1.5) plt.xlabel("回合"); plt.ylabel("总奖励(平滑后)") plt.title("网格世界上的Q学习") plt.grid(alpha=0.3); plt.show() # 显示学到的策略 arrow = ['\u2191', '\u2193', '\u2190', '\u2192'] print("学到的策略:") for i in range(grid_size): row = "" for j in range(grid_size): if (i,j) == goal: row += " G " else: row += f" {arrow[max(range(4), key=lambda a: Q[(i,j)][a])]} " print(row) -
在一个多臂赌博机问题上实现REINFORCE。展示策略在训练过程中如何演变,以偏好最佳臂。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt # 5臂赌博机,不同期望奖励 true_rewards = jnp.array([0.2, 0.5, 0.8, 0.3, 0.1]) n_arms = len(true_rewards) # 策略:logits上的softmax logits = jnp.zeros(n_arms) lr = 0.1 key = jax.random.PRNGKey(42) policy_history = [] reward_history = [] for step in range(2000): probs = jax.nn.softmax(logits) policy_history.append(probs) # 采样动作 key, subkey = jax.random.split(key) action = jax.random.choice(subkey, n_arms, p=probs) # 获取奖励(伯努利) key, subkey = jax.random.split(key) reward = float(jax.random.uniform(subkey) < true_rewards[action]) reward_history.append(reward) # REINFORCE更新 # grad log pi(a) = e_a - probs (对于softmax参数化) grad_log_pi = -probs.at[action].add(1.0) # one-hot(a) - probs logits = logits + lr * reward * grad_log_pi policy_history = jnp.stack(policy_history) fig, axes = plt.subplots(1, 2, figsize=(12, 4)) colors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6', '#f39c12'] for i in range(n_arms): axes[0].plot(policy_history[:, i], color=colors[i], label=f'臂 {i} (真实={true_rewards[i]:.1f})', linewidth=1.5) axes[0].set_xlabel("步数"); axes[0].set_ylabel("P(臂)") axes[0].set_title("策略演变 (REINFORCE)") axes[0].legend(fontsize=8); axes[0].grid(alpha=0.3) # 平滑奖励 window = 50 smoothed = [sum(reward_history[max(0,i-window):i+1])/min(i+1,window) for i in range(len(reward_history))] axes[1].plot(smoothed, color='#27ae60', linewidth=1.5) axes[1].axhline(y=0.8, color='#e74c3c', linestyle='--', alpha=0.5, label='最佳臂') axes[1].set_xlabel("步数"); axes[1].set_ylabel("平均奖励") axes[1].set_title("随时间变化的奖励"); axes[1].legend() axes[1].grid(alpha=0.3) plt.tight_layout(); plt.show()