tincri commited on
Commit
4e3e87f
1 Parent(s): cf9984f

Fix #17 app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -137
app.py CHANGED
@@ -1,15 +1,23 @@
1
- import gradio as gr
2
  import torch
3
- from torch import nn
 
 
4
  from torchvision import transforms
5
  from PIL import Image
6
- import numpy as np
7
  import os
 
8
  import random
9
- import torch.nn.functional as F
10
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
11
 
12
- # DEFINICI脫N DE BLOQUES DE RED
13
  class ResBlk(nn.Module):
14
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
  super().__init__()
@@ -32,18 +40,28 @@ class ResBlk(nn.Module):
32
  skip = self.downsample_layer(skip)
33
  return (out + skip) / math.sqrt(2)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class AdainResBlk(nn.Module):
36
  def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
37
  super().__init__()
38
  self.upsample = upsample
39
  self.w_hpf = w_hpf
40
-
41
  self.norm1 = AdaIN(dim_in, style_dim)
42
  self.norm2 = AdaIN(dim_out, style_dim)
43
  self.actv = nn.LeakyReLU(0.2)
44
  self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
45
  self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
46
-
47
  if dim_in != dim_out:
48
  self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0)
49
  else:
@@ -51,59 +69,74 @@ class AdainResBlk(nn.Module):
51
 
52
  def forward(self, x, s):
53
  x_orig = x
54
-
55
  if self.upsample:
56
  x = F.interpolate(x, scale_factor=2, mode='nearest')
57
  x_orig = F.interpolate(x_orig, scale_factor=2, mode='nearest')
58
-
59
  h = self.norm1(x, s)
60
  h = self.actv(h)
61
  h = self.conv1(h)
62
-
63
  h = self.norm2(h, s)
64
  h = self.actv(h)
65
  h = self.conv2(h)
66
-
67
  skip = self.skip(x_orig)
68
-
69
  out = (h + skip) / math.sqrt(2)
70
  return out
71
 
72
- class AdaIN(nn.Module):
73
- def __init__(self, num_features, style_dim):
74
- super(AdaIN, self).__init__()
75
- self.fc = nn.Linear(style_dim, num_features * 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def forward(self, x, s):
78
- h = self.fc(s)
79
- gamma, beta = torch.chunk(h, chunks=2, dim=1)
80
- gamma = gamma.unsqueeze(2).unsqueeze(3)
81
- beta = beta.unsqueeze(2).unsqueeze(3)
82
- return (1 + gamma) * x + beta
83
 
84
  class MappingNetwork(nn.Module):
85
- def __init__(self, latent_dim, style_dim, num_domains):
86
- super().__init__()
87
- layers = []
88
- layers += [nn.Linear(latent_dim + num_domains, 512)]
89
- layers += [nn.ReLU()]
 
90
  for _ in range(3):
91
- layers += [nn.Linear(512, 512)]
92
- layers += [nn.ReLU()]
 
 
93
  self.shared = nn.Sequential(*layers)
94
  self.unshared = nn.ModuleList()
95
  for _ in range(num_domains):
96
- self.unshared += [nn.Linear(512, style_dim)]
97
 
98
  def forward(self, z, y):
99
- h = torch.cat([z, y], dim=1)
100
- h = self.shared(h)
101
  out = []
102
  for layer in self.unshared:
103
- out += [layer(h)]
104
- out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
105
- idx = torch.LongTensor(range(y.size(0))).unsqueeze(1).to(y.device)
106
- s = torch.gather(out, 1, idx.unsqueeze(2).expand(-1, -1, out.size(2))).squeeze(1)
107
  return s
108
 
109
  class StyleEncoder(nn.Module):
@@ -115,11 +148,10 @@ class StyleEncoder(nn.Module):
115
  repeat_num = int(np.log2(img_size)) - 2
116
  for _ in range(repeat_num):
117
  dim_out = min(dim_in*2, max_conv_dim)
118
- blocks += [ResBlk(dim_in, dim_out, downsample=True)]
119
  dim_in = dim_out
120
  blocks += [nn.LeakyReLU(0.2)]
121
  self.shared = nn.Sequential(*blocks)
122
-
123
  self.unshared = nn.ModuleList()
124
  for _ in range(num_domains):
125
  self.unshared += [nn.Linear(dim_in, style_dim)]
@@ -136,113 +168,168 @@ class StyleEncoder(nn.Module):
136
  s = out[idx, y]
137
  return s
138
 
139
- # DEFINICI脫N DEL GENERADOR
140
- class Generator(nn.Module):
141
- def __init__(self, img_size=256, style_dim=64, max_conv_dim=512):
142
- super().__init__()
143
- dim_in = 64
144
- blocks = []
145
- blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
146
- repeat_num = int(np.log2(img_size)) - 4
147
- for _ in range(repeat_num):
148
- dim_out = min(dim_in*2, max_conv_dim)
149
- blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
150
- dim_in = dim_out
151
- self.encode = nn.Sequential(*blocks)
152
-
153
- self.decode = nn.ModuleList()
154
- for i in range(repeat_num):
155
- dim_out = dim_in // 2
156
- self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)]
157
- dim_in = dim_out
158
- self.to_rgb = nn.Sequential(
159
- nn.InstanceNorm2d(dim_in, affine=True),
160
- nn.ReLU(inplace=True),
161
- nn.Conv2d(dim_in, 3, 1, 1, 0)
162
- )
163
-
164
- def forward(self, x, s):
165
- x = self.encode(x)
166
- for block in self.decode:
167
- x = block(x, s)
168
- out = self.to_rgb(x)
169
- return out
170
-
171
- # FUNCI脫N PARA CARGAR EL MODELO
172
- def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
173
- num_domains_mappin = 3
174
- latent_dim_for_mapping = 13
175
- G = Generator(img_size, style_dim).to(device)
176
- M = MappingNetwork(latent_dim_for_mapping, style_dim, num_domains_mappin).to(device)
177
- S = StyleEncoder(img_size, style_dim, num_domains).to(device)
178
- checkpoint = torch.load(ckpt_path, map_location=device)
179
- G.load_state_dict(checkpoint['generator'])
180
- M.load_state_dict(checkpoint['mapping_network'])
181
- S.load_state_dict(checkpoint['style_encoder'])
182
- G.eval()
183
- S.eval()
184
- return G, S
185
-
186
- # FUNCI脫N PARA COMBINAR ESTILOS
187
- def combine_styles(source_image, reference_image, generator, style_encoder, target_domain_idx, device='cpu'):
188
  transform = transforms.Compose([
189
- transforms.Resize((256, 256)), # Ajustar al tama帽o de entrada de tu modelo
 
190
  transforms.ToTensor(),
191
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
192
  ])
 
 
 
