Skip to content

代码库设计与模式

良好的代码库设计是区分研究原型和生产级软件的关键。本文涵盖项目结构、清洁代码原则、与机器学习相关的设计模式、配置管理、日志记录、API 设计以及打包。

  • 大多数机器学习代码最初都是一个 Jupyter notebook。随着 notebook 不断增长、被复制、修改、分享,最终它会变成一团无法维护的乱麻:全局变量、死掉的单元格和魔法数字四处散落。代码库设计就是一套组织代码的纪律,让项目在成长过程中始终保持可理解和可修改。

  • 这不是为了遵守规则而遵守规则。它的目标是缩短从“我想修改 X”到“X 被修改成功并能正常工作”之间的时间。在精心设计的代码库中,这段时间是以分钟计算的;而在设计糟糕的代码库中,这可能需要几天时间在无文档的意大利面式代码中考古。

项目结构

  • 一致的项目布局能让任何人(包括未来的你自己)快速浏览代码库。
my_project/
├── src/my_project/       # 源代码(可导入的包)
│   ├── __init__.py
│   ├── data/             # 数据加载与预处理
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── transforms.py
│   ├── models/           # 模型架构
│   │   ├── __init__.py
│   │   ├── transformer.py
│   │   └── layers.py
│   ├── training/         # 训练循环、优化器
│   │   ├── __init__.py
│   │   ├── trainer.py
│   │   └── losses.py
│   └── utils/            # 共享工具
│       ├── __init__.py
│       └── logging.py
├── configs/              # 配置文件
│   ├── base.yaml
│   └── experiment_1.yaml
├── scripts/              # 入口脚本(训练、评估、服务)
│   ├── train.py
│   ├── evaluate.py
│   └── serve.py
├── tests/                # 测试文件(镜像 src/ 的结构)
│   ├── test_dataset.py
│   ├── test_model.py
│   └── test_trainer.py
├── notebooks/            # 仅用于探索(非生产代码)
├── pyproject.toml        # 项目元数据和依赖
├── README.md
├── .gitignore
└── Dockerfile
  • src/ 布局:将源代码放在 src/my_project/ 下可以防止意外从当前目录导入(这能掩盖在生产环境中才会暴露的导入错误)。开发时使用 pip install -e . 安装。

  • 单仓库 vs 多仓库单仓库(monorepo) 将所有相关项目放在一个代码仓库中(跨项目修改更容易,CI 共享)。多仓库(multi-repo) 为每个项目提供独立的仓库(边界更清晰,版本独立)。大多数机器学习团队从单仓库开始,需要时再拆分。

  • 脚本 vs 库:将入口点(train.pyevaluate.py)放在 scripts/ 中。将可复用逻辑放在 src/ 中。一个训练脚本应该大约 50 行:解析配置、构建数据集、构建模型、构建训练器、开始训练。所有复杂性都存在于库中。

清洁代码原则

  • 命名:这是你能做的最有影响力的一件事。变量名为 x 迫使你阅读周围代码来理解它;变量名为 learning_rate 则是自文档的。
# 糟糕
def proc(d, n, lr):
    for i in range(n):
        for k, v in d.items():
            v -= lr * g[k]

# 良好
def update_parameters(parameters, num_steps, learning_rate):
    for step in range(num_steps):
        for name, param in parameters.items():
            param -= learning_rate * gradients[name]
  • 单一职责原则:每个函数/类只做一件事。名为 load_data_and_train_model 的函数同时做了两件事,应该拆分。这样每个部分都可以独立测试、复用和理解。

  • DRY(不要重复自己)——但不要过早抽象。如果你复制粘贴了三次代码,就把它提取成一个函数。但对于只用过一次的代码,不要创建抽象。过早的抽象比重复更糟糕:它在没有经过验证的好处下增加了复杂性。

# 过早的抽象(只有一个用例,过度设计)
class AbstractDataTransformPipelineFactory:
    ...

# 恰到好处(直接、清晰,且在三处使用)
def normalise_image(image, mean, std):
    return (image - mean) / std
  • 魔法数字:永远不要使用未经解释的字面常量。
# 糟糕
if len(batch) > 32:
    split_batch(batch, 32)

