LLaMA 中 RMSNorm 相比 LayerNorm 的优势
副标题 / 摘要 LLaMA 使用 RMSNorm 替代 LayerNorm,主要是为了简化计算、提升训练稳定性与推理效率。本文用公式、示例与工程场景讲清差异,并提供最小 PyTorch 代码。 预计阅读时长:12~16 分钟 标签:rmsnorm、layernorm、llama、pytorch SEO 关键词:RMSNorm, LayerNorm, LLaMA, 归一化 元描述:解释 RMSNorm 与 LayerNorm 的差异与优势,并给出可运行的 PyTorch 示例。 目标读者 想理解 LLaMA 架构细节的入门读者 关注训练/推理效率的工程实践者 需要在模型中选择归一化方案的开发者 背景 / 动机 归一化是稳定训练的关键步骤。 LayerNorm 是 Transformer 的默认选择,但在大模型中成本可观。 RMSNorm 用更少的计算达到相似效果,是 LLaMA 等模型的常见替代。 核心概念 LayerNorm(LN):对每个 token 的特征维度做均值和方差归一化。 RMSNorm:只做均方根归一化,不减均值。 缩放参数:两者都保留可学习的缩放向量 g。 A — Algorithm(题目与算法) 用通俗语言说明主题内容 LayerNorm:把每个 token 的特征变成“均值 0、方差 1”。 RMSNorm:只把特征的“幅度”缩放到稳定范围,不强制均值为 0。 基础示例(1) 输入向量 [1, 2, 3],LN 会中心化;RMSNorm 只缩放长度。 基础示例(2) 在大 batch 推理时,RMSNorm 少了一次均值计算,吞吐更高。 实践指南 / 步骤 若追求推理效率与训练稳定性,优先尝试 RMSNorm。 如果模型对偏移敏感,可保留 LN 或搭配残差调参。 对比训练曲线与损失波动,确认稳定性。 可运行示例(最小 PyTorch 对比) import torch import torch.nn as nn torch.manual_seed(42) class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): # x: (..., dim) rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() x = x / rms return x * self.weight x = torch.randn(2, 4, 8) ln = nn.LayerNorm(8) rms = RMSNorm(8) out_ln = ln(x) out_rms = rms(x) print(out_ln.mean(dim=-1)) print(out_rms.mean(dim=-1)) print(out_ln.std(dim=-1)) print(out_rms.std(dim=-1)) 解释与原理 LN 同时消除均值与缩放;RMSNorm 只控制尺度。 RMSNorm 计算少、数值更稳定,适合大模型训练。 由于不做中心化,RMSNorm 可能保留有用的偏移信息。 C — Concepts(核心思想) 方法类型 两者都属于特征归一化,用于稳定训练并加速收敛。 ...