为什么注意力要除以 √(d_k):从数值稳定到工程收益
副标题 / 摘要 注意力中的缩放项 \u221a(d_k) 不是装饰,而是数值稳定的关键:它控制 QK^T 的方差,避免 softmax 饱和和梯度消失。本文用公式与实验解释其必要性,并给出工程场景建议。 预计阅读时长:12~16 分钟 标签:attention、transformer、scaled-dot-product SEO 关键词:Attention, Scaled Dot-Product, \u221a(d_k) 元描述:解释注意力缩放项的数学动机与工程收益。 目标读者 想理解 Transformer 注意力细节的入门读者 需要排查训练不稳定问题的工程实践者 关注数值稳定性与性能优化的开发者 背景 / 动机 在点积注意力中,维度越大,QK^T 的数值越大,softmax 越容易饱和。 一旦饱和,梯度接近 0,训练会变慢甚至不稳定。 \u221a(d_k) 的缩放项就是为了解决这个问题。 核心概念 点积注意力:$QK^\top$ 衡量相似度。 缩放项 \u221a(d_k):控制相似度的尺度。 softmax 饱和:输入过大导致概率趋近 0/1,梯度变小。 A — Algorithm(题目与算法) 用通俗语言说明主题内容 维度大时,QK^T 变大,softmax 过于“自信”。 缩放 \u221a(d_k) 后,数值回到合理范围,梯度更健康。 基础示例(1) d_k=64 时,如果不缩放,softmax 输出会非常尖锐。 基础示例(2) d_k=512 时,缩放与否会直接影响训练是否稳定。 实践指南 / 步骤 使用标准缩放:$QK^\top / \sqrt{d_k}$。 如果做自定义注意力,先验证 softmax 分布是否过尖锐。 在混合精度训练下,缩放更重要。 可运行示例(缩放与不缩放的对比) import torch import torch.nn.functional as F def attn_scores(d, scale=True): q = torch.randn(1, 1, d) k = torch.randn(1, 8, d) scores = q @ k.transpose(-2, -1) if scale: scores = scores / (d ** 0.5) probs = F.softmax(scores, dim=-1) return probs.max().item(), probs.min().item() for d in [32, 128, 512]: mx_s, mn_s = attn_scores(d, scale=True) mx_u, mn_u = attn_scores(d, scale=False) print(f"d={d} scaled max={mx_s:.3f} min={mn_s:.3f} | unscaled max={mx_u:.3f} min={mn_u:.3f}") 解释与原理 如果 $q_i, k_i \sim \mathcal{N}(0, 1)$, ...