Skip to content

Triton和TPU

CUDA C强大但冗长。Triton允许你用Python写GPU内核。TPU提供了与GPU不同的权衡。本文件涵盖Triton内核编程、Flash Attention案例研究、TPU架构与JAX/Pallas,以及如何选择正确的工具。关于Vulkan和跨平台GPU计算,参见文件07。

  • 前一文件教了CUDA C中的GPU编程。本文件攀登抽象阶梯:Triton用Python以20%的精力给予80%的CUDA性能。TPU和Vulkan为特定用例提供替代硬件目标。

Triton:Python中的GPU内核

  • Triton(OpenAI)是一种用于编写GPU内核的基于Python的语言。你不是像CUDA那样考虑单个线程,而是考虑数据。Triton编译器自动处理线程映射、内存合并、共享内存管理和许多优化。

  • 为什么Triton重要:CUDA C需要深入了解warp调度、共享内存bank冲突、寄存器压力和合并模式。Triton抽象了其中大部分,使GPU内核开发对了解Python但不懂系统编程的ML研究者变得可及。

你的第一个Triton内核

import triton
import triton.language as tl
import torch

@triton.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,  # 编译时常量
):
    # 每个程序实例处理一个BLOCK_SIZE个元素的块
    pid = tl.program_id(axis=0)  # 我是哪个block?
    block_start = pid * BLOCK_SIZE

    # 此块的偏移量
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # 处理n_elements不是BLOCK_SIZE倍数情况的掩码
    mask = offsets < n_elements

    # 加载数据(掩码:越界读取返回0)
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # 计算
    output = x + y

    # 存储结果
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n_elements = output.numel()

    # 发射:每block一个程序
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    return output


# 使用
x = torch.randn(1000000, device='cuda')
y = torch.randn(1000000, device='cuda')
z = add(x, y)
  • 与CUDA的关键差异
    • 无显式线程管理。你按(程序)思考,而非线程。
    • tl.arange(0, BLOCK_SIZE)为整个block创建偏移向量。对此向量的所有操作隐式向量化。
    • mask处理边界条件(如AVX-512掩码寄存器,文件03)。不需要标量清理循环。
    • tl.loadtl.store自动处理合并访问。
    • @triton.jit在首次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的内核。

Triton Softmax内核

  • Softmax是一个很好的Triton示例,因为它需要多次遍历数据(max、subtract、exp、sum、divide),并受益于在遍历间将数据保持在SRAM(共享内存)中:
@triton.jit
def softmax_kernel(
    output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    # 每个程序处理一行
    row_idx = tl.program_id(0)
    row_start = input_ptr + row_idx * input_row_stride

    # 加载该行
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))

    # Softmax:max数值稳定,然后exp,再归一化
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    # 存储结果
    output_start = output_ptr + row_idx * output_row_stride
    tl.store(output_start + col_offsets, softmax_output, mask=mask)
  • 在PyTorch中,F.softmax(x, dim=-1)发射3个独立内核(max、exp-and-sum、divide),每个从全局内存读写。Triton版本在一个内核中完成一切,将数据保持在寄存器/SRAM中。这种内核融合是自定义Triton内核可以比PyTorch内置操作快2-4倍的原因。

Triton自动调优

  • Triton支持自动调优:尝试多种配置并选择最快的:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),
    ],
    key=['M', 'N', 'K'],  # 当这些变化时重新调优
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):
    ...
  • Triton在实际硬件上对每种配置进行基准测试并选择最快的。最优分块大小取决于GPU架构、矩阵维度和内存布局——自动调优无需手动实验即可找到它们。

Triton vs CUDA:何时使用

Triton CUDA C
语言 Python C/C++
抽象 块级 线程级
开发速度 快(每内核10-50行) 慢(100-500行)
性能上限 手工调优CUDA的~80-95% 100%(完全硬件控制)
共享内存 自动 手动
合并 自动 手动
Warp级原语 有限 完全(shuffle、vote等)
硬件支持 仅NVIDIA(AMD实验性) 仅NVIDIA
  • 使用Triton用于:融合内核、自定义注意力模式、激活函数、大多数ML研究的内核需求。
  • 使用CUDA C用于:最大性能(最后5-20%)、warp级原语、复杂数据依赖并行、当Triton无法表达你的模式时。

案例研究:Flash Attention

  • Flash Attention(Dao et al., 2022)是近期ML中影响最大的自定义内核。它以\(O(n)\)而非\(O(n^2)\)内存计算注意力,从而支持长得多的序列。

  • 问题:标准注意力计算\(\text{softmax}(QK^T / \sqrt{d}) \cdot V\)\(QK^T\)矩阵是\(n \times n\),其中\(n\)是序列长度。对于\(n = 128K\),此矩阵为\(128K \times 128K \times 4\)字节 = 64 GB。放不进GPU内存。

  • 洞察:不需要物化完整的\(n \times n\)矩阵。以分块计算注意力:加载一块\(Q\)、一块\(K\),计算它们的部分注意力分数,累积,然后移动到下一块。\(n \times n\)矩阵从不被完全物化——SRAM中一次只存在一个分块。

  • 在线softmax:tricky的部分是softmax,它需要知道整行的最大值(用于数值稳定性)。Flash Attention使用在线softmax技巧:维护一个运行最大值并在发现新的最大值时重新缩放先前计算的值。这允许增量计算softmax,一次一个分块。

  • 算法:

