Skip to content

梯度机器学习

基于梯度的学习通过沿着损失曲面的斜率迭代更新模型参数来优化模型。本文件涵盖线性回归、逻辑回归、Softmax分类、梯度下降变体、正则化(L1/L2)以及偏差-方差权衡

  • 文件01中的经典方法使用巧妙的启发式或闭式解。而本文件介绍通过沿着梯度下降、在损失曲面上小步下山直到找到良好参数的算法。梯度学习是从线性回归到最大型神经网络的引擎。

  • 线性回归是最简单的基于梯度的模型,同时也有闭式解,因此是完美的起点。该模型是一条直线(或高维空间中的超平面):

\[\hat{y} = w \cdot x + b = \sum_{i=1}^{d} w_i x_i + b\]
  • 用矩阵表示(来自第02章),如果将所有训练输入堆叠为矩阵 \(X\) 的行,并通过附加一列1将偏置吸收进 \(w\),则变为 \(\hat{y} = Xw\)

  • 目标是最小化均方误差(MSE),即预测值与实际值之间平均平方差:

\[\mathcal{L}(w) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 = \frac{1}{n} \|y - Xw\|^2\]
  • 为什么用平方误差?它有概率论依据:假设目标由 \(y = Xw + \epsilon\) 生成,其中 \(\epsilon \sim \mathcal{N}(0, \sigma^2)\),那么最大化数据的似然(第05章)等价于最小化MSE。平方误差对大错误的惩罚大于小错误,这通常是可取的。

散点图,数据点带有最佳拟合线和显示残差的垂直虚线

  • 由于MSE是 \(w\) 的二次函数,它有唯一的全局最小值,可以通过解析方式找到。求导、设为零并求解,得到正规方程
\[w^{*} = (X^T X)^{-1} X^T y\]
  • 这直接使用了第02章的矩阵逆。表达式 \(X^T X\) 是一个 \(d \times d\) 矩阵(\(d\) 是特征数),\(X^T y\) 是一个 \(d\) 维向量。正规方程一步给出精确的最优权重。

  • 正规方程何时失效?当 \(X^T X\) 奇异(不可逆)时,发生在特征线性相关或特征数大于样本数(\(d > n\))的情况下。此时需要使用正则化(后文介绍)或梯度下降。

  • 逻辑回归将线性模型应用于二分类。我们想要介于0和1之间的概率,而非连续值。Sigmoid函数将任意实数压缩到这个范围:

\[\sigma(z) = \frac{1}{1 + e^{-z}}\]
  • 模型计算 \(z = w \cdot x + b\)(与线性回归相同的线性得分),然后通过sigmoid:\(\hat{y} = \sigma(w \cdot x + b)\)。输出 \(\hat{y}\) 解释为 \(P(y = 1 \mid x)\)

Sigmoid曲线,标记了0.5阈值,显示“预测0”和“预测1”的分类区域

  • Sigmoid有很好的性质:\(\sigma(0) = 0.5\)\(\sigma(z) \to 1\)\(z \to \infty\)\(\sigma(z) \to 0\)\(z \to -\infty\),其导数为优雅的形式 \(\sigma'(z) = \sigma(z)(1 - \sigma(z))\)

  • 逻辑回归的损失函数是二元交叉熵(BCE),直接来自伯努利似然(第05章):

\[\mathcal{L} = -\frac{1}{n} \sum_{i=1}^{n} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]\]
  • 当真标签为1时,只有第一项起作用,惩罚低的预测值;当真标签为0时,只有第二项起作用,惩罚高的预测值。对数使得错误预测的惩罚非常陡峭:预测0.01而真值为1的成本远高于预测0.4。

  • 与线性回归的MSE不同,最小化BCE的权重没有闭式解。我们需要一种迭代方法:梯度下降

  • 梯度下降的直觉很简单:想象你站在雾气笼罩的山丘上(损失曲面),看不见全局最低点,但能感觉到脚下的坡度。你朝下坡迈出一步,再次感受坡度,重复这个过程,最终你会到达一个山谷。

\[w \leftarrow w - \eta \frac{\partial \mathcal{L}}{\partial w}\]
  • 学习率 \(\eta\) 控制你的步长。太大则越过山谷,来回震荡无法收敛;太小则缓慢前行,可能陷入局部最小值。

