0%

pytorch-init

pytorch-init

pytorch模型的初始化

pytorch模型的初始化的常用方法。

1.apply+type

apply可以理解成从children开始遍历
可以用于init,可以用于model定义之后,与type配合。

Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch-nn-init).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch as t
from torch import nn

# define model
class Net(nn.Module):

def __init__(self):
super(Net,self).__init__()
self.pre = nn.Sequential(nn.Linear(2,2), nn.Conv2d(3,3,3))
self.two = nn.Sequential(nn.Linear(3,3))
self.apply(init_weights)
def forward(self,x):
pass

def init_weights(m):
print(m)
print(type(m))
print(nn.Linear)
print(m.__class__)
print(m.__class__.__name__)

if type(m) == nn.Linear:
m.weight.data.fill_(0.0)
print(m.weight.data)
print("_______________________")


net2 = Net()

Linear(in_features=2, out_features=2, bias=True)
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.linear.Linear'>
Linear
tensor([[ 0., 0.],
[ 0., 0.]])
_______________________
Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.conv.Conv2d'>
Conv2d
_______________________
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
)
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.container.Sequential'>
Sequential
_______________________
Linear(in_features=3, out_features=3, bias=True)
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.linear.Linear'>
Linear
tensor([[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]])
_______________________
Sequential(
(0): Linear(in_features=3, out_features=3, bias=True)
)
<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.container.Sequential'>
Sequential
_______________________
Net(
(pre): Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
)
(two): Sequential(
(0): Linear(in_features=3, out_features=3, bias=True)
)
)
<class '__main__.Net'>
<class 'torch.nn.modules.linear.Linear'>
<class '__main__.Net'>
Net
_______________________

2.apply+m.class.name

weights_init_kaiming
还要一种初始化函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def weights_init_kaiming(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_out')
init.constant(m.bias.data, 0.0)
elif classname.find('BatchNorm1d') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)

def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
init.normal(m.weight.data, std=0.001)
init.constant(m.bias.data, 0.0)