Skip to content

测试与质量保证

测试是确保代码能正常工作的手段——不仅是现在,而是在每次变更之后。本文涵盖测试金字塔、pytest 单元测试、模拟、ML 特定代码测试、CI/CD 流水线、代码检查、格式化以及代码审查,这些实践能在 bug 进入生产环境之前就发现它们。

  • 机器学习代码的测试不足是出了名的。“能训练,所以它是对的”是普遍心态。这导致了隐性的 bug:数据加载器打乱顺序出错,损失函数符号写反,预处理步骤悄悄丢弃了 5% 的数据。这些 bug 不会让程序崩溃。它们只会让你的模型悄无声息地变差,而你却白白浪费数周时间去调试那些“应该更高”的指标。

  • 测试不是负担。它是在不破坏东西的前提下快速前进的最快方式。

测试金字塔

  • 测试按层次组织,从快速且范围窄到慢速且范围广:

    • 单元测试(底座):独立测试单个函数和类。快(毫秒级),数量多(成百上千个)。比如:“normalise_image 产生的值是否在 [0, 1] 内?”

    • 集成测试(中间层):测试组件之间是否能协同工作。较慢(秒级)。比如:“数据加载器产生的 batch 格式是否符合模型的预期?”

    • 端到端测试(顶层):测试从输入到输出的完整流水线。慢(分钟级)。比如:“python train.py --config test.yaml 是否能无错误地完成并生成一个有效的 checkpoint?”

  • 金字塔形状意味着:多写单元测试,少写集成测试,只写少量端到端测试。单元测试能捕捉大多数 bug,并且在几秒钟内运行完成。端到端测试能发现集成问题,但运行缓慢且脆弱。

使用 pytest 进行单元测试

  • pytest 是 Python 的标准测试框架。一个测试是以 test_ 开头的函数,放在以 test_ 开头的文件中:
# tests/test_utils.py

def test_normalise_image():
    import numpy as np
    image = np.array([0, 128, 255], dtype=np.uint8)
    result = normalise_image(image, mean=128, std=128)
    assert result.min() >= -1.0
    assert result.max() <= 1.0
    assert abs(result[1]) < 1e-6  # 128 按 mean=128 归一化后应约等于 0

def test_normalise_empty():
    import numpy as np
    image = np.array([], dtype=np.uint8)
    result = normalise_image(image, mean=128, std=128)
    assert len(result) == 0
pytest tests/                     # 运行所有测试
pytest tests/test_utils.py        # 运行单个文件
pytest -v                         # 详细输出
pytest -x                         # 遇到第一个失败就停止
pytest -k "normalise"             # 运行名称匹配模式的测试
pytest --tb=short                 # 更短的回溯信息

Fixture

  • Fixture 为测试提供可重用的设置代码。不用在每个测试中重复设置代码,而是定义一次:
import pytest

@pytest.fixture
def sample_dataset():
    """创建一个小型数据集用于测试。"""
    return {
        "inputs": torch.randn(10, 3, 32, 32),
        "labels": torch.randint(0, 10, (10,))
    }

@pytest.fixture
def trained_model():
    """加载一个小型预训练模型。"""
    model = SmallModel()
    model.load_state_dict(torch.load("tests/fixtures/small_model.pt"))
    return model

def test_model_output_shape(trained_model, sample_dataset):
    output = trained_model(sample_dataset["inputs"])
    assert output.shape == (10, 10)  # batch_size x num_classes
  • Fixture 可以有作用域scope="function"(默认,每个测试独立),scope="module"(每个文件一次),scope="session"(整个测试运行一次)。对于像加载模型这样开销大的设置,使用 scope="session"

参数化测试

  • 用多组输入测试同一个函数,无需重复代码:
@pytest.mark.parametrize("input,expected", [
    ([1, 2, 3], 6),
    ([], 0),
    ([-1, 1], 0),
    ([1000000, 1000000], 2000000),
])
def test_sum(input, expected):
    assert sum(input) == expected

模拟与打补丁

  • 模拟(Mocking) 在测试期间用假的依赖替换真实的依赖。这让你可以独立测试一个函数,无需数据库、API 或 GPU。
from unittest.mock import patch, MagicMock

def test_training_logs_metrics():
    mock_logger = MagicMock()

    with patch("my_project.training.trainer.wandb") as mock_wandb:
        trainer = Trainer(logger=mock_logger)
        trainer.train_one_epoch()

        # 验证训练器记录了指标
        mock_logger.log.assert_called()
        # 验证它记录了 loss 值
        call_args = mock_logger.log.call_args
        assert "loss" in call_args[1]
  • 何时模拟:外部服务(API、数据库、云存储),开销大的操作(GPU 计算、大文件 I/O),以及非确定性行为(随机数生成器、时间戳)。

  • 何时不模拟:你自己的代码。如果你模拟了一切,你的测试验证的是模拟对象的行为符合预期,而不是你的代码真的能工作。在边界处模拟,直接测试你自己的逻辑。

