radames commited on
Commit
67eaa47
·
1 Parent(s): ea476a5

fix end of line

Browse files

fix latents loading

Files changed (2) hide show
  1. interface/app.py +5 -5
  2. interface/model_loader.py +240 -242
interface/app.py CHANGED
@@ -28,7 +28,7 @@ def random_sample(model_name: str):
28
  return pil_img, model_name, latents
29
 
30
 
31
- def zoom(dx, dy, dz, model_state, latents_state):
32
  model = models[model_state]
33
  dx = dx
34
  dy = dy
@@ -43,7 +43,7 @@ def zoom(dx, dy, dz, model_state, latents_state):
43
  return pil_img, latents_state
44
 
45
 
46
- def translate(dx, dy, dz, model_state, latents_state):
47
  model = models[model_state]
48
 
49
  dx = dx
@@ -130,19 +130,19 @@ with gr.Blocks() as block:
130
  )
131
  dx.change(
132
  translate,
133
- inputs=[dx, dy, dz, model_state, latents_state],
134
  outputs=[image, latents_state],
135
  show_progress=False,
136
  )
137
  dy.change(
138
  translate,
139
- inputs=[dx, dy, dz, model_state, latents_state],
140
  outputs=[image, latents_state],
141
  show_progress=False,
142
  )
143
  dz.change(
144
  zoom,
145
- inputs=[dx, dy, dz, model_state, latents_state],
146
  outputs=[image, latents_state],
147
  show_progress=False,
148
  )
 
28
  return pil_img, model_name, latents
29
 
30
 
31
+ def zoom(model_state, latents_state, dx=0, dy=0, dz=0):
32
  model = models[model_state]
33
  dx = dx
34
  dy = dy
 
43
  return pil_img, latents_state
44
 
45
 
46
+ def translate(model_state, latents_state, dx=0, dy=0, dz=0):
47
  model = models[model_state]
48
 
49
  dx = dx
 
130
  )
131
  dx.change(
132
  translate,
133
+ inputs=[model_state, latents_state, dx, dy, dz],
134
  outputs=[image, latents_state],
135
  show_progress=False,
136
  )
137
  dy.change(
138
  translate,
139
+ inputs=[model_state, latents_state, dx, dy, dz],
140
  outputs=[image, latents_state],
141
  show_progress=False,
142
  )
143
  dz.change(
144
  zoom,
145
+ inputs=[model_state, latents_state, dx, dy, dz],
146
  outputs=[image, latents_state],
147
  show_progress=False,
148
  )
