{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Config**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# edit the config\n", "device = torch.device('cuda:0')\n", "dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']\n", "source_image_path = './assets/source.png'\n", "driving_video_path = './assets/driving.mp4'\n", "output_video_path = './generated.mp4'\n", "config_path = 'config/vox-256.yaml'\n", "checkpoint_path = 'checkpoints/vox.pth.tar'\n", "predict_mode = 'relative' # ['standard', 'relative', 'avd']\n", "find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result\n", "\n", "pixel = 256 # for vox, taichi and mgif, the resolution is 256*256\n", "if(dataset_name == 'ted'): # for ted, the resolution is 384*384\n", " pixel = 384\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Read image and video**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 453 }, "id": "Oxi6-riLOgnm", "outputId": "d38a8850-9eb1-4de4-9bf2-24cbd847ca1f" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import imageio\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.animation as animation\n", "from skimage.transform import resize\n", "from IPython.display import HTML\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "source_image = imageio.imread(source_image_path)\n", "reader = imageio.get_reader(driving_video_path)\n", "\n", "\n", "source_image = resize(source_image, (pixel, pixel))[..., :3]\n", "\n", "fps = reader.get_meta_data()['fps']\n", "driving_video = []\n", "try:\n", " for im in reader:\n", " driving_video.append(im)\n", "except RuntimeError:\n", " pass\n", "reader.close()\n", "\n", "driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]\n", "\n", "def display(source, driving, generated=None):\n", " fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))\n", "\n", " ims = []\n", " for i in range(len(driving)):\n", " cols = [source]\n", " cols.append(driving[i])\n", " if generated is not None:\n", " cols.append(generated[i])\n", " im = plt.imshow(np.concatenate(cols, axis=1), animated=True)\n", " plt.axis('off')\n", " ims.append([im])\n", "\n", " ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)\n", " plt.close()\n", " return ani\n", " \n", "\n", "HTML(display(source_image, driving_video).to_html5_video())" ] }, { "cell_type": "markdown", "metadata": { "id": "xjM7ubVfWrwT" }, "source": [ "**Create a model and load checkpoints**" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "3FQiXqQPWt5B" }, "outputs": [], "source": [ "from demo import load_checkpoints\n", "inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)" ] }, { "cell_type": "markdown", "metadata": { "id": "fdFdasHEj3t7" }, "source": [ "**Perform image animation**" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 471 }, "id": "SB12II11kF4c", "outputId": "9e2274aa-fd55-4eed-cb50-bec72fcfb8b9" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:10<00:00, 15.69it/s]\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from demo import make_animation\n", "from skimage import img_as_ubyte\n", "\n", "if predict_mode=='relative' and find_best_frame:\n", " from demo import find_best_frame as _find\n", " i = _find(source_image, driving_video, device.type=='cpu')\n", " print (\"Best frame: \" + str(i))\n", " driving_forward = driving_video[i:]\n", " driving_backward = driving_video[:(i+1)][::-1]\n", " predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", " predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", " predictions = predictions_backward[::-1] + predictions_forward[1:]\n", "else:\n", " predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", "\n", "#save resulting video\n", "imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)\n", "\n", "HTML(display(source_image, driving_video, predictions).to_html5_video())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "include_colab_link": true, "name": "first-order-model-demo.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" } }, "nbformat": 4, "nbformat_minor": 4 }