一维损失曲线,三个球:大学习率过冲,好的学习率收敛,小学习率卡住

  • 梯度 \(\frac{\partial \mathcal{L}}{\partial w}\) 是一个指向最陡上升方向的向量。我们减去它是因为我们想下坡。这就是第03章的链式法则应用于损失函数。

  • 批量梯度下降每一步使用整个训练集计算梯度。这给出精确的梯度,但当 \(n\) 很大时计算昂贵。

  • 随机梯度下降(SGD) 每一步使用单个随机样本。梯度是有噪声的(从单个样本估计真实梯度),但每一步极快。噪声实际上有助于逃离浅的局部最小值。

  • 小批量梯度下降折中:每一步使用一个批次 \(B\) 个样本(通常32、64或256)。这平衡了计算效率(对批次进行向量化操作)与梯度质量。几乎所有深度学习都使用小批量SGD。

  • 反向传播是我们在具有许多参数的模型(如神经网络)中实际计算梯度的方法。它是第03章的链式法则在计算图中的系统应用。

  • 任何模型都可以表示为有向无环图的操作:输入流入,与权重相乘,相加,通过非线性函数,最终产生损失值。前向传播通过从输入到输出的数据流计算输出(和损失)。

  • 反向传播(反向传递)使梯度反向流动。从损失开始,对于每个中间值,使用链式法则计算损失对该中间值的变化率。如果 \(L\) 依赖于 \(z\),而 \(z\) 依赖于 \(w\),则:

\[\frac{\partial L}{\partial w} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w}\]
  • 每个节点只需要知道自己的局部导数和从上方流入的梯度。这使得反向传播模块化且高效:成本大约是前向传播的两倍(一次前向,一次反向)。

  • 普通SGD有一个问题:它在具有陡峭曲率的方向上振荡,而在平坦方向上进展缓慢。优化器通过根据梯度历史自适应步长来改进。

  • 带动量的SGD 维护过去梯度的运行平均值(指数移动平均,来自第04章)。这平滑了振荡并加速了沿一致方向的进展:

\[v_t = \beta v_{t-1} + (1 - \beta) \nabla \mathcal{L}$$ $$w \leftarrow w - \eta \, v_t\]
  • 想象一个滚下山的球:动量让它沿一致方向积累速度,并抑制侧向抖动。典型值为 \(\beta = 0.9\)

  • Nesterov加速梯度(NAG) 是一个巧妙的改进:与其在当前位置计算梯度,不如在“前瞻”位置 \(w - \eta \beta v_{t-1}\) 处计算梯度。这个修正步骤减少了过冲:

\[v_t = \beta \, v_{t-1} + \nabla \mathcal{L}(w - \eta \beta \, v_{t-1})$$ $$w \leftarrow w - \eta \, v_t\]
  • Adagrad 为每个参数自适应学习率。获得大梯度的参数得到较小的学习率,反之亦然。它累积平方梯度:
\[G_t = G_{t-1} + g_t^2, \quad w \leftarrow w - \frac{\eta}{\sqrt{G_t + \epsilon}} g_t\]
  • 问题:\(G_t\) 只增不减,因此有效学习率单调递减,最终变得太小而无法学习任何东西。

  • RMSprop 通过使用平方梯度的指数移动平均而不是总和来修复这个问题,使得近期梯度比遥远梯度更重要:

\[s_t = \beta \, s_{t-1} + (1 - \beta) g_t^2, \quad w \leftarrow w - \frac{\eta}{\sqrt{s_t + \epsilon}} g_t\]
  • Adam(自适应矩估计)结合了动量和RMSprop。它同时维护一阶矩估计(梯度的均值,如动量)和二阶矩估计(平方梯度的均值,如RMSprop):
\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$ $$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\]
  • 由于 \(m_t\)\(v_t\) 初始化为零,它们在早期步长中会偏向零。偏差修正解决了这个问题:
\[\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}\]
\[w \leftarrow w - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t\]

二维等高线图:SGD走之字形,动量沿更平滑的路径,Adam采取最直接的路径到达最小值

  • 默认超参数(\(\beta_1 = 0.9\)\(\beta_2 = 0.999\)\(\epsilon = 10^{-8}\))在广泛问题上表现良好,这就是为什么Adam是大多数深度学习工作中的默认优化器。

  • AdamW 将权重衰减与梯度更新解耦。标准的L2正则化和权重衰减对于SGD等价,但对于Adam则不同。AdamW直接对参数应用权重衰减,而不是在梯度中添加 \(\lambda w\)。这带来了更好的泛化,现在已成为Transformer训练的标准:

