对比学习损失函数系列(3/4):InfoNCE 与 SimCLR
副标题 / 摘要 InfoNCE 是现代对比学习的核心损失,SimCLR 则把它推向实用化。本文用公式、步骤与最小实验,带你理解“批内负样本 + 增强视图”的训练逻辑。 预计阅读时长:18~22 分钟 标签:infonce、simclr、self-supervised SEO 关键词:InfoNCE, SimCLR, 对比学习, 自监督 元描述:讲清 InfoNCE 的数学目标与 SimCLR 的训练结构,含可运行代码示例。 系列导航 (1/4)对比损失 Contrastive Loss (2/4)三元组损失 Triplet Loss (3/4)InfoNCE + SimCLR(本文) (4/4)CLIP 对比学习目标 目标读者 希望入门自监督对比学习的读者 需要理解 SimCLR 训练流程的工程实践者 想把对比学习迁移到业务数据的开发者 背景 / 动机 有标注数据昂贵,而无标注数据充足。 InfoNCE 让我们用“正负样本对齐”替代人工标签, SimCLR 则证明:只要数据增强和 batch 够大,效果可以接近监督学习。 核心概念 正样本视图:同一图像的两种增强视图。 批内负样本:同一 batch 中其他样本视为负样本。 投影头:把表示映射到对比空间,提高对比学习效果。 A — Algorithm(题目与算法) 用通俗语言说明主题内容 InfoNCE 的核心是“在一堆负样本里找到正确配对”。 SimCLR 则把“正确配对”定义为同一张图像的两个增强视图。 基础示例(1) 图像 A 经过两种增强得到 A1 与 A2 目标:A1 与 A2 相似度最大化 基础示例(2) A1 在 batch 中看到 B1、C1 等视为负样本 目标:A1 与 A2 的相似度高于 A1 与其他样本 实践指南 / 步骤 设计增强策略(裁剪、翻转、颜色扰动)。 构造两份增强视图作为正样本对。 编码器 + 投影头输出对比向量。 使用 InfoNCE 计算对比损失并训练。 可运行示例(最小 SimCLR 实验) import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms torch.manual_seed(42) class TwoCrops: def __init__(self, base_transform): self.base = base_transform def __call__(self, x): return self.base(x), self.base(x) def info_nce(z1, z2, temp=0.5): z1 = F.normalize(z1, dim=1) z2 = F.normalize(z2, dim=1) logits = z1 @ z2.T / temp labels = torch.arange(z1.size(0), device=z1.device) loss1 = F.cross_entropy(logits, labels) loss2 = F.cross_entropy(logits.T, labels) return (loss1 + loss2) / 2 class Encoder(nn.Module): def __init__(self, out_dim=128): super().__init__() self.backbone = nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), ) self.proj = nn.Sequential( nn.Linear(32, 128), nn.ReLU(), nn.Linear(128, out_dim), ) def forward(self, x): x = self.backbone(x) return self.proj(x) base_tf = transforms.Compose( [ transforms.RandomResizedCrop(32, scale=(0.6, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] ) dataset = datasets.FakeData( size=512, image_size=(3, 32, 32), num_classes=10, transform=TwoCrops(base_tf), ) loader = DataLoader(dataset, batch_size=128, shuffle=True) model = Encoder() opt = torch.optim.Adam(model.parameters(), lr=1e-3) for epoch in range(1, 6): total = 0.0 for (x1, x2), _ in loader: z1 = model(x1) z2 = model(x2) loss = info_nce(z1, z2, temp=0.5) opt.zero_grad() loss.backward() opt.step() total += loss.item() print(f"epoch={epoch} loss={total/len(loader):.4f}") C — Concepts(核心思想) 方法类型 InfoNCE 与 SimCLR 属于自监督对比学习,通过增强视图构造正样本对。 ...