代码库设计与模式¶
良好的代码库设计是区分研究原型和生产级软件的关键。本文涵盖项目结构、清洁代码原则、与机器学习相关的设计模式、配置管理、日志记录、API 设计以及打包。
-
大多数机器学习代码最初都是一个 Jupyter notebook。随着 notebook 不断增长、被复制、修改、分享,最终它会变成一团无法维护的乱麻:全局变量、死掉的单元格和魔法数字四处散落。代码库设计就是一套组织代码的纪律,让项目在成长过程中始终保持可理解和可修改。
-
这不是为了遵守规则而遵守规则。它的目标是缩短从“我想修改 X”到“X 被修改成功并能正常工作”之间的时间。在精心设计的代码库中,这段时间是以分钟计算的;而在设计糟糕的代码库中,这可能需要几天时间在无文档的意大利面式代码中考古。
项目结构¶
- 一致的项目布局能让任何人(包括未来的你自己)快速浏览代码库。
my_project/
├── src/my_project/ # 源代码(可导入的包)
│ ├── __init__.py
│ ├── data/ # 数据加载与预处理
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── transforms.py
│ ├── models/ # 模型架构
│ │ ├── __init__.py
│ │ ├── transformer.py
│ │ └── layers.py
│ ├── training/ # 训练循环、优化器
│ │ ├── __init__.py
│ │ ├── trainer.py
│ │ └── losses.py
│ └── utils/ # 共享工具
│ ├── __init__.py
│ └── logging.py
├── configs/ # 配置文件
│ ├── base.yaml
│ └── experiment_1.yaml
├── scripts/ # 入口脚本(训练、评估、服务)
│ ├── train.py
│ ├── evaluate.py
│ └── serve.py
├── tests/ # 测试文件(镜像 src/ 的结构)
│ ├── test_dataset.py
│ ├── test_model.py
│ └── test_trainer.py
├── notebooks/ # 仅用于探索(非生产代码)
├── pyproject.toml # 项目元数据和依赖
├── README.md
├── .gitignore
└── Dockerfile
-
src/布局:将源代码放在src/my_project/下可以防止意外从当前目录导入(这能掩盖在生产环境中才会暴露的导入错误)。开发时使用pip install -e .安装。 -
单仓库 vs 多仓库:单仓库(monorepo) 将所有相关项目放在一个代码仓库中(跨项目修改更容易,CI 共享)。多仓库(multi-repo) 为每个项目提供独立的仓库(边界更清晰,版本独立)。大多数机器学习团队从单仓库开始,需要时再拆分。
-
脚本 vs 库:将入口点(
train.py、evaluate.py)放在scripts/中。将可复用逻辑放在src/中。一个训练脚本应该大约 50 行:解析配置、构建数据集、构建模型、构建训练器、开始训练。所有复杂性都存在于库中。
清洁代码原则¶
- 命名:这是你能做的最有影响力的一件事。变量名为
x迫使你阅读周围代码来理解它;变量名为learning_rate则是自文档的。
# 糟糕
def proc(d, n, lr):
for i in range(n):
for k, v in d.items():
v -= lr * g[k]
# 良好
def update_parameters(parameters, num_steps, learning_rate):
for step in range(num_steps):
for name, param in parameters.items():
param -= learning_rate * gradients[name]
-
单一职责原则:每个函数/类只做一件事。名为
load_data_and_train_model的函数同时做了两件事,应该拆分。这样每个部分都可以独立测试、复用和理解。 -
DRY(不要重复自己)——但不要过早抽象。如果你复制粘贴了三次代码,就把它提取成一个函数。但对于只用过一次的代码,不要创建抽象。过早的抽象比重复更糟糕:它在没有经过验证的好处下增加了复杂性。
# 过早的抽象(只有一个用例,过度设计)
class AbstractDataTransformPipelineFactory:
...
# 恰到好处(直接、清晰,且在三处使用)
def normalise_image(image, mean, std):
return (image - mean) / std
- 魔法数字:永远不要使用未经解释的字面常量。
# 糟糕
if len(batch) > 32:
split_batch(batch, 32)
# 良好
MAX_BATCH_SIZE = 32
if len(batch) > MAX_BATCH_SIZE:
split_batch(batch, MAX_BATCH_SIZE)
- 函数应该短小:如果一个函数不能在一屏(约 30 行)内显示完整,那么它可能做了太多事情。将逻辑块提取为具有描述性名称的辅助函数。这样函数体本身读起来就像一个高层摘要。
机器学习的设计模式¶
-
设计模式是针对常见问题的可复用解决方案。以下是与机器学习代码库最相关的几种模式:
-
工厂模式:在不指定具体类的情况下创建对象。当你的配置中说
model: "transformer"而你需要实例化正确的类时,这非常有用:
MODEL_REGISTRY = {
"transformer": TransformerModel,
"cnn": CNNModel,
"mlp": MLPModel,
}
def build_model(config):
model_cls = MODEL_REGISTRY[config["model"]]
return model_cls(**config["model_params"])
-
这使训练脚本与具体的模型实现解耦。添加一个新模型只需在注册表中添加一行,而无需修改训练循环。
-
策略模式:在运行时切换算法。对损失函数、优化器、学习率调度器很有用:
LOSS_FUNCTIONS = {
"mse": nn.MSELoss,
"cross_entropy": nn.CrossEntropyLoss,
"focal": FocalLoss,
}
loss_fn = LOSS_FUNCTIONS[config["loss"]]()
- 观察者模式(回调/钩子):让模块在无需紧耦合的情况下对事件做出响应。训练框架(PyTorch Lightning、Keras)广泛使用回调:
class EarlyStopping:
def __init__(self, patience=5):
self.patience = patience
self.best_loss = float('inf')
self.counter = 0
def on_epoch_end(self, epoch, val_loss):
if val_loss < self.best_loss:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
return "stop"
- 依赖注入:将依赖传递给函数/类,而不是在内部创建它们。这使得测试变得容易(注入一个 mock 对象)并且配置灵活:
# 糟糕:硬编码依赖
class Trainer:
def __init__(self):
self.logger = WandbLogger() # 离开 W&B 就无法测试
# 良好:依赖注入
class Trainer:
def __init__(self, logger):
self.logger = logger # 可以注入任何 logger,包括 mock
配置管理¶
-
硬编码超参数、文件路径和模型设置会使实验无法复现,修改也变得痛苦。将配置外部化到文件中。
-
YAML 是最常见的机器学习配置格式:
# configs/experiment_1.yaml
model:
name: transformer
d_model: 512
n_heads: 8
n_layers: 6
training:
batch_size: 64
learning_rate: 3e-4
max_epochs: 100
early_stopping_patience: 10
data:
train_path: /data/train.parquet
val_path: /data/val.parquet
max_seq_length: 512
-
Hydra(Facebook 出品)是一个配置框架,支持组合(将基础配置与实验特定的覆盖合并)、命令行覆盖(
python train.py training.lr=1e-3)和多运行(超参数搜索)。 -
argparse 对于只有少量参数的脚本更简单:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--config", type=str, default="configs/base.yaml")
args = parser.parse_args()
- 最佳实践:有一个包含所有默认值的基础配置,以及每个实验的配置,后者只覆盖需要改变的部分。将每个实验的配置与其结果一起跟踪。
日志与可观测性¶
print语句用于调试。日志用于生产环境:
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.debug("Batch loaded: %d samples", len(batch)) # 详细,用于调试
logger.info("Epoch %d: loss=%.4f, lr=%.6f", epoch, loss, lr) # 正常运行
logger.warning("GPU memory >90%%, consider reducing batch size")
logger.error("Failed to load checkpoint: %s", path) # 可恢复的错误
logger.critical("CUDA out of memory, aborting") # 致命错误
-
为什么不用 print:日志支持级别(生产环境过滤掉调试信息)、格式(时间戳、模块名)和处理程序(写入文件、发送到监控系统),而无需修改日志调用本身。
-
结构化日志输出机器可解析的格式(JSON),同时保留人类可读的消息。这使得可以根据特定字段进行搜索和告警:
API 设计¶
-
如果你的模型将被其他服务(Web 应用、移动应用、另一个机器学习流水线)使用,那么它需要一个 API(应用程序编程接口)。
-
REST API 使用 HTTP 方法:
GET读取,POST创建/预测,PUT更新,DELETE删除。端点遵循基于资源的命名:
POST /api/v1/predict # 发送输入,获取预测结果
GET /api/v1/models # 列出可用模型
GET /api/v1/models/{id} # 获取模型详情
POST /api/v1/models/{id}/predict # 使用特定模型进行预测
- FastAPI 是 Python 机器学习服务的主流框架:
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class PredictRequest(BaseModel):
text: str
class PredictResponse(BaseModel):
label: str
confidence: float
@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
result = model.predict(request.text)
return PredictResponse(label=result.label, confidence=result.score)
-
FastAPI 自动生成 API 文档(Swagger UI 位于
/docs),使用 Pydantic 模型验证输入/输出,并支持异步以获取高吞吐量。 -
gRPC 在内部服务间通信时比 REST 更快。它使用 Protocol Buffers(二进制序列化,比 JSON 更小更快)并支持流式传输。TensorFlow Serving、Triton Inference Server 和许多微服务架构都在使用它。
打包与分发¶
- 将代码打包成可安装的包,可以让其他人(以及你自己的脚本)清晰地导入它:
# pyproject.toml
[project]
name = "my-ml-project"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0",
"jax>=0.4",
"pydantic>=2.0",
]
[project.optional-dependencies]
dev = ["pytest", "ruff", "mypy"]
[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.backends._legacy:_Backend"
-
可编辑安装(
-e):对源代码的更改会立即生效,无需重新安装。开发过程中必不可少。 -
锁定依赖:使用精确版本号的
requirements.txt(torch==2.2.1,而不是torch>=2.0)可以确保可复现性。使用pip freeze > requirements.txt来捕获当前环境。对于更复杂的依赖管理,可以使用uv、poetry或pip-tools。
与 AI 编程助手协作¶
-
AI 编程助手(Claude Code、GitHub Copilot、Cursor 等)现在已经成为专业工程工作流的一部分。使用得当,它们能极大加速开发;使用不当,则会引入微妙的错误、侵蚀你对自身代码库的理解,并制造虚假的生产力感。
-
正确的心智模型是:AI 助手是一个快速但经验不足的结对编程伙伴。它能快速编写代码,熟悉语法和标准模式,阅读过的文档比你多得多。但它不理解你的具体系统、你的业务约束、你的边界情况,以及你设计决策背后的原因。你是资深工程师,助手是初级工程师。你负责指导、审查并承担责任。
AI 助手的优势场景¶
-
样板代码和脚手架:生成 Dockerfile、CI 配置、测试夹具、数据类定义、argparse 设置等。这些遵循众所周知的模式,手写起来很繁琐。让助手生成,然后检查正确性。
-
编写测试:描述函数的行为,助手生成测试用例。它经常能捕捉到你可能会遗漏的边界情况(空输入、负值、Unicode 字符)。一定要阅读生成的测试——它们验证的是你的假设,而不仅仅是你的代码。
-
重构:“把这个块提取成一个函数”,“将这个类改为使用 dataclass”,“给这个模块添加类型注解”。这些是机械性的转换,意图清晰,引入细微错误的风险较低。
-
探索和原型设计:“写一个快速脚本,用于基准测试推理延迟”或“向我展示如何使用 HuggingFace 的 tokenizer API”。相比于阅读文档,助手能更快地让你获得一个可运行的起点。
-
文档和文档字符串:助手可以根据你的代码结构生成文档。检查其准确性,但繁琐的工作已经被自动化了。
-
调试辅助:粘贴一个错误回溯,询问诊断意见。助手通常能找出根本原因并建议修复方法,尤其是对于常见问题(形状不匹配、导入错误、CUDA 内存不足)。
何时不应依赖 AI 助手¶
-
新颖的架构决策:如果你正在设计一个新的训练流水线,助手会给你一个通用的答案。它不知道你的数据约束、延迟要求或团队的专业知识。使用助手来实现你已经考虑清楚的设计。
-
安全关键代码:身份验证、加密、输入净化。助手生成的代码可能看起来正确,但存在微妙的漏洞(SQL 注入、不安全的默认值、时序攻击)。安全代码应该由理解威胁模型的人编写,并由另一个人审查。
-
性能关键的内层循环:助手会写出正确但幼稚的代码。对于 GPU 内核、内存关键的数据结构或延迟敏感的服务路径,你需要理解硬件约束(第13章、第16章)并有意识地进行优化。
-
你不理解的代码:如果助手生成了 200 行代码,而你无法解释每一行的作用,那就不要提交它。你现在将维护一段你不理解的代码,而当它出问题时(一定会出问题),你无法调试。这是最常见也最危险的失败模式。
审查纪律¶
-
在提交之前,始终阅读生成代码的每一行。这不是可选项。助手的代码是草稿,而不是成品。像对待同事的拉取请求一样对待它:批判性地审查。
-
需要检查的事项:
- 正确性:它是否真的做了你要求的事?助手经常解决一个与你意图略有不同的问题。
- 边界情况:它是否处理了空输入、None 值、负数、极大输入?助手经常遗漏边界情况的处理。
- 幻觉 API:助手可能调用了不存在的函数或使用了不存在的参数,尤其是对于较新或较不常见的库。验证每个 API 调用是否真实存在。
- 过度工程:助手倾向于生成比所需更多的代码。一个 50 行的解决方案解决了一个 10 行的问题,增加了复杂性却没有好处。果断简化。
- 安全性:硬编码的密钥、未净化的用户输入、不安全的默认值。助手不会以对抗性思维思考。
- 风格一致性:生成的代码是否匹配你项目的约定(命名、模式、错误处理)?
如何编写好的提示词¶
-
AI 助手输出的质量与你指令的质量成正比。模糊的提示得到模糊的代码。
-
不好:“写一个数据加载器”
-
好:“写一个 PyTorch DataLoader,用于一个包含 'text' 和 'label' 列的 CSV 文件。使用 HuggingFace tokenizer 'bert-base-uncased' 对文本进行 tokenize,max_length=512。返回 input_ids、attention_mask 和作为张量的 label。处理 CSV 中 label 列缺失值时,跳过这些行。”
-
提供上下文:告诉助手你的项目结构、现有代码、约束和约定。上下文越多,输出越好。
-
指定约束:“只使用标准库”,“必须能在 Python 3.10 上工作”,“不要使用全局变量”,“遵循
src/models/transformer.py中的现有模式。” -
要求解释:“实现 X,并解释关键的设计决策。”这迫使助手阐述其推理,使你更容易发现错误的假设。
使用质量门捕捉助手的错误¶
-
你现有的质量基础设施(文件 04)捕捉助手的错误就像捕捉人类的错误一样有效:
- 类型检查(mypy):捕获幻觉的 API 签名和类型不匹配。
- 静态检查(ruff):捕获未使用的导入、未定义的变量和风格违规。
- 测试(pytest):如果助手的代码通过了你的测试套件,它更有可能是正确的。如果你没有测试,在要求助手实现功能之前就写好它们(测试驱动开发在配合 AI 助手时尤其有效)。
- CI 流水线:在每次提交时自动运行上述所有检查。
-
“助手写代码” + “质量门验证” 的组合比单独任何一项都更高效。助手快速但马虎;质量门全面但不写代码。两者结合,你既能获得速度又能获得正确性。
生产力陷阱¶
-
使用 AI 助手的最大风险是虚假的生产力感。你可以在 10 分钟内生成 500 行代码。但如果因为你没有理解这 500 行代码而花了 2 小时调试它们,那其实比你花 30 分钟自己写 200 行代码还要慢。
-
使用 AI 助手获得真正的生产力来自:
- 保持控制:你决定架构,助手填充实现。
- 理解生成的内容:如果你无法解释它,要么重写它,要么要求助手简化它。
- 投资质量门:测试、类型检查和静态检查的成本会在与 AI 助手的每次交互中分摊。
- 让助手弥补你的弱点:如果你擅长算法但不擅长写测试,让助手写测试。如果你擅长 UI 代码但不熟悉数据库查询,让助手起草 SQL。发挥你的优势,委派你的短板。
-
从 AI 编程助手中获益最多的工程师,是那些已经擅长编程的人。AI 助手放大你现有的技能,而不是取代它。理解数据结构、算法、系统设计和软件工程(正是这一章的内容)才能让你有效地指导 AI 助手并批判性地评估其输出。