Aanisha commited on
Commit
6b8aec5
·
1 Parent(s): 3f754a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.utils import make_grid
3
+ from torchvision import transforms
4
+ import torchvision.transforms.functional as TF
5
+ from torch import nn, optim
6
+ from torch.optim.lr_scheduler import CosineAnnealingLR
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from huggingface_hub import hf_hub_download
9
+ import requests
10
+ import gradio as gr
11
+
12
+ class Upsample(nn.Module):
13
+ def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dropout=True):
14
+ super(Upsample, self).__init__()
15
+ self.dropout = dropout
16
+ self.block = nn.Sequential(
17
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d),
18
+ nn.InstanceNorm2d(out_channels),
19
+ nn.ReLU(inplace=True)
20
+ )
21
+ self.dropout_layer = nn.Dropout2d(0.5)
22
+
23
+ def forward(self, x, shortcut=None):
24
+ x = self.block(x)
25
+ if self.dropout:
26
+ x = self.dropout_layer(x)
27
+
28
+ if shortcut is not None:
29
+ x = torch.cat([x, shortcut], dim=1)
30
+
31
+ return x
32
+
33
+
34
+ class Downsample(nn.Module):
35
+ def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, apply_instancenorm=True):
36
+ super(Downsample, self).__init__()
37
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d)
38
+ self.norm = nn.InstanceNorm2d(out_channels)
39
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
40
+ self.apply_norm = apply_instancenorm
41
+
42
+ def forward(self, x):
43
+ x = self.conv(x)
44
+ if self.apply_norm:
45
+ x = self.norm(x)
46
+ x = self.relu(x)
47
+
48
+ return x
49
+
50
+
51
+ class CycleGAN_Unet_Generator(nn.Module):
52
+ def __init__(self, filter=64):
53
+ super(CycleGAN_Unet_Generator, self).__init__()
54
+ self.downsamples = nn.ModuleList([
55
+ Downsample(3, filter, kernel_size=4, apply_instancenorm=False), # (b, filter, 128, 128)
56
+ Downsample(filter, filter * 2), # (b, filter * 2, 64, 64)
57
+ Downsample(filter * 2, filter * 4), # (b, filter * 4, 32, 32)
58
+ Downsample(filter * 4, filter * 8), # (b, filter * 8, 16, 16)
59
+ Downsample(filter * 8, filter * 8), # (b, filter * 8, 8, 8)
60
+ Downsample(filter * 8, filter * 8), # (b, filter * 8, 4, 4)
61
+ Downsample(filter * 8, filter * 8), # (b, filter * 8, 2, 2)
62
+ ])
63
+
64
+ self.upsamples = nn.ModuleList([
65
+ Upsample(filter * 8, filter * 8),
66
+ Upsample(filter * 16, filter * 8),
67
+ Upsample(filter * 16, filter * 8),
68
+ Upsample(filter * 16, filter * 4, dropout=False),
69
+ Upsample(filter * 8, filter * 2, dropout=False),
70
+ Upsample(filter * 4, filter, dropout=False)
71
+ ])
72
+
73
+ self.last = nn.Sequential(
74
+ nn.ConvTranspose2d(filter * 2, 3, kernel_size=4, stride=2, padding=1),
75
+ nn.Tanh()
76
+ )
77
+
78
+ def forward(self, x):
79
+ skips = []
80
+ for l in self.downsamples:
81
+ x = l(x)
82
+ skips.append(x)
83
+
84
+ skips = reversed(skips[:-1])
85
+ for l, s in zip(self.upsamples, skips):
86
+ x = l(x, s)
87
+
88
+ out = self.last(x)
89
+
90
+ return out
91
+
92
+ class ImageTransform:
93
+ def __init__(self, img_size=256):
94
+ self.transform = {
95
+ 'train': transforms.Compose([
96
+ transforms.Resize((img_size, img_size)),
97
+ transforms.RandomHorizontalFlip(),
98
+ transforms.RandomVerticalFlip(),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize(mean=[0.5], std=[0.5])
101
+ ]),
102
+ 'test': transforms.Compose([
103
+ transforms.Resize((img_size, img_size)),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize(mean=[0.5], std=[0.5])
106
+ })}
107
+
108
+ def __call__(self, img, phase='train'):
109
+ img = self.transform[phase](img)
110
+
111
+ return img
112
+
113
+
114
+ path = hf_hub_download('huggan/NeonGAN', 'model.bin')
115
+ model_gen_n = torch.load(path, map_location=torch.device('cpu'))
116
+
117
+
118
+
119
+
120
+
121
+