tincri commited on
Commit
40e8081
·
1 Parent(s): 70cf11c

Fix #14 app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -27
app.py CHANGED
@@ -13,31 +13,24 @@ from huggingface_hub import hf_hub_download
13
  class ResBlk(nn.Module):
14
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
  super().__init__()
16
- self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
17
- self.norm1 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
18
- self.relu1 = nn.ReLU(inplace=True)
19
- self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
20
- self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
21
- self.relu2 = nn.ReLU(inplace=True)
22
  self.downsample = downsample
23
- if self.downsample:
24
- self.avg_pool = nn.AvgPool2d(2)
 
 
 
 
 
 
 
25
 
26
  def forward(self, x):
27
- residual = x
28
- out = self.conv1(x)
29
- if self.norm1:
30
- out = self.norm1(out)
31
- out = self.relu1(out)
32
- out = self.conv2(out)
33
- if self.norm2:
34
- out = self.norm2(out)
35
- out = self.relu2(out)
36
- if self.downsample:
37
- out = self.avg_pool(out)
38
- residual = self.avg_pool(residual)
39
- out = out + residual
40
- return out
41
 
42
  class AdainResBlk(nn.Module):
43
  def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
@@ -109,8 +102,8 @@ class MappingNetwork(nn.Module):
109
  for layer in self.unshared:
110
  out += [layer(h)]
111
  out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
112
- idx = torch.arange(y.size(0)).to(y.device)
113
- s = out[idx, y]
114
  return s
115
 
116
  class StyleEncoder(nn.Module):
@@ -177,8 +170,8 @@ class Generator(nn.Module):
177
 
178
  # FUNCIÓN PARA CARGAR EL MODELO
179
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
180
- num_domains_mappin = 3
181
- latent_dim_for_mapping = 13
182
  G = Generator(img_size, style_dim).to(device)
183
  M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains).to(device)
184
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
@@ -235,7 +228,7 @@ if __name__ == '__main__':
235
  checkpoint_path = 'iter/12500_nets_ema.ckpt'
236
  img_size = 128
237
  style_dim = 64
238
- num_domains = 3
239
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
240
 
241
  try:
 
13
  class ResBlk(nn.Module):
14
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
  super().__init__()
16
+ self.normalize = normalize
 
 
 
 
 
17
  self.downsample = downsample
18
+ self.main = nn.Sequential(
19
+ nn.Conv2d(dim_in, dim_out, 3, 1, 1),
20
+ nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity(),
21
+ nn.ReLU(inplace=True),
22
+ nn.Conv2d(dim_out, dim_out, 3, 1, 1),
23
+ nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity()
24
+ )
25
+ self.downsample_layer = nn.AvgPool2d(2) if downsample else nn.Identity()
26
+ self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
27
 
28
  def forward(self, x):
29
+ out = self.main(x)
30
+ out = self.downsample_layer(out)
31
+ skip = self.skip(x)
32
+ skip = self.downsample_layer(skip)
33
+ return (out + skip) / math.sqrt(2)
 
 
 
 
 
 
 
 
 
34
 
35
  class AdainResBlk(nn.Module):
36
  def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
 
102
  for layer in self.unshared:
103
  out += [layer(h)]
104
  out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
105
+ idx = torch.LongTensor(range(y.size(0))).unsqueeze(1).to(y.device)
106
+ s = torch.gather(out, 1, idx.unsqueeze(2).expand(-1, -1, out.size(2))).squeeze(1)
107
  return s
108
 
109
  class StyleEncoder(nn.Module):
 
170
 
171
  # FUNCIÓN PARA CARGAR EL MODELO
172
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
173
+ num_domains_mappin = 2
174
+ latent_dim_for_mapping = 14
175
  G = Generator(img_size, style_dim).to(device)
176
  M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains).to(device)
177
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
 
228
  checkpoint_path = 'iter/12500_nets_ema.ckpt'
229
  img_size = 128
230
  style_dim = 64
231
+ num_domains = 2
232
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
233
 
234
  try: