Spaces:
Running
Running
update model version 2
Browse files- app.py +34 -38
- models/fb_encoder.onnx +0 -3
- models/g_mapping.onnx +0 -3
- models/g_synthesis.onnx +0 -3
- models/waifu_dect.onnx +0 -3
- requirements.txt +1 -0
app.py
CHANGED
@@ -3,6 +3,7 @@ import imageio
|
|
3 |
import numpy as np
|
4 |
import onnx
|
5 |
import onnxruntime as rt
|
|
|
6 |
from numpy.random import RandomState
|
7 |
from skimage import transform
|
8 |
|
@@ -74,55 +75,48 @@ def nms(pred, conf_thres, iou_thres, max_instance=20): # pred (anchor_num, 5 +
|
|
74 |
|
75 |
class Model:
|
76 |
def __init__(self):
|
77 |
-
self.img_avg = None
|
78 |
self.detector = None
|
79 |
self.encoder = None
|
80 |
self.g_synthesis = None
|
81 |
self.g_mapping = None
|
82 |
-
self.w_avg = None
|
83 |
self.detector_stride = None
|
84 |
self.detector_imgsz = None
|
85 |
self.detector_class_names = None
|
86 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
def load_models(self, model_dir):
|
89 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
90 |
-
g_mapping = onnx.load(
|
91 |
w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
|
92 |
w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
|
93 |
w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
|
94 |
self.w_avg = w_avg
|
95 |
-
self.g_mapping = rt.InferenceSession(
|
96 |
-
self.g_synthesis = rt.InferenceSession(
|
97 |
-
self.encoder = rt.InferenceSession(
|
98 |
-
self.detector = rt.InferenceSession(
|
99 |
detector_meta = self.detector.get_modelmeta().custom_metadata_map
|
100 |
self.detector_stride = int(detector_meta['stride'])
|
101 |
self.detector_imgsz = 1088
|
102 |
self.detector_class_names = eval(detector_meta['names'])
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
def get_img(self, w):
|
108 |
-
img = self.g_synthesis.run(None, {'w': w})[0]
|
109 |
return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
|
110 |
|
111 |
-
def get_w(self, z,
|
112 |
-
return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([
|
113 |
|
114 |
-
def encode_img(self, img
|
115 |
-
|
116 |
np.float32)
|
117 |
-
|
118 |
-
from_img = self.img_avg.copy()
|
119 |
-
for i in range(iteration):
|
120 |
-
dimg = np.concatenate([target_img, from_img], axis=1)
|
121 |
-
dw = self.encoder.run(None, {'dimg': dimg})[0]
|
122 |
-
w += dw
|
123 |
-
from_img = transform.resize(self.g_synthesis.run(None, {'w': w})[0][0].transpose(1, 2, 0),
|
124 |
-
(256, 256)).transpose(2, 0, 1)[np.newaxis, :]
|
125 |
-
return w
|
126 |
|
127 |
def detect(self, im0, conf_thres, iou_thres, detail=False):
|
128 |
if im0 is None:
|
@@ -217,11 +211,11 @@ class Model:
|
|
217 |
imgs.append(temp_img)
|
218 |
return imgs
|
219 |
|
220 |
-
def gen_video(self, w1, w2, path, frame_num=10):
|
221 |
video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
|
222 |
lin = np.linspace(0, 1, frame_num)
|
223 |
for i in range(0, frame_num):
|
224 |
-
img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2))
|
225 |
video.append_data(img)
|
226 |
video.close()
|
227 |
|
@@ -232,10 +226,10 @@ def get_thumbnail(img):
|
|
232 |
return img_new
|
233 |
|
234 |
|
235 |
-
def gen_fn(method, seed,
|
236 |
z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if method == 1 else np.random.randn(1, 512)
|
237 |
-
w = model.get_w(z.astype(dtype=np.float32),
|
238 |
-
img_out = model.get_img(w)
|
239 |
return img_out, w, get_thumbnail(img_out)
|
240 |
|
241 |
|
@@ -250,11 +244,10 @@ def encode_img_fn(img):
|
|
250 |
return "success", imgs[0], img_out, w, get_thumbnail(img_out)
|
251 |
|
252 |
|
253 |
-
def gen_video_fn(w1, w2, frame):
|
254 |
if w1 is None or w2 is None:
|
255 |
return None
|
256 |
-
model.gen_video(w1, w2, "video.mp4",
|
257 |
-
int(frame))
|
258 |
return "video.mp4"
|
259 |
|
260 |
|
@@ -274,7 +267,9 @@ if __name__ == '__main__':
|
|
274 |
with gr.Row():
|
275 |
gen_input1 = gr.Radio(label="method", choices=["random", "use seed"], type="index")
|
276 |
gen_input2 = gr.Number(value=1, label="seed ( int between -2^31 and 2^31 - 1 )")
|
277 |
-
gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="truncation psi")
|
|
|
|
|
278 |
with gr.Group():
|
279 |
gen_submit = gr.Button("Generate", variant="primary")
|
280 |
with gr.Column():
|
@@ -327,7 +322,7 @@ if __name__ == '__main__':
|
|
327 |
generate_video_button = gr.Button("Generate", variant="primary")
|
328 |
with gr.Column():
|
329 |
generate_video_output = gr.Video(label="output video")
|
330 |
-
gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3],
|
331 |
[gen_output1, select_img_input_w1, select_img_input_img1])
|
332 |
encode_img_submit.click(encode_img_fn, [encode_img_input],
|
333 |
[encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
|
@@ -341,6 +336,7 @@ if __name__ == '__main__':
|
|
341 |
[select_img2_dropdown, select_img_input_img1, select_img_input_img2,
|
342 |
select_img_input_w1, select_img_input_w2],
|
343 |
[select_img2_output_img, select_img2_output_w])
|
344 |
-
generate_video_button.click(gen_video_fn,
|
|
|
345 |
[generate_video_output])
|
346 |
app.launch()
|
|
|
3 |
import numpy as np
|
4 |
import onnx
|
5 |
import onnxruntime as rt
|
6 |
+
import huggingface_hub
|
7 |
from numpy.random import RandomState
|
8 |
from skimage import transform
|
9 |
|
|
|
75 |
|
76 |
class Model:
|
77 |
def __init__(self):
|
|
|
78 |
self.detector = None
|
79 |
self.encoder = None
|
80 |
self.g_synthesis = None
|
81 |
self.g_mapping = None
|
|
|
82 |
self.detector_stride = None
|
83 |
self.detector_imgsz = None
|
84 |
self.detector_class_names = None
|
85 |
+
self.w_avg = None
|
86 |
+
self.load_models("skytnt/fbanime-gan")
|
87 |
+
|
88 |
+
def load_models(self, repo):
|
89 |
+
g_mapping_path = huggingface_hub.hf_hub_download(repo, "g_mapping.onnx")
|
90 |
+
g_synthesis_path = huggingface_hub.hf_hub_download(repo, "g_synthesis.onnx")
|
91 |
+
encoder_path = huggingface_hub.hf_hub_download(repo, "encoder.onnx")
|
92 |
+
detector_path = huggingface_hub.hf_hub_download(repo, "waifu_dect.onnx")
|
93 |
|
|
|
94 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
95 |
+
g_mapping = onnx.load(g_mapping_path)
|
96 |
w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
|
97 |
w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
|
98 |
w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
|
99 |
self.w_avg = w_avg
|
100 |
+
self.g_mapping = rt.InferenceSession(g_mapping_path, providers=providers)
|
101 |
+
self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=providers)
|
102 |
+
self.encoder = rt.InferenceSession(encoder_path, providers=providers)
|
103 |
+
self.detector = rt.InferenceSession(detector_path, providers=providers)
|
104 |
detector_meta = self.detector.get_modelmeta().custom_metadata_map
|
105 |
self.detector_stride = int(detector_meta['stride'])
|
106 |
self.detector_imgsz = 1088
|
107 |
self.detector_class_names = eval(detector_meta['names'])
|
108 |
|
109 |
+
def get_img(self, w, noise=0):
|
110 |
+
img = self.g_synthesis.run(None, {'w': w, "noise": np.asarray([noise], dtype=np.float32)})[0]
|
|
|
|
|
|
|
111 |
return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
|
112 |
|
113 |
+
def get_w(self, z, psi1, psi2):
|
114 |
+
return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
|
115 |
|
116 |
+
def encode_img(self, img):
|
117 |
+
img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
|
118 |
np.float32)
|
119 |
+
return self.encoder.run(None, {'img': img})[0] + self.w_avg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
def detect(self, im0, conf_thres, iou_thres, detail=False):
|
122 |
if im0 is None:
|
|
|
211 |
imgs.append(temp_img)
|
212 |
return imgs
|
213 |
|
214 |
+
def gen_video(self, w1, w2, noise, path, frame_num=10):
|
215 |
video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
|
216 |
lin = np.linspace(0, 1, frame_num)
|
217 |
for i in range(0, frame_num):
|
218 |
+
img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2), noise)
|
219 |
video.append_data(img)
|
220 |
video.close()
|
221 |
|
|
|
226 |
return img_new
|
227 |
|
228 |
|
229 |
+
def gen_fn(method, seed, psi1, psi2, noise):
|
230 |
z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if method == 1 else np.random.randn(1, 512)
|
231 |
+
w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
|
232 |
+
img_out = model.get_img(w, noise)
|
233 |
return img_out, w, get_thumbnail(img_out)
|
234 |
|
235 |
|
|
|
244 |
return "success", imgs[0], img_out, w, get_thumbnail(img_out)
|
245 |
|
246 |
|
247 |
+
def gen_video_fn(w1, w2, noise, frame):
|
248 |
if w1 is None or w2 is None:
|
249 |
return None
|
250 |
+
model.gen_video(w1, w2, noise, "video.mp4", int(frame))
|
|
|
251 |
return "video.mp4"
|
252 |
|
253 |
|
|
|
267 |
with gr.Row():
|
268 |
gen_input1 = gr.Radio(label="method", choices=["random", "use seed"], type="index")
|
269 |
gen_input2 = gr.Number(value=1, label="seed ( int between -2^31 and 2^31 - 1 )")
|
270 |
+
gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="truncation psi 1")
|
271 |
+
gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
|
272 |
+
gen_input5 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="noise strength")
|
273 |
with gr.Group():
|
274 |
gen_submit = gr.Button("Generate", variant="primary")
|
275 |
with gr.Column():
|
|
|
322 |
generate_video_button = gr.Button("Generate", variant="primary")
|
323 |
with gr.Column():
|
324 |
generate_video_output = gr.Video(label="output video")
|
325 |
+
gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4, gen_input5],
|
326 |
[gen_output1, select_img_input_w1, select_img_input_img1])
|
327 |
encode_img_submit.click(encode_img_fn, [encode_img_input],
|
328 |
[encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
|
|
|
336 |
[select_img2_dropdown, select_img_input_img1, select_img_input_img2,
|
337 |
select_img_input_w1, select_img_input_w2],
|
338 |
[select_img2_output_img, select_img2_output_w])
|
339 |
+
generate_video_button.click(gen_video_fn,
|
340 |
+
[select_img1_output_w, select_img2_output_w, gen_input5, generate_video_frame],
|
341 |
[generate_video_output])
|
342 |
app.launch()
|
models/fb_encoder.onnx
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1e8d206092b3e686b6d4798f8976e154413b12316161b5e1b077a493a41d75e4
|
3 |
-
size 706114106
|
|
|
|
|
|
|
|
models/g_mapping.onnx
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b6e5918a214bb2b1cbdecb76f9c2124fd5fa2cb88e02de16d10530f7441fb205
|
3 |
-
size 8410285
|
|
|
|
|
|
|
|
models/g_synthesis.onnx
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:12929197c3eeb423c5987303995ef640eb5e2e44638cd3c0657a8aed67fc2aab
|
3 |
-
size 112794026
|
|
|
|
|
|
|
|
models/waifu_dect.onnx
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:4a5de6949912bf94c3307f2b18ebc7b49f309e713b1799d29805ccd882e327d3
|
3 |
-
size 83550422
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -2,3 +2,4 @@ onnx
|
|
2 |
onnxruntime-gpu
|
3 |
scikit-image
|
4 |
imageio-ffmpeg
|
|
|
|
2 |
onnxruntime-gpu
|
3 |
scikit-image
|
4 |
imageio-ffmpeg
|
5 |
+
huggingface_hub
|