论文精读 对比学习综述 2021

1. 百花齐放

 在第一阶段上,方法模型都没有统一,目标函数,代理任务也没有统一,所以说是一个百花齐放的年代

1.1 InstDisc

image-20230731101249802

基本思想:图片聚集在一起的原因,并不是这些图片有相似的语义标签信息,而是因为这些图片长得比较像。通过一个卷积神经网络来将图片进行编码成一个低维特征,然后使得这些特征在特征空间上都尽可能的区分开,因为个体判别认为每张图片都是自成一类,提出了个体判别任务

Forward:假设模型的batchsize是256,有256张图片进入CNN网络,将256张图片编码为128维的向量。因为batchsize是256,因此有256个正样本。负样本来自memory bank,每次从memory bank中随机采样出4096个负数样本,利用 InfoNCE loss去更新CNN的参数。本次更新结束后,会将CNN编码得到的向量替换掉memory bank中原有的存储。就这样循环往复的更新CNN和memory bank,最后让模型收敛,就训练好一个CNN encoder了。

1.2 InvaSpread

image-20230731103807221

基本思想:用mini batch中的数据作为负样本,使用一个编码器进行端到端的学习,所选取的字典长度不够大。

Forward:首先利用数据增广,将每个图片增广一次,也就是将256张图片变为512个图片了。之后将512张图片通过CNN编码为向量,并使用一个全连接层将数据的维度降低。之后将$x{i}$和其经过增广后的图$\widetilde{x}{i}$作为正样本,其余的512-2张图片都认为是负样本。所以总计有256个正例,有2×(256-1)张负例。之后的在特征空间中$x{i}$ 与$\widetilde{x}{i}$的距离应该尽可能相近,而$x{i}$与$\widetilde{x}_{j}$的距离应该尽可能相远。

以上两篇工作都是使用个体判别 Instance Discrimination 作为代理任务的

1.3 CPC

image-20230731112441320

基本思想:有一个持续的序列,把之前时刻的输入喂给编码器,返回的特征再喂给一个自回归模型gar(auto-regressive,一般的自回归模型是RNN或LSTM),然后得到一个context representation,这是一个代表上下文的特征表示。如果context representation足够好,那么其应该可以做出一些合理的预测,所以可以用$c_{t}$预测未来时刻的特征输出$z_{t+i}$

生成式的代理任务

1.4 CMC

image-20230731113140008

基本思想:CMC想学一个非常强大的特征,其具有视角的不变性(不管是看见了一只狗,还是听到了狗叫声,都能判断出这是个狗)。所以,CMC的工作目的就是去增大这个互信息,就是所有视角之间的互信息。如果能学到一种特征,能够抓住所有视角下的这个关键的因素,那么这个特征就比较好。最大化互信息

方法:输入view来自于不同的传感器,或者说是不同的模态,但是这些所有的输入其实对应的都是一整的图片,一个东西,那么它们就应该互为正样本,相互配对。而这些相互配对的视角在特征空间中应该尽可能的相近,而与其他的视角尽可能的远离。Teacher和student编码得到的相同图片的向量互为正例,不同图片得到的输出作为负例,利用对比学习的思路进行知识蒸馏。

问题:在于multi view的工作可能需要多个编码器进行编码,训练代价可能有点高。比如CLIP,就是用大型的语言编码器BERT对语言模型进行编码,用视觉模型VIT对视觉信息进行编码。

InfoMin是CMC的作者做的一个分析型的延伸性工作,要是提出了一个InfoMin的原则,InfoMin的本意是不能一味的最大化这个互信息,而是要不多不少刚刚好,去选择合适的数据增强与合适的对比学习的视角。

小结

 可以看到以上的工作代理任务不尽相同,其中有个体判别,有预测未来,还有多视角多模态。使用的目标函数也不尽相同,有NCE,infoNCE以及其变体。使用的模型也可以是不同的,比如InvaSpread使用的是相同的编码器对key和query进行编码,CMC对key和query使用的是不同的编码,是百花齐放的。

