Triton和TPU¶
CUDA C强大但冗长。Triton允许你用Python写GPU内核。TPU提供了与GPU不同的权衡。本文件涵盖Triton内核编程、Flash Attention案例研究、TPU架构与JAX/Pallas,以及如何选择正确的工具。关于Vulkan和跨平台GPU计算,参见文件07。
- 前一文件教了CUDA C中的GPU编程。本文件攀登抽象阶梯:Triton用Python以20%的精力给予80%的CUDA性能。TPU和Vulkan为特定用例提供替代硬件目标。
Triton:Python中的GPU内核¶
-
Triton(OpenAI)是一种用于编写GPU内核的基于Python的语言。你不是像CUDA那样考虑单个线程,而是考虑数据块。Triton编译器自动处理线程映射、内存合并、共享内存管理和许多优化。
-
为什么Triton重要:CUDA C需要深入了解warp调度、共享内存bank冲突、寄存器压力和合并模式。Triton抽象了其中大部分,使GPU内核开发对了解Python但不懂系统编程的ML研究者变得可及。
你的第一个Triton内核¶
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr, # 编译时常量
):
# 每个程序实例处理一个BLOCK_SIZE个元素的块
pid = tl.program_id(axis=0) # 我是哪个block?
block_start = pid * BLOCK_SIZE
# 此块的偏移量
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 处理n_elements不是BLOCK_SIZE倍数情况的掩码
mask = offsets < n_elements
# 加载数据(掩码:越界读取返回0)
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# 计算
output = x + y
# 存储结果
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
# 发射:每block一个程序
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
# 使用
x = torch.randn(1000000, device='cuda')
y = torch.randn(1000000, device='cuda')
z = add(x, y)
- 与CUDA的关键差异:
- 无显式线程管理。你按块(程序)思考,而非线程。
tl.arange(0, BLOCK_SIZE)为整个block创建偏移向量。对此向量的所有操作隐式向量化。mask处理边界条件(如AVX-512掩码寄存器,文件03)。不需要标量清理循环。tl.load和tl.store自动处理合并访问。@triton.jit在首次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的内核。
Triton Softmax内核¶
- Softmax是一个很好的Triton示例,因为它需要多次遍历数据(max、subtract、exp、sum、divide),并受益于在遍历间将数据保持在SRAM(共享内存)中:
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr,
):
# 每个程序处理一行
row_idx = tl.program_id(0)
row_start = input_ptr + row_idx * input_row_stride
# 加载该行
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))
# Softmax:max数值稳定,然后exp,再归一化
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# 存储结果
output_start = output_ptr + row_idx * output_row_stride
tl.store(output_start + col_offsets, softmax_output, mask=mask)
- 在PyTorch中,
F.softmax(x, dim=-1)发射3个独立内核(max、exp-and-sum、divide),每个从全局内存读写。Triton版本在一个内核中完成一切,将数据保持在寄存器/SRAM中。这种内核融合是自定义Triton内核可以比PyTorch内置操作快2-4倍的原因。
Triton自动调优¶
- Triton支持自动调优:尝试多种配置并选择最快的:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),
],
key=['M', 'N', 'K'], # 当这些变化时重新调优
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):
...
- Triton在实际硬件上对每种配置进行基准测试并选择最快的。最优分块大小取决于GPU架构、矩阵维度和内存布局——自动调优无需手动实验即可找到它们。
Triton vs CUDA:何时使用¶
| Triton | CUDA C | |
|---|---|---|
| 语言 | Python | C/C++ |
| 抽象 | 块级 | 线程级 |
| 开发速度 | 快(每内核10-50行) | 慢(100-500行) |
| 性能上限 | 手工调优CUDA的~80-95% | 100%(完全硬件控制) |
| 共享内存 | 自动 | 手动 |
| 合并 | 自动 | 手动 |
| Warp级原语 | 有限 | 完全(shuffle、vote等) |
| 硬件支持 | 仅NVIDIA(AMD实验性) | 仅NVIDIA |
- 使用Triton用于:融合内核、自定义注意力模式、激活函数、大多数ML研究的内核需求。
- 使用CUDA C用于:最大性能(最后5-20%)、warp级原语、复杂数据依赖并行、当Triton无法表达你的模式时。
案例研究:Flash Attention¶
-
Flash Attention(Dao et al., 2022)是近期ML中影响最大的自定义内核。它以\(O(n)\)而非\(O(n^2)\)内存计算注意力,从而支持长得多的序列。
-
问题:标准注意力计算\(\text{softmax}(QK^T / \sqrt{d}) \cdot V\)。\(QK^T\)矩阵是\(n \times n\),其中\(n\)是序列长度。对于\(n = 128K\),此矩阵为\(128K \times 128K \times 4\)字节 = 64 GB。放不进GPU内存。
-
洞察:不需要物化完整的\(n \times n\)矩阵。以分块计算注意力:加载一块\(Q\)、一块\(K\),计算它们的部分注意力分数,累积,然后移动到下一块。\(n \times n\)矩阵从不被完全物化——SRAM中一次只存在一个分块。
-
在线softmax:tricky的部分是softmax,它需要知道整行的最大值(用于数值稳定性)。Flash Attention使用在线softmax技巧:维护一个运行最大值并在发现新的最大值时重新缩放先前计算的值。这允许增量计算softmax,一次一个分块。
-
算法:
对于Q行的每个block:
对于K列的每个block:
1. 从HBM加载Q_block到SRAM
2. 从HBM加载K_block到SRAM
3. 在SRAM中计算S_block = Q_block @ K_block.T
4. 更新运行最大值,重新缩放先前结果
5. 计算exp(S_block - running_max)
6. 更新运行总和和输出累加器
加载V_block并计算最终输出
将输出block写回HBM
-
为什么快:内部循环完全在SRAM(共享内存)中运行。全局内存(HBM)仅被访问以加载Q、K、V的block和写入最终输出。数据重用因子与SRAM大小成正比,而访问SRAM比HBM快约100倍。
-
Flash Attention在Triton和CUDA C中都有实现。CUDA版本更快(效率高约10%),但Triton版本更可读且可修改,这对研究新的注意力变体很重要。
TPU架构¶
-
TPU(Tensor Processing Units)是Google的定制ML加速器。它们采用与GPU根本不同的方法:
-
脉动阵列:TPU的核心计算单元是矩阵乘法单元(MXU),一个128×128或256×256脉动阵列,通过让数据流过一个乘加单元网格来计算矩阵乘法。数据从边缘进入并在阵列中传播,每个单元执行一次乘加并将结果传递给下一个。
-
与GPU(调度数千独立线程)不同,脉动阵列是单个确定的數據流。没有线程调度、没有warp发散、没有分支预测。这种简洁性使MXU在矩阵乘法方面极为能效。
-
HBM:TPU使用与GPU相同的高带宽内存。TPU v5e每芯片有16 GB HBM2e;TPU v5p有95 GB HBM2e。
-
ICI(Inter-Chip Interconnect):TPU pod通过自定义高速网络连接数百个TPU。跨TPU pod的数据并行和模型并行(第6章)由JAX原生支持。
-
BFloat16:TPU率先使用bfloat16(第13章文件02)。BF16与float32具有相同的指数范围(防止训练期间溢出)但尾数精度更低。这种权衡对ML是理想的,因为梯度值跨越很宽的范围但不需要23位精度。
编程TPU:JAX和Pallas¶
- TPU通过JAX和XLA编程。你写Python/JAX代码,
jax.jit将其编译到XLA HLO,XLA将HLO编译为TPU专用指令。无需CUDA,无需C++。
import jax
import jax.numpy as jnp
@jax.jit
def matmul(a, b):
return jnp.dot(a, b)
# 这根据设备在CPU、GPU或TPU上运行
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
c = matmul(a, b)
- Pallas是JAX的内核编写API——JAX的Triton等价物。它让你编写XLA为GPU或TPU编译的低级内核:
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
def add_pallas(x, y):
return pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(x.shape[0] // 128,),
in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
pl.BlockSpec((128,), lambda i: (i,))],
out_specs=pl.BlockSpec((128,), lambda i: (i,)),
)(x, y)
- Pallas比Triton新且不够成熟,但它是编写TPU自定义内核的唯一方式(因为TPU不支持CUDA)。
GPU vs TPU¶
| GPU(NVIDIA) | TPU(Google) | |
|---|---|---|
| 可用性 | 任何云、本地部署 | 仅Google Cloud |
| 编程 | CUDA C、Triton、PyTorch | JAX/XLA、Pallas |
| 灵活性 | 通用计算 | 针对矩阵密集型ML优化 |
| 峰值矩阵乘法FLOPS | 非常高(Tensor Cores) | 非常高(MXU) |
| 非矩阵乘法操作 | 好 | 较慢(通过向量单元路由,非MXU) |
| 多芯片扩展 | NVLink(8 GPU)、InfiniBand | ICI(数千TPU,更紧密集成) |
| 成本效率 | 有竞争力 | 大规模训练通常更便宜 |
| 生态 | 最大(PyTorch、TensorFlow、JAX) | 以JAX为中心 |
- 使用GPU用于:大多数ML工作负载、基于PyTorch的研究、推理服务、有显著非矩阵乘法计算的工作负载。
- 使用TPU用于:大规模JAX训练(数千芯片)、Google Cloud上的成本敏感训练、以矩阵乘法为主的工作负载。
选择合适的工具¶
| 工作负载 | 最佳工具 | 原因 |
|---|---|---|
| ML训练(PyTorch) | NVIDIA GPU + CUDA/Triton | 最大生态、最好工具 |
| ML训练(JAX,大规模) | TPU或NVIDIA GPU | TPU在Google规模上成本更低,GPU灵活性高 |
| 自定义融合内核 | Triton(Python)或CUDA C | Triton用于开发速度,CUDA用于峰值性能 |
| JAX自定义内核 | Pallas | 用于TPU的唯一选项,在GPU上也可用 |
| 跨平台推理 | Vulkan(文件07)或ONNX Runtime | 在任何GPU厂商上运行 |
| 移动/边缘推理 | Metal(Apple)、Vulkan(Android)、NNAPI | 平台特定加速器 |
| 浏览器推理 | WebGPU(文件07) | 浏览器中唯一选项 |
| 仅CPU推理 | ONNX Runtime + AVX/NEON | 无需GPU,使用SIMD(文件02-03) |
| 新型硬件 | 厂商特定SDK | 每种加速器有自己的工具链 |
编程任务(使用带GPU运行时的CoLab)¶
-
为向量加法编写并运行Triton内核。将其性能与PyTorch内置加法比较。
import triton import triton.language as tl import torch import time @triton.jit def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) y = tl.load(y_ptr + offs, mask=mask) tl.store(out_ptr + offs, x + y, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') y = torch.randn(n, device='cuda') # Triton out_triton = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) add_kernel[grid](x, y, out_triton, n, BLOCK=1024) # PyTorch out_torch = x + y # 验证正确性 assert torch.allclose(out_triton, out_torch, atol=1e-5) # 基准测试 torch.cuda.synchronize() start = time.time() for _ in range(1000): add_kernel[grid](x, y, out_triton, n, BLOCK=1024) torch.cuda.synchronize() triton_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_torch = x + y torch.cuda.synchronize() torch_time = (time.time() - start) / 1000 print(f"Triton: {triton_time*1000:.3f} ms") print(f"PyTorch: {torch_time*1000:.3f} ms") print(f"比率: {torch_time/triton_time:.2f}x") -
编写一个Triton融合内核,在单次遍历中做multiply + add + ReLU。与三个单独的PyTorch操作比较。
import triton import triton.language as tl import torch import time @triton.jit def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) w = tl.load(w_ptr + offs, mask=mask) b = tl.load(b_ptr + offs, mask=mask) result = tl.maximum(x * w + b, 0.0) # 融合:mul + add + relu tl.store(out_ptr + offs, result, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') w = torch.randn(n, device='cuda') b = torch.randn(n, device='cuda') # 融合(Triton) out_fused = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) # 未融合(PyTorch) out_unfused = torch.relu(x * w + b) assert torch.allclose(out_fused, out_unfused, atol=1e-5) # 基准测试 torch.cuda.synchronize() start = time.time() for _ in range(1000): fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) torch.cuda.synchronize() fused_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_unfused = torch.relu(x * w + b) torch.cuda.synchronize() unfused_time = (time.time() - start) / 1000 print(f"融合(Triton): {fused_time*1000:.3f} ms") print(f"未融合(PyTorch): {unfused_time*1000:.3f} ms") print(f"加速比: {unfused_time/fused_time:.2f}x") -
测量JAX的XLA编译器如何自动融合操作。比较带jit和不带jit的操作链。
import jax import jax.numpy as jnp import time def chain_ops(x): x = x * 2.0 x = x + 1.0 x = jnp.maximum(x, 0.0) # ReLU x = x / jnp.sum(x) return x chain_jit = jax.jit(chain_ops) x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000)) # 预热 _ = chain_jit(x) jax.block_until_ready(_) # Eager(每个操作是单独的内核发射) start = time.time() for _ in range(100): y = chain_ops(x) jax.block_until_ready(y) eager_time = (time.time() - start) / 100 # JIT(XLA融合操作) start = time.time() for _ in range(100): y = chain_jit(x) jax.block_until_ready(y) jit_time = (time.time() - start) / 100 print(f"Eager: {eager_time*1000:.2f} ms") print(f"JIT: {jit_time*1000:.2f} ms") print(f"加速比: {eager_time/jit_time:.1f}x(XLA将4个操作融合为1个内核)")