tincri commited on
Commit
603f92a
·
1 Parent(s): 6ac1d70

Fix #5 app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -49
app.py CHANGED
@@ -13,70 +13,73 @@ from huggingface_hub import hf_hub_download
13
  class ResBlk(nn.Module):
14
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
  super().__init__()
16
- self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
17
- self.norm1 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
18
- self.relu1 = nn.ReLU(inplace=True)
19
- self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
20
- self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
21
- self.relu2 = nn.ReLU(inplace=True)
22
  self.downsample = downsample
23
- if self.downsample:
24
- self.avg_pool = nn.AvgPool2d(2)
 
 
 
 
 
 
 
25
 
26
  def forward(self, x):
27
- residual = x
28
- out = self.conv1(x)
29
- if self.norm1:
30
- out = self.norm1(out)
31
- out = self.relu1(out)
32
- out = self.conv2(out)
33
- if self.norm2:
34
- out = self.norm2(out)
35
- out = self.relu2(out)
36
- if self.downsample:
37
- out = self.avg_pool(out)
38
- residual = self.avg_pool(residual)
39
- out = out + residual
40
- return out
41
 
42
  class AdainResBlk(nn.Module):
43
- def __init__(self, dim_in, dim_out, style_dim, upsample=False):
44
  super().__init__()
 
 
 
 
 
 
45
  self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
46
- self.norm1 = AdaIN(dim_out, style_dim)
47
- self.relu1 = nn.ReLU(inplace=True)
48
  self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
49
- self.norm2 = AdaIN(dim_out, style_dim)
50
- self.relu2 = nn.ReLU(inplace=True)
51
- self.upsample = upsample
 
 
52
 
53
  def forward(self, x, s):
54
- residual = x
55
- if self.upsample:
56
- residual = F.interpolate(residual, scale_factor=2, mode='nearest')
57
- out = self.conv1(x)
58
- out = self.norm1(out, s)
59
- out = self.relu1(out)
60
  if self.upsample:
61
- out = F.interpolate(out, scale_factor=2, mode='nearest')
62
- out = self.conv2(out)
63
- out = self.norm2(out, s)
64
- out = self.relu2(out)
65
- out = out + residual
 
 
 
 
 
 
 
 
 
66
  return out
67
 
68
  class AdaIN(nn.Module):
69
  def __init__(self, num_features, style_dim):
70
- super().__init__()
71
- self.norm = nn.InstanceNorm2d(num_features, affine=False)
72
  self.fc = nn.Linear(style_dim, num_features * 2)
73
 
74
  def forward(self, x, s):
75
  h = self.fc(s)
76
- gamma, beta = torch.chunk(h, 2, dim=1)
77
  gamma = gamma.unsqueeze(2).unsqueeze(3)
78
  beta = beta.unsqueeze(2).unsqueeze(3)
79
- return (1 + gamma) * self.norm(x) + beta
80
 
81
  class MappingNetwork(nn.Module):
82
  def __init__(self, latent_dim, style_dim, num_domains):
@@ -145,7 +148,7 @@ class Generator(nn.Module):
145
  self.encode = nn.Sequential(*blocks)
146
 
147
  self.decode = nn.ModuleList()
148
- for _ in range(repeat_num):
149
  dim_out = dim_in // 2
150
  self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)]
151
  dim_in = dim_out
@@ -214,22 +217,23 @@ def create_interface(generator, style_encoder, domain_names, device='cpu'):
214
  )
215
  return iface
216
 
 
217
  if __name__ == '__main__':
218
  #CARGAR EL MODELO ENTRENADO
219
- checkpoint_path = 'iter/10500_nets_ema.ckpt'
220
  img_size = 128
221
- style_dim = 64
222
- num_domains = 3
223
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
224
 
225
  try:
226
  generator, style_encoder = load_pretrained_model(checkpoint_path, img_size, style_dim, num_domains, device)
227
  print("Modelo cargado exitosamente.")
228
 
229
- #DEFINIR LOS NOMBRES DE LOS DOMINIOS
230
  domain_names = ["BMW", "Corvette", "Mazda"]
