diff --git "a/demo.ipynb" "b/demo.ipynb"
new file mode 100644--- /dev/null
+++ "b/demo.ipynb"
@@ -0,0 +1,5113 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Config**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "# edit the config\n",
+ "device = torch.device('cuda:0')\n",
+ "dataset_name = 'vox' # ['vox', 'taichi', 'ted', 'mgif']\n",
+ "source_image_path = './assets/source.png'\n",
+ "driving_video_path = './assets/driving.mp4'\n",
+ "output_video_path = './generated.mp4'\n",
+ "config_path = 'config/vox-256.yaml'\n",
+ "checkpoint_path = 'checkpoints/vox.pth.tar'\n",
+ "predict_mode = 'relative' # ['standard', 'relative', 'avd']\n",
+ "find_best_frame = False # when use the relative mode to animate a face, use 'find_best_frame=True' can get better quality result\n",
+ "\n",
+ "pixel = 256 # for vox, taichi and mgif, the resolution is 256*256\n",
+ "if(dataset_name == 'ted'): # for ted, the resolution is 384*384\n",
+ " pixel = 384\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Read image and video**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 453
+ },
+ "id": "Oxi6-riLOgnm",
+ "outputId": "d38a8850-9eb1-4de4-9bf2-24cbd847ca1f"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import imageio\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.animation as animation\n",
+ "from skimage.transform import resize\n",
+ "from IPython.display import HTML\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "source_image = imageio.imread(source_image_path)\n",
+ "reader = imageio.get_reader(driving_video_path)\n",
+ "\n",
+ "\n",
+ "source_image = resize(source_image, (pixel, pixel))[..., :3]\n",
+ "\n",
+ "fps = reader.get_meta_data()['fps']\n",
+ "driving_video = []\n",
+ "try:\n",
+ " for im in reader:\n",
+ " driving_video.append(im)\n",
+ "except RuntimeError:\n",
+ " pass\n",
+ "reader.close()\n",
+ "\n",
+ "driving_video = [resize(frame, (pixel, pixel))[..., :3] for frame in driving_video]\n",
+ "\n",
+ "def display(source, driving, generated=None):\n",
+ " fig = plt.figure(figsize=(8 + 4 * (generated is not None), 6))\n",
+ "\n",
+ " ims = []\n",
+ " for i in range(len(driving)):\n",
+ " cols = [source]\n",
+ " cols.append(driving[i])\n",
+ " if generated is not None:\n",
+ " cols.append(generated[i])\n",
+ " im = plt.imshow(np.concatenate(cols, axis=1), animated=True)\n",
+ " plt.axis('off')\n",
+ " ims.append([im])\n",
+ "\n",
+ " ani = animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=1000)\n",
+ " plt.close()\n",
+ " return ani\n",
+ " \n",
+ "\n",
+ "HTML(display(source_image, driving_video).to_html5_video())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xjM7ubVfWrwT"
+ },
+ "source": [
+ "**Create a model and load checkpoints**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "3FQiXqQPWt5B"
+ },
+ "outputs": [],
+ "source": [
+ "from demo import load_checkpoints\n",
+ "inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = config_path, checkpoint_path = checkpoint_path, device = device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fdFdasHEj3t7"
+ },
+ "source": [
+ "**Perform image animation**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 471
+ },
+ "id": "SB12II11kF4c",
+ "outputId": "9e2274aa-fd55-4eed-cb50-bec72fcfb8b9"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 169/169 [00:10<00:00, 15.69it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from demo import make_animation\n",
+ "from skimage import img_as_ubyte\n",
+ "\n",
+ "if predict_mode=='relative' and find_best_frame:\n",
+ " from demo import find_best_frame as _find\n",
+ " i = _find(source_image, driving_video, device.type=='cpu')\n",
+ " print (\"Best frame: \" + str(i))\n",
+ " driving_forward = driving_video[i:]\n",
+ " driving_backward = driving_video[:(i+1)][::-1]\n",
+ " predictions_forward = make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
+ " predictions_backward = make_animation(source_image, driving_backward, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
+ " predictions = predictions_backward[::-1] + predictions_forward[1:]\n",
+ "else:\n",
+ " predictions = make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network, avd_network, device = device, mode = predict_mode)\n",
+ "\n",
+ "#save resulting video\n",
+ "imageio.mimsave(output_video_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
+ "\n",
+ "HTML(display(source_image, driving_video, predictions).to_html5_video())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "name": "first-order-model-demo.ipynb",
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}