# 良好
MAX_BATCH_SIZE = 32
if len(batch) > MAX_BATCH_SIZE:
    split_batch(batch, MAX_BATCH_SIZE)
  • 函数应该短小:如果一个函数不能在一屏(约 30 行)内显示完整,那么它可能做了太多事情。将逻辑块提取为具有描述性名称的辅助函数。这样函数体本身读起来就像一个高层摘要。

机器学习的设计模式

  • 设计模式是针对常见问题的可复用解决方案。以下是与机器学习代码库最相关的几种模式:

  • 工厂模式:在不指定具体类的情况下创建对象。当你的配置中说 model: "transformer" 而你需要实例化正确的类时,这非常有用:

MODEL_REGISTRY = {
    "transformer": TransformerModel,
    "cnn": CNNModel,
    "mlp": MLPModel,
}

def build_model(config):
    model_cls = MODEL_REGISTRY[config["model"]]
    return model_cls(**config["model_params"])
  • 这使训练脚本与具体的模型实现解耦。添加一个新模型只需在注册表中添加一行,而无需修改训练循环。

  • 策略模式:在运行时切换算法。对损失函数、优化器、学习率调度器很有用:

LOSS_FUNCTIONS = {
    "mse": nn.MSELoss,
    "cross_entropy": nn.CrossEntropyLoss,
    "focal": FocalLoss,
}

loss_fn = LOSS_FUNCTIONS[config["loss"]]()
  • 观察者模式(回调/钩子):让模块在无需紧耦合的情况下对事件做出响应。训练框架(PyTorch Lightning、Keras)广泛使用回调:
class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience
        self.best_loss = float('inf')
        self.counter = 0

    def on_epoch_end(self, epoch, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return "stop"
  • 依赖注入:将依赖传递给函数/类,而不是在内部创建它们。这使得测试变得容易(注入一个 mock 对象)并且配置灵活:
# 糟糕:硬编码依赖
class Trainer:
    def __init__(self):
        self.logger = WandbLogger()  # 离开 W&B 就无法测试

# 良好:依赖注入
class Trainer:
    def __init__(self, logger):
        self.logger = logger  # 可以注入任何 logger,包括 mock

配置管理

  • 硬编码超参数、文件路径和模型设置会使实验无法复现,修改也变得痛苦。将配置外部化到文件中。

  • YAML 是最常见的机器学习配置格式:

# configs/experiment_1.yaml
model:
  name: transformer
  d_model: 512
  n_heads: 8
  n_layers: 6

training:
  batch_size: 64
  learning_rate: 3e-4
  max_epochs: 100
  early_stopping_patience: 10

data:
  train_path: /data/train.parquet
  val_path: /data/val.parquet
  max_seq_length: 512
  • Hydra(Facebook 出品)是一个配置框架,支持组合(将基础配置与实验特定的覆盖合并)、命令行覆盖(python train.py training.lr=1e-3)和多运行(超参数搜索)。

  • argparse 对于只有少量参数的脚本更简单:

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--config", type=str, default="configs/base.yaml")
args = parser.parse_args()
  • 最佳实践:有一个包含所有默认值的基础配置,以及每个实验的配置,后者只覆盖需要改变的部分。将每个实验的配置与其结果一起跟踪。

日志与可观测性

  • print 语句用于调试。日志用于生产环境:
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

logger.debug("Batch loaded: %d samples", len(batch))     # 详细,用于调试
logger.info("Epoch %d: loss=%.4f, lr=%.6f", epoch, loss, lr)  # 正常运行
logger.warning("GPU memory >90%%, consider reducing batch size")
logger.error("Failed to load checkpoint: %s", path)       # 可恢复的错误
logger.critical("CUDA out of memory, aborting")            # 致命错误
  • 为什么不用 print:日志支持级别(生产环境过滤掉调试信息)、格式(时间戳、模块名)和处理程序(写入文件、发送到监控系统),而无需修改日志调用本身。

  • 结构化日志输出机器可解析的格式(JSON),同时保留人类可读的消息。这使得可以根据特定字段进行搜索和告警:

logger.info("training_step", extra={
    "epoch": 5, "step": 1200, "loss": 0.0342, "lr": 2.1e-4
})

API 设计

  • 如果你的模型将被其他服务(Web 应用、移动应用、另一个机器学习流水线)使用,那么它需要一个 API(应用程序编程接口)。

  • REST API 使用 HTTP 方法:GET 读取,POST 创建/预测,PUT 更新,DELETE 删除。端点遵循基于资源的命名:

POST /api/v1/predict          # 发送输入,获取预测结果
GET  /api/v1/models           # 列出可用模型
GET  /api/v1/models/{id}      # 获取模型详情
POST /api/v1/models/{id}/predict  # 使用特定模型进行预测
  • FastAPI 是 Python 机器学习服务的主流框架:
from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()

class PredictRequest(BaseModel):
    text: str

class PredictResponse(BaseModel):
    label: str
    confidence: float

@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    result = model.predict(request.text)
    return PredictResponse(label=result.label, confidence=result.score)
  • FastAPI 自动生成 API 文档(Swagger UI 位于 /docs),使用 Pydantic 模型验证输入/输出,并支持异步以获取高吞吐量。

  • gRPC 在内部服务间通信时比 REST 更快。它使用 Protocol Buffers(二进制序列化,比 JSON 更小更快)并支持流式传输。TensorFlow Serving、Triton Inference Server 和许多微服务架构都在使用它。

打包与分发

  • 将代码打包成可安装的包,可以让其他人(以及你自己的脚本)清晰地导入它:
# pyproject.toml
[project]
name = "my-ml-project"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
    "torch>=2.0",
    "jax>=0.4",
    "pydantic>=2.0",
]

