|
import os |
|
import sys |
|
import numpy as np |
|
import torch |
|
import cv2 |
|
from PIL import Image |
|
import folder_paths |
|
import comfy.utils |
|
import time |
|
import copy |
|
import dill |
|
import yaml |
|
from ultralytics import YOLO |
|
|
|
current_file_path = os.path.abspath(__file__) |
|
current_directory = os.path.dirname(current_file_path) |
|
|
|
from .LivePortrait.live_portrait_wrapper import LivePortraitWrapper |
|
from .LivePortrait.utils.camera import get_rotation_matrix |
|
from .LivePortrait.config.inference_config import InferenceConfig |
|
|
|
from .LivePortrait.modules.spade_generator import SPADEDecoder |
|
from .LivePortrait.modules.warping_network import WarpingNetwork |
|
from .LivePortrait.modules.motion_extractor import MotionExtractor |
|
from .LivePortrait.modules.appearance_feature_extractor import AppearanceFeatureExtractor |
|
from .LivePortrait.modules.stitching_retargeting_network import StitchingRetargetingNetwork |
|
from collections import OrderedDict |
|
|
|
cur_device = None |
|
def get_device(): |
|
global cur_device |
|
if cur_device == None: |
|
if torch.cuda.is_available(): |
|
cur_device = torch.device('cuda') |
|
print("Uses CUDA device.") |
|
elif torch.backends.mps.is_available(): |
|
cur_device = torch.device('mps') |
|
print("Uses MPS device.") |
|
else: |
|
cur_device = torch.device('cpu') |
|
print("Uses CPU device.") |
|
return cur_device |
|
|
|
def tensor2pil(image): |
|
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
def pil2tensor(image): |
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
def rgb_crop(rgb, region): |
|
return rgb[region[1]:region[3], region[0]:region[2]] |
|
|
|
def rgb_crop_batch(rgbs, region): |
|
return rgbs[:, region[1]:region[3], region[0]:region[2]] |
|
def get_rgb_size(rgb): |
|
return rgb.shape[1], rgb.shape[0] |
|
def create_transform_matrix(x, y, s_x, s_y): |
|
return np.float32([[s_x, 0, x], [0, s_y, y]]) |
|
|
|
def get_model_dir(m): |
|
try: |
|
return folder_paths.get_folder_paths(m)[0] |
|
except: |
|
return os.path.join(folder_paths.models_dir, m) |
|
|
|
def calc_crop_limit(center, img_size, crop_size): |
|
pos = center - crop_size / 2 |
|
if pos < 0: |
|
crop_size += pos * 2 |
|
pos = 0 |
|
|
|
pos2 = pos + crop_size |
|
|
|
if img_size < pos2: |
|
crop_size -= (pos2 - img_size) * 2 |
|
pos2 = img_size |
|
pos = pos2 - crop_size |
|
|
|
return pos, pos2, crop_size |
|
|
|
def retargeting(delta_out, driving_exp, factor, idxes): |
|
for idx in idxes: |
|
|
|
delta_out[0, idx] += driving_exp[0, idx] * factor |
|
|
|
class PreparedSrcImg: |
|
def __init__(self, src_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori): |
|
self.src_rgb = src_rgb |
|
self.crop_trans_m = crop_trans_m |
|
self.x_s_info = x_s_info |
|
self.f_s_user = f_s_user |
|
self.x_s_user = x_s_user |
|
self.mask_ori = mask_ori |
|
|
|
import requests |
|
from tqdm import tqdm |
|
|
|
class LP_Engine: |
|
pipeline = None |
|
detect_model = None |
|
mask_img = None |
|
temp_img_idx = 0 |
|
|
|
def get_temp_img_name(self): |
|
self.temp_img_idx += 1 |
|
return "expression_edit_preview" + str(self.temp_img_idx) + ".png" |
|
|
|
def download_model(_, file_path, model_url): |
|
print('AdvancedLivePortrait: Downloading model...') |
|
response = requests.get(model_url, stream=True) |
|
try: |
|
if response.status_code == 200: |
|
total_size = int(response.headers.get('content-length', 0)) |
|
block_size = 1024 |
|
|
|
|
|
with open(file_path, 'wb') as file, tqdm( |
|
desc='Downloading', |
|
total=total_size, |
|
unit='iB', |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as bar: |
|
for data in response.iter_content(block_size): |
|
bar.update(len(data)) |
|
file.write(data) |
|
|
|
except requests.exceptions.RequestException as err: |
|
print('AdvancedLivePortrait: Model download failed: {err}') |
|
print(f'AdvancedLivePortrait: Download it manually from: {model_url}') |
|
print(f'AdvancedLivePortrait: And put it in {file_path}') |
|
except Exception as e: |
|
print(f'AdvancedLivePortrait: An unexpected error occurred: {e}') |
|
|
|
def remove_ddp_dumplicate_key(_, state_dict): |
|
state_dict_new = OrderedDict() |
|
for key in state_dict.keys(): |
|
state_dict_new[key.replace('module.', '')] = state_dict[key] |
|
return state_dict_new |
|
|
|
def filter_for_model(_, checkpoint, prefix): |
|
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if |
|
key.startswith(prefix)} |
|
return filtered_checkpoint |
|
|
|
def load_model(self, model_config, model_type): |
|
|
|
device = get_device() |
|
|
|
if model_type == 'stitching_retargeting_module': |
|
ckpt_path = os.path.join(get_model_dir("liveportrait"), "retargeting_models", model_type + ".pth") |
|
else: |
|
ckpt_path = os.path.join(get_model_dir("liveportrait"), "base_models", model_type + ".pth") |
|
|
|
is_safetensors = None |
|
if os.path.isfile(ckpt_path) == False: |
|
is_safetensors = True |
|
ckpt_path = os.path.join(get_model_dir("liveportrait"), model_type + ".safetensors") |
|
if os.path.isfile(ckpt_path) == False: |
|
self.download_model(ckpt_path, |
|
"https://huggingface.co/Kijai/LivePortrait_safetensors/resolve/main/" + model_type + ".safetensors") |
|
model_params = model_config['model_params'][f'{model_type}_params'] |
|
if model_type == 'appearance_feature_extractor': |
|
model = AppearanceFeatureExtractor(**model_params).to(device) |
|
elif model_type == 'motion_extractor': |
|
model = MotionExtractor(**model_params).to(device) |
|
elif model_type == 'warping_module': |
|
model = WarpingNetwork(**model_params).to(device) |
|
elif model_type == 'spade_generator': |
|
model = SPADEDecoder(**model_params).to(device) |
|
elif model_type == 'stitching_retargeting_module': |
|
|
|
config = model_config['model_params']['stitching_retargeting_module_params'] |
|
checkpoint = comfy.utils.load_torch_file(ckpt_path) |
|
|
|
stitcher = StitchingRetargetingNetwork(**config.get('stitching')) |
|
if is_safetensors: |
|
stitcher.load_state_dict(self.filter_for_model(checkpoint, 'retarget_shoulder')) |
|
else: |
|
stitcher.load_state_dict(self.remove_ddp_dumplicate_key(checkpoint['retarget_shoulder'])) |
|
stitcher = stitcher.to(device) |
|
stitcher.eval() |
|
|
|
return { |
|
'stitching': stitcher, |
|
} |
|
else: |
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
|
model.load_state_dict(comfy.utils.load_torch_file(ckpt_path)) |
|
model.eval() |
|
return model |
|
|
|
def load_models(self): |
|
model_path = get_model_dir("liveportrait") |
|
if not os.path.exists(model_path): |
|
os.mkdir(model_path) |
|
|
|
model_config_path = os.path.join(current_directory, 'LivePortrait', 'config', 'models.yaml') |
|
model_config = yaml.safe_load(open(model_config_path, 'r')) |
|
|
|
appearance_feature_extractor = self.load_model(model_config, 'appearance_feature_extractor') |
|
motion_extractor = self.load_model(model_config, 'motion_extractor') |
|
warping_module = self.load_model(model_config, 'warping_module') |
|
spade_generator = self.load_model(model_config, 'spade_generator') |
|
stitching_retargeting_module = self.load_model(model_config, 'stitching_retargeting_module') |
|
|
|
self.pipeline = LivePortraitWrapper(InferenceConfig(), appearance_feature_extractor, motion_extractor, warping_module, spade_generator, stitching_retargeting_module) |
|
|
|
def get_detect_model(self): |
|
if self.detect_model == None: |
|
model_dir = get_model_dir("ultralytics") |
|
if not os.path.exists(model_dir): os.mkdir(model_dir) |
|
model_path = os.path.join(model_dir, "face_yolov8n.pt") |
|
if not os.path.exists(model_path): |
|
self.download_model(model_path, "https://huggingface.co/Bingsu/adetailer/resolve/main/face_yolov8n.pt") |
|
self.detect_model = YOLO(model_path) |
|
|
|
return self.detect_model |
|
|
|
def get_face_bboxes(self, image_rgb): |
|
detect_model = self.get_detect_model() |
|
pred = detect_model(image_rgb, conf=0.7, device="") |
|
return pred[0].boxes.xyxy.cpu().numpy() |
|
|
|
def detect_face(self, image_rgb, crop_factor, sort = True): |
|
bboxes = self.get_face_bboxes(image_rgb) |
|
w, h = get_rgb_size(image_rgb) |
|
|
|
print(f"w, h:{w, h}") |
|
|
|
cx = w / 2 |
|
min_diff = w |
|
best_box = None |
|
for x1, y1, x2, y2 in bboxes: |
|
bbox_w = x2 - x1 |
|
if bbox_w < 30: continue |
|
diff = abs(cx - (x1 + bbox_w / 2)) |
|
if diff < min_diff: |
|
best_box = [x1, y1, x2, y2] |
|
print(f"diff, min_diff, best_box:{diff, min_diff, best_box}") |
|
min_diff = diff |
|
|
|
if best_box == None: |
|
print("Failed to detect face!!") |
|
return [0, 0, w, h] |
|
|
|
x1, y1, x2, y2 = best_box |
|
|
|
|
|
bbox_w = x2 - x1 |
|
bbox_h = y2 - y1 |
|
|
|
crop_w = bbox_w * crop_factor |
|
crop_h = bbox_h * crop_factor |
|
|
|
crop_w = max(crop_h, crop_w) |
|
crop_h = crop_w |
|
|
|
kernel_x = int(x1 + bbox_w / 2) |
|
kernel_y = int(y1 + bbox_h / 2) |
|
|
|
new_x1 = int(kernel_x - crop_w / 2) |
|
new_x2 = int(kernel_x + crop_w / 2) |
|
new_y1 = int(kernel_y - crop_h / 2) |
|
new_y2 = int(kernel_y + crop_h / 2) |
|
|
|
if not sort: |
|
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)] |
|
|
|
if new_x1 < 0: |
|
new_x2 -= new_x1 |
|
new_x1 = 0 |
|
elif w < new_x2: |
|
new_x1 -= (new_x2 - w) |
|
new_x2 = w |
|
if new_x1 < 0: |
|
new_x2 -= new_x1 |
|
new_x1 = 0 |
|
|
|
if new_y1 < 0: |
|
new_y2 -= new_y1 |
|
new_y1 = 0 |
|
elif h < new_y2: |
|
new_y1 -= (new_y2 - h) |
|
new_y2 = h |
|
if new_y1 < 0: |
|
new_y2 -= new_y1 |
|
new_y1 = 0 |
|
|
|
if w < new_x2 and h < new_y2: |
|
over_x = new_x2 - w |
|
over_y = new_y2 - h |
|
over_min = min(over_x, over_y) |
|
new_x2 -= over_min |
|
new_y2 -= over_min |
|
|
|
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)] |
|
|
|
|
|
def calc_face_region(self, square, dsize): |
|
region = copy.deepcopy(square) |
|
is_changed = False |
|
if dsize[0] < region[2]: |
|
region[2] = dsize[0] |
|
is_changed = True |
|
if dsize[1] < region[3]: |
|
region[3] = dsize[1] |
|
is_changed = True |
|
|
|
return region, is_changed |
|
|
|
def expand_img(self, rgb_img, square): |
|
|
|
crop_trans_m = create_transform_matrix(max(-square[0], 0), max(-square[1], 0), 1, 1) |
|
new_img = cv2.warpAffine(rgb_img, crop_trans_m, (square[2] - square[0], square[3] - square[1]), |
|
cv2.INTER_LINEAR) |
|
return new_img |
|
|
|
def get_pipeline(self): |
|
if self.pipeline == None: |
|
print("Load pipeline...") |
|
self.load_models() |
|
|
|
return self.pipeline |
|
|
|
def prepare_src_image(self, img): |
|
h, w = img.shape[:2] |
|
input_shape = [256,256] |
|
if h != input_shape[0] or w != input_shape[1]: |
|
if 256 < h: interpolation = cv2.INTER_AREA |
|
else: interpolation = cv2.INTER_LINEAR |
|
x = cv2.resize(img, (input_shape[0], input_shape[1]), interpolation = interpolation) |
|
else: |
|
x = img.copy() |
|
|
|
if x.ndim == 3: |
|
x = x[np.newaxis].astype(np.float32) / 255. |
|
elif x.ndim == 4: |
|
x = x.astype(np.float32) / 255. |
|
else: |
|
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') |
|
x = np.clip(x, 0, 1) |
|
x = torch.from_numpy(x).permute(0, 3, 1, 2) |
|
x = x.to(get_device()) |
|
return x |
|
|
|
def GetMaskImg(self): |
|
if self.mask_img is None: |
|
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "./LivePortrait/utils/resources/mask_template.png") |
|
self.mask_img = cv2.imread(path, cv2.IMREAD_COLOR) |
|
return self.mask_img |
|
|
|
def crop_face(self, img_rgb, crop_factor): |
|
crop_region = self.detect_face(img_rgb, crop_factor) |
|
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb)) |
|
face_img = rgb_crop(img_rgb, face_region) |
|
if is_changed: face_img = self.expand_img(face_img, crop_region) |
|
return face_img |
|
|
|
def prepare_source(self, source_image, crop_factor, is_video = False, tracking = False): |
|
print("Prepare source...") |
|
engine = self.get_pipeline() |
|
source_image_np = (source_image * 255).byte().numpy() |
|
img_rgb = source_image_np[0] |
|
|
|
psi_list = [] |
|
for img_rgb in source_image_np: |
|
if tracking or len(psi_list) == 0: |
|
crop_region = self.detect_face(img_rgb, crop_factor) |
|
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb)) |
|
|
|
s_x = (face_region[2] - face_region[0]) / 512. |
|
s_y = (face_region[3] - face_region[1]) / 512. |
|
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s_x, s_y) |
|
mask_ori = cv2.warpAffine(self.GetMaskImg(), crop_trans_m, get_rgb_size(img_rgb), cv2.INTER_LINEAR) |
|
mask_ori = mask_ori.astype(np.float32) / 255. |
|
|
|
if is_changed: |
|
s = (crop_region[2] - crop_region[0]) / 512. |
|
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s, s) |
|
|
|
face_img = rgb_crop(img_rgb, face_region) |
|
if is_changed: face_img = self.expand_img(face_img, crop_region) |
|
i_s = self.prepare_src_image(face_img) |
|
x_s_info = engine.get_kp_info(i_s) |
|
f_s_user = engine.extract_feature_3d(i_s) |
|
x_s_user = engine.transform_keypoint(x_s_info) |
|
psi = PreparedSrcImg(img_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori) |
|
if is_video == False: |
|
return psi |
|
psi_list.append(psi) |
|
|
|
return psi_list |
|
|
|
def prepare_driving_video(self, face_images): |
|
print("Prepare driving video...") |
|
pipeline = self.get_pipeline() |
|
f_img_np = (face_images * 255).byte().numpy() |
|
|
|
out_list = [] |
|
for f_img in f_img_np: |
|
i_d = self.prepare_src_image(f_img) |
|
d_info = pipeline.get_kp_info(i_d) |
|
out_list.append(d_info) |
|
|
|
return out_list |
|
|
|
def calc_fe(_, x_d_new, eyes, eyebrow, wink, pupil_x, pupil_y, mouth, eee, woo, smile, |
|
rotate_pitch, rotate_yaw, rotate_roll): |
|
|
|
x_d_new[0, 20, 1] += smile * -0.01 |
|
x_d_new[0, 14, 1] += smile * -0.02 |
|
x_d_new[0, 17, 1] += smile * 0.0065 |
|
x_d_new[0, 17, 2] += smile * 0.003 |
|
x_d_new[0, 13, 1] += smile * -0.00275 |
|
x_d_new[0, 16, 1] += smile * -0.00275 |
|
x_d_new[0, 3, 1] += smile * -0.0035 |
|
x_d_new[0, 7, 1] += smile * -0.0035 |
|
|
|
x_d_new[0, 19, 1] += mouth * 0.001 |
|
x_d_new[0, 19, 2] += mouth * 0.0001 |
|
x_d_new[0, 17, 1] += mouth * -0.0001 |
|
rotate_pitch -= mouth * 0.05 |
|
|
|
x_d_new[0, 20, 2] += eee * -0.001 |
|
x_d_new[0, 20, 1] += eee * -0.001 |
|
|
|
x_d_new[0, 14, 1] += eee * -0.001 |
|
|
|
x_d_new[0, 14, 1] += woo * 0.001 |
|
x_d_new[0, 3, 1] += woo * -0.0005 |
|
x_d_new[0, 7, 1] += woo * -0.0005 |
|
x_d_new[0, 17, 2] += woo * -0.0005 |
|
|
|
x_d_new[0, 11, 1] += wink * 0.001 |
|
x_d_new[0, 13, 1] += wink * -0.0003 |
|
x_d_new[0, 17, 0] += wink * 0.0003 |
|
x_d_new[0, 17, 1] += wink * 0.0003 |
|
x_d_new[0, 3, 1] += wink * -0.0003 |
|
rotate_roll -= wink * 0.1 |
|
rotate_yaw -= wink * 0.1 |
|
|
|
if 0 < pupil_x: |
|
x_d_new[0, 11, 0] += pupil_x * 0.0007 |
|
x_d_new[0, 15, 0] += pupil_x * 0.001 |
|
else: |
|
x_d_new[0, 11, 0] += pupil_x * 0.001 |
|
x_d_new[0, 15, 0] += pupil_x * 0.0007 |
|
|
|
x_d_new[0, 11, 1] += pupil_y * -0.001 |
|
x_d_new[0, 15, 1] += pupil_y * -0.001 |
|
eyes -= pupil_y / 2. |
|
|
|
x_d_new[0, 11, 1] += eyes * -0.001 |
|
x_d_new[0, 13, 1] += eyes * 0.0003 |
|
x_d_new[0, 15, 1] += eyes * -0.001 |
|
x_d_new[0, 16, 1] += eyes * 0.0003 |
|
x_d_new[0, 1, 1] += eyes * -0.00025 |
|
x_d_new[0, 2, 1] += eyes * 0.00025 |
|
|
|
|
|
if 0 < eyebrow: |
|
x_d_new[0, 1, 1] += eyebrow * 0.001 |
|
x_d_new[0, 2, 1] += eyebrow * -0.001 |
|
else: |
|
x_d_new[0, 1, 0] += eyebrow * -0.001 |
|
x_d_new[0, 2, 0] += eyebrow * 0.001 |
|
x_d_new[0, 1, 1] += eyebrow * 0.0003 |
|
x_d_new[0, 2, 1] += eyebrow * -0.0003 |
|
|
|
|
|
return torch.Tensor([rotate_pitch, rotate_yaw, rotate_roll]) |
|
g_engine = LP_Engine() |
|
|
|
class ExpressionSet: |
|
def __init__(self, erst = None, es = None): |
|
if es != None: |
|
self.e = copy.deepcopy(es.e) |
|
self.r = copy.deepcopy(es.r) |
|
self.s = copy.deepcopy(es.s) |
|
self.t = copy.deepcopy(es.t) |
|
elif erst != None: |
|
self.e = erst[0] |
|
self.r = erst[1] |
|
self.s = erst[2] |
|
self.t = erst[3] |
|
else: |
|
self.e = torch.from_numpy(np.zeros((1, 21, 3))).float().to(get_device()) |
|
self.r = torch.Tensor([0, 0, 0]) |
|
self.s = 0 |
|
self.t = 0 |
|
def div(self, value): |
|
self.e /= value |
|
self.r /= value |
|
self.s /= value |
|
self.t /= value |
|
def add(self, other): |
|
self.e += other.e |
|
self.r += other.r |
|
self.s += other.s |
|
self.t += other.t |
|
def sub(self, other): |
|
self.e -= other.e |
|
self.r -= other.r |
|
self.s -= other.s |
|
self.t -= other.t |
|
def mul(self, value): |
|
self.e *= value |
|
self.r *= value |
|
self.s *= value |
|
self.t *= value |
|
|
|
|
|
|
|
def logging_time(original_fn): |
|
def wrapper_fn(*args, **kwargs): |
|
start_time = time.time() |
|
result = original_fn(*args, **kwargs) |
|
end_time = time.time() |
|
print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time - start_time)) |
|
return result |
|
|
|
return wrapper_fn |
|
|
|
|
|
|
|
exp_data_dir = os.path.join(folder_paths.output_directory, "exp_data") |
|
if os.path.isdir(exp_data_dir) == False: |
|
os.mkdir(exp_data_dir) |
|
class SaveExpData: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"file_name": ("STRING", {"multiline": False, "default": ""}), |
|
}, |
|
"optional": {"save_exp": ("EXP_DATA",), } |
|
} |
|
|
|
RETURN_TYPES = ("STRING",) |
|
RETURN_NAMES = ("file_name",) |
|
FUNCTION = "run" |
|
CATEGORY = "AdvancedLivePortrait" |
|
OUTPUT_NODE = True |
|
|
|
def run(self, file_name, save_exp:ExpressionSet=None): |
|
if save_exp == None or file_name == "": |
|
return file_name |
|
|
|
with open(os.path.join(exp_data_dir, file_name + ".exp"), "wb") as f: |
|
dill.dump(save_exp, f) |
|
|
|
return file_name |
|
|
|
class LoadExpData: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
file_list = [os.path.splitext(file)[0] for file in os.listdir(exp_data_dir) if file.endswith('.exp')] |
|
return {"required": { |
|
"file_name": (sorted(file_list, key=str.lower),), |
|
"ratio": ("FLOAT", {"default": 1, "min": 0, "max": 1, "step": 0.01}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("EXP_DATA",) |
|
RETURN_NAMES = ("exp",) |
|
FUNCTION = "run" |
|
CATEGORY = "AdvancedLivePortrait" |
|
|
|
def run(self, file_name, ratio): |
|
|
|
with open(os.path.join(exp_data_dir, file_name + ".exp"), 'rb') as f: |
|
es = dill.load(f) |
|
es.mul(ratio) |
|
return (es,) |
|
|
|
class ExpData: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required":{ |
|
|
|
"code1": ("INT", {"default": 0}), |
|
"value1": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), |
|
"code2": ("INT", {"default": 0}), |
|
"value2": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), |
|
"code3": ("INT", {"default": 0}), |
|
"value3": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), |
|
"code4": ("INT", {"default": 0}), |
|
"value4": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), |
|
"code5": ("INT", {"default": 0}), |
|
"value5": ("FLOAT", {"default": 0, "min": -100, "max": 100, "step": 0.1}), |
|
}, |
|
"optional":{"add_exp": ("EXP_DATA",),} |
|
} |
|
|
|
RETURN_TYPES = ("EXP_DATA",) |
|
RETURN_NAMES = ("exp",) |
|
FUNCTION = "run" |
|
CATEGORY = "AdvancedLivePortrait" |
|
|
|
def run(self, code1, value1, code2, value2, code3, value3, code4, value4, code5, value5, add_exp=None): |
|
if add_exp == None: |
|
es = ExpressionSet() |
|
else: |
|
es = ExpressionSet(es = add_exp) |
|
|
|
codes = [code1, code2, code3, code4, code5] |
|
values = [value1, value2, value3, value4, value5] |
|
for i in range(5): |
|
idx = int(codes[i] / 10) |
|
r = codes[i] % 10 |
|
es.e[0, idx, r] += values[i] * 0.001 |
|
|
|
return (es,) |
|
|
|
class PrintExpData: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"cut_noise": ("FLOAT", {"default": 0, "min": 0, "max": 100, "step": 0.1}), |
|
}, |
|
"optional": {"exp": ("EXP_DATA",), } |
|
} |
|
|
|
RETURN_TYPES = ("EXP_DATA",) |
|
RETURN_NAMES = ("exp",) |
|
FUNCTION = "run" |
|
CATEGORY = "AdvancedLivePortrait" |
|
OUTPUT_NODE = True |
|
|
|
def run(self, cut_noise, exp = None): |
|
if exp == None: return (exp,) |
|
|
|
cuted_list = [] |
|
e = exp.exp * 1000 |
|
for idx in range(21): |
|
for r in range(3): |
|
a = abs(e[0, idx, r]) |
|
if(cut_noise < a): cuted_list.append((a, e[0, idx, r], idx*10+r)) |
|
|
|
sorted_list = sorted(cuted_list, reverse=True, key=lambda item: item[0]) |
|
print(f"sorted_list: {[[item[2], round(float(item[1]),1)] for item in sorted_list]}") |
|
return (exp,) |
|
|
|
class Command: |
|
def __init__(self, es, change, keep): |
|
self.es:ExpressionSet = es |
|
self.change = change |
|
self.keep = keep |
|
|
|
crop_factor_default = 1.7 |
|
crop_factor_min = 1.5 |
|
crop_factor_max = 2.5 |
|
|
|
class AdvancedLivePortrait: |
|
def __init__(self): |
|
self.src_images = None |
|
self.driving_images = None |
|
self.pbar = comfy.utils.ProgressBar(1) |
|
self.crop_factor = None |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
|
|
return { |
|
"required": { |
|
"retargeting_eyes": ("FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01}), |
|
"retargeting_mouth": ("FLOAT", {"default": 0, "min": 0, "max": 1, "step": 0.01}), |
|
"crop_factor": ("FLOAT", {"default": crop_factor_default, |
|
"min": crop_factor_min, "max": crop_factor_max, "step": 0.1}), |
|
"turn_on": ("BOOLEAN", {"default": True}), |
|
"tracking_src_vid": ("BOOLEAN", {"default": False}), |
|
"animate_without_vid": ("BOOLEAN", {"default": False}), |
|
"command": ("STRING", {"multiline": True, "default": ""}), |
|
}, |
|
"optional": { |
|
"src_images": ("IMAGE",), |
|
"motion_link": ("EDITOR_LINK",), |
|
"driving_images": ("IMAGE",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
RETURN_NAMES = ("images",) |
|
FUNCTION = "run" |
|
OUTPUT_NODE = True |
|
CATEGORY = "AdvancedLivePortrait" |
|
|
|
|
|
|
|
|
|
def parsing_command(self, command, motoin_link): |
|
command.replace(' ', '') |
|
|
|
lines = command.split('\n') |
|
|
|
cmd_list = [] |
|
|
|
total_length = 0 |
|
|
|
i = 0 |
|
|
|
for line in lines: |
|
i += 1 |
|
if line == '': continue |
|
try: |
|
cmds = line.split('=') |
|
idx = int(cmds[0]) |
|
if idx == 0: es = ExpressionSet() |
|
else: es = ExpressionSet(es = motoin_link[idx]) |
|
cmds = cmds[1].split(':') |
|
change = int(cmds[0]) |
|
keep = int(cmds[1]) |
|
except: |
|
assert False, f"(AdvancedLivePortrait) Command Err Line {i}: {line}" |
|
|
|
|
|
return None, None |
|
|
|
total_length += change + keep |
|
es.div(change) |
|
cmd_list.append(Command(es, change, keep)) |
|
|
|
return cmd_list, total_length |
|
|
|
|
|
def run(self, retargeting_eyes, retargeting_mouth, turn_on, tracking_src_vid, animate_without_vid, command, crop_factor, |
|
src_images=None, driving_images=None, motion_link=None): |
|
if turn_on == False: return (None,None) |
|
src_length = 1 |
|
|
|
if src_images == None: |
|
if motion_link != None: |
|
self.psi_list = [motion_link[0]] |
|
else: return (None,None) |
|
|
|
if src_images != None: |
|
src_length = len(src_images) |
|
if id(src_images) != id(self.src_images) or self.crop_factor != crop_factor: |
|
self.crop_factor = crop_factor |
|
self.src_images = src_images |
|
if 1 < src_length: |
|
self.psi_list = g_engine.prepare_source(src_images, crop_factor, True, tracking_src_vid) |
|
else: |
|
self.psi_list = [g_engine.prepare_source(src_images, crop_factor)] |
|
|
|
|
|
cmd_list, cmd_length = self.parsing_command(command, motion_link) |
|
if cmd_list == None: return (None,None) |
|
cmd_idx = 0 |
|
|
|
driving_length = 0 |
|
if driving_images is not None: |
|
if id(driving_images) != id(self.driving_images): |
|
self.driving_images = driving_images |
|
self.driving_values = g_engine.prepare_driving_video(driving_images) |
|
driving_length = len(self.driving_values) |
|
|
|
total_length = max(driving_length, src_length) |
|
|
|
if animate_without_vid: |
|
total_length = max(total_length, cmd_length) |
|
|
|
c_i_es = ExpressionSet() |
|
c_o_es = ExpressionSet() |
|
d_0_es = None |
|
out_list = [] |
|
|
|
psi = None |
|
pipeline = g_engine.get_pipeline() |
|
for i in range(total_length): |
|
|
|
if i < src_length: |
|
psi = self.psi_list[i] |
|
s_info = psi.x_s_info |
|
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t'])) |
|
|
|
new_es = ExpressionSet(es = s_es) |
|
|
|
if i < cmd_length: |
|
cmd = cmd_list[cmd_idx] |
|
if 0 < cmd.change: |
|
cmd.change -= 1 |
|
c_i_es.add(cmd.es) |
|
c_i_es.sub(c_o_es) |
|
elif 0 < cmd.keep: |
|
cmd.keep -= 1 |
|
|
|
new_es.add(c_i_es) |
|
|
|
if cmd.change == 0 and cmd.keep == 0: |
|
cmd_idx += 1 |
|
if cmd_idx < len(cmd_list): |
|
c_o_es = ExpressionSet(es = c_i_es) |
|
cmd = cmd_list[cmd_idx] |
|
c_o_es.div(cmd.change) |
|
elif 0 < cmd_length: |
|
new_es.add(c_i_es) |
|
|
|
if i < driving_length: |
|
d_i_info = self.driving_values[i] |
|
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) |
|
|
|
if d_0_es is None: |
|
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t'])) |
|
|
|
retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16)) |
|
retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20)) |
|
|
|
new_es.e += d_i_info['exp'] - d_0_es.e |
|
new_es.r += d_i_r - d_0_es.r |
|
new_es.t += d_i_info['t'] - d_0_es.t |
|
|
|
r_new = get_rotation_matrix( |
|
s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2]) |
|
d_new = new_es.s * (new_es.e @ r_new) + new_es.t |
|
d_new = pipeline.stitching(psi.x_s_user, d_new) |
|
crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new) |
|
crop_out = pipeline.parse_output(crop_out['out'])[0] |
|
|
|
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), |
|
cv2.INTER_LINEAR) |
|
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype( |
|
np.uint8) |
|
out_list.append(out) |
|
|
|
self.pbar.update_absolute(i+1, total_length, ("PNG", Image.fromarray(crop_out), None)) |
|
|
|
if len(out_list) == 0: return (None,) |
|
|
|
out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list]) |
|
return (out_imgs,) |
|
|
|
class ExpressionEditor: |
|
def __init__(self): |
|
self.sample_image = None |
|
self.src_image = None |
|
self.crop_factor = None |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
display = "number" |
|
|
|
return { |
|
"required": { |
|
|
|
"rotate_pitch": ("FLOAT", {"default": 0, "min": -20, "max": 20, "step": 0.5, "display": display}), |
|
"rotate_yaw": ("FLOAT", {"default": 0, "min": -20, "max": 20, "step": 0.5, "display": display}), |
|
"rotate_roll": ("FLOAT", {"default": 0, "min": -20, "max": 20, "step": 0.5, "display": display}), |
|
|
|
"blink": ("FLOAT", {"default": 0, "min": -20, "max": 5, "step": 0.5, "display": display}), |
|
"eyebrow": ("FLOAT", {"default": 0, "min": -10, "max": 15, "step": 0.5, "display": display}), |
|
"wink": ("FLOAT", {"default": 0, "min": 0, "max": 25, "step": 0.5, "display": display}), |
|
"pupil_x": ("FLOAT", {"default": 0, "min": -15, "max": 15, "step": 0.5, "display": display}), |
|
"pupil_y": ("FLOAT", {"default": 0, "min": -15, "max": 15, "step": 0.5, "display": display}), |
|
"aaa": ("FLOAT", {"default": 0, "min": -30, "max": 120, "step": 1, "display": display}), |
|
"eee": ("FLOAT", {"default": 0, "min": -20, "max": 15, "step": 0.2, "display": display}), |
|
"woo": ("FLOAT", {"default": 0, "min": -20, "max": 15, "step": 0.2, "display": display}), |
|
"smile": ("FLOAT", {"default": 0, "min": -0.3, "max": 1.3, "step": 0.01, "display": display}), |
|
|
|
"src_ratio": ("FLOAT", {"default": 1, "min": 0, "max": 1, "step": 0.01, "display": display}), |
|
"sample_ratio": ("FLOAT", {"default": 1, "min": -0.2, "max": 1.2, "step": 0.01, "display": display}), |
|
"sample_parts": (["OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"],), |
|
"crop_factor": ("FLOAT", {"default": crop_factor_default, |
|
"min": crop_factor_min, "max": crop_factor_max, "step": 0.1}), |
|
}, |
|
|
|
"optional": {"src_image": ("IMAGE",), "motion_link": ("EDITOR_LINK",), |
|
"sample_image": ("IMAGE",), "add_exp": ("EXP_DATA",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", "EDITOR_LINK", "EXP_DATA") |
|
RETURN_NAMES = ("image", "motion_link", "save_exp") |
|
|
|
FUNCTION = "run" |
|
|
|
OUTPUT_NODE = True |
|
|
|
CATEGORY = "AdvancedLivePortrait" |
|
|
|
|
|
|
|
|
|
def run(self, rotate_pitch, rotate_yaw, rotate_roll, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile, |
|
src_ratio, sample_ratio, sample_parts, crop_factor, src_image=None, sample_image=None, motion_link=None, add_exp=None): |
|
rotate_yaw = -rotate_yaw |
|
|
|
new_editor_link = None |
|
if motion_link != None: |
|
self.psi = motion_link[0] |
|
new_editor_link = motion_link.copy() |
|
elif src_image != None: |
|
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor: |
|
self.crop_factor = crop_factor |
|
self.psi = g_engine.prepare_source(src_image, crop_factor) |
|
self.src_image = src_image |
|
new_editor_link = [] |
|
new_editor_link.append(self.psi) |
|
else: |
|
return (None,None) |
|
|
|
pipeline = g_engine.get_pipeline() |
|
|
|
psi = self.psi |
|
s_info = psi.x_s_info |
|
|
|
s_exp = s_info['exp'] * src_ratio |
|
s_exp[0, 5] = s_info['exp'][0, 5] |
|
s_exp += s_info['kp'] |
|
|
|
es = ExpressionSet() |
|
|
|
if sample_image != None: |
|
if id(self.sample_image) != id(sample_image): |
|
self.sample_image = sample_image |
|
d_image_np = (sample_image * 255).byte().numpy() |
|
d_face = g_engine.crop_face(d_image_np[0], 1.7) |
|
i_d = g_engine.prepare_src_image(d_face) |
|
self.d_info = pipeline.get_kp_info(i_d) |
|
self.d_info['exp'][0, 5, 0] = 0 |
|
self.d_info['exp'][0, 5, 1] = 0 |
|
|
|
|
|
if sample_parts == "OnlyExpression" or sample_parts == "All": |
|
es.e += self.d_info['exp'] * sample_ratio |
|
if sample_parts == "OnlyRotation" or sample_parts == "All": |
|
rotate_pitch += self.d_info['pitch'] * sample_ratio |
|
rotate_yaw += self.d_info['yaw'] * sample_ratio |
|
rotate_roll += self.d_info['roll'] * sample_ratio |
|
elif sample_parts == "OnlyMouth": |
|
retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20)) |
|
elif sample_parts == "OnlyEyes": |
|
retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16)) |
|
|
|
es.r = g_engine.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile, |
|
rotate_pitch, rotate_yaw, rotate_roll) |
|
|
|
if add_exp != None: |
|
es.add(add_exp) |
|
|
|
new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1], |
|
s_info['roll'] + es.r[2]) |
|
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t'] |
|
|
|
x_d_new = pipeline.stitching(psi.x_s_user, x_d_new) |
|
|
|
crop_out = pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new) |
|
crop_out = pipeline.parse_output(crop_out['out'])[0] |
|
|
|
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR) |
|
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8) |
|
|
|
out_img = pil2tensor(out) |
|
|
|
filename = g_engine.get_temp_img_name() |
|
folder_paths.get_save_image_path(filename, folder_paths.get_temp_directory()) |
|
img = Image.fromarray(crop_out) |
|
img.save(os.path.join(folder_paths.get_temp_directory(), filename), compress_level=1) |
|
results = list() |
|
results.append({"filename": filename, "type": "temp"}) |
|
|
|
new_editor_link.append(es) |
|
|
|
return {"ui": {"images": results}, "result": (out_img, new_editor_link, es)} |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"AdvancedLivePortrait": AdvancedLivePortrait, |
|
"ExpressionEditor": ExpressionEditor, |
|
"LoadExpData": LoadExpData, |
|
"SaveExpData": SaveExpData, |
|
"ExpData": ExpData, |
|
"PrintExpData:": PrintExpData, |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"AdvancedLivePortrait": "Advanced Live Portrait (PHM)", |
|
"ExpressionEditor": "Expression Editor (PHM)", |
|
"LoadExpData": "Load Exp Data (PHM)", |
|
"SaveExpData": "Save Exp Data (PHM)" |
|
} |