diff --git a/.DS_Store b/.DS_Store
index 0b4ccc22b2339ffd7784d2d95769685d08543865..cba5e29f184117d551fda66406c0a96002fd9473 100644
Binary files a/.DS_Store and b/.DS_Store differ
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..e8711cb816d416037617dc0b72b33d866790c3d4
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,57 @@
+FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && \
+ apt-get upgrade -y && \
+ apt-get install -y --no-install-recommends \
+ git \
+ git-lfs \
+ wget \
+ curl \
+ # ffmpeg \
+ ffmpeg \
+ x264 \
+ # python build dependencies \
+ build-essential \
+ libssl-dev \
+ zlib1g-dev \
+ libbz2-dev \
+ libreadline-dev \
+ libsqlite3-dev \
+ libncursesw5-dev \
+ xz-utils \
+ tk-dev \
+ libxml2-dev \
+ libxmlsec1-dev \
+ libffi-dev \
+ liblzma-dev && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN useradd -m -u 1000 user
+USER user
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:${PATH}
+WORKDIR ${HOME}/app
+
+RUN curl https://pyenv.run | bash
+ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
+ENV PYTHON_VERSION=3.10.9
+RUN pyenv install ${PYTHON_VERSION} && \
+ pyenv global ${PYTHON_VERSION} && \
+ pyenv rehash && \
+ pip install --no-cache-dir -U pip setuptools wheel
+
+RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
+COPY --chown=1000 requirements.txt /tmp/requirements.txt
+RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
+
+COPY --chown=1000 . ${HOME}/app
+RUN cd Tune-A-Video && patch -p1 < ../patch
+ENV PYTHONPATH=${HOME}/app \
+ PYTHONUNBUFFERED=1 \
+ GRADIO_ALLOW_FLAGGING=never \
+ GRADIO_NUM_PORTS=1 \
+ GRADIO_SERVER_NAME=0.0.0.0 \
+ GRADIO_THEME=huggingface \
+ SYSTEM=spaces
+CMD ["python", "app.py"]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..74a6f299d51d145387ae5ab220e9730719614111
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 hysts
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Video-P2P-Demo/.DS_Store b/Video-P2P-Demo/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..0b4ccc22b2339ffd7784d2d95769685d08543865
Binary files /dev/null and b/Video-P2P-Demo/.DS_Store differ
diff --git a/Video-P2P-Demo/README.md b/Video-P2P-Demo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c25d5d0c391136488b32f92d51c853d075c28f2a
--- /dev/null
+++ b/Video-P2P-Demo/README.md
@@ -0,0 +1,27 @@
+---
+title: Video-P2P Demo
+emoji: 🐶
+colorFrom: blue
+colorTo: pink
+sdk: gradio
+app_file: app.py
+pinned: false
+---
+
+# Video-P2P
+
+## Setup
+
+All required packages are listed in the requirements file.
+The code was tested on a Tesla V100 32GB but should work on other cards with at least **16GB** VRAM.
+
+## Quickstart
+
+``` bash
+bash script.sh
+```
+
+## References
+* prompt-to-prompt: https://github.com/google/prompt-to-prompt
+* Tune-A-Video: https://github.com/showlab/Tune-A-Video
+* diffusers: https://github.com/huggingface/diffusers
\ No newline at end of file
diff --git a/configs/.DS_Store b/Video-P2P-Demo/configs/.DS_Store
similarity index 100%
rename from configs/.DS_Store
rename to Video-P2P-Demo/configs/.DS_Store
diff --git a/configs/man-motor-tune.yaml b/Video-P2P-Demo/configs/man-motor-tune.yaml
similarity index 100%
rename from configs/man-motor-tune.yaml
rename to Video-P2P-Demo/configs/man-motor-tune.yaml
diff --git a/configs/rabbit-jump-p2p.yaml b/Video-P2P-Demo/configs/rabbit-jump-p2p.yaml
similarity index 100%
rename from configs/rabbit-jump-p2p.yaml
rename to Video-P2P-Demo/configs/rabbit-jump-p2p.yaml
diff --git a/configs/rabbit-jump-tune.yaml b/Video-P2P-Demo/configs/rabbit-jump-tune.yaml
similarity index 100%
rename from configs/rabbit-jump-tune.yaml
rename to Video-P2P-Demo/configs/rabbit-jump-tune.yaml
diff --git a/data/.DS_Store b/Video-P2P-Demo/data/.DS_Store
similarity index 100%
rename from data/.DS_Store
rename to Video-P2P-Demo/data/.DS_Store
diff --git a/data/motorbike/.DS_Store b/Video-P2P-Demo/data/motorbike/.DS_Store
similarity index 100%
rename from data/motorbike/.DS_Store
rename to Video-P2P-Demo/data/motorbike/.DS_Store
diff --git a/data/motorbike/1.jpg b/Video-P2P-Demo/data/motorbike/1.jpg
similarity index 100%
rename from data/motorbike/1.jpg
rename to Video-P2P-Demo/data/motorbike/1.jpg
diff --git a/data/motorbike/2.jpg b/Video-P2P-Demo/data/motorbike/2.jpg
similarity index 100%
rename from data/motorbike/2.jpg
rename to Video-P2P-Demo/data/motorbike/2.jpg
diff --git a/data/motorbike/3.jpg b/Video-P2P-Demo/data/motorbike/3.jpg
similarity index 100%
rename from data/motorbike/3.jpg
rename to Video-P2P-Demo/data/motorbike/3.jpg
diff --git a/data/motorbike/4.jpg b/Video-P2P-Demo/data/motorbike/4.jpg
similarity index 100%
rename from data/motorbike/4.jpg
rename to Video-P2P-Demo/data/motorbike/4.jpg
diff --git a/data/motorbike/5.jpg b/Video-P2P-Demo/data/motorbike/5.jpg
similarity index 100%
rename from data/motorbike/5.jpg
rename to Video-P2P-Demo/data/motorbike/5.jpg
diff --git a/data/motorbike/6.jpg b/Video-P2P-Demo/data/motorbike/6.jpg
similarity index 100%
rename from data/motorbike/6.jpg
rename to Video-P2P-Demo/data/motorbike/6.jpg
diff --git a/data/motorbike/7.jpg b/Video-P2P-Demo/data/motorbike/7.jpg
similarity index 100%
rename from data/motorbike/7.jpg
rename to Video-P2P-Demo/data/motorbike/7.jpg
diff --git a/data/motorbike/8.jpg b/Video-P2P-Demo/data/motorbike/8.jpg
similarity index 100%
rename from data/motorbike/8.jpg
rename to Video-P2P-Demo/data/motorbike/8.jpg
diff --git a/data/rabbit/1.jpg b/Video-P2P-Demo/data/rabbit/1.jpg
similarity index 100%
rename from data/rabbit/1.jpg
rename to Video-P2P-Demo/data/rabbit/1.jpg
diff --git a/data/rabbit/2.jpg b/Video-P2P-Demo/data/rabbit/2.jpg
similarity index 100%
rename from data/rabbit/2.jpg
rename to Video-P2P-Demo/data/rabbit/2.jpg
diff --git a/data/rabbit/3.jpg b/Video-P2P-Demo/data/rabbit/3.jpg
similarity index 100%
rename from data/rabbit/3.jpg
rename to Video-P2P-Demo/data/rabbit/3.jpg
diff --git a/data/rabbit/4.jpg b/Video-P2P-Demo/data/rabbit/4.jpg
similarity index 100%
rename from data/rabbit/4.jpg
rename to Video-P2P-Demo/data/rabbit/4.jpg
diff --git a/data/rabbit/5.jpg b/Video-P2P-Demo/data/rabbit/5.jpg
similarity index 100%
rename from data/rabbit/5.jpg
rename to Video-P2P-Demo/data/rabbit/5.jpg
diff --git a/data/rabbit/6.jpg b/Video-P2P-Demo/data/rabbit/6.jpg
similarity index 100%
rename from data/rabbit/6.jpg
rename to Video-P2P-Demo/data/rabbit/6.jpg
diff --git a/data/rabbit/7.jpg b/Video-P2P-Demo/data/rabbit/7.jpg
similarity index 100%
rename from data/rabbit/7.jpg
rename to Video-P2P-Demo/data/rabbit/7.jpg
diff --git a/data/rabbit/8.jpg b/Video-P2P-Demo/data/rabbit/8.jpg
similarity index 100%
rename from data/rabbit/8.jpg
rename to Video-P2P-Demo/data/rabbit/8.jpg
diff --git a/ptp_utils.py b/Video-P2P-Demo/ptp_utils.py
similarity index 100%
rename from ptp_utils.py
rename to Video-P2P-Demo/ptp_utils.py
diff --git a/Video-P2P-Demo/requirements.txt b/Video-P2P-Demo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fbaa5d73d3c11116aa7d1703c54111f81571bc63
--- /dev/null
+++ b/Video-P2P-Demo/requirements.txt
@@ -0,0 +1,15 @@
+torch==1.12.1
+torchvision==0.13.1
+diffusers[torch]==0.11.1
+transformers>=4.25.1
+bitsandbytes==0.35.4
+decord==0.6.0
+accelerate
+tensorboard
+modelcards
+omegaconf
+einops
+imageio
+ftfy
+opencv-python
+ipywidgets
\ No newline at end of file
diff --git a/run_tuning.py b/Video-P2P-Demo/run_tuning.py
similarity index 100%
rename from run_tuning.py
rename to Video-P2P-Demo/run_tuning.py
diff --git a/run_videop2p.py b/Video-P2P-Demo/run_videop2p.py
similarity index 100%
rename from run_videop2p.py
rename to Video-P2P-Demo/run_videop2p.py
diff --git a/script.sh b/Video-P2P-Demo/script.sh
similarity index 100%
rename from script.sh
rename to Video-P2P-Demo/script.sh
diff --git a/seq_aligner.py b/Video-P2P-Demo/seq_aligner.py
similarity index 100%
rename from seq_aligner.py
rename to Video-P2P-Demo/seq_aligner.py
diff --git a/tuneavideo/data/dataset.py b/Video-P2P-Demo/tuneavideo/data/dataset.py
similarity index 100%
rename from tuneavideo/data/dataset.py
rename to Video-P2P-Demo/tuneavideo/data/dataset.py
diff --git a/tuneavideo/models/attention.py b/Video-P2P-Demo/tuneavideo/models/attention.py
similarity index 100%
rename from tuneavideo/models/attention.py
rename to Video-P2P-Demo/tuneavideo/models/attention.py
diff --git a/tuneavideo/models/resnet.py b/Video-P2P-Demo/tuneavideo/models/resnet.py
similarity index 100%
rename from tuneavideo/models/resnet.py
rename to Video-P2P-Demo/tuneavideo/models/resnet.py
diff --git a/tuneavideo/models/unet.py b/Video-P2P-Demo/tuneavideo/models/unet.py
similarity index 100%
rename from tuneavideo/models/unet.py
rename to Video-P2P-Demo/tuneavideo/models/unet.py
diff --git a/tuneavideo/models/unet_blocks.py b/Video-P2P-Demo/tuneavideo/models/unet_blocks.py
similarity index 100%
rename from tuneavideo/models/unet_blocks.py
rename to Video-P2P-Demo/tuneavideo/models/unet_blocks.py
diff --git a/tuneavideo/pipelines/pipeline_tuneavideo.py b/Video-P2P-Demo/tuneavideo/pipelines/pipeline_tuneavideo.py
similarity index 100%
rename from tuneavideo/pipelines/pipeline_tuneavideo.py
rename to Video-P2P-Demo/tuneavideo/pipelines/pipeline_tuneavideo.py
diff --git a/tuneavideo/util.py b/Video-P2P-Demo/tuneavideo/util.py
similarity index 100%
rename from tuneavideo/util.py
rename to Video-P2P-Demo/tuneavideo/util.py
diff --git a/app.py b/app.py
new file mode 100755
index 0000000000000000000000000000000000000000..3989b9cac50d39f3348061b4e346bda641265a82
--- /dev/null
+++ b/app.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import os
+from subprocess import getoutput
+
+import gradio as gr
+import torch
+
+# from app_inference import create_inference_demo
+from app_training import create_training_demo
+# from app_upload import create_upload_demo
+from inference import InferencePipeline
+from trainer import Trainer
+
+TITLE = '# [Video-P2P](https://video-p2p.github.io/) UI'
+
+ORIGINAL_SPACE_ID = 'Shaldon/Video-P2P-Training-UI'
+SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
+GPU_DATA = getoutput('nvidia-smi')
+SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
+
+
+'''
+
+if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
+ SETTINGS = f'Settings'
+else:
+ SETTINGS = 'Settings'
+
+INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
+
+CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
+
+You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
+You can use "T4 small/medium" to run this demo.
+
+'''
+
+HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
+
+You can check and create your Hugging Face tokens here.
+You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
+
+'''
+
+HF_TOKEN = os.getenv('HF_TOKEN')
+
+
+def show_warning(warning_text: str) -> gr.Blocks:
+ with gr.Blocks() as demo:
+ with gr.Box():
+ gr.Markdown(warning_text)
+ return demo
+
+
+pipe = InferencePipeline(HF_TOKEN)
+trainer = Trainer(HF_TOKEN)
+
+with gr.Blocks(css='style.css') as demo:
+ if SPACE_ID == ORIGINAL_SPACE_ID:
+ show_warning(SHARED_UI_WARNING)
+ elif not torch.cuda.is_available():
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
+ elif (not 'T4' in GPU_DATA):
+ show_warning(INVALID_GPU_WARNING)
+
+ gr.Markdown(TITLE)
+ with gr.Tabs():
+ with gr.TabItem('Train'):
+ create_training_demo(trainer, pipe)
+ # with gr.TabItem('Run'):
+ # create_inference_demo(pipe, HF_TOKEN)
+ # with gr.TabItem('Upload'):
+ # gr.Markdown('''
+ # - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
+ # ''')
+ # create_upload_demo(HF_TOKEN)
+
+ if not HF_TOKEN:
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
+
+demo.queue(max_size=1).launch(share=False)
diff --git a/app_inference.py b/app_inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..d705504e5bc7a8938e1b5fcfb207f4cb731c866b
--- /dev/null
+++ b/app_inference.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import enum
+
+import gradio as gr
+from huggingface_hub import HfApi
+
+from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
+from inference import InferencePipeline
+from utils import find_exp_dirs
+
+
+class ModelSource(enum.Enum):
+ HUB_LIB = UploadTarget.MODEL_LIBRARY.value
+ LOCAL = 'Local'
+
+
+class InferenceUtil:
+ def __init__(self, hf_token: str | None):
+ self.hf_token = hf_token
+
+ def load_hub_model_list(self) -> dict:
+ api = HfApi(token=self.hf_token)
+ choices = [
+ info.modelId
+ for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
+ ]
+ return gr.update(choices=choices,
+ value=choices[0] if choices else None)
+
+ @staticmethod
+ def load_local_model_list() -> dict:
+ choices = find_exp_dirs()
+ return gr.update(choices=choices,
+ value=choices[0] if choices else None)
+
+ def reload_model_list(self, model_source: str) -> dict:
+ if model_source == ModelSource.HUB_LIB.value:
+ return self.load_hub_model_list()
+ elif model_source == ModelSource.LOCAL.value:
+ return self.load_local_model_list()
+ else:
+ raise ValueError
+
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
+ try:
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
+ except Exception:
+ return '', ''
+ base_model = getattr(card.data, 'base_model', '')
+ training_prompt = getattr(card.data, 'training_prompt', '')
+ return base_model, training_prompt
+
+ def reload_model_list_and_update_model_info(
+ self, model_source: str) -> tuple[dict, str, str]:
+ model_list_update = self.reload_model_list(model_source)
+ model_list = model_list_update['choices']
+ model_info = self.load_model_info(model_list[0] if model_list else '')
+ return model_list_update, *model_info
+
+
+def create_inference_demo(pipe: InferencePipeline,
+ hf_token: str | None = None) -> gr.Blocks:
+ app = InferenceUtil(hf_token)
+
+ with gr.Blocks() as demo:
+ with gr.Row():
+ with gr.Column():
+ with gr.Box():
+ model_source = gr.Radio(
+ label='Model Source',
+ choices=[_.value for _ in ModelSource],
+ value=ModelSource.HUB_LIB.value)
+ reload_button = gr.Button('Reload Model List')
+ model_id = gr.Dropdown(label='Model ID',
+ choices=None,
+ value=None)
+ with gr.Accordion(
+ label=
+ 'Model info (Base model and prompt used for training)',
+ open=False):
+ with gr.Row():
+ base_model_used_for_training = gr.Text(
+ label='Base model', interactive=False)
+ prompt_used_for_training = gr.Text(
+ label='Training prompt', interactive=False)
+ prompt = gr.Textbox(
+ label='Prompt',
+ max_lines=1,
+ placeholder='Example: "A panda is surfing"')
+ video_length = gr.Slider(label='Video length',
+ minimum=4,
+ maximum=12,
+ step=1,
+ value=8)
+ fps = gr.Slider(label='FPS',
+ minimum=1,
+ maximum=12,
+ step=1,
+ value=1)
+ seed = gr.Slider(label='Seed',
+ minimum=0,
+ maximum=100000,
+ step=1,
+ value=0)
+ with gr.Accordion('Other Parameters', open=False):
+ num_steps = gr.Slider(label='Number of Steps',
+ minimum=0,
+ maximum=100,
+ step=1,
+ value=50)
+ guidance_scale = gr.Slider(label='CFG Scale',
+ minimum=0,
+ maximum=50,
+ step=0.1,
+ value=7.5)
+
+ run_button = gr.Button('Generate')
+
+ gr.Markdown('''
+ - After training, you can press "Reload Model List" button to load your trained model names.
+ - It takes a few minutes to download model first.
+ - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
+ ''')
+ with gr.Column():
+ result = gr.Video(label='Result')
+
+ model_source.change(fn=app.reload_model_list_and_update_model_info,
+ inputs=model_source,
+ outputs=[
+ model_id,
+ base_model_used_for_training,
+ prompt_used_for_training,
+ ])
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
+ inputs=model_source,
+ outputs=[
+ model_id,
+ base_model_used_for_training,
+ prompt_used_for_training,
+ ])
+ model_id.change(fn=app.load_model_info,
+ inputs=model_id,
+ outputs=[
+ base_model_used_for_training,
+ prompt_used_for_training,
+ ])
+ inputs = [
+ model_id,
+ prompt,
+ video_length,
+ fps,
+ seed,
+ num_steps,
+ guidance_scale,
+ ]
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
+ return demo
+
+
+if __name__ == '__main__':
+ import os
+
+ hf_token = os.getenv('HF_TOKEN')
+ pipe = InferencePipeline(hf_token)
+ demo = create_inference_demo(pipe, hf_token)
+ demo.queue(max_size=10).launch(share=False)
diff --git a/app_training.py b/app_training.py
new file mode 100755
index 0000000000000000000000000000000000000000..edf85e9bc3f92ce1940904210830faa762a89bbd
--- /dev/null
+++ b/app_training.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import os
+
+import gradio as gr
+
+from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
+from inference import InferencePipeline
+from trainer import Trainer
+
+
+def create_training_demo(trainer: Trainer,
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
+ hf_token = os.getenv('HF_TOKEN')
+ with gr.Blocks() as demo:
+ with gr.Row():
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Training Data')
+ training_video = gr.File(label='Training video')
+ training_prompt = gr.Textbox(
+ label='Training prompt',
+ max_lines=1,
+ placeholder='A rabbit is jumping on the grass')
+ gr.Markdown('''
+ - Upload a video and write a `Training Prompt` that describes the video.
+ ''')
+
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Training Parameters')
+ with gr.Row():
+ base_model = gr.Text(
+ label='Base Model',
+ value='CompVis/stable-diffusion-v1-5',
+ max_lines=1)
+ resolution = gr.Dropdown(choices=['512', '768'],
+ value='512',
+ label='Resolution',
+ visible=False)
+
+ input_token = gr.Text(label='Hugging Face Write Token',
+ placeholder='',
+ visible=False if hf_token else True)
+ with gr.Accordion('Advanced settings', open=False):
+ num_training_steps = gr.Number(
+ label='Number of Training Steps',
+ value=300,
+ precision=0)
+ learning_rate = gr.Number(label='Learning Rate',
+ value=0.000035)
+ gradient_accumulation = gr.Number(
+ label='Number of Gradient Accumulation',
+ value=1,
+ precision=0)
+ seed = gr.Slider(label='Seed',
+ minimum=0,
+ maximum=100000,
+ step=1,
+ randomize=True,
+ value=0)
+ fp16 = gr.Checkbox(label='FP16', value=True)
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam',
+ value=False)
+ checkpointing_steps = gr.Number(
+ label='Checkpointing Steps',
+ value=1000,
+ precision=0)
+ validation_epochs = gr.Number(
+ label='Validation Epochs', value=100, precision=0)
+ gr.Markdown('''
+ - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
+ - Expected time to train a model for 300 steps: ~20 minutes with T4
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
+ ''')
+
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('Output Model')
+ output_model_name = gr.Text(label='Name of your model',
+ placeholder='The surfer man',
+ max_lines=1)
+ validation_prompt = gr.Text(
+ label='Validation Prompt',
+ placeholder=
+ 'prompt to test the model, e.g: a dog is surfing')
+ with gr.Column():
+ gr.Markdown('Upload Settings')
+ with gr.Row():
+ upload_to_hub = gr.Checkbox(label='Upload model to Hub',
+ value=True)
+ use_private_repo = gr.Checkbox(label='Private', value=True)
+ delete_existing_repo = gr.Checkbox(
+ label='Delete existing repo of the same name',
+ value=False)
+ upload_to = gr.Radio(
+ label='Upload to',
+ choices=[_.value for _ in UploadTarget],
+ value=UploadTarget.MODEL_LIBRARY.value)
+
+ remove_gpu_after_training = gr.Checkbox(
+ label='Remove GPU after training',
+ value=False,
+ interactive=bool(os.getenv('SPACE_ID')),
+ visible=False)
+ run_button = gr.Button('Start Training')
+
+ with gr.Box():
+ gr.Markdown('Output message')
+ output_message = gr.Markdown()
+
+ if pipe is not None:
+ run_button.click(fn=pipe.clear)
+ run_button.click(
+ fn=trainer.run,
+ inputs=[
+ training_video, training_prompt, output_model_name,
+ delete_existing_repo, validation_prompt, base_model,
+ resolution, num_training_steps, learning_rate,
+ gradient_accumulation, seed, fp16, use_8bit_adam,
+ checkpointing_steps, validation_epochs, upload_to_hub,
+ use_private_repo, delete_existing_repo, upload_to,
+ remove_gpu_after_training, input_token
+ ],
+ outputs=output_message)
+ return demo
+
+
+if __name__ == '__main__':
+ hf_token = os.getenv('HF_TOKEN')
+ trainer = Trainer(hf_token)
+ demo = create_training_demo(trainer)
+ demo.queue(max_size=1).launch(share=False)
diff --git a/app_upload.py b/app_upload.py
new file mode 100755
index 0000000000000000000000000000000000000000..f672f555512b456d95d8f674fa832b1c9bf34309
--- /dev/null
+++ b/app_upload.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python
+
+from __future__ import annotations
+
+import pathlib
+
+import gradio as gr
+import slugify
+
+from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
+from uploader import Uploader
+from utils import find_exp_dirs
+
+
+class ModelUploader(Uploader):
+ def upload_model(
+ self,
+ folder_path: str,
+ repo_name: str,
+ upload_to: str,
+ private: bool,
+ delete_existing_repo: bool,
+ input_token: str | None = None,
+ ) -> str:
+ if not folder_path:
+ raise ValueError
+ if not repo_name:
+ repo_name = pathlib.Path(folder_path).name
+ repo_name = slugify.slugify(repo_name)
+
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
+ organization = ''
+ elif upload_to == UploadTarget.MODEL_LIBRARY.value:
+ organization = MODEL_LIBRARY_ORG_NAME
+ else:
+ raise ValueError
+
+ return self.upload(folder_path,
+ repo_name,
+ organization=organization,
+ private=private,
+ delete_existing_repo=delete_existing_repo,
+ input_token=input_token)
+
+
+def load_local_model_list() -> dict:
+ choices = find_exp_dirs()
+ return gr.update(choices=choices, value=choices[0] if choices else None)
+
+
+def create_upload_demo(hf_token: str | None) -> gr.Blocks:
+ uploader = ModelUploader(hf_token)
+ model_dirs = find_exp_dirs()
+
+ with gr.Blocks() as demo:
+ with gr.Box():
+ gr.Markdown('Local Models')
+ reload_button = gr.Button('Reload Model List')
+ model_dir = gr.Dropdown(
+ label='Model names',
+ choices=model_dirs,
+ value=model_dirs[0] if model_dirs else None)
+ with gr.Box():
+ gr.Markdown('Upload Settings')
+ with gr.Row():
+ use_private_repo = gr.Checkbox(label='Private', value=True)
+ delete_existing_repo = gr.Checkbox(
+ label='Delete existing repo of the same name', value=False)
+ upload_to = gr.Radio(label='Upload to',
+ choices=[_.value for _ in UploadTarget],
+ value=UploadTarget.MODEL_LIBRARY.value)
+ model_name = gr.Textbox(label='Model Name')
+ input_token = gr.Text(label='Hugging Face Write Token',
+ placeholder='',
+ visible=False if hf_token else True)
+ upload_button = gr.Button('Upload')
+ gr.Markdown(f'''
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
+ ''')
+ with gr.Box():
+ gr.Markdown('Output message')
+ output_message = gr.Markdown()
+
+ reload_button.click(fn=load_local_model_list,
+ inputs=None,
+ outputs=model_dir)
+ upload_button.click(fn=uploader.upload_model,
+ inputs=[
+ model_dir,
+ model_name,
+ upload_to,
+ use_private_repo,
+ delete_existing_repo,
+ input_token,
+ ],
+ outputs=output_message)
+
+ return demo
+
+
+if __name__ == '__main__':
+ import os
+
+ hf_token = os.getenv('HF_TOKEN')
+ demo = create_upload_demo(hf_token)
+ demo.queue(max_size=1).launch(share=False)
diff --git a/constants.py b/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fb6e1f7ea852e729e950861e4e5beb4e1e38b75
--- /dev/null
+++ b/constants.py
@@ -0,0 +1,10 @@
+import enum
+
+
+class UploadTarget(enum.Enum):
+ PERSONAL_PROFILE = 'Personal Profile'
+ MODEL_LIBRARY = 'Tune-A-Video Library'
+
+
+MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
+SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..65ea5c55ce0fdacb37cf6945699824368bb5ff70
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,109 @@
+from __future__ import annotations
+
+import gc
+import pathlib
+import sys
+import tempfile
+
+import gradio as gr
+import imageio
+import PIL.Image
+import torch
+from diffusers.utils.import_utils import is_xformers_available
+from einops import rearrange
+from huggingface_hub import ModelCard
+
+sys.path.append('Tune-A-Video')
+
+from tuneavideo.models.unet import UNet3DConditionModel
+from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
+
+
+class InferencePipeline:
+ def __init__(self, hf_token: str | None = None):
+ self.hf_token = hf_token
+ self.pipe = None
+ self.device = torch.device(
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
+ self.model_id = None
+
+ def clear(self) -> None:
+ self.model_id = None
+ del self.pipe
+ self.pipe = None
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ @staticmethod
+ def check_if_model_is_local(model_id: str) -> bool:
+ return pathlib.Path(model_id).exists()
+
+ @staticmethod
+ def get_model_card(model_id: str,
+ hf_token: str | None = None) -> ModelCard:
+ if InferencePipeline.check_if_model_is_local(model_id):
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
+ else:
+ card_path = model_id
+ return ModelCard.load(card_path, token=hf_token)
+
+ @staticmethod
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
+ card = InferencePipeline.get_model_card(model_id, hf_token)
+ return card.data.base_model
+
+ def load_pipe(self, model_id: str) -> None:
+ if model_id == self.model_id:
+ return
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
+ unet = UNet3DConditionModel.from_pretrained(
+ model_id,
+ subfolder='unet',
+ torch_dtype=torch.float16,
+ use_auth_token=self.hf_token)
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
+ unet=unet,
+ torch_dtype=torch.float16,
+ use_auth_token=self.hf_token)
+ pipe = pipe.to(self.device)
+ if is_xformers_available():
+ pipe.unet.enable_xformers_memory_efficient_attention()
+ self.pipe = pipe
+ self.model_id = model_id # type: ignore
+
+ def run(
+ self,
+ model_id: str,
+ prompt: str,
+ video_length: int,
+ fps: int,
+ seed: int,
+ n_steps: int,
+ guidance_scale: float,
+ ) -> PIL.Image.Image:
+ if not torch.cuda.is_available():
+ raise gr.Error('CUDA is not available.')
+
+ self.load_pipe(model_id)
+
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+ out = self.pipe(
+ prompt,
+ video_length=video_length,
+ width=512,
+ height=512,
+ num_inference_steps=n_steps,
+ guidance_scale=guidance_scale,
+ generator=generator,
+ ) # type: ignore
+
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
+ frames = (frames * 255).to(torch.uint8).numpy()
+
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
+ writer = imageio.get_writer(out_file.name, fps=fps)
+ for frame in frames:
+ writer.append_data(frame)
+ writer.close()
+
+ return out_file.name
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..20645e641240cb419f5fc66c14c1447e91daf669
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1 @@
+ffmpeg
diff --git a/patch b/patch
new file mode 100644
index 0000000000000000000000000000000000000000..5f637d39325bf19d3fa51ee953d6dacd993b51ea
--- /dev/null
+++ b/patch
@@ -0,0 +1,15 @@
+diff --git a/train_tuneavideo.py b/train_tuneavideo.py
+index 66d51b2..86b2a5d 100644
+--- a/train_tuneavideo.py
++++ b/train_tuneavideo.py
+@@ -94,8 +94,8 @@ def main(
+
+ # Handle the output folder creation
+ if accelerator.is_main_process:
+- now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+- output_dir = os.path.join(output_dir, now)
++ #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
++ #output_dir = os.path.join(output_dir, now)
+ os.makedirs(output_dir, exist_ok=True)
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
+
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..c4739b4ea5fc35e774a049e3dacc443f7f0eac19
--- /dev/null
+++ b/style.css
@@ -0,0 +1,3 @@
+h1 {
+ text-align: center;
+}
diff --git a/trainer.py b/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91dd390a535becbf2d0f83cb854ab04a5319dee
--- /dev/null
+++ b/trainer.py
@@ -0,0 +1,166 @@
+from __future__ import annotations
+
+import datetime
+import os
+import pathlib
+import shlex
+import shutil
+import subprocess
+import sys
+
+import gradio as gr
+import slugify
+import torch
+from huggingface_hub import HfApi
+from omegaconf import OmegaConf
+
+from app_upload import ModelUploader
+from utils import save_model_card
+
+sys.path.append('Video-P2P-Demo')
+
+# URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
+ORIGINAL_SPACE_ID = 'Shaldon/Video-P2P-Training-UI'
+SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
+
+
+class Trainer:
+ def __init__(self, hf_token: str | None = None):
+ self.hf_token = hf_token
+ self.model_uploader = ModelUploader(hf_token)
+
+ self.checkpoint_dir = pathlib.Path('checkpoints')
+ self.checkpoint_dir.mkdir(exist_ok=True)
+
+ def download_base_model(self, base_model_id: str) -> str:
+ model_dir = self.checkpoint_dir / base_model_id
+ if not model_dir.exists():
+ org_name = base_model_id.split('/')[0]
+ org_dir = self.checkpoint_dir / org_name
+ org_dir.mkdir(exist_ok=True)
+ subprocess.run(shlex.split(
+ f'git clone https://huggingface.co/{base_model_id}'),
+ cwd=org_dir)
+ return model_dir.as_posix()
+
+ # def join_model_library_org(self, token: str) -> None:
+ # subprocess.run(
+ # shlex.split(
+ # f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
+ # ))
+
+ def run(
+ self,
+ training_video: str,
+ training_prompt: str,
+ output_model_name: str,
+ overwrite_existing_model: bool,
+ validation_prompt: str,
+ base_model: str,
+ resolution_s: str,
+ n_steps: int,
+ learning_rate: float,
+ gradient_accumulation: int,
+ seed: int,
+ fp16: bool,
+ use_8bit_adam: bool,
+ checkpointing_steps: int,
+ validation_epochs: int,
+ upload_to_hub: bool,
+ use_private_repo: bool,
+ delete_existing_repo: bool,
+ upload_to: str,
+ remove_gpu_after_training: bool,
+ input_token: str,
+ ) -> str:
+ if SPACE_ID == ORIGINAL_SPACE_ID:
+ raise gr.Error(
+ 'This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU'
+ )
+ if not torch.cuda.is_available():
+ raise gr.Error('CUDA is not available.')
+ if training_video is None:
+ raise gr.Error('You need to upload a video.')
+ if not training_prompt:
+ raise gr.Error('The training prompt is missing.')
+ if not validation_prompt:
+ raise gr.Error('The validation prompt is missing.')
+
+ resolution = int(resolution_s)
+
+ if not output_model_name:
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
+ output_model_name = f'video-p2p-{timestamp}'
+ output_model_name = slugify.slugify(output_model_name)
+
+ repo_dir = pathlib.Path(__file__).parent
+ output_dir = repo_dir / 'experiments' / output_model_name
+ if overwrite_existing_model or upload_to_hub:
+ shutil.rmtree(output_dir, ignore_errors=True)
+ output_dir.mkdir(parents=True)
+
+ # if upload_to_hub:
+ # self.join_model_library_org(
+ # self.hf_token if self.hf_token else input_token)
+
+ config = OmegaConf.load('Video-P2P-Demo/configs/rabbit-jump-tune.yaml')
+ config.pretrained_model_path = self.download_base_model(base_model)
+ config.output_dir = output_dir.as_posix()
+ config.train_data.video_path = training_video.name # type: ignore
+ config.train_data.prompt = training_prompt
+ config.train_data.n_sample_frames = 8
+ config.train_data.width = resolution
+ config.train_data.height = resolution
+ config.train_data.sample_start_idx = 0
+ config.train_data.sample_frame_rate = 1
+ config.validation_data.prompts = [validation_prompt]
+ config.validation_data.video_length = 8
+ config.validation_data.width = resolution
+ config.validation_data.height = resolution
+ config.validation_data.num_inference_steps = 50
+ config.validation_data.guidance_scale = 7.5
+ config.learning_rate = learning_rate
+ config.gradient_accumulation_steps = gradient_accumulation
+ config.train_batch_size = 1
+ config.max_train_steps = n_steps
+ config.checkpointing_steps = checkpointing_steps
+ config.validation_steps = validation_epochs
+ config.seed = seed
+ config.mixed_precision = 'fp16' if fp16 else ''
+ config.use_8bit_adam = use_8bit_adam
+
+ config_path = output_dir / 'config.yaml'
+ with open(config_path, 'w') as f:
+ OmegaConf.save(config, f)
+
+ command = f'accelerate launch Video-P2P-Demo/train_tuneavideo.py --config {config_path}'
+ subprocess.run(shlex.split(command))
+ save_model_card(save_dir=output_dir,
+ base_model=base_model,
+ training_prompt=training_prompt,
+ test_prompt=validation_prompt,
+ test_image_dir='samples')
+
+ message = 'Training completed!'
+ print(message)
+
+ if upload_to_hub:
+ upload_message = self.model_uploader.upload_model(
+ folder_path=output_dir.as_posix(),
+ repo_name=output_model_name,
+ upload_to=upload_to,
+ private=use_private_repo,
+ delete_existing_repo=delete_existing_repo,
+ input_token=input_token)
+ print(upload_message)
+ message = message + '\n' + upload_message
+
+ if remove_gpu_after_training:
+ space_id = os.getenv('SPACE_ID')
+ if space_id:
+ api = HfApi(
+ token=self.hf_token if self.hf_token else input_token)
+ api.request_space_hardware(repo_id=space_id,
+ hardware='cpu-basic')
+
+ return message
diff --git a/uploader.py b/uploader.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9e06ec02127db34016d3d7b550e88f820a737fe
--- /dev/null
+++ b/uploader.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+from huggingface_hub import HfApi
+
+
+class Uploader:
+ def __init__(self, hf_token: str | None):
+ self.hf_token = hf_token
+
+ def upload(self,
+ folder_path: str,
+ repo_name: str,
+ organization: str = '',
+ repo_type: str = 'model',
+ private: bool = True,
+ delete_existing_repo: bool = False,
+ input_token: str | None = None) -> str:
+
+ api = HfApi(token=self.hf_token if self.hf_token else input_token)
+
+ if not folder_path:
+ raise ValueError
+ if not repo_name:
+ raise ValueError
+ if not organization:
+ organization = api.whoami()['name']
+
+ repo_id = f'{organization}/{repo_name}'
+ if delete_existing_repo:
+ try:
+ api.delete_repo(repo_id, repo_type=repo_type)
+ except Exception:
+ pass
+ try:
+ api.create_repo(repo_id, repo_type=repo_type, private=private)
+ api.upload_folder(repo_id=repo_id,
+ folder_path=folder_path,
+ path_in_repo='.',
+ repo_type=repo_type)
+ url = f'https://huggingface.co/{repo_id}'
+ message = f'Your model was successfully uploaded to {url}.'
+ except Exception as e:
+ message = str(e)
+ return message
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9a1a0a57c02181c4a0dd93b397fb9dc85f51956
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,65 @@
+from __future__ import annotations
+
+import pathlib
+
+
+def find_exp_dirs() -> list[str]:
+ repo_dir = pathlib.Path(__file__).parent
+ exp_root_dir = repo_dir / 'experiments'
+ if not exp_root_dir.exists():
+ return []
+ exp_dirs = sorted(exp_root_dir.glob('*'))
+ exp_dirs = [
+ exp_dir for exp_dir in exp_dirs
+ if (exp_dir / 'model_index.json').exists()
+ ]
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
+
+
+def save_model_card(
+ save_dir: pathlib.Path,
+ base_model: str,
+ training_prompt: str,
+ test_prompt: str = '',
+ test_image_dir: str = '',
+) -> None:
+ image_str = ''
+ if test_prompt and test_image_dir:
+ image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
+ if image_paths:
+ image_path = image_paths[-1]
+ rel_path = image_path.relative_to(save_dir)
+ image_str = f'''## Samples
+Test prompt: {test_prompt}
+
+'''
+
+ model_card = f'''---
+license: creativeml-openrail-m
+base_model: {base_model}
+training_prompt: {training_prompt}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- text-to-video
+- tune-a-video
+inference: false
+---
+
+# Tune-A-Video - {save_dir.name}
+
+## Model description
+- Base model: [{base_model}](https://huggingface.co/{base_model})
+- Training prompt: {training_prompt}
+
+{image_str}
+
+## Related papers:
+- [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
+- [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
+'''
+
+ with open(save_dir / 'README.md', 'w') as f:
+ f.write(model_card)