2. CV双雄

2.1 MoCo v1

image-20230731142805835

基本思想:将一系列的对比学习方法归纳为一个字典查询的问题building dynamic dictionaries。将负样本图片通过编码器后所得的输出看成是一个特征key,将正样本图片通过另外一个编码器所得到的输出看成是一个query。对比学习本质上,就是希望在字典中找到与query最匹配的那个key,而这个key是正样本通过一些列的数据增强变化获得,所以语义信息应该相同,在特征空间上也应该类似,而与其他的负样本的特征key应该尽可能的远离,损失函数InfoNEC

贡献:

  1. queue 数据结构
  2. Momentum Encoder $θ_k←mθ{_k}+(1−m)θq$
  3. Shuffling BN - BN可能导致信息泄露

对比:

image-20230731144205429

方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Algorithm 1 Pseudocode of MoCo in a PyTorch-like style
# f_q, f_k: encoder networks for query and key
# queue: dictionary as a queue of K keys (CxK)
# m: momentum
# t: temperature
f_k.params = f_q.params # initialize
for x in loader: # load a minibatch x with N samples
x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version

q = f_q.forward(x_q) # queries: NxC (256x128)
k = f_k.forward(x_k) # keys: NxC (256x128)
k = k.detach() # no gradient to keys

# positive logits: Nx1 (256x1)
l_pos = bmm(q.view(N,1,C), k.view(N,C,1)) # q·k+

# negative logits: NxK (256x65536)
l_neg = mm(q.view(N,C), queue.view(C,K)) # sum q·ki

# logits: Nx(1+K) (256x65537)
logits = cat([l_pos, l_neg], dim=1)

# contrastive loss, Eqn.(1)
labels = zeros(N) # positives are the 0-th;利用pytorch函数特性
loss = CrossEntropyLoss(logits/t, labels)

# SGD update: query network
loss.backward()
update(f_q.params)

# momentum update: key network
f_k.params = m * f_k.params+(1-m) * f_q.params

# update dictionary
enqueue(queue, k) # enqueue the current minibatch
dequeue(queue) # dequeue the earliest minibatch
# bmm: batch matrix multiplication; mm: matrix multiplication; cat: concatenation.

2.2 SimCLR v1

image-20230731144944108

思路:假如有一个minibatch的图片,对整个minibatch的所有图片做数据增强,对图片x xx做不同的数据增强就会得$x_{i}$和$x_{j}$ 同一个图片延申得到的两个图片就是正样本,比如batchSize是n的话,那么正样本就是n,这个batchsize剩下的所有的样本以及其经过数据增强后得到的都是负样本,也就是2(n-1)。有了正负样本之后,对其进行编码,通过一个编码器$f ( ⋅ )$得到正负样本的编码结。SimCLR的创新点就是在得到数据的编码之后在后面加了一个编码层$g ( ⋅ )$函数,就是一个MLP层,得到较低维度的特征$z_{i}$和 $z_{j}$ ,用其进行对比学习,拉近正例之间的距离,拉远负例之间的距离。但是需要注意的一点就是投影函数仅仅在训练的时候才使用,在测试的时候是不使用的,测试的时候仅仅使用编码器$f(·)$ 。加上投影函数的目的也仅仅是想让模型训练的更好。

与InvaSpread相比:

  1. SimCLR使用了更多的数据增强 其中随机的裁剪以及随机的色彩变换最重要

  2. 加入了投影的$g ( ⋅ )$ 函数

  3. 就是SimCLR用了更大的batchsize,且训练的时间更久

损失函数:the normalized temperature-scaled cross entropy loss

方法:

image-20230731152831099

2.3 MoCo v2

改进:

image-20230731153325242

更省钱!!

2.4 SimCLR v2

