tincri commited on
Commit
17e6b36
·
1 Parent(s): ada74dc

Fix #19 app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -8,7 +8,7 @@ import os
8
  from types import SimpleNamespace
9
  import random
10
  from torchvision.utils import save_image
11
- import gradio as gr # Importamos Gradio
12
 
13
  # Asegúrate de que las funciones necesarias estén definidas (si no lo están ya)
14
  def resize(img, size):
@@ -274,12 +274,15 @@ class Solver(object):
274
  # Generar la imagen con el estilo transferido
275
  generated_image = self.G(source_image, s_ref)
276
 
277
- # Denormalizar la imagen para mostrarla o guardarla
278
  generated_image = denormalize(generated_image.squeeze(0)).cpu()
279
  return generated_image
280
 
281
  # Función principal para la inferencia
282
- def main(args, checkpoint_path, source_image, reference_image, target_domain_index): # Cambiamos los paths por las imágenes
 
 
 
283
  # Crear el solver
284
  solver = Solver(args)
285
  # Cargar el checkpoint
@@ -287,7 +290,6 @@ def main(args, checkpoint_path, source_image, reference_image, target_domain_ind
287
 
288
  # Realizar la transferencia de estilo
289
  generated_image = solver.transfer_style(source_image, reference_image, target_domain_index)
290
-
291
  return generated_image
292
 
293
  def gradio_interface():
@@ -327,4 +329,4 @@ def gradio_interface():
327
 
328
  if __name__ == '__main__':
329
  iface = gradio_interface()
330
- iface.launch()
 
8
  from types import SimpleNamespace
9
  import random
10
  from torchvision.utils import save_image
11
+ import gradio as gr
12
 
13
  # Asegúrate de que las funciones necesarias estén definidas (si no lo están ya)
14
  def resize(img, size):
 
274
  # Generar la imagen con el estilo transferido
275
  generated_image = self.G(source_image, s_ref)
276
 
277
+ # Denormalizar la imagen para mostrarla en la interfaz
278
  generated_image = denormalize(generated_image.squeeze(0)).cpu()
279
  return generated_image
280
 
281
  # Función principal para la inferencia
282
+ def main(source_image, reference_image, target_domain_index, checkpoint_path, args):
283
+ if source_image is None or reference_image is None:
284
+ raise gr.Error("Por favor, proporciona ambas imágenes (fuente y referencia).")
285
+
286
  # Crear el solver
287
  solver = Solver(args)
288
  # Cargar el checkpoint
 
290
 
291
  # Realizar la transferencia de estilo
292
  generated_image = solver.transfer_style(source_image, reference_image, target_domain_index)
 
293
  return generated_image
294
 
295
  def gradio_interface():
 
329
 
330
  if __name__ == '__main__':
331
  iface = gradio_interface()
332
+ iface.launch(share=True)