Nikhil Mudhalwadkar commited on
Commit
10a6c3a
·
1 Parent(s): 7b96f60

update gradio

Browse files
app.py CHANGED
@@ -5,7 +5,7 @@ import matplotlib
5
  import torch
6
  from pytorch_lightning.utilities.types import EPOCH_OUTPUT
7
 
8
- matplotlib.use('Agg')
9
  import numpy as np
10
  from PIL import Image
11
  import albumentations as A
@@ -13,13 +13,17 @@ import albumentations.pytorch as al_pytorch
13
  import torchvision
14
  from pl_bolts.models.gans import Pix2Pix
15
 
 
 
 
 
 
16
  """ Class """
17
 
18
 
19
  class OverpoweredPix2Pix(Pix2Pix):
20
-
21
  def validation_step(self, batch, batch_idx):
22
- """ Validation step """
23
  real, condition = batch
24
  with torch.no_grad():
25
  loss = self._disc_step(real, condition)
@@ -28,33 +32,36 @@ class OverpoweredPix2Pix(Pix2Pix):
28
  loss = self._gen_step(real, condition)
29
  self.log("val_generator_loss", loss)
30
 
31
- return {
32
- 'sketch': real,
33
- 'colour': condition
34
- }
35
 
36
- def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
37
- sketch = outputs[0]['sketch']
38
- colour = outputs[0]['colour']
 
 
39
  with torch.no_grad():
40
  gen_coloured = self.gen(sketch)
41
  grid_image = torchvision.utils.make_grid(
42
  [
43
- sketch[0], colour[0], gen_coloured[0],
 
 
44
  ],
45
- normalize=True
46
  )
47
  self.logger.experiment.add_image(
48
- f'Image Grid {str(self.current_epoch)}',
49
- grid_image,
50
- self.current_epoch
51
  )
52
 
53
 
54
  """ Load the model """
55
  # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
56
- train_64_val_16_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=44600.ckpt"
57
- train_16_val_1_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
 
 
 
 
58
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
59
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
60
 
@@ -72,18 +79,20 @@ train_16_val_1_plbolts_model.eval()
72
 
73
 
74
  def predict(img: Image, type_of_model: str):
75
- """ Create predictions """
76
  # transform img
77
  image = np.asarray(img)
78
  # use on inference
79
- inference_transform = A.Compose([
80
- A.Resize(width=256, height=256),
81
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
82
- al_pytorch.ToTensorV2(),
83
- ])
84
- inference_img = inference_transform(
85
- image=image
86
- )['image'].unsqueeze(0)
 
 
87
 
88
  # Choose model
89
  if type_of_model == "train batch size 16, val batch size 1":
@@ -113,7 +122,7 @@ model_input = gr.inputs.Radio(
113
  "train batch size 64, val batch size 16",
114
  "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
115
  ],
116
- label="Type of Pix2Pix model to use : "
117
  )
118
  image_input = gr.inputs.Image(type="pil")
