tincri commited on
Commit
7bcd511
1 Parent(s): 1f14b39
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -251,7 +251,7 @@ class Solver(object):
251
  print(f"Error al cargar el checkpoint: {e}.")
252
  raise Exception(f"Error al cargar el checkpoint: {e}")
253
 
254
- def transfer_style(self, source_image, reference_image, target_domain_index):
255
  # Aseg煤rate de que los modelos est茅n en modo de evaluaci贸n
256
  self.G.eval()
257
  self.S.eval()
@@ -263,14 +263,18 @@ class Solver(object):
263
  transforms.ToTensor(),
264
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
265
  ])
 
 
 
 
266
  source_image = transform(source_image).unsqueeze(0).to(self.device)
267
  reference_image = transform(reference_image).unsqueeze(0).to(self.device)
268
 
269
  # Crear el tensor de dominio objetivo
270
- target_domain = torch.tensor([target_domain_index]).to(self.device)
271
 
272
  # Codificar el estilo de la imagen de referencia
273
- s_ref = self.S(reference_image, target_domain)
274
 
275
  # Generar la imagen con el estilo transferido
276
  generated_image = self.G(source_image, s_ref)
@@ -280,7 +284,7 @@ class Solver(object):
280
  return generated_image
281
 
282
  # Funci贸n principal para la inferencia
283
- def main(source_image, reference_image, target_domain_index, checkpoint_path, args):
284
  if source_image is None or reference_image is None:
285
  raise gr.Error("Por favor, proporciona ambas im谩genes (fuente y referencia).")
286
 
@@ -290,37 +294,37 @@ def main(source_image, reference_image, target_domain_index, checkpoint_path, ar
290
  solver.load_checkpoint(checkpoint_path)
291
 
292
  # Realizar la transferencia de estilo
293
- generated_image = solver.transfer_style(source_image, reference_image, target_domain_index)
294
  return generated_image
295
 
296
  def gradio_interface():
297
  # Definir los argumentos (ajustados para la inferencia)
298
  args = SimpleNamespace(
299
- img_size=128,
300
- num_domains=3,
301
- latent_dim=16,
302
  style_dim=64,
303
  num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
304
  seed=8365,
305
  )
306
 
307
  # Ruta al checkpoint
308
- checkpoint_path = "iter/20500_nets_ema.ckpt" # Reemplaza con la ruta correcta a tu checkpoint
309
 
310
  # Crear la interfaz de Gradio
311
  inputs = [
312
  gr.Image(label="Source Image (Car to change style)"),
313
  gr.Image(label="Reference Image (Style to transfer)"),
314
- gr.Radio(choices=[0, 1, 2], label="Target Domain (0: BMW, 1: Corvette, 2: Mazda)", value=0), # Cambiado a value
315
  ]
316
  outputs = gr.Image(label="Generated Image (Car with transferred style)")
317
 
318
  title = "AutoStyleGAN: Car Style Transfer"
319
- description = "Transfer the style of one car to another. Upload a source car image and a reference car image. Select the target domain (car brand)."
320
 
321
  # Crear la interfaz de Gradio
322
  iface = gr.Interface(
323
- fn=lambda source_image, reference_image, target_domain_index: main(source_image, reference_image, target_domain_index, checkpoint_path, args),
324
  inputs=inputs,
325
  outputs=outputs,
326
  title=title,
@@ -330,4 +334,4 @@ def gradio_interface():
330
 
331
  if __name__ == '__main__':
332
  iface = gradio_interface()
333
- iface.launch(share=True)
 
251
  print(f"Error al cargar el checkpoint: {e}.")
252
  raise Exception(f"Error al cargar el checkpoint: {e}")
253
 
254
+ def transfer_style(self, source_image, reference_image): # Eliminado target_domain_index
255
  # Aseg煤rate de que los modelos est茅n en modo de evaluaci贸n
256
  self.G.eval()
257
  self.S.eval()
 
263
  transforms.ToTensor(),
264
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
265
  ])
266
+ # Convertir a PIL image antes de la transformaci贸n
267
+ source_image = Image.fromarray(source_image)
268
+ reference_image = Image.fromarray(reference_image)
269
+
270
  source_image = transform(source_image).unsqueeze(0).to(self.device)
271
  reference_image = transform(reference_image).unsqueeze(0).to(self.device)
272
 
273
  # Crear el tensor de dominio objetivo
274
+ # target_domain = torch.tensor([target_domain_index]).to(self.device) # Eliminado
275
 
276
  # Codificar el estilo de la imagen de referencia
277
+ s_ref = self.S(reference_image, torch.tensor([0]).to(self.device)) # Simplificado
278
 
279
  # Generar la imagen con el estilo transferido
280
  generated_image = self.G(source_image, s_ref)
 
284
  return generated_image
285
 
286
  # Funci贸n principal para la inferencia
287
+ def main(source_image, reference_image, checkpoint_path, args): # Eliminado target_domain_index
288
  if source_image is None or reference_image is None:
289
  raise gr.Error("Por favor, proporciona ambas im谩genes (fuente y referencia).")
290
 
 
294
  solver.load_checkpoint(checkpoint_path)
295
 
296
  # Realizar la transferencia de estilo
297
+ generated_image = solver.transfer_style(source_image, reference_image) # Eliminado target_domain_index
298
  return generated_image
299
 
300
  def gradio_interface():
301
  # Definir los argumentos (ajustados para la inferencia)
302
  args = SimpleNamespace(
303
+ img_size=128, # Aseg煤rate de que esto coincida con el tama帽o de imagen usado en el entrenamiento
304
+ num_domains=3, # Cambiado a 3 para que coincida con el checkpoint del MappingNetwork
305
+ latent_dim=16, # Puedes ajustar esto si es necesario
306
  style_dim=64,
307
  num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
308
  seed=8365,
309
  )
310
 
311
  # Ruta al checkpoint
312
+ checkpoint_path = "iter/10500_nets_ema.ckpt" # Reemplaza con la ruta correcta a tu checkpoint
313
 
314
  # Crear la interfaz de Gradio
315
  inputs = [
316
  gr.Image(label="Source Image (Car to change style)"),
317
  gr.Image(label="Reference Image (Style to transfer)"),
318
+ # gr.Radio(choices=[0, 1, 2], label="Target Domain (0: BMW, 1: Corvette, 2: Mazda)", value=0), # Eliminado
319
  ]
320
  outputs = gr.Image(label="Generated Image (Car with transferred style)")
321
 
322
  title = "AutoStyleGAN: Car Style Transfer"
323
+ description = "Transfer the style of one car to another. Upload a source car image and a reference car image." # Modificado
324
 
325
  # Crear la interfaz de Gradio
326
  iface = gr.Interface(
327
+ fn=lambda source_image, reference_image: main(source_image, reference_image, checkpoint_path, args), # Eliminado target_domain_index
328
  inputs=inputs,
329
  outputs=outputs,
330
  title=title,
 
334
 
335
  if __name__ == '__main__':
336
  iface = gradio_interface()
337
+ iface.launch(share=True)