0%

InstaGAN

实例转换

0. 前言

InstaGAN Instance-aware image-to-image translation

Sangwoo Mo, Minsu Cho, Jinwoo Shin

github: https://github.com/sangwoomo/instagan

project: https://openreview.net/forum?id=ryxwJhC9YX

1. Introduction

Tranlation results

整体分为三部分:

  1. an instance-augmented neural architecture
  2. a context preserving loss
  3. a sequential mini-batch inference/training technique
  • an instance-augmented neural architecture: an image and the corresponding set of instance attributes.
  • a context preserving loss: target instances and an identity function
  • a sequential mini-batch inference/training technique: translating the mini-batches of instance attributes sequentially

2. InstaGAN

符号说明:

符号 说明
$\mathcal{X}$, $\mathcal{Y}$ image domain
$\mathcal{A}, \mathcal{B}$ a space of set of instance attributes
$\boldsymbol{a} = \lbrace ai \rbrace {i=1}^N $ set of instance attributes
instance segmentation mask
$G{XY}:\mathcal{X}->\mathcal{Y}, G{YX}:\mathcal{Y}->\mathcal{X}$ tranlation function

2.1 InstaGAN architecture

InstaGAN architecture

符号 说明
$f_{GX}$ image feature extractor
$f_{GA}$ attribute feature extractor
$H{GX}(x,a)=[f{GX}(x);\sum{i=1}^Nf{GA}(a_i)]$ image representation
$H{GA}^n(x,a)=[f{GX}(x);\sum{i=1}^Nf{GA}(ai);f{GA}(a_n)]$ image representation
$h{DX}(x,a)=[f{DX}(x);\sum{i=1}^Nf{DA}(a_i)]$ image representation for discriminator
$f{GX},f{GA},f{DX},f{DA},g{GX},G{GA},G_{DX}$
$(x,a)->(y’,b’)$
$(y,b)->(x’,a’)$

作者为了能实现mask顺序不变性,采用相加的方式

2.2 Training loss

  1. domain loss: GAN loss
  2. content loss: cycle-consistency loss and identity mapping loss and context preserving loss

LSGAN: 判断图片是原始的还是生成的

cycle-consistency loss: 循环一致性

identity mapping loss: 恒等映射

context preserving loss: 保留背景

其中,$w(a,b’), w(b,a’)$表示在原图片和生成图片都是背景的位置的权重是1.

Total loss:

2.3 sequential mini-batch translation

考虑到在图片上的实例可能很多,而GPU的所需空间随之线性增长,可能不符合现实情况,所以需要考虑在图片上可以转化一小部分实例。

sequential mini-batch translation

符号说明:

符号 说明
$a=\cup_{i=1}^Ma_i$ divide the set of instance masks a into mini-batch $a_1,a_2,…,a_M$
$(xm, a_m)->(y’_m, b’_m) or (x{m+1}, a_{m+1})$ mini-batch translation
$(y’_m, b’_{1:m})=(y’_m, \cup_{i=1}^m b’_i)$ 用于判断真假

在这种情况下,不同的损失函数作用的范围发生改变,第m次时,content loss作用在$(xm, a_m), (y’_m, b’_m)$,domain loss 作用在$(x,a), (y’_m, b’\{1:m})$,即

  • 每m个迭代detach一次,来使用固定大小的GPU。
  • 划分mini-batch的原则:size of instances, 由大到小

3. experimental results

3.1 image-to-image translation results

translation results

通过上述结果的展示,我可以认为在这方面InstaGAN要比CycleGAN的效果更好,更能得到想要的指定的结果。

results of translation

第一个结果表明可以通过控制掩码来控制生成的图片。

第二个结果表明可以使用预测的掩码进行转换图片,从而减少获取掩码的成本。

3.2 ablation study

ablation study

Fig.9 主要使研究作者提出的三部分功能的作用,instance mask,损失函数,mini-batch的影响,从效果上看,还是最后一张图片效果更好一些。

ablation  study on the effects of the sequential mini-batch inference/training technique

Fig.10分别表示在training和inference中使用one-step还是sequential方法,我觉得都差不多,但是对于有限的GPU是个很好的方法。

4. Appendix

4.1 architecture details

PatchGAN discriminator is composed of 5 convolutional layers, including normalization and non-linearity layers. We used the first 3 convolution layers for feature extractors, and the last 2 convolution layers for classifier.

4.2 traning details

  • $\lambda{cyc}=10, \lambda{idt}=10, \lambda_{ctx}=10$
  • Adam: $\beta_1=0.5, \beta_2=0.999$
  • batch_size=4
  • GPU = 4
  • learning rate: 0.0002 for G, 0.0001 for D, 前m个epoch保持不变,后n个epoch线性衰减为0.不同的数据集的m和n不同
  • size对于不同的数据集也不同。

