副标题 / 摘要
不做横向罗列,而是用两个核心概念深入解释:依赖路径长度与资源复杂度。
把“路径长度”当作信息能否有效传递的尺度,把“资源复杂度”当作可训练性的硬约束。
看懂它们,就能判断 CNN/RNN/LSTM/Transformer 在什么场景最合适,并能进行可量化的取舍。
- 预计阅读时长:约 18 分钟
- 标签:
cnn、rnn、lstm、transformer - SEO 关键词:CNN, RNN, LSTM, Transformer
- 元描述:从路径长度与复杂度两条主线系统对比 CNN、RNN、LSTM 与 Transformer。
目标读者
- 想快速理解主流神经网络结构差异的初学者
- 需要做模型选型的工程实践者
- 关注序列建模与多模态扩展的开发者
背景 / 动机
模型结构的选择本质上回答两个问题:
- 信息在序列里传多远、传多久(依赖路径长度)
- 计算与显存能否支撑(资源复杂度)
这篇文章只围绕这两条主线深入讲透,避免“平铺式扩展”。
举个具象例子:当 n=1024 时,RNN 需要 1024 次顺序步进才能完成一次前向;
Transformer 在 6~12 层内即可完成全局交互,但注意力矩阵有 n^2=1,048,576 个元素。
这两个“硬事实”几乎决定了:你要么被 路径长度 卡住,要么被 显存/吞吐 卡住。
忽略任意一条主线,都会让模型在性能或成本上失衡。
快速掌握地图(60-120s)
- 问题形态:图像/网格 → CNN;序列 → RNN/LSTM/Transformer。
- 核心差异:RNN/LSTM 的路径长度随
n增长;Transformer 路径长度接近 1,但成本随n^2爆炸。 - 何时使用/避免:
n<=256且低算力 → LSTM/RNN;n>=512且需并行 → Transformer;纯视觉 → CNN。 - 复杂度关键词:CNN
O(HWk^2);RNNO(n d^2)串行;LSTMO(4 n d^2);TransformerO(n^2 d)。 - 常见坑:忽略
n^2显存、误判依赖范围、mask/形状错配。
大师级心智模型
- 核心抽象:把“序列建模”看作“在计算图上进行信息路由”。路径长度决定信息是否能抵达,资源复杂度决定能否承载。
- 问题家族:局部连接(CNN)、链式传播(RNN/LSTM)、全局相似度聚合(Transformer),本质上是不同的图结构与最短路长度。
- 同构模板:
信息路由 = 聚合(邻居)。RNN 是线性链邻居,CNN 是固定半径邻居,注意力是全连接邻居。 - 关键不变量:若最短路径
L随n增长,长依赖学习将受到梯度衰减或延迟;若交互数为n^2,则显存与时间成本不可避免。
核心概念与术语(本篇只深挖两个)
- 依赖路径长度与并行性:决定“长依赖能否有效建模”。
- 资源复杂度(时间/显存)随 n 的增长:决定“是否能训练/部署”。
关键术语定义(后文反复使用):
- 路径长度 L:计算图中从位置
i到j的最短边数。 - 并行步数 S:一次前向需要的顺序步数。RNN 的
S≈n,CNN/Transformer 的S≈层数。 - 感受野 R:CNN 在输入空间可覆盖的跨度,
R = 1 + (k-1)L(无空洞时)。 - 序列长度 n / 隐向量维度 d:复杂度中的主导变量。
这四个量足以写出“路径长度”和“资源复杂度”的核心公式。
一个直接可用的估算式是:
若每层只能连接半径 r 的邻居,则跨越距离 d 的依赖需要L >= ceil(d / r)。
例如 r=2, d=256 时,L>=128,这在训练深度与梯度上都很吃力。
这个公式把“依赖跨度”与“层数需求”直接连在了一起。
问题抽象(输入/输出)
- 图像输入:
X ∈ R^{B x C x H x W},输出为分类/检测 logits。 - 序列输入:
X ∈ R^{B x n x d},输出为每步预测或序列表示。 - 优化目标:在算力/显存预算下最大化准确率与吞吐,同时满足延迟要求。
典型约束(工程上经常遇到的区间):
- 序列长度:
n ∈ [128, 8192],其中n>=1024进入“长序列”区域。 - 显存预算:单卡 16~80GB;
n>=4096时全量注意力经常触发 OOM。 - 延迟目标:在线推理常要求
P95 < 200ms,这会放大串行结构的劣势。
可行性与下界直觉
路径长度下界:如果每层只允许连接半径 r 的邻居(RNN 的 r=1,CNN 的 r=(k-1)/2),
则跨越距离 d 的依赖至少需要 L >= ceil(d / r) 层。
例:k=3 的 1D CNN,r=1,要覆盖 d=512 需要 L>=512 层;
即便改成 k=5,r=2,也要 L>=256 层,深度成本依旧极高。
注意力下界:全量注意力要计算任意 i,j 的相似度,
这意味着至少需要 Ω(n^2) 级别的交互或内存读写。
除非你主动丢弃一部分交互(窗口、稀疏、近似),否则不可能突破这个上界。
一个常用的折中是先下采样再注意力:
如果把序列长度从 n=2048 压到 n=1024,注意力成本会下降到 1/4;
但每个 token 代表的信息范围也被放大,等价于改变“有效路径长度”。
这说明你永远在两条主线上做权衡:要么压缩长度,要么付出平方成本。
朴素基线与瓶颈
- 基线 1:RNN 直接建模长序列
当n=1024时需要 1024 次顺序步进,GPU 利用率低;
反向传播要保存所有中间状态,训练耗时显著上升。 - 基线 2:浅层 CNN 覆盖长依赖
k=3, L=8的感受野仅R=17,对n=512任务几乎等于“看不到全局”。
想靠堆深度补足感受野,参数量与训练时间会迅速膨胀。
即使每步计算很轻,串行步数仍会决定延迟:
若单步 0.3ms,n=512 的 RNN 前向耗时约 154ms;
而 Transformer 的顺序步数仅是层数(例如 6 层 ≈ 1.8ms)。
这也是“基线可用但难扩展”的现实原因。
关键观察
“依赖”不是时间顺序本身,而是位置之间的关联强弱。
如果能在同一层中让所有位置互相“看见”,路径长度就可以从 O(n) 下降到 O(1);
代价是相互作用从 O(n) 变成 O(n^2),即资源复杂度提升。
深挖概念一:依赖路径长度与并行性(PDKH)
1) 问题重述(Pólya)
如果位置 i 的信息要影响位置 j,它必须沿计算图传播。
传播路径越长,梯度越容易衰减,训练越慢。
可以把“层-位置”看作节点,把“可达连接”看作边:
路径长度 L 就是最短路长度。
路径短意味着信息可以快速聚合,路径长意味着信息要经过多次变换才能抵达。
这就是为什么“路径长度”几乎直接决定了“长依赖能否学到”。
2) 最小示例(Bentley)
设序列长度 n=6,要让位置 1 影响位置 6:
- RNN/LSTM:必须逐步传递,路径长度 = 5。
- CNN(k=3, L=2):感受野为
1+(k-1)L=5,仍无法覆盖 6。
需要L=3层才覆盖全部。 - Transformer:同层任意位置直接交互,路径长度 = 1。
路径长度与并行度对比表
| 结构 | 路径长度 L(依赖距离 d) | 并行度 | 备注 |
|---|---|---|---|
| RNN | L=d | 低 | 串行依赖,难并行 |
| LSTM | L=d | 低 | 门控缓解梯度衰减 |
| CNN | L>=ceil((d-1)/(k-1)) | 中-高 | 依赖于层数与核宽 |
| Transformer | L=1 | 高 | 全局注意力并行 |
并行步数示例(S)
假设 n=1024,对单样本前向来说:
- RNN/LSTM:需要 1024 次顺序步进,
S≈1024。 - Transformer(6 层):需要 6 次顺序步进,
S≈6。 - CNN(20 层):需要 20 次顺序步进,
S≈20。
这解释了为什么 RNN 在 GPU 上往往吞吐最低:它不是算子慢,而是串行步数太多。
粗略估算:如果单步计算约为 0.2ms,S=1024 的 RNN 一次前向需要约 205ms;
而 S=6 的 Transformer 仅需约 1.2ms(不计通信与内存瓶颈)。
这也是“路径长度”直接决定吞吐的现实体现。
Worked Example:长依赖需要多少层 CNN?
若要覆盖 d=512 的依赖,k=3 卷积需满足L >= (d-1)/(k-1) = 255.5,至少 256 层。
这解释了为什么 CNN 在长序列任务上常被注意力替代。
微型追踪:n=4 的依赖传播
设序列为 [x1, x2, x3, x4],目标是让 x1 影响 x4:
- RNN:
x1 -> h2 -> h3 -> h4,路径长度为 3。 - CNN(k=3, L=2):第 1 层
x1只能影响{x1,x2},第 2 层才影响x3,仍达不到x4。 - Transformer:
x1可以直接参与x4的注意力加权,路径长度为 1。
这个极小例子说明:路径长度的差异从最小规模就已经出现,不是大规模才有的问题。
3) 不变量/契约(Dijkstra/Hoare)
若模型要稳定捕捉距离为
d的依赖,计算图必须提供长度L<=d的路径。
当L与n同阶增长,长依赖训练难度显著升高。
梯度衰减的数学直觉
RNN 的梯度链路是多个雅可比矩阵的连乘:∂h_t/∂h_{t-k} = Π_{i=t-k+1}^{t} J_i。
当 k 很大时,连乘会迅速缩小或爆炸,这就是“长依赖难学”的根源。
LSTM 通过 c_t 的门控通道让梯度更“直通”,但路径仍然是 O(n)。
一个直观的数值例子:假设每步的平均谱半径约为 0.9,
那么 100 步后的梯度规模约为 0.9^100 ≈ 0.000026;
即便提高到 0.99,0.99^100 ≈ 0.366,仍然在持续衰减。
这说明 路径长度越长,模型必须越依赖门控或残差来维持可训练性。
依赖跨度示例(为什么长依赖难)
复制任务:输入序列长度 n=512,模型需要在最后输出第 1 个 token。
RNN/LSTM 必须将信息连续传递 511 次;
Transformer 可以在一次注意力中直接连通开头和结尾。
括号匹配:匹配最外层括号往往需要跨越整段序列。
这类任务对路径长度极其敏感,往往更偏向 Transformer。
依赖跨度的估计方法
- 文本:统计同一句内的依赖跨度(通常 < 128),
若跨段落依赖频繁出现,跨度可接近 512 甚至更大。 - 时间序列:用自相关长度估计“有效记忆跨度”。
- 视觉序列/视频:依赖跨度常由帧间物体轨迹决定。
经验上可用 P90 作为“安全跨度”:
若 90% 的依赖都小于 256,则优先考虑 CNN/LSTM;
若 P90 已经超过 512,Transformer 的优势通常更稳。
当你能估计出一个“典型跨度 d”,选型就有了方向。
4) 形式化(Knuth)
- RNN/LSTM:路径长度
L = |i-j|。 - 1D CNN:感受野
R = 1 + (k-1)L,要覆盖d需L >= (d-1)/(k-1)。 - Transformer:单层即可连接任意位置,
L = 1。
并行度可以用“需要多少顺序步”理解:
RNN/LSTM 需要 n 次顺序步,CNN/Transformer 主要受层数影响。
这也是 Transformer 训练吞吐高的直接原因。
5) 正确性草证(Dijkstra/Hoare)
- RNN 的状态只从
t-1传到t,所以要跨d步必须经过d次传递。 - CNN 每层扩展
(k-1)的感受野,叠L层得到1+(k-1)L。 - Transformer 的注意力矩阵直接构建全局依赖,因此路径长度为 1。
结构逐一深化(路径长度视角)
CNN:
感受野增长是线性的。以 k=3 为例,感受野序列为 3、5、7、9…
当 L=6 时感受野只有 13;当 L=20 时也只有 41。
这说明 CNN 在“长依赖任务”上需要非常深的网络才能覆盖全局。
若引入空洞卷积(dilation),感受野公式可写为R = 1 + (k-1) * Σ d_l。
例如 4 层空洞卷积,d_l = [1, 2, 4, 8],则R = 1 + 2 * (1+2+4+8) = 31。
这比普通卷积大,但距离 n=512 仍然相去甚远。
它改善了路径长度,却没有从根本上改变“线性增长”的本质。
RNN:
路径长度等于时间步数。n=512 时最远依赖需要 511 次状态传递。
即使每步计算很轻,长链路也会放大梯度衰减问题。
LSTM:
门控让“有效记忆”更稳定,但路径长度仍然是 O(n)。
工程上常用 b_f=1 这类策略延长记忆,但并不能改变路径的本质长度。
Transformer:
路径长度为 1,使长依赖建模变成“全局并行”的矩阵乘法。
代价是显存与计算开销上升(见概念二)。
LSTM 门控机制如何延长“有效记忆”
LSTM 的核心是细胞状态 c_t 与三个门控:f_t = σ(W_f [x_t, h_{t-1}])(遗忘门)i_t = σ(W_i [x_t, h_{t-1}])(输入门)o_t = σ(W_o [x_t, h_{t-1}])(输出门)c_t = f_t ⊙ c_{t-1} + i_t ⊙ tanh(W_c [x_t, h_{t-1}])
当 f_t 接近 1,c_t 就能在长序列中保持信息不衰减。
这解释了为什么 LSTM 比普通 RNN 更适合中长序列,
但它仍然无法改变“路径长度随 n 增长”的事实。
如果 f_t 的均值约为 0.95,200 步后的记忆系数约为 0.95^200 ≈ 0.00034;
即便提升到 0.99,200 步后仍只有 0.99^200 ≈ 0.133。
这说明 门控只是在延长“有效路径长度”,无法从数量级上改变路径长度。
Transformer 的“路径短”仍需要顺序信号
注意力本身是置换不变的,如果不加位置编码,
Transformer 会把序列当作无序集合。
因此路径长度为 1 并不等于“顺序问题自动解决”,
位置编码是保证顺序信息可达的必要条件。
常用的正弦位置编码形式为:PE(pos, 2i) = sin(pos / 10000^(2i/d_model))PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
它为每个位置提供多尺度的相位信息,使注意力具备顺序感知。
从图视角看,注意力矩阵 A = softmax(QK^T) 就是一个带权全连接图:A 的每一行都归一化为 1,输出是值向量的凸组合。
因此单层注意力就能把任何位置的信息路由到任意位置,
这也是路径长度被压缩到 1 的核心原因。
Worked Example:CNN 感受野增长表(k=3)
| 层数 L | 感受野 R |
|---|---|
| 1 | 3 |
| 2 | 5 |
| 3 | 7 |
| 4 | 9 |
| 8 | 17 |
| 16 | 33 |
对比 n=512 的序列长度,这些感受野仍然非常有限。
def receptive_field(k, layers):
return 1 + (k - 1) * layers
for L in [1, 2, 4, 8, 16]:
print(L, receptive_field(3, L))
6) 阈值与规模(Knuth)
- 当 依赖跨度 > 256,RNN/LSTM 通常开始吃力;
- 当 跨度 > 512,Transformer 的优势开始显著;
- 但这同时引入
n^2成本(见概念二)。
这些阈值不是“理论极限”,而是经验尺度:
在语音与短文本(n≈128~256)中,LSTM 往往还能保持稳定;
在长文档与代码(n>=512)中,路径长度成为主要瓶颈,
如果还要求高吞吐,注意力的并行优势会更明显。
7) 反例/失败模式(Bentley/Sedgewick)
如果任务是局部依赖(如 n<=128 的短文本分类),
Transformer 反而可能因过强全局建模而过拟合,
此时 LSTM/1D CNN 仍是更稳妥的选择。
例如在 n=64 的评论情感任务上,
如果训练集只有几万样本,Transformer 的参数量与自由度会明显过剩,
路径长度优势无法转化为收益,反而可能导致验证集性能下降。
8) 工程现实(Knuth)
路径长度短 ≠ 一定更好:
Transformer 必须有 位置编码 才能表达顺序;
RNN/LSTM 通过门控机制在 n=200~500 仍可保持稳定记忆。
工程上需要同时评估依赖跨度与训练成本。
在实践中常用的“补救手段”是:
- CNN 通过残差与金字塔结构扩展感受野;
- RNN/LSTM 通过截断 BPTT 控制训练成本;
- Transformer 通过相对位置编码增强局部性。
这些技巧的共同目标是:在不改变结构主线的前提下缩短有效路径。
截断 BPTT 的具体影响:当你把反向传播长度截到 256,
等价于承认“有效依赖跨度上限就是 256”。
这在语音、短文本任务上很合理,但在长文摘要或代码理解中会显著损失性能。
因此 BPTT 的截断长度其实就是“工程上的路径长度预算”。
深挖概念二:资源复杂度随 n 增长(PDKH)
1) 问题重述(Pólya)
当 n 变大时,模型还能否训练和部署?
这个问题由时间/显存复杂度决定。
把资源复杂度拆成三个维度会更清晰:
- 计算量(FLOPs):决定训练/推理速度;
- 显存占用:决定是否 OOM;
- 内存带宽:决定实际吞吐是否被读写拖慢。
Transformer 常常不是算子数量最慢,而是“内存读写”最贵。
2) 最小示例(Bentley)
设 n=2048, d_model=512, h=8:
- 注意力矩阵元素数
n^2=4,194,304。 - 单头 FP16 权重约 8 MB,8 头约 64 MB。
- 训练还需激活与梯度,峰值常见是 3~5 倍。
资源估算公式(注意力权重)
若 batch 为 B、头数为 h、dtype 为 FP16(2 bytes):memory ≈ B * h * n^2 * 2 bytes。
例如 B=4, h=8, n=2048 时:4 * 8 * 2048^2 * 2 ≈ 512 MB(仅注意力权重,不含激活与梯度)。
这个公式的放大效应非常“残酷”:
- B 翻倍 → 显存翻倍;
- n 翻倍 → 显存变为 4 倍;
- h 翻倍 → 显存翻倍。
所以“把 n 从 2k 提到 4k”常常比“把层数从 12 提到 16”更致命。
一个更实用的估算方式是先解出“可承受的 n 上限”:n_max ≈ sqrt(显存预算 / (B * h * 2 bytes))。
如果显存预算是 8GB、B=2, h=8,粗算 n_max ≈ sqrt(8GB / 32 bytes) ≈ 16k。
但考虑 48 倍峰值倍率,实际可用 4 折。n 通常要再打 3
n 与显存的量级表(单头 FP16)
| n | n^2 元素 | 约显存 |
|---|---|---|
| 512 | 262,144 | ~0.5 MB |
| 1024 | 1,048,576 | ~2 MB |
| 2048 | 4,194,304 | ~8 MB |
| 4096 | 16,777,216 | ~32 MB |
| 8192 | 67,108,864 | ~128 MB |
将以上数值乘以 B 与 h,即可得到真实占用。
还要考虑内存带宽:
以 n=2048 为例,单头注意力权重约 8MB;12 层就是 96MB 的读写量。
训练时还要读写梯度与激活,实际带宽压力会进一步放大,
这也是为什么 FlashAttention 通过“少读写”就能带来显著加速。
def attn_memory_mb(n, h=8, batch=4, bytes_per_elem=2):
return batch * h * n * n * bytes_per_elem / (1024 ** 2)
for n in [512, 1024, 2048, 4096]:
print(n, f\"{attn_memory_mb(n):.1f} MB\")
3) 不变量/契约(Dijkstra/Hoare)
只要使用全量注意力,就必须显式或隐式计算
n^2级别的交互。
不引入近似,就无法突破这一开销。
4) 形式化(Knuth)
- CNN:
O(HWk^2)(或序列化为O(n k d^2)) - RNN:
O(n d^2)(串行) - LSTM:
O(4 n d^2) - Transformer:
O(n^2 d)+O(n d^2)(FFN)
计算量粗估(以 n=1024, d=512 为例)
- RNN:每步
d^2,总计1024 * 512^2 ≈ 268M乘加。 - LSTM:约 4 倍,
≈ 1.07B乘加。 - Transformer 注意力:
n^2 * d_k,若d_k=64,约1024^2 * 64 ≈ 67M乘加,
但还需 FFN:2 * n * d * d_ff(d_ff=2048时约 2.1B)。
结论:Transformer 的瓶颈常在 FFN 与注意力矩阵的内存,而非单纯算子数量。
注意力与 FFN 的主导区间可用一个简单比较得到:n^2 d(注意力) vs 2 n d d_ff(FFN)。
约简后得到阈值 n > 2 d_ff 时注意力开始主导。
若 d_ff=2048,则当 n>4096,注意力成本才会明显压过 FFN。
这解释了“中等长度时 FFN 是瓶颈,超长序列时注意力是瓶颈”。
5) 正确性草证(Dijkstra/Hoare)
Transformer 的 QK^T 必须计算每对 token 相似度,
因此时间和显存随 n^2 增长是不可避免的。
6) 阈值与规模(Knuth)
n<=2048:全量注意力通常可接受。2048 < n <= 8192:建议 FlashAttention 或分块注意力。n>8192:需要稀疏/线性注意力或检索增强。
一个可操作的上限估计:
若单卡 24GB、B=2, h=8,仅注意力权重在 n=4096 时约 512MB,
结合激活与优化器后容易逼近 16~24GB。
这意味着“n=4k 已经是单卡训练的警戒线”。
结构逐一深化(资源复杂度视角)
CNN:
计算量主要随输入分辨率 H*W 增长,显存与 H*W 同阶。
在视觉任务中,参数共享让 CNN 的参数量相对可控。
RNN/LSTM:
计算量随 n 线性增长,但必须串行执行;显存相对稳定。
当 n 变大时,训练时间常成为瓶颈而非显存。
Transformer:
显存和计算随 n^2 增长,最敏感。
当 n 翻倍时,注意力矩阵规模变为 4 倍,训练成本急剧上升。
显存组成(训练阶段的主要项)
- 注意力权重:
B * h * n^2 - Q/K/V 激活:
3 * B * n * d - FFN 激活:
B * n * d_ff - 优化器状态(Adam):约 2 倍参数量
这意味着即便权重不大,激活与优化器状态也会把显存拉到很高。
预算工作表(粗略估算)
以 B=2, n=2048, d=512, h=8, d_ff=2048 为例,做一个粗略预算:
- 注意力权重:
B*h*n^2*2 bytes ≈ 256 MB - Q/K/V 激活:
3*B*n*d*2 bytes ≈ 12 MB - FFN 激活:
B*n*d_ff*2 bytes ≈ 16 MB - 参数与优化器状态(Adam):每参数约 12 bytes(权重+动量+方差)
实际训练中还要考虑梯度缓存与临时张量,
因此“粗算 300MB”往往会变成“峰值 1GB+”。
这也是为什么全量注意力在长序列下极其昂贵。
一个经验性的“峰值倍率”是 4~8 倍:
参数 + 梯度 + 优化器状态 + 激活缓存叠加后,
你很难仅凭“参数量”判断显存是否足够。
这也解释了为什么许多显存 OOM 并非来自模型权重,而是来自激活与注意力矩阵。
7) 反例/失败模式(Bentley/Sedgewick)
在显存只有 16GB 的单卡上强行使用 n=8k 全量注意力,
极易 OOM 或必须极小 batch,训练效率反而更差。
8) 工程现实(Knuth)
常见解决路径:FlashAttention、分块注意力、KV cache、梯度检查点。
这些方法牺牲一定实现复杂度,换取可训练性和吞吐。
训练 vs 推理的复杂度差异
- 训练:全量注意力需要
n^2的矩阵,显存与计算都高。 - 推理(自回归):使用 KV cache 后,每步仅与历史
K/V交互,
单步复杂度近似O(n),显存也更可控。
这也是为什么 Transformer 在推理时常“勉强可用”,
但训练时需要更强的算力与更精细的内存优化。
KV cache 的显存规模可用公式估计:memory ≈ B * h * n * d_k * 2 bytes。
若 B=1, h=8, n=4096, d_k=64,则约为 4 MB;
但若 B=8 或 n=16k,显存会线性膨胀,需要提前规划。
Worked Example:n=1024 与 n=4096 的成本差异
以单头 FP16 为例:n=1024 时注意力权重约 2 MB;n=4096 时约 32 MB,直接增加 16 倍。
如果 B=4, h=8,n=4096 的注意力权重单项就超过 1 GB,
这还不包含梯度与激活。
这说明“长度翻倍”不只是线性增量,而是几何级别的成本跳跃。
复杂度与规模总结(对两条主线做汇总)
| 结构 | 路径长度 L | 顺序步数 S | 时间复杂度(主导项) | 显存复杂度(主导项) |
|---|---|---|---|---|
| CNN | ~(d/(k-1)) | ≈层数 | O(n k d^2) 或 O(HWk^2) | O(n d) |
| RNN | d | ≈n | O(n d^2) | O(n d) |
| LSTM | d | ≈n | O(4 n d^2) | O(n d) |
| Transformer | 1 | ≈层数 | O(n^2 d) + O(n d^2) | O(n^2) + O(n d) |
这张表把“路径长度”和“资源复杂度”放在同一平面上:
路径短的结构(Transformer)在资源上更昂贵;
资源稳定的结构(RNN/LSTM)在路径长度上更吃亏。
还要记住“顺序步数 S”是一种硬上限:
即便加机器,S 也很难通过并行彻底消除。
例如 RNN 在 n=1024 时需要 1024 次顺序步进,
多卡只能分摊批量,无法缩短这 1024 次“时间步”。
常量因素与工程现实(与两条主线相关)
- 算子粒度差异:RNN 的矩阵乘法粒度小、次数多,GPU 吞吐难以打满;
Transformer 的矩阵乘法粒度大、次数少,但内存带宽压力大。
这解释了为什么“理论 FLOPs 不高”并不意味着“实际训练快”。 - 精度与显存:FP16/BF16 可把注意力权重与激活显存减半,
例如n=2048时注意力权重从 ~8MB 降到 ~4MB(单头)。
但路径长度与依赖跨度不会因此改变。 - 残差与缓存:残差连接缩短有效路径,但也会让激活缓存变大;
路径越短的结构往往越依赖缓存与带宽,工程上需要更精细的内存规划。
Worked Example(Trace):同一任务的路径与成本
设一个玩具任务:序列长度 n=8,希望让第 1 个 token 的信息影响第 8 个 token。
我们用同一目标对四种结构做“路径与成本”的同步对照:
RNN/LSTM
路径长度L = 7,必须经过 7 次状态传递。
顺序步数S = 8,无法并行。CNN(k=3)
感受野R = 1 + 2L。L=1 -> R=3,L=2 -> R=5,L=3 -> R=7,L=4 -> R=9。
只有L>=4才能覆盖x1 -> x8的依赖。Transformer(1 层)
路径长度L = 1,x1可直接影响x8。
但注意力矩阵是n^2=64个元素(每头)。
这在n=8时极小,但当n=2048时会跃迁到 4,194,304。
这个例子强调:路径优势在小规模就成立,资源劣势在大规模才爆发。
因此选型必须同时关注“依赖跨度”和“序列长度”。
实践指南 / 步骤(选型流程)
- 估计依赖跨度:对文本可用依存跨度或句间跨度;若大量跨段落依赖,通常
d>=512。 - 估计序列长度 n:给出分位数(P50/P90/Max),因为
n决定n^2成本。 - 评估预算:用公式
B * h * n^2 * 2 bytes粗估注意力显存,并留出 4~8 倍峰值余量。 - 评估并行需求:若在线推理需要
P95 < 200ms,串行结构会被优先排除。 - 先跑轻量基线:用小 CNN/LSTM 验证“数据是否可学”,设定最低准确率门槛。
- 再升级结构:当
d大且预算足够时转向 Transformer;预算不足时考虑局部注意力或混合结构。
可以把流程进一步压缩成两个“必须回答”的问题:
- 依赖跨度问题:最远依赖
d是否明显大于 256? - 资源预算问题:显存是否能承受
B * h * n^2的注意力矩阵?
只要这两个问题有明确答案,模型选型通常就不会偏离太远。
一个简化的“2x2 决策矩阵”可以直接落地:
- d 小 + 预算小 → CNN/LSTM;
- d 小 + 预算大 → 小型 Transformer 或 CNN;
- d 大 + 预算小 → 局部注意力或混合结构;
- d 大 + 预算大 → 全量 Transformer。
它把两条主线变成可执行的工程判断。
如果你对 d 缺乏把握,可以先训练一个小型注意力模型,
统计注意力跨度分布,再决定是否需要全量注意力。
决策准则(Selection Guide)
- 依赖跨度门槛:若
d<=128,CNN/小型 RNN 往往足够;d>=512时优先考虑 Transformer。 - 序列长度门槛:
n<=256时全量注意力成本低;n>=2048必须提前做显存预算。 - 显存预算门槛:单卡 24GB 下,
B=2, h=8, n=4096的注意力权重就接近 512MB,
结合激活与优化器后容易突破 16~24GB。 - 实现复杂度容忍度:若团队不具备算子优化能力,优先用成熟实现(如标准 Transformer + FlashAttention)。
这些门槛不是绝对标准,但它们提供了“能否跑起来”的第一性过滤。
可运行示例(最小对比)
下面的代码只做“结构层面的最小对比”,不涉及训练与损失函数:
它帮助你感受 CNN 的局部聚合、LSTM 的顺序状态传递 与 Transformer 的全局交互。
运行后可观察各模块的输出形状,直观理解它们如何处理不同输入形态。
import torch
import torch.nn as nn
# CNN
cnn = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(16, 10),
)
img = torch.randn(2, 3, 32, 32)
print("cnn:", cnn(img).shape)
# LSTM
lstm = nn.LSTM(input_size=16, hidden_size=32, batch_first=True)
seq = torch.randn(2, 5, 16)
out, _ = lstm(seq)
print("lstm:", out.shape)
# Transformer
encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=32, nhead=4, batch_first=True),
num_layers=2,
)
seq = torch.randn(2, 6, 32)
print("transformer:", encoder(seq).shape)
解释与原理(归纳到两条主线)
- 依赖路径:Transformer 最短,RNN/LSTM 最长,CNN 取决于层数与核大小。
- 资源成本:Transformer 最贵(
n^2),RNN/LSTM 计算量大但显存稳定。
其余差异(如门控、位置编码)都可以看作对这两条主线的补强手段。
如果把模型放在二维坐标中理解:
- 横轴是“路径长度”(越短越靠左),
- 纵轴是“资源复杂度”(越低越靠下)。
RNN/LSTM 位于“下方但靠右”,Transformer 位于“上方但靠左”,
CNN 的位置则取决于核大小与层数。
这也是为什么在实际工程里经常需要混合结构:
用轻量局部模块保证资源预算,用少量全局模块补足路径长度。
更重要的是,两条主线并非独立:
你降低 n 会压缩资源成本,但也会扩大单 token 的“语义覆盖范围”;
你增加层数能缩短路径,但也提高计算量与训练难度。
因此真正的工程解法往往是“压缩 + 局部 + 少量全局”的组合。
工程应用场景(仅保留与两条主线相关的 3 个)
- 短文本分类(n<=128):依赖跨度小 → LSTM/1D CNN 通常足够。
当n<=128时,注意力矩阵只有16k级别元素,Transformer 的优势很难体现。 - 长文摘要(n>=1024):依赖跨度大 → Transformer,但需考虑
n^2成本。n=2048时注意力权重已达 4.2M 元素,需要 FlashAttention 或分块策略。 - 流式语音识别:低延迟要求 → CNN+LSTM 的混合结构更稳。
因为串行步数对实时性更敏感,局部 CNN 可先压缩,再用 LSTM 维持中程依赖。
替代方案与取舍(只围绕两条主线)
- 全量注意力 vs 局部注意力:
全量注意力是O(n^2),局部窗口注意力是O(n w)。
当n=2048, w=256时,成本约减少 8 倍,但路径长度约增加到n/w≈8。
换句话说:你是在用“更长路径”换“更低显存”。 如果依赖跨度d=2048,窗口大小w=256,
需要至少L>=ceil(d/w)=8层才能让信息跨越全局。
这会把“资源优势”转化为“深度与训练难度”的额外成本。 - 加深 CNN vs 引入注意力:
k=3的 CNN 要覆盖d=512需 256 层;
引入注意力可以把路径长度降到 1,但显存成本变为n^2。 - RNN/LSTM vs Transformer:
前者是线性资源但长路径;后者是短路径但平方资源。
当n小而d不大时,RNN/LSTM 的实际性价比常更好。 - 增大卷积核 vs 增加层数:
增大k可以缩短所需层数,但计算量增加为O(k d^2);
增加层数可以保持小核,但路径长度依旧增长,且训练更深网络更难。
迁移路径(Skill Ladder)
- 先掌握局部结构:理解 CNN 的感受野与路径长度。
- 再掌握链式传播:理解 RNN/LSTM 的状态传递与梯度衰减。
- 最后掌握全局路由:理解 Transformer 的全局交互与
n^2成本。 - 实战扩展:当
n极长或预算有限,尝试局部注意力或混合结构。
常见问题与注意事项
- 低估
n^2显存会导致训练无法展开。 - 位置编码缺失会使 Transformer 无法表达顺序。
- LSTM 隐状态过大在小数据上容易过拟合。
- 用浅层 CNN 处理长依赖会“看不到全局”,效果往往无法提升。
- 截断 BPTT 太短会把有效依赖压到 128/256,长依赖任务会明显掉分。
- 把
n翻倍但不增加数据量时,模型容易过拟合且显存激增。 - 只看参数量不看激活量,常常低估 Transformer 的真实显存需求。
- 只用平均
n估算成本会踩坑:P90 若是 2 倍,注意力显存就是 4 倍。 - 大量 padding 会把“无效 token”也送进注意力,建议使用长度分桶。
最佳实践与建议
- 把“依赖跨度”当作首要判断依据。
- 把“显存/吞吐预算”当作第二判断依据。
- 先用轻量基线验证数据可学,再升级结构。
- 如果两者冲突(跨度大但预算小),优先考虑稀疏/分块注意力或混合结构。
- 在长序列任务上,先做
n的分位数统计,再决定是否需要全量注意力。 - 调参时优先调
n和h,它们对显存与吞吐的影响最大。 - 记录一次完整训练的“峰值显存”,而不是只看模型参数量。
- 超长序列先做分块/降采样实验,观察准确率与依赖跨度是否被削弱。
- 训练日志里固定记录 P90
n、峰值显存与吞吐,作为结构调整依据。
小结 / 结论
- CNN 适合局部模式与视觉网格数据。
- RNN/LSTM 适合中短序列与低算力场景。
- Transformer 擅长长依赖与并行训练,但
n^2成本高。 - 选型的关键是两件事:依赖路径长度 与 资源复杂度。
- 当
d大于 512 时,路径长度往往决定上限;当n超过 2048 时,显存决定上限。 - 如果预算不足,优先缩短
n或使用局部注意力,再谈更深的模型结构。
参考与延伸阅读
行动号召(CTA)
用同一数据集分别跑一个 LSTM 和一个 Transformer,比较依赖跨度与显存成本,写下你的结论。