为什么用C++以及ML框架如何工作¶
本书中每次jnp.matmul、每次torch.nn.Linear、每次np.dot调用,底层都是在执行C++和CUDA代码。本文件揭开面纱:为什么ML框架这样构建,给Python工程师的C++快速入门,何时编写自定义C++内核,以及如何将它们绑定进Python——连接你写的代码与其运行的硬件之间的桥梁。
-
你已经花了15章写Python。你导入了JAX,调用了
jax.grad,运行了训练循环,构建了模型。感觉都像Python。但真相是:几乎没有实际计算发生在Python中。 -
当你在PyTorch中写
output = model(input)或在JAX中写output = jnp.matmul(W, x)时,Python几乎什么都没做。它构造计算的描述(操作图),然后交给做真正工作的C++/CUDA后端。Python是方向盘;C++是引擎。
为什么Python前端、C++后端¶
- 这种双语言架构存在是因为Python和C++擅长的东西相反:
| Python | C++ | |
|---|---|---|
| 开发速度 | 快(动态类型、REPL、无需编译) | 慢(静态类型、头文件、编译时间) |
| 执行速度 | ~比C慢100倍(解释执行、GIL) | 接近硬件速度(编译执行、无开销) |
| 内存控制 | 自动(GC),无法控制布局 | 手动,精确控制每个字节 |
| 硬件访问 | 无(无SIMD、无GPU、无自定义内存) | 完全(内联函数、CUDA、内联汇编) |
| 生态 | 丰富的ML生态(notebooks、可视化、数据) | 丰富的系统编程生态(OS、驱动、引擎) |
-
洞见:每种语言用于它擅长的方面。Python处理与人类生产力有关的方面(实验设计、超参数调优、数据探索)。C++处理与机器性能有关的方面(矩阵乘法、卷积、注意力内核)。
-
一次简单的矩阵乘法
jnp.matmul(A, B),\(A\)是\(4096 \times 4096\),执行约1370亿次浮点运算。在纯Python(嵌套循环)中,这需要约30分钟。在带AVX-512 SIMD和多线程的优化C++中,需要约10毫秒。这是180,000倍的差距。没有任何Python技巧能弥补这个差距。
ML框架如何组织¶
- 每个主流ML框架遵循相同的架构:
用户代码(Python)
↓
Python API层(torch.nn, jax.numpy, numpy)
↓
分发/JIT编译器(torch.compile, XLA, NumPy分发)
↓
C++内核库(ATen/PyTorch, XLA, BLAS/LAPACK)
↓
硬件专用后端(CUDA, cuDNN, MKL, oneDNN, Metal)
↓
硬件(CPU SIMD单元, GPU核心, TPU MXU)
NumPy¶
-
NumPy的核心是用C写的。当你调用
np.dot(A, B)时,Python调用一个C函数,该函数调用BLAS(Basic Linear Algebra Subprograms),通常是Intel MKL或OpenBLAS。BLAS是手工优化的C和Fortran代码,使用SIMD指令、缓存感知的内存访问模式和多线程。数十年的优化投入让矩阵乘法变快。 -
NumPy仅支持CPU。它不使用GPU。但在CPU上,它非常快,因为它委托给最好的可用BLAS实现。
PyTorch¶
-
PyTorch的计算引擎是ATen(A Tensor Library),用C++编写。ATen实现了约2000个张量操作(add、matmul、conv2d、softmax……),每个都有CPU和CUDA后端。
-
当你调用
torch.matmul(A, B)时:- Python分发到ATen C++函数。
- ATen检查设备(CPU或CUDA)和数据类型。
- CPU上:调用MKL/OpenBLAS。GPU上:调用cuBLAS(NVIDIA针对GPU优化的BLAS)。
- 结果包装在Python张量对象中返回。
-
torch.compile(PyTorch 2.0+)更进一步:它追踪你的Python代码,构建计算图,并使用Triton(GPU)或C++/OpenMP(CPU)编译。编译后的代码融合操作、消除Python开销,可以比eager模式快2-5倍。
JAX¶
-
JAX将Python函数编译到XLA(Accelerated Linear Algebra),Google的ML工作负载编译器。当你
jax.jit一个函数时:- JAX追踪函数,将操作捕获为XLA计算图(HLO——High Level Operations)。
- XLA优化图:融合操作、消除冗余计算、优化内存布局。
- XLA编译到目标后端:CPU(通过LLVM)、GPU(通过CUDA/PTX)或TPU(通过TPU专用指令)。
- 编译后的代码直接在硬件上运行,完全不涉及Python。
-
这就是
jax.jit如此重要的原因:没有它,每个操作都是一次单独的Python→C++往返。有了它,整个函数就是单个编译内核。
给Python工程师的C++快速入门¶
- 你不需要成为C++专家。你需要理解足够的内容来阅读内核代码、编写简单扩展和理解性能讨论。以下是基础内容。
类型和变量¶
// C++需要显式类型(与Python不同)
int count = 0; // 32位整数
float loss = 0.5f; // 32位浮点数
double lr = 3e-4; // 64位浮点数
bool training = true; // 布尔值
// 数组(固定大小,栈分配)
float weights[1024]; // 1024个浮点数,内存连续
// 指针:持有内存地址的变量
float* ptr = weights; // ptr指向weights的第一个元素
float val = ptr[42]; // 通过指针算术访问第42个元素
// ptr[42]等价于*(ptr + 42)
- 指针是与Python最大的概念差异。在Python中,一切是引用,你从不考虑内存地址。在C++中,指针让你直接访问内存——强大但危险(悬空指针、缓冲区溢出)。
函数¶
// 函数声明:return_type name(param_type param_name)
float relu(float x) {
return x > 0.0f ? x : 0.0f;
}
// 按引用传递(避免复制大对象)
void scale_vector(std::vector<float>& vec, float factor) {
for (size_t i = 0; i < vec.size(); i++) {
vec[i] *= factor;
}
}
// const引用:只读、无拷贝
float sum(const std::vector<float>& vec) {
float total = 0.0f;
for (float x : vec) { // 基于范围的for循环(类似Python的for x in vec)
total += x;
}
return total;
}
内存:栈 vs 堆¶
// 栈分配:快速,自动生命周期(函数返回时释放)
float buffer[256]; // 栈上256个浮点数
// 堆分配:手动管理,生命周期超出函数
float* data = new float[n]; // 在堆上分配n个浮点数
// ... 使用data ...
delete[] data; // 你必须释放它(没有垃圾回收器)
// 现代C++:智能指针(自动清理,类似Python引用)
#include <memory>
auto data = std::make_unique<float[]>(n); // 超出作用域自动释放
- 关键规则:栈快速但有限(通常1-8 MB)。大数组(张量、特征图)必须在堆上。在Python中,一切在堆上,GC处理清理。在C++中,你自己管理(或使用智能指针)。
模板(泛型)¶
// 适用于任何数值类型的函数
template <typename T>
T add(T a, T b) {
return a + b;
}
add<float>(1.5f, 2.5f); // 返回4.0f
add<int>(3, 4); // 返回7
- 模板是C++库(如ATen)如何编写适用于float16、float32、float64等的代码,而不需要重复实现的方式。
标准库要点¶
#include <vector> // 动态数组(类似Python list)
#include <string> // 字符串类型
#include <unordered_map> // 哈希映射(类似Python dict)
#include <algorithm> // sort, find, transform等
#include <cmath> // 数学函数
std::vector<float> vec = {1.0f, 2.0f, 3.0f};
vec.push_back(4.0f); // 追加
float first = vec[0]; // 索引
size_t len = vec.size(); // 长度
std::unordered_map<std::string, int> counts;
counts["hello"] = 5; // 插入
if (counts.count("hello")) { } // 检查存在
何时编写自定义C++内核¶
-
大多数ML工程师从不需要写C++。框架的内置操作覆盖了99%的用例。你应该只在以下情况考虑自定义C++:
-
你的操作在框架中不存在:一种新的激活函数、一种自定义注意力模式、一种不能用现有操作组合表达的专用损失函数。
-
融合操作以提高性能:你的模型做
relu(layernorm(matmul(x, W) + b))。每个操作启动单独的内核、读写内存并同步。融合内核是一次性做全部,避免内存往返。这可以快2-5倍。 -
减少内存使用:自定义内核可以在不存储所有中间激活的情况下计算梯度(在内核级别的梯度检查点)。
-
面向新型硬件:新的加速器(如Cerebras、Groq)可能没有框架支持。你直接写内核。
-
对于情况1-2,Triton(第16章文件05)通常足够,并且比直接写CUDA C容易得多。只有在Triton无法表达你所需时才降级到CUDA C。
如何将C++绑定到Python¶
- 写C++是工作的一半。你还需要从Python调用它。
pybind11(通用)¶
- pybind11用最少的样板代码为C++函数创建Python绑定:
// my_ops.cpp
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;
// 一个简单的自定义操作
py::array_t<float> custom_relu(py::array_t<float> input) {
auto buf = input.request();
float* ptr = static_cast<float*>(buf.ptr);
size_t n = buf.size;
auto result = py::array_t<float>(n);
float* out = static_cast<float*>(result.request().ptr);
for (size_t i = 0; i < n; i++) {
out[i] = ptr[i] > 0 ? ptr[i] : 0;
}
return result;
}
PYBIND11_MODULE(my_ops, m) {
m.def("custom_relu", &custom_relu, "自定义ReLU操作");
}
# 编译
pip install pybind11
c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) my_ops.cpp -o my_ops$(python3-config --extension-suffix)
# 从Python使用
import my_ops
import numpy as np
x = np.array([-1.0, 2.0, -3.0, 4.0], dtype=np.float32)
y = my_ops.custom_relu(x)
print(y) # [0. 2. 0. 4.]
PyTorch C++扩展¶
- PyTorch提供了添加自定义操作的简化方法:
// custom_op.cpp
#include <torch/extension.h>
torch::Tensor custom_gelu(torch::Tensor x) {
return x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_gelu", &custom_gelu, "自定义GELU激活");
}
# 即时加载和编译
from torch.utils.cpp_extension import load
custom_ops = load(
name="custom_ops",
sources=["custom_op.cpp"],
extra_cflags=["-O3"],
)
x = torch.randn(1000)
y = custom_ops.custom_gelu(x)
torch.utils.cpp_extension.load一次性编译C++代码、创建共享库并将其作为Python模块加载。这是在PyTorch中试验自定义C++操作的最简单方式。
JAX自定义调用¶
-
JAX使用XLA自定义调用。流程更复杂(你向XLA注册一个C函数),但概念相同:写C/C++,绑定它,从Python调用。
-
对于大多数JAX用户,Pallas(在文件05中介绍)是更好的选择:它让你用XLA编译的类似Python的语法写GPU内核,而无需脱离JAX生态。
全景图¶
-
本文件解释了Python和硬件之间的层。本章其余文件更深入:
- 文件01:硬件本身(CPU架构、GPU架构、内存系统)
- 文件02-03:CPU上的SIMD编程(ARM NEON、x86 AVX)——你写使用CPU向量单元的C++
- 文件04:CUDA GPU编程——你写运行在数千GPU核心上的C++
- 文件05:Triton、Pallas和更高级的GPU编程——你写编译为GPU内核的Python
-
这个递进镜像了抽象阶梯:C++内联函数(最低、最多控制)→ CUDA(GPU专用)→ Triton/Pallas(Python风格、编译)→ JAX/PyTorch(最高、自动化)。每一级用控制换便利。理解较低级让你成为更高级的更好用户。
编程任务(用g++或clang++编译)¶
-
编写你的第一个C++程序。分配一个数组,填充它,计算总和,并计时。这介绍了编译、数组、指针和计时。
// task1_basics.cpp // 编译:g++ -O3 -o task1 task1_basics.cpp // 运行:./task1 #include <iostream> #include <chrono> #include <vector> int main() { const int N = 10'000'000; // C++允许'作为数字分隔符 std::vector<float> data(N); // 填充数组 for (int i = 0; i < N; i++) { data[i] = static_cast<float>(i) * 0.001f; } // 计算总和 auto start = std::chrono::high_resolution_clock::now(); float sum = 0.0f; for (int i = 0; i < N; i++) { sum += data[i]; } auto end = std::chrono::high_resolution_clock::now(); double elapsed = std::chrono::duration<double, std::milli>(end - start).count(); std::cout << "总和: " << sum << std::endl; std::cout << "时间: " << elapsed << " ms" << std::endl; std::cout << "元素数: " << N << std::endl; std::cout << "吞吐量: " << (N * sizeof(float)) / elapsed / 1e6 << " GB/s" << std::endl; return 0; } -
编写一个在数组上计算ReLU的C++函数,然后用pybind11构建Python绑定。从Python调用并与NumPy比较速度。
// task2_relu.cpp // 编译:c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) \ // task2_relu.cpp -o my_relu$(python3-config --extension-suffix) #include <pybind11/pybind11.h> #include <pybind11/numpy.h> namespace py = pybind11; py::array_t<float> cpp_relu(py::array_t<float> input) { auto buf = input.request(); float* ptr = static_cast<float*>(buf.ptr); int n = buf.size; auto result = py::array_t<float>(n); float* out = static_cast<float*>(result.request().ptr); for (int i = 0; i < n; i++) { out[i] = ptr[i] > 0.0f ? ptr[i] : 0.0f; } return result; } PYBIND11_MODULE(my_relu, m) { m.def("relu", &cpp_relu, "C++ ReLU"); }# test_relu.py — 在上面的C++模块编译后运行 import numpy as np import time import my_relu # 编译后的C++模块 x = np.random.randn(10_000_000).astype(np.float32) # C++ ReLU start = time.time() for _ in range(100): y_cpp = my_relu.relu(x) cpp_time = (time.time() - start) / 100 # NumPy ReLU start = time.time() for _ in range(100): y_np = np.maximum(x, 0) np_time = (time.time() - start) / 100 print(f"C++ ReLU: {cpp_time*1000:.2f} ms") print(f"NumPy ReLU: {np_time*1000:.2f} ms") print(f"匹配: {np.allclose(y_cpp, y_np)}") -
编写一个演示内存布局为何重要的C++程序。比较行主序 vs 列主序访问模式,并测量性能差异。
// task3_layout.cpp // 编译:g++ -O3 -o task3 task3_layout.cpp #include <iostream> #include <chrono> #include <vector> int main() { const int N = 4096; std::vector<float> matrix(N * N, 1.0f); // 行主序访问:顺序内存地址(缓存友好) auto start = std::chrono::high_resolution_clock::now(); float sum_row = 0.0f; for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { sum_row += matrix[i * N + j]; // stride-1访问 } } auto end = std::chrono::high_resolution_clock::now(); double row_ms = std::chrono::duration<double, std::milli>(end - start).count(); // 列主序访问:stride-N访问(缓存不友好) start = std::chrono::high_resolution_clock::now(); float sum_col = 0.0f; for (int j = 0; j < N; j++) { for (int i = 0; i < N; i++) { sum_col += matrix[i * N + j]; // stride-N访问(缓存缺失!) } } end = std::chrono::high_resolution_clock::now(); double col_ms = std::chrono::duration<double, std::milli>(end - start).count(); std::cout << "行主序(缓存友好): " << row_ms << " ms" << std::endl; std::cout << "列主序(缓存不友好): " << col_ms << " ms" << std::endl; std::cout << "变慢: " << col_ms / row_ms << "x" << std::endl; std::cout << "(两个和: " << sum_row << ", " << sum_col << ")" << std::endl; return 0; }