{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Config**"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# edit the config\n",
"device = torch.device('cuda:0')\n",
"dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']\n",
"source_image_path = './assets/source.png'\n",
"driving_video_path = './assets/driving.mp4'\n",
"output_video_path = './generated.mp4'\n",
"config_path = 'config/vox-256.yaml'\n",
"checkpoint_path = 'checkpoints/vox.pth.tar'\n",
"predict_mode = 'relative' # ['standard', 'relative', 'avd']\n",
"find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result\n",
"\n",
"pixel = 256 # for vox, taichi and mgif, the resolution is 256*256\n",
"if(dataset_name == 'ted'): # for ted, the resolution is 384*384\n",
" pixel = 384\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Read image and video**"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 453
},
"id": "Oxi6-riLOgnm",
"outputId": "d38a8850-9eb1-4de4-9bf2-24cbd847ca1f"
},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import imageio\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.animation as animation\n",
"from skimage.transform import resize\n",
"from IPython.display import HTML\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"source_image = imageio.imread(source_image_path)\n",
"reader = imageio.get_reader(driving_video_path)\n",
"\n",
"\n",
"source_image = resize(source_image, (pixel, pixel))[..., :3]\n",
"\n",
"fps = reader.get_meta_data()['fps']\n",
"driving_video = []\n",
"try:\n",
" for im in reader:\n",
" driving_video.append(im)\n",
"except RuntimeError:\n",
" pass\n",
"reader.close()\n",
"\n",
"driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]\n",
"\n",
"def display(source, driving, generated=None):\n",
" fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))\n",
"\n",
" ims = []\n",
" for i in range(len(driving)):\n",
" cols = [source]\n",
" cols.append(driving[i])\n",
" if generated is not None:\n",
" cols.append(generated[i])\n",
" im = plt.imshow(np.concatenate(cols, axis=1), animated=True)\n",
" plt.axis('off')\n",
" ims.append([im])\n",
"\n",
" ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)\n",
" plt.close()\n",
" return ani\n",
" \n",
"\n",
"HTML(display(source_image, driving_video).to_html5_video())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xjM7ubVfWrwT"
},
"source": [
"**Create a model and load checkpoints**"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "3FQiXqQPWt5B"
},
"outputs": [],
"source": [
"from demo import load_checkpoints\n",
"inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fdFdasHEj3t7"
},
"source": [
"**Perform image animation**"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 471
},
"id": "SB12II11kF4c",
"outputId": "9e2274aa-fd55-4eed-cb50-bec72fcfb8b9"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:10<00:00, 15.69it/s]\n"
]
},
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from demo import make_animation\n",
"from skimage import img_as_ubyte\n",
"\n",
"if predict_mode=='relative' and find_best_frame:\n",
" from demo import find_best_frame as _find\n",
" i = _find(source_image, driving_video, device.type=='cpu')\n",
" print (\"Best frame: \" + str(i))\n",
" driving_forward = driving_video[i:]\n",
" driving_backward = driving_video[:(i+1)][::-1]\n",
" predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
" predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
" predictions = predictions_backward[::-1] + predictions_forward[1:]\n",
"else:\n",
" predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
"\n",
"#save resulting video\n",
"imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\n",
"HTML(display(source_image, driving_video, predictions).to_html5_video())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "first-order-model-demo.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}