skytnt commited on
Commit
fc2d897
1 Parent(s): 90074d8

update model version 2

Browse files
app.py CHANGED
@@ -3,6 +3,7 @@ import imageio
3
  import numpy as np
4
  import onnx
5
  import onnxruntime as rt
 
6
  from numpy.random import RandomState
7
  from skimage import transform
8
 
@@ -74,55 +75,48 @@ def nms(pred, conf_thres, iou_thres, max_instance=20): # pred (anchor_num, 5 +
74
 
75
  class Model:
76
  def __init__(self):
77
- self.img_avg = None
78
  self.detector = None
79
  self.encoder = None
80
  self.g_synthesis = None
81
  self.g_mapping = None
82
- self.w_avg = None
83
  self.detector_stride = None
84
  self.detector_imgsz = None
85
  self.detector_class_names = None
86
- self.load_models("./models/")
 
 
 
 
 
 
 
87
 
88
- def load_models(self, model_dir):
89
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
90
- g_mapping = onnx.load(model_dir + "g_mapping.onnx")
91
  w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
92
  w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
93
  w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
94
  self.w_avg = w_avg
95
- self.g_mapping = rt.InferenceSession(model_dir + "g_mapping.onnx", providers=providers)
96
- self.g_synthesis = rt.InferenceSession(model_dir + "g_synthesis.onnx", providers=providers)
97
- self.encoder = rt.InferenceSession(model_dir + "fb_encoder.onnx", providers=providers)
98
- self.detector = rt.InferenceSession(model_dir + "waifu_dect.onnx", providers=providers)
99
  detector_meta = self.detector.get_modelmeta().custom_metadata_map
100
  self.detector_stride = int(detector_meta['stride'])
101
  self.detector_imgsz = 1088
102
  self.detector_class_names = eval(detector_meta['names'])
103
 
104
- self.img_avg = transform.resize(self.g_synthesis.run(None, {'w': w_avg})[0][0].transpose(1, 2, 0),
105
- (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
106
-
107
- def get_img(self, w):
108
- img = self.g_synthesis.run(None, {'w': w})[0]
109
  return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
110
 
111
- def get_w(self, z, psi):
112
- return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi], dtype=np.float32)})[0]
113
 
114
- def encode_img(self, img, iteration=5):
115
- target_img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
116
  np.float32)
117
- w = self.w_avg.copy()
118
- from_img = self.img_avg.copy()
119
- for i in range(iteration):
120
- dimg = np.concatenate([target_img, from_img], axis=1)
121
- dw = self.encoder.run(None, {'dimg': dimg})[0]
122
- w += dw
123
- from_img = transform.resize(self.g_synthesis.run(None, {'w': w})[0][0].transpose(1, 2, 0),
124
- (256, 256)).transpose(2, 0, 1)[np.newaxis, :]
125
- return w
126
 
127
  def detect(self, im0, conf_thres, iou_thres, detail=False):
128
  if im0 is None:
@@ -217,11 +211,11 @@ class Model:
217
  imgs.append(temp_img)
218
  return imgs
219
 
220
- def gen_video(self, w1, w2, path, frame_num=10):
221
  video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
222
  lin = np.linspace(0, 1, frame_num)
223
  for i in range(0, frame_num):
224
- img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2))
225
  video.append_data(img)
226
  video.close()
227
 
@@ -232,10 +226,10 @@ def get_thumbnail(img):
232
  return img_new
233
 
234
 
235
- def gen_fn(method, seed, psi):
236
  z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if method == 1 else np.random.randn(1, 512)
237
- w = model.get_w(z.astype(dtype=np.float32), psi)
238
- img_out = model.get_img(w)
239
  return img_out, w, get_thumbnail(img_out)
240
 
241
 
@@ -250,11 +244,10 @@ def encode_img_fn(img):
250
  return "success", imgs[0], img_out, w, get_thumbnail(img_out)
251
 
252
 
253
- def gen_video_fn(w1, w2, frame):
254
  if w1 is None or w2 is None:
255
  return None
256
- model.gen_video(w1, w2, "video.mp4",
257
- int(frame))
258
  return "video.mp4"