\[w \leftarrow w - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \, w \right)\]
  • LION(演化符号动量)是一种通过程序搜索发现的新型优化器。它只使用动量更新的符号(而不是大小),这使得每次更新的规模一致。LION使用的内存少于Adam(没有二阶矩缓冲区),并在许多任务上可以匹敌或超越Adam:
\[w \leftarrow w - \eta \cdot \text{sign}(\beta_1 \, m_{t-1} + (1 - \beta_1) \, g_t)$$ $$m_t = \beta_2 \, m_{t-1} + (1 - \beta_2) \, g_t\]
  • Muon(动量+正交化)应用Nesterov动量,然后使用Newton-Schulz迭代对更新矩阵进行正交化,该迭代近似极分解。得到的更新方向位于Stiefel流形上,每次更新在所有奇异方向上的大小大致相等,防止任何单一方向占主导。这消除了对自适应二阶矩估计的需求(没有像Adam那样的 \(v_t\) 缓冲区),减少了内存。Muon在Transformer训练上显示出强劲的结果,通常以更快的收敛速度匹配AdamW的质量,特别是对于注意力和MLP权重矩阵。嵌入层和输出层通常仍由AdamW处理。
\[G_t = \text{NesterovMomentum}(\nabla \mathcal{L})$$ $$U_t = \text{NewtonSchulz}(G_t) \approx G_t (G_t^T G_t)^{-1/2}$$ $$W \leftarrow W - \eta \, U_t\]
  • Newton-Schulz迭代通过重复 \(X_{k+1} = \frac{1}{2} X_k (3I - X_k^T X_k)\) 几步(通常5-10次)来计算正交因子。这避免了完整SVD的成本,同时给出了良好的近似。

Muon正交化:动量更新的奇异值偏斜,Newton-Schulz迭代使其均衡,因此所有方向均匀更新

优化器内存对比:每个优化器每个参数存储的内容

  • 除了MSE和BCE,还有几种常用的损失函数

  • 平均绝对误差(MAE),即L1损失,取绝对差的平均值:\(\frac{1}{n}\sum|y_i - \hat{y}_i|\)。它对异常值更鲁棒,因为不平方大误差。

  • Huber损失 结合了两者的优点:对小误差行为像MSE(平滑,易优化),对大误差行为像MAE(对异常值鲁棒)。它有一个阈值 \(\delta\) 控制转换点。

  • 分类交叉熵(CCE) 将BCE推广到多类。如果 \(\hat{y}_k\) 是第 \(k\) 类的预测概率,真实类别是 \(c\)

\[\mathcal{L} = -\log(\hat{y}_c)\]
  • 这仅仅是正确类别的负对数概率。最小化交叉熵等价于最大化似然,这联系到第05章的信息论:交叉熵衡量当你使用预测分布而非真实分布时所需的额外比特数。

  • 合页损失 用于支持向量机:\(\mathcal{L} = \max(0, 1 - y \cdot f(x))\)。它只惩罚那些在边界错误侧或在边界内的预测。一旦一个点以足够置信度被正确分类,损失为零。

  • 正则化 通过增加对复杂模型的惩罚来防止过拟合。正则化后的损失为:

