tincri commited on
Commit
b3a710d
·
1 Parent(s): f627f24

Fix #12 app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -102,8 +102,8 @@ class MappingNetwork(nn.Module):
102
  for layer in self.unshared:
103
  out += [layer(h)]
104
  out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
105
- idx = torch.LongTensor(range(y.size(0))).unsqueeze(1).to(y.device)
106
- s = torch.gather(out, 1, idx.unsqueeze(2).expand(-1, -1, out.size(2))).squeeze(1)
107
  return s
108
 
109
  class StyleEncoder(nn.Module):
@@ -170,8 +170,8 @@ class Generator(nn.Module):
170
 
171
  # FUNCIÓN PARA CARGAR EL MODELO
172
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
173
- num_domains_mappin = 2
174
- latent_dim_for_mapping = 14
175
  G = Generator(img_size, style_dim).to(device)
176
  M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains).to(device)
177
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
@@ -228,7 +228,7 @@ if __name__ == '__main__':
228
  checkpoint_path = 'iter/12500_nets_ema.ckpt'
229
  img_size = 128
230
  style_dim = 64
231
- num_domains = 2
232
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
233
 
234
  try:
 
102
  for layer in self.unshared:
103
  out += [layer(h)]
104
  out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
105
+ idx = torch.arange(y.size(0)).to(y.device)
106
+ s = out[idx, y]
107
  return s
108
 
109
  class StyleEncoder(nn.Module):
 
170
 
171
  # FUNCIÓN PARA CARGAR EL MODELO
172
  def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
173
+ num_domains_mappin = 3
174
+ latent_dim_for_mapping = 13
175
  G = Generator(img_size, style_dim).to(device)
176
  M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains).to(device)
177
  S = StyleEncoder(img_size, style_dim, num_domains).to(device)
 
228
  checkpoint_path = 'iter/12500_nets_ema.ckpt'
229
  img_size = 128
230
  style_dim = 64
231
+ num_domains = 3
232
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
233
 
234
  try: