Skip to content

矩阵运算

矩阵运算是深度学习的计算引擎。本文涵盖矩阵加法、标量乘法、矩阵-向量乘积、矩阵乘法、逐元素运算、克罗内克积和广播——这些是每一次前向传播和梯度更新背后的运算。

  • 矩阵可以像向量一样进行加法和缩放。

  • 对于加法,两个矩阵必须具有相同的维度,并且逐元素相加:

\[ \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} 6 & 8 \\ 10 & 12 \end{bmatrix} \]
  • 对于标量乘法,将每个元素乘以标量:
\[ 3 \times \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} = \begin{bmatrix} 3 & 6 \\ 9 & 12 \end{bmatrix} \]
  • 用矩阵能做的最简单的事情就是将它乘以一个向量。矩阵-向量乘法 \(A\mathbf{x}\) 使用 \(\mathbf{x}\) 的条目作为权重来组合 \(A\) 的列:
\[ \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 5 \\ 6 \end{bmatrix} = 5 \begin{bmatrix} 1 \\ 3 \end{bmatrix} + 6 \begin{bmatrix} 2 \\ 4 \end{bmatrix} = \begin{bmatrix} 17 \\ 39 \end{bmatrix} \]
  • 这是机器学习中的核心运算。每个神经网络层计算 \(A\mathbf{x} + \mathbf{b}\):一个矩阵乘以输入向量,再加上偏置。

  • 一般情况是矩阵乘法。给定 \(A\)\(m \times n\))和 \(B\)\(n \times p\)),乘积 \(C = AB\) 是一个 \(m \times p\) 矩阵,其中每个元素都是一个点积:

\[C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}\]
  • 结果中的每个条目是 \(A\) 的一行与 \(B\) 的一列的点积。内部维度必须匹配(\(n\)),结果取外部维度(\(m \times p\))。

  • 另一种理解方式:结果的每一列是 \(A\) 的列的加权和,权重来自 \(B\) 的对应列。

  • 如果 \(B\) 的某一列为 \([2, 3]^T\),那么结果的对应列是 \(2 \times (A \text{ 的第 1 列}) + 3 \times (A \text{ 的第 2 列})\)

  • 一个有用的特例:将矩阵乘以其转置总是得到一个方阵。\(AA^T\)\(m \times m\)\(A^TA\)\(n \times n\)

\[ \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 & 4 \\ 2 & 5 \\ 3 & 6 \end{bmatrix} = \begin{bmatrix} 14 & 32 \\ 32 & 77 \end{bmatrix} \]
  • 矩阵乘法的重要规则:

    • 不可交换:一般情况下 \(AB \neq BA\)。顺序很重要。

    • 可结合\((AB)C = A(BC)\)。你可以任意分组乘法。

    • 可分配\(A(B + C) = AB + AC\)

    • 单位元\(AI = IA = A\)

  • 哈达玛积(逐元素乘积)将两个相同大小的矩阵逐条目相乘,记作 \(A \odot B\)

\[ \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \odot \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} 5 & 12 \\ 21 & 32 \end{bmatrix} \]
  • 与标准矩阵乘法不同,哈达玛积是可交换的(\(A \odot B = B \odot A\)),并且要求两个矩阵具有相同的维度。它在机器学习中大量用于门控:通过与一个介于 0 和 1 之间的掩码逐元素相乘,控制每个条目的“通过”程度。

  • 两个向量 \(\mathbf{u}\)\(\mathbf{v}\)外积产生一个矩阵:\(\mathbf{u}\mathbf{v}^T\)。每个条目是 \(\mathbf{u}\) 的一个元素与 \(\mathbf{v}\) 的一个元素的乘积:

\[ \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} \begin{bmatrix} 4 & 5 \end{bmatrix} = \begin{bmatrix} 4 & 5 \\ 8 & 10 \\ 12 & 15 \end{bmatrix} \]
  • 结果总是秩为 1,因为每一行都是 \(\mathbf{v}^T\) 的缩放版本。任何矩阵都可以写成一系列秩为 1 的外积之和,这正是奇异值分解(SVD)所做的(在分解章节中介绍)。

  • 矩阵乘法计算成本很高。将两个 \(n \times n\) 矩阵相乘需要 \(O(n^3)\) 次运算。对于一个 \(1000 \times 1000\) 的矩阵,那就是十亿次乘法。

  • 当矩阵是稀疏的(大部分为零)时,朴素的乘法会浪费时间去乘以零。压缩稀疏行(CSR)格式只存储非零元素及其位置:

    • 数值:按行顺序排列的非零条目
    • 列索引:每个值所属的列
    • 行偏移:每行在数值列表中的起始位置
  • 例如,矩阵:

