抽样¶
抽样决定了我们如何收集数据,并直接控制我们得出的每个结论的质量。本章涵盖简单随机抽样、分层抽样、整群抽样、系统抽样、抽样分布、大数定律和自助法,这些方法对于机器学习中的训练/测试拆分和数据集构建至关重要。
-
在理想情况下,你会测量你所关注群体中的每一个个体。但在实际中,这几乎总是不可能的。你无法调查每一个选民、测试每一个灯泡或扫描每一位患者。因此,你抽取一个样本,并用它来了解整体。
-
总体是你想要研究的全部个体或项目的集合。样本是你实际观测到的子集。
-
参数是描述总体的数字(例如,一个国家所有成年人的真实平均身高)。
-
统计量是根据样本计算出的数字(例如,你测量的500人的平均身高)。统计量被用来估计参数。
-
结论的质量完全取决于你如何选择样本。无论你的分析多么复杂,有偏的样本都会导致有偏的结论。
-
抽样框是你实际从中抽取样本的所有个体的列表。理想情况下,抽样框与总体完全匹配,但在实践中会存在差距。
-
例如,如果你通过电话调查人们,你就会漏掉所有没有电话的人。抽样框与总体之间的差异称为覆盖误差。
-
抽样误差是样本统计量与总体参数之间的自然差异。
-
即使是一个完美的随机样本也不会完全等于总体。更大的样本可以减少抽样误差。
-
抽样有两大类别:概率抽样和非概率抽样。
-
概率抽样意味着总体中的每个成员都有已知的、非零的被选中的概率。这使你能够量化不确定性并推广结果。
-
简单随机抽样:每个个体被选中的机会均等,且每个大小为\(n\)的可能样本出现的概率相同。想象把每个名字放进帽子里然后盲抽。
-
分层抽样:根据某个共同特征(例如年龄组、地区)将总体划分为互不重叠的组(层),然后从每一层中随机抽样。这保证了每个群体都有代表性,并在各层彼此不同时减少方差。
-
整群抽样:将总体划分为多个组(群),随机选择一些群,然后将选中群中的所有个体都纳入样本。当总体在地理上分散时(例如,抽样整个学校而不是跨学区抽个别学生),这种方法很实用。
-
系统抽样:选取一个随机起点,然后从列表中每隔\(k\)个个体选取一个。例如,从第7个人开始,然后每10个人取一个(7, 17, 27, …)。实现简单,但如果列表存在隐藏模式可能会引入偏差。
-
非概率抽样并不赋予每个成员已知的被选中的概率。结果不能严格地推广,但这些方法通常更快、更便宜。
-
便利抽样:选择最容易接触到的人。在购物中心调查人们很方便,但会遗漏不去那里购物的人。
-
配额抽样:类似于分层抽样,但没有随机性。研究者通过从每组中选择容易接触的个体来填补配额(例如50名男性和50名女性)。
-
滚雪球抽样:从少数参与者开始,然后请他们招募其他人。适用于难以接触到的总体(例如研究罕见疾病),但对有联系网络的个体存在严重偏差。
-
一旦你有了抽样方法,一个自然的问题就会出现:如果我抽取一个不同的样本,会得到不同的统计量吗?几乎肯定会的。抽样分布是指同一个统计量(如样本均值)在所有可能的大小相同的样本上的分布。
-
想象抽取1000个不同的大小为30的样本,并计算每个样本的平均身高。这1000个均值形成一个分布。有些会略高于真实总体均值,有些略低于,但大多数会聚集在真实值周围。
-
这个抽样分布的标准差称为标准误:
-
注意到标准误随着\(n\)的增大而减小。更大的样本给出更精确的估计。样本量增加到四倍,标准误减半。
-
统计学中最重要的结果是中心极限定理(CLT)。它指出:无论原始总体的形状如何,样本均值的分布随着样本量的增加趋近于正态分布。
- 更精确地说,如果\(X_1, X_2, \ldots, X_n\)是来自任意具有均值\(\mu\)和有限方差\(\sigma^2\)的分布的独立观测值,那么随着\(n\)增大:
-
CLT是大多数推断统计学有效的基础。它允许我们使用正态分布作为近似,即使底层数据不是正态的,只要样本足够大。
-
“足够大”是多大?一个常见的经验法则是\(n \ge 30\),但这取决于总体偏离正态的程度。对于高度偏斜的分布,你可能需要更大的样本。对于大致对称的总体,即使\(n = 10\)也可能足够。
-
CLT有三个关键条件:
- 独立性:每个观测值不能影响其他观测值
- 有限方差:总体方差必须存在(排除了一些奇特的分布)
- 同分布:所有观测值都来自同一个分布
编程任务(使用 CoLab 或 notebook)¶
-
可视化展现中心极限定理:从一个高度偏斜的分布中抽取样本,计算样本均值,观察均值的直方图如何变成钟形。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt key = jax.random.PRNGKey(0) # 指数分布(高度偏斜) population = jax.random.exponential(key, shape=(100_000,)) fig, axes = plt.subplots(1, 4, figsize=(14, 3)) sample_sizes = [1, 5, 30, 100] for ax, n in zip(axes, sample_sizes): keys = jax.random.split(key, 2000) means = jnp.array([jax.random.choice(k, population, shape=(n,)).mean() for k in keys]) ax.hist(means, bins=40, color="#3498db", alpha=0.7, density=True) ax.set_title(f"n = {n}") ax.set_xlim(0, 4) fig.suptitle("中心极限定理:随着 n 增加,样本均值趋于正态", fontsize=13) plt.tight_layout() plt.show() -
比较简单随机抽样与分层抽样。创建一个具有明显分组的总体,并表明分层抽样在估计中具有更低的方差。
import jax import jax.numpy as jnp key = jax.random.PRNGKey(42) # 总体:两个不同的组 group_a = jax.random.normal(key, shape=(500,)) + 10 # 均值约10 key, subkey = jax.random.split(key) group_b = jax.random.normal(subkey, shape=(500,)) + 20 # 均值约20 population = jnp.concatenate([group_a, group_b]) # 简单随机抽样:1000次试验,样本量20 srs_means = [] for i in range(1000): key, subkey = jax.random.split(key) sample = jax.random.choice(subkey, population, shape=(20,), replace=False) srs_means.append(sample.mean()) srs_means = jnp.array(srs_means) # 分层抽样:每个组取10个 strat_means = [] for i in range(1000): key, k1, k2 = jax.random.split(key, 3) s_a = jax.random.choice(k1, group_a, shape=(10,), replace=False) s_b = jax.random.choice(k2, group_b, shape=(10,), replace=False) strat_means.append(jnp.concatenate([s_a, s_b]).mean()) strat_means = jnp.array(strat_means) print(f"简单随机 - 均值: {srs_means.mean():.3f}, 标准差: {srs_means.std():.3f}") print(f"分层抽样 - 均值: {strat_means.mean():.3f}, 标准差: {strat_means.std():.3f}") print(f"分层抽样将方差减少了 {(1 - strat_means.var()/srs_means.var())*100:.1f}%") -
探索样本量如何影响标准误。绘制标准误对样本量的图,并确认\(1/\sqrt{n}\)的关系。
import jax import jax.numpy as jnp import matplotlib.pyplot as plt key = jax.random.PRNGKey(7) population = jax.random.normal(key, shape=(50_000,)) * 10 + 50 sample_sizes = [5, 10, 20, 50, 100, 200, 500, 1000] std_errors = [] for n in sample_sizes: means = [] for _ in range(500): key, subkey = jax.random.split(key) sample = jax.random.choice(subkey, population, shape=(n,)) means.append(sample.mean()) std_errors.append(jnp.array(means).std()) plt.figure(figsize=(8, 4)) plt.plot(sample_sizes, std_errors, "o-", color="#e74c3c", label="观测到的标准误") theoretical = population.std() / jnp.sqrt(jnp.array(sample_sizes, dtype=jnp.float32)) plt.plot(sample_sizes, theoretical, "--", color="#3498db", label="σ/√n (理论值)") plt.xlabel("样本量 (n)") plt.ylabel("标准误") plt.legend() plt.title("更大的样本量使标准误缩小") plt.show()