sebastiansarasti commited on
Commit
c5a5fc6
·
verified ·
1 Parent(s): eb29cc1

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +40 -0
model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from huggingface_hub import PyTorchModelHubMixin
3
+
4
+ class ModelColorization(nn.Module, PyTorchModelHubMixin):
5
+ def __init__(self):
6
+ super(ModelColorization, self).__init__()
7
+ self.encoder = nn.Sequential(
8
+ nn.Conv2d(1, 256, kernel_size=3, stride=1, padding=1),
9
+ nn.MaxPool2d(kernel_size=2, stride=2),
10
+ nn.ReLU(),
11
+ nn.BatchNorm2d(256),
12
+ nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
13
+ nn.MaxPool2d(kernel_size=2, stride=2),
14
+ nn.ReLU(),
15
+ nn.BatchNorm2d(128),
16
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
17
+ nn.MaxPool2d(kernel_size=2, stride=2),
18
+ nn.ReLU(),
19
+ nn.BatchNorm2d(64),
20
+ nn.Flatten(),
21
+ nn.Linear(64 * 16 * 16, 1024),
22
+ )
23
+ self.decoder = nn.Sequential(
24
+ nn.Linear(1024, 64 * 16 * 16),
25
+ nn.ReLU(),
26
+ nn.Unflatten(1, (64, 16, 16)),
27
+ nn.ConvTranspose2d(64, 128, kernel_size=2, stride=2),
28
+ nn.ReLU(),
29
+ nn.BatchNorm2d(128),
30
+ nn.ConvTranspose2d(128, 256, kernel_size=2, stride=2),
31
+ nn.ReLU(),
32
+ nn.BatchNorm2d(256),
33
+ nn.ConvTranspose2d(256, 3, kernel_size=2, stride=2),
34
+ nn.Sigmoid(),
35
+ )
36
+
37
+ def forward(self, x):
38
+ x = self.encoder(x)
39
+ x = self.decoder(x)
40
+ return x