0. 前言 这篇文章的重点在于 dual attention 的作用,并且attention的使用和之前看到的 SE block 还不太一样。dual attention 主要解决了全局依赖性,即其他位置的物体对当前位置的的物体的特征的影响。重点不是场景分割,自己也不是很懂分割的代码和实现,暂时对分割不做过多研究。
1. Introduction 作者提出了 Dual Attention Network(DANet) 来融合局部特征。具体的过程是在 dilated FCN 上添加了两种 attention modules: the position attention module and the channel attention module ,这两个attention主要解决的是全局依赖性。
场景分割需要解决的两个问题:区分相似的东西(田地和草),识别不同大小外观的同一个东西(车)。因此,场景分割模型需要提高像素级别识别的特征表示。
一种方法是多尺度 融合来识别不同大小的物体,但是不能以全局地角度来很好地处理物体与物体之间的关系。应该是指最后的特征的感受野有大有小,可以理解成不同大小的物体都能识别到。
还有一种方法是利用了LSTM来实现 long-range dependencies,可以理解成物体的识别不仅依靠自己的特征,还依赖于其他物体的的特征,即全局依赖 或者空间依赖性 。
其中 the position attention module 主要用于解决全局的空间位置依赖问题,the channel attention module 解决的是全局的通道依赖性。
所以作者主要解决的是全局依赖性,并没有考虑不同大小的物体的分割问题。
按照作者的说法,DANet 有两种作用:第一,可以避免显眼的大的物体的特征影响不起眼的小的物体的标签;第二,可以在一定程度上融合不同尺寸的物体的相似特征;第三,利用空间和通道的依懒性解决全局依赖问题。
2. Dual Attention Network 3.1 Overview 基准网络是 dilated FCN.
3.2 Position Attention Module
通过特征提取网络得到特征图 $A\in R^{C\times H\times W}$,分别通过一个卷积层得两个特征图 $\lbrace B,C \rbrace \in R^{C\times H\times W}$,并且 reshape 成 $R^{C\times N}$,其中$N=H\times W$,然后得到$S\in R^{N\times N}$,此时把$R^{C}$看成这个位置的特征。下面阐述下具体的过程:
从而得到$S$
可以理解成对$\hat{S}$的每一行都做一次softmax,即 S 的每一行和为1,可以解释成C中的点与B中所有点的相似性,越相似值越大。其中B和C是对称的。
同时,将A送进第三个滤波器得到 $D\in R^{C\times H\times W}$ 并且 reshape 成 $R^{C\times N}$ ,从而得到最后的输出$E$,下面阐述具体计算过程:
从而得到$E\in R^{C\times N}$,
相当于 $\hat{E}$ 与 $A$ 进行了线性组合,并对其reshape变成$E\in R^{C\times H\times W}$,其中$\alpha$是一个可学习参数,网络自动学习,初始化为0。
如果不考虑其中的 softmax, 可以写成:
3.3 Channel Attention Module Position Attention Module 是把每个位置的通道作为其特征 $R^C$ ,Channel Attention Module 是把每个通道的特征图作为其特征 $R^N$。
与 Position Attention Module 不同的地方还有没有经过三个滤波器得到 $B,C,D$ ,而是直接使用A。
仍然是先把 A reshape 成 $A\in R^{C\times N}$,然后进行和上述类似的操作,可以令$B,C,D=A^T$下面阐述具体过程:
结合后面的代码分析,从而得到$X\in R^{C\times C}$:
同样可以理解成对 $\hat{X}$ 的每一行做一次softmax,可以理解成A的自相关性。结合后面的 channel attention 的可视化,不同通道代表的类别不同,所以这里应该是越不相似值越大。
然后类似地我们得到$E\in R^{C\times H \times W}$:
如果不考虑其中的 softmax, 可以写成:
这里给我的感觉更多地是在加法,而不是 SE block 用的乘法。
其实看到这里我是表示很怀疑的,这种 attention 能有效果吗?后面的可视化证明了作者的思路是正确的。
其中$\beta$也是一个可学习参数,网络自动学习,初始化为0。
4. Experiments 4.1 Implementation Details 学习率: 多项式衰减
下面专门做一个学习率衰减的情况。
4.2 Results on Datasets 4.2.1 Ablation Study for Attention Modules
从实验结果可以看出, the position module 和 the channel module 互为补充,两个合起来后的提升效果没有单个的提升效果明显。
4.2.3 Visualization of Attention Module 这一小节很有意思的。
对于 position attention,得到的 $E\in R^{(H\times W)\times (H\times W)}$ ,可以理解点与点之间的相似性,对每个图片,选两个点,记为 ( #1 and #2 ),并且展示这两个点的 position attention map. 第一张图 #1 标记的是建筑物, #2 标记的是车,第二张图分别标记的是交通标记和行人,第三行标记的是植物和行人。可以看出来,同一类事物哪怕离得远也可以标记出来,不同事物哪怕离得近也标记不出来。或者说, position attention 具有在全局的角度来标记同一类事物,哪怕离得远,哪怕事物很小,同时区分近距离的不同事物。
对于 channel attention, 从图片中可以看出来,主要是同一通道得到的是同一类别。
现在还不知道是怎么可视化的。
4.2.4 Comparing with State-of-the-art
嗯,比其他方法都强。
作者一共在四个数据集上做了实验,说明是真的强。
5. Learning rate 以前虽然一直在用一些学习率衰减方式,但是都不系统。
5.1 fixed 5.2 step 离散的学习率变化策略
其中,向下取整,并且 $\gamma$ 和 step_size 都需要设置 gamma一般取0.1, step_wise一般取40
5.3 exp 其中 $\gamma$ 需要设置 gamma一般取0.99
5.4 inv 其中$\gamma$和power都需要设置 gamma控制下降速率,power控制曲线在饱和状态下学习率达到的最低值。可以理解成当epoch达到最大值的时候,学习率在不同的power下最低值不一样。
5.5 multistep 多次step,只是学习率改变的迭代次数不均匀
1 2 3 4 5 lr_policy: "multistep" gamma: 0.5 stepvalue: 10000 stepvalue: 30000 stepvalue: 60000
5.6 poly 其中,power需要设置,并且epoch为0时,lr是base_lr,当达到最大次数时,学习率变成0.
5.7 sigmoid
其中step_size控制sigmoid为0.5的位置,gamma学习率的变化速率。
5.8 warm up 在前10个epoch使用较小的lr,之后正常使用
5.9 all
其中,step和multi_step最好,其次是exp,ploy,最差的是 inv,sigmoid.
5.10 code 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 import matplotlib.pyplot as pltx = list (range (1000 )) base_lr=0.01 def step_lr (epoch ): step_wise=50 gamma=0.1 return base_lr*gamma**(epoch//step_wise) y = [step_lr(i) for i in x] plt.plot(x,y) def exp_lr (epoch ): gamma=0.999 return base_lr*gamma**epoch y = [exp_lr(i) for i in x] plt.plot(x,y) def inv_lr (epoch ): gamma=0.1 power=0.75 return base_lr*(1 +gamma*epoch)**(-power) y = [inv_lr(i) for i in x] plt.plot(x,y) def multi_lr (epoch ): step_wise1=200 step_wise2=300 step_wise3=400 gamma=0.5 power = [0 ]*step_wise1 + [1 ]*step_wise2+[2 ]*step_wise3 if epoch<len (power): return base_lr*gamma**power[epoch] else : return base_lr*gamma**3 y = [multi_lr(i) for i in x] plt.plot(x,y)
6. code 这次的代码很有含金量,用到了多GPU。
一是涉及到的代码有点多,二是自己没有跑过分割的代码,不清楚具体的代码组织形式。所以下面从小到大一个个讲关键的地方。有些代码和作者的论文描述不是非常一致,但不影响总体。
6.1 PAM and CAM position attention and channel attention
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 class PAM_Module (Module ): """ Position attention module""" def __init__ (self, in_dim ): super (PAM_Module, self).__init__() self.chanel_in = in_dim self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8 , kernel_size=1 ) self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8 , kernel_size=1 ) self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1 ) self.gamma = Parameter(torch.zeros(1 )) self.softmax = Softmax(dim=-1 ) def forward (self, x ): """ inputs : x : input feature maps( B X C X H X W) returns : out : attention value + input feature attention: B X (HxW) X (HxW) """ m_batchsize, C, height, width = x.size() proj_query = self.query_conv(x).view(m_batchsize, -1 , width*height).permute(0 , 2 , 1 ) proj_key = self.key_conv(x).view(m_batchsize, -1 , width*height) energy = torch.bmm(proj_query, proj_key) attention = self.softmax(energy) proj_value = self.value_conv(x).view(m_batchsize, -1 , width*height) out = torch.bmm(proj_value, attention.permute(0 , 2 , 1 )) out = out.view(m_batchsize, C, height, width) out = self.gamma*out + x return out
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 class CAM_Module (Module ): """ Channel attention module""" def __init__ (self, in_dim ): super (CAM_Module, self).__init__() self.chanel_in = in_dim self.gamma = Parameter(torch.zeros(1 )) self.softmax = Softmax(dim=-1 ) def forward (self,x ): """ inputs : x : input feature maps( B X C X H X W) returns : out : attention value + input feature attention: B X C X C """ m_batchsize, C, height, width = x.size() proj_query = x.view(m_batchsize, C, -1 ) proj_key = x.view(m_batchsize, C, -1 ).permute(0 , 2 , 1 ) energy = torch.bmm(proj_query, proj_key) energy_new = torch.max (energy, -1 , keepdim=True )[0 ].expand_as(energy)-energy attention = self.softmax(energy_new) proj_value = x.view(m_batchsize, C, -1 ) out = torch.bmm(attention, proj_value) out = out.view(m_batchsize, C, height, width) out = self.gamma*out + x return out
6.2 DANetHead 从代码上看,过程大概是:
是在进入 attention module 会进行一次通道缩小,2048->512,
position attention module: sa_conv, channel attention module: sc_conv
得到三种预测的类别:sa_conv+sc_conv->sasc_output, sa_conv->sa_output, sc_conv->sc_output
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 class DANetHead (nn.Module): def __init__ (self, in_channels, out_channels, norm_layer ): super (DANetHead, self).__init__() inter_channels = in_channels // 4 self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3 , padding=1 , bias=False ), norm_layer(inter_channels), nn.ReLU()) self.sa = PAM_Module(inter_channels) self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3 , padding=1 , bias=False ), norm_layer(inter_channels), nn.ReLU()) self.conv6 = nn.Sequential(nn.Dropout2d(0.1 , False ), nn.Conv2d(512 , out_channels, 1 )) self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3 , padding=1 , bias=False ), norm_layer(inter_channels), nn.ReLU()) self.sc = CAM_Module(inter_channels) self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3 , padding=1 , bias=False ), norm_layer(inter_channels), nn.ReLU()) self.conv7 = nn.Sequential(nn.Dropout2d(0.1 , False ), nn.Conv2d(512 , out_channels, 1 )) self.conv8 = nn.Sequential(nn.Dropout2d(0.1 , False ), nn.Conv2d(512 , out_channels, 1 )) def forward (self, x ): feat1 = self.conv5a(x) sa_feat = self.sa(feat1) sa_conv = self.conv51(sa_feat) sa_output = self.conv6(sa_conv) feat2 = self.conv5c(x) sc_feat = self.sc(feat2) sc_conv = self.conv52(sc_feat) sc_output = self.conv7(sc_conv) feat_sum = sa_conv+sc_conv sasc_output = self.conv8(feat_sum) output = [sasc_output] output.append(sa_output) output.append(sc_output) return tuple (output)
6.3 BaseNet 以ResNet-50为例,相当于求得每一个layer的输出 [c1, c2, c3, c4]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 class BaseNet (nn.Module): def __init__ (self, nclass, backbone, aux, se_loss, dilated=True , norm_layer=None , base_size=576 , crop_size=608 , mean=[.485 , .456 , .406 ], std=[.229 , .224 , .225 ], root='./pretrain_models' , multi_grid=False , multi_dilation=None ): super (BaseNet, self).__init__() self.nclass = nclass self.aux = aux self.se_loss = se_loss self.mean = mean self.std = std self.base_size = base_size self.crop_size = crop_size if backbone == 'resnet50' : self.pretrained = resnet.resnet50(pretrained=True , dilated=dilated, norm_layer=norm_layer, root=root, multi_grid=multi_grid, multi_dilation=multi_dilation) elif backbone == 'resnet101' : self.pretrained = resnet.resnet101(pretrained=True , dilated=dilated, norm_layer=norm_layer, root=root, multi_grid=multi_grid,multi_dilation=multi_dilation) elif backbone == 'resnet152' : self.pretrained = resnet.resnet152(pretrained=True , dilated=dilated, norm_layer=norm_layer, root=root, multi_grid=multi_grid, multi_dilation=multi_dilation) else : raise RuntimeError('unknown backbone: {}' .format (backbone)) self._up_kwargs = up_kwargs def base_forward (self, x ): x = self.pretrained.conv1(x) x = self.pretrained.bn1(x) x = self.pretrained.relu(x) x = self.pretrained.maxpool(x) c1 = self.pretrained.layer1(x) c2 = self.pretrained.layer2(c1) c3 = self.pretrained.layer3(c2) c4 = self.pretrained.layer4(c3) return c1, c2, c3, c4 def evaluate (self, x, target=None ): pred = self.forward(x) if isinstance (pred, (tuple , list )): pred = pred[0 ] if target is None : return pred correct, labeled = batch_pix_accuracy(pred.data, target.data) inter, union = batch_intersection_union(pred.data, target.data, self.nclass) return correct, labeled, inter, union
6.5 DANet 相当于求这三种的预测:sa_conv+sc_conv->sasc_output, sa_conv->sa_output, sc_conv->sc_output
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 class DANet (BaseNet ): r"""Fully Convolutional Networks for Semantic Segmentation Parameters ---------- nclass : int Number of categories for the training dataset. backbone : string Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 'resnet101' or 'resnet152'). norm_layer : object Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; Reference: Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." *CVPR*, 2015 """ def __init__ (self, nclass, backbone, aux=False , se_loss=False , norm_layer=nn.BatchNorm2d, **kwargs ): super (DANet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs) self.head = DANetHead(2048 , nclass, norm_layer) def forward (self, x ): imsize = x.size()[2 :] _, _, c3, c4 = self.base_forward(x) x = self.head(c4) x = list (x) x[0 ] = upsample(x[0 ], imsize, **self._up_kwargs) x[1 ] = upsample(x[1 ], imsize, **self._up_kwargs) x[2 ] = upsample(x[2 ], imsize, **self._up_kwargs) outputs = [x[0 ]] outputs.append(x[1 ]) outputs.append(x[2 ]) return tuple (outputs)
6.6 SegmentationMultiLosses 希望 position+channel attetion, position attention, channel attention 三种预测都准确
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class SegmentationMultiLosses (CrossEntropyLoss ): """2D Cross Entropy Loss with Multi-L1oss""" def __init__ (self, nclass=-1 , weight=None ,size_average=True , ignore_index=-1 ): super (SegmentationMultiLosses, self).__init__(weight, size_average, ignore_index) self.nclass = nclass def forward (self, *inputs ): *preds, target = tuple (inputs) pred1, pred2 ,pred3= tuple (preds) loss1 = super (SegmentationMultiLosses, self).forward(pred1, target) loss2 = super (SegmentationMultiLosses, self).forward(pred2, target) loss3 = super (SegmentationMultiLosses, self).forward(pred3, target) loss = loss1 + loss2 + loss3 return loss
6.7 其他 其他的代码暂时就不看了,只是记录一个自己没有看到过的函数
Synchronized Cross-GPU Batch Normalization functions