对于Q行的每个block:
    对于K列的每个block:
        1. 从HBM加载Q_block到SRAM
        2. 从HBM加载K_block到SRAM
        3. 在SRAM中计算S_block = Q_block @ K_block.T
        4. 更新运行最大值,重新缩放先前结果
        5. 计算exp(S_block - running_max)
        6. 更新运行总和和输出累加器
    加载V_block并计算最终输出
    将输出block写回HBM
  • 为什么快:内部循环完全在SRAM(共享内存)中运行。全局内存(HBM)仅被访问以加载Q、K、V的block和写入最终输出。数据重用因子与SRAM大小成正比,而访问SRAM比HBM快约100倍。

  • Flash Attention在Triton和CUDA C中都有实现。CUDA版本更快(效率高约10%),但Triton版本更可读且可修改,这对研究新的注意力变体很重要。

TPU架构

  • TPU(Tensor Processing Units)是Google的定制ML加速器。它们采用与GPU根本不同的方法:

  • 脉动阵列:TPU的核心计算单元是矩阵乘法单元(MXU),一个128×128或256×256脉动阵列,通过让数据流过一个乘加单元网格来计算矩阵乘法。数据从边缘进入并在阵列中传播,每个单元执行一次乘加并将结果传递给下一个。

  • 与GPU(调度数千独立线程)不同,脉动阵列是单个确定的數據流。没有线程调度、没有warp发散、没有分支预测。这种简洁性使MXU在矩阵乘法方面极为能效。

  • HBM:TPU使用与GPU相同的高带宽内存。TPU v5e每芯片有16 GB HBM2e;TPU v5p有95 GB HBM2e。

  • ICI(Inter-Chip Interconnect):TPU pod通过自定义高速网络连接数百个TPU。跨TPU pod的数据并行和模型并行(第6章)由JAX原生支持。

  • BFloat16:TPU率先使用bfloat16(第13章文件02)。BF16与float32具有相同的指数范围(防止训练期间溢出)但尾数精度更低。这种权衡对ML是理想的,因为梯度值跨越很宽的范围但不需要23位精度。

编程TPU:JAX和Pallas

  • TPU通过JAXXLA编程。你写Python/JAX代码,jax.jit将其编译到XLA HLO,XLA将HLO编译为TPU专用指令。无需CUDA,无需C++。
import jax
import jax.numpy as jnp

@jax.jit
def matmul(a, b):
    return jnp.dot(a, b)

# 这根据设备在CPU、GPU或TPU上运行
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
c = matmul(a, b)
  • Pallas是JAX的内核编写API——JAX的Triton等价物。它让你编写XLA为GPU或TPU编译的低级内核:
from jax.experimental import pallas as pl
import jax.numpy as jnp

def add_kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] + y_ref[...]

def add_pallas(x, y):
    return pl.pallas_call(
        add_kernel,
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
        grid=(x.shape[0] // 128,),
        in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
                  pl.BlockSpec((128,), lambda i: (i,))],
        out_specs=pl.BlockSpec((128,), lambda i: (i,)),
    )(x, y)
  • Pallas比Triton新且不够成熟,但它是编写TPU自定义内核的唯一方式(因为TPU不支持CUDA)。

GPU vs TPU

GPU(NVIDIA) TPU(Google)
可用性 任何云、本地部署 仅Google Cloud
编程 CUDA C、Triton、PyTorch JAX/XLA、Pallas
灵活性 通用计算 针对矩阵密集型ML优化
峰值矩阵乘法FLOPS 非常高(Tensor Cores) 非常高(MXU)
非矩阵乘法操作 较慢(通过向量单元路由,非MXU)
多芯片扩展 NVLink(8 GPU)、InfiniBand ICI(数千TPU,更紧密集成)
成本效率 有竞争力 大规模训练通常更便宜
生态 最大(PyTorch、TensorFlow、JAX) 以JAX为中心
  • 使用GPU用于:大多数ML工作负载、基于PyTorch的研究、推理服务、有显著非矩阵乘法计算的工作负载。
  • 使用TPU用于:大规模JAX训练(数千芯片)、Google Cloud上的成本敏感训练、以矩阵乘法为主的工作负载。

选择合适的工具