231
 
232
- # CREAR E LANZAR LA INTERFAZ DE GRADIO
233
  iface = create_interface(generator, style_encoder, domain_names, device)
234
  iface.launch(share=True)
235
 
 
13
  class ResBlk(nn.Module):
14
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
  super().__init__()
16
+ self.normalize = normalize
 
 
 
 
 
17
  self.downsample = downsample
18
+ self.main = nn.Sequential(
19
+ nn.Conv2d(dim_in, dim_out, 3, 1, 1),
20
+ nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity(),
21
+ nn.ReLU(inplace=True),
22
+ nn.Conv2d(dim_out, dim_out, 3, 1, 1),
23
+ nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity()
24
+ )
25
+ self.downsample_layer = nn.AvgPool2d(2) if downsample else nn.Identity()
26
+ self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
27
 
28
  def forward(self, x):
29
+ out = self.main(x)
30
+ out = self.downsample_layer(out)
31
+ skip = self.skip(x)
32
+ skip = self.downsample_layer(skip)
33
+ return (out + skip) / math.sqrt(2)
 
 
 
 
 
 
 
 
 
34
 
35
  class AdainResBlk(nn.Module):
36
+ def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
37
  super().__init__()
38
+ self.upsample = upsample
39
+ self.w_hpf = w_hpf
40
+
41
+ self.norm1 = AdaIN(dim_in, style_dim)
42
+ self.norm2 = AdaIN(dim_out, style_dim)
43
+ self.actv = nn.LeakyReLU(0.2)
44
  self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
 
 
45
  self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
46
+
47
+ if dim_in != dim_out:
48
+ self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0)
49
+ else:
50
+ self.skip = nn.Identity()
51
 
52
  def forward(self, x, s):
53
+ x_orig = x
54
+
 
 
 
 
55
  if self.upsample:
56
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
57
+ x_orig = F.interpolate(x_orig, scale_factor=2, mode='nearest')
58
+
59
+ h = self.norm1(x, s)
60
+ h = self.actv(h)
61
+ h = self.conv1(h)
62
+
63
+ h = self.norm2(h, s)
64
+ h = self.actv(h)
65
+ h = self.conv2(h)
66
+
67
+ skip = self.skip(x_orig)
68
+
69
+ out = (h + skip) / math.sqrt(2)
70
  return out
71
 
72
  class AdaIN(nn.Module):
73
  def __init__(self, num_features, style_dim):
74
+ super(AdaIN, self).__init__()
 
75
  self.fc = nn.Linear(style_dim, num_features * 2)
76
 
77
  def forward(self, x, s):
78
  h = self.fc(s)
79
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
80
  gamma = gamma.unsqueeze(2).unsqueeze(3)
81
  beta = beta.unsqueeze(2).unsqueeze(3)
82
+ return (1 + gamma) * x + beta
83
 
84
  class MappingNetwork(nn.Module):
85
  def __init__(self, latent_dim, style_dim, num_domains):
 
148
  self.encode = nn.Sequential(*blocks)
149
 
150
  self.decode = nn.ModuleList()
151
+ for i in range(repeat_num):
152
  dim_out = dim_in // 2
153
  self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)]
154
  dim_in = dim_out
 
217
  )
218
  return iface
219
 
220
+
221
  if __name__ == '__main__':
222
  #CARGAR EL MODELO ENTRENADO
223
+ checkpoint_path = 'iter/12000_nets_ema.ckpt'
224
  img_size = 128
225
+ style_dim = 64
226
+ num_domains = 3
227
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
228
 
229
  try:
230
  generator, style_encoder = load_pretrained_model(checkpoint_path, img_size, style_dim, num_domains, device)
231
  print("Modelo cargado exitosamente.")
232
 
233
+ # DEFINIR LOS NOMBRES DE LOS DOMINIOS
234
  domain_names = ["BMW", "Corvette", "Mazda"]
235
 
236
+ # CREAR E LANZAR LA INTERFAZ DE GRADIO
237
  iface = create_interface(generator, style_encoder, domain_names, device)
238
  iface.launch(share=True)
239