Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,719 Bytes
a891a57 58ca92c 2214795 a891a57 58ca92c e18ffd5 a891a57 e18ffd5 a891a57 4f1874e a891a57 437577c a891a57 0bc2c6f a891a57 0bc2c6f dfa0990 a891a57 0bc2c6f a891a57 0bc2c6f 8f7fee0 a891a57 0bc2c6f 8f7fee0 a891a57 8f7fee0 a891a57 8f7fee0 a891a57 0bc2c6f 4f1874e a891a57 e3070b6 0bc2c6f a891a57 0bc2c6f 4f1874e a891a57 58ca92c 0bc2c6f a891a57 58ca92c a891a57 0bc2c6f a891a57 dfa0990 c75f14e e3070b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# coding: utf-8
"""
Pipeline for gradio
"""
import gradio as gr
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
# from .utils.camera import get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
def update_args(args, user_args):
"""update the args according to user inputs
"""
for k, v in user_args.items():
if hasattr(args, k):
setattr(args, k, v)
return args
class GradioPipeline(LivePortraitPipeline):
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
def execute_video(
self,
input_image_path,
input_video_path,
flag_relative_input,
flag_do_crop_input,
flag_remap_input,
):
""" for video driven potrait animation
"""
if input_image_path is not None and input_video_path is not None:
args_user = {
'source_image': input_image_path,
'driving_info': input_video_path,
'flag_relative': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
# gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
else:
raise gr.Error("Please upload the source portrait and driving video 🤗🤗🤗", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop = True):
""" for single image retargeting
"""
# disposable feature
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, flag_do_crop)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
x_s_user = x_s_user.to("cuda")
f_s_user = f_s_user.to("cuda")
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
num_kp = x_s_user.shape[1]
# default: use x_s
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x′_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
# gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def execute_image_lip(self, input_lip_ratio: float, input_image, flag_do_crop = True):
""" for single image retargeting
"""
# disposable feature
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, flag_do_crop)
if input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
x_s_user = x_s_user.to("cuda")
f_s_user = f_s_user.to("cuda")
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
num_kp = x_s_user.shape[1]
# default: use x_s
x_d_new = x_s_user + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x′_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
# gr.Info("Run successfully!", duration=2)
return out_to_ori_blend
def prepare_retargeting(self, input_image, flag_do_crop = True):
""" for single image retargeting
"""
if input_image is not None:
# gr.Info("Upload successfully!", duration=2)
inference_cfg = self.live_portrait_wrapper.cfg
######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=1) # n=1 means do not trim the pixels
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_single_image(img_rgb)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
# R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
|