副标题 / 摘要
ViT 把图像切成 patch 序列,再交给 Transformer 编码器处理。本文用 ACERS 框架拆解 ViT 的核心结构与工程选择,并提供最小可运行的 PyTorch 示例。
- 预计阅读时长:16~20 分钟
- 标签:
vit、transformer、vision - SEO 关键词:ViT, Vision Transformer, Patch Embedding, 图像分类
- 元描述:系统描述 ViT 架构与工程实践,含最小 PyTorch 示例。
目标读者
- 想理解 ViT 架构的入门读者
- 需要做视觉模型选型的工程实践者
- 想从 CNN 迁移到 Transformer 的开发者
背景 / 动机
CNN 通过局部卷积捕获特征,但长程依赖与全局建模能力有限。
ViT 把图像当成序列,直接用自注意力做全局建模,
在大规模数据预训练下表现非常强。
核心概念
- Patch Embedding:将图像切成 patch 并线性投影。
- Position Embedding:补充位置信息。
- [CLS] Token:聚合全局特征用于分类。
- Transformer Encoder:多头自注意力 + FFN 堆叠。
A — Algorithm(题目与算法)
用通俗语言说明主题内容
ViT 的核心流程:
- 把图像切成固定大小 patch。
- 每个 patch 拉平成向量并投影成 token。
- 加上位置编码和 [CLS] token。
- 送入 Transformer Encoder 得到全局表征。
- 用 [CLS] token 输出做分类。
基础示例(1)
- 图像 224x224,patch 16x16 → 196 个 patch + 1 个 [CLS]。
基础示例(2)
- 只保留编码器,就能做图像分类与检索。
实践指南 / 步骤
- 选择 patch 大小(8/16/32)。
- 设置隐藏维度与层数(如 12 层,768 维)。
- 添加位置编码与 [CLS] token。
- 训练:优先用预训练权重再微调。
可运行示例(最小 ViT 前向)
import torch
import torch.nn as nn
torch.manual_seed(42)
class MiniViT(nn.Module):
def __init__(self, img_size=32, patch=8, dim=64, depth=2, heads=4):
super().__init__()
self.patch = patch
self.unfold = nn.Unfold(kernel_size=patch, stride=patch)
num_patches = (img_size // patch) ** 2
self.proj = nn.Linear(3 * patch * patch, dim)
self.cls = nn.Parameter(torch.zeros(1, 1, dim))
self.pos = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
enc_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=depth)
self.head = nn.Linear(dim, 10)
def forward(self, x):
patches = self.unfold(x).transpose(1, 2) # B, N, patch_dim
tokens = self.proj(patches)
cls = self.cls.expand(x.size(0), -1, -1)
tokens = torch.cat([cls, tokens], dim=1) + self.pos
z = self.encoder(tokens)
return self.head(z[:, 0])
x = torch.randn(2, 3, 32, 32)
model = MiniViT()
print(model(x).shape)
解释与原理
- patch embedding 把图像变成序列。
- self-attention 能在全局范围建模依赖。
- [CLS] token 作为全局聚合向量用于分类。
C — Concepts(核心思想)
方法类型
ViT 属于基于注意力的视觉表征模型,用 Transformer Encoder 替代卷积堆叠。
关键公式
Patch embedding:
$ x \in \mathbb{R}^{H\times W\times C} \rightarrow x_p \in \mathbb{R}^{N\times (P^2C)} $
自注意力:
$ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^\top}{\sqrt{d}})V $
解释与原理
- patch 大小决定 token 数量,从而决定注意力复杂度。
- 全局注意力使 ViT 对长程依赖更敏感。
E — Engineering(工程应用)
场景 1:图像分类
- 背景:ImageNet 级别分类任务。
- 为什么适用:ViT 在大规模预训练下精度高。
- 代码示例(Python):
import torch
logits = torch.randn(1, 1000)
print(logits.argmax(dim=1).item())
场景 2:小数据迁移学习
- 背景:小样本任务直接训练易过拟合。
- 为什么适用:预训练 ViT 微调更稳定。
- 代码示例(Python):
import torch
features = torch.randn(1, 768)
head = torch.randn(768, 5)
print((features @ head).shape)
场景 3:多模态图文对齐
- 背景:CLIP 等模型需要视觉编码器。
- 为什么适用:ViT 输出可直接对齐文本特征。
- 代码示例(Python):
import torch
import torch.nn.functional as F
img = F.normalize(torch.randn(1, 512), dim=-1)
text = F.normalize(torch.randn(1, 512), dim=-1)
print((img @ text.T).item())
R — Reflection(反思与深入)
- 时间复杂度:注意力为
O(N^2),N为 patch 数。 - 空间复杂度:注意力矩阵也为
O(N^2)。 - 替代方案:
- CNN:更高效但全局建模弱。
- Swin Transformer:窗口注意力降低复杂度。
- Hybrid 模型:卷积 + Transformer。
- 工程可行性:ViT 对数据量依赖更强,预训练是关键。
常见问题与注意事项
- patch 太小会导致显存爆炸。
- 小数据集训练易过拟合。
- 位置编码选择(绝对/相对)会影响性能。
最佳实践与建议
- 先用预训练权重,再做任务微调。
- 调整 patch 大小平衡精度与成本。
- 结合强数据增强提升泛化。
S — Summary(总结)
核心收获
- ViT 将图像转成 token 序列并用 Transformer 编码。
- Patch 大小决定复杂度与表现。
- 预训练 + 微调是 ViT 的主流工程路径。
- 与 CNN 相比,ViT 更擅长全局建模。
推荐延伸阅读
- An Image is Worth 16x16 Words
- DeiT:Data-efficient Image Transformers
- Swin Transformer
参考与延伸阅读
小结 / 结论
ViT 用最简洁的方式把视觉任务带入 Transformer 世界。
理解 patch embedding 与编码器结构,就能快速上手 ViT。
行动号召(CTA)
用本文的最小 ViT 结构替换你的视觉模型,观察精度与成本变化。