259
 
260
 
@@ -274,7 +267,9 @@ if __name__ == '__main__':
274
  with gr.Row():
275
  gen_input1 = gr.Radio(label="method", choices=["random", "use seed"], type="index")
276
  gen_input2 = gr.Number(value=1, label="seed ( int between -2^31 and 2^31 - 1 )")
277
- gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="truncation psi")
 
 
278
  with gr.Group():
279
  gen_submit = gr.Button("Generate", variant="primary")
280
  with gr.Column():
@@ -327,7 +322,7 @@ if __name__ == '__main__':
327
  generate_video_button = gr.Button("Generate", variant="primary")
328
  with gr.Column():
329
  generate_video_output = gr.Video(label="output video")
330
- gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3],
331
  [gen_output1, select_img_input_w1, select_img_input_img1])
332
  encode_img_submit.click(encode_img_fn, [encode_img_input],
333
  [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
@@ -341,6 +336,7 @@ if __name__ == '__main__':
341
  [select_img2_dropdown, select_img_input_img1, select_img_input_img2,
342
  select_img_input_w1, select_img_input_w2],
343
  [select_img2_output_img, select_img2_output_w])
344
- generate_video_button.click(gen_video_fn, [select_img1_output_w, select_img2_output_w, generate_video_frame],
 
345
  [generate_video_output])
346
  app.launch()
 
3
  import numpy as np
4
  import onnx
5
  import onnxruntime as rt
6
+ import huggingface_hub
7
  from numpy.random import RandomState
8
  from skimage import transform
9
 
 
75
 
76
  class Model:
77
  def __init__(self):
 
78
  self.detector = None
79
  self.encoder = None
80
  self.g_synthesis = None
81
  self.g_mapping = None
 
82
  self.detector_stride = None
83
  self.detector_imgsz = None
84
  self.detector_class_names = None
85
+ self.w_avg = None
86
+ self.load_models("skytnt/fbanime-gan")
87
+
88
+ def load_models(self, repo):
89
+ g_mapping_path = huggingface_hub.hf_hub_download(repo, "g_mapping.onnx")
90
+ g_synthesis_path = huggingface_hub.hf_hub_download(repo, "g_synthesis.onnx")
91
+ encoder_path = huggingface_hub.hf_hub_download(repo, "encoder.onnx")
92
+ detector_path = huggingface_hub.hf_hub_download(repo, "waifu_dect.onnx")
93
 
 
94
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
95
+ g_mapping = onnx.load(g_mapping_path)
96
  w_avg = [x for x in g_mapping.graph.initializer if x.name == "w_avg"][0]
97
  w_avg = np.frombuffer(w_avg.raw_data, dtype=np.float32)[np.newaxis, :]
98
  w_avg = w_avg.repeat(16, axis=0)[np.newaxis, :]
99
  self.w_avg = w_avg
100
+ self.g_mapping = rt.InferenceSession(g_mapping_path, providers=providers)
101
+ self.g_synthesis = rt.InferenceSession(g_synthesis_path, providers=providers)
102
+ self.encoder = rt.InferenceSession(encoder_path, providers=providers)
103
+ self.detector = rt.InferenceSession(detector_path, providers=providers)
104
  detector_meta = self.detector.get_modelmeta().custom_metadata_map
105
  self.detector_stride = int(detector_meta['stride'])
106
  self.detector_imgsz = 1088
107
  self.detector_class_names = eval(detector_meta['names'])
108
 
109
+ def get_img(self, w, noise=0):
110
+ img = self.g_synthesis.run(None, {'w': w, "noise": np.asarray([noise], dtype=np.float32)})[0]
 
 
 
111
  return (img.transpose(0, 2, 3, 1) * 127.5 + 128).clip(0, 255).astype(np.uint8)[0]
112
 
113
+ def get_w(self, z, psi1, psi2):
114
+ return self.g_mapping.run(None, {'z': z, 'psi': np.asarray([psi1, psi2], dtype=np.float32)})[0]
115
 
116
+ def encode_img(self, img):
117
+ img = transform.resize(((img / 255 - 0.5) / 0.5), (256, 256)).transpose(2, 0, 1)[np.newaxis, :].astype(
118
  np.float32)
119
+ return self.encoder.run(None, {'img': img})[0] + self.w_avg
 
 
 
 
 
 
 
 
120
 
121
  def detect(self, im0, conf_thres, iou_thres, detail=False):
122
  if im0 is None:
 
211
  imgs.append(temp_img)
212
  return imgs
213
 
214
+ def gen_video(self, w1, w2, noise, path, frame_num=10):
215
  video = imageio.get_writer(path, mode='I', fps=frame_num // 2, codec='libx264', bitrate='16M')
216
  lin = np.linspace(0, 1, frame_num)
217
  for i in range(0, frame_num):
218
+ img = self.get_img(((1 - lin[i]) * w1) + (lin[i] * w2), noise)
219
  video.append_data(img)
220
  video.close()
221
 
 
226
  return img_new
227
 
228
 
229
+ def gen_fn(method, seed, psi1, psi2, noise):
230
  z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if method == 1 else np.random.randn(1, 512)
231
+ w = model.get_w(z.astype(dtype=np.float32), psi1, psi2)
232
+ img_out = model.get_img(w, noise)
233
  return img_out, w, get_thumbnail(img_out)
234
 
235
 
 
244
  return "success", imgs[0], img_out, w, get_thumbnail(img_out)
245
 
246
 
247
+ def gen_video_fn(w1, w2, noise, frame):
248
  if w1 is None or w2 is None:
249
  return None
250
+ model.gen_video(w1, w2, noise, "video.mp4", int(frame))
 
251
  return "video.mp4"
252
 
253
 
 
267
  with gr.Row():
268
  gen_input1 = gr.Radio(label="method", choices=["random", "use seed"], type="index")
269
  gen_input2 = gr.Number(value=1, label="seed ( int between -2^31 and 2^31 - 1 )")
270
+ gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="truncation psi 1")
271
+ gen_input4 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="truncation psi 2")
272
+ gen_input5 = gr.Slider(minimum=0, maximum=1, step=0.01, value=1, label="noise strength")
273
  with gr.Group():