193
 
194
- source_img = transform(source_image).unsqueeze(0).to(device)
195
- reference_img = transform(reference_image).unsqueeze(0).to(device)
196
- target_domain = torch.tensor([target_domain_idx]).unsqueeze(0).to(device) # Crear un tensor para el dominio objetivo
197
-
198
- with torch.no_grad():
199
- style_ref = style_encoder(reference_img, target_domain) # Usar el mismo 铆ndice de dominio que la referencia
200
- generated_image = generator(source_img, style_ref)
201
- generated_image = (generated_image + 1) / 2.0 # Desnormalizar a [0, 1]
202
- generated_image = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
203
- generated_image = (generated_image * 255).astype(np.uint8)
204
- return Image.fromarray(generated_image)
205
-
206
- # CONFIGURACI脫N DE GRADIO
207
- def create_interface(generator, style_encoder, domain_names, device='cpu'):
208
- def predict(source_img, ref_img, target_domain):
209
- target_domain_idx = domain_names.index(target_domain)
210
- return combine_styles(source_img, ref_img, generator, style_encoder, target_domain_idx, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  iface = gr.Interface(
213
- fn=predict,
214
- inputs=[
215
- gr.Image(label="Imagen Fuente"),
216
- gr.Image(label="Imagen de Referencia"),
217
- gr.Dropdown(choices=domain_names, label="Dominio de Referencia (para el estilo)"),
218
- ],
219
- outputs=gr.Image(label="Imagen Generada"),
220
- title="AutoStyleGAN - Transferencia de Estilo de Carros",
221
- description="Selecciona una imagen de carro fuente y una imagen de carro de referencia para transferir el estilo de la referencia a la fuente."
222
  )
223
  return iface
224
 
225
-
226
  if __name__ == '__main__':
227
- #CARGAR EL MODELO ENTRENADO
228
- checkpoint_path = 'iter/12500_nets_ema.ckpt'
229
- img_size = 128
230
- style_dim = 64
231
- num_domains = 2
232
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
233
-
234
- try:
235
- generator, style_encoder = load_pretrained_model(checkpoint_path, img_size, style_dim, num_domains, device)
236
- print("Modelo cargado exitosamente.")
237
-
238
- # DEFINIR LOS NOMBRES DE LOS DOMINIOS
239
- domain_names = ["BMW", "Corvette", "Mazda"]
240
-
241
- # CREAR E LANZAR LA INTERFAZ DE GRADIO
242
- iface = create_interface(generator, style_encoder, domain_names, device)
243
- iface.launch(share=True)
244
-
245
- except FileNotFoundError:
246
- print(f"Error: No se encontr贸 el archivo de checkpoint en '{checkpoint_path}'. Aseg煤rate de proporcionar la ruta correcta.")
247
- except Exception as e:
248
- print(f"Ocurri贸 un error al cargar el modelo: {e}")
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
  from torchvision import transforms
6
  from PIL import Image
 
7
  import os
8
+ from types import SimpleNamespace
9
  import random
10
+ from torchvision.utils import save_image
11
+ import gradio as gr # Importamos Gradio
12
+
13
+ # Aseg煤rate de que las funciones necesarias est茅n definidas (si no lo est谩n ya)
14
+ def resize(img, size):
15
+ return F.interpolate(img, size=size, mode='bilinear', align_corners=False)
16
+
17
+ def denormalize(x):
18
+ return (x + 1) / 2
19
 
20
+ # Definici贸n de las clases de los modelos (Generator, StyleEncoder, MappingNetwork, ResBlk, AdaIN, AdainResBlk)
21
  class ResBlk(nn.Module):
22
  def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
23
  super().__init__()
 
40
  skip = self.downsample_layer(skip)
41
  return (out + skip) / math.sqrt(2)
42
 
43
+ class AdaIN(nn.Module):
44
+ def __init__(self, num_features, style_dim):
45
+ super(AdaIN, self).__init__()
46
+ self.fc = nn.Linear(style_dim, num_features * 2)
47
+
48
+ def forward(self, x, s):
49
+ h = self.fc(s)
50
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
51
+ gamma = gamma.unsqueeze(2).unsqueeze(3)
52
+ beta = beta.unsqueeze(2).unsqueeze(3)
53
+ return (1 + gamma) * x + beta
54
+
55
  class AdainResBlk(nn.Module):
56
  def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
57
  super().__init__()
58
  self.upsample = upsample
59
  self.w_hpf = w_hpf
 
60
  self.norm1 = AdaIN(dim_in, style_dim)
61
  self.norm2 = AdaIN(dim_out, style_dim)
62
  self.actv = nn.LeakyReLU(0.2)
63
  self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
64
  self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
 
65
  if dim_in != dim_out:
66
  self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0)
