副标题 / 摘要
多头注意力并不是“多次重复”,而是让模型在不同子空间中同时关注不同关系。本文从原理、复杂度与工程场景出发解释其必要性,并给出最小 PyTorch 示例。
- 预计阅读时长:14~18 分钟
- 标签:
multi-head-attention、attention、transformer - SEO 关键词:多头注意力, Multi-Head Attention, Transformer
- 元描述:系统解释多头注意力机制的优势与工程取舍,含最小示例。
目标读者
- 想理解 Transformer 关键设计的入门读者
- 需要做模型结构选型的工程实践者
- 关注注意力可解释性与效率的开发者
背景 / 动机
单头注意力只能在一个投影空间里“看关系”。
而自然语言/多模态里存在多种关系(语法、语义、位置、对齐)。
多头注意力让模型并行捕捉多种关系,提高表达能力与泛化。
核心概念
- Head(注意力头):一个独立的注意力子空间。
- 子空间投影:每个头有独立的 Q/K/V 线性投影。
- 拼接与映射:多个头输出拼接后再线性映射回模型维度。
A — Algorithm(题目与算法)
用通俗语言说明主题内容
- 单头注意力像“单一视角”。
- 多头注意力像“多视角协作”,同时关注不同关系。
基础示例(1)
- 机器翻译中,一个头关注语法对齐,另一个头关注实体对齐。
基础示例(2)
- 同一序列中,一个头关注局部邻近词,另一个头关注长距离依赖。
实践指南 / 步骤
- 选择头数
h,保持d_model % h == 0。 - 每个头在子空间
d_head = d_model / h中计算注意力。 - 拼接各头输出,线性投影回
d_model。 - 观察注意力分布是否更丰富。
可运行示例(最小多头注意力)
import torch
import torch.nn as nn
torch.manual_seed(42)
mha = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x = torch.randn(2, 5, 32)
attn_out, attn_weights = mha(x, x, x)
print(attn_out.shape)
print(attn_weights.shape)
解释与原理
- 每个头在不同线性子空间建模关系。
- 多头输出拼接后,模型获得更丰富的特征组合。
- 这使得同一层能同时学习多种依赖模式。
C — Concepts(核心思想)
方法类型
多头注意力属于并行子空间注意力建模范式。
关键公式
单头注意力:
$ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^\top}{\sqrt{d}})V $
多头注意力:
$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $
$ \text{MHA}(Q,K,V) = \text{Concat}(\text{head}_1, …, \text{head}_h)W^O $
解释与原理
- 通过多组投影矩阵,把“不同关系”分配给不同头。
- 拼接后线性映射,让模型融合多视角信息。
E — Engineering(工程应用)
场景 1:机器翻译对齐
- 背景:源语言与目标语言存在多种对齐关系。
- 为什么适用:不同头可以学习不同类型对齐。
- 代码示例(Python):
import torch
src = torch.randn(1, 6, 32)
tgt = torch.randn(1, 5, 32)
print(src.shape, tgt.shape)
场景 2:长文档理解
- 背景:需要同时捕捉局部上下文与全局主题。
- 为什么适用:不同头分工关注不同尺度。
- 代码示例(Python):
import torch
x = torch.randn(1, 128, 32)
print(x.mean().item())
场景 3:图文对齐
- 背景:文本需要对齐图像区域。
- 为什么适用:多头能同时关注多个视觉区域。
- 代码示例(Python):
import torch
text = torch.randn(1, 10, 32)
image = torch.randn(1, 49, 32)
score = text @ image.transpose(-2, -1)
print(score.shape)
R — Reflection(反思与深入)
- 时间复杂度:理论上仍为
O(n^2),但多头带来常数开销。 - 空间复杂度:注意力矩阵与头数成比例增长。
- 替代方案:
- 单头注意力:成本更低但表达能力弱。
- 多查询注意力(MQA/GQA):减少 KV 计算成本。
- 工程可行性:在多数 NLP/多模态任务上,多头注意力是稳健默认。
常见问题与注意事项
- 头数过多会导致每头维度过小,表示能力下降。
- 头数过少会限制多视角建模。
- 实际效果依赖
d_model与h的匹配。
最佳实践与建议
- 默认选择 8 或 12 头作为起点。
- 观察注意力可视化,确认多头是否学习到不同模式。
- 若推理成本高,考虑 GQA/MQA。
S — Summary(总结)
核心收获
- 多头注意力让模型在不同子空间并行建模关系。
- 它提升表达能力与稳定性,是 Transformer 的关键设计。
- 头数需要与维度匹配,否则会削弱效果。
- 工程上可在性能与成本间权衡。
推荐延伸阅读
- Attention Is All You Need
- Multi-Query Attention / Grouped-Query Attention 研究
- Transformer 结构优化实践
参考与延伸阅读
小结 / 结论
多头注意力不是“堆数量”,而是“并行视角”。
它让模型在同一层同时学习多种关系,是 Transformer 成功的关键。
行动号召(CTA)
从 4 或 8 个头开始实验,观察不同头数对效果与成本的影响。