微分学¶
微分学捕捉瞬时变化率。本文涵盖极限、导数、求导法则、链式法则(反向传播的基础)以及机器学习中常用的导数。
-
在前面的章节中,我们学习了如何将数据表示为向量并用矩阵进行变换。但许多现实世界的现象并不是静态的。汽车在加速、股票价格在波动、神经网络损失随着权重更新而变化。微积分是描述变化的数学。
-
微积分提出两个问题:某事物此刻变化得有多快?(微分学)以及随着时间的推移它累积了多少?(积分学)。本节讨论“有多快”的问题。
-
想象你在开车,瞥一眼速度表,读数为 60 公里/小时。这个数字不是你整个行程的平均速度,而是你在这一瞬间的速度。微分学为我们提供了计算这种瞬时变化率的工具。
-
但首先,让我们重新审视一条直线的方程:\(y = mx + b\)。
-
这是两个量之间最简单的关系。
- \(b\) 是 y 轴截距,直线与 y 轴的交点(\(x = 0\) 时的起始值)。
- \(m\) 是 斜率,变化率:\(x\) 每增加 1 个单位,\(y\) 就改变 \(m\)。
- 如果 \(m = 3\),直线陡峭上升;如果 \(m = 0\),直线是水平的;如果 \(m = -2\),直线下降。
-
斜率计算公式为 \(m = \frac{\Delta y}{\Delta x} = \frac{y_2 - y_1}{x_2 - x_1}\) ,即“\(y\) 变化了多少”与“\(x\) 变化了多少”的比值。
-
一旦知道 \(m\) 和 \(b\),就可以对任意 \(x\) 计算出 \(y\)。
-
例如,如果 \(m = 2\) 且 \(b = 3\),那么在 \(x = 5\) 处:\(y = 2(5) + 3 = 13\)。
-
这两个参数完全决定了直线,预测任何输出只需代入即可。
-
对于直线,斜率处处相同。
-
这个思想可以推广到直线以外的函数。任何函数都是一个将输入映射到输出的规则,一旦知道它的公式(参数和形状),就可以计算任意输入对应的输出并绘制结果。
-
\(y = x^2\) 给出抛物线,\(y = \sin(x)\) 给出波形,\(y = e^x\) 给出指数增长。每个公式定义了一条特定的曲线,能够将函数视为形状来理解,对于后续内容至关重要。
-
对于直线,斜率处处相同。但大多数有趣的函数是弯曲的,因此斜率随点的不同而不同。微积分为我们提供了一种方法,可以找到曲线上任意一点的斜率。
-
我们还需要极限的概念。极限描述了当输入越来越接近某个目标值(不一定要达到)时,函数值所趋近的值。
-
这读作:“当 \(x\) 趋近于 \(a\) 时,\(f(x)\) 趋近于 \(L\)。”函数在 \(x = a\) 处不一定等于 \(L\),只需要无限接近即可。
-
例如,取 \(f(x) = \frac{x^2 - 1}{x - 1}\)。如果直接代入 \(x = 1\),得到 \(\frac{0}{0}\),这是未定义的。
-
但尝试接近 1 的值:\(f(0.9) = 1.9\),\(f(0.99) = 1.99\),\(f(1.01) = 2.01\)。输出显然趋向于 2。
-
从代数角度可以看出原因:将分子分解为 \((x-1)(x+1)\),消去 \((x-1)\) 项,得到 \(f(x) = x + 1\)(对所有 \(x \neq 1\) 成立)。因此当 \(x \to 1\) 时,\(f(x) \to 2\)。
-
函数在 \(x = 1\) 处有一个空洞,但极限仍然存在。
-
极限是微积分其他所有内容的基础。
-
函数 \(f(x)\) 在点 \(x = a\) 处的导数衡量瞬时变化率。几何上,它是曲线在该点处切线的斜率。
- 为了计算这个斜率,我们先取曲线上的两个点,计算通过这两点的直线(割线)的斜率。然后让第二个点不断靠近第一个点,观察割线斜率趋近的值。这就是差商:
-
分子 \(f(a+h) - f(a)\) 是输出变化量,分母 \(h\) 是输入变化量。它们的比值是在一个极小区间上的平均变化率。当 \(h \to 0\) 时,这个平均值就变成了瞬时变化率。
-
例如,设 \(f(x) = x^2\),在 \(x = 3\) 处:
-
因此在 \(x = 3\) 处,函数 \(x^2\) 的变化率是每单位输入增加 6 单位输出。
-
如果极限存在,则称函数在该点可导。为此,函数必须连续(无跳跃)、光滑(无尖角),并且在点附近有定义。
-
如果你能不抬笔地画出曲线且没有任何折点,那么该点处很可能是可导的。
-
每次都从极限定义计算导数会很繁琐。幸运的是,有一些法则可以让我们快速求出几乎所有函数的导数。
-
常数规则:常数的导数为零。若 \(f(x) = 5\),则 \(f'(x) = 0\)。水平线的斜率为零。
-
幂规则:求导的主力。将指数提到前面,指数减一:
-
例如:\(\frac{d}{dx} x^3 = 3x^2\)。三次函数变成了二次函数。这对任何实数指数都成立,包括负数和分数:\(\frac{d}{dx} x^{-1} = -x^{-2}\),\(\frac{d}{dx} \sqrt{x} = \frac{d}{dx} x^{1/2} = \frac{1}{2}x^{-1/2}\)。
-
和/差规则:逐项求导。
- 乘积规则:当两个函数相乘时,导数不是简单地将各自的导数相乘。而是:
-
可以理解为:“第一个的变化率乘以第二个,加上第一个乘以第二个的变化率”。例如,\(\frac{d}{dx}[x^2 \sin x] = 2x \sin x + x^2 \cos x\)。
-
商规则:对于两个函数的比值:
-
助记:“下乘上导减上乘下导,除以分母的平方”。
-
链式法则:对机器学习最重要的法则。当函数复合(一个函数套在另一个函数里面)时,导数是沿链的各导数之积:
- 可以理解为“剥洋葱”:先对外层函数求导(保持内层函数不变),再乘以内层函数的导数。
-
例如,\(\frac{d}{dx} (3x + 1)^5 = 5(3x+1)^4 \cdot 3 = 15(3x+1)^4\)。外层函数是 \((\cdot)^5\),内层是 \(3x+1\)。
-
链式法则是神经网络中反向传播的数学基础。一个深度网络是一个长长的复合函数链。为了计算损失相对于每个权重的变化率,我们从输出层反向传播到输入层,每一步都乘以局部导数。
-
以下是你会遇到的最常见的导数。每一个都可以从极限定义推导出来,但熟记它们可以节省时间:
| 函数 | 导数 | 备注 |
|---|---|---|
| \(e^x\) | \(e^x\) | 唯一一个导数为自身的函数 |
| \(a^x\) | \(a^x \ln a\) | 指数函数的一般形式 |
| \(\ln x\) | \(\frac{1}{x}\) | 自然对数 |
| \(\log_a x\) | \(\frac{1}{x \ln a}\) | 一般对数 |
| \(\sin x\) | \(\cos x\) | |
| \(\cos x\) | \(-\sin x\) | 注意负号 |
| \(\tan x\) | \(\sec^2 x\) |
-
指数函数 \(e^x\) 非常特别:它是唯一一个等于自身导数的函数。这就是为什么 \(e\) 在机器学习中无处不在,从 softmax 激活函数到概率分布。
-
洛必达法则 处理产生 \(\frac{0}{0}\) 或 \(\frac{\infty}{\infty}\) 等不定式形式的极限。当直接代入得到这些形式之一时,你可以分别对分子和分母求导,然后再尝试求极限:
-
条件:\(f\) 和 \(g\) 必须在 \(a\) 附近可导,并且在 \(a\) 附近(可能 \(a\) 本身除外)\(g'(x) \neq 0\)。原始极限必须给出不定式形式。
-
例如:\(\lim_{x \to 0} \frac{\sin x}{x}\)。直接代入得到 \(\frac{0}{0}\)。应用洛必达法则:\(\lim_{x \to 0} \frac{\cos x}{1} = 1\)。这个极限是基础性的,出现在信号处理和傅里叶分析中。
-
如果结果仍然是不定式,你可以重复应用该法则。例如,\(\lim_{x \to 0} \frac{1 - \cos x}{x^2}\) 给出 \(\frac{0}{0}\)。第一次应用:\(\lim_{x \to 0} \frac{\sin x}{2x}\),仍然是 \(\frac{0}{0}\)。第二次应用:\(\lim_{x \to 0} \frac{\cos x}{2} = \frac{1}{2}\)。
-
如果两个函数可导,那么它们的和、差、积、复合以及商(分母不为零)也是可导的。这就是为什么我们可以自信地对由简单部分构成的复杂表达式进行求导。
编程任务(使用 CoLab 或 notebook)¶
-
可视化常见函数。并排绘制 \(x^2\)、\(\sin(x)\) 和 \(e^x\),建立不同公式产生不同形状的直观感受。尝试改变参数(例如 \(2x^2\)、\(\sin(2x)\)),观察曲线如何变化。
import jax.numpy as jnp import matplotlib.pyplot as plt x = jnp.linspace(-3, 3, 300) fig, axes = plt.subplots(1, 3, figsize=(12, 3)) axes[0].plot(x, x**2, color="#e74c3c") axes[0].set_title("x² (抛物线)") axes[1].plot(x, jnp.sin(x), color="#3498db") axes[1].set_title("sin(x) (波形)") axes[2].plot(x, jnp.exp(x), color="#27ae60") axes[2].set_title("eˣ (指数)") for ax in axes: ax.axhline(0, color="gray", linewidth=0.5) ax.axvline(0, color="gray", linewidth=0.5) plt.tight_layout() plt.show() -
使用 JAX 的自动微分计算函数 \(f(x) = x^3 - 2x + 1\) 在若干点处的导数。与解析导数 \(f'(x) = 3x^2 - 2\) 进行比较。
-
数值验证链式法则。定义 \(f(x) = \sin(x^2)\),通过
jax.grad计算其导数,并与解析结果 \(2x\cos(x^2)\) 进行比较。 -
可视化导数。在同一图中绘制 \(f(x) = x^3 - 3x\) 及其导数 \(f'(x) = 3x^2 - 3\)。注意 \(f'(x) = 0\) 的点对应 \(f\) 的峰和谷。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt f = lambda x: x**3 - 3*x # jax.grad 作用于标量;jax.vmap 将其向量化,以便同时处理数组的每个输入 df = jax.vmap(jax.grad(f)) x = jnp.linspace(-2.5, 2.5, 200) plt.plot(x, jax.vmap(f)(x), label="f(x)") plt.plot(x, df(x), label="f'(x)", linestyle="--") plt.axhline(0, color="gray", linewidth=0.5) plt.legend() plt.title("函数及其导数") plt.show()