0. 前言
0.0 前言
这篇文章主要解决的问题是:当语义分割网络在合成数据集(有标签)上训练好,在真实数据集(没有标签)上性能下降比较多。作者认为有两个原因:对合成数据过拟合,合成数据与真实数据存在分布差异。(好吧,我认为这两是一个原因)。作者提出target guided distillation 和 spatial-aware adaptation 来改进性能,效果还挺好的。我主要看target guided distillaiton。
0.1 HydraPlus-Net
顺便记录下刚看的论文 HydraPlus-Net,因为这篇论文是caffe代码且对我的帮助不大,所以只是简单地记录下其中的创新点。
其主要创新点在于:
- attention不仅可以用于本block,也可以用于其他block
- 一个block可以生成多个attention map
1.Introduction
针对合成数据集的模型在真实数据集上性能很差,作者提出原因可能是:过拟合和分布不一致,因此提出模型:ROAD-Net。下面对作者提出的名词做出解释
- Reality Oriented Adaptation Networks(ROAD-Net)
- real style orientation & target guided distillation: 为了避免过拟合合成数据集,使用 distillation 使 model 的输出和预训练的模型的输出一致,注意,这里 distillation 针对的是真实数据集,而不是合成数据集,这种方法称为 target guided distillation.
- real distribution orientation & spatial-aware adaptation & domain classifier: 因为合成数据集和真实数据集的分布不一致,提出 DANN 也就是 domain classifier,其实类似GAN的D,来使得合成数据集和真实数据集的特征分布一致。主要过程就是将合成数据集和真实数据集的特征图分割成几个区域,然后判断这几个区域是不是同一个domain。
2. Reality Oriented Adaptation Networks
2.1 Target Guided Distillation
其中,pretrained model 在训练的时候不更新.
distillation loss:
其中,$x{i,j}, z{i,j}$分别表示 segmentation model 和 pretrained model 得到的 feature map 在位置 $(i,j)$ 的值,二范是简单的欧式距离。这个损失简单粗暴,后续的实验也证明了这种方法的确要更好一些。
我觉得这个是一个思路,这种 distillation 可以在一定程度上使得 segmentation model 学习到真实数据集的特征分布。
当然,除了 distillation loss 用来防止过拟合,也有其他方法用来防止过拟合,比如冻结一些层然后循环,或者用 source (合成数据集) 用来进行 distillation,是 learning without forgetting,这篇文章也很有用,等下简单地讲下这篇文章。
其实这里大有文章可做。
2.2 Spatial-Aware Adaptation
假设把特征图分割成了$m=1,…,M$块,每一块的区域坐标集合表示为$(u,v)\in Rm$,记点(u,v)对应的特征图为$x{u,v}$,记区域对应的特征图为$Xm^s=\{x{u,v}^s | (u,v)\in Rm\}$和$X_m^t=\{x{u,v}^t | (u,v)\in R_m\}$,定义其loss为:
其中,$L_{da}$表示 domain adaptation loss,其实就是domain classifier loss,具体表示如下。
其中$h:x\to \{0,1\}$,采用的DANN模型,应该是类似GAN中的D,$d\in \{0,1\}$.
这个domain classifer的目的是为了使生成的source和target生成的特征尽量相似。
2.3 Network Overview
3. Experimental Results
3.1 Experimental Results
- dst: target guided distillation
- spt: spatial-aware adaptation
通过结果可以看出,这两个创新点是有用的。
3.2 Analysis on Real Style Orientation
3.3 Analysis on Real Distribution Orientation
4. Others
这里主要简单介绍下论文中提到的几篇参考文献的主要内容,并没有细读这几篇参考文献。
4.1 Domain Adaptation
Unsupervised Domain Adaptation by Backpropagation
4.2 Distillation
Awesome Knowledge Distillation中有相关的论文和一部分实现。
终于找到一个 KD (knowledge-distillation) loss 的代码. 这个我看懂了。
1 | # 标准的KD,利用的是交叉熵求KD |
1 | # 利用KL散度求KD |
1 | # 这段代码证明了DL散度和交叉熵的反向传播是一样的 |
或者也可以参考: code