Skip to content

优化

优化是模型训练的数学核心,旨在找到使损失函数最小化的参数。本章涵盖临界点、凸性、梯度下降、牛顿法、带拉格朗日乘子的约束优化,以及驱动现代深度学习的优化器(SGD、Adam)。

  • 训练神经网络、拟合回归线、调整超参数:几乎每个机器学习算法的核心都是一个优化问题。

  • 我们有一个函数(损失、代价、目标函数),想要找到使其尽可能小(或大)的输入。

  • 在优化之前,我们需要理解函数的零点(或根)。\(f(x)\) 的零点是满足 \(f(x) = 0\)\(x\) 值。在图形上,就是 \(x\) 轴截距。

  • 例如,\(f(x) = x^2 - 3x + 2 = (x-1)(x-2)\) 的零点在 \(x = 1\)\(x = 2\)。在两个零点之间,函数为负(\(f(1.5) = -0.25\));在零点之外,函数为正。零点将数轴划分为函数符号恒定的区域。

  • 零点的重数是相应因子出现的次数。

  • 在单零点(重数1)处,图形穿过 \(x\) 轴。在二重零点(重数2)处,图形接触 \(x\) 轴但弹回而不穿过,在该点呈现“平坦”。

  • 求零点很重要,因为导数 \(f'(x)\) 的零点就是 \(f(x)\)临界点,它们是极大值和极小值的候选点。

  • 在极大值或极小值处,切线是水平的(斜率为0),因此 \(f'(x) = 0\)

临界点:导数等于零的点,函数在此处有峰值、谷值或鞍点

  • 但并非每个临界点都是极大值或极小值。\(f'(x) = 0\) 的点也可能是拐点(例如 \(f(x) = x^3\)\(x = 0\) 处),函数在此短暂变平但不改变方向。

  • 二阶导数判别法可解决这个问题。在临界点 \(x = c\)\(f'(c) = 0\))处:

    • \(f''(c) > 0\):曲线是凹向上的(如碗状),则 \(c\)局部极小值
    • \(f''(c) < 0\):曲线是凹向下的(如山丘状),则 \(c\)局部极大值
    • \(f''(c) = 0\):判别法失效,需要更高阶导数或其他方法。
  • 例如,\(f(x) = x^3 - 3x\)。导数为 \(f'(x) = 3x^2 - 3 = 3(x-1)(x+1)\),因此临界点在 \(x = -1\)\(x = 1\)。二阶导数为 \(f''(x) = 6x\)。在 \(x = -1\) 处:\(f''(-1) = -6 < 0\)(局部极大值)。在 \(x = 1\) 处:\(f''(1) = 6 > 0\)(局部极小值)。

  • 如果函数图像上任意两点之间的线段位于图像上方(或图像上),则该函数是凸的。想象一个碗形,处处向上弯曲。数学上,若对所有 \(x\)\(f''(x) \geq 0\),则 \(f\) 是凸的。

凸函数有唯一的全局极小值;非凸函数可能有多个局部极小值

  • 凸性非常强大,因为凸函数有一个显著性质:每个局部极小值同时也是全局极小值。不会存在欺骗性的局部低谷让你陷入。如果你把一个球滚进一个凸碗,它总会到达底部。

  • \(-f\) 是凸的,则 \(f\)凹的(向下弯曲)。函数从凹转为凸的点是拐点,出现在 \(f''(x) = 0\) 处。

  • 牛顿法利用切线来寻找函数的零点(并由此找到其导数的临界点)。从初始猜测 \(x_0\) 开始,迭代更新:

\[x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}\]

牛顿法:沿切线找到根的更好近似

  • 思路:在 \(x_n\) 处作切线,找到切线与 \(x\) 轴的交点,该交点成为 \(x_{n+1}\)。对于良态函数且初始点良好,牛顿法收敛极快(二次收敛,即正确数字的位数大约每一步翻倍)。

  • 例如,求 \(\sqrt{5}\)\(f(x) = x^2 - 5\) 的零点):\(f'(x) = 2x\),所以 \(x_{n+1} = x_n - \frac{x_n^2 - 5}{2x_n}\)。从 \(x_0 = 2\) 开始:\(x_1 = 2.25\)\(x_2 = 2.2361\ldots\),已经精确到小数点后四位。

  • 如果初始猜测远离根、根附近 \(f'(x) = 0\)、或者函数附近有拐点,牛顿法可能失败。它还需要计算导数,这可能很昂贵。

  • 对于优化(寻找极小值而非零点),我们将牛顿法应用于 \(f'(x) = 0\),得到更新公式:

\[x_{n+1} = x_n - \frac{f'(x_n)}{f''(x_n)}\]
  • 在多维情况下,这变为 \(\mathbf{x}_{n+1} = \mathbf{x}_n - H^{-1} \nabla f(\mathbf{x}_n)\),其中 \(H\) 是海森矩阵。这正是上一文件中的二阶泰勒逼近的应用:将函数近似为二次函数,跳到该二次函数的最小值,然后重复。

  • 拉格朗日乘子法求解约束优化问题:在约束 \(g(x, y) = c\) 下找到 \(f(x, y)\) 的最优值。我们不是在 \(\mathbb{R}^n\) 中全域搜索,而是被限制在满足约束的集合(一条曲线或曲面)上。

  • 关键几何直觉是:在约束最优值处,\(f\) 的梯度必须与 \(g\) 的梯度平行。如果它们不平行,我们可以沿着约束朝某个方向移动来继续改善 \(f\),因此尚未达到最优。

  • 我们引入一个新变量 \(\lambda\)(拉格朗日乘子),并定义拉格朗日函数

\[\mathcal{L}(x, y, \lambda) = f(x, y) - \lambda(g(x, y) - c)\]
  • 令所有偏导数为零得到一个方程组,其解即为约束最优值:
\[\frac{\partial \mathcal{L}}{\partial x} = 0, \quad \frac{\partial \mathcal{L}}{\partial y} = 0, \quad \frac{\partial \mathcal{L}}{\partial \lambda} = 0\]

拉格朗日乘子:在最优值处,f 和 g 的梯度平行

  • 例如,在约束 \(x^2 + y^2 = 1\) 下最大化 \(f(x,y) = x^2 y\)。拉格朗日函数为 \(\mathcal{L} = x^2 y - \lambda(x^2 + y^2 - 1)\)。求偏导:
\[2xy - 2\lambda x = 0, \quad x^2 - 2\lambda y = 0, \quad x^2 + y^2 = 1\]
  • 由第一个方程(假设 \(x \neq 0\)):\(\lambda = y\)。代入第二个:\(x^2 = 2y^2\)。结合约束:\(2y^2 + y^2 = 1\),所以 \(y = \frac{1}{\sqrt{3}}\)。最大值为 \(f = \frac{2}{3\sqrt{3}}\)

  • 对于不等式约束(\(g(x,y) \leq c\) 而非 \(= c\)),Karush-Kuhn-Tucker (KKT) 条件 推广了拉格朗日乘子法。约束要么是激活的(紧的,按等式处理),要么是不激活的(解在内部,约束无关)。

  • 在实践中,我们很少手动优化。以下是主要的算法族:

    • 一阶方法(只使用梯度):梯度下降、随机梯度下降(SGD)、Adam。每步计算便宜,但收敛可能较慢,尤其是在病态问题上。

    • 二阶方法(使用梯度和海森矩阵):牛顿法收敛快,但计算和求逆海森矩阵代价高(对于 \(n\) 个参数是 \(O(n^3)\))。拟牛顿法(如 BFGS 和 L-BFGS)仅利用梯度信息近似海森矩阵,比一阶方法收敛更快,又没有二阶方法的全部代价。

    • 共轭梯度法:对大型稀疏系统高效,仅需矩阵-向量乘积而不存储完整海森矩阵。

    • 高斯-牛顿法莱文贝格-马夸尔特法:专门用于最小二乘问题(回归中常见),通过雅可比矩阵近似海森矩阵。

    • 自然梯度下降:利用 Fisher 信息矩阵考虑参数空间的几何结构,对概率模型更有效。

  • 优化器的选择取决于问题。对于深度学习,一阶方法(尤其是 Adam)占主导地位,因为参数数量巨大(数百万到数十亿),使得海森矩阵计算不现实。对于较小且目标光滑的问题,二阶方法可能快得多。

编程任务(使用 CoLab 或 notebook)

  1. 实现牛顿法求 \(\sqrt{7}\)\(f(x) = x^2 - 7\) 的零点)。观察快速收敛。

    import jax.numpy as jnp
    
    f = lambda x: x**2 - 7
    df = lambda x: 2*x
    
    x = 3.0  # 初始猜测
    for i in range(6):
        x = x - f(x) / df(x)
        print(f"第 {i+1} 步: x = {x:.10f}  (误差: {abs(x - jnp.sqrt(7.0)):.2e})")
    

  2. 使用梯度下降最小化 \(f(x, y) = (x - 3)^2 + (y + 1)^2\)。最小值在 \((3, -1)\)。尝试不同的学习率。

    import jax
    import jax.numpy as jnp
    
    def f(params):
        x, y = params
        return (x - 3)**2 + (y + 1)**2
    
    grad_f = jax.grad(f)
    params = jnp.array([0.0, 0.0])
    lr = 0.1
    
    for i in range(20):
        g = grad_f(params)
        params = params - lr * g
        if i % 5 == 0 or i == 19:
            print(f"第 {i:2d} 步: ({params[0]:.4f}, {params[1]:.4f})  loss={f(params):.6f}")
    

  3. 数值求解约束优化问题。通过参数化 \(y = 10 - x\) 并求单变量函数的最优值,在约束 \(x + y = 10\) 下最大化 \(f(x,y) = xy\)

    import jax
    import jax.numpy as jnp
    
    # 代入约束:y = 10 - x,所以 f = x(10 - x) = 10x - x²
    f = lambda x: x * (10 - x)
    df = jax.grad(f)
    
    # 梯度上升(我们想要最大值,所以加上梯度)
    x = 1.0
    lr = 0.1
    for i in range(20):
        x = x + lr * df(x)
    print(f"x={x:.4f}, y={10-x:.4f}, f={f(x):.4f}")  # 应为 x=5, y=5, f=25