tincri commited on
Commit
b53dc3f
·
1 Parent(s): 40e8081

Fix #15 app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -13,24 +13,31 @@ from huggingface_hub import hf_hub_download
13
  class ResBlk(nn.Module):
14
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
  super().__init__()
16
- self.normalize = normalize
 
 
 
 
 
17
  self.downsample = downsample
18
- self.main = nn.Sequential(
19
- nn.Conv2d(dim_in, dim_out, 3, 1, 1),
20
- nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity(),
21
- nn.ReLU(inplace=True),
22
- nn.Conv2d(dim_out, dim_out, 3, 1, 1),
23
- nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity()
24
- )
25
- self.downsample_layer = nn.AvgPool2d(2) if downsample else nn.Identity()
26
- self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
27
 
28
  def forward(self, x):
29
- out = self.main(x)
30
- out = self.downsample_layer(out)
31
- skip = self.skip(x)
32
- skip = self.downsample_layer(skip)
33
- return (out + skip) / math.sqrt(2)
 
 
 
 
 
 
 
 
 
34
 
35
  class AdainResBlk(nn.Module):
36
  def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
@@ -170,10 +177,10 @@ class Generator(nn.Module):
170
 
171
  # FUNCIÓN PARA CARGAR EL MODELO
172
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
173
- num_domains_mappin = 2
174
- latent_dim_for_mapping = 14
175
  G = Generator(img_size, style_dim).to(device)
176
- M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains).to(device)
177
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
178
  checkpoint = torch.load(ckpt_path, map_location=device)
179
  G.load_state_dict(checkpoint['generator'])
 
13
  class ResBlk(nn.Module):
14
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
  super().__init__()
16
+ self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
17
+ self.norm1 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
18
+ self.relu1 = nn.ReLU(inplace=True)
19
+ self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
20
+ self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
21
+ self.relu2 = nn.ReLU(inplace=True)
22
  self.downsample = downsample
23
+ if self.downsample:
24
+ self.avg_pool = nn.AvgPool2d(2)
 
 
 
 
 
 
 
25
 
26
  def forward(self, x):
27
+ residual = x
28
+ out = self.conv1(x)
29
+ if self.norm1:
30
+ out = self.norm1(out)
31
+ out = self.relu1(out)
32
+ out = self.conv2(out) # <--- Corrección aquí
33
+ if self.norm2:
34
+ out = self.norm2(out)
35
+ out = self.relu2(out)
36
+ if self.downsample:
37
+ out = self.avg_pool(out)
38
+ residual = self.avg_pool(residual)
39
+ out = out + residual
40
+ return out
41
 
42
  class AdainResBlk(nn.Module):
43
  def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
 
177
 
178
  # FUNCIÓN PARA CARGAR EL MODELO
179
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
180
+ num_domains_mappin = 3
181
+ latent_dim_for_mapping = 13
182
  G = Generator(img_size, style_dim).to(device)
183
+ M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains_mappin).to(device)
184
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
185
  checkpoint = torch.load(ckpt_path, map_location=device)
186
  G.load_state_dict(checkpoint['generator'])