image-20230731153856595

2.5 SWaV

image-20230731154140013

基本思想:给定同样的一张图片,如果去生成不同的视角(views),希望可以用一个视角得到的特征去预测另外一个视角的得到的特征,因为所有的这些视角的特征按道理来说都应该是非常接近的。然后SWaV将对比学习和之前的聚类的方法合在的一起,这样做也不是偶然,因为聚类也是无监督特征表示学习的方法,而且它也希望相似的物体都聚集在一个聚类中心附近,不相似的物体推到别的聚类中心

方法:聚类中心C CC就是Prototypes,作为一个矩阵维度是$d $$ *k$(d是特征的维度128维,k是聚类中心的数目3000)SwAV前向过程依旧是一个实例x通过两次数据增强变为 $x_{1}$ 和$x_{2}$ ,之后利用编码器对其进行编码,从而得到嵌入向量$z_{1}$ 和$z_{2}$ 。但是有了$z_{1}$和$z_{2}$ 之后,并不是直接在特征上去做对比学习的loss,而且让 $z_{1}$和$z_{2}$和聚类中心C进行聚类,从而得到ground truth的标签$Q_{1}$ 和$Q_{2}$ 。如果说两个特征比较相似或者是含有等量的信息,按道理来说应该是可以相互预测的。也就是说,用$z_{1}$ 和C作点乘按道理是可以去预测$Q_{2}$的,反过来用$z_{2}$ 和C作点乘按道理是可以去预测$Q_{1}$ 的,SwAV通过这种换位交叉预测的方法来对模型进行训练更新参数。

keys:

  1. Multi-crop:两个160×160的crop去注意全局特征,选择四个96×96的crop去注意局部特征
  2. 聚类

小结

 到了第二阶段,其实很多细节都趋于统一了,比如目标函数都是使用infoNCE,模型都归一为用一个encoder+projection head了,大家都采用了一个更强的数据增强,都想用一个动量编码器,也都尝试训练更久,最后在ImageNet上的准确度也逐渐逼近于有监督的基线模型。

3. 不用负样本

3.1 BYOL

image-20230731161241934

 在之前的对比学习工作中,是让$z_{\theta}$和$z_{\xi}^{‘}$尽可能的相似,而在BYOL这里,又加了一层predictor的全连接层$q_{\theta}$ ,$q_{\theta}$ 的网络结构和$g_{\theta}$ 的网络结构是完全一样的$z_{\theta}$ 通过$q_{\theta}$又得到了一个新的特征$q_{\theta}(z_{\theta})$现在的目的是想让特征$q_{\theta}(z_{\theta})$与$z_{\xi}^{‘}$

 图中的sg表示stop gradient,这里是没有梯度的。模型的上一支相当于query编码器,下面一支相当于key编码器,而key编码器都是通过query编码器来动量更新。不同是代理任务不一样,BYOL相当于是自己一个视角的特征去预测另外一个视角的特征,通过这种预测性的任务来完成模型的训练。

损失函数:MSE

3.2 SimSiam

image-20230731163008906

基本思想:实例x xx经过数据增强变为$x_{1}$ 和$x_{2}$ ,之后经过孪生的编码器$f ( ⋅ )$ ,得到嵌入$z_{1}$和$z_{2}$ ,之后经过预测层得到$p_{1}$ 和$p_{2}$ ,之后让$p_{1}$ 预测$z_{2}$,用$ p_{2}$去预测$z_{1}$,进行模型的训练。

伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# f: backbone + projection mlp
# h: prediction mlp
for x in loader: # load a minibatch x with n samples
x1, x2 = aug(x), aug(x) # random augmentation
z1, z2 = f(x1), f(x2) # projections, n-by-d
p1, p2 = h(z1), h(z2) # predictions, n-by-d

L = D(p1, z2)/2 + D(p2, z1)/2 # loss

L.backward() # back-propagate
update(f, h) # SGD update

