副标题 / 摘要
多头注意力并不是“多次重复”,而是让模型在不同子空间中同时关注不同关系。本文从原理、复杂度与工程场景出发解释其必要性,并给出最小 PyTorch 示例。

  • 预计阅读时长:14~18 分钟
  • 标签multi-head-attentionattentiontransformer
  • SEO 关键词:多头注意力, Multi-Head Attention, Transformer
  • 元描述:系统解释多头注意力机制的优势与工程取舍,含最小示例。

目标读者

  • 想理解 Transformer 关键设计的入门读者
  • 需要做模型结构选型的工程实践者
  • 关注注意力可解释性与效率的开发者

背景 / 动机

单头注意力只能在一个投影空间里“看关系”。
而自然语言/多模态里存在多种关系(语法、语义、位置、对齐)。
多头注意力让模型并行捕捉多种关系,提高表达能力与泛化。

核心概念

  • Head(注意力头):一个独立的注意力子空间。
  • 子空间投影:每个头有独立的 Q/K/V 线性投影。
  • 拼接与映射:多个头输出拼接后再线性映射回模型维度。

A — Algorithm(题目与算法)

用通俗语言说明主题内容

  • 单头注意力像“单一视角”。
  • 多头注意力像“多视角协作”,同时关注不同关系。

基础示例(1)

  • 机器翻译中,一个头关注语法对齐,另一个头关注实体对齐。

基础示例(2)

  • 同一序列中,一个头关注局部邻近词,另一个头关注长距离依赖。

实践指南 / 步骤

  1. 选择头数 h,保持 d_model % h == 0
  2. 每个头在子空间 d_head = d_model / h 中计算注意力。
  3. 拼接各头输出,线性投影回 d_model
  4. 观察注意力分布是否更丰富。

可运行示例(最小多头注意力)

import torch
import torch.nn as nn


torch.manual_seed(42)

mha = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)

x = torch.randn(2, 5, 32)
attn_out, attn_weights = mha(x, x, x)
print(attn_out.shape)
print(attn_weights.shape)

解释与原理

  • 每个头在不同线性子空间建模关系。
  • 多头输出拼接后,模型获得更丰富的特征组合。
  • 这使得同一层能同时学习多种依赖模式。

C — Concepts(核心思想)

方法类型

多头注意力属于并行子空间注意力建模范式。

关键公式

单头注意力:

$ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^\top}{\sqrt{d}})V $

多头注意力:

$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $

$ \text{MHA}(Q,K,V) = \text{Concat}(\text{head}_1, …, \text{head}_h)W^O $

解释与原理

  • 通过多组投影矩阵,把“不同关系”分配给不同头。
  • 拼接后线性映射,让模型融合多视角信息。

E — Engineering(工程应用)

场景 1:机器翻译对齐

  • 背景:源语言与目标语言存在多种对齐关系。
  • 为什么适用:不同头可以学习不同类型对齐。
  • 代码示例(Python):
import torch

src = torch.randn(1, 6, 32)
tgt = torch.randn(1, 5, 32)
print(src.shape, tgt.shape)

场景 2:长文档理解

  • 背景:需要同时捕捉局部上下文与全局主题。
  • 为什么适用:不同头分工关注不同尺度。
  • 代码示例(Python):
import torch

x = torch.randn(1, 128, 32)
print(x.mean().item())

场景 3:图文对齐

  • 背景:文本需要对齐图像区域。
  • 为什么适用:多头能同时关注多个视觉区域。
  • 代码示例(Python):
import torch

text = torch.randn(1, 10, 32)
image = torch.randn(1, 49, 32)
score = text @ image.transpose(-2, -1)
print(score.shape)

R — Reflection(反思与深入)

  • 时间复杂度:理论上仍为 O(n^2),但多头带来常数开销。
  • 空间复杂度:注意力矩阵与头数成比例增长。
  • 替代方案
    • 单头注意力:成本更低但表达能力弱。
    • 多查询注意力(MQA/GQA):减少 KV 计算成本。
  • 工程可行性:在多数 NLP/多模态任务上,多头注意力是稳健默认。

常见问题与注意事项

  • 头数过多会导致每头维度过小,表示能力下降。
  • 头数过少会限制多视角建模。
  • 实际效果依赖 d_modelh 的匹配。

最佳实践与建议

  • 默认选择 8 或 12 头作为起点。
  • 观察注意力可视化,确认多头是否学习到不同模式。
  • 若推理成本高,考虑 GQA/MQA。

S — Summary(总结)

核心收获

  • 多头注意力让模型在不同子空间并行建模关系。
  • 它提升表达能力与稳定性,是 Transformer 的关键设计。
  • 头数需要与维度匹配,否则会削弱效果。
  • 工程上可在性能与成本间权衡。

推荐延伸阅读

  • Attention Is All You Need
  • Multi-Query Attention / Grouped-Query Attention 研究
  • Transformer 结构优化实践

参考与延伸阅读

小结 / 结论

多头注意力不是“堆数量”,而是“并行视角”。
它让模型在同一层同时学习多种关系,是 Transformer 成功的关键。

行动号召(CTA)

从 4 或 8 个头开始实验,观察不同头数对效果与成本的影响。