A visual walkthrough of the CVPR 2022 Partial FC training loop: gather features, shard class centers, sample positives and negatives, compute partial logits, then update only the sampled weights.
这个页面用可视化方式拆解 CVPR 2022 Partial FC 的训练流程:聚合特征、切分分类中心、采样正负类别、计算局部 logits,最后只更新被采样到的权重。
Each GPU computes feature embeddings for its local batch. To compute the global loss, embeddings from all GPUs are gathered together.
每张 GPU 先计算本地 batch 的特征。为了计算全局分类损失,需要把所有 GPU 上的特征聚合到一起。
The massive 10M-class weight matrix is partitioned across GPUs. Each GPU holds approximately 2.5M class centers.
巨大的 1000 万类别分类权重会按列切分到多张 GPU 上。每张 GPU 只保存约 250 万个类别中心。
Instead of computing logits against all 2.5M local classes, each GPU samples the positive classes plus a small random negative subset.
每张 GPU 不再对本地全部 250 万类别计算 logits,而是保留正类,再随机采样一小部分负类参与训练。
The similarity between normalized embeddings and sampled weight centers is calculated, followed by ArcFace angular margin.
归一化后的特征只与采样到的类别中心计算相似度,然后对目标 logit 加上 ArcFace 角度间隔。
Gradients flow backward, are synchronized via AllReduce, and only the sampled weight subset is updated.
梯度反向传播后通过 AllReduce 同步,优化器只更新本轮被采样到的那部分分类权重。
References:
- "Killing Two Birds with One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC" — arXiv:2203.15565 (CVPR 2022)
- "Partial FC: Training 10 Million Identities on a Single Machine" — arXiv:2010.05222
参考文献:
- "Killing Two Birds with One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC" — arXiv:2203.15565 (CVPR 2022)
- "Partial FC: Training 10 Million Identities on a Single Machine" — arXiv:2010.05222
| Data Parallel数据并行 | Model Parallel模型并行 | PartialFC (r=0.1) | |
|---|---|---|---|
| Weight Memory / GPU每卡权重显存 | W: 19.07 GB Full C×d on every GPU每张卡存全部 C×d |
W: 298 MB C/N columns per GPU每卡存 C/N 列 |
W: 298 MB Same as Model Parallel与模型并行相同 |
| Logits MemoryLogits 显存 | BS×C = 156 GB All classes全量类别 |
BS×C/N = 2.44 GB Sharded but still large分片后仍然很大 |
BS×C·r/N = 250 MB Sampled subset only只计算采样子集 |
| Phase 1阶段 1 AllGather FeaturesAllGather 特征 |
— | 384 KB / GPU | 384 KB / GPU |
| Phase 3阶段 3 Softmax CommunicationSoftmax 通信 |
— | 32 KB 2 scalars / sample每样本 2 个标量 |
32 KB 2 scalars / sample每样本 2 个标量 |
| Phase 4阶段 4 Gradient Sync梯度同步 |
2×|W| ≈ 38 GB Full AllReduce全量 AllReduce |
505 MB AllReduce feature gradsAllReduce 特征梯度 |
488 MB AllReduce sampled gradsAllReduce 采样梯度 |
| Total Comm / step每步总通信量 | ~38 GB | ~1.01 GB | ~0.98 GB |
| Throughput吞吐量 | OOM Cannot train 10M classes无法训练 1000 万类 |
4,840 img/s 64 GPUs |
17,819 img/s 64 GPUs · 3.7× speedup3.7× 加速 |
* Based on N=64 GPUs, BS=64/GPU, feature_dim=512, C=10M, sample_rate=0.1.
* 以 N=64 GPUs, BS=64/GPU, feature_dim=512, C=10M, sample_rate=0.1 为例。