[project.optional-dependencies]
dev = ["pytest", "ruff", "mypy"]

[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.backends._legacy:_Backend"
pip install -e ".[dev]"    # 以可编辑模式安装,包含开发依赖
  • 可编辑安装-e):对源代码的更改会立即生效,无需重新安装。开发过程中必不可少。

  • 锁定依赖:使用精确版本号的 requirements.txttorch==2.2.1,而不是 torch>=2.0)可以确保可复现性。使用 pip freeze > requirements.txt 来捕获当前环境。对于更复杂的依赖管理,可以使用 uvpoetrypip-tools

与 AI 编程助手协作

  • AI 编程助手(Claude Code、GitHub Copilot、Cursor 等)现在已经成为专业工程工作流的一部分。使用得当,它们能极大加速开发;使用不当,则会引入微妙的错误、侵蚀你对自身代码库的理解,并制造虚假的生产力感。

  • 正确的心智模型是:AI 助手是一个快速但经验不足的结对编程伙伴。它能快速编写代码,熟悉语法和标准模式,阅读过的文档比你多得多。但它不理解你的具体系统、你的业务约束、你的边界情况,以及你设计决策背后的原因。你是资深工程师,助手是初级工程师。你负责指导、审查并承担责任。

AI 助手的优势场景

  • 样板代码和脚手架:生成 Dockerfile、CI 配置、测试夹具、数据类定义、argparse 设置等。这些遵循众所周知的模式,手写起来很繁琐。让助手生成,然后检查正确性。

  • 编写测试:描述函数的行为,助手生成测试用例。它经常能捕捉到你可能会遗漏的边界情况(空输入、负值、Unicode 字符)。一定要阅读生成的测试——它们验证的是你的假设,而不仅仅是你的代码。

  • 重构:“把这个块提取成一个函数”,“将这个类改为使用 dataclass”,“给这个模块添加类型注解”。这些是机械性的转换,意图清晰,引入细微错误的风险较低。

  • 探索和原型设计:“写一个快速脚本,用于基准测试推理延迟”或“向我展示如何使用 HuggingFace 的 tokenizer API”。相比于阅读文档,助手能更快地让你获得一个可运行的起点。

  • 文档和文档字符串:助手可以根据你的代码结构生成文档。检查其准确性,但繁琐的工作已经被自动化了。

  • 调试辅助:粘贴一个错误回溯,询问诊断意见。助手通常能找出根本原因并建议修复方法,尤其是对于常见问题(形状不匹配、导入错误、CUDA 内存不足)。

何时不应依赖 AI 助手

  • 新颖的架构决策:如果你正在设计一个新的训练流水线,助手会给你一个通用的答案。它不知道你的数据约束、延迟要求或团队的专业知识。使用助手来实现你已经考虑清楚的设计。

  • 安全关键代码:身份验证、加密、输入净化。助手生成的代码可能看起来正确,但存在微妙的漏洞(SQL 注入、不安全的默认值、时序攻击)。安全代码应该由理解威胁模型的人编写,并由另一个人审查。

  • 性能关键的内层循环:助手会写出正确但幼稚的代码。对于 GPU 内核、内存关键的数据结构或延迟敏感的服务路径,你需要理解硬件约束(第13章、第16章)并有意识地进行优化。

  • 你不理解的代码:如果助手生成了 200 行代码,而你无法解释每一行的作用,那就不要提交它。你现在将维护一段你不理解的代码,而当它出问题时(一定会出问题),你无法调试。这是最常见也最危险的失败模式。