工作负载 最佳工具 原因
ML训练(PyTorch) NVIDIA GPU + CUDA/Triton 最大生态、最好工具
ML训练(JAX,大规模) TPU或NVIDIA GPU TPU在Google规模上成本更低,GPU灵活性高
自定义融合内核 Triton(Python)或CUDA C Triton用于开发速度,CUDA用于峰值性能
JAX自定义内核 Pallas 用于TPU的唯一选项,在GPU上也可用
跨平台推理 Vulkan(文件07)或ONNX Runtime 在任何GPU厂商上运行
移动/边缘推理 Metal(Apple)、Vulkan(Android)、NNAPI 平台特定加速器
浏览器推理 WebGPU(文件07) 浏览器中唯一选项
仅CPU推理 ONNX Runtime + AVX/NEON 无需GPU,使用SIMD(文件02-03)
新型硬件 厂商特定SDK 每种加速器有自己的工具链

编程任务(使用带GPU运行时的CoLab)

  1. 为向量加法编写并运行Triton内核。将其性能与PyTorch内置加法比较。

    import triton
    import triton.language as tl
    import torch
    import time
    
    @triton.jit
    def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
        pid = tl.program_id(0)
        offs = pid * BLOCK + tl.arange(0, BLOCK)
        mask = offs < n
        x = tl.load(x_ptr + offs, mask=mask)
        y = tl.load(y_ptr + offs, mask=mask)
        tl.store(out_ptr + offs, x + y, mask=mask)
    
    n = 10_000_000
    x = torch.randn(n, device='cuda')
    y = torch.randn(n, device='cuda')
    
    # Triton
    out_triton = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
    add_kernel[grid](x, y, out_triton, n, BLOCK=1024)
    
    # PyTorch
    out_torch = x + y
    
    # 验证正确性
    assert torch.allclose(out_triton, out_torch, atol=1e-5)
    
    # 基准测试
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(1000):
        add_kernel[grid](x, y, out_triton, n, BLOCK=1024)
    torch.cuda.synchronize()
    triton_time = (time.time() - start) / 1000
    
    start = time.time()
    for _ in range(1000):
        out_torch = x + y
    torch.cuda.synchronize()
    torch_time = (time.time() - start) / 1000
    
    print(f"Triton:  {triton_time*1000:.3f} ms")
    print(f"PyTorch: {torch_time*1000:.3f} ms")
    print(f"比率:   {torch_time/triton_time:.2f}x")
    

  2. 编写一个Triton融合内核,在单次遍历中做multiply + add + ReLU。与三个单独的PyTorch操作比较。

    import triton
    import triton.language as tl
    import torch
    import time
    
    @triton.jit
    def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr):
        pid = tl.program_id(0)
        offs = pid * BLOCK + tl.arange(0, BLOCK)
        mask = offs < n
        x = tl.load(x_ptr + offs, mask=mask)
        w = tl.load(w_ptr + offs, mask=mask)
        b = tl.load(b_ptr + offs, mask=mask)
        result = tl.maximum(x * w + b, 0.0)  # 融合:mul + add + relu
        tl.store(out_ptr + offs, result, mask=mask)
    
    n = 10_000_000
    x = torch.randn(n, device='cuda')
    w = torch.randn(n, device='cuda')
    b = torch.randn(n, device='cuda')
    
    # 融合(Triton)
    out_fused = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
    fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)
    
    # 未融合(PyTorch)
    out_unfused = torch.relu(x * w + b)
    
    assert torch.allclose(out_fused, out_unfused, atol=1e-5)
    
    # 基准测试
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(1000):
        fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)
    torch.cuda.synchronize()
    fused_time = (time.time() - start) / 1000
    
    start = time.time()
    for _ in range(1000):
        out_unfused = torch.relu(x * w + b)
    torch.cuda.synchronize()
    unfused_time = (time.time() - start) / 1000
    
    print(f"融合(Triton):    {fused_time*1000:.3f} ms")
    print(f"未融合(PyTorch): {unfused_time*1000:.3f} ms")
    print(f"加速比:           {unfused_time/fused_time:.2f}x")
    

  3. 测量JAX的XLA编译器如何自动融合操作。比较带jit和不带jit的操作链。

    import jax
    import jax.numpy as jnp
    import time
    
    def chain_ops(x):
        x = x * 2.0
        x = x + 1.0
        x = jnp.maximum(x, 0.0)  # ReLU
        x = x / jnp.sum(x)
        return x
    
    chain_jit = jax.jit(chain_ops)
    x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000))
    
    # 预热
    _ = chain_jit(x)
    jax.block_until_ready(_)
    
    # Eager(每个操作是单独的内核发射)
    start = time.time()
    for _ in range(100):
        y = chain_ops(x)
    jax.block_until_ready(y)
    eager_time = (time.time() - start) / 100
    
    # JIT(XLA融合操作)
    start = time.time()
    for _ in range(100):
        y = chain_jit(x)
    jax.block_until_ready(y)
    jit_time = (time.time() - start) / 100
    
    print(f"Eager: {eager_time*1000:.2f} ms")
    print(f"JIT:   {jit_time*1000:.2f} ms")
    print(f"加速比: {eager_time/jit_time:.1f}x(XLA将4个操作融合为1个内核)")