119
  img_examples = [
@@ -132,13 +141,17 @@ with gr.Blocks() as demo:
132
  gr.Markdown(" There are three Pix2Pix models in this example:")
133
  gr.Markdown(" 1. Training batch size is 16 , validation is 1")
134
  gr.Markdown(" 2. Training batch size is 64 , validation is 16")
135
- gr.Markdown(" 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
136
- "training batch size is 64 , validation is 16")
 
 
137
  with gr.Tabs():
138
  with gr.TabItem("tr_16_val_1"):
139
  with gr.Row():
140
  image_input1 = gr.inputs.Image(type="pil")
141
- image_output1 = gr.outputs.Image(type="pil", )
 
 
142
  colour_1 = gr.Button("Colour it!")
143
  gr.Examples(
144
  examples=img_examples,
@@ -149,7 +162,9 @@ with gr.Blocks() as demo:
149
  with gr.TabItem("tr_64_val_14"):
150
  with gr.Row():
151
  image_input2 = gr.inputs.Image(type="pil")
152
- image_output2 = gr.outputs.Image(type="pil", )
 
 
153
  colour_2 = gr.Button("Colour it!")
154
  with gr.Row():
155
  gr.Examples(
 
5
  import torch
6
  from pytorch_lightning.utilities.types import EPOCH_OUTPUT
7
 
8
+ matplotlib.use("Agg")
9
  import numpy as np
10
  from PIL import Image
11
  import albumentations as A
 
13
  import torchvision
14
  from pl_bolts.models.gans import Pix2Pix
15
 
16
+ # Hack for spaces
17
+ import os
18
+ os.system("pip uninstall -y gradio")
19
+ os.system("pip install -r requirements.txt")
20
+
21
  """ Class """
22
 
23
 
24
  class OverpoweredPix2Pix(Pix2Pix):
 
25
  def validation_step(self, batch, batch_idx):
26
+ """Validation step"""
27
  real, condition = batch
28
  with torch.no_grad():
29
  loss = self._disc_step(real, condition)
 
32
  loss = self._gen_step(real, condition)
33
  self.log("val_generator_loss", loss)
34
 
35
+ return {"sketch": real, "colour": condition}
 
 
 
36
 
37
+ def validation_epoch_end(
38
+ self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
39
+ ) -> None:
40
+ sketch = outputs[0]["sketch"]
41
+ colour = outputs[0]["colour"]
42
  with torch.no_grad():
43
  gen_coloured = self.gen(sketch)
44
  grid_image = torchvision.utils.make_grid(
45
  [
46
+ sketch[0],
47
+ colour[0],
48
+ gen_coloured[0],
49
  ],
50
+ normalize=True,
51
  )
52
  self.logger.experiment.add_image(
53
+ f"Image Grid {str(self.current_epoch)}", grid_image, self.current_epoch
 
 
54
  )
55
 
56
 
57
  """ Load the model """
58
  # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
59
+ train_64_val_16_plbolts_model_chkpt = (
60
+ "model/lightning_bolts_model/epoch=99-step=44600.ckpt"
61
+ )
62
+ train_16_val_1_plbolts_model_chkpt = (
63
+ "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
64
+ )
65
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
66
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
67
 
 
79
 
80
 
81
  def predict(img: Image, type_of_model: str):
82
+ """Create predictions"""
83
  # transform img
84
  image = np.asarray(img)
85
  # use on inference
86
+ inference_transform = A.Compose(
87
+ [
88
+ A.Resize(width=256, height=256),
89
+ A.Normalize(
90
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0
91
+ ),
92
+ al_pytorch.ToTensorV2(),
93
+ ]
94
+ )
95
+ inference_img = inference_transform(image=image)["image"].unsqueeze(0)
96
 
97
  # Choose model
98
  if type_of_model == "train batch size 16, val batch size 1":
 
122
  "train batch size 64, val batch size 16",
123
  "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
124
  ],
125
+ label="Type of Pix2Pix model to use : ",
126
  )
127
  image_input = gr.inputs.Image(type="pil")
128
  img_examples = [
 
141
  gr.Markdown(" There are three Pix2Pix models in this example:")
142
  gr.Markdown(" 1. Training batch size is 16 , validation is 1")
143
  gr.Markdown(" 2. Training batch size is 64 , validation is 16")
144
+ gr.Markdown(
145
+ " 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
146
+ "training batch size is 64 , validation is 16"
147
+ )
148
  with gr.Tabs():
149
  with gr.TabItem("tr_16_val_1"):
150
  with gr.Row():
151
  image_input1 = gr.inputs.Image(type="pil")
152
+ image_output1 = gr.outputs.Image(
153
+ type="pil",
154
+ )
155
  colour_1 = gr.Button("Colour it!")
156
  gr.Examples(
157
  examples=img_examples,
 
162
  with gr.TabItem("tr_64_val_14"):
163
  with gr.Row():
164
  image_input2 = gr.inputs.Image(type="pil")
165
+ image_output2 = gr.outputs.Image(
166
+ type="pil",
167
+ )
168
  colour_2 = gr.Button("Colour it!")
169
  with gr.Row():
170
  gr.Examples(
app/__init__.py DELETED
File without changes
app/config.py DELETED
@@ -1,3 +0,0 @@
1
- num_workers = 4
2
- train_batch_size = 32
3
- val_batch_size = 1
 
 
 
 
app/consume_data/__init__.py DELETED
File without changes
app/consume_data/consume_data.py DELETED
@@ -1,165 +0,0 @@
1
- import torch
2
- import os
3
- from typing import List, Optional
4
- from PIL import Image
5
- import matplotlib.pyplot as plt
6
- from torchvision import transforms
7
- import albumentations as A
8
- import numpy as np
9
- import albumentations.pytorch as al_pytorch
10
- from typing import Dict, Tuple
11
- from app import config
12
- import pytorch_lightning as pl
13
-
14
- torch.__version__
15
-
16
-
17
- class AnimeDataset(torch.utils.data.Dataset):
18
- """ Sketchs and Colored Image dataset """
19
-
20
- def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
21
- """ Set the transforms and file path """
22
- self.list_files = imgs_path
23
- self.transform = transforms
24
-
25
- def __len__(self) -> int:
26
- """ Should return number of files """
27
- return len(self.list_files)
28
-
29
- def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
30
- """ Get image and mask by index """
31
- # read image file
32
- img_file = self.list_files[index]
33
- # img_path = os.path.join(self.root_dir, img_file)
34
- image = np.array(Image.open(img_file))
35
-
36
- # divide image into sketchs and colored_imgs, right is sketch and left is colored images
37
- sketchs = image[:, image.shape[1] // 2:, :]
38
- colored_imgs = image[:, :image.shape[1] // 2, :]
39
-
40
- # data augmentation on both sketchs and colored_imgs
41
- augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
42
- sketchs, colored_imgs = augmentations['image'], augmentations['image0']
43
-
44
- # conduct data augmentation respectively
45
- sketchs = self.transform.transform_only_input(image=sketchs)['image']
46
- colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
47
- return sketchs, colored_imgs
48
-
49
-
50
- # Data Augmentation
51
- class Transforms:
52
- def __init__(self):
53
- # use on both sketchs and colored images
54
- self.both_transform = A.Compose([
55
- A.Resize(width=256, height=256),
56
- A.HorizontalFlip(p=.5)
57
- ], additional_targets={'image0': 'image'})
58
-
59
- # use on sketchs only
60
- self.transform_only_input = A.Compose([
61
- A.ColorJitter(p=.1),
62
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
63
- al_pytorch.ToTensorV2(),
64
- ])
65
-
66
- # use on colored images
67
- self.transform_only_mask = A.Compose([
68
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
69
- al_pytorch.ToTensorV2(),
70
- ])
71
-
72
-
73
- class Transforms_v1:
74
- """ Class to hold transforms """
75
-
76
- def __init__(self):
77
- # use on both sketchs and colored images
78
- self.resize_572 = A.Compose([
79
- A.Resize(width=572, height=572)
80
- ])
81
-
82
- self.resize_388 = A.Compose([
83
- A.Resize(width=388, height=388)
84
- ])
85
-
86
- self.resize_256 = A.Compose([
87
- A.Resize(width=256, height=256)
88
- ])
89
-
90
- # use on sketchs only
91
- self.transform_only_input = A.Compose([
92
- # A.ColorJitter(p=.1),
93
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
94
- al_pytorch.ToTensorV2(),
95
- ])
96
-
97
- # use on colored images
98
- self.transform_only_mask = A.Compose([
99
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
100
- al_pytorch.ToTensorV2(),
101
- ])
102
-
103
-
104
- class AnimeSketchDataModule(pl.LightningDataModule):
105
- """ Class to hold the Anime sketch Data"""
106
-
107
- def __init__(
108
- self,
109
- data_dir: str,
110
- train_folder_name: str = "train/",
111
- val_folder_name: str = "val/",
112
- train_batch_size: int = config.train_batch_size,
113
- val_batch_size: int = config.val_batch_size,
114
- train_num_images: int = 0,
115
- val_num_images: int = 0,
116
- ):
117
- super().__init__()
118
- self.val_dataset = None
119
- self.train_dataset = None
120
- self.data_dir: str = data_dir
121
- # Set train and val images folder
122
- train_path: str = f"{self.data_dir}{train_folder_name}/"
123
- train_images: List[str] = [f"{train_path}{x}" for x in os.listdir(train_path)]
124
- val_path: str = f"{self.data_dir}{val_folder_name}"
125
- val_images: List[str] = [f"{val_path}{x}" for x in os.listdir(val_path)]
126
- #
127
- self.train_images = train_images[:train_num_images] if train_num_images else train_images
128
- self.val_images = val_images[:val_num_images] if val_num_images else val_images
129
- #
130
- self.train_batch_size = train_batch_size
131
- self.val_batch_size = val_batch_size
132
-
133
- def set_datasets(self) -> None:
134
- """ Get the train and test datasets """
135
- self.train_dataset = AnimeDataset(
136
- imgs_path=self.train_images,
137
- transforms=Transforms()
138
- )
139
- self.val_dataset = AnimeDataset(
140
- imgs_path=self.val_images,
141
- transforms=Transforms()
142
- )
143
- print("The train test dataset lengths are : ", len(self.train_dataset), len(self.val_dataset))
144
- return None
145
-
146
- def setup(self, stage: Optional[str] = None) -> None:
147
- self.set_datasets()
148
-
149
- def train_dataloader(self):
150
- return torch.utils.data.DataLoader(
151
- self.train_dataset,
152
- batch_size=self.train_batch_size,
153
- shuffle=False,
154
- num_workers=2,
155
- pin_memory=True
156
- )
157
-
158
- def val_dataloader(self):
159
- return torch.utils.data.DataLoader(
160
- self.val_dataset,
161
- batch_size=self.val_batch_size,
162
- shuffle=False,
163
- num_workers=2,
164
- pin_memory=True
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/data.py DELETED
@@ -1,69 +0,0 @@
1
- import torch
2
- import os
3
- from typing import List
4
- from PIL import Image
5
- import matplotlib.pyplot as plt
6
- from torchvision import transforms
7
- import albumentations as A
8
- import numpy as np
9
- import albumentations.pytorch as al_pytorch
10
- from typing import Dict, Tuple
11
-
12
-
13
- class AnimeDataset(torch.utils.data.Dataset):
14
- """ Sketchs and Colored Image dataset """
15
-
16
- def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
17
- """ Set the transforms and file path """
18
- self.list_files = imgs_path
19
- self.transform = transforms
20
-
21
- def __len__(self) -> int:
22
- """ Should return number of files """
23
- return len(self.list_files)
24
-
25
- def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
26
- """ Get image and mask by index """
27
- # read image file
28
- img_path = img_file = self.list_files[index]
29
- image = np.array(Image.open(img_path))
30
-
31
- # divide image into sketchs and colored_imgs, right is sketch and left is colored images
32
- # as according to the dataset
33
- sketchs = image[:, image.shape[1] // 2:, :]
34
- colored_imgs = image[:, :image.shape[1] // 2, :]
35
-
36
- # data augmentation on both sketchs and colored_imgs
37
- augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
38
- sketchs, colored_imgs = augmentations['image'], augmentations['image0']
39
-
40
- # conduct data augmentation respectively
41
- sketchs = self.transform.transform_only_input(image=sketchs)['image']
42
- colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
43
- return sketchs, colored_imgs
44
-
45
-
46
- class Transforms:
47
- """ Class to hold transforms """
48
-
49
- def __init__(self):
50
- # use on both sketchs and colored images
51
- self.both_transform = A.Compose([
52
- A.Resize(width=1024, height=1024),
53
- A.HorizontalFlip(p=.5)
54
- ],
55
- additional_targets={'image0': 'image'}
56
- )
57
-
58
- # use on sketchs only
59
- self.transform_only_input = A.Compose([
60
- # A.ColorJitter(p=.1),
61
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
62
- al_pytorch.ToTensorV2(),
63
- ])
64
-
65
- # use on colored images
66
- self.transform_only_mask = A.Compose([
67
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
68
- al_pytorch.ToTensorV2(),
69
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/discriminator/__init__.py DELETED
File without changes
app/discriminator/patch_gan.py DELETED
@@ -1,137 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
- import albumentations as A
4
-
5
-
6
- # CNN block will be used repeatly later
7
- class CNNBlock(nn.Module):
8
- def __init__(self, in_channels, out_channels, stride=2):
9
- super().__init__()
10
- self.conv = nn.Sequential(
11
- nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode='reflect'),
12
- nn.BatchNorm2d(out_channels),
13
- nn.LeakyReLU(0.2)
14
- )
15
-
16
- def forward(self, x):
17
- return self.conv(x)
18
-
19
-
20
- class PatchGan(torch.nn.Module):
21
- """ Patch GAN Architecture """
22
-
23
- @staticmethod
24
- def create_contracting_block(in_channels: int, out_channels: int):
25
- """
26
- Create encoding layer
27
- :param in_channels:
28
- :param out_channels:
29
- :return:
30
- """
31
- conv_layer = torch.nn.Sequential(
32
- torch.nn.Conv2d(
33
- in_channels=in_channels,
34
- out_channels=out_channels,
35
- kernel_size=3,
36
- padding=1,
37
- ),
38
- torch.nn.ReLU(),
39
- torch.nn.Conv2d(
40
- in_channels=out_channels,
41
- out_channels=out_channels,
42
- kernel_size=3,
43
- padding=1,
44
- ),
45
- torch.nn.ReLU(),
46
- )
47
- max_pool = torch.nn.Sequential(
48
- torch.nn.MaxPool2d(
49
- stride=2,
50
- kernel_size=2,
51
- ),
52
- )
53
- layer = torch.nn.Sequential(
54
- conv_layer,
55
- max_pool,
56
- )
57
- return layer
58
-
59
- def __init__(self, input_channels: int, hidden_channels: int) -> None:
60
- super().__init__()
61
- self.resize_channels = torch.nn.Conv2d(
62
- in_channels=input_channels,
63
- out_channels=hidden_channels,
64
- kernel_size=1,
65
- )
66
-
67
- self.enc1 = self.create_contracting_block(
68
- in_channels=hidden_channels,
69
- out_channels=hidden_channels * 2
70
- )
71
-
72
- self.enc2 = self.create_contracting_block(
73
- in_channels=hidden_channels * 2,
74
- out_channels=hidden_channels * 4
75
- )
76
-
77
- self.enc3 = self.create_contracting_block(
78
- in_channels=hidden_channels * 4,
79
- out_channels=hidden_channels * 8
80
- )
81
- self.enc4 = self.create_contracting_block(
82
- in_channels=hidden_channels * 8,
83
- out_channels=hidden_channels * 16
84
- )
85
-
86
- self.final_layer = torch.nn.Conv2d(
87
- in_channels=hidden_channels * 16,
88
- out_channels=1,
89
- kernel_size=1,
90
- )
91
-
92
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
93
- """ Forward patch gan layer """
94
- inpt = torch.cat([x, y], axis=1)
95
- resize_img = self.resize_channels(inpt)
96
- enc1 = self.enc1(resize_img)
97
- enc2 = self.enc2(enc1)
98
- enc3 = self.enc3(enc2)
99
- enc4 = self.enc4(enc3)
100
- final_layer = self.final_layer(enc4)
101
- return final_layer
102
-
103
-
104
- # x, y <- concatenate the gen image and the input image to determin the gen image is real or not
105
- class Discriminator(nn.Module):
106
- def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
107
- super().__init__()
108
- self.initial = nn.Sequential(
109
- nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
110
- nn.LeakyReLU(.2)
111
- )
112
-
113
- # save layers into a list
114
- layers = []
115
- in_channels = features[0]
116
- for feature in features[1:]:
117
- layers.append(
118
- CNNBlock(
119
- in_channels,
120
- feature,
121
- stride=1 if feature == features[-1] else 2
122
- ),
123
- )
124
- in_channels = feature
125
-
126
- # append last conv layer
127
- layers.append(
128
- nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')
129
- )
130
-
131
- # create a model using the list of layers
132
- self.model = nn.Sequential(*layers)
133
-
134
- def forward(self, x, y):
135
- x = torch.cat([x, y], dim=1)
136
- x = self.initial(x)
137
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/generator/__init__.py DELETED
File without changes
app/generator/unetGen.py DELETED
@@ -1,174 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from app.generator import unetParts
4
-
5
-
6
- class UNET(torch.nn.Module):
7
- """ Implementation of unet """
8
-
9
- def __init__(
10
- self,
11
- ) -> None:
12
- """
13
- Create the UNET here
14
- """
15
- super().__init__()
16
- self.enc_layer1: unetParts.EncoderLayer = unetParts.EncoderLayer(
17
- in_channels=3,
18
- out_channels=64
19
- )
20
- self.enc_layer2: unetParts.EncoderLayer = unetParts.EncoderLayer(
21
- in_channels=64,
22
- out_channels=128
23
- )
24
- self.enc_layer3: unetParts.EncoderLayer = unetParts.EncoderLayer(
25
- in_channels=128,
26
- out_channels=256
27
- )
28
- self.enc_layer4: unetParts.EncoderLayer = unetParts.EncoderLayer(
29
- in_channels=256,
30
- out_channels=512
31
- )
32
- # Middle layer
33
- self.middle_layer: unetParts.MiddleLayer = unetParts.MiddleLayer(
34
- in_channels=512,
35
- out_channels=1024,
36
- )
37
- # Decoding layer
38
- self.dec_layer1: unetParts.DecoderLayer = unetParts.DecoderLayer(
39
- in_channels=1024,
40
- out_channels=512,
41
- )
42
- self.dec_layer2: unetParts.DecoderLayer = unetParts.DecoderLayer(
43
- in_channels=512,
44
- out_channels=256,
45
- )
46
-
47
- self.dec_layer3: unetParts.DecoderLayer = unetParts.DecoderLayer(
48
- in_channels=256,
49
- out_channels=128,
50
- )
51
- self.dec_layer4: unetParts.DecoderLayer = unetParts.DecoderLayer(
52
- in_channels=128,
53
- out_channels=64,
54
- )
55
- self.final_layer: torch.nn.Conv2d = torch.nn.Conv2d(
56
- in_channels=64,
57
- out_channels=3,
58
- kernel_size=1
59
- )
60
-
61
- def forward(self, x: torch.Tensor) -> torch.Tensor:
62
- """
63
- Forward function
64
- :param x:
65
- :return:
66
- """
67
- # enc layers
68
- enc1, conv1 = self.enc_layer1(x=x) # 64
69
- enc2, conv2 = self.enc_layer2(x=enc1) # 128
70
- enc3, conv3 = self.enc_layer3(x=enc2) # 256
71
- enc4, conv4 = self.enc_layer4(x=enc3) # 512
72
- # middle layers
73
- mid = self.middle_layer(x=enc4) # 1024
74
- # expanding layers
75
- # 512
76
- dec1 = self.dec_layer1(
77
- input_layer=mid,
78
- cropping_layer=conv4,
79
- )
80
- # 256
81
- dec2 = self.dec_layer2(
82
- input_layer=dec1,
83
- cropping_layer=conv3,
84
- )
85
- # 128
86
- dec3 = self.dec_layer3(
87
- input_layer=dec2,
88
- cropping_layer=conv2,
89
- )
90
- # 64
91
- dec4 = self.dec_layer4(
92
- input_layer=dec3,
93
- cropping_layer=conv1,
94
- )
95
- # 3
96
- fin_layer = self.final_layer(
97
- dec4,
98
- )
99
- # Interpolate to retain size
100
- fin_layer_resized = torch.nn.functional.interpolate(fin_layer, 572)
101
- return fin_layer_resized
102
-
103
-
104
- class Generator(nn.Module):
105
- def __init__(self, in_channels=3, features=64):
106
- super().__init__()
107
- # Encoder
108
- self.initial_down = nn.Sequential(
109
- nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode='reflect'),
110
- nn.LeakyReLU(.2),
111
- )
112
- self.down1 = Block(features, features * 2, down=True, act='leaky', use_dropout=False) # 64
113
- self.down2 = Block(features * 2, features * 4, down=True, act='leaky', use_dropout=False) # 32
114
- self.down3 = Block(features * 4, features * 8, down=True, act='leaky', use_dropout=False) # 16
115
- self.down4 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 8
116
- self.down5 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 4
117
- self.down6 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 2
118
- self.bottleneck = nn.Sequential(
119
- nn.Conv2d(features * 8, features * 8, 4, 2, 1, padding_mode='reflect'),
120
- nn.ReLU(), # 1x1
121
- )
122
- # Decoder
123
- self.up1 = Block(features * 8, features * 8, down=False, act='relu', use_dropout=True)
124
- self.up2 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=True)
125
- self.up3 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=True)
126
- self.up4 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=False)
127
- self.up5 = Block(features * 8 * 2, features * 4, down=False, act='relu', use_dropout=False)
128
- self.up6 = Block(features * 4 * 2, features * 2, down=False, act='relu', use_dropout=False)
129
- self.up7 = Block(features * 2 * 2, features, down=False, act='relu', use_dropout=False)
130
- self.final_up = nn.Sequential(
131
- nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
132
- nn.Tanh()
133
- )
134
-
135
- def forward(self, x):
136
- # Encoder
137
- d1 = self.initial_down(x)
138
- d2 = self.down1(d1)
139
- d3 = self.down2(d2)
140
- d4 = self.down3(d3)
141
- d5 = self.down4(d4)
142
- d6 = self.down5(d5)
143
- d7 = self.down6(d6)
144
- bottleneck = self.bottleneck(d7)
145
-
146
- # Decoder
147
- u1 = self.up1(bottleneck)
148
- u2 = self.up2(torch.cat([u1, d7], 1))
149
- u3 = self.up3(torch.cat([u2, d6], 1))
150
- u4 = self.up4(torch.cat([u3, d5], 1))
151
- u5 = self.up5(torch.cat([u4, d4], 1))
152
- u6 = self.up6(torch.cat([u5, d3], 1))
153
- u7 = self.up7(torch.cat([u6, d2], 1))
154
- return self.final_up(torch.cat([u7, d1], 1))
155
-
156
-
157
- # block will be use repeatly later
158
- class Block(nn.Module):
159
- def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
160
- super().__init__()
161
- self.conv = nn.Sequential(
162
- # the block will be use on both encoder (down=True) and decoder (down=False)
163
- nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode='reflect')
164
- if down
165
- else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
166
- nn.BatchNorm2d(out_channels),
167
- nn.ReLU() if act == 'relu' else nn.LeakyReLU(.2)
168
- )
169
- self.use_dropout = use_dropout
170
- self.dropout = nn.Dropout(.5)
171
-
172
- def forward(self, x):
173
- x = self.conv(x)
174
- return self.dropout(x) if self.use_dropout else x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/generator/unetParts.py DELETED
@@ -1,106 +0,0 @@
1
- import torch
2
- from typing import Tuple
3
-
4
-
5
- class DecoderLayer(torch.nn.Module):
6
- """Decoder model"""
7
-
8
- def __init__(self, in_channels: int, out_channels: int):
9
- super().__init__()
10
- self.up_sample_layer = torch.nn.Sequential(
11
- torch.nn.ConvTranspose2d(
12
- in_channels=in_channels,
13
- out_channels=out_channels,
14
- kernel_size=2,
15
- stride=2,
16
- bias=False,
17
- )
18
- )
19
- self.conv_layer = EncoderLayer(
20
- in_channels=in_channels,
21
- out_channels=out_channels,
22
- ).conv_layer
23
-
24
- @staticmethod
25
- def _get_cropping_shape(previous_layer_shape: torch.Size, current_layer_shape: torch.Size) -> int:
26
- """ Get the shape to crop """
27
- return (previous_layer_shape[2] - current_layer_shape[2]) // 2 * -1
28
-
29
- def forward(
30
- self,
31
- input_layer: torch.Tensor,
32
- cropping_layer: torch.Tensor
33
- ) -> torch.Tensor:
34
- """
35
- Forward function to concatenate and conv the figure
36
- :param cropping_layer:
37
- :param input_layer:
38
- :return:
39
- """
40
- input_layer = self.up_sample_layer(input_layer)
41
-
42
- cropping_shape = self._get_cropping_shape(
43
- current_layer_shape=input_layer.shape,
44
- previous_layer_shape=cropping_layer.shape,
45
- )
46
-
47
- cropping_layer = torch.nn.functional.pad(
48
- input=cropping_layer,
49
- pad=[cropping_shape for _ in range(4)]
50
- )
51
- combined_layer = torch.cat(
52
- tensors=[input_layer, cropping_layer],
53
- dim=1
54
- )
55
- result = self.conv_layer(combined_layer)
56
- return result
57
-
58
-
59
- class EncoderLayer(torch.nn.Module):
60
- """Encoder Layer"""
61
-
62
- def __init__(self, in_channels: int, out_channels: int) -> None:
63
- super().__init__()
64
- self.conv_layer = torch.nn.Sequential(
65
- torch.nn.Conv2d(
66
- in_channels=in_channels,
67
- out_channels=out_channels,
68
- kernel_size=3,
69
- stride=2,
70
- padding=1,
71
- ),
72
- torch.nn.LeakyReLU(),
73
- torch.nn.Conv2d(
74
- in_channels=out_channels,
75
- out_channels=out_channels,
76
- kernel_size=3,
77
- stride=2,
78
- padding=1,
79
- ),
80
- torch.nn.LeakyReLU(),
81
- )
82
- self.max_pool = torch.nn.Sequential(
83
- torch.nn.MaxPool2d(2),
84
- )
85
- self.layer = torch.nn.Sequential(
86
- self.conv_layer,
87
- self.max_pool,
88
- )
89
-
90
- def get_conv_layers(self, x: torch.Tensor) -> torch.Tensor:
91
- """Need to concatenate the layer"""
92
- return self.conv_layer(x)
93
-
94
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
95
- """Forward pass to return conv layer and the max pool layer"""
96
- conv_output: torch.tensor = self.conv_layer(x)
97
- fin_out: torch.Tensor = self.max_pool(conv_output)
98
- return fin_out, conv_output
99
-
100
-
101
- class MiddleLayer(EncoderLayer):
102
- """Middle layer only"""
103
-
104
- def forward(self, x: torch.tensor) -> torch.tensor:
105
- """Forward pass"""
106
- return self.conv_layer(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/model/__init__.py DELETED
File without changes
app/model/lit_model.py DELETED
@@ -1,145 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import pytorch_lightning as pl
3
- import torch
4
- import torch.nn as nn
5
- import torchvision
6
-
7
-
8
- class Pix2PixLitModule(pl.LightningModule):
9
- """ Lightning Module for pix2pix """
10
-
11
- @staticmethod
12
- def _weights_init(m):
13
- if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
14
- torch.nn.init.normal_(m.weight, 0.0, 0.02)
15
- if isinstance(m, nn.BatchNorm2d):
16
- torch.nn.init.normal_(m.weight, 0.0, 0.02)
17
- torch.nn.init.constant_(m.bias, 0)
18
-
19
- def __init__(
20
- self,
21
- generator,
22
- discriminator,
23
- use_gpu: bool,
24
- lambda_recon=100
25
- ):
26
- super().__init__()
27
- self.save_hyperparameters()
28
-
29
- self.gen = generator
30
- self.disc = discriminator
31
-
32
- # intializing weights
33
- self.gen = self.gen.apply(self._weights_init)
34
- self.disc = self.disc.apply(self._weights_init)
35
- #
36
- self.adversarial_criterion = nn.BCEWithLogitsLoss()
37
- self.recon_criterion = nn.L1Loss()
38
- self.lambda_l1 = lambda_recon
39
-
40
- def _gen_step(self, sketch, coloured_sketches):
41
- # Pix2Pix has adversarial and a reconstruction loss
42
- # First calculate the adversarial loss
43
- gen_coloured_sketches = self.gen(sketch)
44
- # disc_logits = self.disc(gen_coloured_sketches, coloured_sketches)
45
- disc_logits = self.disc(sketch, gen_coloured_sketches)
46
- adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))
47
- # calculate reconstruction loss
48
- recon_loss = self.recon_criterion(gen_coloured_sketches, coloured_sketches) * self.lambda_l1
49
- #
50
- self.log("Gen recon_loss", recon_loss)
51
- self.log("Gen adversarial_loss", adversarial_loss)
52
- #
53
- return adversarial_loss + recon_loss
54
-
55
- def _disc_step(self, sketch, coloured_sketches):
56
- gen_coloured_sketches = self.gen(sketch).detach()
57
- #
58
- # fake_logits = self.disc(gen_coloured_sketches, coloured_sketches)
59
- fake_logits = self.disc(sketch, gen_coloured_sketches)
60
- real_logits = self.disc(sketch, coloured_sketches)
61
- #
62
- fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
63
- real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
64
- #
65
- self.log("PatchGAN fake_loss", fake_loss)
66
- self.log("PatchGAN real_loss", real_loss)
67
- return (real_loss + fake_loss) / 2
68
-
69
- def forward(self, x):
70
- return self.gen(x)
71
-
72
- def training_step(self, batch, batch_idx, optimizer_idx):
73
- real, condition = batch
74
- loss = None
75
- if optimizer_idx == 0:
76
- loss = self._disc_step(real, condition)
77
- self.log("TRAIN_PatchGAN Loss", loss)
78
- elif optimizer_idx == 1:
79
- loss = self._gen_step(real, condition)
80
- self.log("TRAIN_Generator Loss", loss)
81
- return loss
82
-
83
- def validation_epoch_end(self, outputs) -> None:
84
- """ Log the images"""
85
- sketch = outputs[0]['sketch']
86
- colour = outputs[0]['colour']
87
- gen_coloured = self.gen(sketch)
88
- grid_image = torchvision.utils.make_grid(
89
- [sketch[0], colour[0], gen_coloured[0]],
90
- normalize=True
91
- )
92
- self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
93
- #plt.imshow(grid_image.permute(1, 2, 0))
94
-
95
- def validation_step(self, batch, batch_idx):
96
- """ Validation step """
97
- real, condition = batch
98
- return {
99
- 'sketch': real,
100
- 'colour': condition
101
- }
102
-
103
- def configure_optimizers(self, lr=2e-4):
104
- gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))
105
- disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
106
- return disc_opt, gen_opt
107
-
108
- # class EpochInference(pl.Callback):
109
- # """
110
- # Callback on each end of training epoch
111
- # The callback will do inference on test dataloader based on corresponding checkpoints
112
- # The results will be saved as an image with 4-rows:
113
- # 1 - Input image e.g. grayscale edged input
114
- # 2 - Ground-truth
115
- # 3 - Single inference
116
- # 4 - Mean of hundred accumulated inference
117
- # Note that the inference have a noise factor that will generate different output on each execution
118
- # """
119
- #
120
- # def __init__(self, dataloader, use_gpu: bool, *args, **kwargs):
121
- # super().__init__(*args, **kwargs)
122
- # self.dataloader = dataloader
123
- # self.use_gpu = use_gpu
124
- #
125
- # def on_train_epoch_end(self, trainer, pl_module):
126
- # super().on_train_epoch_end(trainer, pl_module)
127
- # data = next(iter(self.dataloader))
128
- # image, target = data
129
- # if self.use_gpu:
130
- # image = image.cuda()
131
- # target = target.cuda()
132
- # with torch.no_grad():
133
- # # Take average of multiple inference as there is a random noise
134
- # # Single
135
- # reconstruction_init = pl_module(image)
136
- # reconstruction_init = torch.clip(reconstruction_init, 0, 1)
137
- # # # Mean
138
- # # reconstruction_mean = torch.stack([pl_module(image) for _ in range(10)])
139
- # # reconstruction_mean = torch.clip(reconstruction_mean, 0, 1)
140
- # # reconstruction_mean = torch.mean(reconstruction_mean, dim=0)
141
- # # Grayscale 1-D to 3-D
142
- # # image = torch.stack([image for _ in range(3)], dim=1)
143
- # # image = torch.squeeze(image)
144
- # grid_image = torchvision.utils.make_grid([image[0], target[0], reconstruction_init[0]])
145
- # torchvision.utils.save_image(grid_image, fp=f'{trainer.default_root_dir}/epoch-{trainer.current_epoch:04}.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/scratch.py DELETED
@@ -1,34 +0,0 @@
1
-
2
- class GANInference:
3
- def __init__(
4
- self,
5
- model: Pix2PixLitModule,
6
- img_file: str = "/Users/nimud/Downloads/thesis_test2.png",
7
- ) -> None:
8
- self.img_file = img_file
9
- self.model = model
10
-
11
- def _get_image_from_path(self) -> torch.Tensor:
12
- """ gets the tensor from filepath """
13
- image = np.array(Image.open(self.img_file))
14
- # use on inference
15
- inference_transform = A.Compose([
16
- A.Resize(width=256, height=256),
17
- A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
18
- al_pytorch.ToTensorV2(),
19
- ])
20
- inference_img = inference_transform(image=image)['image'].unsqueeze(0)
21
- return inference_img
22
-
23
- def _create_grid(self, result: torch.Tensor) -> np.array:
24
- return torchvision.utils.make_grid(
25
- [result[0].permute(1, 2, 0).detach()],
26
- normalize=True
27
- )
28
-
29
- def run(self) -> np.array:
30
- """ Returns a plottable image """
31
- inference_img = self._get_image_from_path()
32
- result = self.model(inference_img)
33
- adjusted_result = self._create_grid(result=result)
34
- return adjusted_result