Spaces:
Sleeping
Sleeping
Fix #18 app.py
Browse files
app.py
CHANGED
@@ -290,46 +290,41 @@ def main(args, checkpoint_path, source_image, reference_image, target_domain_ind
|
|
290 |
|
291 |
return generated_image
|
292 |
|
293 |
-
def gradio_interface(
|
294 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
inputs = [
|
296 |
-
gr.Image(label="Source Image
|
297 |
-
gr.Image(label="Reference Image
|
298 |
-
gr.Radio(choices=[
|
299 |
]
|
300 |
-
outputs = gr.Image(label="Generated Image")
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
domain_index = {"BMW": 0, "Corvette": 1, "Mazda": 2}[target_domain]
|
305 |
-
|
306 |
-
# Definir los argumentos (ajustados para la inferencia)
|
307 |
-
args = SimpleNamespace(
|
308 |
-
img_size=img_size, # Asegúrate de que esto coincida con el tamaño de imagen usado en el entrenamiento
|
309 |
-
num_domains=num_domains, #args.num_domains, # Cambiado a 3 para que coincida con el checkpoint del MappingNetwork
|
310 |
-
latent_dim=16, # Puedes ajustar esto si es necesario
|
311 |
-
style_dim=64,
|
312 |
-
num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
|
313 |
-
seed=8365,
|
314 |
-
)
|
315 |
-
try:
|
316 |
-
# Llamar a la función principal para realizar la inferencia
|
317 |
-
generated_image = main(args, checkpoint_path, source_image, reference_image, domain_index)
|
318 |
-
return generated_image
|
319 |
-
except Exception as e:
|
320 |
-
print(f"Error during processing: {e}")
|
321 |
-
return None # Devolvemos None en caso de error
|
322 |
|
|
|
323 |
iface = gr.Interface(
|
324 |
-
fn=
|
325 |
inputs=inputs,
|
326 |
outputs=outputs,
|
327 |
-
title=
|
328 |
-
description=
|
329 |
)
|
330 |
return iface
|
331 |
|
332 |
if __name__ == '__main__':
|
333 |
-
# Lanzar la interfaz de Gradio
|
334 |
iface = gradio_interface()
|
335 |
iface.launch()
|
|
|
290 |
|
291 |
return generated_image
|
292 |
|
293 |
+
def gradio_interface():
|
294 |
+
# Definir los argumentos (ajustados para la inferencia)
|
295 |
+
args = SimpleNamespace(
|
296 |
+
img_size=128,
|
297 |
+
num_domains=3,
|
298 |
+
latent_dim=16,
|
299 |
+
style_dim=64,
|
300 |
+
num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
|
301 |
+
seed=8365,
|
302 |
+
)
|
303 |
+
|
304 |
+
# Ruta al checkpoint
|
305 |
+
checkpoint_path = "iter/20500_nets_ema.ckpt" # Reemplaza con la ruta correcta a tu checkpoint
|
306 |
+
|
307 |
+
# Crear la interfaz de Gradio
|
308 |
inputs = [
|
309 |
+
gr.Image(label="Source Image (Car to change style)"),
|
310 |
+
gr.Image(label="Reference Image (Style to transfer)"),
|
311 |
+
gr.Radio(choices=[0, 1, 2], label="Target Domain (0: BMW, 1: Corvette, 2: Mazda)", value=0), # Cambiado a value
|
312 |
]
|
313 |
+
outputs = gr.Image(label="Generated Image (Car with transferred style)")
|
314 |
+
|
315 |
+
title = "AutoStyleGAN: Car Style Transfer"
|
316 |
+
description = "Transfer the style of one car to another. Upload a source car image and a reference car image. Select the target domain (car brand)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
+
# Crear la interfaz de Gradio
|
319 |
iface = gr.Interface(
|
320 |
+
fn=lambda source_image, reference_image, target_domain_index: main(source_image, reference_image, target_domain_index, checkpoint_path, args),
|
321 |
inputs=inputs,
|
322 |
outputs=outputs,
|
323 |
+
title=title,
|
324 |
+
description=description,
|
325 |
)
|
326 |
return iface
|
327 |
|
328 |
if __name__ == '__main__':
|
|
|
329 |
iface = gradio_interface()
|
330 |
iface.launch()
|