tincri commited on
Commit
fd00255
1 Parent(s): 4a42a42

app.py and dependencies

Browse files
Files changed (2) hide show
  1. app.py +239 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ import random
9
+ import torch.nn.functional as F
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ # DEFINICI脫N DE BLOQUES DE RED
13
+ class ResBlk(nn.Module):
14
+ def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
15
+ super().__init__()
16
+ self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
17
+ self.norm1 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
18
+ self.relu1 = nn.ReLU(inplace=True)
19
+ self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
20
+ self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) if normalize else None
21
+ self.relu2 = nn.ReLU(inplace=True)
22
+ self.downsample = downsample
23
+ if self.downsample:
24
+ self.avg_pool = nn.AvgPool2d(2)
25
+
26
+ def forward(self, x):
27
+ residual = x
28
+ out = self.conv1(x)
29
+ if self.norm1:
30
+ out = self.norm1(out)
31
+ out = self.relu1(out)
32
+ out = self.conv2(out)
33
+ if self.norm2:
34
+ out = self.norm2(out)
35
+ out = self.relu2(out)
36
+ if self.downsample:
37
+ out = self.avg_pool(out)
38
+ residual = self.avg_pool(residual)
39
+ out = out + residual
40
+ return out
41
+
42
+ class AdainResBlk(nn.Module):
43
+ def __init__(self, dim_in, dim_out, style_dim, upsample=False):
44
+ super().__init__()
45
+ self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
46
+ self.norm1 = AdaIN(dim_out, style_dim)
47
+ self.relu1 = nn.ReLU(inplace=True)
48
+ self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
49
+ self.norm2 = AdaIN(dim_out, style_dim)
50
+ self.relu2 = nn.ReLU(inplace=True)
51
+ self.upsample = upsample
52
+
53
+ def forward(self, x, s):
54
+ residual = x
55
+ if self.upsample:
56
+ residual = F.interpolate(residual, scale_factor=2, mode='nearest')
57
+ out = self.conv1(x)
58
+ out = self.norm1(out, s)
59
+ out = self.relu1(out)
60
+ if self.upsample:
61
+ out = F.interpolate(out, scale_factor=2, mode='nearest')
62
+ out = self.conv2(out)
63
+ out = self.norm2(out, s)
64
+ out = self.relu2(out)
65
+ out = out + residual
66
+ return out
67
+
68
+ class AdaIN(nn.Module):
69
+ def __init__(self, num_features, style_dim):
70
+ super().__init__()
71
+ self.norm = nn.InstanceNorm2d(num_features, affine=False)
72
+ self.fc = nn.Linear(style_dim, num_features * 2)
73
+
74
+ def forward(self, x, s):
75
+ h = self.fc(s)
76
+ gamma, beta = torch.chunk(h, 2, dim=1)
77
+ gamma = gamma.unsqueeze(2).unsqueeze(3)
78
+ beta = beta.unsqueeze(2).unsqueeze(3)
79
+ return (1 + gamma) * self.norm(x) + beta
80
+
81
+ class MappingNetwork(nn.Module):
82
+ def __init__(self, latent_dim, style_dim, num_domains):
83
+ super().__init__()
84
+ layers = []
85
+ layers += [nn.Linear(latent_dim + num_domains, 512)]
86
+ layers += [nn.ReLU()]
87
+ for _ in range(3):
88
+ layers += [nn.Linear(512, 512)]
89
+ layers += [nn.ReLU()]
90
+ self.shared = nn.Sequential(*layers)
91
+ self.unshared = nn.ModuleList()
92
+ for _ in range(num_domains):
93
+ self.unshared += [nn.Linear(512, style_dim)]
94
+
95
+ def forward(self, z, y):
96
+ h = torch.cat([z, y], dim=1)
97
+ h = self.shared(h)
98
+ out = []
99
+ for layer in self.unshared:
100
+ out += [layer(h)]
101
+ out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
102
+ idx = torch.LongTensor(range(y.size(0))).unsqueeze(1).to(y.device)
103
+ s = torch.gather(out, 1, idx.unsqueeze(2).expand(-1, -1, out.size(2))).squeeze(1)
104
+ return s
105
+
106
+ class StyleEncoder(nn.Module):
107
+ def __init__(self, img_size=256, style_dim=64, num_domains=3, max_conv_dim=512):
108
+ super().__init__()
109
+ dim_in = 64
110
+ blocks = []
111
+ blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
112
+ repeat_num = int(np.log2(img_size)) - 2
113
+ for _ in range(repeat_num):
114
+ dim_out = min(dim_in*2, max_conv_dim)
115
+ blocks += [ResBlk(dim_in, dim_out, downsample=True)]
116
+ dim_in = dim_out
117
+ self.shared = nn.Sequential(*blocks)
118
+ self.unshared = nn.ModuleList()
119
+ for _ in range(num_domains):
120
+ self.unshared += [nn.Linear(dim_in * (img_size // (2**repeat_num))**2, style_dim)]
121
+
122
+ def forward(self, x, y):
123
+ h = self.shared(x)
124
+ h = h.view(h.size(0), -1)
125
+ out = []
126
+ for layer in self.unshared:
127
+ out += [layer(h)]
128
+ out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
129
+ idx = torch.LongTensor(range(y.size(0))).unsqueeze(1).to(y.device)
130
+ s = torch.gather(out, 1, idx.unsqueeze(2).expand(-1, -1, out.size(2))).squeeze(1)
131
+ return s
132
+
133
+ # DEFINICI脫N DEL GENERADOR
134
+ class Generator(nn.Module):
135
+ def __init__(self, img_size=256, style_dim=64, max_conv_dim=512):
136
+ super().__init__()
137
+ dim_in = 64
138
+ blocks = []
139
+ blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
140
+ repeat_num = int(np.log2(img_size)) - 4
141
+ for _ in range(repeat_num):
142
+ dim_out = min(dim_in*2, max_conv_dim)
143
+ blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
144
+ dim_in = dim_out
145
+ self.encode = nn.Sequential(*blocks)
146
+
147
+ self.decode = nn.ModuleList()
148
+ for _ in range(repeat_num):
149
+ dim_out = dim_in // 2
150
+ self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)]
151
+ dim_in = dim_out
152
+ self.to_rgb = nn.Sequential(
153
+ nn.InstanceNorm2d(dim_in, affine=True),
154
+ nn.ReLU(inplace=True),
155
+ nn.Conv2d(dim_in, 3, 1, 1, 0)
156
+ )
157
+
158
+ def forward(self, x, s):
159
+ x = self.encode(x)
160
+ for block in self.decode:
161
+ x = block(x, s)
162
+ out = self.to_rgb(x)
163
+ return out
164
+
165
+ # FUNCI脫N PARA CARGAR EL MODELO
166
+ def load_pretrained_model(ckpt_path, img_size=256, style_dim=64, num_domains=3, device='cpu'):
167
+ G = Generator(img_size, style_dim).to(device)
168
+ M = MappingNetwork(16, style_dim, num_domains).to(device) # Suponiendo latent_dim=16
169
+ S = StyleEncoder(img_size, style_dim, num_domains).to(device)
170
+ checkpoint = torch.load(ckpt_path, map_location=device)
171
+ G.load_state_dict(checkpoint['generator'])
172
+ M.load_state_dict(checkpoint['mapping_network'])
173
+ S.load_state_dict(checkpoint['style_encoder'])
174
+ G.eval()
175
+ S.eval()
176
+ return G, S
177
+
178
+ # FUNCI脫N PARA COMBINAR ESTILOS
179
+ def combine_styles(source_image, reference_image, generator, style_encoder, target_domain_idx, device='cpu'):
180
+ transform = transforms.Compose([
181
+ transforms.Resize((256, 256)), # Ajustar al tama帽o de entrada de tu modelo
182
+ transforms.ToTensor(),
183
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
184
+ ])
185
+
186
+ source_img = transform(source_image).unsqueeze(0).to(device)
187
+ reference_img = transform(reference_image).unsqueeze(0).to(device)
188
+ target_domain = torch.tensor([target_domain_idx]).unsqueeze(0).to(device) # Crear un tensor para el dominio objetivo
189
+
190
+ with torch.no_grad():
191
+ style_ref = style_encoder(reference_img, target_domain) # Usar el mismo 铆ndice de dominio que la referencia
192
+ generated_image = generator(source_img, style_ref)
193
+ generated_image = (generated_image + 1) / 2.0 # Desnormalizar a [0, 1]
194
+ generated_image = generated_image.squeeze(0).cpu().permute(1, 2, 0).numpy()
195
+ generated_image = (generated_image * 255).astype(np.uint8)
196
+ return Image.fromarray(generated_image)
197
+
198
+ # CONFIGURACI脫N DE GRADIO
199
+ def create_interface(generator, style_encoder, domain_names, device='cpu'):
200
+ def predict(source_img, ref_img, target_domain):
201
+ target_domain_idx = domain_names.index(target_domain)
202
+ return combine_styles(source_img, ref_img, generator, style_encoder, target_domain_idx, device)
203
+
204
+ iface = gr.Interface(
205
+ fn=predict,
206
+ inputs=[
207
+ gr.Image(label="Imagen Fuente"),
208
+ gr.Image(label="Imagen de Referencia"),
209
+ gr.Dropdown(choices=domain_names, label="Dominio de Referencia (para el estilo)"),
210
+ ],
211
+ outputs=gr.Image(label="Imagen Generada"),
212
+ title="AutoStyleGAN - Transferencia de Estilo de Carros",
213
+ description="Selecciona una imagen de carro fuente y una imagen de carro de referencia para transferir el estilo de la referencia a la fuente."
214
+ )
215
+ return iface
216
+
217
+ if __name__ == '__main__':
218
+ #CARGAR EL MODELO ENTRENADO
219
+ checkpoint_path = '10000_nets_ema.ckpt'
220
+ img_size = 128
221
+ style_dim = 64
222
+ num_domains = 3
223
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
224
+
225
+ try:
226
+ generator, style_encoder = load_pretrained_model(checkpoint_path, img_size, style_dim, num_domains, device)
227
+ print("Modelo cargado exitosamente.")
228
+
229
+ #DEFINIR LOS NOMBRES DE LOS DOMINIOS
230
+ domain_names = ["BMW", "Corvette", "Mazda"]
231
+
232
+ # CREAR E LANZAR LA INTERFAZ DE GRADIO
233
+ iface = create_interface(generator, style_encoder, domain_names, device)
234
+ iface.launch(share=True)
235
+
236
+ except FileNotFoundError:
237
+ print(f"Error: No se encontr贸 el archivo de checkpoint en '{checkpoint_path}'. Aseg煤rate de proporcionar la ruta correcta.")
238
+ except Exception as e:
239
+ print(f"Ocurri贸 un error al cargar el modelo: {e}")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ numpy
6
+ huggingface_hub