def D(p, z): # negative cosine similarity 负余弦相似性
z = z.detach() # stop gradient

p = normalize(p, dim=1) # l2-normalize
z = normalize(z, dim=1) # l2-normalize
return -(p*z).sum(dim=1).mean()

 SimSiam能够成功训练的原因,不会发生模型坍塌,主要就是因为有stop gradient这个操作的存在。由于stop gradient,可以将SimSiam的结构看成是一个EM算法,相当于是在解决两个子问题,而模型更新也在交替进行,相当于不断的更新聚类中心。

image-20230731164336674

4. Transformer

 在vision transformer之后,因为其大大提升了encoder的效果,所以很多对比学习任务打算使用vision transformer作为backbone进行对比学习,涌现出了两篇工作,分别是MoCov3和DINO。

4.1 MoCo v3

骨干网络从ResNet 替换为 ViT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# f_q: encoder: backbone + proj mlp + pred mlp
# f_k: momentum encoder: backbone + proj mlp
# m: momentum coefficient
# tau: temperature

for x in loader: # load a minibatch x with N samples
x1, x2 = aug(x), aug(x) # augmentation
q1, q2 = f_q(x1), f_q(x2) # queries: [N, C] each
k1, k2 = f_k(x1), f_k(x2) # keys: [N, C] each

loss = ctr(q1, k2) + ctr(q2, k1) # symmetrized
loss.backward()

update(f_q) # optimizer update: f_q
f_k = m*f_k + (1-m)*f_q # momentum update: f_k

# contrastive loss
def ctr(q, k):
logits = mm(q, k.t()) # [N, N] pairs
labels = range(N) # positives are in diagonal
loss = CrossEntropyLoss(logits/tau, labels)
return 2 * tau * loss

# Notes: mm is matrix multiplication. k.t() is k’s transpose. The prediction head is excluded from f k (and thus the momentum update).

第一阶段的Patch投影时冻住,有效解决梯度波动问题

4.2 DINO

image-20230731165738061

 这里想表达的意思是一个完全不用任何标签信息训练出来的Vision Transformers,如果将其自注意力图拿出来进行可视化,可以发现其可以非常准确的抓住每个物体的轮廓,这个效果甚至可以直接匹配对这个物体作语义分割。

image-20230731165935686

 DINO的前向过程都是类似的,当有一个图片x的两个视角$x_{1}$和$x_{2}$之后,$ x_{1}$和$x_{2}$分别通过学生网络编码器$g_{\theta s}$和教师网络编码器$g_{\theta t}$得到两个特征$p_{1}$和$p_{2}$,其中编码器结构中同样包含projection head和prediction head。而为了避免模型的坍塌,DINO做了一个额外的工作centering,这个操作就是把整个batch里的样本都算一个均值,然后减掉这个均值其实就是centering。最后也是有一个stop gradient的操作,然后用$p_{1}$预测$p_{2}$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# gs, gt: student and teacher networks
# C: center (K)
# tps, tpt: student and teacher temperatures
# l, m: network and center momentum rates

gt.params = gs.params
for x in loader: # load a minibatch x with n samples
x1, x2 = augment(x), augment(x) # random views

s1, s2 = gs(x1), gs(x2) # student output n-by-K
t1, t2 = gt(x1), gt(x2) # teacher output n-by-K

loss = H(t1, s2)/2 + H(t2, s1)/2
loss.backward() # back-propagate

# student, teacher and center updates
update(gs) # SGD
gt.params = l*gt.params + (1-l)*gs.params
C = m*C + (1-m)*cat([t1, t2]).mean(dim=0)

def H(t, s):
t = t.detach() # stop gradient
s = softmax(s / tps, dim=1)
t = softmax((t - C) / tpt, dim=1) # center + sharpen
return - (t * log(s)).sum(dim=1).mean()

5. 总结

image-20230731222419004