67
  else:
 
69
 
70
  def forward(self, x, s):
71
  x_orig = x
 
72
  if self.upsample:
73
  x = F.interpolate(x, scale_factor=2, mode='nearest')
74
  x_orig = F.interpolate(x_orig, scale_factor=2, mode='nearest')
 
75
  h = self.norm1(x, s)
76
  h = self.actv(h)
77
  h = self.conv1(h)
 
78
  h = self.norm2(h, s)
79
  h = self.actv(h)
80
  h = self.conv2(h)
 
81
  skip = self.skip(x_orig)
 
82
  out = (h + skip) / math.sqrt(2)
83
  return out
84
 
85
+ class Generator(nn.Module):
86
+ def __init__(self, img_size=256, style_dim=64, max_conv_dim=512):
87
+ super().__init__()
88
+ dim_in = 64
89
+ blocks = []
90
+ blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
91
+ repeat_num = int(np.log2(img_size)) - 4
92
+ for _ in range(repeat_num):
93
+ dim_out = min(dim_in*2, max_conv_dim)
94
+ blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
95
+ dim_in = dim_out
96
+ self.encode = nn.Sequential(*blocks)
97
+ self.decode = nn.ModuleList()
98
+ for _ in range(repeat_num):
99
+ dim_out = dim_in // 2
100
+ self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)]
101
+ dim_in = dim_out
102
+ self.to_rgb = nn.Sequential(
103
+ nn.InstanceNorm2d(dim_in, affine=True),
104
+ nn.ReLU(inplace=True),
105
+ nn.Conv2d(dim_in, 3, 1, 1, 0)
106
+ )
107
 
108
  def forward(self, x, s):
109
+ x = self.encode(x)
110
+ for block in self.decode:
111
+ x = block(x, s)
112
+ out = self.to_rgb(x)
113
+ return out
114
 
115
  class MappingNetwork(nn.Module):
116
+ def __init__(self, latent_dim=16, style_dim=64, num_domains=2, hidden_dim=512):
117
+ super(MappingNetwork, self).__init__()
118
+ layers = [
119
+ nn.Linear(latent_dim, hidden_dim),
120
+ nn.ReLU()
121
+ ]
122
  for _ in range(3):
