tincri commited on
Commit
dccf04b
·
1 Parent(s): 7bcd511
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -10,6 +10,7 @@ import random
10
  from torchvision.utils import save_image
11
  import gradio as gr
12
  import numpy as np
 
13
 
14
  # Asegúrate de que las funciones necesarias estén definidas (si no lo están ya)
15
  def resize(img, size):
@@ -251,7 +252,7 @@ class Solver(object):
251
  print(f"Error al cargar el checkpoint: {e}.")
252
  raise Exception(f"Error al cargar el checkpoint: {e}")
253
 
254
- def transfer_style(self, source_image, reference_image): # Eliminado target_domain_index
255
  # Asegúrate de que los modelos estén en modo de evaluación
256
  self.G.eval()
257
  self.S.eval()
@@ -264,17 +265,14 @@ class Solver(object):
264
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
265
  ])
266
  # Convertir a PIL image antes de la transformación
267
- source_image = Image.fromarray(source_image)
268
- reference_image = Image.fromarray(reference_image)
269
 
270
  source_image = transform(source_image).unsqueeze(0).to(self.device)
271
  reference_image = transform(reference_image).unsqueeze(0).to(self.device)
272
 
273
- # Crear el tensor de dominio objetivo
274
- # target_domain = torch.tensor([target_domain_index]).to(self.device) # Eliminado
275
-
276
  # Codificar el estilo de la imagen de referencia
277
- s_ref = self.S(reference_image, torch.tensor([0]).to(self.device)) # Simplificado
278
 
279
  # Generar la imagen con el estilo transferido
280
  generated_image = self.G(source_image, s_ref)
@@ -284,7 +282,7 @@ class Solver(object):
284
  return generated_image
285
 
286
  # Función principal para la inferencia
287
- def main(source_image, reference_image, checkpoint_path, args): # Eliminado target_domain_index
288
  if source_image is None or reference_image is None:
289
  raise gr.Error("Por favor, proporciona ambas imágenes (fuente y referencia).")
290
 
@@ -294,37 +292,35 @@ def main(source_image, reference_image, checkpoint_path, args): # Eliminado targ
294
  solver.load_checkpoint(checkpoint_path)
295
 
296
  # Realizar la transferencia de estilo
297
- generated_image = solver.transfer_style(source_image, reference_image) # Eliminado target_domain_index
298
  return generated_image
299
 
300
  def gradio_interface():
301
  # Definir los argumentos (ajustados para la inferencia)
302
  args = SimpleNamespace(
303
- img_size=128, # Asegúrate de que esto coincida con el tamaño de imagen usado en el entrenamiento
304
- num_domains=3, # Cambiado a 3 para que coincida con el checkpoint del MappingNetwork
305
- latent_dim=16, # Puedes ajustar esto si es necesario
306
  style_dim=64,
307
- num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
308
  seed=8365,
309
  )
310
 
311
  # Ruta al checkpoint
312
- checkpoint_path = "iter/10500_nets_ema.ckpt" # Reemplaza con la ruta correcta a tu checkpoint
313
 
314
  # Crear la interfaz de Gradio
315
  inputs = [
316
  gr.Image(label="Source Image (Car to change style)"),
317
  gr.Image(label="Reference Image (Style to transfer)"),
318
- # gr.Radio(choices=[0, 1, 2], label="Target Domain (0: BMW, 1: Corvette, 2: Mazda)", value=0), # Eliminado
319
  ]
320
  outputs = gr.Image(label="Generated Image (Car with transferred style)")
321
 
322
  title = "AutoStyleGAN: Car Style Transfer"
323
- description = "Transfer the style of one car to another. Upload a source car image and a reference car image." # Modificado
324
 
325
- # Crear la interfaz de Gradio
326
  iface = gr.Interface(
327
- fn=lambda source_image, reference_image: main(source_image, reference_image, checkpoint_path, args), # Eliminado target_domain_index
328
  inputs=inputs,
329
  outputs=outputs,
330
  title=title,
@@ -334,4 +330,4 @@ def gradio_interface():
334
 
335
  if __name__ == '__main__':
336
  iface = gradio_interface()
337
- iface.launch(share=True)
 
10
  from torchvision.utils import save_image
11
  import gradio as gr
12
  import numpy as np
13
+ import io
14
 
15
  # Asegúrate de que las funciones necesarias estén definidas (si no lo están ya)
16
  def resize(img, size):
 
252
  print(f"Error al cargar el checkpoint: {e}.")
253
  raise Exception(f"Error al cargar el checkpoint: {e}")
254
 
255
+ def transfer_style(self, source_image, reference_image):
256
  # Asegúrate de que los modelos estén en modo de evaluación
257
  self.G.eval()
258
  self.S.eval()
 
265
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
266
  ])
267
  # Convertir a PIL image antes de la transformación
268
+ source_image = Image.open(io.BytesIO(source_image)) # Use BytesIO
269
+ reference_image = Image.open(io.BytesIO(reference_image))
270
 
271
  source_image = transform(source_image).unsqueeze(0).to(self.device)
272
  reference_image = transform(reference_image).unsqueeze(0).to(self.device)
273
 
 
 
 
274
  # Codificar el estilo de la imagen de referencia
275
+ s_ref = self.S(reference_image, torch.tensor([0]).to(self.device))
276
 
277
  # Generar la imagen con el estilo transferido
278
  generated_image = self.G(source_image, s_ref)
 
282
  return generated_image
283
 
284
  # Función principal para la inferencia
285
+ def main(source_image, reference_image, checkpoint_path, args):
286
  if source_image is None or reference_image is None:
287
  raise gr.Error("Por favor, proporciona ambas imágenes (fuente y referencia).")
288
 
 
292
  solver.load_checkpoint(checkpoint_path)
293
 
294
  # Realizar la transferencia de estilo
295
+ generated_image = solver.transfer_style(source_image, reference_image)
296
  return generated_image
297
 
298
  def gradio_interface():
299
  # Definir los argumentos (ajustados para la inferencia)
300
  args = SimpleNamespace(
301
+ img_size=128,
302
+ num_domains=3,
303
+ latent_dim=16,
304
  style_dim=64,
305
+ num_workers=0,
306
  seed=8365,
307
  )
308
 
309
  # Ruta al checkpoint
310
+ checkpoint_path = "iter/20500_nets_ema.ckpt"
311
 
312
  # Crear la interfaz de Gradio
313
  inputs = [
314
  gr.Image(label="Source Image (Car to change style)"),
315
  gr.Image(label="Reference Image (Style to transfer)"),
 
316
  ]
317
  outputs = gr.Image(label="Generated Image (Car with transferred style)")
318
 
319
  title = "AutoStyleGAN: Car Style Transfer"
320
+ description = "Transfer the style of one car to another. Upload a source car image and a reference car image."
321
 
 
322
  iface = gr.Interface(
323
+ fn=lambda source_image, reference_image: main(source_image, reference_image, checkpoint_path, args),
324
  inputs=inputs,
325
  outputs=outputs,
326
  title=title,
 
330
 
331
  if __name__ == '__main__':
332
  iface = gradio_interface()
333
+ iface.launch(share=True)