tincri commited on
Commit
ada74dc
·
1 Parent(s): 56710ee

Fix #18 app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -31
app.py CHANGED
@@ -290,46 +290,41 @@ def main(args, checkpoint_path, source_image, reference_image, target_domain_ind
290
 
291
  return generated_image
292
 
293
- def gradio_interface(checkpoint_path="iter/20500_nets_ema.ckpt", img_size=128, num_domains=3): # Agregamos los valores por defecto
294
- # Interfaz de Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  inputs = [
296
- gr.Image(label="Source Image", type="pil"), # Especificamos el tipo de imagen como PIL
297
- gr.Image(label="Reference Image", type="pil"),
298
- gr.Radio(choices=["BMW", "Corvette", "Mazda"], label="Target Domain", default="BMW")
299
  ]
300
- outputs = gr.Image(label="Generated Image")
301
-
302
- def process_images(source_image, reference_image, target_domain):
303
- # Mapear el dominio seleccionado a un índice
304
- domain_index = {"BMW": 0, "Corvette": 1, "Mazda": 2}[target_domain]
305
-
306
- # Definir los argumentos (ajustados para la inferencia)
307
- args = SimpleNamespace(
308
- img_size=img_size, # Asegúrate de que esto coincida con el tamaño de imagen usado en el entrenamiento
309
- num_domains=num_domains, #args.num_domains, # Cambiado a 3 para que coincida con el checkpoint del MappingNetwork
310
- latent_dim=16, # Puedes ajustar esto si es necesario
311
- style_dim=64,
312
- num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
313
- seed=8365,
314
- )
315
- try:
316
- # Llamar a la función principal para realizar la inferencia
317
- generated_image = main(args, checkpoint_path, source_image, reference_image, domain_index)
318
- return generated_image
319
- except Exception as e:
320
- print(f"Error during processing: {e}")
321
- return None # Devolvemos None en caso de error
322
 
 
323
  iface = gr.Interface(
324
- fn=process_images,
325
  inputs=inputs,
326
  outputs=outputs,
327
- title="AutoStyleGAN Demo",
328
- description="Transfer the style of a reference car image to a source car image. Select the target car domain.",
329
  )
330
  return iface
331
 
332
  if __name__ == '__main__':
333
- # Lanzar la interfaz de Gradio
334
  iface = gradio_interface()
335
  iface.launch()
 
290
 
291
  return generated_image
292
 
293
+ def gradio_interface():
294
+ # Definir los argumentos (ajustados para la inferencia)
295
+ args = SimpleNamespace(
296
+ img_size=128,
297
+ num_domains=3,
298
+ latent_dim=16,
299
+ style_dim=64,
300
+ num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
301
+ seed=8365,
302
+ )
303
+
304
+ # Ruta al checkpoint
305
+ checkpoint_path = "iter/20500_nets_ema.ckpt" # Reemplaza con la ruta correcta a tu checkpoint
306
+
307
+ # Crear la interfaz de Gradio
308
  inputs = [
309
+ gr.Image(label="Source Image (Car to change style)"),
310
+ gr.Image(label="Reference Image (Style to transfer)"),
311
+ gr.Radio(choices=[0, 1, 2], label="Target Domain (0: BMW, 1: Corvette, 2: Mazda)", value=0), # Cambiado a value
312
  ]
313
+ outputs = gr.Image(label="Generated Image (Car with transferred style)")
314
+
315
+ title = "AutoStyleGAN: Car Style Transfer"
316
+ description = "Transfer the style of one car to another. Upload a source car image and a reference car image. Select the target domain (car brand)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
+ # Crear la interfaz de Gradio
319
  iface = gr.Interface(
320
+ fn=lambda source_image, reference_image, target_domain_index: main(source_image, reference_image, target_domain_index, checkpoint_path, args),
321
  inputs=inputs,
322
  outputs=outputs,
323
+ title=title,
324
+ description=description,
325
  )
326
  return iface
327
 
328
  if __name__ == '__main__':
 
329
  iface = gradio_interface()
330
  iface.launch()