\[\mathcal{L}_{\text{reg}} = \mathcal{L}_{\text{data}} + \lambda \, R(w)\]
  • L2正则化(岭回归,权重衰减)惩罚权重的平方和:\(R(w) = \|w\|^2 = \sum w_i^2\)。它阻止任何单个权重大幅增长,有效地将所有权重向零收缩,但很少使它们精确为零。

  • L1正则化(Lasso)惩罚权重的绝对值之和:\(R(w) = \|w\|_1 = \sum |w_i|\)。它鼓励稀疏性,使许多权重精确为零,从而执行自动特征选择。

  • 弹性网 结合两者:\(R(w) = \alpha \|w\|_1 + (1 - \alpha) \|w\|^2\),混合了稀疏性和收缩。

  • 有一个优美的贝叶斯解释(来自第05章)。L2正则化等价于在权重上放置高斯先验并求最大后验估计。L1正则化对应拉普拉斯先验。正则化强度 \(\lambda\) 控制你相对于数据对先验的信任程度。

  • 评估指标 告诉你模型是否真正有效。对于回归,标准指标是MSE和MAE。对于分类,情况更为细致。

  • 混淆矩阵 是一个包含二分类四个计数的表格:

  • 真正例(TP):预测为正,实际为正
  • 假正例(FP):预测为正,实际为负
  • 真负例(TN):预测为负,实际为负
  • 假负例(FN):预测为负,实际为正

  • 准确率 = \(\frac{TP + TN}{TP + TN + FP + FN}\) 在类别不平衡时可能具有误导性。如果99%的邮件不是垃圾邮件,那么总是预测“非垃圾邮件”的模型有99%的准确率,但毫无用处。

  • 精确率 = \(\frac{TP}{TP + FP}\) 回答:在所有预测为正的样本中,有多少实际为正?高精确率意味着很少误报。

  • 召回率(灵敏度)= \(\frac{TP}{TP + FN}\) 回答:在所有实际为正的样本中,你捕捉到了多少?高召回率意味着很少漏报。

  • F1分数 = \(\frac{2 \cdot \text{precision} \cdot \text{recall}}{\text{precision} + \text{recall}}\) 是精确率和召回率的调和平均数,平衡两者。

  • ROC曲线 绘制真正例率(召回率)与假正例率(\(\frac{FP}{FP + TN}\))的关系,随着分类阈值从0变化到1。完美分类器紧贴左上角。AUC(ROC曲线下面积)用一个数字概括性能:1.0为完美,0.5为随机猜测。

  • 交叉验证 提供了更可靠的泛化性能估计。在 \(k\) 折交叉验证中,你将数据分成 \(k\) 折,用 \(k-1\) 折训练,在剩余一折上测试,然后轮换。所有 \(k\) 折的平均测试性能就是你的估计。这使用了所有数据既用于训练又用于测试(只是不同时),在数据稀缺时尤其有价值。

  • 偏差-方差权衡(来自第04章)是机器学习中的基本张力。模型的期望误差分解为:

\[\text{Error} = \text{Bias}^2 + \text{Variance} + \text{Irreducible Noise}\]
  • 偏差 是由错误假设引起的系统性误差(例如,对弯曲数据拟合直线)。方差 是对训练数据波动的敏感性(例如,一个20次多项式拟合噪声)。简单模型具有高偏差和低方差;复杂模型具有低偏差和高方差。最佳点使总误差最小化。

  • 学习率调度 在训练期间调整 \(\eta\)。常见策略:

  • 阶梯衰减:每 \(N\) 个周期将 \(\eta\) 乘以一个因子(例如0.1)
  • 余弦退火:按照余弦曲线从初始值平滑降低 \(\eta\) 到接近零
  • 预热:开始时使用非常小的 \(\eta\),在前几千步线性增加,然后衰减。这防止大的初始梯度使训练不稳定
  • 1cycle:一次余弦先上升后下降,可以加速收敛

  • 超参数调优 是寻找学习率、批量大小、正则化强度以及其他不能由梯度下降学习的好值的过程。常见方法:

  • 网格搜索:在预定义网格上尝试每个组合(详尽但昂贵)
  • 随机搜索:随机采样组合,这通常更高效,因为并非所有超参数同等重要
  • 贝叶斯优化:建立目标函数的模型,并智能选择下一步尝试的超参数
  • ASHA(异步连续减半算法):并行运行许多试验,预算较小,然后将最有希望的试验提升到更大的预算,同时提前终止其余试验。它结合了早期停止的效率和大规模并行性——与其运行100次完整训练,不如廉价地启动所有100次,在每个梯级保留前四分之一,只有少数运行到完成。这是现代大规模调优框架(如Ray Tune)的支柱。

  • 无调度学习 完全消除了对学习率调度的需求。它不是在固定曲线上衰减 \(\eta\),而是维护两个序列:一个慢速移动平均的迭代 \(z_t\)(收敛到最优)和一个快速探索的迭代 \(y_t\)(梯度在此计算)。最终输出是平均后的序列,从理论上证明其收敛速度与事后最优调度的收敛速度相匹配。这完全消除了调度作为一个超参数——你只需设置基础学习率,优化器处理其余部分。SGD和Adam的无调度变体已被证明能够匹配或超过其调优调度对应版本。