interface/model_loader.py CHANGED
@@ -1,242 +1,240 @@
1
- import os
2
- from argparse import Namespace
3
- import numpy as np
4
- import torch
5
-
6
- from models.StyleGANControler import StyleGANControler
7
-
8
-
9
- class Model:
10
- def __init__(
11
- self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
12
- ):
13
- self.truncation = truncation
14
- self.use_average_code_as_input = use_average_code_as_input
15
- ckpt = torch.load(checkpoint_path, map_location="cpu")
16
- opts = ckpt["opts"]
17
- opts["checkpoint_path"] = checkpoint_path
18
- self.opts = Namespace(**ckpt["opts"])
19
- self.net = StyleGANControler(self.opts)
20
- self.net.eval()
21
- self.net.cuda()
22
- self.target_layers = [0, 1, 2, 3, 4, 5]
23
-
24
- def random_sample(self):
25
- z1 = torch.randn(1, 512).to("cuda")
26
- x1, w1, f1 = self.net.decoder(
27
- [z1],
28
- input_is_latent=False,
29
- randomize_noise=False,
30
- return_feature_map=True,
31
- return_latents=True,
32
- truncation=self.truncation,
33
- truncation_latent=self.net.latent_avg[0],
34
- )
35
- w1_initial = w1.clone()
36
- x1 = self.net.face_pool(x1)
37
- image = (
38
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
39
- )
40
- return (
41
- image,
42
- {
43
- "w1": w1.cpu().detach().numpy(),
44
- "w1_initial": w1_initial.cpu().detach().numpy(),
45
- },
46
- ) # return latent vector along with the image
47
-
48
- def latents_to_tensor(self, latents):
49
- w1 = latents["w1"]
50
- w1_initial = latents["w1_initial"]
51
-
52
- w1 = torch.tensor(w1).to("cuda")
53
- w1_initial = torch.tensor(w1_initial).to("cuda")
54
-
55
- x1, w1 = self.net.decoder(
56
- [w1],
57
- input_is_latent=True,
58
- randomize_noise=False,
59
- return_feature_map=False,
60
- return_latents=True,
61
- truncation=self.truncation,
62
- truncation_latent=self.net.latent_avg[0],
63
- )
64
- x1, _, f1 = self.net.decoder(
65
- [w1_initial],
66
- input_is_latent=False,
67
- randomize_noise=False,
68
- return_feature_map=True,
69
- return_latents=True,
70
- truncation=self.truncation,
71
- truncation_latent=self.net.latent_avg[0],
72
- )
73
- return (w1, w1_initial, f1)
74
-
75
- def zoom(self, latents, dz, sxsy=[0, 0], stop_points=[]):
76
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
77
-
78
- vec_num = abs(dz) / 5
79
- dz = 100 * np.sign(dz)
80
- x = torch.from_numpy(np.array([[[1.0, 0, dz]]], dtype=np.float32)).cuda()
81
- f1 = torch.nn.functional.interpolate(f1, (256, 256))
82
- y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
83
-
84
- if len(stop_points) > 0:
85
- x = torch.cat(
86
- [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
87
- )
88
- tmp = []
89
- for sp in stop_points:
90
- tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
91
- y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
92
-
93
- if not self.use_average_code_as_input:
94
- w_hat = self.net.encoder(
95
- w1[:, self.target_layers].detach(),
96
- x.detach(),
97
- y.detach(),
98
- alpha=vec_num,
99
- )
100
- w1 = w1.clone()
101
- w1[:, self.target_layers] = w_hat
102
- else:
103
- w_hat = self.net.encoder(
104
- self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
105
- x.detach(),
106
- y.detach(),
107
- alpha=vec_num,
108
- )
109
- w1 = w1.clone()
110
- w1[:, self.target_layers] = (
111
- w1.clone()[:, self.target_layers]
112
- + w_hat
113
- - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
114
- )
115
-
116
- x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
117
-
118
- x1 = self.net.face_pool(x1)
119
- result = (
120
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
121
- )
122
- return (
123
- result,
124
- {
125
- "w1": w1.cpu().detach().numpy(),
126
- "w1_initial": w1_initial.cpu().detach().numpy(),
127
- },
128
- ) # return latent vector along with the image
129
-
130
- def translate(
131
- self, latents, dxy, sxsy=[0, 0], stop_points=[], zoom_in=False, zoom_out=False
132
- ):
133
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
134
-
135
- dz = -5.0 if zoom_in else 0.0
136
- dz = 5.0 if zoom_out else dz
137
-
138
- dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
139
- dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
140
- dxyz[:2] = dxyz[:2] / dxy_norm
141
- vec_num = dxy_norm / 10
142
-
143
- x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
144
- f1 = torch.nn.functional.interpolate(f1, (256, 256))
145
- y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
146
-
147
- if len(stop_points) > 0:
148
- x = torch.cat(
149
- [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
150
- )
151
- tmp = []
152
- for sp in stop_points:
153
- tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
154
- y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
155
-
156
- if not self.use_average_code_as_input:
157
- w_hat = self.net.encoder(
158
- w1[:, self.target_layers].detach(),
159
- x.detach(),
160
- y.detach(),
161
- alpha=vec_num,
162
- )
163
- w1 = w1.clone()
164
- w1[:, self.target_layers] = w_hat
165
- else:
166
- w_hat = self.net.encoder(
167
- self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
168
- x.detach(),
169
- y.detach(),
170
- alpha=vec_num,
171
- )
172
- w1 = w1.clone()
173
- w1[:, self.target_layers] = (
174
- w1.clone()[:, self.target_layers]
175
- + w_hat
176
- - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
177
- )
178
-
179
- x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
180
-
181
- x1 = self.net.face_pool(x1)
182
- result = (
183
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
184
- )
185
- return (
186
- result,
187
- {
188
- "w1": w1.cpu().detach().numpy(),
189
- "w1_initial": w1_initial.cpu().detach().numpy(),
190
- },
191
- )
192
-
193
- def change_style(self, latents):
194
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
195
-
196
- z1 = torch.randn(1, 512).to("cuda")
197
- x1, w2 = self.net.decoder(
198
- [z1],
199
- input_is_latent=False,
200
- randomize_noise=False,
201
- return_latents=True,
202
- truncation=self.truncation,
203
- truncation_latent=self.net.latent_avg[0],
204
- )
205
- w1[:, 6:] = w2.detach()[:, 0]
206
- x1, w1_new, f1 = self.net.decoder(
207
- [w1],
208
- input_is_latent=True,
209
- randomize_noise=False,
210
- return_feature_map=True,
211
- return_latents=True,
212
- )
213
- result = (
214
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
215
- )
216
- return (
217
- result,
218
- {
219
- "w1": w1_new.cpu().detach().numpy(),
220
- "w1_initial": w1_initial.cpu().detach().numpy(),
221
- },
222
- )
223
-
224
- def reset(self, latents):
225
- w1, w1_initial, f1 = self.latents_to_tensor(latents)
226
- x1, w1_new, f1 = self.net.decoder(
227
- [w1_initial],
228
- input_is_latent=True,
229
- randomize_noise=False,
230
- return_feature_map=True,
231
- return_latents=True,
232
- )
233
- result = (
234
- ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
235
- )
236
- return (
237
- result,
238
- {
239
- "w1": w1_new.cpu().detach().numpy(),
240
- "w1_initial": w1_initial.cpu().detach().numpy(),
241
- },
242
- )
 
1
+ import os
2
+ from argparse import Namespace
3
+ import numpy as np
4
+ import torch
5
+
6
+ from models.StyleGANControler import StyleGANControler
7
+
8
+
9
+ class Model:
10
+ def __init__(
11
+ self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
12
+ ):
13
+ self.truncation = truncation
14
+ self.use_average_code_as_input = use_average_code_as_input
15
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
16
+ opts = ckpt["opts"]
17
+ opts["checkpoint_path"] = checkpoint_path
18
+ self.opts = Namespace(**ckpt["opts"])
19
+ self.net = StyleGANControler(self.opts)
20
+ self.net.eval()
21
+ self.net.cuda()
22
+ self.target_layers = [0, 1, 2, 3, 4, 5]
23
+
24
+ def random_sample(self):
25
+ z1 = torch.randn(1, 512).to("cuda")
26
+ x1, w1, f1 = self.net.decoder(
27
+ [z1],
28
+ input_is_latent=False,
29
+ randomize_noise=False,
30
+ return_feature_map=True,
31
+ return_latents=True,
32
+ truncation=self.truncation,
33
+ truncation_latent=self.net.latent_avg[0],
34
+ )
35
+ w1_initial = w1.clone()
36
+ x1 = self.net.face_pool(x1)
37
+ image = (
38
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
39
+ )
40
+ return (
41
+ image,
42
+ {
43
+ "w1": w1.cpu().detach().numpy(),
44
+ "w1_initial": w1_initial.cpu().detach().numpy(),
45
+ },
46
+ ) # return latent vector along with the image
47
+
48
+ def latents_to_tensor(self, latents):
49
+ w1 = latents["w1"]
50
+ w1_initial = latents["w1_initial"]
51
+
52
+ w1 = torch.tensor(w1).to("cuda")
53
+ w1_initial = torch.tensor(w1_initial).to("cuda")
54
+
55
+ x1, w1, f1 = self.net.decoder(
56
+ [w1],
57
+ input_is_latent=True,
58
+ randomize_noise=False,
59
+ return_feature_map=True,
60
+ return_latents=True,
61
+ )
62
+ x1, w1_initial, f1 = self.net.decoder(
63
+ [w1_initial],
64
+ input_is_latent=True,
65
+ randomize_noise=False,
66
+ return_feature_map=True,
67
+ return_latents=True,
68
+ )
69
+
70
+ return (w1, w1_initial, f1)
71
+
72
+ def zoom(self, latents, dz, sxsy=[0, 0], stop_points=[]):
73
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
74
+ w1 = w1_initial.clone()
75
+
76
+ vec_num = abs(dz) / 5
77
+ dz = 100 * np.sign(dz)
78
+ x = torch.from_numpy(np.array([[[1.0, 0, dz]]], dtype=np.float32)).cuda()
79
+ f1 = torch.nn.functional.interpolate(f1, (256, 256))
80
+ y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
81
+
82
+ if len(stop_points) > 0:
83
+ x = torch.cat(
84
+ [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
85
+ )
86
+ tmp = []
87
+ for sp in stop_points:
88
+ tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
89
+ y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
90
+
91
+ if not self.use_average_code_as_input:
92
+ w_hat = self.net.encoder(
93
+ w1[:, self.target_layers].detach(),
94
+ x.detach(),
95
+ y.detach(),
96
+ alpha=vec_num,
97
+ )
98
+ w1 = w1.clone()
99
+ w1[:, self.target_layers] = w_hat
100
+ else:
101
+ w_hat = self.net.encoder(
102
+ self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
103
+ x.detach(),
104
+ y.detach(),
105
+ alpha=vec_num,
106
+ )
107
+ w1 = w1.clone()
108
+ w1[:, self.target_layers] = (
109
+ w1.clone()[:, self.target_layers]
110
+ + w_hat
111
+ - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
112
+ )
113
+
114
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
115
+
116
+ x1 = self.net.face_pool(x1)
117
+ result = (
118
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
119
+ )
120
+ return (
121
+ result,
122
+ {
123
+ "w1": w1.cpu().detach().numpy(),
124
+ "w1_initial": w1_initial.cpu().detach().numpy(),
125
+ },
126
+ ) # return latent vector along with the image
127
+
128
+ def translate(
129
+ self, latents, dxy, sxsy=[0, 0], stop_points=[], zoom_in=False, zoom_out=False
130
+ ):
131
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
132
+ w1 = w1_initial.clone()
133
+ dz = -5.0 if zoom_in else 0.0
134
+ dz = 5.0 if zoom_out else dz
135
+
136
+ dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
137
+ dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
138
+ dxyz[:2] = dxyz[:2] / dxy_norm
139
+ vec_num = dxy_norm / 10
140
+
141
+ x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
142
+ f1 = torch.nn.functional.interpolate(f1, (256, 256))
143
+ y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
144
+
145
+ if len(stop_points) > 0:
146
+ x = torch.cat(
147
+ [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
148
+ )
149
+ tmp = []
150
+ for sp in stop_points:
151
+ tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
152
+ y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
153
+
154
+ if not self.use_average_code_as_input:
155
+ w_hat = self.net.encoder(
156
+ w1[:, self.target_layers].detach(),
157
+ x.detach(),
158
+ y.detach(),
159
+ alpha=vec_num,
160
+ )
161
+ w1 = w1.clone()
162
+ w1[:, self.target_layers] = w_hat
163
+ else:
164
+ w_hat = self.net.encoder(
165
+ self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
166
+ x.detach(),
167
+ y.detach(),
168
+ alpha=vec_num,
169
+ )
170
+ w1 = w1.clone()
171
+ w1[:, self.target_layers] = (
172
+ w1.clone()[:, self.target_layers]
173
+ + w_hat
174
+ - self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
175
+ )
176
+
177
+ x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
178
+
179
+ x1 = self.net.face_pool(x1)
180
+ result = (
181
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
182
+ )
183
+ return (
184
+ result,
185
+ {
186
+ "w1": w1.cpu().detach().numpy(),
187
+ "w1_initial": w1_initial.cpu().detach().numpy(),
188
+ },
189
+ )
190
+
191
+ def change_style(self, latents):
192
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
193
+ w1 = w1_initial.clone()
194
+
195
+ z1 = torch.randn(1, 512).to("cuda")
196
+ x1, w2 = self.net.decoder(
197
+ [z1],
198
+ input_is_latent=False,
199
+ randomize_noise=False,
200
+ return_latents=True,
201
+ truncation=self.truncation,
202
+ truncation_latent=self.net.latent_avg[0],
203
+ )
204
+ w1[:, 6:] = w2.detach()[:, 0]
205
+ x1, w1_new = self.net.decoder(
206
+ [w1],
207
+ input_is_latent=True,
208
+ randomize_noise=False,
209
+ return_latents=True,
210
+ )
211
+ result = (
212
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
213
+ )
214
+ return (
215
+ result,
216
+ {
217
+ "w1": w1_new.cpu().detach().numpy(),
218
+ "w1_initial": w1_new.cpu().detach().numpy(),
219
+ },
220
+ )
221
+
222
+ def reset(self, latents):
223
+ w1, w1_initial, f1 = self.latents_to_tensor(latents)
224
+ x1, w1_new, f1 = self.net.decoder(
225
+ [w1_initial],
226
+ input_is_latent=True,
227
+ randomize_noise=False,
228
+ return_feature_map=True,
229
+ return_latents=True,
230
+ )
231
+ result = (
232
+ ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
233
+ )
234
+ return (
235
+ result,
236
+ {
237
+ "w1": w1_new.cpu().detach().numpy(),
238
+ "w1_initial": w1_new.cpu().detach().numpy(),
239
+ },
240
+ )