jstetina commited on
Commit
05cd0b4
·
1 Parent(s): e18f2de

Add: model class

Browse files
Files changed (1) hide show
  1. model.py +111 -0
model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from torchvision import datasets, transforms
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from PIL import Image
9
+ import os
10
+
11
+ class ColorNet(nn.Module):
12
+ DEFAULT_CHECKPOINT_PATH = "checkpoint/colornet.pt"
13
+
14
+ def __init__(self, checkpoint_path:str=DEFAULT_CHECKPOINT_PATH):
15
+ super(ColorNet, self).__init__()
16
+
17
+ self.encoder = nn.Sequential(
18
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
19
+ nn.ReLU(),
20
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
21
+ nn.ReLU(),
22
+ nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
23
+ nn.ReLU()
24
+ )
25
+ self.decoder = nn.Sequential(
26
+ nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
27
+ nn.ReLU(),
28
+ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
29
+ nn.ReLU(),
30
+ nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
31
+ nn.Sigmoid() # to scale the output to [0, 1]
32
+ )
33
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ self.to(self.device)
35
+
36
+ if os.path.exists(checkpoint_path):
37
+ self._load_model(checkpoint_path)
38
+
39
+ def _load_model(self, path):
40
+ print("Loading ColorNet model...", end="")
41
+ self.load_state_dict(torch.load(path, map_location=self.device))
42
+ print("done.")
43
+
44
+ def forward(self, x):
45
+ x = x.to(self.device)
46
+ x = self.encoder(x)
47
+ x = self.decoder(x)
48
+ return x
49
+
50
+ def train_model(self, model, train_loader, criterion, optimizer, num_epochs=10):
51
+ for epoch in range(num_epochs):
52
+ model.train()
53
+ running_loss = 0.0
54
+ for inputs, _ in train_loader:
55
+ gray_images = transforms.Grayscale(num_output_channels=1)(inputs).to(self.device)
56
+ gray_images = gray_images.repeat(1,3,1,1)
57
+ color_images = inputs.to(self.device)
58
+
59
+ optimizer.zero_grad()
60
+
61
+ outputs = model(gray_images)
62
+ loss = criterion(outputs, color_images)
63
+ loss.backward()
64
+ optimizer.step()
65
+
66
+ running_loss += loss.item() * gray_images.size(0)
67
+
68
+ epoch_loss = running_loss / len(train_loader.dataset)
69
+ print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
70
+
71
+ torch.save(model.state_dict(), self.DEFAULT_CHECKPOINT_PATH)
72
+
73
+
74
+ def colorize(self, input_path:str, output_path):
75
+ input_image = Image.open(input_path).convert("RGB")
76
+ input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(self.device)
77
+
78
+ with torch.inference_mode():
79
+ output_image_tnsr = self(input_image)
80
+ output_image_tnsr = output_image_tnsr.squeeze(0).cpu()
81
+ output_image_tnsr = transforms.ToPILImage()(output_image_tnsr)
82
+
83
+ output_image_tnsr.save(output_path)
84
+
85
+ def visualize_results(model, test_loader, num_images=5):
86
+ model.eval()
87
+ with torch.no_grad():
88
+ data_iter = iter(test_loader)
89
+ images, _ = data_iter.next()
90
+
91
+ # Get grayscale and colorized images
92
+ gray_images = images[:num_images]
93
+ colorized_images = model(gray_images)
94
+
95
+ # Plotting the results
96
+ for i in range(num_images):
97
+ plt.subplot(3, num_images, i+1)
98
+ plt.imshow(gray_images[i].permute(1, 2, 0).squeeze(), cmap="gray")
99
+ plt.axis('off')
100
+
101
+ plt.subplot(3, num_images, num_images+i+1)
102
+ plt.imshow(colorized_images[i].permute(1, 2, 0))
103
+ plt.axis('off')
104
+
105
+ plt.subplot(3, num_images, 2*num_images+i+1)
106
+ plt.imshow(gray_images[i].permute(1, 2, 0).repeat(3, 1, 1).permute(1, 2, 0))
107
+ plt.axis('off')
108
+
109
+ plt.show()
110
+
111
+