测试 ML 代码

  • ML 代码有独特的测试挑战:输出是概率性的,训练很慢,“正确”的定义是模糊的。

确定性种子

  • 在所有地方设置随机种子,使测试可复现:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

数值容差

  • 浮点数比较需要容差(见第13章,IEEE 754):
# 糟糕:精确比较会因浮点数而失败
assert model_output == 0.5

# 良好:近似比较
import numpy as np
assert np.isclose(model_output, 0.5, atol=1e-5)

# 针对张量
assert torch.allclose(output, expected, atol=1e-4)

ML 代码中应该测试什么

  • 形状测试:验证输出具有预期的维度。
def test_model_output_shape():
    model = MyModel(d_model=256, n_classes=10)
    x = torch.randn(8, 32, 256)  # batch=8, seq=32, dim=256
    output = model(x)
    assert output.shape == (8, 10)
  • 梯度流:验证可训练参数的梯度非零。
def test_gradients_flow():
    model = MyModel()
    x = torch.randn(4, 3, 32, 32)
    y = torch.randint(0, 10, (4,))

    output = model(x)
    loss = F.cross_entropy(output, y)
    loss.backward()

    for name, param in model.named_parameters():
        assert param.grad is not None, f"{name} 没有梯度"
        assert param.grad.abs().sum() > 0, f"{name} 的梯度为零"
  • 过拟合一个 batch:模型应该能够记住单个 batch。如果不能,说明有根本性问题。
def test_overfit_one_batch():
    model = MyModel()
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
    x, y = get_single_batch()

    for _ in range(100):
        loss = F.cross_entropy(model(x), y)
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()

    assert loss.item() < 0.01, f"无法过拟合一个 batch: loss={loss.item()}"
  • 数据验证:验证数据加载产生有效的输出。
def test_dataset_basics():
    dataset = MyDataset("tests/fixtures/small_data.csv")
    assert len(dataset) > 0
    x, y = dataset[0]
    assert x.shape == (3, 224, 224)
    assert 0 <= y < 10
    assert not torch.isnan(x).any()
    assert not torch.isinf(x).any()
  • 确定性:相同输入 + 相同种子 → 相同输出。
def test_determinism():
    set_seed(42)
    output1 = model(input_data)
    set_seed(42)
    output2 = model(input_data)
    assert torch.allclose(output1, output2)

CI/CD 流水线

  • 持续集成(CI):在每次提交或 PR 时自动运行测试。如果测试失败,PR 就不能合并。这能防止损坏的代码进入 main 分支。

  • GitHub Actions 示例.github/workflows/ci.yml):

name: CI
on: [push, pull_request]

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: "3.11"
      - run: pip install -e ".[dev]"
      - run: ruff check src/
      - run: mypy src/
      - run: pytest tests/ -v --tb=short
  • Pre-commit 钩子:在每个提交之前(本地)运行检查,在代码进入 CI 之前就发现问题:
# .pre-commit-config.yaml
repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    rev: v0.3.0
    hooks:
      - id: ruff
        args: [--fix]
      - id: ruff-format
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.5.0
    hooks:
      - id: trailing-whitespace
      - id: end-of-file-fixer
      - id: check-yaml
pip install pre-commit
pre-commit install    # 现在每次 git commit 时都会运行钩子

代码检查与格式化

  • 代码检查(Linting) 可以在不运行代码的情况下发现 bug 和风格问题。格式化(Formatting) 自动强制执行一致的风格。

  • Ruff:一个快速的 Python 代码检查器和格式化工具(在一个工具中替代了 flake8、isort 和 black):

ruff check src/          # 代码检查
ruff check --fix src/    # 代码检查并自动修复
ruff format src/         # 格式化
  • mypy:Python 静态类型检查器。在运行之前就能捕捉类型错误:
mypy src/
# src/model.py:42: error: Argument 1 to "forward" has incompatible type "int"; expected "Tensor"
  • 类型注解使代码自文档化,并能捕捉 bug:
def train(
    model: nn.Module,
    dataloader: DataLoader,
    optimiser: torch.optim.Optimizer,
    num_epochs: int = 10,
) -> float:
    """训练模型并返回最终 loss。"""
    ...

代码审查最佳实践

  • 对于作者

    • 在请求审查之前,先自己 review 自己的 diff。你会发现一些明显的问题。
    • 保持 PR 小而聚焦。每个 PR 只解决一个关注点。
    • 写清晰的描述:改了什么,为什么改,如何测试。
    • 回应每一条评论(即使只是“已处理”)。
  • 对于审查者

    • 保持友善。批评代码,而不是人。说“这里可以更清晰”而不是“这令人困惑”。
    • 区分阻塞性问题(bug、安全)和建议(风格、命名)。使用标签:“nit:”、“suggestion:”、“blocking:”。
    • 用提问代替命令。“如果这个列表是空的会怎样?”比“处理空列表情况”更有帮助。
    • 及时批准。一个等待数天才能被审查的 PR 会阻塞作者,并且鼓励大批量的 PR(而这种 PR 更难审查)。