副标题 / 摘要
这是一篇“算法解释型”长文:用结构化心智模型讲清 Transformer 的核心算法、为什么它能替代 RNN,以及工程上怎么落地。
- 预计阅读时长:约 15 分钟
- 标签:
transformer、attention、self-attention、multi-head-attention - SEO 关键词:Attention Is All You Need, Transformer, 自注意力, 位置编码
- 元描述:讲清 Transformer 的算法结构、复杂度与工程取舍,含可运行示例。
目标读者
- 想系统理解 Transformer 算法的中级工程师
- 熟悉 RNN/CNN,但对注意力与位置编码缺乏系统图景的读者
- 需要从“算法原理”过渡到“工程实现”的实践者
背景 / 动机
RNN 能处理序列,却难以并行;CNN 能并行,却难以建模长程依赖。
以长度 n=512 的序列为例,RNN 需要 512 次顺序步进,GPU 难以充分并行;而注意力主要由几次大矩阵乘法构成,更容易吃满算力。
Attention Is All You Need 的突破点是:直接在序列内部做全局依赖建模,同时保持高度并行。
这使得 Transformer 成为 NLP 与多模态的基础结构,尤其在 n>=128 的长依赖任务上优势明显。
快速掌握地图(60-120s)
- 问题形态:序列到序列或序列到表示的建模
- 核心思想:用自注意力为每个 token 动态聚合全局信息
- 何时使用/避免:需要全局依赖、并行训练时用;超长序列且内存极限时慎用
- 复杂度关键词:注意力
O(n^2),n=2048时注意力矩阵约 420 万元素 - 常见坑:Mask、位置编码与张量形状错配(例如因果 mask 漏加导致“看未来”)
大师级心智模型
- 核心抽象:把“序列建模”理解为“路由信息的相似度聚合”
- 问题家族:基于相似度的全局上下文建模(attention family)
- 同构模板:
softmax(QK^T)V的加权聚合范式 - 不变量:注意力权重矩阵每一行非负且和为 1(行随机矩阵),输出是值向量的凸组合
核心概念与术语
- Q/K/V:查询、键、值的线性投影矩阵
- 自注意力:序列内部每个位置对所有位置的加权汇聚
- 多头注意力:在多个子空间并行建模不同关系
- 位置编码:让模型感知顺序,弥补注意力的置换不变性
- 关键公式:
Q = X W_Q, K = X W_K, V = X W_VAttention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V- 其中
X ∈ R^{n x d_model},W_Q, W_K, W_V ∈ R^{d_model x d_k},d_k = d_model / h
位置编码的具体公式(Sinusoidal)
论文使用固定正弦/余弦位置编码,让不同维度对应不同频率:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
例如 d_model=8、pos=3 时,第 i=0 维使用 sin(3), 第 i=1 维使用 sin(3/10000^(2/8)),高维变化更慢,形成“多尺度位置信号”。
这使得注意力可以利用相对位移(两个位置差值)来推断顺序关系。
可行性与下界直觉
- 下界直觉:全量自注意力本质上要求任意 token 互相关联,天然需要
O(n^2)的 pairwise 交互。 - 模型破坏条件:当序列极长且内存受限时必须引入稀疏/局部/近似注意力。
例如n=4096时注意力矩阵约 1677 万元素,FP16 约 32 MB/头;n=16k时约 2.68 亿元素,单头约 512 MB,8 头就接近 4 GB,仅注意力权重就会成为瓶颈。 若再考虑梯度与激活缓存,实际峰值可达 2~3 倍,单卡训练几乎不可行。
问题抽象(输入/输出)
- 输入:长度为
n的 token 序列(embedding) - 输出:每个 token 的上下文化表示或解码分布
- 优化目标:提高建模能力与并行性,同时控制时间/空间开销(例如
n在 128~2048 时保持可训练,n>=8k时仍可部署)
评价指标与实验设计
在翻译任务中常用 BLEU,在语言建模中常用 Perplexity。
若只关注吞吐与延迟,可以记录 tokens/sec 与显存峰值。
例如在相同 batch 下比较 n=512 与 n=1024 的吞吐,可直接体现注意力 n^2 的代价。
如果关注可解释性,可以保存注意力权重并统计 top-k 覆盖率。
实际项目还会追踪显存峰值与单步延迟。
张量形状追踪(Encoder Block)
以 B=2, n=128, d_model=512, h=8 为例:
- 输入
X:[B, n, d_model] = [2, 128, 512] - 线性投影后
Q/K/V:[2, 128, 512] - 分头后
Q/K/V:[2, 8, 128, 64](每头维度d_k=64) - 注意力权重
A:[2, 8, 128, 128] - 头输出拼接:
[2, 128, 512] - FFN 中间维度:
d_ff=2048,张量形状[2, 128, 2048]
这套形状可以用来检查实现是否“维度对齐”,也是排查 mask 错误的最短路径。
交叉注意力形状追踪(Decoder -> Encoder)
- 编码器输出
M:[B, n_src, d_model] - 解码器当前状态
Y:[B, n_tgt, d_model] - 交叉注意力中
Q来自Y,K/V来自M - 权重张量形状
[B, h, n_tgt, n_src]
例如n_src=128, n_tgt=64时,注意力矩阵大小为64 x 128,可直接反映“对齐关系”。
思路推导(从朴素到突破)
- 朴素方案:RNN 顺序建模,长依赖难、梯度易衰减。
- 瓶颈:不可并行 + 长程依赖效率差(时间复杂度
O(n * d_model^2),且必须按时间步串行)。 - 关键观察:依赖不是“时间序列”本身,而是“任意位置之间的关联”。
- 方法选择:直接让所有位置互相“注意”,用矩阵乘法实现并行。
路径长度对比(为什么注意力更擅长长依赖)
RNN 中任意两个位置的依赖路径长度是 O(n);在 n=512 的句子中,信息要经过 512 次传递。
自注意力的路径长度是 O(1):每个 token 直接与所有位置交互,因此长程依赖不会被“长链路”稀释。
关键观察
注意力可以被表示为两次矩阵乘法:QK^T 得到相似度,再与 V 相乘得到聚合结果。
这意味着核心计算是 GEMM(矩阵乘法),在 GPU/TPU 上效率很高,且可以完全并行。
解释与原理
- 自注意力为何有效:每个 token 根据相似度拉取全局信息,动态选择“该看谁”。
- 多头机制的作用:让模型在不同子空间同时关注语法、语义、位置等不同关系。
- 位置编码的必要性:注意力对顺序不敏感,位置编码提供顺序/相对位置信号。
- 残差 + LayerNorm:稳定梯度,控制深层训练难度。
微型推导:为什么要除以 sqrt(d_k)
若 q_i, k_j 的元素近似均值 0、方差 1,未缩放的点积 q_i · k_j 的方差近似为 d_k。
当 d_k=64 时,标准差约 8,softmax 的输入易饱和,梯度变小。
用 1 / sqrt(d_k) 缩放可把标准差拉回到 1 量级,训练更稳定。
参数量估算(以 d_model=512 为例)
多头注意力常用“合并投影”实现:W_Q/W_K/W_V 与输出投影 W_O。
参数量近似为 4 * d_model^2:4 * 512^2 = 1,048,576(约 1.05M),不含偏置。
这解释了为什么 Transformer 的参数主要集中在投影层与 FFN 中,而不是 softmax。
单层总参数量估算
注意力层约 4 * d_model^2 = 1.05M,FFN 约 2 * d_model * d_ff = 2.10M,
单层合计约 3.15M。若堆叠 L=6 层,参数量约 18.9M(不含词嵌入与输出层)。
位置编码方案对比
| 方案 | 是否可外推 | 参数量 | 优点 | 代价 |
|---|---|---|---|---|
| 正弦/余弦 | 强 | 0 | 无需学习,长度可外推 | 表达受限 |
| 可学习绝对位置 | 弱 | n * d_model | 训练收敛快 | 长度受限 |
| 相对/旋转位置(RoPE) | 中-强 | 0 或少量 | 更适合长序列 | 公式更复杂 |
当 n 扩展到 4k/8k 时,RoPE/相对位置编码通常更稳健,
而可学习绝对位置在超出训练长度时容易退化。
位置编码最小实现(PyTorch)
import torch
def sinusoidal_position_encoding(n, d_model):
position = torch.arange(n).float().unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model)
)
pe = torch.zeros(n, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
pe = sinusoidal_position_encoding(8, 16)
print(pe.shape)
实践指南 / 步骤(算法流程)
- Embedding:将 token 映射到向量空间,得到
X ∈ R^{n x d_model}。 - 加位置编码:引入绝对或相对位置信息,保持序列顺序。
- 多头自注意力:计算
Q/K/V,每个头的维度d_k = d_model / h。 - 注意力矩阵:得到
A = softmax(QK^T / sqrt(d_k)),A ∈ R^{n x n}。 - 拼接与输出投影:拼接多头后乘
W_O回到d_model维度。 - 前馈网络:逐位置非线性变换增强表达力(通常
d_ff ≈ 4 * d_model)。 - 残差与归一化:稳定训练与深层堆叠。
- 编码器-解码器:解码器额外包含对编码器输出的交叉注意力。
前馈网络为何必要(Position-wise FFN)
注意力只是“加权汇聚”,本质是线性组合;
FFN 提供逐位置的非线性变换,提升表示能力:
FFN(x) = W2 * relu(W1 * x + b1) + b2
常见配置是 d_ff ≈ 4 * d_model。
当 d_model=512 时 d_ff=2048,单层 FFN 参数量约 2 * d_model * d_ff = 2,097,152,
这也是 Transformer 参数主要消耗的原因之一。
选型决策(Selection Guide)
- 输入长度:
n <= 2048全量自注意力可用;2048 < n <= 8192优先 FlashAttention 或分块;n > 8192考虑稀疏/线性注意力。 - 分布特征:依赖远距离上下文时优先 Transformer。
- 内存约束:24 GB 显存下,
n=4096、h=8时注意力权重开销已接近 256 MB,需关注激活与梯度叠加。 - 实现复杂度:小规模任务可用现成模块,大规模需定制算子。
注意力变体对比(工程选型)
| 方案 | 计算复杂度 | 显存复杂度 | 适用场景 | 代价 |
|---|---|---|---|---|
| 全量注意力 | O(n^2) | O(n^2) | n<=2k | 成本随 n^2 爆炸 |
| FlashAttention | O(n^2) | O(n) | 长序列训练 | 需要特定内核 |
| 分块注意力 | O(n^2) | O(n) | 显存受限 | 编程复杂度高 |
| 稀疏/线性注意力 | O(n) | O(n) | 超长序列 | 近似误差 |
头数与维度的经验选择
- 经验上保持
d_k在 32~128 之间效果更稳定。
例如d_model=512时,h=8得d_k=64;h=16得d_k=32。 - 若
d_k太小(如 16),每头表达力不足;太大(如 256),每头计算量激增。
Worked Example(Trace)
设有 3 个 token,维度 d=2,简化为 Q=K=V=X:
x1=(1,0),x2=(0,1),x3=(1,1)scores = QK^T / sqrt(2)得到:- 行 1:
(0.707, 0, 0.707) - 行 2:
(0, 0.707, 0.707) - 行 3:
(0.707, 0.707, 1.414)
- 行 1:
- softmax 后近似权重:
- 行 1:
(0.401, 0.198, 0.401) - 行 2:
(0.198, 0.401, 0.401) - 行 3:
(0.248, 0.248, 0.503)
- 行 1:
- 输出(加权求和):
y1 ≈ (0.802, 0.599)y2 ≈ (0.599, 0.802)y3 ≈ (0.751, 0.751)
Worked Example 2(因果 Mask 影响)
设 n=3 的自回归场景,因果 mask 禁止位置 i 看到 j>i:
注意力矩阵中 A[i, j] = -inf (j > i),softmax 后权重为 0。
这样第 1 个 token 只聚合自身,第 2 个 token 只能看前 2 个,避免“看未来”。
一个简化分数矩阵示例(未缩放):S = [[1, 2, 3], [1, 1, 1], [2, 0, 1]]
加入因果 mask 后变为:S' = [[1, -inf, -inf], [1, 1, -inf], [2, 0, 1]]
softmax 后第一行权重变成 (1, 0, 0),第二行变成 (0.5, 0.5, 0),第三行才保留完整的三项分布。
正确性(Proof Sketch)
- 不变量:每一行注意力权重非负且和为 1(softmax)。
- 保持性:softmax 对任意实数输入都产生概率分布。
- 正确性含义:输出是值向量的凸组合,保证表示位于输入子空间中,且是“基于相似度的加权信息融合”。
复杂度分析
- 时间复杂度:
O(n^2 * d_k)(QK^T)+O(n^2 * d_k)(AV),主导项为O(n^2 * d)。 - 空间复杂度:
O(n^2)(注意力矩阵)+O(n * d)(Q/K/V)。 - 结论:长序列的瓶颈来自
n^2,当n翻倍,内存和计算都约增加 4 倍。
算力粗估(以 n=1024, d_model=512 为例)
QK^T 计算量约为 n^2 * d_k = 1024^2 * 64 ≈ 67M 乘加。AV 同量级,再加上线性投影与 FFN,总体一次前向约数百 MFLOPs。
这也是为什么长序列训练需要高效注意力内核与混合精度。
注意力矩阵显存估算(单头 FP16)
| n | n^2 元素数 | 约占用 |
|---|---|---|
| 512 | 262,144 | ~0.5 MB |
| 1024 | 1,048,576 | ~2 MB |
| 2048 | 4,194,304 | ~8 MB |
| 4096 | 16,777,216 | ~32 MB |
这只是单头的权重矩阵,还未包含梯度与激活缓存。
多头与 batch 叠加后,显存压力会迅速放大。
常数因子与工程现实
- 带宽瓶颈:注意力矩阵占用大量显存与带宽,
n=2048时单头约 8 MB(FP16)。 - 优化路径:FlashAttention、块状注意力、KV 缓存。
- 风险点:数值稳定性(softmax 溢出)、mask 错误导致信息泄露。
在训练中还要考虑梯度与激活缓存。以 B=4, n=2048, h=8 为例,注意力权重只是其中一部分,
前向激活与反向梯度叠加后,显存峰值可能是权重矩阵的 3~5 倍。
超长序列工程策略(n>=8k)
当 n 提升到 8k/16k 时,单纯堆算力并不能解决问题,通常需要算法与工程同时改造:
- 分块注意力:将序列切为长度
w的窗口,显存从O(n^2)变为O(nw)。
例如n=16k, w=512时注意力权重规模约为n*w=8,192,000,比全量n^2小两个数量级。 - 稀疏模式:保留局部 + 少量全局 token(如 [CLS]),在不破坏全局信息的前提下降成本。
- 检索增强:先检索
k个相关片段再做注意力,复杂度从O(n^2)变为O(nk)。
这些策略的代价是“近似误差 + 实现复杂度”,因此需要结合任务验证效果。
常见性能误区
- 盲目加头数:
h翻倍会导致QK^T的并行头数增加,但总算力几乎不变,反而增加调度开销。 - 忽略 batch 维度:
B从 4 提到 16 时,注意力矩阵内存直接乘 4。 - 序列长度未裁剪:许多文本任务把
n固定为 2048,但有效内容可能只有 300~500 token,浪费严重。
可运行示例(Python / PyTorch)
import torch
import torch.nn as nn
torch.manual_seed(42)
class MiniTransformerBlock(nn.Module):
def __init__(self, d_model=32, num_heads=4, d_ff=64):
super().__init__()
self.mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_out, _ = self.mha(x, x, x)
x = self.norm1(x + attn_out)
ffn_out = self.ffn(x)
return self.norm2(x + ffn_out)
x = torch.randn(2, 5, 32) # batch=2, seq=5, dim=32
block = MiniTransformerBlock()
out = block(x)
print(out.shape)
可运行示例 2(NumPy 实现单头注意力)
import numpy as np
def softmax(x, axis=-1):
x = x - np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def attention(x, wq, wk, wv):
q = x @ wq
k = x @ wk
v = x @ wv
d_k = q.shape[-1]
scores = (q @ k.T) / np.sqrt(d_k)
weights = softmax(scores, axis=-1)
return weights @ v, weights
np.random.seed(0)
x = np.random.randn(4, 8) # n=4, d_model=8
wq = np.random.randn(8, 4) # d_k=4
wk = np.random.randn(8, 4)
wv = np.random.randn(8, 4)
out, weights = attention(x, wq, wk, wv)
print(out.shape, weights.shape)
注意力权重的 Top-k 观察(PyTorch)
import torch
import torch.nn as nn
torch.manual_seed(0)
x = torch.randn(1, 6, 32)
mha = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
_, attn = mha(x, x, x, need_weights=True, average_attn_weights=True)
topk_vals, topk_idx = torch.topk(attn[0, 0], k=3)
print(topk_vals, topk_idx)
这个小实验可以用来验证“关注分布”是否合理:若 top-k 永远只集中在局部位置,
可能说明模型没有学到全局依赖,或者数据本身缺少长程关系。
工程应用场景(含最小代码片段)
- 文本分类(Encoder-only):用 Transformer 编码句子后接分类头。
encoder = nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True)
clf = nn.Linear(32, 2)
seq = torch.randn(4, 10, 32)
encoded = encoder(seq)
logits = clf(encoded[:, 0]) # 取 CLS 位置
这一模式的关键是确定聚合位置(CLS 或平均池化),以及确保 mask 与 padding 对齐。
- 机器翻译(Encoder-Decoder):解码器对编码器输出做交叉注意力。
encoder_layer = nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True)
decoder_layer = nn.TransformerDecoderLayer(d_model=32, nhead=4, batch_first=True)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)
src = torch.randn(2, 8, 32)
tgt = torch.randn(2, 6, 32)
mem = encoder(src)
out = decoder(tgt, mem)
翻译场景需要同时处理源序列 mask 与目标序列因果 mask,两者缺一不可。
- 视觉 Transformer(Patch Embedding):把图像切块后做注意力。
patches = torch.randn(2, 196, 32) # 14x14 patches, dim=32
encoder = nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True)
encoded = encoder(patches)
ViT 的核心在于 patch embedding 的尺寸与 stride 选择,直接影响 token 数 n。
推理加速:KV Cache(自回归解码)
在生成任务中,解码长度从 1 增长到 n。不使用 KV cache 时每步都要重新计算全部注意力。
KV cache 通过缓存历史 K/V,每步只计算新 token 的投影,使复杂度从 O(n^2) 变为近似 O(n)。
import torch
def step_decode(x_t, k_cache, v_cache, wq, wk, wv):
q_t = x_t @ wq
k_t = x_t @ wk
v_t = x_t @ wv
k_cache = torch.cat([k_cache, k_t], dim=0)
v_cache = torch.cat([v_cache, v_t], dim=0)
scores = (q_t @ k_cache.T) / (q_t.shape[-1] ** 0.5)
attn = torch.softmax(scores, dim=-1)
out = attn @ v_cache
return out, k_cache, v_cache
d_model = 8
x_t = torch.randn(1, d_model)
wq = torch.randn(d_model, d_model)
wk = torch.randn(d_model, d_model)
wv = torch.randn(d_model, d_model)
k_cache = torch.empty((0, d_model))
v_cache = torch.empty((0, d_model))
out, k_cache, v_cache = step_decode(x_t, k_cache, v_cache, wq, wk, wv)
print(out.shape, k_cache.shape, v_cache.shape)
训练与推理的复杂度差异
- 训练:每个 batch 需要全量
n x n注意力,复杂度O(n^2)。 - 推理:自回归时只新增 1 个 token;使用 KV cache 后每步开销近似
O(n)。
当n=2048时,训练阶段的注意力矩阵规模约 4.2M,而推理阶段单步只需要2048的点积。
替代方案与取舍
- RNN:长依赖差、难并行,但参数少、适合短序列(
n<=128时更轻量)。 - CNN:并行强但感受野受限,需要堆叠多层(感受野需层数
L≈n/k)。 - Transformer:全局依赖强、并行高,但
O(n^2)成本高。
取舍表(时间/空间)
| 方法 | 时间复杂度 | 空间复杂度 | 优势 | 代价 |
|---|---|---|---|---|
| RNN | O(n * d^2) | O(n * d) | 参数少、短序列快 | 串行、长依赖差 |
| CNN | O(n * k * d^2) | O(n * d) | 并行好 | 感受野增长慢 |
| Transformer | O(n^2 * d) | O(n^2) | 全局依赖强 | 长序列成本高 |
如果任务是 n<=128 的短序列分类,RNN 或轻量 CNN 往往更省显存;
而当 n>=512 且需要跨句依赖时,Transformer 的优势更明显,尤其在多 GPU 并行训练下。
迁移路径(Skill Ladder)
- 下一步:学习高效注意力(FlashAttention、Sparse Attention)。
- 更复杂问题:长文本建模(
n>=16k)、检索增强、跨模态对齐。 - 工程深化:实现 KV cache、混合精度与算子融合,理解吞吐/延迟的权衡。
常见问题与注意事项
- mask 错误:自回归任务必须使用因果 mask,否则会“看未来”。
- 位置编码缺失:没有位置编码会破坏顺序信息,模型容易退化成无序集合。
- 维度错配:
d_model必须能被num_heads整除,否则无法分头。 - 缩放缺失:缺少
1/sqrt(d_k)时大d_k会导致 softmax 饱和。 - 投影维度误设:例如
d_model=512, num_heads=6时d_k不是整数,权重无法按头切分。
Padding Mask 示例
import torch
seq = torch.randn(2, 5, 32)
pad_mask = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]).bool()
attn_mask = ~pad_mask # True 表示需要屏蔽
attn_mask = attn_mask.unsqueeze(1).repeat(1, 5, 1)
print(attn_mask.shape)
padding mask 的维度错误是常见 bug,尤其在批量输入与多头广播时。
最佳实践与建议
- 使用 LayerNorm + residual 保障深层稳定性。
- 长序列任务优先考虑高效注意力方案。
- 训练时关注数值稳定(softmax、FP16 下溢/上溢)。
- 基于原论文的基线配置:
d_model=512,h=8,d_ff=2048,便于复现和调参起点。 - 小模型优先减少
d_model而不是h,以避免头维度过小(例如d_model=256, h=8时d_k=32)。
实现自检清单(Debug Checklist)
- 输入张量形状是否是
[B, n, d_model],并且d_model % h == 0。 - 计算
QK^T时是否是[B, h, n, n],mask 是否广播到该形状。 - softmax 维度是否正确(应在最后一维
n上)。 - FP16 下是否使用了数值稳定技巧(减去 max)。
- 解码时是否启用因果 mask 与 KV cache,避免重复计算。
小型消融实验建议(验证关键设计)
- 去掉位置编码:观察性能下降幅度,通常会出现明显退化。
- 头数对比:
h=4vsh=8,保持d_model不变,看是否过拟合或欠拟合。 - 缩放因子:移除
1/sqrt(d_k)后比较 loss 曲线是否变得不稳定。 - FFN 宽度:
d_ff=2*d_modelvs4*d_model,观察表达能力与计算成本的折中。 - 序列长度裁剪:
n=512与n=1024对比,确认长序列是否真的带来收益。
小结 / 结论
- Transformer 把序列建模转换为“全局相似度聚合”。
- 多头注意力让模型并行学习多种关系。
- 位置编码解决注意力的顺序盲点。
- 算法核心是
softmax(QK^T/sqrt(d_k))V。 - 工程瓶颈来自
O(n^2),需结合高效实现。 - 真正的工程难点在显存与带宽,需要算子级优化与策略性近似。
- 做到可复现的关键是形状与 mask 的一致性检查。
参考与延伸阅读
- Vaswani et al., Attention Is All You Need (2017) - https://arxiv.org/abs/1706.03762
- The Annotated Transformer - https://nlp.seas.harvard.edu/annotated-transformer/
- FlashAttention - https://arxiv.org/abs/2205.14135
行动号召(CTA)
从最小实现开始,亲手把一个注意力块跑起来;然后对比 RNN/CNN 的效果与训练速度,写下你的观察。