123
+ layers += [
124
+ nn.Linear(hidden_dim, hidden_dim),
125
+ nn.ReLU()
126
+ ]
127
  self.shared = nn.Sequential(*layers)
128
  self.unshared = nn.ModuleList()
129
  for _ in range(num_domains):
130
+ self.unshared.append(nn.Linear(hidden_dim, style_dim))
131
 
132
  def forward(self, z, y):
133
+ h = self.shared(z)
 
134
  out = []
135
  for layer in self.unshared:
136
+ out.append(layer(h))
137
+ out = torch.stack(out, dim=1)
138
+ idx = torch.arange(y.size(0)).to(y.device)
139
+ s = out[idx, y]
140
  return s
141
 
142
  class StyleEncoder(nn.Module):
 
148
  repeat_num = int(np.log2(img_size)) - 2
149
  for _ in range(repeat_num):
150
  dim_out = min(dim_in*2, max_conv_dim)
151
+ blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
152
  dim_in = dim_out
153
  blocks += [nn.LeakyReLU(0.2)]
154
  self.shared = nn.Sequential(*blocks)
 
155
  self.unshared = nn.ModuleList()
156
  for _ in range(num_domains):
157
  self.unshared += [nn.Linear(dim_in, style_dim)]
 
168
  s = out[idx, y]
169
  return s
170
 
171
+ # Clase para cargar imagenes
172
+ class ImageFolder(Dataset):
173
+ def __init__(self, root, transform, mode, which='source'):
174
+ self.transform = transform
175
+ self.paths = []
176
+ domains = sorted(os.listdir(root))
177
+ for domain in domains:
178
+ if os.path.isdir(os.path.join(root, domain)):
179
+ files = os.listdir(os.path.join(root, domain))
180
+ files = [os.path.join(root, domain, f) for f in files]
181
+ self.paths += [(f, domains.index(domain)) for f in files]
182
+ if mode == 'train' and which == 'reference':
183
+ random.shuffle(self.paths)
184
+
185
+ def __getitem__(self, index):
186
+ path, label = self.paths[index]
187
+ img = Image.open(path).convert('RGB')
188
+ return self.transform(img), label
189
+
190
+ def __len__(self):
191
+ return len(self.paths)
192
+
193
+ # Funciones para obtener los data loaders
194
+ def get_transform(img_size, mode='train', prob=0.5):
195
+ transform = []
196
+ transform.append(transforms.Resize((img_size, img_size)))
197
+ if mode == 'train':
198
+ transform.append(transforms.RandomHorizontalFlip())
199
+ transform.append(transforms.RandomApply([
200
+ transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0))
201
+ ], p=prob))
202
+ transform.append(transforms.ToTensor())
203
+ transform.append(transforms.Normalize(mean=[0.5, 0.5, 0.5],
204
+ std=[0.5, 0.5, 0.5]))
205
+ return transforms.Compose(transform)
206
+
207
+ def get_train_loader(root, which='source', img_size=256, batch_size=8, prob=0.5, num_workers=4):
 
 
 
 
 
 
 
 
 
 
 
 
208
  transform = transforms.Compose([
209
+ transforms.Resize((img_size, img_size)),
210
+ transforms.RandomHorizontalFlip(p=prob),
211
  transforms.ToTensor(),
212
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
213
  ])
214
+ dataset = ImageFolder(root=root, transform=transform, mode=which)
215
+ loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
216
+ return loader
217
 
