CLIP 系列(2/3):PyTorch 完整可复现实战——从数据到训练闭环
副标题 / 摘要 这篇文章给出一个“最小但完整”的 CLIP 训练闭环:CIFAR-10 图像 + 文本提示,配套可直接运行的 PyTorch 脚本,确保你可以本地复现训练与零样本分类。 预计阅读时长:20~25 分钟 标签:clip、pytorch、reproducible、cifar10 SEO 关键词:CLIP, PyTorch, 可复现, CIFAR10, 对比学习 元描述:从数据准备到训练与评估,给出完整可复现的 CLIP PyTorch 实战脚本。 系列导航 (1/3)原理与对比学习公式 (2/3)PyTorch 完整可复现实战(本文) (3/3)工程化与优化 目标读者 想跑通 CLIP 训练闭环的初学者 需要可复现实验模板的工程实践者 希望基于 PyTorch 做多模态原型验证的读者 背景 / 动机 CLIP 的训练流程看起来简单,但“可复现”很难: 缺数据、缺脚本、缺评估,导致很多实验停在“理论上懂了”。 本篇用一个小数据集闭环复现,优先保证你能在本地跑起来。 核心概念 可复现性:固定随机种子、控制数据划分与预处理。 弱标注文本:用类名构造文本提示,模拟图文对齐。 对比损失:双向交叉熵 + 温度参数。 零样本评估:用文本提示作为“类别描述”进行分类。 A — Algorithm(题目与算法) 训练闭环的核心流程 为每张图像生成文本提示(如 a photo of a cat)。 图像与文本分别编码成向量并归一化。 计算相似度矩阵并用对比损失训练。 推理时用“文本提示集合”做零样本分类。 基础示例(1) 图像:一只猫 文本提示集合:cat, dog, car 目标:相似度最高的提示即为预测类别 基础示例(2) 同一 batch 内,对角线是“正确图文对” 训练目标:对角线最大化,非对角线最小化 实践指南 / 步骤 创建环境并安装依赖: python -m venv .venv source .venv/bin/activate pip install torch torchvision tqdm 把下面脚本保存为 clip_cifar10.py。 运行训练(推荐 GPU): python clip_cifar10.py --epochs 10 --batch-size 256 --device cuda 观察输出:loss 逐步下降,零样本准确率逐步上升。 可运行示例(完整 PyTorch 脚本) import argparse import math import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms, models from tqdm import tqdm CIFAR10_CLASSES = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False class SimpleTokenizer: def __init__(self, texts): self.pad_token = "<pad>" self.unk_token = "<unk>" self.bos_token = "<bos>" self.eos_token = "<eos>" vocab = { self.pad_token: 0, self.unk_token: 1, self.bos_token: 2, self.eos_token: 3, } for text in texts: for token in text.lower().split(): if token not in vocab: vocab[token] = len(vocab) self.stoi = vocab self.itos = {i: t for t, i in vocab.items()} self.pad_id = self.stoi[self.pad_token] self.unk_id = self.stoi[self.unk_token] self.bos_id = self.stoi[self.bos_token] self.eos_id = self.stoi[self.eos_token] def encode(self, text, max_len=16): tokens = text.lower().split() ids = [self.bos_id] ids.extend(self.stoi.get(t, self.unk_id) for t in tokens) ids.append(self.eos_id) if len(ids) > max_len: ids = ids[:max_len] ids[-1] = self.eos_id return ids def pad_tokens(token_lists, pad_id): max_len = max(len(t) for t in token_lists) tokens = torch.full((len(token_lists), max_len), pad_id, dtype=torch.long) attn = torch.zeros((len(token_lists), max_len), dtype=torch.bool) for i, ids in enumerate(token_lists): tokens[i, : len(ids)] = torch.tensor(ids, dtype=torch.long) attn[i, : len(ids)] = True return tokens, attn class CIFAR10Text(Dataset): def __init__(self, root, train, transform, tokenizer, max_len=16): self.ds = datasets.CIFAR10(root=root, train=train, download=True, transform=transform) self.prompts = [f"a photo of a {name}" for name in CIFAR10_CLASSES] self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.ds) def __getitem__(self, idx): image, label = self.ds[idx] text = self.prompts[label] token_ids = self.tokenizer.encode(text, max_len=self.max_len) return image, token_ids, label def collate_fn(batch, pad_id): images, token_lists, labels = zip(*batch) images = torch.stack(images) tokens, attn = pad_tokens(token_lists, pad_id) labels = torch.tensor(labels, dtype=torch.long) return images, tokens, attn, labels class ImageEncoder(nn.Module): def __init__(self, embed_dim): super().__init__() self.backbone = models.resnet18(weights=None) self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.backbone.maxpool = nn.Identity() self.backbone.fc = nn.Linear(self.backbone.fc.in_features, embed_dim) def forward(self, x): x = self.backbone(x) return F.normalize(x, dim=-1) class TextEncoder(nn.Module): def __init__(self, vocab_size, embed_dim, width=256, layers=2, heads=4, max_len=16): super().__init__() self.token = nn.Embedding(vocab_size, width) self.pos = nn.Embedding(max_len, width) encoder_layer = nn.TransformerEncoderLayer( d_model=width, nhead=heads, dim_feedforward=width * 4, dropout=0.1, batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=layers) self.proj = nn.Linear(width, embed_dim) def forward(self, token_ids, attn_mask): bsz, seq_len = token_ids.shape pos_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0).expand(bsz, -1) x = self.token(token_ids) + self.pos(pos_ids) x = self.encoder(x, src_key_padding_mask=~attn_mask) attn = attn_mask.unsqueeze(-1) x = (x * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1) x = self.proj(x) return F.normalize(x, dim=-1) class CLIPModel(nn.Module): def __init__(self, vocab_size, embed_dim=256, max_len=16): super().__init__() self.image_encoder = ImageEncoder(embed_dim) self.text_encoder = TextEncoder(vocab_size, embed_dim, max_len=max_len) self.logit_scale = nn.Parameter(torch.tensor(math.log(1 / 0.07))) def forward(self, images, token_ids, attn_mask): image_features = self.image_encoder(images) text_features = self.text_encoder(token_ids, attn_mask) logit_scale = self.logit_scale.exp().clamp(max=100) logits = logit_scale * image_features @ text_features.T return logits def clip_loss(logits): labels = torch.arange(logits.size(0), device=logits.device) loss_i = F.cross_entropy(logits, labels) loss_t = F.cross_entropy(logits.T, labels) return (loss_i + loss_t) / 2 @torch.no_grad() def zero_shot_accuracy(model, loader, tokenizer, device, max_len=16): model.eval() prompts = [f"a photo of a {name}" for name in CIFAR10_CLASSES] token_lists = [tokenizer.encode(p, max_len=max_len) for p in prompts] tokens, attn = pad_tokens(token_lists, tokenizer.pad_id) tokens = tokens.to(device) attn = attn.to(device) text_features = model.text_encoder(tokens, attn) correct = 0 total = 0 for images, _, _, labels in loader: images = images.to(device) image_features = model.image_encoder(images) logits = image_features @ text_features.T preds = logits.argmax(dim=1).cpu() correct += (preds == labels).sum().item() total += labels.size(0) return correct / max(total, 1) def main(): parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--embed-dim", type=int, default=256) parser.add_argument("--max-len", type=int, default=16) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--num-workers", type=int, default=2) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--data-root", type=str, default="./data") args = parser.parse_args() device = args.device if torch.cuda.is_available() and args.device == "cuda" else "cpu" set_seed(args.seed) prompts = [f"a photo of a {name}" for name in CIFAR10_CLASSES] tokenizer = SimpleTokenizer(prompts) train_tf = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ] ) test_tf = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ] ) train_ds = CIFAR10Text(args.data_root, True, train_tf, tokenizer, args.max_len) test_ds = CIFAR10Text(args.data_root, False, test_tf, tokenizer, args.max_len) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=lambda b: collate_fn(b, tokenizer.pad_id), ) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=lambda b: collate_fn(b, tokenizer.pad_id), ) model = CLIPModel(len(tokenizer.stoi), embed_dim=args.embed_dim, max_len=args.max_len).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) for epoch in range(1, args.epochs + 1): model.train() total_loss = 0.0 for images, tokens, attn, _ in tqdm(train_loader, desc=f"Epoch {epoch}"): images = images.to(device) tokens = tokens.to(device) attn = attn.to(device) logits = model(images, tokens, attn) loss = clip_loss(logits) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) scheduler.step() avg_loss = total_loss / len(train_ds) acc = zero_shot_accuracy(model, test_loader, tokenizer, device, args.max_len) print(f"Epoch {epoch}: loss={avg_loss:.4f}, zero-shot acc={acc:.4f}") if __name__ == "__main__": main() C — Concepts(核心思想) 方法类型 CLIP 属于对比学习 + 多模态表示学习范式,采用图文双塔编码器对齐语义空间。 ...