ulichovick commited on
Commit
cdfd2e1
1 Parent(s): a4e7e18

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +185 -0
model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import PyTorchModelHubMixin
2
+ from torch import nn
3
+
4
+ class SurfinBird(nn.Module, PyTorchModelHubMixin):
5
+ def __init__(self, config: dict) -> None:
6
+ super().__init__()
7
+ self.conv1 = nn.Conv2d(
8
+ in_channels=config["num_channels"],
9
+ out_channels=64,
10
+ kernel_size=7,
11
+ stride=2,
12
+ padding=3)
13
+ self.bn1 = nn.BatchNorm2d(64)
14
+ self.relu1 = nn.ReLU()
15
+ self.mp1 = nn.MaxPool2d(kernel_size=2,
16
+ stride=2)
17
+ self.conv_block_2 = nn.Sequential(
18
+ nn.Conv2d(
19
+ in_channels=64,
20
+ out_channels=64,
21
+ kernel_size=3,
22
+ stride=1,
23
+ padding=1
24
+ ),
25
+ nn.BatchNorm2d(64),
26
+ nn.ReLU(),
27
+ nn.Conv2d(
28
+ in_channels=64,
29
+ out_channels=64,
30
+ kernel_size=3,
31
+ stride=1,
32
+ padding=1
33
+ ),
34
+ nn.BatchNorm2d(64),
35
+ nn.ReLU(),
36
+ nn.Conv2d(
37
+ in_channels=64,
38
+ out_channels=64,
39
+ kernel_size=3,
40
+ stride=1,
41
+ padding=1
42
+ ),
43
+ nn.BatchNorm2d(64),
44
+ nn.ReLU(),
45
+ nn.MaxPool2d(kernel_size=2,
46
+ stride=2)
47
+ )
48
+ self.conv_block_3 = nn.Sequential(
49
+ nn.Conv2d(
50
+ in_channels=64,
51
+ out_channels=128,
52
+ kernel_size=3,
53
+ stride=1,
54
+ padding=1
55
+ ),
56
+ nn.BatchNorm2d(128),
57
+ nn.ReLU(),
58
+ nn.Conv2d(
59
+ in_channels=128,
60
+ out_channels=128,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1
64
+ ),
65
+ nn.BatchNorm2d(128),
66
+ nn.ReLU(),
67
+ nn.Conv2d(
68
+ in_channels=128,
69
+ out_channels=128,
70
+ kernel_size=3,
71
+ stride=1,
72
+ padding=1
73
+ ),
74
+ nn.BatchNorm2d(128),
75
+ nn.ReLU(),
76
+ nn.MaxPool2d(kernel_size=2,
77
+ stride=2)
78
+ )
79
+ self.conv_block_4 = nn.Sequential(
80
+ nn.Conv2d(
81
+ in_channels=128,
82
+ out_channels=128,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1
86
+ ),
87
+ nn.BatchNorm2d(128),
88
+ nn.ReLU(),
89
+ nn.Conv2d(
90
+ in_channels=128,
91
+ out_channels=128,
92
+ kernel_size=3,
93
+ stride=1,
94
+ padding=1
95
+ ),
96
+ nn.BatchNorm2d(128),
97
+ nn.ReLU(),
98
+ nn.Conv2d(
99
+ in_channels=128,
100
+ out_channels=128,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1
104
+ ),
105
+ nn.BatchNorm2d(128),
106
+ nn.ReLU(),
107
+ nn.MaxPool2d(kernel_size=2,
108
+ stride=2)
109
+ )
110
+ self.conv_block_5 = nn.Sequential(
111
+ nn.Conv2d(
112
+ in_channels=128,
113
+ out_channels=256,
114
+ kernel_size=3,
115
+ stride=1,
116
+ padding=1
117
+ ),
118
+ nn.BatchNorm2d(256),
119
+ nn.ReLU(),
120
+ nn.Conv2d(
121
+ in_channels=256,
122
+ out_channels=256,
123
+ kernel_size=3,
124
+ stride=1,
125
+ padding=1
126
+ ),
127
+ nn.BatchNorm2d(256),
128
+ nn.ReLU(),
129
+ nn.Conv2d(
130
+ in_channels=256,
131
+ out_channels=256,
132
+ kernel_size=3,
133
+ stride=1,
134
+ padding=1
135
+ ),
136
+ nn.BatchNorm2d(256),
137
+ nn.ReLU(),
138
+ nn.MaxPool2d(kernel_size=2,
139
+ stride=2)
140
+ )
141
+ self.conv_block_6 = nn.Sequential(
142
+ nn.Conv2d(
143
+ in_channels=256,
144
+ out_channels=256,
145
+ kernel_size=3,
146
+ stride=1,
147
+ padding=1
148
+ ),
149
+ nn.BatchNorm2d(256),
150
+ nn.ReLU(),
151
+ nn.Conv2d(
152
+ in_channels=256,
153
+ out_channels=256,
154
+ kernel_size=3,
155
+ stride=1,
156
+ padding=1
157
+ ),
158
+ nn.BatchNorm2d(256),
159
+ nn.ReLU(),
160
+ nn.Conv2d(
161
+ in_channels=256,
162
+ out_channels=256,
163
+ kernel_size=3,
164
+ stride=1,
165
+ padding=1
166
+ ),
167
+ nn.BatchNorm2d(256),
168
+ nn.ReLU(),
169
+ nn.MaxPool2d(kernel_size=2,
170
+ stride=2)
171
+ )
172
+
173
+ self.avgpool = nn.Sequential(
174
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
175
+ )
176
+
177
+ self.classifier = nn.Sequential(
178
+ nn.Flatten(),
179
+ nn.Linear(in_features=config["hidden_units"]*1*1,
180
+ out_features=config["num_classes"])
181
+ )
182
+ def forward(self, x: torch.Tensor):
183
+ return self.classifier(self.avgpool(self.conv_block_6(self.conv_block_5(self.conv_block_4(self.conv_block_3(self.conv_block_2(self.mp1(self.relu1(self.bn1(self.conv1(x)))))))))))
184
+
185
+ config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525}