0%

Dual Attention Network for Scene Segmentation

0. 前言

这篇文章的重点在于 dual attention 的作用,并且attention的使用和之前看到的 SE block 还不太一样。dual attention 主要解决了全局依赖性,即其他位置的物体对当前位置的的物体的特征的影响。重点不是场景分割,自己也不是很懂分割的代码和实现,暂时对分割不做过多研究。

  • paper: CPVR2019: Dual Attention Network for Scene Segmentation
  • code: pytorch
  • team: 中科院自动化所图像与视频分析团队(IVA),隶属于模式识别国家重点实验室,在 ICCV 2017 COCO-Places 场景解析竞赛、京东 AI 时尚挑战赛和阿里巴巴大规模图像搜索大赛踢馆赛等多次拔得头筹。嗯,一句话,很牛逼。
  • 解读

1. Introduction

作者提出了 Dual Attention Network(DANet) 来融合局部特征。具体的过程是在 dilated FCN 上添加了两种 attention modules: the position attention module and the channel attention module,这两个attention主要解决的是全局依赖性。

场景分割需要解决的两个问题:区分相似的东西(田地和草),识别不同大小外观的同一个东西(车)。因此,场景分割模型需要提高像素级别识别的特征表示。

一种方法是多尺度融合来识别不同大小的物体,但是不能以全局地角度来很好地处理物体与物体之间的关系。应该是指最后的特征的感受野有大有小,可以理解成不同大小的物体都能识别到。

还有一种方法是利用了LSTM来实现 long-range dependencies,可以理解成物体的识别不仅依靠自己的特征,还依赖于其他物体的的特征,即全局依赖或者空间依赖性

其中 the position attention module 主要用于解决全局的空间位置依赖问题,the channel attention module 解决的是全局的通道依赖性。

所以作者主要解决的是全局依赖性,并没有考虑不同大小的物体的分割问题。

按照作者的说法,DANet 有两种作用:第一,可以避免显眼的大的物体的特征影响不起眼的小的物体的标签;第二,可以在一定程度上融合不同尺寸的物体的相似特征;第三,利用空间和通道的依懒性解决全局依赖问题。

2. Dual Attention Network

3.1 Overview

基准网络是 dilated FCN.

3.2 Position Attention Module

通过特征提取网络得到特征图 $A\in R^{C\times H\times W}$,分别通过一个卷积层得两个特征图 $\lbrace B,C \rbrace \in R^{C\times H\times W}$,并且 reshape 成 $R^{C\times N}$,其中$N=H\times W$,然后得到$S\in R^{N\times N}$,此时把$R^{C}$看成这个位置的特征。下面阐述下具体的过程:

从而得到$S$

可以理解成对$\hat{S}$的每一行都做一次softmax,即 S 的每一行和为1,可以解释成C中的点与B中所有点的相似性,越相似值越大。其中B和C是对称的。

同时,将A送进第三个滤波器得到 $D\in R^{C\times H\times W}$ 并且 reshape 成 $R^{C\times N}$ ,从而得到最后的输出$E$,下面阐述具体计算过程:

从而得到$E\in R^{C\times N}$,

相当于 $\hat{E}$ 与 $A$ 进行了线性组合,并对其reshape变成$E\in R^{C\times H\times W}$,其中$\alpha$是一个可学习参数,网络自动学习,初始化为0。

如果不考虑其中的 softmax, 可以写成:

3.3 Channel Attention Module

Position Attention Module 是把每个位置的通道作为其特征 $R^C$ ,Channel Attention Module 是把每个通道的特征图作为其特征 $R^N$。

与 Position Attention Module 不同的地方还有没有经过三个滤波器得到 $B,C,D$ ,而是直接使用A。

仍然是先把 A reshape 成 $A\in R^{C\times N}$,然后进行和上述类似的操作,可以令$B,C,D=A^T$下面阐述具体过程:

结合后面的代码分析,从而得到$X\in R^{C\times C}$:

同样可以理解成对 $\hat{X}$ 的每一行做一次softmax,可以理解成A的自相关性。结合后面的 channel attention 的可视化,不同通道代表的类别不同,所以这里应该是越不相似值越大。

