tincri commited on
Commit
27ffbc0
1 Parent(s): dccf04b

Fix #23 app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -11,6 +11,7 @@ from torchvision.utils import save_image
11
  import gradio as gr
12
  import numpy as np
13
  import io
 
14
 
15
  # Aseg煤rate de que las funciones necesarias est茅n definidas (si no lo est谩n ya)
16
  def resize(img, size):
@@ -265,11 +266,15 @@ class Solver(object):
265
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
266
  ])
267
  # Convertir a PIL image antes de la transformaci贸n
268
- source_image = Image.open(io.BytesIO(source_image)) # Use BytesIO
269
- reference_image = Image.open(io.BytesIO(reference_image))
 
 
 
 
270
 
271
- source_image = transform(source_image).unsqueeze(0).to(self.device)
272
- reference_image = transform(reference_image).unsqueeze(0).to(self.device)
273
 
274
  # Codificar el estilo de la imagen de referencia
275
  s_ref = self.S(reference_image, torch.tensor([0]).to(self.device))
@@ -307,7 +312,7 @@ def gradio_interface():
307
  )
308
 
309
  # Ruta al checkpoint
310
- checkpoint_path = "iter/20500_nets_ema.ckpt"
311
 
312
  # Crear la interfaz de Gradio
313
  inputs = [
@@ -330,4 +335,4 @@ def gradio_interface():
330
 
331
  if __name__ == '__main__':
332
  iface = gradio_interface()
333
- iface.launch(share=True)
 
11
  import gradio as gr
12
  import numpy as np
13
  import io
14
+ import tempfile # Importar tempfile
15
 
16
  # Aseg煤rate de que las funciones necesarias est茅n definidas (si no lo est谩n ya)
17
  def resize(img, size):
 
266
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
267
  ])
268
  # Convertir a PIL image antes de la transformaci贸n
269
+ with tempfile.NamedTemporaryFile(suffix=".jpg") as source_temp_file, \
270
+ tempfile.NamedTemporaryFile(suffix=".jpg") as reference_temp_file:
271
+ source_temp_file.write(source_image)
272
+ reference_temp_file.write(reference_image)
273
+ source_image_pil = Image.open(source_temp_file.name)
274
+ reference_image_pil = Image.open(reference_temp_file.name)
275
 
276
+ source_image = transform(source_image_pil).unsqueeze(0).to(self.device)
277
+ reference_image = transform(reference_image_pil).unsqueeze(0).to(self.device)
278
 
279
  # Codificar el estilo de la imagen de referencia
280
  s_ref = self.S(reference_image, torch.tensor([0]).to(self.device))
 
312
  )
313
 
314
  # Ruta al checkpoint
315
+ checkpoint_path = "iter/10500_nets_ema.ckpt"
316
 
317
  # Crear la interfaz de Gradio
318
  inputs = [
 
335
 
336
  if __name__ == '__main__':
337
  iface = gradio_interface()
338
+ iface.launch(share=True)