这篇博客主要记录在跟随cycleGAN作者的代码复现学到的东西。
title: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks(ICCV2017)
paper: https://arxiv.org/abs/1703.10593
code: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
mycode: https://github.com/TJJTJJTJJ/pytorch_cycleGAN
cycle_gan的整体框架写得很漂亮,frame可以参考github的frame
1.动态导入模块以及文件内的类
类似这种文件结构
.models
|— init.py
|— base_model.py
|— cycle_gan_model.py
|— networks.py
|— pix2pix_model.py
`— test_model.py
在init.py这样写两个函数
1 | def find_model_using_name(model_name): |
1 | modellib.__dict__ == vars(modellib) |
1 | import importlib |
1 | exit(0)无错误退出 |
2.学习率直线下降
1 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): |
3.NotImplemented && NotImplementedError
参考:
http://www.php.cn/python-tutorials-160083.html
https://stackoverflow.com/questions/1062096/python-notimplemented-constant
return NotImplemented
raise NotImplementedError(‘initialization method [%s] is not implemented’ % init_type)
4.parser的修改
这里既有外界传入的参数,也有自己的参数isTrain,在主函数里调用的时候调用方式是一致的,只是一个可以通过外界传参,一个不能通过外界传参
1 | class TrainOptions(): |
5.eval()和test()函数的结合
1 | def eval(self): |
6.多GPU
结论
1 | modelb = torch.nn.DataParallel(modela, device_ids=[0,1,2]) |
对于单gpu和Module
对于普通的model.cuda,在保存模型会自动变成cpu,需要再次cuda一次
对于DataParallel,在保存模型会自动变成cpu,需要再次cuda一次
通过源码可以得知,DataParallel的device_ids初始化就已经确定,所以不用担心cuda到第一个GPU上而导致DataParallel忘记自己可以复制到哪些GPU上,会自动复制的
1 | import torch |
7.Norm
参考:
https://blog.csdn.net/liuxiao214/article/details/81037416
输入图像:[N,C,H,W]
BatchNorm: [1,C,1,1]
InstanceNorm: [N,C,1,1]
经过实验,instanceNorm层的weight, bias, running_mean, running_var总是None
代码中加载模型的时候对instanceNorm层进行了删除操作,为什么
对于pytorch之前的版本instanceNorm层是有running_mean和running_var的,之后的版本修正了之后,就不再需要了
1 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): |
8.functools
偏函数:适合为多个调用函数提供一致的函数接口
1 | from functools import partial |
9.论文与代码
ndf
模型的定义与论文有一个地方不一致,论文写的第一个conv之后通道数是32,但实现是64.
与作者沟通得知,第一层不是32,而是64,剩下的也依次递增.
下采样的时候没有使用reflect进行补充,而是使用了0填充.
与作者沟通后,提出的是都可以尝试一下
unet model
Unet model
与网上的不是很一致
3->1->2->4->8->8->8
3<-2<-4<-8<-16<-16<-16
参数 no_lsgan
1 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) |
也就是
opt.no_lsgan为True时, netD使用sigmoid, GANloss使用BCELoss()
opt.no_lsgan为False时, netD不使用sigmoid, GANloss使用MSELoss()
MSELoss:均方误差 (x-y)*2
BCELoss:二分类的交叉熵:使用前需要使用sigmoid函数,input和target的输入维度是一样的.(N,)
根据作者提供的运行代码,猜测作者使用的是opt.no_lsgan为False,均方误差
L1loss: |x-y|
G and D 的反向传播过程
回顾一下G和D的反向传播
train G
1 | set_requires_grad(D, False) |
train D
1 | set_requires_grad(D, False) |
10.ConTransposed的计算方法
逆卷积后的图像大小和之前的能对应上,需要output_padding
1 | nn.ConvTranspose2d(ngf*mult, int(ngf*mult/2), kernel_size=3,stride=2, padding=1, output_padding=1, bias=use_bias)] |
11.初始化参数
1 | def init_weights(net, init_type='normal', gain=0.02): |
12.disriminator PatchGAN and GANLoss
PatchGAN的kernel是4.
1 | class GANLoss(nn.Module): |
GANLoss的备注
使用时,直观上可将layer看成数学概念中的函数,调用layer(input)即可得到input对应的结果。它等价于layers.call(input),在call函数中,主要调用的是 layer.forward(x),另外还对钩子做了一些处理。所以在实际使用中应尽量使用layer(x)而不是使用layer.forward(x)。
13.PatchGAN的感受野
论文使用的是70X70 PatchGAN
PatchGAN:
paper:
Image-to-Image Translation with Conditional Adversarial Networks
https://arxiv.org/abs/1611.07004
自动计算网址:https://fomoro.com/tools/receptive-fields/
1 | 感受野的计算规则 |
14.torch.tensor.clone()
clone()
梯度受影响,clone之后的新的tensor的梯度也会影响到原tensor,但是新tensor本身是没有梯度的.
clone之后的新tensor的改变不会影响原有的tensor
应该这么理解,clone也是计算图中的一个操作,这样的话就可以解释通了.
1 | import torch |
1 | import torch |
1 | print(input2.grad_fn) |
clone的用法
tensor保留梯度的交换
1 | tmp = tensor1.clone() |
15.from XX import
这里还有一些不太对的地方
1 | from .base_model import BaseModel # 同一个文件夹 |
16.register_buffer
register_buffer
self.register_buffer可以将tensor注册成buffer,在forward中使用self.mybuffer, 而不是self.mybuffer_tmp.
定义Parameter和buffer都只需要传入Tensor即可。也不需要将其转成gpu。这是因为,当网络进行.cuda()时候,会自动将里面的层的参数,buffer等转换成相应的GPU上。
网络存储时也会将buffer存下,当网络load模型时,会将存储的模型的buffer也进行赋值。
buffer的更新在forward中,optim.step只能更新nn.Parameter类型的参数。
用法
self.register_buffer(‘running_mean’, torch.zeros(num_features))
17. itertools
无限迭代器
itertools,用于创建高效迭代器的函数,
itertools.chain 连接多个列表或者迭代器。
将多个网络写在一起,使用一个优化器
1 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), |
1 | # 自然数无限迭代器 |
18.visdom
1 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, |
19.三引号
1 | 三引号的作用 |
20.异常
1 | # 1.触发异常 |
21.自定义类的iter
1 | # 自定义类的iter |