274
  gen_submit = gr.Button("Generate", variant="primary")
275
  with gr.Column():
 
322
  generate_video_button = gr.Button("Generate", variant="primary")
323
  with gr.Column():
324
  generate_video_output = gr.Video(label="output video")
325
+ gen_submit.click(gen_fn, [gen_input1, gen_input2, gen_input3, gen_input4, gen_input5],
326
  [gen_output1, select_img_input_w1, select_img_input_img1])
327
  encode_img_submit.click(encode_img_fn, [encode_img_input],
328
  [encode_img_output1, encode_img_output2, encode_img_output3, select_img_input_w2,
 
336
  [select_img2_dropdown, select_img_input_img1, select_img_input_img2,
337
  select_img_input_w1, select_img_input_w2],
338
  [select_img2_output_img, select_img2_output_w])
339
+ generate_video_button.click(gen_video_fn,
340
+ [select_img1_output_w, select_img2_output_w, gen_input5, generate_video_frame],
341
  [generate_video_output])
342
  app.launch()
models/fb_encoder.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e8d206092b3e686b6d4798f8976e154413b12316161b5e1b077a493a41d75e4
3
- size 706114106
 
 
 
 
models/g_mapping.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6e5918a214bb2b1cbdecb76f9c2124fd5fa2cb88e02de16d10530f7441fb205
3
- size 8410285
 
 
 
 
models/g_synthesis.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:12929197c3eeb423c5987303995ef640eb5e2e44638cd3c0657a8aed67fc2aab
3
- size 112794026
 
 
 
 
models/waifu_dect.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a5de6949912bf94c3307f2b18ebc7b49f309e713b1799d29805ccd882e327d3
3
- size 83550422
 
 
 
 
requirements.txt CHANGED
@@ -2,3 +2,4 @@ onnx
2
  onnxruntime-gpu
3
  scikit-image
4
  imageio-ffmpeg
 
 
2
  onnxruntime-gpu
3
  scikit-image
4
  imageio-ffmpeg
5
+ huggingface_hub