tincri commited on
Commit
2c87f60
·
1 Parent(s): d2c05b5

Fix #8 app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -167,8 +167,9 @@ class Generator(nn.Module):
167
 
168
  # FUNCIÓN PARA CARGAR EL MODELO
169
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
 
170
  G = Generator(img_size, style_dim).to(device)
171
- M = MappingNetwork(13, style_dim, num_domains).to(device) # Suponiendo latent_dim=16
172
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
173
  checkpoint = torch.load(ckpt_path, map_location=device)
174
  G.load_state_dict(checkpoint['generator'])
 
167
 
168
  # FUNCIÓN PARA CARGAR EL MODELO
169
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
170
+ latent_dim_for_mapping = 13
171
  G = Generator(img_size, style_dim).to(device)
172
+ M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains).to(device)
173
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
174
  checkpoint = torch.load(ckpt_path, map_location=device)
175
  G.load_state_dict(checkpoint['generator'])