审查纪律

  • 在提交之前,始终阅读生成代码的每一行。这不是可选项。助手的代码是草稿,而不是成品。像对待同事的拉取请求一样对待它:批判性地审查。

  • 需要检查的事项

    • 正确性:它是否真的做了你要求的事?助手经常解决一个与你意图略有不同的问题。
    • 边界情况:它是否处理了空输入、None 值、负数、极大输入?助手经常遗漏边界情况的处理。
    • 幻觉 API:助手可能调用了不存在的函数或使用了不存在的参数,尤其是对于较新或较不常见的库。验证每个 API 调用是否真实存在。
    • 过度工程:助手倾向于生成比所需更多的代码。一个 50 行的解决方案解决了一个 10 行的问题,增加了复杂性却没有好处。果断简化。
    • 安全性:硬编码的密钥、未净化的用户输入、不安全的默认值。助手不会以对抗性思维思考。
    • 风格一致性:生成的代码是否匹配你项目的约定(命名、模式、错误处理)?

如何编写好的提示词

  • AI 助手输出的质量与你指令的质量成正比。模糊的提示得到模糊的代码。

  • 不好:“写一个数据加载器”

  • :“写一个 PyTorch DataLoader,用于一个包含 'text' 和 'label' 列的 CSV 文件。使用 HuggingFace tokenizer 'bert-base-uncased' 对文本进行 tokenize,max_length=512。返回 input_ids、attention_mask 和作为张量的 label。处理 CSV 中 label 列缺失值时,跳过这些行。”

  • 提供上下文:告诉助手你的项目结构、现有代码、约束和约定。上下文越多,输出越好。

  • 指定约束:“只使用标准库”,“必须能在 Python 3.10 上工作”,“不要使用全局变量”,“遵循 src/models/transformer.py 中的现有模式。”

  • 要求解释:“实现 X,并解释关键的设计决策。”这迫使助手阐述其推理,使你更容易发现错误的假设。

使用质量门捕捉助手的错误

  • 你现有的质量基础设施(文件 04)捕捉助手的错误就像捕捉人类的错误一样有效:

    • 类型检查(mypy):捕获幻觉的 API 签名和类型不匹配。
    • 静态检查(ruff):捕获未使用的导入、未定义的变量和风格违规。
    • 测试(pytest):如果助手的代码通过了你的测试套件,它更有可能是正确的。如果你没有测试,在要求助手实现功能之前就写好它们(测试驱动开发在配合 AI 助手时尤其有效)。
    • CI 流水线:在每次提交时自动运行上述所有检查。
  • “助手写代码” + “质量门验证” 的组合比单独任何一项都更高效。助手快速但马虎;质量门全面但不写代码。两者结合,你既能获得速度又能获得正确性。

生产力陷阱

  • 使用 AI 助手的最大风险是虚假的生产力感。你可以在 10 分钟内生成 500 行代码。但如果因为你没有理解这 500 行代码而花了 2 小时调试它们,那其实比你花 30 分钟自己写 200 行代码还要慢。

  • 使用 AI 助手获得真正的生产力来自:

    1. 保持控制:你决定架构,助手填充实现。
    2. 理解生成的内容:如果你无法解释它,要么重写它,要么要求助手简化它。
    3. 投资质量门:测试、类型检查和静态检查的成本会在与 AI 助手的每次交互中分摊。
    4. 让助手弥补你的弱点:如果你擅长算法但不擅长写测试,让助手写测试。如果你擅长 UI 代码但不熟悉数据库查询,让助手起草 SQL。发挥你的优势,委派你的短板。
  • 从 AI 编程助手中获益最多的工程师,是那些已经擅长编程的人。AI 助手放大你现有的技能,而不是取代它。理解数据结构、算法、系统设计和软件工程(正是这一章的内容)才能让你有效地指导 AI 助手并批判性地评估其输出。