编码任务(使用CoLab或notebook)

  1. 实现线性回归,使用正规方程和梯度下降。比较解,并绘制GD损失随迭代的收敛情况。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 生成合成数据:y = 3x + 2 + 噪声
    key = jax.random.PRNGKey(42)
    n = 100
    X = jax.random.uniform(key, (n, 1), minval=0, maxval=10)
    y = 3 * X[:, 0] + 2 + jax.random.normal(key, (n,)) * 1.5
    
    # 添加偏置列
    X_b = jnp.column_stack([X, jnp.ones(n)])
    
    # 正规方程
    w_exact = jnp.linalg.solve(X_b.T @ X_b, X_b.T @ y)
    print(f"正规方程: w={w_exact[0]:.4f}, b={w_exact[1]:.4f}")
    
    # 梯度下降
    w_gd = jnp.zeros(2)
    lr = 0.005
    losses = []
    for step in range(500):
        pred = X_b @ w_gd
        error = pred - y
        loss = jnp.mean(error ** 2)
        losses.append(float(loss))
        grad = (2 / n) * X_b.T @ error
        w_gd = w_gd - lr * grad
    
    print(f"梯度下降: w={w_gd[0]:.4f}, b={w_gd[1]:.4f}")
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].scatter(X[:, 0], y, s=15, alpha=0.5, color='#3498db')
    axes[0].plot([0, 10], [w_exact[1], w_exact[0]*10 + w_exact[1]], color='#e74c3c', linewidth=2)
    axes[0].set_title("线性回归拟合")
    axes[0].set_xlabel("x"); axes[0].set_ylabel("y")
    
    axes[1].plot(losses, color='#27ae60', linewidth=1.5)
    axes[1].set_title("GD损失收敛")
    axes[1].set_xlabel("步数"); axes[1].set_ylabel("MSE")
    axes[1].set_yscale('log')
    plt.tight_layout()
    plt.show()
    

  2. 从头实现逻辑回归,使用梯度下降。在一个二维数据集上训练,并可视化学到的决策边界。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    from sklearn.datasets import make_moons
    
    # 生成数据
    X, y = make_moons(n_samples=300, noise=0.2, random_state=42)
    X, y = jnp.array(X), jnp.array(y, dtype=jnp.float32)
    
    def sigmoid(z):
        return 1 / (1 + jnp.exp(-z))
    
    # 添加偏置列
    X_b = jnp.column_stack([X, jnp.ones(len(X))])
    w = jnp.zeros(3)
    lr = 0.5
    losses = []
    
    for step in range(2000):
        z = X_b @ w
        pred = sigmoid(z)
        # BCE损失
        loss = -jnp.mean(y * jnp.log(pred + 1e-8) + (1 - y) * jnp.log(1 - pred + 1e-8))
        losses.append(float(loss))
        # 梯度
        grad = X_b.T @ (pred - y) / len(y)
        w = w - lr * grad
    
    # 决策边界
    xx, yy = jnp.meshgrid(jnp.linspace(-2, 3, 200), jnp.linspace(-1.5, 2, 200))
    grid = jnp.column_stack([xx.ravel(), yy.ravel(), jnp.ones(xx.size)])
    zz = sigmoid(grid @ w).reshape(xx.shape)
    
    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db'])
    plt.contour(xx, yy, zz, levels=[0.5], colors='#9b59b6', linewidths=2)
    plt.scatter(X[y==0, 0], X[y==0, 1], c='#e74c3c', s=15, label='类别0')
    plt.scatter(X[y==1, 0], X[y==1, 1], c='#3498db', s=15, label='类别1')
    plt.title("逻辑回归决策边界")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()
    

  3. 在二维二次曲面上比较优化器的轨迹。从相同的起点运行SGD、SGD+Momentum和Adam,并绘制它们的路径。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 拉长的二次曲面:L(w1, w2) = 0.5*w1^2 + 10*w2^2
    def loss_fn(w):
        return 0.5 * w[0]**2 + 10 * w[1]**2
    
    grad_fn = jax.grad(loss_fn)
    
    def run_sgd(w0, lr=0.05, steps=80):
        w = w0.copy()
        path = [w.copy()]
        for _ in range(steps):
            g = grad_fn(w)
            w = w - lr * g
            path.append(w.copy())
        return jnp.stack(path)
    
    def run_momentum(w0, lr=0.05, beta=0.9, steps=80):
        w, v = w0.copy(), jnp.zeros(2)
        path = [w.copy()]
        for _ in range(steps):
            g = grad_fn(w)
            v = beta * v + (1 - beta) * g
            w = w - lr * v
            path.append(w.copy())
        return jnp.stack(path)
    
    def run_adam(w0, lr=0.05, b1=0.9, b2=0.999, eps=1e-8, steps=80):
        w, m, v = w0.copy(), jnp.zeros(2), jnp.zeros(2)
        path = [w.copy()]
        for t in range(1, steps + 1):
            g = grad_fn(w)
            m = b1 * m + (1 - b1) * g
            v = b2 * v + (1 - b2) * g**2
            m_hat = m / (1 - b1**t)
            v_hat = v / (1 - b2**t)
            w = w - lr * m_hat / (jnp.sqrt(v_hat) + eps)
            path.append(w.copy())
        return jnp.stack(path)
    
    w0 = jnp.array([8.0, 3.0])
    sgd_path = run_sgd(w0)
    mom_path = run_momentum(w0)
    adam_path = run_adam(w0)
    
    # 绘图
    fig, ax = plt.subplots(figsize=(8, 6))
    w1 = jnp.linspace(-10, 10, 100)
    w2 = jnp.linspace(-4, 4, 100)
    W1, W2 = jnp.meshgrid(w1, w2)
    L = 0.5 * W1**2 + 10 * W2**2
    ax.contour(W1, W2, L, levels=20, cmap='Greys', alpha=0.4)
    ax.plot(sgd_path[:,0], sgd_path[:,1], 'o-', color='#3498db', markersize=2, linewidth=1, label='SGD')
    ax.plot(mom_path[:,0], mom_path[:,1], 'o-', color='#27ae60', markersize=2, linewidth=1, label='动量')
    ax.plot(adam_path[:,0], adam_path[:,1], 'o-', color='#e74c3c', markersize=2, linewidth=1, label='Adam')
    ax.plot(0, 0, 'k*', markersize=15, label='最小值')
    ax.set_xlabel('w₁'); ax.set_ylabel('w₂')
    ax.set_title("优化器在拉长二次曲面上的轨迹")
    ax.legend()
    plt.grid(alpha=0.3)
    plt.show()
    

  4. 展示L1与L2正则化对权重稀疏性的影响。使用两种惩罚训练线性回归,并比较得到的权重向量。

    import jax
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 合成数据:只有20个特征中的前3个是相关的
    key = jax.random.PRNGKey(0)
    n, d = 200, 20
    w_true = jnp.zeros(d).at[:3].set(jnp.array([3.0, -2.0, 1.5]))
    X = jax.random.normal(key, (n, d))
    y = X @ w_true + 0.5 * jax.random.normal(key, (n,))
    
    def train_ridge(X, y, lam=1.0, lr=0.01, steps=2000):
        """通过GD进行L2正则化线性回归"""
        w = jnp.zeros(X.shape[1])
        for _ in range(steps):
            pred = X @ w
            grad = (2/len(y)) * X.T @ (pred - y) + 2 * lam * w
            w = w - lr * grad
        return w
    
    def train_lasso(X, y, lam=1.0, lr=0.01, steps=2000):
        """通过近端GD进行L1正则化线性回归"""
        w = jnp.zeros(X.shape[1])
        for _ in range(steps):
            pred = X @ w
            grad = (2/len(y)) * X.T @ (pred - y)
            w = w - lr * grad
            # 软阈值(L1的近端算子)
            w = jnp.sign(w) * jnp.maximum(jnp.abs(w) - lr * lam, 0)
        return w
    
    w_l2 = train_ridge(X, y, lam=0.1)
    w_l1 = train_lasso(X, y, lam=0.1)
    
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    axes[0].bar(range(d), w_true, color='#333', alpha=0.7)
    axes[0].set_title("真实权重"); axes[0].set_xlabel("特征")
    axes[1].bar(range(d), w_l2, color='#3498db', alpha=0.7)
    axes[1].set_title("L2 (岭回归): 收缩所有"); axes[1].set_xlabel("特征")
    axes[2].bar(range(d), w_l1, color='#e74c3c', alpha=0.7)
    axes[2].set_title("L1 (Lasso): 将不相关的置零"); axes[2].set_xlabel("特征")
    plt.tight_layout()
    plt.show()
    
    print(f"L2非零权重数: {int(jnp.sum(jnp.abs(w_l2) > 0.01))}/{d}")
    print(f"L1非零权重数: {int(jnp.sum(jnp.abs(w_l1) > 0.01))}/{d}")