KDTalker / gradio_app.py
fffiloni's picture
correct endpoint description
270da5d verified
# -*- coding: UTF-8 -*-
import os
os.environ['HYDRA_FULL_ERROR']='1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from huggingface_hub import snapshot_download
# Download weights
snapshot_download(
repo_id = "ChaolongYang/KDTalker",
local_dir = "./"
)
import argparse
import shutil
import uuid
import os
import numpy as np
from tqdm import tqdm
import cv2
from rich.progress import track
import tyro
import gradio as gr
from PIL import Image
import time
import torch
import torch.nn.functional as F
from torch import nn
import imageio
from pydub import AudioSegment
from pykalman import KalmanFilter
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline
from src.utils.camera import get_rotation_matrix
from dataset_process import audio
from dataset_process.croper import Croper
import spaces
def parse_audio_length(audio_length, sr, fps):
bit_per_frames = sr / fps
num_frames = int(audio_length / bit_per_frames)
audio_length = int(num_frames * bit_per_frames)
return audio_length, num_frames
def crop_pad_audio(wav, audio_length):
if len(wav) > audio_length:
wav = wav[:audio_length]
elif len(wav) < audio_length:
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
return wav
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
self.use_act = use_act
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
if self.use_act:
return self.act(out)
else:
return out
class AudioEncoder(nn.Module):
def __init__(self, wav2lip_checkpoint, device):
super(AudioEncoder, self).__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
#### load the pre-trained audio_encoder
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
state_dict = self.audio_encoder.state_dict()
for k,v in wav2lip_state_dict.items():
if 'audio_encoder' in k:
state_dict[k.replace('module.audio_encoder.', '')] = v
self.audio_encoder.load_state_dict(state_dict)
def forward(self, audio_sequences):
B = audio_sequences.size(0)
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
dim = audio_embedding.shape[1]
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def dct2device(dct: dict, device):
for key in dct:
dct[key] = torch.tensor(dct[key]).to(device)
return dct
def save_video_with_watermark(video, audio, save_path):
temp_file = str(uuid.uuid4())+'.mp4'
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
os.system(cmd)
shutil.move(temp_file, save_path)
class Inferencer(object):
def __init__(self):
st=time.time()
print('#'*25+'Start initialization'+'#'*25)
self.device = 'cuda'
from model import get_model
self.point_diffusion = get_model()
ckpt = torch.load('ckpts/KDTalker.pth')
self.point_diffusion.load_state_dict(ckpt['model'])
self.point_diffusion.eval()
self.point_diffusion.to(self.device)
lm_croper_checkpoint = 'ckpts/shape_predictor_68_face_landmarks.dat'
self.croper = Croper(lm_croper_checkpoint)
self.norm_info = dict(np.load('dataset_process/norm.npz'))
wav2lip_checkpoint = 'ckpts/wav2lip.pth'
self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda')
self.wav2lip_model.cuda()
self.wav2lip_model.eval()
# set tyro theme
tyro.extras.set_accent_color("bright_cyan")
args = tyro.cli(ArgumentConfig)
# specify configs for inference
self.inf_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
self.crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
self.live_portrait_pipeline = LivePortraitPipeline(inference_cfg=self.inf_cfg, crop_cfg=self.crop_cfg)
def _norm(self, data_dict):
for k in data_dict.keys():
if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp', ]:
v=data_dict[k]
data_dict[k] = (v - self.norm_info[k+'_mean'])/self.norm_info[k+'_std']
return data_dict
def _denorm(self, data_dict):
for k in data_dict.keys():
if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp']:
v=data_dict[k]
data_dict[k] = v * self.norm_info[k+'_std'] + self.norm_info[k+'_mean']
return data_dict
def output_to_dict(self, data):
output = {}
output['scale'] = data[:, 0]
output['yaw'] = data[:, 1, None]
output['pitch'] = data[:, 2, None]
output['roll'] = data[:, 3, None]
output['t'] = data[:, 4:7]
output['exp'] = data[:, 7:]
return output
def extract_mel_from_audio(self, audio_file_path):
syncnet_mel_step_size = 16
fps = 25
wav = audio.load_wav(audio_file_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy()
indiv_mels = []
for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i - 2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [min(max(item, 0), orig_mel.shape[0] - 1) for item in seq]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels) # T 80 16
return indiv_mels
def extract_wav2lip_from_audio(self, audio_file_path):
asd_mel = self.extract_mel_from_audio(audio_file_path)
asd_mel = torch.FloatTensor(asd_mel).cuda().unsqueeze(0).unsqueeze(2)
with torch.no_grad():
hidden = self.wav2lip_model(asd_mel)
return hidden[0].cpu().detach().numpy()
def headpose_pred_to_degree(self, pred):
device = pred.device
idx_tensor = [idx for idx in range(66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred)
degree = torch.sum(pred * idx_tensor, 1) * 3 - 99
return degree
@torch.no_grad()
def generate_with_audio_img(self, image_path, audio_path, save_path):
image = np.array(Image.open(image_path).convert('RGB'))
cropped_image, crop, quad = self.croper.crop([image], still=False, xsize=512)
input_image = cv2.resize(cropped_image[0], (256, 256))
I_s = torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255
x_s_info = self.live_portrait_pipeline.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info['kp'].reshape(1, 21, -1)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_pipeline.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_pipeline.live_portrait_wrapper.transform_keypoint(x_s_info)
######## process driving info ########
kp_info = {}
for k in x_s_info.keys():
kp_info[k] = x_s_info[k].cpu().numpy()
kp_info = self._norm(kp_info)
ori_kp = torch.cat([torch.zeros([1, 7]), torch.Tensor(kp_info['kp'])], -1).cuda()
input_x = np.concatenate([kp_info[k] for k in ['scale', 'yaw', 'pitch', 'roll', 't', 'exp']], 1)
input_x = np.expand_dims(input_x, -1)
input_x = np.expand_dims(input_x, 0)
input_x = np.concatenate([input_x, input_x, input_x], -1)
aud_feat = self.extract_wav2lip_from_audio(audio_path)
sample_frame = 64
padding_size = (sample_frame - aud_feat.shape[0] % sample_frame) % sample_frame
if padding_size > 0:
aud_feat = np.concatenate((aud_feat, aud_feat[:padding_size, :]), axis=0)
else:
aud_feat = aud_feat
outputs = [input_x]
sample_frame = 64
for i in range(0, aud_feat.shape[0] - 1, sample_frame):
input_mel = torch.Tensor(aud_feat[i: i + sample_frame]).unsqueeze(0).cuda()
kp0 = torch.Tensor(outputs[-1])[:, -1].cuda()
pred_kp = self.point_diffusion.forward_sample(70, ref_kps=kp0, ori_kps=ori_kp, aud_feat=input_mel,
scheduler='ddim', num_inference_steps=50)
outputs.append(pred_kp.cpu().numpy())
outputs = np.mean(np.concatenate(outputs, 1)[0, 1:aud_feat.shape[0] - padding_size + 1], -1)
output_dict = self.output_to_dict(outputs)
output_dict = self._denorm(output_dict)
num_frame = output_dict['yaw'].shape[0]
x_d_info = {}
for key in output_dict:
x_d_info[key] = torch.tensor(output_dict[key]).cuda()
# smooth
def smooth(sequence, n_dim_state=1):
kf = KalmanFilter(initial_state_mean=sequence[0],
transition_covariance=0.05 * np.eye(n_dim_state),
observation_covariance=0.001 * np.eye(n_dim_state))
state_means, _ = kf.smooth(sequence)
return state_means
yaw_data = x_d_info['yaw'].cpu().numpy()
pitch_data = x_d_info['pitch'].cpu().numpy()
roll_data = x_d_info['roll'].cpu().numpy()
t_data = x_d_info['t'].cpu().numpy()
exp_data = x_d_info['exp'].cpu().numpy()
smoothed_pitch = smooth(pitch_data, n_dim_state=1)
smoothed_yaw = smooth(yaw_data, n_dim_state=1)
smoothed_roll = smooth(roll_data, n_dim_state=1)
smoothed_t = smooth(t_data, n_dim_state=3)
smoothed_exp = smooth(exp_data, n_dim_state=63)
x_d_info['pitch'] = torch.Tensor(smoothed_pitch).cuda()
x_d_info['yaw'] = torch.Tensor(smoothed_yaw).cuda()
x_d_info['roll'] = torch.Tensor(smoothed_roll).cuda()
x_d_info['t'] = torch.Tensor(smoothed_t).cuda()
x_d_info['exp'] = torch.Tensor(smoothed_exp).cuda()
template_dct = {'motion': [], 'c_d_eyes_lst': [], 'c_d_lip_lst': []}
for i in track(range(num_frame), description='Making motion templates...', total=num_frame):
x_d_i_info = x_d_info
R_d_i = get_rotation_matrix(x_d_i_info['pitch'][i], x_d_i_info['yaw'][i], x_d_i_info['roll'][i])
item_dct = {
'scale': x_d_i_info['scale'][i].cpu().numpy().astype(np.float32),
'R_d': R_d_i.cpu().numpy().astype(np.float32),
'exp': x_d_i_info['exp'][i].reshape(1, 21, -1).cpu().numpy().astype(np.float32),
't': x_d_i_info['t'][i].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
I_p_lst = []
R_d_0, x_d_0_info = None, None
for i in track(range(num_frame), description='🚀Animating...', total=num_frame):
x_d_i_info = template_dct['motion'][i]
for key in x_d_i_info:
x_d_i_info[key] = torch.tensor(x_d_i_info[key]).cuda()
R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if self.inf_cfg.flag_relative_motion:
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
R_new = R_d_i
delta_new = x_d_i_info['exp']
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
t_new[..., 2].fill_(0)
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0]
I_p_lst.append(I_p_i)
video_name = save_path.split('/')[-1]
video_save_dir = os.path.dirname(save_path)
path = os.path.join(video_save_dir, 'temp_' + video_name)
imageio.mimsave(path, I_p_lst, fps=float(25))
audio_name = audio_path.split('/')[-1]
new_audio_path = os.path.join(video_save_dir, audio_name)
start_time = 0
sound = AudioSegment.from_file(audio_path)
end_time = start_time + num_frame * 1 / 25 * 1000
word1 = sound.set_frame_rate(16000)
word = word1[start_time:end_time]
word.export(new_audio_path, format="wav")
save_video_with_watermark(path, new_audio_path, save_path)
print(f'The generated video is named {video_save_dir}/{video_name}')
os.remove(path)
os.remove(new_audio_path)
@spaces.GPU()
def gradio_infer(source_image, driven_audio):
"""
Generate a talking-head video from a static source image and an audio file.
This function uses a pre-trained motion and lip-sync model to animate a face image so that it
appears to speak in sync with a given audio clip. The resulting video is saved
and returned.
Args:
source_image: A path to an input image (or uploaded image) of a person's face
that will be animated.
driven_audio: A path to an audio file (or uploaded audio) that will drive the
lip-sync and head movement of the animation.
Returns:
A file path to the generated .mp4 video, which shows the input face animated
to speak and move in sync with the audio.
Workflow:
1. Load and initialize the animation pipeline (Inferencer).
2. Process the image and audio.
3. Generate a talking-head animation using lip-sync and motion synthesis models.
4. Combine generated video frames with the original audio.
5. Return the video path to be displayed or downloaded.
"""
import tempfile
temp_dir = tempfile.mkdtemp()
output_path = f"{temp_dir}/output.mp4"
Infer = Inferencer()
Infer.generate_with_audio_img(source_image, driven_audio, output_path)
return output_path
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# KDTalker")
gr.Markdown("Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/chaolongy/KDTalker">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://arxiv.org/abs/2503.12963">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/KDTalker?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
</div>
""")
with gr.Row():
with gr.Column():
source_image = gr.Image(label="Source Image", type="filepath")
driven_audio = gr.Audio(label="Driven Audio", type="filepath")
submit_btn = gr.Button("Submit")
gr.Examples(
examples = [
["example/source_image/WDA_BenCardin1_000.png", "example/audio_driven/WDA_BenCardin1_000.wav"],
],
inputs = [source_image, driven_audio],
cache_examples = False
)
with gr.Column():
output_video = gr.Video(label="Output Video")
submit_btn.click(
fn = gradio_infer,
inputs = [source_image, driven_audio],
outputs = [output_video]
)
demo.launch(mcp_server=True)