diff --git "a/demo.ipynb" "b/demo.ipynb" new file mode 100644--- /dev/null +++ "b/demo.ipynb" @@ -0,0 +1,5113 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Config**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# edit the config\n", + "device = torch.device('cuda:0')\n", + "dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']\n", + "source_image_path = './assets/source.png'\n", + "driving_video_path = './assets/driving.mp4'\n", + "output_video_path = './generated.mp4'\n", + "config_path = 'config/vox-256.yaml'\n", + "checkpoint_path = 'checkpoints/vox.pth.tar'\n", + "predict_mode = 'relative' # ['standard', 'relative', 'avd']\n", + "find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result\n", + "\n", + "pixel = 256 # for vox, taichi and mgif, the resolution is 256*256\n", + "if(dataset_name == 'ted'): # for ted, the resolution is 384*384\n", + " pixel = 384\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Read image and video**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 453 + }, + "id": "Oxi6-riLOgnm", + "outputId": "d38a8850-9eb1-4de4-9bf2-24cbd847ca1f" + }, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import imageio\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.animation as animation\n", + "from skimage.transform import resize\n", + "from IPython.display import HTML\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "source_image = imageio.imread(source_image_path)\n", + "reader = imageio.get_reader(driving_video_path)\n", + "\n", + "\n", + "source_image = resize(source_image, (pixel, pixel))[..., :3]\n", + "\n", + "fps = reader.get_meta_data()['fps']\n", + "driving_video = []\n", + "try:\n", + " for im in reader:\n", + " driving_video.append(im)\n", + "except RuntimeError:\n", + " pass\n", + "reader.close()\n", + "\n", + "driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]\n", + "\n", + "def display(source, driving, generated=None):\n", + " fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))\n", + "\n", + " ims = []\n", + " for i in range(len(driving)):\n", + " cols = [source]\n", + " cols.append(driving[i])\n", + " if generated is not None:\n", + " cols.append(generated[i])\n", + " im = plt.imshow(np.concatenate(cols, axis=1), animated=True)\n", + " plt.axis('off')\n", + " ims.append([im])\n", + "\n", + " ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)\n", + " plt.close()\n", + " return ani\n", + " \n", + "\n", + "HTML(display(source_image, driving_video).to_html5_video())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xjM7ubVfWrwT" + }, + "source": [ + "**Create a model and load checkpoints**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "3FQiXqQPWt5B" + }, + "outputs": [], + "source": [ + "from demo import load_checkpoints\n", + "inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fdFdasHEj3t7" + }, + "source": [ + "**Perform image animation**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 471 + }, + "id": "SB12II11kF4c", + "outputId": "9e2274aa-fd55-4eed-cb50-bec72fcfb8b9" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:10<00:00, 15.69it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from demo import make_animation\n", + "from skimage import img_as_ubyte\n", + "\n", + "if predict_mode=='relative' and find_best_frame:\n", + " from demo import find_best_frame as _find\n", + " i = _find(source_image, driving_video, device.type=='cpu')\n", + " print (\"Best frame: \" + str(i))\n", + " driving_forward = driving_video[i:]\n", + " driving_backward = driving_video[:(i+1)][::-1]\n", + " predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", + " predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", + " predictions = predictions_backward[::-1] + predictions_forward[1:]\n", + "else:\n", + " predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n", + "\n", + "#save resulting video\n", + "imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)\n", + "\n", + "HTML(display(source_image, driving_video, predictions).to_html5_video())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "name": "first-order-model-demo.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}