4.3 trend of translation results

trend of translation results

4.4 其他

我觉得这是相当于对于CycleGAN,加上了指向性生成,不再是单独地生成目标域风格的图片,而是对指定区域生成目标域风格的图片。

刚刚想到一个问题,InstaGAN可以生成指定形状的图片,但是对于同一形状的不同物体,比如生成红色的裙子和黑色的裙子这样子的任务,可能不行。

4.5 video translation results

video translation results

作者使用pix2pix作为分割。

感觉在视频上,裤子换成裙子后,能保持所有帧的裙子都是一样的,说明转换的稳定性很好。

5. code

看细节还是需要看代码的实现过程。

5.1 文件目录

文件目录

通过文件组织,可以发现cycleGAN尽可能地考虑了可扩展性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
.
|-- LICENSE
|-- README.md
|-- data
| |-- __init__.py
| |-- aligned_dataset.py
| |-- base_data_loader.py
| |-- base_dataset.py
| |-- image_folder.py
| |-- single_dataset.py
| |-- unaligned_dataset.py
| `-- unaligned_seg_dataset.py
|-- datasets
| |-- combine_A_and_B.py
| |-- download_coco.sh
| |-- download_cyclegan_dataset.sh
| |-- download_pix2pix_dataset.sh
| |-- generate_ccp_dataset.py
| |-- generate_coco_dataset.py
| |-- generate_mhp_dataset.py
| |-- make_dataset_aligned.py
| |-- pants2skirt_mhp
|-- docs
| `-- more_results.md
|-- environment.yml
|-- models
| |-- __init__.py
| |-- base_model.py
| |-- cycle_gan_model.py
| |-- insta_gan_model.py
| |-- networks.py
| |-- pix2pix_model.py
| `-- test_model.py
|-- options
| |-- __init__.py
| |-- base_options.py
| |-- test_options.py
| `-- train_options.py
|-- requirements.txt
|-- scripts
| |-- conda_deps.sh
| |-- download_cyclegan_model.sh
| |-- download_pix2pix_model.sh
| |-- install_deps.sh
| |-- test_before_push.py
| |-- test_cyclegan.sh
| |-- test_pix2pix.sh
| |-- test_single.sh
| |-- train_cyclegan.sh
| `-- train_pix2pix.sh
|-- test.py
|-- train.py
`-- util
|-- __init__.py
|-- get_data.py
|-- html.py
|-- image_pool.py
|-- util.py
`-- visualizer.py

5.2 seg

从下面的代码可以看出,需要读取固定数量的instance的segmentation。

1
2
3
4
5
6
7
8
9
10
11
12
# self.max_instances = 20
def read_segs(self, seg_path, seed):
segs = list()
for i in range(self.max_instances):
path = seg_path.replace('.png', '_{}.png'.format(i))
if os.path.isfile(path):
seg = Image.open(path).convert('L')
seg = self.fixed_transform(seg, seed)
segs.append(seg)
else:
segs.append(-torch.ones(segs[0].size()))
return torch.cat(segs)

备注: 原始图片transforms之后,0~1变成了-1~1; 分割图片transforms之后-1表示背景,取值-1~1,这也是为什么补充的时候用-1补充的原因.

5.3 generator

ResNet generator is composed of downsampling blocks, residual blocks, and upsampling blocks. We used downsampling blocks and residual blocks for encoders, and used upsampling blocks for generators.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class ResnetSetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
assert (n_blocks >= 0)
super(ResnetSetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d

n_downsampling = 2
self.encoder_img = self.get_encoder(input_nc, n_downsampling, ngf, norm_layer, use_dropout, n_blocks, padding_type, use_bias)
self.encoder_seg = self.get_encoder(1, n_downsampling, ngf, norm_layer, use_dropout, n_blocks, padding_type, use_bias)
self.decoder_img = self.get_decoder(output_nc, n_downsampling, 2 * ngf, norm_layer, use_bias) # 2*ngf
self.decoder_seg = self.get_decoder(1, n_downsampling, 3 * ngf, norm_layer, use_bias) # 3*ngf

def get_encoder(self, input_nc, n_downsampling, ngf, norm_layer, use_dropout, n_blocks, padding_type, use_bias):
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]

for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]

mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

return nn.Sequential(*model)

def get_decoder(self, output_nc, n_downsampling, ngf, norm_layer, use_bias):
model = []
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
return nn.Sequential(*model)

def forward(self, inp):
# split data
img = inp[:, :self.input_nc, :, :] # (B, CX, W, H)
segs = inp[:, self.input_nc:, :, :] # (B, CA, W, H)
mean = (segs + 1).mean(0).mean(-1).mean(-1)
if mean.sum() == 0:
mean[0] = 1 # forward at least one segmentation

# run encoder
enc_img = self.encoder_img(img)
enc_segs = list()
for i in range(segs.size(1)): # 第i个instance
if mean[i] > 0: # skip empty segmentation
seg = segs[:, i, :, :].unsqueeze(1)
enc_segs.append(self.encoder_seg(seg))
enc_segs = torch.cat(enc_segs)
enc_segs_sum = torch.sum(enc_segs, dim=0, keepdim=True) # aggregated set feature

# run decoder
feat = torch.cat([enc_img, enc_segs_sum], dim=1)
out = [self.decoder_img(feat)]
idx = 0
for i in range(segs.size(1)):
if mean[i] > 0:
enc_seg = enc_segs[idx].unsqueeze(0) # (1, ngf, w, h)
idx += 1 # move to next index
feat = torch.cat([enc_seg, enc_img, enc_segs_sum], dim=1)
out += [self.decoder_seg(feat)]
else:
out += [segs[:, i, :, :].unsqueeze(1)] # skip empty segmentation
return torch.cat(out, dim=1)


5.4 Discriminator

On the other hand, PatchGAN discriminator is composed of 5 convolutional layers, including normalization and non-linearity layers. We used the first 3 convolution layers for feature extractors, and the last 2 convolution layers for classifier.

In addition, we observed that applying Spectral Normalization (SN) (Miyato et al., 2018) for discriminators significantly improve the performance, although we used LSGAN (Mao et al., 2017), while the original motivation of SN was to enforce Lipschitz condition to match with the theory of WGAN (Arjovsky et al., 2017; Gulrajani et al., 2017).

虽然还是没有太搞懂其原理,但大致清楚了,是求矩阵的谱范数,因为难以求解,便用迭代的方式计算u、v。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Define spectral normalization layer
# Code from Christian Cosgrove's repository
# https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)

class SpectralNorm(nn.Module):
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()

def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")

height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))

sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))

def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False

def _make_params(self):
w = getattr(self.module, self.name) # shape: (64,3,4,4)

height = w.data.shape[0] # int 64
width = w.view(height, -1).data.shape[1] # int 48

u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) # shape (64)
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) # shape (48)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = nn.Parameter(w.data)

del self.module._parameters[self.name]

self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)

def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)

class NLayerSetDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
super(NLayerSetDiscriminator, self).__init__()
self.input_nc = input_nc
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d

kw = 4
padw = 1
self.feature_img = self.get_feature_extractor(input_nc, ndf, n_layers, kw, padw, norm_layer, use_bias)
self.feature_seg = self.get_feature_extractor(1, ndf, n_layers, kw, padw, norm_layer, use_bias)
self.classifier = self.get_classifier(2 * ndf, n_layers, kw, padw, norm_layer, use_sigmoid) # 2*ndf

def get_feature_extractor(self, input_nc, ndf, n_layers, kw, padw, norm_layer, use_bias):
model = [
# Use spectral normalization
SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
model += [
# Use spectral normalization
SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
return nn.Sequential(*model)

def get_classifier(self, ndf, n_layers, kw, padw, norm_layer, use_sigmoid):
nf_mult_prev = min(2 ** (n_layers-1), 8)
nf_mult = min(2 ** n_layers, 8)
model = [
# Use spectral normalization
SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw)),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
# Use spectral normalization
model += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))]
if use_sigmoid:
model += [nn.Sigmoid()]
return nn.Sequential(*model)

def forward(self, inp):
# split data
img = inp[:, :self.input_nc, :, :] # (B, CX, W, H)
segs = inp[:, self.input_nc:, :, :] # (B, CA, W, H)
mean = (segs + 1).mean(0).mean(-1).mean(-1)
if mean.sum() == 0:
mean[0] = 1 # forward at least one segmentation

# run feature extractor
feat_img = self.feature_img(img)
feat_segs = list()
for i in range(segs.size(1)): # 第i个instance
if mean[i] > 0: # skip empty segmentation
seg = segs[:, i, :, :].unsqueeze(1)
feat_segs.append(self.feature_seg(seg))
feat_segs_sum = torch.sum(torch.stack(feat_segs), dim=0) # aggregated set feature

# run classifier
feat = torch.cat([feat_img, feat_segs_sum], dim=1)
out = self.classifier(feat)
return out


5.5 model 的输入

1
2
3
4
5
6
7
8
9
10
11
12
13
def set_input(self, input):
AtoB = self.opt.direction == 'AtoB'
self.real_A_img = input['A' if AtoB else 'B'].to(self.device)
self.real_B_img = input['B' if AtoB else 'A'].to(self.device)
real_A_segs = input['A_segs' if AtoB else 'B_segs']
real_B_segs = input['B_segs' if AtoB else 'A_segs']
self.real_A_segs = self.select_masks(real_A_segs).to(self.device) # shape:(1,4,240,160)
self.real_B_segs = self.select_masks(real_B_segs).to(self.device)
self.real_A = torch.cat([self.real_A_img, self.real_A_segs], dim=1) # shape:(1,7,240,160)
self.real_B = torch.cat([self.real_B_img, self.real_B_segs], dim=1)
self.real_A_seg = self.merge_masks(self.real_A_segs) # merged mask
self.real_B_seg = self.merge_masks(self.real_B_segs) # merged mask
self.image_paths = input['A_paths' if AtoB else 'B_paths']

前面说过,每次都生成20个mask,不足用-1补充,在输入网络时,只取面积最大的4个mask,然后对这4个进行或者从高到低排序或者随机排序。

1
2
3
4
5
6
7
8
9
10
11
# ins_max = 4
def select_masks_random(self, segs_batch):
"""Select masks in random order"""
ret = list()
for segs in segs_batch:
mean = (segs + 1).mean(-1).mean(-1)
m, i = mean.topk(self.opt.ins_max)
num = min(len(mean.nonzero()), self.opt.ins_max)
reorder = np.concatenate((np.random.permutation(num), np.arange(num, self.opt.ins_max)))
ret.append(segs[i[reorder], :, :])
return torch.stack(ret)

这里的mask的合并没有太看懂,是为了去除(-1,1)之外的数字吗?

跑了代码,觉得是的,或许是担心有其他干扰因素吧,反正剩下的都是-1~1之间的数字.

1
2
3
4
def merge_masks(self, segs):
"""Merge masks (B, N, W, H) -> (B, 1, W, H)"""
ret = torch.sum((segs + 1)/2, dim=1, keepdim=True) # (B, 1, W, H)
return ret.clamp(max=1, min=0) * 2 - 1

其他

  • [ ] 这一步的意义是什么??

理解了,如果图片中没有instance,那么就不用进行下一步的转换了。

1
2
self.forward_A = (self.real_A_seg_sng + 1).sum() > 0  # check if there are remaining instances
self.forward_B = (self.real_B_seg_sng + 1).sum() > 0 # check if there are remaining instances
  • [x] fake_B_mul的意义是什么?

因为在sequential mini-batch translation中,GAN_loss是全局的,所以每次需要把之前的fake_B_seg_sng保存起来一起计算,因此每次的临时的self.fake_B_mul,而self.fake_B_seg_list保存是mini-batch计算得到的。

1
2
3
4
5
6
7
8
9
10
11
12
13
if self.forward_A:
self.real_A_sng = torch.cat([self.real_A_img_sng, self.real_A_seg_sng], dim=1)
self.fake_B_sng = self.netG_A(self.real_A_sng)
self.rec_A_sng = self.netG_B(self.fake_B_sng)

self.fake_B_img_sng, self.fake_B_seg_sng = self.split(self.fake_B_sng)
self.rec_A_img_sng, self.rec_A_seg_sng = self.split(self.rec_A_sng)
fake_B_seg_list = self.fake_B_seg_list + [self.fake_B_seg_sng] # not detach
for i in range(self.ins_iter - idx - 1):
fake_B_seg_list.append(empty)

self.fake_B_seg_mul = torch.cat(fake_B_seg_list, dim=1)
self.fake_B_mul = torch.cat([self.fake_B_img_sng, self.fake_B_seg_mul], dim=1)
  • [x] 怎么选取的背景

只要在A中且在B中都是背景的则都算是背景,否则只要有instance的区域不为背景。

1
2
3
4
5
def merge_masks(self, segs):
"""Merge masks (B, N, W, H) -> (B, 1, W, H)"""
# segs: shape(1,4,240,260), 取值(-1~1) 训练集A中有两个instance,训练集B中有两个instance,
ret = torch.sum((segs + 1)/2, dim=1, keepdim=True) # (B, 1, W, H)
return ret.clamp(max=1, min=0) * 2 - 1
1
2
3
4
5
6
7
8
9
def merge_masks(self, segs):
"""Merge masks (B, N, W, H) -> (B, 1, W, H)"""
ret = torch.sum((segs + 1)/2, dim=1, keepdim=True) # (B, 1, W, H)
return ret.clamp(max=1, min=0) * 2 - 1

def get_weight_for_ctx(self, x, y):
"""Get weight for context preserving loss"""
z = self.merge_masks(torch.cat([x, y], dim=1))
return (1 - z) / 2 # [-1,1] -> [1,0]
  • [ ] 这里的empty的作用是什么
1
empty = -torch.ones(self.real_A_seg_sng.size()).to(self.device)
  • [ ] pix2pix 是怎么预测mask的,需要提前训练吗,数据集怎么提供?如果可以直接用,那么是否可以直接实现行人重识别的换人?

  • [x] 论文+代码,共4天