\[ A = \begin{bmatrix} 5 & 0 & 0 & 2 \\ 0 & 0 & 3 & 0 \\ 0 & 0 & 0 & -1 \end{bmatrix} \]
  • 存储为:values = [5, 2, 3, -1],columns = [0, 3, 2, 3],row offsets = [0, 2, 3, 4]。这样跳过了所有零,使得稀疏运算快得多。

  • 矩阵的一个核心用途是求解线性方程组。方程组 \(A\mathbf{x} = \mathbf{b}\) 问的是:“哪个向量 \(\mathbf{x}\),经过 \(A\) 变换后,会产生 \(\mathbf{b}\)?”

  • 例如,假设你在买水果。苹果单价 \(x_1\) 元,香蕉单价 \(x_2\) 元。已知 2 个苹果和 1 个香蕉花费 5 元,1 个苹果和 3 个香蕉花费 10 元。写成矩阵形式:

\[ \begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} 5 \\ 10 \end{bmatrix} \]
  • 将矩阵逐行乘以向量(每行与 \([x_1, x_2]^T\) 点积)得到两个方程:
\[2x_1 + 1x_2 = 5 \qquad \text{(第 1 行)} \qquad \qquad x_1 + 3x_2 = 10 \qquad \text{(第 2 行)}\]
  • 由第 1 行得 \(x_2 = 5 - 2x_1\)。代入第 2 行:\(x_1 + 3(5 - 2x_1) = 10\),解得 \(x_1 = 1\),然后 \(x_2 = 3\)。苹果 1 元,香蕉 3 元。

  • 验证——结果正确:

\[ \begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix} \begin{bmatrix} 1 \\ 3 \end{bmatrix} = \begin{bmatrix} 2 + 3 \\ 1 + 9 \end{bmatrix} = \begin{bmatrix} 5 \\ 10 \end{bmatrix} \]
  • 如果 \(A\) 可逆,解就是 \(\mathbf{x} = A^{-1}\mathbf{b}\)。但直接计算逆矩阵代价高昂且数值不稳定。在实践中,我们使用分解方法代替。

  • 并非所有矩阵都是方阵,也并非所有方阵都可逆。伪逆 \(A^+\) 将逆推广到任意矩阵。它总是存在并提供“尽可能好”的逆:

\[A^+ = (A^TA)^{-1}A^T\]
  • \(A\) 是下三角矩阵时,通过前代求解 \(L\mathbf{x} = \mathbf{b}\) 很容易:先解出 \(x_1\),然后用它求 \(x_2\),依此类推。

  • \(A\) 是上三角矩阵时,通过回代求解 \(U\mathbf{x} = \mathbf{b}\):先解出最后一个变量,然后向上回代。

  • 这就是将矩阵分解为三角因子(我们将在分解章节中看到)如此有用的原因。它将一个难题转化为两个简单的问题。

编程任务(使用 CoLab 或 notebook)

  1. 将两个矩阵相乘并验证维度。然后交换顺序,观察结果的变化(如果维度不匹配,会失败)。

    import jax.numpy as jnp
    
    A = jnp.array([[1.0, 2.0],
                   [3.0, 4.0]])
    B = jnp.array([[5.0, 6.0],
                   [7.0, 8.0]])
    
    print(f"A @ B:\n{A @ B}")
    print(f"B @ A:\n{B @ A}")
    print(f"是否相等: {jnp.allclose(A @ B, B @ A)}")
    

  2. 求解线性方程组 \(A\mathbf{x} = \mathbf{b}\),并通过回乘验证解。尝试改变 \(\mathbf{b}\) 观察解如何变化。

    import jax.numpy as jnp
    
    A = jnp.array([[2.0, 1.0],
                   [5.0, 3.0]])
    b = jnp.array([4.0, 7.0])
    
    x = jnp.linalg.solve(A, b)
    print(f"解 x: {x}")
    print(f"A @ x: {A @ x}")