然后类似地我们得到$E\in R^{C\times H \times W}$:

如果不考虑其中的 softmax, 可以写成:

这里给我的感觉更多地是在加法,而不是 SE block 用的乘法。

其实看到这里我是表示很怀疑的,这种 attention 能有效果吗?后面的可视化证明了作者的思路是正确的。

其中$\beta$也是一个可学习参数,网络自动学习,初始化为0。

4. Experiments

4.1 Implementation Details

学习率: 多项式衰减

下面专门做一个学习率衰减的情况。

4.2 Results on Datasets

4.2.1 Ablation Study for Attention Modules

从实验结果可以看出, the position module 和 the channel module 互为补充,两个合起来后的提升效果没有单个的提升效果明显。

4.2.3 Visualization of Attention Module

这一小节很有意思的。

对于 position attention,得到的 $E\in R^{(H\times W)\times (H\times W)}$ ,可以理解点与点之间的相似性,对每个图片,选两个点,记为 ( #1 and #2 ),并且展示这两个点的 position attention map. 第一张图 #1 标记的是建筑物, #2 标记的是车,第二张图分别标记的是交通标记和行人,第三行标记的是植物和行人。可以看出来,同一类事物哪怕离得远也可以标记出来,不同事物哪怕离得近也标记不出来。或者说, position attention 具有在全局的角度来标记同一类事物,哪怕离得远,哪怕事物很小,同时区分近距离的不同事物。

对于 channel attention, 从图片中可以看出来,主要是同一通道得到的是同一类别。

现在还不知道是怎么可视化的。

4.2.4 Comparing with State-of-the-art

嗯,比其他方法都强。

作者一共在四个数据集上做了实验,说明是真的强。

5. Learning rate

以前虽然一直在用一些学习率衰减方式,但是都不系统。

5.1 fixed

5.2 step

离散的学习率变化策略

其中,向下取整,并且 $\gamma$ 和 step_size 都需要设置

gamma一般取0.1, step_wise一般取40

5.3 exp

其中 $\gamma$ 需要设置

gamma一般取0.99

5.4 inv

其中$\gamma$和power都需要设置

gamma控制下降速率,power控制曲线在饱和状态下学习率达到的最低值。可以理解成当epoch达到最大值的时候,学习率在不同的power下最低值不一样。

5.5 multistep

多次step,只是学习率改变的迭代次数不均匀

1
2
3
4
5
lr_policy: "multistep"
gamma: 0.5
stepvalue: 10000
stepvalue: 30000
stepvalue: 60000

5.6 poly

其中,power需要设置,并且epoch为0时,lr是base_lr,当达到最大次数时,学习率变成0.

5.7 sigmoid

其中step_size控制sigmoid为0.5的位置,gamma学习率的变化速率。

5.8 warm up

在前10个epoch使用较小的lr,之后正常使用

5.9 all

其中,step和multi_step最好,其次是exp,ploy,最差的是 inv,sigmoid.

5.10 code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import matplotlib.pyplot as plt
x = list(range(1000))
base_lr=0.01
def step_lr(epoch):
step_wise=50
gamma=0.1
return base_lr*gamma**(epoch//step_wise)
y = [step_lr(i) for i in x]
plt.plot(x,y)
def exp_lr(epoch):
gamma=0.999
return base_lr*gamma**epoch
y = [exp_lr(i) for i in x]
plt.plot(x,y)
def inv_lr(epoch):
gamma=0.1
power=0.75
return base_lr*(1+gamma*epoch)**(-power)
y = [inv_lr(i) for i in x]
plt.plot(x,y)
def multi_lr(epoch):
step_wise1=200
step_wise2=300
step_wise3=400
gamma=0.5
power = [0]*step_wise1 + [1]*step_wise2+[2]*step_wise3
if epoch<len(power):
return base_lr*gamma**power[epoch]
else:
return base_lr*gamma**3
y = [multi_lr(i) for i in x]
plt.plot(x,y)

6. code

这次的代码很有含金量,用到了多GPU。

一是涉及到的代码有点多,二是自己没有跑过分割的代码,不清楚具体的代码组织形式。所以下面从小到大一个个讲关键的地方。有些代码和作者的论文描述不是非常一致,但不影响总体。

6.1 PAM and CAM

position attention and channel attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class PAM_Module(Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim

self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = Parameter(torch.zeros(1))

self.softmax = Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
# C: 512, C//8: 64
m_batchsize, C, height, width = x.size()
# x: B,C,H,W
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
# C': B,HxW,C//8
proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
# B: B,C//8,HxW
energy = torch.bmm(proj_query, proj_key)
# \hat{S} = C'xB : B,HxW,HxW
attention = self.softmax(energy)
# S: B,HxW,HxW
proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
# D: B,C,HxW
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
# \hat{E} = DxS': B,C,HxW
out = out.view(m_batchsize, C, height, width)
# \hat{E} : B,C,H,W
out = self.gamma*out + x
# E: B,C,H,W
return out
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class CAM_Module(Module):
""" Channel attention module"""
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim=-1)
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X C X C
"""
m_batchsize, C, height, width = x.size()
# x: B,C,H,W
proj_query = x.view(m_batchsize, C, -1)
# A: B,C,HxW
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
# A': B,HxW,C
energy = torch.bmm(proj_query, proj_key)
# \hat{X} = AxA': B,C,C
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
# note that, 作者在这里用了一次 max-v_i,而不是常见的v_i-max,按照github上的解释,
# 作者选用前者而不是后者的原因是后者的效果不好,不知道该怎么反驳,
# channel attention 主要衡量的是通道与通道之间的相似性,
# 按照这个公式,结合channel的可视化,只能强行解释成,希望通道之间不相似,越不相似给的值越高,
attention = self.softmax(energy_new)
# X: B,C,C
proj_value = x.view(m_batchsize, C, -1)
# D: B,C,HxW
out = torch.bmm(attention, proj_value)
# \hat{E} = XxD : B,C,HxW,这里也满足\hat{E}中的每个元素的系数之后为1
out = out.view(m_batchsize, C, height, width)
# \hat{E}: B,C,H,W
out = self.gamma*out + x
# E: B,C,H,W
return out

6.2 DANetHead

从代码上看,过程大概是:

  1. 是在进入 attention module 会进行一次通道缩小,2048->512,
  2. position attention module: sa_conv, channel attention module: sc_conv
  3. 得到三种预测的类别:sa_conv+sc_conv->sasc_output, sa_conv->sa_output, sc_conv->sc_output
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class DANetHead(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer):
# in_channels: 2048
# out_channels: dataset.num_classes
super(DANetHead, self).__init__()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU())
self.sa = PAM_Module(inter_channels)
self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU())
self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1))

self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU())
self.sc = CAM_Module(inter_channels)
self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU())
self.conv7 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1))

self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(512, out_channels, 1))

def forward(self, x):
# x: B,C,H,W: C 2048
feat1 = self.conv5a(x)
# feat1: B,C//4,H,W
sa_feat = self.sa(feat1)
# sa_feat: B,C//4,H,W
sa_conv = self.conv51(sa_feat)
# sa_conv: B,C//4,H,W
sa_output = self.conv6(sa_conv)
# sa_output: B,C_out,H,W

# x: B,C,H,W: C 2048
feat2 = self.conv5c(x)
# feat2: B,C//4,H,W
sc_feat = self.sc(feat2)
# sc_feat: B,C//4,H,W
sc_conv = self.conv52(sc_feat)
# sc_conv: B,C//4,H,W
sc_output = self.conv7(sc_conv)
# sc_output: B,C_out,H,W

feat_sum = sa_conv+sc_conv
# feat_sum: B,C//4,H,W
sasc_output = self.conv8(feat_sum)
# sasc_output: B,C_out,H,W

output = [sasc_output]
output.append(sa_output)
output.append(sc_output)
# output:[sasc_output, sa_output, sc_output]: 3,B,C_out,H,W
return tuple(output)

6.3 BaseNet

以ResNet-50为例,相当于求得每一个layer的输出 [c1, c2, c3, c4]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
base_size=576, crop_size=608, mean=[.485, .456, .406],
std=[.229, .224, .225], root='./pretrain_models',
multi_grid=False, multi_dilation=None):
super(BaseNet, self).__init__()
self.nclass = nclass
self.aux = aux
self.se_loss = se_loss
self.mean = mean
self.std = std
self.base_size = base_size
self.crop_size = crop_size
# copying modules from pretrained models
if backbone == 'resnet50':
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root,
multi_grid=multi_grid, multi_dilation=multi_dilation)
elif backbone == 'resnet101':
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root,
multi_grid=multi_grid,multi_dilation=multi_dilation)
elif backbone == 'resnet152':
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated,
norm_layer=norm_layer, root=root,
multi_grid=multi_grid, multi_dilation=multi_dilation)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
# bilinear upsample options
self._up_kwargs = up_kwargs

def base_forward(self, x):
x = self.pretrained.conv1(x)
x = self.pretrained.bn1(x)
x = self.pretrained.relu(x)
x = self.pretrained.maxpool(x)
c1 = self.pretrained.layer1(x)
c2 = self.pretrained.layer2(c1)
c3 = self.pretrained.layer3(c2)
c4 = self.pretrained.layer4(c3)
return c1, c2, c3, c4

def evaluate(self, x, target=None):
pred = self.forward(x)
if isinstance(pred, (tuple, list)):
pred = pred[0]
if target is None:
return pred
correct, labeled = batch_pix_accuracy(pred.data, target.data)
inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
return correct, labeled, inter, union

6.5 DANet

相当于求这三种的预测:sa_conv+sc_conv->sasc_output, sa_conv->sa_output, sc_conv->sc_output

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class DANet(BaseNet):
r"""Fully Convolutional Networks for Semantic Segmentation

Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
'resnet101' or 'resnet152').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;


Reference:

Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks
for semantic segmentation." *CVPR*, 2015

"""
def __init__(self, nclass, backbone, aux=False, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
super(DANet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = DANetHead(2048, nclass, norm_layer)

def forward(self, x):
# 具体的图片大小还是需要看图像分割的输入,这里以标准的224为例
# x: 3,H,W && 224, 224
imsize = x.size()[2:]
_, _, c3, c4 = self.base_forward(x)
# c3, c4: ResNet-50 的 layer3 和 layer 4 的输出
# c3: 1024, H//16, W//16 && 1024, 14, 14=224//16
# c4: 2018, H//32, W//32 && 7, 7=224//32
x = self.head(c4)
# x: [sasc_output, sa_output, sc_output]: 3,B,dataset.num_classes,H//32, W//32 && 7, 7=224//32
x = list(x)
x[0] = upsample(x[0], imsize, **self._up_kwargs)
x[1] = upsample(x[1], imsize, **self._up_kwargs)
x[2] = upsample(x[2], imsize, **self._up_kwargs)
# 上采样
# x: [sasc_output, sa_output, sc_output]: 3,B,dataset.num_classes,H,W && 224, 224
outputs = [x[0]]
outputs.append(x[1])
outputs.append(x[2])
# x: [sasc_output, sa_output, sc_output]: 3,B,dataset.num_classes,H,W && 224, 224
return tuple(outputs)

6.6 SegmentationMultiLosses

希望 position+channel attetion, position attention, channel attention 三种预测都准确

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class SegmentationMultiLosses(CrossEntropyLoss):
"""2D Cross Entropy Loss with Multi-L1oss"""
def __init__(self, nclass=-1, weight=None,size_average=True, ignore_index=-1):
super(SegmentationMultiLosses, self).__init__(weight, size_average, ignore_index)
self.nclass = nclass


def forward(self, *inputs):

*preds, target = tuple(inputs)
pred1, pred2 ,pred3= tuple(preds)
# sa_conv+sc_conv->sasc_output, sa_conv->sa_output, sc_conv->sc_output
loss1 = super(SegmentationMultiLosses, self).forward(pred1, target)
loss2 = super(SegmentationMultiLosses, self).forward(pred2, target)
loss3 = super(SegmentationMultiLosses, self).forward(pred3, target)
loss = loss1 + loss2 + loss3
return loss

6.7 其他

其他的代码暂时就不看了,只是记录一个自己没有看到过的函数

Synchronized Cross-GPU Batch Normalization functions