1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
| import torch as t from torch import nn
class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.pre = nn.Sequential(nn.Linear(2,2), nn.Conv2d(3,3,3)) self.two = nn.Sequential(nn.Linear(3,3)) self.apply(init_weights) def forward(self,x): pass
def init_weights(m): print(m) print(type(m)) print(nn.Linear) print(m.__class__) print(m.__class__.__name__) if type(m) == nn.Linear: m.weight.data.fill_(0.0) print(m.weight.data) print("_______________________")
net2 = Net()
Linear(in_features=2, out_features=2, bias=True) <class 'torch.nn.modules.linear.Linear'> <class 'torch.nn.modules.linear.Linear'> <class 'torch.nn.modules.linear.Linear'> Linear tensor([[ 0., 0.], [ 0., 0.]]) _______________________ Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) <class 'torch.nn.modules.conv.Conv2d'> <class 'torch.nn.modules.linear.Linear'> <class 'torch.nn.modules.conv.Conv2d'> Conv2d _______________________ Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) ) <class 'torch.nn.modules.container.Sequential'> <class 'torch.nn.modules.linear.Linear'> <class 'torch.nn.modules.container.Sequential'> Sequential _______________________ Linear(in_features=3, out_features=3, bias=True) <class 'torch.nn.modules.linear.Linear'> <class 'torch.nn.modules.linear.Linear'> <class 'torch.nn.modules.linear.Linear'> Linear tensor([[ 0., 0., 0.], [ 0., 0., 0.], [ 0., 0., 0.]]) _______________________ Sequential( (0): Linear(in_features=3, out_features=3, bias=True) ) <class 'torch.nn.modules.container.Sequential'> <class 'torch.nn.modules.linear.Linear'> <class 'torch.nn.modules.container.Sequential'> Sequential _______________________ Net( (pre): Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) ) (two): Sequential( (0): Linear(in_features=3, out_features=3, bias=True) ) ) <class '__main__.Net'> <class 'torch.nn.modules.linear.Linear'> <class '__main__.Net'> Net _______________________
|