AlekseyKorshuk's picture
feat: updates
f844f44
import torch
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
import warnings
import os
from demo import make_animation
from skimage import img_as_ubyte
from demo import load_checkpoints
import gradio
def inference(source_image_path='./assets/source.png', driving_video_path='./assets/driving.mp4', dataset_name="vox"):
# edit the config
device = torch.device('cpu')
# dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']
# source_image_path = './assets/source.png'
# driving_video_path = './assets/driving.mp4'
output_video_path = './generated.mp4'
pixel = 256 # for vox, taichi and mgif, the resolution is 256*256
if (dataset_name == 'ted'): # for ted, the resolution is 384*384
pixel = 384
config_path = f'config/{dataset_name}-{pixel}.yaml'
checkpoint_path = f'checkpoints/{dataset_name}.pth.tar'
predict_mode = 'relative' # ['standard', 'relative', 'avd']
warnings.filterwarnings("ignore")
source_image = imageio.imread(source_image_path)
reader = imageio.get_reader(driving_video_path)
source_image = resize(source_image, (pixel, pixel))[..., :3]
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]
# driving_video = driving_video[:10]
def display(source, driving, generated=None) -> animation.ArtistAnimation:
fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))
ims = []
for i in range(len(driving)):
cols = [source]
cols.append(driving[i])
if generated is not None:
cols.append(generated[i])
im = plt.imshow(np.concatenate(cols, axis=1), animated=True)
plt.axis('off')
ims.append([im])
ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)
# plt.show()
plt.close()
return ani
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path=config_path,
checkpoint_path=checkpoint_path,
device=device)
predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network,
avd_network, device=device, mode=predict_mode)
# save resulting video
imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
ani = display(source_image, driving_video, predictions)
ani.save('animation.mp4', writer='imagemagick', fps=60)
return 'animation.mp4'
demo = gradio.Interface(
fn=inference,
inputs=[
gradio.inputs.Image(type="filepath", label="Input image"),
gradio.inputs.Video(label="Input video"),
gradio.inputs.Dropdown(['vox', 'taichi', 'ted', 'mgif'], type="value", default="vox", label="Model",
optional=False),
],
outputs=["video"],
examples=[
['./assets/source.png', './assets/driving.mp4', "vox"],
['./assets/source_ted.png', './assets/driving_ted.mp4', "ted"],
],
)
if __name__ == "__main__":
demo.launch()