218
+ def get_test_loader(root, img_size=256, batch_size=8, shuffle=False, num_workers=4, mode='reference'):
219
+ transform = transforms.Compose([
220
+ transforms.Resize((img_size, img_size)),
221
+ transforms.ToTensor(),
222
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
223
+ ])
224
+ dataset = ImageFolder(root=root, transform=transform, mode=mode)
225
+ loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=False)
226
+ return loader
227
+
228
+ # Clase Solver (adaptada para la inferencia)
229
+ class Solver(object):
230
+ def __init__(self, args):
231
+ self.args = args
232
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
233
+
234
+ # Definir los modelos
235
+ self.G = Generator(args.img_size, args.style_dim).to(self.device)
236
+ self.M = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains).to(self.device)
237
+ self.S = StyleEncoder(args.img_size, args.style_dim, args.num_domains).to(self.device)
238
+
239
+ def load_checkpoint(self, checkpoint_path):
240
+ try:
241
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
242
+ self.G.load_state_dict(checkpoint['generator'])
243
+ self.M.load_state_dict(checkpoint['mapping_network'])
244
+ self.S.load_state_dict(checkpoint['style_encoder'])
245
+ print(f"Checkpoint cargado exitosamente desde {checkpoint_path}.")
246
+ except FileNotFoundError:
247
+ print(f"Error: No se encontr贸 el checkpoint en {checkpoint_path}.")
248
+ raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}")
249
+ except Exception as e:
250
+ print(f"Error al cargar el checkpoint: {e}.")
251
+ raise Exception(f"Error al cargar el checkpoint: {e}")
252
+
253
+ def transfer_style(self, source_image, reference_image, target_domain_index):
254
+ # Aseg煤rate de que los modelos est茅n en modo de evaluaci贸n
255
+ self.G.eval()
256
+ self.S.eval()
257
+
258
+ with torch.no_grad():
259
+ # Preprocesar las im谩genes de entrada
260
+ transform = transforms.Compose([
261
+ transforms.Resize((self.args.img_size, self.args.img_size)),
262
+ transforms.ToTensor(),
263
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
264
+ ])
265
+ source_image = transform(source_image).unsqueeze(0).to(self.device)
266
+ reference_image = transform(reference_image).unsqueeze(0).to(self.device)
267
+
268
+ # Crear el tensor de dominio objetivo
269
+ target_domain = torch.tensor([target_domain_index]).to(self.device)
270
+
271
+ # Codificar el estilo de la imagen de referencia
272
+ s_ref = self.S(reference_image, target_domain)
273
+
274
+ # Generar la imagen con el estilo transferido
275
+ generated_image = self.G(source_image, s_ref)
276
+
277
+ # Denormalizar la imagen para mostrarla o guardarla
278
+ generated_image = denormalize(generated_image.squeeze(0)).cpu()
279
+ return generated_image
280
+
281
+ # Funci贸n principal para la inferencia
282
+ def main(args, checkpoint_path, source_image, reference_image, target_domain_index): # Cambiamos los paths por las im谩genes
283
+ # Crear el solver
284
+ solver = Solver(args)
285
+ # Cargar el checkpoint
286
+ solver.load_checkpoint(checkpoint_path)
287
+
288
+ # Realizar la transferencia de estilo
289
+ generated_image = solver.transfer_style(source_image, reference_image, target_domain_index)
290
+
291
+ return generated_image
292
+
293
+ def gradio_interface(checkpoint_path="iter/20500_nets_ema.ckpt", img_size=128, num_domains=3): # Agregamos los valores por defecto
294
+ # Interfaz de Gradio
295
+ inputs = [
296
+ gr.Image(label="Source Image", type="pil"), # Especificamos el tipo de imagen como PIL
297
+ gr.Image(label="Reference Image", type="pil"),
298
+ gr.Radio(choices=["BMW", "Corvette", "Mazda"], label="Target Domain", default="BMW")
299
+ ]
300
+ outputs = gr.Image(label="Generated Image")
301
+
302
+ def process_images(source_image, reference_image, target_domain):
303
+ # Mapear el dominio seleccionado a un 铆ndice
304
+ domain_index = {"BMW": 0, "Corvette": 1, "Mazda": 2}[target_domain]
305
+
306
+ # Definir los argumentos (ajustados para la inferencia)
307
+ args = SimpleNamespace(
308
+ img_size=img_size, # Aseg煤rate de que esto coincida con el tama帽o de imagen usado en el entrenamiento
309
+ num_domains=num_domains, #args.num_domains, # Cambiado a 3 para que coincida con el checkpoint del MappingNetwork
310
+ latent_dim=16, # Puedes ajustar esto si es necesario
311
+ style_dim=64,
312
+ num_workers=0, # Establecer en 0 para evitar problemas en algunos entornos
313
+ seed=8365,
314
+ )
315
+ try:
316
+ # Llamar a la funci贸n principal para realizar la inferencia
317
+ generated_image = main(args, checkpoint_path, source_image, reference_image, domain_index)
318
+ return generated_image
319
+ except Exception as e:
320
+ print(f"Error during processing: {e}")
321
+ return None # Devolvemos None en caso de error
322
 
323
  iface = gr.Interface(
324
+ fn=process_images,
325
+ inputs=inputs,
326
+ outputs=outputs,
327
+ title="AutoStyleGAN Demo",
328
+ description="Transfer the style of a reference car image to a source car image. Select the target car domain.",
 
 
 
 
329
  )
330
  return iface
331
 
 
332
  if __name__ == '__main__':
333
+ # Lanzar la interfaz de Gradio
334
+ iface = gradio_interface()
335
+ iface.launch()