diff --git a/FateZero b/FateZero deleted file mode 160000 index 6992d238770f464c03a0a74cbcec4f99da4635ec..0000000000000000000000000000000000000000 --- a/FateZero +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6992d238770f464c03a0a74cbcec4f99da4635ec diff --git a/FateZero/.gitignore b/FateZero/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..5deb6916e9f35749728ac97fed3ea3d8166adabc --- /dev/null +++ b/FateZero/.gitignore @@ -0,0 +1,176 @@ +start_hold +chenyangqi +trash/** +runs*/** +result/** +ckpt/** +ckpt +**.whl +stable-diffusion-v1-4 +trash +# data/** + +# Initially taken from Github's Python gitignore file + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# tests and logs +tests/fixtures/cached_*_text.txt +logs/ +lightning_logs/ +lang_code_data/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# vscode +.vs +.vscode + +# Pycharm +.idea + +# TF code +tensorflow_code + +# Models +proc_data + +# examples +runs +/runs_old +/wandb +/examples/runs +/examples/**/*.args +/examples/rag/sweep + +# emacs +*.*~ +debug.env + +# vim +.*.swp + +#ctags +tags + +# pre-commit +.pre-commit* + +# .lock +*.lock + +# DS_Store (MacOS) +.DS_Store +# RL pipelines may produce mp4 outputs +*.mp4 diff --git a/FateZero/LICENSE.md b/FateZero/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..1b4153021303b3da833d534ca9b712943e9b402d --- /dev/null +++ b/FateZero/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Chenyang QI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/FateZero/README.md b/FateZero/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d8923420a37f269314868f7fa9dd67fdb8fa0e9a --- /dev/null +++ b/FateZero/README.md @@ -0,0 +1,393 @@ +## FateZero: Fusing Attentions for Zero-shot Text-based Video Editing + +[Chenyang Qi](https://chenyangqiqi.github.io/), [Xiaodong Cun](http://vinthony.github.io/), [Yong Zhang](https://yzhang2016.github.io), [Chenyang Lei](https://chenyanglei.github.io/), [Xintao Wang](https://xinntao.github.io/), [Ying Shan](https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ), and [Qifeng Chen](https://cqf.io) + + + [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb) +[![GitHub](https://img.shields.io/github/stars/ChenyangQiQi/FateZero?style=social)](https://github.com/ChenyangQiQi/FateZero) + + + + + + + + + + + + +
"silver jeep ➜ posche car""+ Van Gogh style"
+ +## Abstract +TL;DR: Using FateZero, Edits your video via pretrained Diffusion models without training. + +
CLICK for full abstract + + +> The diffusion-based generative models have achieved +remarkable success in text-based image generation. However, +since it contains enormous randomness in generation +progress, it is still challenging to apply such models for +real-world visual content editing, especially in videos. In +this paper, we propose FateZero, a zero-shot text-based editing method on real-world videos without per-prompt +training or use-specific mask. To edit videos consistently, +we propose several techniques based on the pre-trained +models. Firstly, in contrast to the straightforward DDIM +inversion technique, our approach captures intermediate +attention maps during inversion, which effectively retain +both structural and motion information. These maps are +directly fused in the editing process rather than generated +during denoising. To further minimize semantic leakage of +the source video, we then fuse self-attentions with a blending +mask obtained by cross-attention features from the source +prompt. Furthermore, we have implemented a reform of the +self-attention mechanism in denoising UNet by introducing +spatial-temporal attention to ensure frame consistency. Yet +succinct, our method is the first one to show the ability of +zero-shot text-driven video style and local attribute editing +from the trained text-to-image model. We also have a better +zero-shot shape-aware editing ability based on the text-tovideo +model. Extensive experiments demonstrate our +superior temporal consistency and editing capability than +previous works. +
+ +## Changelog +- 2023.03.27 Release [`attribute editing config`](config/attribute) and + + [`data`](https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/attribute.zip) used in the paper. +- 2023.03.22 Upload a `colab notebook` [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb). Enjoy the fun of zero-shot video-editing freely! +- 2023.03.22 Release [`style editing config`](config/style) and + + [`data`](https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/style.zip) + used in the paper. +- 2023.03.21 [Editing guidance](docs/EditingGuidance.md) is provided to help users to edit in-the-wild video. Welcome to play and give feedback! +- 2023.03.21 Update the `codebase and configuration`. Now, it can run with lower resources (16G GPU and less than 16G CPU RAM) with [new configuration](config/low_resource_teaser) in `config/low_resource_teaser`. + +- 2023.03.17 Release Code and Paper! + +## Todo + +- [x] Release the edit config for teaser +- [x] Memory and runtime profiling +- [x] Hands-on guidance of hyperparameters tuning +- [x] Colab +- [x] Release configs for other result and in-the-wild dataset + +- [-] hugging-face: inprogress +- [ ] Tune-a-video optimization and shape editing configs +- [ ] Release more application + +## Setup Environment +Our method is tested using cuda11, fp16 of accelerator and xformers on a single A100 or 3090. + +```bash +conda create -n fatezero38 python=3.8 +conda activate fatezero38 + +pip install -r requirements.txt +``` + +`xformers` is recommended for A100 GPU to save memory and running time. + +
Click for xformers installation + +We find its installation not stable. You may try the following wheel: +```bash +wget https://github.com/ShivamShrirao/xformers-wheels/releases/download/4c06c79/xformers-0.0.15.dev0+4c06c79.d20221201-cp38-cp38-linux_x86_64.whl +pip install xformers-0.0.15.dev0+4c06c79.d20221201-cp38-cp38-linux_x86_64.whl +``` + +
+ +Validate the installation by +``` +python test_install.py +``` + +Our environment is similar to Tune-A-video ([official](https://github.com/showlab/Tune-A-Video), [unofficial](https://github.com/bryandlee/Tune-A-Video)) and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/). You may check them for more details. + + +## FateZero Editing + +#### Style and Attribute Editing in Teaser + +Download the [stable diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) (or other interesting image diffusion model) and put it to `./ckpt/stable-diffusion-v1-4`. + +
Click for bash command: + +``` +mkdir ./ckpt +# download from huggingface face, takes 20G space +git lfs install +git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 +cd ./ckpt +ln -s ../stable-diffusion-v1-4 . +``` +
+ +Then, you could reproduce style and shape editing result in our teaser by running: + +```bash +accelerate launch test_fatezero.py --config config/teaser/jeep_watercolor.yaml +# or CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml +``` + +
The result is saved at `./result` . (Click for directory structure) + +``` +result +├── teaser +│ ├── jeep_posche +│ ├── jeep_watercolor +│ ├── cross-attention # visualization of cross-attention during inversion +│ ├── sample # result +│ ├── train_samples # the input video + +``` + +
+ +Editing 8 frames on an Nvidia 3090, use `100G CPU memory, 12G GPU memory` for editing. We also provide some [`low cost setting`](config/low_resource_teaser) of style editing by different hyper-parameters on a 16GB GPU. +You may try these low cost setting on colab. +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb) + +More the speed and hardware benchmark [here](docs/EditingGuidance.md#ddim-hyperparameters). + +#### Shape and large motion editing with Tune-A-Video + +Besides style and attribution editing above, we also provide a `Tune-A-Video` [checkpoint](https://hkustconnect-my.sharepoint.com/:f:/g/personal/cqiaa_connect_ust_hk/EviSTWoAOs1EmHtqZruq50kBZu1E8gxDknCPigSvsS96uQ?e=492khj). You may download the it and move it to `./ckpt/jeep_tuned_200/`. + + +
The directory structure should like this: (Click for directory structure) + +``` +ckpt +├── stable-diffusion-v1-4 +├── jeep_tuned_200 +... +data +├── car-turn +│ ├── 00000000.png +│ ├── 00000001.png +│ ├── ... +video_diffusion +``` +
+ +You could reproduce the shape editing result in our teaser by running: + +```bash +accelerate launch test_fatezero.py --config config/teaser/jeep_posche.yaml +``` + + +### Reproduce other results in the paper (in progress) + +Download the data of style editing and attribute editing +from [onedrive](https://hkustconnect-my.sharepoint.com/:f:/g/personal/cqiaa_connect_ust_hk/EkIeHj3CQiBNhm6iEEhJQZwBEBJNCGt3FsANmyqeAYbuXQ?e=FxYtJk) or from Github [Release](https://github.com/ChenyangQiQi/FateZero/releases/tag/v0.0.1). +
Click for wget bash command: + +``` +wget https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/attribute.zip +wget https://github.com/ChenyangQiQi/FateZero/releases/download/v0.0.1/style.zip +``` +
+ +Unzip and Place it in ['./data'](data). Then use the command in ['config/style'](config/style) and ['config/attribute'](config/attribute) to get the results. + +The config of our tune-a-video ckpts will be updated latter. + +## Tuning guidance to edit YOUR video +We provided a tuning guidance to edit in-the-wild video at [here](./docs/EditingGuidance.md). The work is still in progress. Welcome to give your feedback in issues. + +## Style Editing Results with Stable Diffusion +We show the difference of source prompt and target prompt in the box below each video. + +Note mp4 and gif files in this github page are compressed. +Please check our [Project Page](https://fate-zero-edit.github.io/) for mp4 files of original video editing results. + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
"+ Ukiyo-e style""+ watercolor painting""+ Monet style"
"+ Pokémon cartoon style""+ Makoto Shinkai style""+ cartoon style"
+ +## Attribute Editing Results with Stable Diffusion + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
"rabbit, strawberry ➜ white rabbit, flower""rabbit, strawberry ➜ squirrel, carrot""rabbit, strawberry ➜ white rabbit, leaves"
"squirrel ➜ robot squirrel""squirrel, Carrot ➜ rabbit, eggplant""squirrel, Carrot ➜ robot mouse, screwdriver"
"bear ➜ a red tiger""bear ➜ a yellow leopard""bear ➜ a brown lion"
"cat ➜ black cat, grass...""cat ➜ red tiger""cat ➜ Shiba-Inu"
"bus ➜ GPU""gray dog ➜ yellow corgi""gray dog ➜ robotic dog"
"white duck ➜ yellow rubber duck""grass ➜ snow""white fox ➜ grey wolf"
+ +## Shape and large motion editing with Tune-A-Video + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
"silver jeep ➜ posche car""Swan ➜ White Duck""Swan ➜ Pink flamingo"
"A man ➜ A Batman""A man ➜ A Wonder Woman, With cowboy hat""A man ➜ A Spider-Man"
+ + +## Demo Video + +https://user-images.githubusercontent.com/45789244/225698509-79c14793-3153-4bba-9d6e-ede7d811d7f8.mp4 + +The video here is compressed due to the size limit of github. +The original full resolution video is [here](https://hkustconnect-my.sharepoint.com/:v:/g/personal/cqiaa_connect_ust_hk/EXKDI_nahEhKtiYPvvyU9SkBDTG2W4G1AZ_vkC7ekh3ENw?e=Xhgtmk). + + +## Citation + +``` +@misc{qi2023fatezero, + title={FateZero: Fusing Attentions for Zero-shot Text-based Video Editing}, + author={Chenyang Qi and Xiaodong Cun and Yong Zhang and Chenyang Lei and Xintao Wang and Ying Shan and Qifeng Chen}, + year={2023}, + eprint={2303.09535}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + + +## Acknowledgements + +This repository borrows heavily from [Tune-A-Video](https://github.com/showlab/Tune-A-Video) and [prompt-to-prompt](https://github.com/google/prompt-to-prompt/). thanks the authors for sharing their code and models. + +## Maintenance + +This is the codebase for our research work. We are still working hard to update this repo and more details are coming in days. If you have any questions or ideas to discuss, feel free to contact [Chenyang Qi](cqiaa@connect.ust.hk) or [Xiaodong Cun](vinthony@gmail.com). + diff --git a/FateZero/colab_fatezero.ipynb b/FateZero/colab_fatezero.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..548fdb611e7725d3b808122a59d92592bf95d08f --- /dev/null +++ b/FateZero/colab_fatezero.ipynb @@ -0,0 +1,528 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "fZ_xQvU70UQc" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChenyangQiQi/FateZero/blob/main/colab_fatezero.ipynb)\n", + "\n", + "# FateZero: Fusing Attentions for Zero-shot Text-based Video Editing\n", + "\n", + "[Chenyang Qi](https://chenyangqiqi.github.io/), [Xiaodong Cun](http://vinthony.github.io/), [Yong Zhang](https://yzhang2016.github.io), [Chenyang Lei](https://chenyanglei.github.io/), [Xintao Wang](https://xinntao.github.io/), [Ying Shan](https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ), and [Qifeng Chen](https://cqf.io)\n", + "\n", + "\n", + "[![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://fate-zero-edit.github.io/)\n", + "[![arXiv](https://img.shields.io/badge/arXiv-2303.09535-b31b1b.svg)](https://arxiv.org/abs/2303.09535)\n", + "[![GitHub](https://img.shields.io/github/stars/ChenyangQiQi/FateZero?style=social)](https://github.com/ChenyangQiQi/FateZero)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XU7NuMAA2drw", + "outputId": "82c4a90d-0ed6-4ad5-c74d-0a0ed3d98bbe" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tesla T4, 15360 MiB, 15101 MiB\n" + ] + } + ], + "source": [ + "#@markdown Check type of GPU and VRAM available.\n", + "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "D1PRgre3Gt5U", + "outputId": "ac1db329-a373-4c82-9b0d-77f4e5cb7140" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into '/content/FateZero'...\n", + "remote: Enumerating objects: 332, done.\u001b[K\n", + "remote: Counting objects: 100% (53/53), done.\u001b[K\n", + "remote: Compressing objects: 100% (7/7), done.\u001b[K\n", + "remote: Total 332 (delta 50), reused 47 (delta 46), pack-reused 279\u001b[K\n", + "Receiving objects: 100% (332/332), 34.21 MiB | 14.26 MiB/s, done.\n", + "Resolving deltas: 100% (157/157), done.\n", + "/content/FateZero\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.3/63.3 MB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m145.0/145.0 KB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for lit (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m524.9/524.9 KB\u001b[0m \u001b[31m35.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.3/6.3 MB\u001b[0m \u001b[31m74.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.5/62.5 MB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.6/13.6 MB\u001b[0m \u001b[31m96.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.8/212.8 KB\u001b[0m \u001b[31m25.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.5/79.5 KB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.6/41.6 KB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 KB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.8/15.8 MB\u001b[0m \u001b[31m88.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m26.9/26.9 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.9/50.9 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.8/199.8 KB\u001b[0m \u001b[31m23.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m105.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.0/117.0 KB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.5/71.5 KB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 KB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m106.5/106.5 KB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.5/143.5 KB\u001b[0m \u001b[31m18.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m64.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.7/45.7 KB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.8/57.8 KB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.1/57.1 KB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.5/50.5 KB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.2/114.2 KB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.8/158.8 KB\u001b[0m \u001b[31m20.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m264.6/264.6 KB\u001b[0m \u001b[31m26.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.9/66.9 KB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m69.6/69.6 KB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 KB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for antlr4-python3-runtime (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "#@title Install requirements\n", + "\n", + "!git clone https://github.com/ChenyangQiQi/FateZero /content/FateZero\n", + "%cd /content/FateZero\n", + "# %pip install -r requirements.txt\n", + "%pip install -q -U --pre triton\n", + "%pip install -q diffusers[torch]==0.11.1 transformers==4.26.0 bitsandbytes==0.35.4 \\\n", + "decord accelerate omegaconf einops ftfy gradio imageio-ffmpeg xformers" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "m6I6kZNG3Inb", + "outputId": "f3bcb6eb-a79c-4810-d575-e926c8e7564f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Updated git hooks.\n", + "Git LFS initialized.\n", + "Cloning into 'ckpt/CompVis/stable-diffusion-v1-4'...\n", + "remote: Enumerating objects: 738, done.\u001b[K\n", + "remote: Counting objects: 100% (12/12), done.\u001b[K\n", + "remote: Compressing objects: 100% (12/12), done.\u001b[K\n", + "remote: Total 738 (delta 3), reused 1 (delta 0), pack-reused 726\u001b[K\n", + "Receiving objects: 100% (738/738), 682.52 KiB | 954.00 KiB/s, done.\n", + "Resolving deltas: 100% (123/123), done.\n", + "Filtering content: 100% (8/8), 10.20 GiB | 63.59 MiB/s, done.\n", + "[*] MODEL_NAME=./ckpt/CompVis/stable-diffusion-v1-4\n" + ] + } + ], + "source": [ + "#@title Download pretrained model\n", + "\n", + "#@markdown Name/Path of the initial model.\n", + "MODEL_NAME = \"CompVis/stable-diffusion-v1-4\" #@param {type:\"string\"}\n", + "\n", + "#@markdown If model should be download from a remote repo. Untick it if the model is loaded from a local path.\n", + "download_pretrained_model = True #@param {type:\"boolean\"}\n", + "if download_pretrained_model:\n", + " !git lfs install\n", + " !git clone https://huggingface.co/$MODEL_NAME ckpt/$MODEL_NAME\n", + " MODEL_NAME = f\"./ckpt/{MODEL_NAME}\"\n", + "print(f\"[*] MODEL_NAME={MODEL_NAME}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qn5ILIyDJIcX" + }, + "source": [ + "# **Usage**\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i4L2yDXGflaC" + }, + "source": [ + "## FateZero Edit with low resource cost\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fXZs1veYIMMw", + "outputId": "c665eaba-ef12-498e-d173-6432e977fc07" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "save new configue to config/car-turn.yaml\n" + ] + } + ], + "source": [ + "#@markdown Edit config\n", + "\n", + "#@markdown More details of the configuration will be given soon.\n", + "\n", + "from omegaconf import OmegaConf\n", + "\n", + "VIDEO_FILE = 'data/car-turn' #@param {type:\"string\"}\n", + "\n", + "VIDEO_ID = VIDEO_FILE.split('/')[-1]\n", + "\n", + "RESULT_DIR = 'result/'+VIDEO_ID\n", + "\n", + "CONFIG_NAME = \"config/\"+VIDEO_ID+\".yaml\" \n", + "\n", + "source_prompt = \"a silver jeep driving down a curvy road in the countryside\" #@param {type:\"string\"}\n", + "edit_prompt = \"watercolor painting of a silver jeep driving down a curvy road in the countryside\" #@param {type:\"string\"}\n", + "EMPHYSIS_WORD = \"watercolor\" #@param {type:\"string\"}\n", + "EMPHYSIS_VALUE = 10 #@param {type:\"number\"}\n", + "video_length = 8 #@param {type:\"number\"}\n", + "INVERSION_STEP = 8 #@param {type:\"number\"}\n", + "REPLACE_STRENGTH = 0.8 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", + "STORE_ATTENTION_ON_disk = False #@param {type:\"boolean\"}\n", + "width = 512 \n", + "height = 512 \n", + "\n", + "config = {\n", + " \"pretrained_model_path\": MODEL_NAME,\n", + " \"logdir\": RESULT_DIR,\n", + " \"train_dataset\": {\n", + " \"path\": VIDEO_FILE,\n", + " \"prompt\": source_prompt,\n", + " \"n_sample_frame\": video_length,\n", + " \"sampling_rate\": 1,\n", + " \"stride\": 80,\n", + " \"offset\": \n", + " {\n", + " \"left\": 0,\n", + " \"right\": 0,\n", + " \"top\": 0,\n", + " \"bottom\": 0,\n", + " }\n", + " },\n", + " \"validation_sample_logger_config\":{\n", + " \"use_train_latents\": True,\n", + " \"use_inversion_attention\": True,\n", + " \"guidance_scale\": 7.5,\n", + " \"prompts\":[\n", + " source_prompt,\n", + " edit_prompt,\n", + " ],\n", + " \"p2p_config\":[ \n", + " {\n", + " \"cross_replace_steps\":{\n", + " \"default_\":0.8\n", + " },\n", + " \"self_replace_steps\": 0.8,\n", + " \"masked_self_attention\": True,\n", + " \"bend_th\": [2, 2],\n", + " \"is_replace_controller\": False \n", + " },\n", + " {\n", + " \"cross_replace_steps\":{\n", + " \"default_\":0.8\n", + " },\n", + " \"self_replace_steps\": 0.8,\n", + " \"eq_params\":{\n", + " \"words\":[EMPHYSIS_WORD],\n", + " \"values\": [EMPHYSIS_VALUE]\n", + " },\n", + " \"use_inversion_attention\": True,\n", + " \"is_replace_controller\": False \n", + " }]\n", + " ,\n", + " \"clip_length\": \"${..train_dataset.n_sample_frame}\",\n", + " \"sample_seeds\": [0],\n", + " \"num_inference_steps\": INVERSION_STEP,\n", + " \"prompt2prompt_edit\": True\n", + " },\n", + " \"disk_store\": STORE_ATTENTION_ON_disk,\n", + " \"model_config\":{\n", + " \"lora\": 160,\n", + " \"SparseCausalAttention_index\": ['mid'],\n", + " \"least_sc_channel\": 640\n", + " },\n", + " \"test_pipeline_config\":{\n", + " \"target\": \"video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline\",\n", + " \"num_inference_steps\": \"${..validation_sample_logger.num_inference_steps}\"\n", + " },\n", + " \"epsilon\": 1e-5,\n", + " \"train_steps\": 10,\n", + " \"seed\": 0,\n", + " \"learning_rate\": 1e-5,\n", + " \"train_temporal_conv\": False,\n", + " \"guidance_scale\": \"${validation_sample_logger_config.guidance_scale}\"\n", + "}\n", + "\n", + "OmegaConf.save(config, CONFIG_NAME)\n", + "print('save new configue to ', CONFIG_NAME)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jjcSXTp-u-Eg", + "outputId": "194d964e-08dc-4d3d-c0fd-7e56ed2eb187" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-03-22 09:04:20.819710: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", + "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2023-03-22 09:04:24.565385: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", + "2023-03-22 09:04:24.565750: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", + "2023-03-22 09:04:24.565782: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", + "The following values were not passed to `accelerate launch` and had defaults used instead:\n", + "\t`--num_processes` was set to a value of `1`\n", + "\t`--num_machines` was set to a value of `1`\n", + "\t`--mixed_precision` was set to a value of `'no'`\n", + "\t`--dynamo_backend` was set to a value of `'no'`\n", + "To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.\n", + "2023-03-22 09:04:31.342590: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", + "2023-03-22 09:04:31.342704: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia\n", + "2023-03-22 09:04:31.342734: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", + "The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n", + "use fp16\n", + "Number of attention layer registered 32\n", + " Invert clean image to noise latents by DDIM and Unet\n", + "100% 8/8 [00:25<00:00, 3.19s/it]\n", + "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (3328, 307) to (3328, 320) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n", + "Number of attention layer registered 32\n", + "Generating sample images: 0% 0/2 [00:00\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import HTML\n", + "from base64 import b64encode\n", + "import os, sys\n", + "import glob\n", + "\n", + "# get the last from results\n", + "mp4_name = sorted(glob.glob('./result/*/sample/step_0.mp4'))[-1]\n", + "\n", + "print(mp4_name)\n", + "mp4 = open('{}'.format(mp4_name),'rb').read()\n", + "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + "\n", + "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n", + "display(HTML(\"\"\"\n", + " \n", + " \"\"\" % data_url))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cBb3wTEXfhRo" + }, + "source": [ + "## Edit your video" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mQR2cjDZV9tu" + }, + "outputs": [], + "source": [ + "#@markdown Upload your video(.mp4) by running this cell or skip this cell using the default data\n", + "\n", + "import os\n", + "from google.colab import files\n", + "import shutil\n", + "from IPython.display import HTML\n", + "from base64 import b64encode\n", + "\n", + "uploaded = files.upload()\n", + "for filename in uploaded.keys():\n", + " dst_path = os.path.join(\"data\", filename)\n", + " shutil.move(filename, dst_path)\n", + " \n", + "file_id = dst_path.replace('.mp4', '')\n", + "\n", + "! mkdir -p $file_id\n", + "! ffmpeg -hide_banner -loglevel error -i $dst_path -vf scale=\"512:512\" -vf fps=25 $file_id/%05d.png\n", + "\n", + "mp4 = open('{}'.format(dst_path),'rb').read()\n", + "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + "\n", + "display(HTML(\"\"\"\n", + " \n", + " \"\"\" % data_url))\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "2.7.18 (default, Jul 1 2022, 12:27:04) \n[GCC 9.4.0]" + }, + "vscode": { + "interpreter": { + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/FateZero/config/.gitignore b/FateZero/config/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..de7a3dfd5079436802059004b14e4afbc657110d --- /dev/null +++ b/FateZero/config/.gitignore @@ -0,0 +1 @@ +# debug/** \ No newline at end of file diff --git a/FateZero/config/attribute/bear_tiger_lion_leopard.yaml b/FateZero/config/attribute/bear_tiger_lion_leopard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5aa85a5d28b1fb884e7b6963af5844a1dca0f2b6 --- /dev/null +++ b/FateZero/config/attribute/bear_tiger_lion_leopard.yaml @@ -0,0 +1,108 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/bear_tiger_lion_leopard.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/attribute/bear_tiger_lion_leopard" + prompt: "a brown bear walking on the rock against a wall" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + a brown bear walking on the rock against a wall, + + # foreground texture style + a red tiger walking on the rock against a wall, + a yellow leopard walking on the rock against a wall, + a brown lion walking on the rock against a wall, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: true + cross_replace_steps: + default_: 0.7 + self_replace_steps: 0.7 + 2: + is_replace_controller: true + cross_replace_steps: + default_: 0.7 + self_replace_steps: 0.7 + 3: + is_replace_controller: true + cross_replace_steps: + default_: 0.7 + self_replace_steps: 0.7 + + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/bus_gpu.yaml b/FateZero/config/attribute/bus_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33f6a899d16fe67c6751e55bef0fd9a1930995af --- /dev/null +++ b/FateZero/config/attribute/bus_gpu.yaml @@ -0,0 +1,100 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/bus_gpu.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/attribute/bus_gpu" + prompt: "a white and blue bus on the road" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + a white and blue bus on the road, + + # foreground texture style + a black and green GPU on the road + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: true + cross_replace_steps: + default_: 0.1 + self_replace_steps: 0.1 + + eq_params: + words: ["Nvidia", "GPU"] + values: [10, 10] # amplify attention to the word "tiger" by *2 + + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/cat_tiger_leopard_grass.yaml b/FateZero/config/attribute/cat_tiger_leopard_grass.yaml new file mode 100644 index 0000000000000000000000000000000000000000..508ffca956b2c677bb622eb72f96536e59de37f0 --- /dev/null +++ b/FateZero/config/attribute/cat_tiger_leopard_grass.yaml @@ -0,0 +1,112 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/cat_tiger_leopard_grass.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/attribute/cat_tiger_leopard_grass" + prompt: "A black cat walking on the floor next to a wall" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + A black cat walking on the floor next to a wall, + A black cat walking on the grass next to a wall, + A red tiger walking on the floor next to a wall, + a yellow cute Shiba-Inu walking on the floor next to a wall, + a yellow cute leopard walking on the floor next to a wall, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + 2: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + 3: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + 4: + is_replace_controller: false + cross_replace_steps: + default_: 0.7 + self_replace_steps: 0.7 + + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/dog_robotic_corgi.yaml b/FateZero/config/attribute/dog_robotic_corgi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7fbe83a2610be865e51b57a09d01d7d88df7332 --- /dev/null +++ b/FateZero/config/attribute/dog_robotic_corgi.yaml @@ -0,0 +1,103 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/dog_robotic_corgi.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/attribute/gray_dog" + prompt: "A gray dog sitting on the mat" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + A gray dog sitting on the mat, + + # foreground texture style + A robotic dog sitting on the mat, + A yellow corgi sitting on the mat + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + + eq_params: + words: ["robotic"] + values: [10] # amplify attention to the word "tiger" by *2 + + 2: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/duck_rubber.yaml b/FateZero/config/attribute/duck_rubber.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2b422424b0014fbbc8e109f9f931d5df60221f7 --- /dev/null +++ b/FateZero/config/attribute/duck_rubber.yaml @@ -0,0 +1,99 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/duck_rubber.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/attribute/duck_rubber" + prompt: "a sleepy white duck" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + a sleepy white duck, + + # foreground texture style + a sleepy yellow rubber duck + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 0.7 + self_replace_steps: 0.7 + + # eq_params: + # words: ["yellow", "rubber"] + # values: [10, 10] # amplify attention to the word "tiger" by *2 + + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/fox_wolf_snow.yaml b/FateZero/config/attribute/fox_wolf_snow.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7e20984e6275eb350a6855eaf51d0c1ea1c3ffb --- /dev/null +++ b/FateZero/config/attribute/fox_wolf_snow.yaml @@ -0,0 +1,107 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/fox_wolf_snow.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/attribute/fox_wolf_snow" + prompt: "a white fox sitting in the grass" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + a white fox sitting in the grass, + + # foreground texture style + a grey wolf sitting in the grass, + a white fox sitting in the snow + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + + eq_params: + words: ["robotic"] + values: [10] # amplify attention to the word "tiger" by *2 + + 2: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + eq_params: + words: ["snow"] + values: [10] # amplify attention to the word "tiger" by *2 + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/rabbit_straberry_leaves_flowers.yaml b/FateZero/config/attribute/rabbit_straberry_leaves_flowers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c579eb66dea4cb3953694a200629896d4a9270ef --- /dev/null +++ b/FateZero/config/attribute/rabbit_straberry_leaves_flowers.yaml @@ -0,0 +1,114 @@ +# CUDA_VISIBLE_DEVICES=1 python test_fatezero.py --config config/attribute/rabbit_straberry_leaves_flowers.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/attribute/rabbit_strawberry" + prompt: "A rabbit is eating strawberries" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + A rabbit is eating strawberries, + + # foreground texture style + A white rabbit is eating leaves, + A white rabbit is eating flower, + A white rabbit is eating orange, + + # a brown lion walking on the rock against a wall, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + 1: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + eq_params: + words: ["leaves"] + values: [10] + 2: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + eq_params: + words: ["flower"] + values: [10] + 3: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + eq_params: + words: ["orange"] + values: [10] + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/squ_carrot_robot_eggplant.yaml b/FateZero/config/attribute/squ_carrot_robot_eggplant.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b186d9114abfc649678b9a0bca34b23b871cb0b3 --- /dev/null +++ b/FateZero/config/attribute/squ_carrot_robot_eggplant.yaml @@ -0,0 +1,123 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/squ_carrot_robot_eggplant.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/attribute/squirrel_carrot" + prompt: "A squirrel is eating a carrot" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + A squirrel is eating a carrot, + A robot squirrel is eating a carrot, + A rabbit is eating a eggplant, + A robot mouse is eating a screwdriver, + A white mouse is eating a peanut, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.4 + eq_params: + words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"] + values: [10, 10, 20, 10, 10, 10] + 2: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + eq_params: + words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"] + values: [10, 10, 20, 10, 10, 10] + 3: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + eq_params: + words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"] + values: [10, 10, 20, 10, 10, 10] + 4: + is_replace_controller: false + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + eq_params: + words: ["rabbit", "mouse", "robot", "eggplant", "peanut", "screwdriver"] + values: [10, 10, 20, 10, 10, 10] + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/attribute/swan_swa.yaml b/FateZero/config/attribute/swan_swa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb64d1dfa268916b9e31bcd0b3a080704b3df13c --- /dev/null +++ b/FateZero/config/attribute/swan_swa.yaml @@ -0,0 +1,102 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/attribute/swan_swa.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/attribute/swan_swarov" + prompt: "a black swan with a red beak swimming in a river near a wall and bushes," + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +use_train_latents: True + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + a black swan with a red beak swimming in a river near a wall and bushes, + + # foreground texture style + a Swarovski crystal swan with a red beak swimming in a river near a wall and bushes, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 0.8 + self_replace_steps: 0.6 + + eq_params: + words: ["Swarovski", "crystal"] + values: [5, 5] # amplify attention to the word "tiger" by *2 + use_inversion_attention: True + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 1280 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml b/FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ea3a3437c33bbdb7d90943e532a7d785dc3f8c06 --- /dev/null +++ b/FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml @@ -0,0 +1,83 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/low_resource_teaser/jeep_watercolor.yaml + +pretrained_model_path: "FateZero/ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "FateZero/data/teaser_car-turn" + prompt: "a silver jeep driving down a curvy road in the countryside" + n_sample_frame: 8 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + + +validation_sample_logger_config: + use_train_latents: true + use_inversion_attention: true + guidance_scale: 7.5 + source_prompt: "${train_dataset.prompt}" + prompts: [ + # a silver jeep driving down a curvy road in the countryside, + watercolor painting of a silver jeep driving down a curvy road in the countryside, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + + is_replace_controller: False + + # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.8 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["watercolor"] + values: [10,10] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + # blend_words: [['jeep',], ["car",]] + # masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + # bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention + # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + + num_inference_steps: 10 + prompt2prompt_edit: True + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml b/FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fe8b2f58bb1bc70b968f6e73f7a5ccbcff22188 --- /dev/null +++ b/FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml @@ -0,0 +1,84 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/teaser_car-turn" + prompt: "a silver jeep driving down a curvy road in the countryside" + n_sample_frame: 8 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + + +validation_sample_logger_config: + use_train_latents: true + use_inversion_attention: true + guidance_scale: 7.5 + source_prompt: "${train_dataset.prompt}" + prompts: [ + # a silver jeep driving down a curvy road in the countryside, + watercolor painting of a silver jeep driving down a curvy road in the countryside, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + + is_replace_controller: False + + # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.8 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["watercolor"] + values: [10,10] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + # blend_words: [['jeep',], ["car",]] + # masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + # bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention + # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + + num_inference_steps: 10 + prompt2prompt_edit: True + +disk_store: True +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/style/jeep_watercolor.yaml b/FateZero/config/style/jeep_watercolor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b4e3e0b1057f4ce0416aec2e46425372b9d5ecf2 --- /dev/null +++ b/FateZero/config/style/jeep_watercolor.yaml @@ -0,0 +1,94 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml + +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/teaser_car-turn" + prompt: "a silver jeep driving down a curvy road in the countryside" + n_sample_frame: 8 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + + +validation_sample_logger_config: + use_train_latents: true + use_inversion_attention: true + guidance_scale: 7.5 + prompts: [ + a silver jeep driving down a curvy road in the countryside, + watercolor painting of a silver jeep driving down a curvy road in the countryside, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.9 + + + # Amplify the target-words cross attention, larger value, more close to target + # eq_params: + # words: ["", ""] + # values: [10,10] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + # blend_words: [['jeep',], ["car",]] + masked_self_attention: True + # masked_latents: False # Directly copy the latents, performance not so good in our case + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention + # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention + + + 1: + cross_replace_steps: + default_: 0.8 + self_replace_steps: 0.8 + + eq_params: + words: ["watercolor"] + values: [10] # amplify attention to the word "tiger" by *2 + use_inversion_attention: True + is_replace_controller: False + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/style/lily_monet.yaml b/FateZero/config/style/lily_monet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4cd00b743daee0d83707632bc78f156cb02c54c1 --- /dev/null +++ b/FateZero/config/style/lily_monet.yaml @@ -0,0 +1,93 @@ +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/style/red_water_lily_opening" + prompt: "a pink water lily" + start_sample_frame: 1 + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 20 + stride: 8000 + # offset: + # left: 300 + # right: 0 + # top: 0 + # bottom: 0 + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + a pink water lily, + Claude Monet painting of a pink water lily, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.7 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.7 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + + eq_params: + words: ["Monet"] + values: [10] + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 1280 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/style/rabit_pokemon.yaml b/FateZero/config/style/rabit_pokemon.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bab8e40423d7b9936ccbf138718afdc4ec201fd6 --- /dev/null +++ b/FateZero/config/style/rabit_pokemon.yaml @@ -0,0 +1,92 @@ +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/style/rabit" + prompt: "A rabbit is eating a watermelon" + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 3 + stride: 80 + + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + # source prompt + A rabbit is eating a watermelon, + # overall style + pokemon cartoon of A rabbit is eating a watermelon, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 0.7 + self_replace_steps: 0.7 + + eq_params: + words: ["pokemon", "cartoon"] + values: [3, 3] # amplify attention to the word "tiger" by *2 + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + # lora: 160 + # temporal_downsample_time: 4 + # SparseCausalAttention_index: ['mid'] + # least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 50 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/style/sun_flower_van_gogh.yaml b/FateZero/config/style/sun_flower_van_gogh.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f115905ca37a85e9ca7df5bc77cb4d51e20351ac --- /dev/null +++ b/FateZero/config/style/sun_flower_van_gogh.yaml @@ -0,0 +1,86 @@ +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/style/sunflower" + prompt: "a yellow sunflower" + start_sample_frame: 0 + n_sample_frame: 8 + sampling_rate: 1 + + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + a yellow sunflower, + van gogh style painting of a yellow sunflower, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.7 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.7 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + + eq_params: + words: ["van", "gogh"] + values: [10, 10] # amplify attention to the word "tiger" by *2 + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/style/surf_ukiyo.yaml b/FateZero/config/style/surf_ukiyo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a398bb9d87bb0f356b5d26448ce07ef18961ab2d --- /dev/null +++ b/FateZero/config/style/surf_ukiyo.yaml @@ -0,0 +1,90 @@ +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/style/surf" + prompt: "a man with round helmet surfing on a white wave in blue ocean with a rope" + n_sample_frame: 1 + + sampling_rate: 8 + + +# use_train_latents: True + +validation_sample_logger_config: + use_train_latents: true + use_inversion_attention: true + guidance_scale: 7.5 + prompts: [ + a man with round helmet surfing on a white wave in blue ocean with a rope, + The Ukiyo-e style painting of a man with round helmet surfing on a white wave in blue ocean with a rope + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.8 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 0.9 + self_replace_steps: 0.9 + + eq_params: + words: ["Ukiyo-e"] + values: [10, 10] # amplify attention to the word "tiger" by *2 + + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + # lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 50 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/style/swan_cartoon.yaml b/FateZero/config/style/swan_cartoon.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca282e1c56c99e93b453b1e76ab649af0d0f0ab3 --- /dev/null +++ b/FateZero/config/style/swan_cartoon.yaml @@ -0,0 +1,101 @@ +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + + +train_dataset: + path: "data/style/blackswan" + prompt: "a black swan with a red beak swimming in a river near a wall and bushes," + n_sample_frame: 8 + # n_sample_frame: 22 + sampling_rate: 6 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + +# use_train_latents: True + +validation_sample_logger_config: + use_train_latents: true + use_inversion_attention: true + guidance_scale: 7.5 + prompts: [ + # source prompt + a black swan with a red beak swimming in a river near a wall and bushes, + cartoon photo of a black swan with a red beak swimming in a river near a wall and bushes, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.6 + + + # Amplify the target-words cross attention, larger value, more close to target + eq_params: + words: ["silver", "sculpture"] + values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + blend_words: [['cat',], ["cat",]] + masked_self_attention: True + # masked_latents: False # performance not so good in our case, need debug + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + # Fixed hyperparams + use_inversion_attention: True + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 0.8 + self_replace_steps: 0.7 + + eq_params: + words: ["cartoon"] + values: [10] # amplify attention to the word "tiger" by *2 + use_inversion_attention: True + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + # guidance_scale: 7.5 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/style/train_shinkai.yaml b/FateZero/config/style/train_shinkai.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04a4b61b2ea1fb542408914f94520e77aaa9afa4 --- /dev/null +++ b/FateZero/config/style/train_shinkai.yaml @@ -0,0 +1,97 @@ +pretrained_model_path: "./ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "data/style/train" + prompt: "a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track" + n_sample_frame: 32 + # n_sample_frame: 22 + sampling_rate: 7 + stride: 80 + # offset: + # left: 300 + # right: 0 + # top: 0 + # bottom: 0 + +use_train_latents: True + +validation_sample_logger_config: + use_train_latents: True + use_inversion_attention: True + guidance_scale: 7.5 + prompts: [ + a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track, + a train traveling down tracks next to a forest filled with trees and flowers and a man on the side of the track Makoto Shinkai style + + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic preserving and replacement Debug me + cross_replace_steps: + default_: 1.0 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 1.0 + + + # Amplify the target-words cross attention, larger value, more close to target + # eq_params: + # words: ["silver", "sculpture"] + # values: [2,2] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + # blend_words: [['cat',], ["cat",]] + # masked_self_attention: True + # # masked_latents: False # performance not so good in our case, need debug + # bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # preserve all source self-attention + # bend_th : [0.0, 0.0], mask -> 1, use more att_replace, more generated attention, less source acttention + + + 1: + is_replace_controller: False + cross_replace_steps: + default_: 1.0 + self_replace_steps: 0.9 + + eq_params: + words: ["Makoto", "Shinkai"] + values: [10, 10] # amplify attention to the word "tiger" by *2 + + + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + val_all_frames: False + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 1280 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/teaser/jeep_posche.yaml b/FateZero/config/teaser/jeep_posche.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9cb9fe0370558b58f2230e5ac02c84d8637faabe --- /dev/null +++ b/FateZero/config/teaser/jeep_posche.yaml @@ -0,0 +1,93 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_posche.yaml + +pretrained_model_path: "./ckpt/jeep_tuned_200" + +train_dataset: + path: "data/teaser_car-turn" + prompt: "a silver jeep driving down a curvy road in the countryside," + n_sample_frame: 8 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + + +validation_sample_logger_config: + use_train_latents: true + use_inversion_attention: true + guidance_scale: 7.5 + prompts: [ + a silver jeep driving down a curvy road in the countryside, + a Porsche car driving down a curvy road in the countryside, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.9 + + + # Amplify the target-words cross attention, larger value, more close to target + # Usefull in style editing + eq_params: + words: ["watercolor", "painting"] + values: [10,10] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + # Usefull in shape editing + blend_words: [['jeep',], ["car",]] + masked_self_attention: True + # masked_latents: False # Directly copy the latents, performance not so good in our case + + # preserve source structure of blend_words , [0, 1] + # bend_th-> [1.0, 1.0], mask -> 0, use inversion-time attention, the structure is similar to the input + # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention + bend_th: [0.3, 0.3] + + 1: + cross_replace_steps: + default_: 0.5 + self_replace_steps: 0.5 + + use_inversion_attention: True + is_replace_controller: True + + blend_words: [['silver', 'jeep'], ["Porsche", 'car']] # for local edit. If it is not local yet - use only the source object: blend_word = ((('cat',), ("cat",))). + masked_self_attention: True + bend_th: [0.3, 0.3] + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/config/teaser/jeep_watercolor.yaml b/FateZero/config/teaser/jeep_watercolor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e72607ee652c04ab6f81a43d9e60d9524f9c961 --- /dev/null +++ b/FateZero/config/teaser/jeep_watercolor.yaml @@ -0,0 +1,94 @@ +# CUDA_VISIBLE_DEVICES=0 python test_fatezero.py --config config/teaser/jeep_watercolor.yaml + +pretrained_model_path: "FateZero/ckpt/stable-diffusion-v1-4" + +train_dataset: + path: "FateZero/data/teaser_car-turn" + prompt: "a silver jeep driving down a curvy road in the countryside" + n_sample_frame: 8 + sampling_rate: 1 + stride: 80 + offset: + left: 0 + right: 0 + top: 0 + bottom: 0 + + +validation_sample_logger_config: + use_train_latents: true + use_inversion_attention: true + guidance_scale: 7.5 + prompts: [ + a silver jeep driving down a curvy road in the countryside, + watercolor painting of a silver jeep driving down a curvy road in the countryside, + ] + p2p_config: + 0: + # Whether to directly copy the cross attention from source + # True: directly copy, better for object replacement + # False: keep source attention, better for style + is_replace_controller: False + + # Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout + cross_replace_steps: + default_: 0.8 + + # Source background structure preserving, in [0, 1]. + # e.g., =0.6 Replace the first 60% steps self-attention + self_replace_steps: 0.9 + + + # Amplify the target-words cross attention, larger value, more close to target + # eq_params: + # words: ["", ""] + # values: [10,10] + + # Target structure-divergence hyperparames + # If you change the shape of object better to use all three line, otherwise, no need. + # Without following three lines, all self-attention will be replaced + # blend_words: [['jeep',], ["car",]] + masked_self_attention: True + # masked_latents: False # Directly copy the latents, performance not so good in our case + bend_th: [2, 2] + # preserve source structure of blend_words , [0, 1] + # default is bend_th: [2, 2] # replace full-resolution edit source with self-attention + # bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention + + + 1: + cross_replace_steps: + default_: 0.8 + self_replace_steps: 0.8 + + eq_params: + words: ["watercolor"] + values: [10] # amplify attention to the word "tiger" by *2 + use_inversion_attention: True + is_replace_controller: False + + + clip_length: "${..train_dataset.n_sample_frame}" + sample_seeds: [0] + + num_inference_steps: 50 + prompt2prompt_edit: True + + +model_config: + lora: 160 + # temporal_downsample_time: 4 + SparseCausalAttention_index: ['mid'] + least_sc_channel: 640 + # least_sc_channel: 100000 + +test_pipeline_config: + target: video_diffusion.pipelines.p2pDDIMSpatioTemporalPipeline.p2pDDIMSpatioTemporalPipeline + num_inference_steps: "${..validation_sample_logger.num_inference_steps}" + +epsilon: 1e-5 +train_steps: 10 +seed: 0 +learning_rate: 1e-5 +train_temporal_conv: False +guidance_scale: "${validation_sample_logger_config.guidance_scale}" \ No newline at end of file diff --git a/FateZero/data/.gitignore b/FateZero/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9025a3bd663543b8544571de73115a2a16be7a30 --- /dev/null +++ b/FateZero/data/.gitignore @@ -0,0 +1,4 @@ +* +!teaser_car-turn +!teaser_car-turn/* +!.gitignore \ No newline at end of file diff --git a/FateZero/data/teaser_car-turn/00000.png b/FateZero/data/teaser_car-turn/00000.png new file mode 100644 index 0000000000000000000000000000000000000000..aab14fc2632209038caf957f7bad990d0ac3572f Binary files /dev/null and b/FateZero/data/teaser_car-turn/00000.png differ diff --git a/FateZero/data/teaser_car-turn/00001.png b/FateZero/data/teaser_car-turn/00001.png new file mode 100644 index 0000000000000000000000000000000000000000..7a30cd0379351556985f2df389a459b6c69bf9a3 Binary files /dev/null and b/FateZero/data/teaser_car-turn/00001.png differ diff --git a/FateZero/data/teaser_car-turn/00002.png b/FateZero/data/teaser_car-turn/00002.png new file mode 100644 index 0000000000000000000000000000000000000000..dbbb8284c228c6789b0e0b935607b2a1548ceee0 Binary files /dev/null and b/FateZero/data/teaser_car-turn/00002.png differ diff --git a/FateZero/data/teaser_car-turn/00003.png b/FateZero/data/teaser_car-turn/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..4b65cad8d0f7424bd5489792dcce520ac49d8f29 Binary files /dev/null and b/FateZero/data/teaser_car-turn/00003.png differ diff --git a/FateZero/data/teaser_car-turn/00004.png b/FateZero/data/teaser_car-turn/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..29c9723d34d9d7eca08ac06ff10d1aa549ab2b36 Binary files /dev/null and b/FateZero/data/teaser_car-turn/00004.png differ diff --git a/FateZero/data/teaser_car-turn/00005.png b/FateZero/data/teaser_car-turn/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..edfa4a1b612dcd2cd1ad05c19dc165cf5ee5d286 Binary files /dev/null and b/FateZero/data/teaser_car-turn/00005.png differ diff --git a/FateZero/data/teaser_car-turn/00006.png b/FateZero/data/teaser_car-turn/00006.png new file mode 100644 index 0000000000000000000000000000000000000000..4ab30fe31c6f2bcb7a84c7bcb98b6fb3ad438f23 Binary files /dev/null and b/FateZero/data/teaser_car-turn/00006.png differ diff --git a/FateZero/data/teaser_car-turn/00007.png b/FateZero/data/teaser_car-turn/00007.png new file mode 100644 index 0000000000000000000000000000000000000000..edbb6a392dd2ef3fff01276f67a26db5d69abf22 Binary files /dev/null and b/FateZero/data/teaser_car-turn/00007.png differ diff --git a/FateZero/docs/EditingGuidance.md b/FateZero/docs/EditingGuidance.md new file mode 100644 index 0000000000000000000000000000000000000000..c33b110aa1bd21cf810699a4c900edf195fb5d0d --- /dev/null +++ b/FateZero/docs/EditingGuidance.md @@ -0,0 +1,65 @@ +# EditingGuidance + +## Prompt Engineering +For the results in the paper and webpage, we get the source prompt using the BLIP model embedded in the [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/). + +Click the "interrogate CLIP", and we will get a source prompt automatically. Then, we remove the last few useless words. + + + +During stylization, you may use a very simple source prompt "A photo" as a baseline if your input video is too complicated to describe by one sentence. + +### Validate the prompt + +- Put the source prompt into the stable diffusion. If the generated image is close to our input video, it can be a good source prompt. +- A good prompt describes each frame and most objects in video. Especially, it has the object or attribute that we want to edit or preserve. +- Put the target prompt into the stable diffusion. We can check the upper bound of our editing effect. A reasonable composition of video may achieve better results(e.g., "sunflower" video with "Van Gogh" prompt is better than "sunflower" with "Monet") + + + + + + +## FateZero hyperparameters +We give a simple analysis of the involved hyperparaters as follows: +``` yaml +# Whether to directly copy the cross attention from source +# True: directly copy, better for object replacement +# False: keep source attention, better for style +is_replace_controller: False + +# Semantic layout preserving. High steps, replace more cross attention to preserve semantic layout +cross_replace_steps: + default_: 0.8 + +# Source background structure preserving, in [0, 1]. +# e.g., =0.6 Replace the first 60% steps self-attention +self_replace_steps: 0.8 + + +# Amplify the target-words cross attention, larger value, more close to target +# eq_params: +# words: ["", ""] +# values: [10,10] + +# Target structure-divergence hyperparames +# If you change the shape of object, it is better to use all three line; otherwise, no need. +# Without following three lines, all self-attention will be replaced +blend_words: [['jeep',], ["car",]] +masked_self_attention: True +# masked_latents: False # Directly copy the latents, performance not so good in our case +bend_th: [2, 2] +# preserve source structure of blend_words in [0, 1] +# default is bend_th: [2, 2] # replace full-resolution edit source with self-attention +# bend_th-> [0.0, 0.0], mask -> 1, use more edit self-attention, more generated shape, less source acttention +``` + +## DDIM hyperparameters + +We profile the cost of editing 8 frames on an Nvidia 3090, fp16 of accelerator, xformers. + +| Configs | Attention location | DDIM Inver. Step | CPU memory | GPU memory | Inversion time | Editing time time | Quality +|------------------|------------------ |------------------|------------------|------------------|------------------|----| ---- | +| [basic](../config/teaser/jeep_watercolor.yaml) | RAM | 50 | 100G | 12G | 60s | 40s | Full support +| [low cost](../config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml) | RAM | 10 | 15G | 12G | 10s | 10s | OK for Style, not work for shape +| [lower cost](../config/low_resource_teaser/jeep_watercolor_ddim_10_steps_disk_store.yaml) | DISK | 10 | 6G | 12G | 33 s | 100s | OK for Style, not work for shape diff --git a/FateZero/docs/OpenSans-Regular.ttf b/FateZero/docs/OpenSans-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..ae02899e818bb4a6c0b7302ea392a62e5453eebd Binary files /dev/null and b/FateZero/docs/OpenSans-Regular.ttf differ diff --git a/FateZero/requirements.txt b/FateZero/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ff2d18eb96f17e14099eaa8321bfd6d18ed7fec7 --- /dev/null +++ b/FateZero/requirements.txt @@ -0,0 +1,17 @@ +--extra-index-url https://download.pytorch.org/whl/cu113 +torch==1.12.1+cu113 # --index-url https://download.pytorch.org/whl/cu113 +torchvision==0.13.1+cu113 # --index-url https://download.pytorch.org/whl/cu113 +diffusers[torch]==0.11.1 +accelerate==0.15.0 +transformers==4.25.1 +bitsandbytes==0.35.4 +einops +omegaconf +ftfy +tensorboard +modelcards +imageio +triton +click +opencv-python +imageio[ffmpeg] \ No newline at end of file diff --git a/FateZero/test_fatezero.py b/FateZero/test_fatezero.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f31ce937f4332caa9a8e94ca35179f9160732e --- /dev/null +++ b/FateZero/test_fatezero.py @@ -0,0 +1,290 @@ +import os +from glob import glob +import copy +from typing import Optional,Dict +from tqdm.auto import tqdm +from omegaconf import OmegaConf +import click + +import torch +import torch.utils.data +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDIMScheduler, +) +from diffusers.utils.import_utils import is_xformers_available +from transformers import AutoTokenizer, CLIPTextModel +from einops import rearrange + +import sys +sys.path.append('FateZero') +from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel +from video_diffusion.data.dataset import ImageSequenceDataset +from video_diffusion.common.util import get_time_string, get_function_args +from video_diffusion.common.image_util import log_train_samples +from video_diffusion.common.instantiate_from_config import instantiate_from_config +from video_diffusion.pipelines.p2pvalidation_loop import p2pSampleLogger + +logger = get_logger(__name__) + + +def collate_fn(examples): + """Concat a batch of sampled image in dataloader + """ + batch = { + "prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0), + "images": torch.stack([example["images"] for example in examples]), + } + return batch + + + +def test( + config: str, + pretrained_model_path: str, + train_dataset: Dict, + logdir: str = None, + validation_sample_logger_config: Optional[Dict] = None, + test_pipeline_config: Optional[Dict] = None, + gradient_accumulation_steps: int = 1, + seed: Optional[int] = None, + mixed_precision: Optional[str] = "fp16", + train_batch_size: int = 1, + model_config: dict={}, + verbose: bool=True, + **kwargs + +): + args = get_function_args() + + time_string = get_time_string() + if logdir is None: + logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '') + logdir += f"_{time_string}" + + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision, + ) + if accelerator.is_main_process: + os.makedirs(logdir, exist_ok=True) + OmegaConf.save(args, os.path.join(logdir, "config.yml")) + + if seed is not None: + set_seed(seed) + + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_path, + subfolder="tokenizer", + use_fast=False, + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_path, + subfolder="text_encoder", + ) + + vae = AutoencoderKL.from_pretrained( + pretrained_model_path, + subfolder="vae", + ) + + unet = UNetPseudo3DConditionModel.from_2d_model( + os.path.join(pretrained_model_path, "unet"), model_config=model_config + ) + + if 'target' not in test_pipeline_config: + test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline' + + pipeline = instantiate_from_config( + test_pipeline_config, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=DDIMScheduler.from_pretrained( + pretrained_model_path, + subfolder="scheduler", + ), + disk_store=kwargs.get('disk_store', False) + ) + pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps']) + pipeline.set_progress_bar_config(disable=True) + + + if is_xformers_available(): + try: + pipeline.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + prompt_ids = tokenizer( + train_dataset["prompt"], + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=4, + collate_fn=collate_fn, + ) + train_sample_save_path = os.path.join(logdir, "train_samples.gif") + log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader) + + unet, train_dataloader = accelerator.prepare( + unet, train_dataloader + ) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + print('use fp16') + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # These models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("video") # , config=vars(args)) + logger.info("***** wait to fix the logger path *****") + + if validation_sample_logger_config is not None and accelerator.is_main_process: + validation_sample_logger = p2pSampleLogger(**validation_sample_logger_config, logdir=logdir) + # validation_sample_logger.log_sample_images( + # pipeline=pipeline, + # device=accelerator.device, + # step=0, + # ) + def make_data_yielder(dataloader): + while True: + for batch in dataloader: + yield batch + accelerator.wait_for_everyone() + + train_data_yielder = make_data_yielder(train_dataloader) + + + batch = next(train_data_yielder) + if validation_sample_logger_config.get('use_train_latents', False): + # Precompute the latents for this video to align the initial latents in training and test + assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video" + # we only inference for latents, no training + vae.eval() + text_encoder.eval() + unet.eval() + + text_embeddings = pipeline._encode_prompt( + train_dataset.prompt, + device = accelerator.device, + num_images_per_prompt = 1, + do_classifier_free_guidance = True, + negative_prompt=None + ) + + use_inversion_attention = validation_sample_logger_config.get('use_inversion_attention', False) + batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted( + rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"), + batch_size = 1, + num_images_per_prompt = 1, # not sure how to use it + text_embeddings = text_embeddings, + prompt = train_dataset.prompt, + store_attention=use_inversion_attention, + LOW_RESOURCE = True, # not classifier-free guidance + save_path = logdir if verbose else None + ) + + batch['ddim_init_latents'] = batch['latents_all_step'][-1] + + else: + batch['ddim_init_latents'] = None + + vae.eval() + text_encoder.eval() + unet.eval() + + # with accelerator.accumulate(unet): + # Convert images to latent space + images = batch["images"].to(dtype=weight_dtype) + images = rearrange(images, "b c f h w -> (b f) c h w") + + + if accelerator.is_main_process: + + if validation_sample_logger is not None: + unet.eval() + samples_all, save_path = validation_sample_logger.log_sample_images( + image=images, # torch.Size([8, 3, 512, 512]) + pipeline=pipeline, + device=accelerator.device, + step=0, + latents = batch['ddim_init_latents'], + save_dir = logdir if verbose else None + ) + # accelerator.log(logs, step=step) + print('accelerator.end_training()') + accelerator.end_training() + return save_path + + +# @click.command() +# @click.option("--config", type=str, default="FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml") +def run(config='FateZero/config/low_resource_teaser/jeep_watercolor_ddim_10_steps.yaml'): + print(f'in run function {config}') + Omegadict = OmegaConf.load(config) + if 'unet' in os.listdir(Omegadict['pretrained_model_path']): + test(config=config, **Omegadict) + print('test finished') + return '/home/cqiaa/diffusion/hugging_face/Tune-A-Video-inference/FateZero/result/low_resource_teaser/jeep_watercolor_ddim_10_steps_230327-200651/sample/step_0_0_0.mp4' + else: + # Go through all ckpt if possible + checkpoint_list = sorted(glob(os.path.join(Omegadict['pretrained_model_path'], 'checkpoint_*'))) + print('checkpoint to evaluate:') + for checkpoint in checkpoint_list: + epoch = checkpoint.split('_')[-1] + + for checkpoint in tqdm(checkpoint_list): + epoch = checkpoint.split('_')[-1] + if 'pretrained_epoch_list' not in Omegadict or int(epoch) in Omegadict['pretrained_epoch_list']: + print(f'Evaluate {checkpoint}') + # Update saving dir and ckpt + Omegadict_checkpoint = copy.deepcopy(Omegadict) + Omegadict_checkpoint['pretrained_model_path'] = checkpoint + + if 'logdir' not in Omegadict_checkpoint: + logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '') + logdir += f"/{os.path.basename(checkpoint)}" + + Omegadict_checkpoint['logdir'] = logdir + print(f'Saving at {logdir}') + + test(config=config, **Omegadict_checkpoint) + + +if __name__ == "__main__": + run('FateZero/config/teaser/jeep_watercolor.yaml') diff --git a/FateZero/test_fatezero_dataset.py b/FateZero/test_fatezero_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..95d777ff31161bed2ad54c06ec81e47d5065c664 --- /dev/null +++ b/FateZero/test_fatezero_dataset.py @@ -0,0 +1,52 @@ + + +from test_fatezero import * +from glob import glob +import copy + +@click.command() +@click.option("--edit_config", type=str, default="config/supp/style/0313_style_edit_warp_640.yaml") +@click.option("--dataset_config", type=str, default="data/supp_edit_dataset/dataset_prompt.yaml") +def run(edit_config, dataset_config): + Omegadict_edit_config = OmegaConf.load(edit_config) + Omegadict_dataset_config = OmegaConf.load(dataset_config) + + # Go trough all data sample + data_sample_list = sorted(Omegadict_dataset_config.keys()) + print(f'Datasample to evaluate: {data_sample_list}') + dataset_time_string = get_time_string() + for data_sample in data_sample_list: + print(f'Evaluate {data_sample}') + + for p2p_config_index, p2p_config in Omegadict_edit_config['validation_sample_logger_config']['p2p_config'].items(): + edit_config_now = copy.deepcopy(Omegadict_edit_config) + edit_config_now['train_dataset'] = copy.deepcopy(Omegadict_dataset_config[data_sample]) + edit_config_now['train_dataset'].pop('target') + if 'eq_params' in edit_config_now['train_dataset']: + edit_config_now['train_dataset'].pop('eq_params') + # edit_config_now['train_dataset']['prompt'] = Omegadict_dataset_config[data_sample]['source'] + + edit_config_now['validation_sample_logger_config']['prompts'] \ + = copy.deepcopy( [Omegadict_dataset_config[data_sample]['prompt'],]+ OmegaConf.to_object(Omegadict_dataset_config[data_sample]['target'])) + p2p_config_now = dict() + for i in range(len(edit_config_now['validation_sample_logger_config']['prompts'])): + p2p_config_now[i] = p2p_config + if 'eq_params' in Omegadict_dataset_config[data_sample]: + p2p_config_now[i]['eq_params'] = Omegadict_dataset_config[data_sample]['eq_params'] + + edit_config_now['validation_sample_logger_config']['p2p_config'] = copy.deepcopy(p2p_config_now) + edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['prompt'] + # edit_config_now['validation_sample_logger_config']['source_prompt'] = Omegadict_dataset_config[data_sample]['eq_params'] + + + # if 'logdir' not in edit_config_now: + logdir = edit_config.replace('config', 'result').replace('.yml', '').replace('.yaml', '')+f'_config_{p2p_config_index}'+f'_{os.path.basename(dataset_config)[:-5]}'+f'_{dataset_time_string}' + logdir += f"/{data_sample}" + edit_config_now['logdir'] = logdir + print(f'Saving at {logdir}') + + test(config=edit_config, **edit_config_now) + + +if __name__ == "__main__": + run() diff --git a/FateZero/test_install.py b/FateZero/test_install.py new file mode 100644 index 0000000000000000000000000000000000000000..9acea8be04e21364980e6a1abf47eeeb74bc82a5 --- /dev/null +++ b/FateZero/test_install.py @@ -0,0 +1,23 @@ +import torch +import os + +import sys +print(f"python version {sys.version}") +print(f"torch version {torch.__version__}") +print(f"validate gpu status:") +print( torch.tensor(1.0).cuda()*2) +os.system("nvcc --version") + +import diffusers +print(diffusers.__version__) +print(diffusers.__file__) + +try: + import bitsandbytes + print(bitsandbytes.__file__) +except: + print("fail to import bitsandbytes") + +os.system("accelerate env") + +os.system("python -m xformers.info") \ No newline at end of file diff --git a/FateZero/train_tune_a_video.py b/FateZero/train_tune_a_video.py new file mode 100644 index 0000000000000000000000000000000000000000..588b22bbae94147c0c7551e3e859dc3065074f7a --- /dev/null +++ b/FateZero/train_tune_a_video.py @@ -0,0 +1,426 @@ +import os,copy +import inspect +from typing import Optional, List, Dict, Union +import PIL +import click +from omegaconf import OmegaConf + +import torch +import torch.utils.data +import torch.nn.functional as F +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.utils import set_seed +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DDIMScheduler, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils.import_utils import is_xformers_available +from diffusers.pipeline_utils import DiffusionPipeline + +from tqdm.auto import tqdm +from transformers import AutoTokenizer, CLIPTextModel +from einops import rearrange + +from video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel +from video_diffusion.data.dataset import ImageSequenceDataset +from video_diffusion.common.util import get_time_string, get_function_args +from video_diffusion.common.logger import get_logger_config_path +from video_diffusion.common.image_util import log_train_samples, log_train_reg_samples +from video_diffusion.common.instantiate_from_config import instantiate_from_config, get_obj_from_str +from video_diffusion.pipelines.validation_loop import SampleLogger + + +def collate_fn(examples): + batch = { + "prompt_ids": torch.cat([example["prompt_ids"] for example in examples], dim=0), + "images": torch.stack([example["images"] for example in examples]), + + } + if "class_images" in examples[0]: + batch["class_prompt_ids"] = torch.cat([example["class_prompt_ids"] for example in examples], dim=0) + batch["class_images"] = torch.stack([example["class_images"] for example in examples]) + return batch + + + +def train( + config: str, + pretrained_model_path: str, + train_dataset: Dict, + logdir: str = None, + train_steps: int = 300, + validation_steps: int = 1000, + validation_sample_logger_config: Optional[Dict] = None, + test_pipeline_config: Optional[Dict] = dict(), + trainer_pipeline_config: Optional[Dict] = dict(), + gradient_accumulation_steps: int = 1, + seed: Optional[int] = None, + mixed_precision: Optional[str] = "fp16", + enable_xformers: bool = True, + train_batch_size: int = 1, + learning_rate: float = 3e-5, + scale_lr: bool = False, + lr_scheduler: str = "constant", # ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] + lr_warmup_steps: int = 0, + use_8bit_adam: bool = True, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_weight_decay: float = 1e-2, + adam_epsilon: float = 1e-08, + max_grad_norm: float = 1.0, + gradient_checkpointing: bool = False, + train_temporal_conv: bool = False, + checkpointing_steps: int = 1000, + model_config: dict={}, + # use_train_latents: bool=False, + # kwr + # **kwargs +): + args = get_function_args() + # args.update(kwargs) + train_dataset_config = copy.deepcopy(train_dataset) + time_string = get_time_string() + if logdir is None: + logdir = config.replace('config', 'result').replace('.yml', '').replace('.yaml', '') + logdir += f"_{time_string}" + + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision, + ) + if accelerator.is_main_process: + os.makedirs(logdir, exist_ok=True) + OmegaConf.save(args, os.path.join(logdir, "config.yml")) + logger = get_logger_config_path(logdir) + if seed is not None: + set_seed(seed) + + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_path, + subfolder="tokenizer", + use_fast=False, + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_path, + subfolder="text_encoder", + ) + + vae = AutoencoderKL.from_pretrained( + pretrained_model_path, + subfolder="vae", + ) + + unet = UNetPseudo3DConditionModel.from_2d_model( + os.path.join(pretrained_model_path, "unet"), model_config=model_config + ) + + + if 'target' not in test_pipeline_config: + test_pipeline_config['target'] = 'video_diffusion.pipelines.stable_diffusion.SpatioTemporalStableDiffusionPipeline' + + pipeline = instantiate_from_config( + test_pipeline_config, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=DDIMScheduler.from_pretrained( + pretrained_model_path, + subfolder="scheduler", + ), + ) + pipeline.scheduler.set_timesteps(validation_sample_logger_config['num_inference_steps']) + pipeline.set_progress_bar_config(disable=True) + + + if is_xformers_available() and enable_xformers: + # if False: # Disable xformers for null inversion + try: + pipeline.enable_xformers_memory_efficient_attention() + print('enable xformers in the training and testing') + except Exception as e: + logger.warning( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Start of config trainable parameters in Unet and optimizer + trainable_modules = ("attn_temporal", ".to_q") + if train_temporal_conv: + trainable_modules += ("conv_temporal",) + for name, module in unet.named_modules(): + if name.endswith(trainable_modules): + for params in module.parameters(): + params.requires_grad = True + + + if gradient_checkpointing: + print('enable gradient checkpointing in the training and testing') + unet.enable_gradient_checkpointing() + + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + params_to_optimize = unet.parameters() + num_trainable_modules = 0 + num_trainable_params = 0 + num_unet_params = 0 + for params in params_to_optimize: + num_unet_params += params.numel() + if params.requires_grad == True: + num_trainable_modules +=1 + num_trainable_params += params.numel() + + logger.info(f"Num of trainable modules: {num_trainable_modules}") + logger.info(f"Num of trainable params: {num_trainable_params/(1024*1024):.2f} M") + logger.info(f"Num of unet params: {num_unet_params/(1024*1024):.2f} M ") + + + params_to_optimize = unet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, + ) + # End of config trainable parameters in Unet and optimizer + + + prompt_ids = tokenizer( + train_dataset["prompt"], + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if 'class_data_root' in train_dataset_config: + if 'class_data_prompt' not in train_dataset_config: + train_dataset_config['class_data_prompt'] = train_dataset_config['prompt'] + class_prompt_ids = tokenizer( + train_dataset_config["class_data_prompt"], + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + else: + class_prompt_ids = None + train_dataset = ImageSequenceDataset(**train_dataset, prompt_ids=prompt_ids, class_prompt_ids=class_prompt_ids) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + num_workers=16, + collate_fn=collate_fn, + ) + + train_sample_save_path = os.path.join(logdir, "train_samples.gif") + log_train_samples(save_path=train_sample_save_path, train_dataloader=train_dataloader) + if 'class_data_root' in train_dataset_config: + log_train_reg_samples(save_path=train_sample_save_path.replace('train_samples', 'class_data_samples'), train_dataloader=train_dataloader) + + # Prepare learning rate scheduler in accelerate config + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, + num_training_steps=train_steps * gradient_accumulation_steps, + ) + + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + accelerator.register_for_checkpointing(lr_scheduler) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + print('enable float16 in the training and testing') + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("video") # , config=vars(args)) + + # Start of config trainer + trainer = instantiate_from_config( + trainer_pipeline_config, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler= DDPMScheduler.from_pretrained( + pretrained_model_path, + subfolder="scheduler", + ), + # training hyperparams + weight_dtype=weight_dtype, + accelerator=accelerator, + optimizer=optimizer, + max_grad_norm=max_grad_norm, + lr_scheduler=lr_scheduler, + prior_preservation=None + ) + trainer.print_pipeline(logger) + # Train! + total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Instantaneous batch size per device = {train_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {train_steps}") + step = 0 + # End of config trainer + + if validation_sample_logger_config is not None and accelerator.is_main_process: + validation_sample_logger = SampleLogger(**validation_sample_logger_config, logdir=logdir) + + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(step, train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + def make_data_yielder(dataloader): + while True: + for batch in dataloader: + yield batch + accelerator.wait_for_everyone() + + train_data_yielder = make_data_yielder(train_dataloader) + + + assert(train_dataset.overfit_length == 1), "Only support overfiting on a single video" + # batch = next(train_data_yielder) + + + while step < train_steps: + batch = next(train_data_yielder) + """************************* start of an iteration*******************************""" + loss = trainer.step(batch) + # torch.cuda.empty_cache() + + """************************* end of an iteration*******************************""" + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + step += 1 + + if accelerator.is_main_process: + + if validation_sample_logger is not None and (step % validation_steps == 0): + unet.eval() + + val_image = rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w") + + # Unet is changing in different iteration; we should invert online + if validation_sample_logger_config.get('use_train_latents', False): + # Precompute the latents for this video to align the initial latents in training and test + assert batch["images"].shape[0] == 1, "Only support, overfiting on a single video" + # we only inference for latents, no training + vae.eval() + text_encoder.eval() + unet.eval() + + text_embeddings = pipeline._encode_prompt( + train_dataset.prompt, + device = accelerator.device, + num_images_per_prompt = 1, + do_classifier_free_guidance = True, + negative_prompt=None + ) + batch['latents_all_step'] = pipeline.prepare_latents_ddim_inverted( + rearrange(batch["images"].to(dtype=weight_dtype), "b c f h w -> (b f) c h w"), + batch_size = 1 , + num_images_per_prompt = 1, # not sure how to use it + text_embeddings = text_embeddings + ) + batch['ddim_init_latents'] = batch['latents_all_step'][-1] + else: + batch['ddim_init_latents'] = None + + + + validation_sample_logger.log_sample_images( + # image=rearrange(train_dataset.get_all()["images"].to(accelerator.device, dtype=weight_dtype), "c f h w -> f c h w"), # torch.Size([8, 3, 512, 512]) + image= val_image, # torch.Size([8, 3, 512, 512]) + pipeline=pipeline, + device=accelerator.device, + step=step, + latents = batch['ddim_init_latents'], + ) + torch.cuda.empty_cache() + unet.train() + + if step % checkpointing_steps == 0: + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} + pipeline_save = get_obj_from_str(test_pipeline_config["target"]).from_pretrained( + pretrained_model_path, + unet=accelerator.unwrap_model(unet, **extra_args), + ) + checkpoint_save_path = os.path.join(logdir, f"checkpoint_{step}") + pipeline_save.save_pretrained(checkpoint_save_path) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=step) + + accelerator.end_training() + + +@click.command() +@click.option("--config", type=str, default="config/sample.yml") +def run(config): + train(config=config, **OmegaConf.load(config)) + + +if __name__ == "__main__": + run() diff --git a/FateZero/video_diffusion/common/image_util.py b/FateZero/video_diffusion/common/image_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f5258a4c7d49ca266eb73287c09aa7ee18fa9421 --- /dev/null +++ b/FateZero/video_diffusion/common/image_util.py @@ -0,0 +1,203 @@ +import os +import math +import textwrap + +import imageio +import numpy as np +from typing import Sequence +import requests +import cv2 +from PIL import Image, ImageDraw, ImageFont + +import torch +from torchvision import transforms +from einops import rearrange + + + + + + +IMAGE_EXTENSION = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") + +FONT_URL = "https://raw.github.com/googlefonts/opensans/main/fonts/ttf/OpenSans-Regular.ttf" +FONT_PATH = "./docs/OpenSans-Regular.ttf" + + +def pad(image: Image.Image, top=0, right=0, bottom=0, left=0, color=(255, 255, 255)) -> Image.Image: + new_image = Image.new(image.mode, (image.width + right + left, image.height + top + bottom), color) + new_image.paste(image, (left, top)) + return new_image + + +def download_font_opensans(path=FONT_PATH): + font_url = FONT_URL + response = requests.get(font_url) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "wb") as f: + f.write(response.content) + + +def annotate_image_with_font(image: Image.Image, text: str, font: ImageFont.FreeTypeFont) -> Image.Image: + image_w = image.width + _, _, text_w, text_h = font.getbbox(text) + line_size = math.floor(len(text) * image_w / text_w) + + lines = textwrap.wrap(text, width=line_size) + padding = text_h * len(lines) + image = pad(image, top=padding + 3) + + ImageDraw.Draw(image).text((0, 0), "\n".join(lines), fill=(0, 0, 0), font=font) + return image + + +def annotate_image(image: Image.Image, text: str, font_size: int = 15): + if not os.path.isfile(FONT_PATH): + download_font_opensans() + font = ImageFont.truetype(FONT_PATH, size=font_size) + return annotate_image_with_font(image=image, text=text, font=font) + + +def make_grid(images: Sequence[Image.Image], rows=None, cols=None) -> Image.Image: + if isinstance(images[0], np.ndarray): + images = [Image.fromarray(i) for i in images] + + if rows is None: + assert cols is not None + rows = math.ceil(len(images) / cols) + else: + cols = math.ceil(len(images) / rows) + + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + for i, image in enumerate(images): + if image.size != (w, h): + image = image.resize((w, h)) + grid.paste(image, box=(i % cols * w, i // cols * h)) + return grid + + +def save_images_as_gif( + images: Sequence[Image.Image], + save_path: str, + loop=0, + duration=100, + optimize=False, +) -> None: + + images[0].save( + save_path, + save_all=True, + append_images=images[1:], + optimize=optimize, + loop=loop, + duration=duration, + ) + +def save_images_as_mp4( + images: Sequence[Image.Image], + save_path: str, +) -> None: + # images[0].save( + # save_path, + # save_all=True, + # append_images=images[1:], + # optimize=optimize, + # loop=loop, + # duration=duration, + # ) + writer_edit = imageio.get_writer( + save_path, + fps=10) + for i in images: + init_image = i.convert("RGB") + writer_edit.append_data(np.array(init_image)) + writer_edit.close() + + + +def save_images_as_folder( + images: Sequence[Image.Image], + save_path: str, +) -> None: + os.makedirs(save_path, exist_ok=True) + for index, image in enumerate(images): + init_image = image + if len(np.array(init_image).shape) == 3: + cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image)[:, :, ::-1]) + else: + cv2.imwrite(os.path.join(save_path, f"{index:05d}.png"), np.array(init_image)) + +def log_train_samples( + train_dataloader, + save_path, + num_batch: int = 4, +): + train_samples = [] + for idx, batch in enumerate(train_dataloader): + if idx >= num_batch: + break + train_samples.append(batch["images"]) + + train_samples = torch.cat(train_samples).numpy() + train_samples = rearrange(train_samples, "b c f h w -> b f h w c") + train_samples = (train_samples * 0.5 + 0.5).clip(0, 1) + train_samples = numpy_batch_seq_to_pil(train_samples) + train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)] + # save_images_as_gif(train_samples, save_path) + save_gif_mp4_folder_type(train_samples, save_path) + +def log_train_reg_samples( + train_dataloader, + save_path, + num_batch: int = 4, +): + train_samples = [] + for idx, batch in enumerate(train_dataloader): + if idx >= num_batch: + break + train_samples.append(batch["class_images"]) + + train_samples = torch.cat(train_samples).numpy() + train_samples = rearrange(train_samples, "b c f h w -> b f h w c") + train_samples = (train_samples * 0.5 + 0.5).clip(0, 1) + train_samples = numpy_batch_seq_to_pil(train_samples) + train_samples = [make_grid(images, cols=int(np.ceil(np.sqrt(len(train_samples))))) for images in zip(*train_samples)] + # save_images_as_gif(train_samples, save_path) + save_gif_mp4_folder_type(train_samples, save_path) + + +def save_gif_mp4_folder_type(images, save_path, save_gif=False): + + if isinstance(images[0], np.ndarray): + images = [Image.fromarray(i) for i in images] + elif isinstance(images[0], torch.Tensor): + images = [transforms.ToPILImage()(i.cpu().clone()[0]) for i in images] + save_path_mp4 = save_path.replace('gif', 'mp4') + save_path_folder = save_path.replace('.gif', '') + if save_gif: save_images_as_gif(images, save_path) + save_images_as_mp4(images, save_path_mp4) + save_images_as_folder(images, save_path_folder) + +# copy from video_diffusion/pipelines/stable_diffusion.py +def numpy_seq_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + +# copy from diffusers-0.11.1/src/diffusers/pipeline_utils.py +def numpy_batch_seq_to_pil(images): + pil_images = [] + for sequence in images: + pil_images.append(numpy_seq_to_pil(sequence)) + return pil_images diff --git a/FateZero/video_diffusion/common/instantiate_from_config.py b/FateZero/video_diffusion/common/instantiate_from_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9c410d1ba6f0073fada0bbdb056cbad4abed3aa9 --- /dev/null +++ b/FateZero/video_diffusion/common/instantiate_from_config.py @@ -0,0 +1,33 @@ +""" +Copy from stable diffusion +""" +import importlib + + +def instantiate_from_config(config:dict, **args_from_code): + """Util funciton to decompose differenct modules using config + + Args: + config (dict): with key of "target" and "params", better from yaml + static + args_from_code: additional con + + + Returns: + a validation/training pipeline, a module + """ + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **args_from_code) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/FateZero/video_diffusion/common/logger.py b/FateZero/video_diffusion/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ed344e5f4d377540e96fbd6dc00f1d9edc7201dd --- /dev/null +++ b/FateZero/video_diffusion/common/logger.py @@ -0,0 +1,17 @@ +import os +import logging, logging.handlers +from accelerate.logging import get_logger + +def get_logger_config_path(logdir): + # accelerate handles the logger in multiprocessing + logger = get_logger(__name__) + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s:%(levelname)s : %(message)s', + datefmt='%a, %d %b %Y %H:%M:%S', + filename=os.path.join(logdir, 'log.log'), + filemode='w') + chlr = logging.StreamHandler() + chlr.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s : %(message)s')) + logger.logger.addHandler(chlr) + return logger \ No newline at end of file diff --git a/FateZero/video_diffusion/common/set_seed.py b/FateZero/video_diffusion/common/set_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..8f30dbf3028fc884adcd3ed0ffb317f2220ac32a --- /dev/null +++ b/FateZero/video_diffusion/common/set_seed.py @@ -0,0 +1,28 @@ +import os +os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + +import torch +import numpy as np +import random + +from accelerate.utils import set_seed + + +def video_set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + """ + set_seed(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True, warn_only=True) + # [W Context.cpp:82] Warning: efficient_attention_forward_cutlass does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. You can file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. (function alertNotDeterministic) + diff --git a/FateZero/video_diffusion/common/util.py b/FateZero/video_diffusion/common/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b393ba6745e6737d1476626e95b3a40137e2982d --- /dev/null +++ b/FateZero/video_diffusion/common/util.py @@ -0,0 +1,73 @@ +import os +import sys +import copy +import inspect +import datetime +from typing import List, Tuple, Optional, Dict + + +def glob_files( + root_path: str, + extensions: Tuple[str], + recursive: bool = True, + skip_hidden_directories: bool = True, + max_directories: Optional[int] = None, + max_files: Optional[int] = None, + relative_path: bool = False, +) -> Tuple[List[str], bool, bool]: + """glob files with specified extensions + + Args: + root_path (str): _description_ + extensions (Tuple[str]): _description_ + recursive (bool, optional): _description_. Defaults to True. + skip_hidden_directories (bool, optional): _description_. Defaults to True. + max_directories (Optional[int], optional): max number of directories to search. Defaults to None. + max_files (Optional[int], optional): max file number limit. Defaults to None. + relative_path (bool, optional): _description_. Defaults to False. + + Returns: + Tuple[List[str], bool, bool]: _description_ + """ + paths = [] + hit_max_directories = False + hit_max_files = False + for directory_idx, (directory, _, fnames) in enumerate(os.walk(root_path, followlinks=True)): + if skip_hidden_directories and os.path.basename(directory).startswith("."): + continue + + if max_directories is not None and directory_idx >= max_directories: + hit_max_directories = True + break + + paths += [ + os.path.join(directory, fname) + for fname in sorted(fnames) + if fname.lower().endswith(extensions) + ] + + if not recursive: + break + + if max_files is not None and len(paths) > max_files: + hit_max_files = True + paths = paths[:max_files] + break + + if relative_path: + paths = [os.path.relpath(p, root_path) for p in paths] + + return paths, hit_max_directories, hit_max_files + + +def get_time_string() -> str: + x = datetime.datetime.now() + return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}" + + +def get_function_args() -> Dict: + frame = sys._getframe(1) + args, _, _, values = inspect.getargvalues(frame) + args_dict = copy.deepcopy({arg: values[arg] for arg in args}) + + return args_dict diff --git a/FateZero/video_diffusion/data/dataset.py b/FateZero/video_diffusion/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0ee0bb2ee10ab75d2e7427005eab7fe4938675 --- /dev/null +++ b/FateZero/video_diffusion/data/dataset.py @@ -0,0 +1,158 @@ +import os + +import numpy as np +from PIL import Image +from einops import rearrange +from pathlib import Path + +import torch +from torch.utils.data import Dataset + +from .transform import short_size_scale, random_crop, center_crop, offset_crop +from ..common.image_util import IMAGE_EXTENSION + +import sys +sys.path.append('FateZero') + +class ImageSequenceDataset(Dataset): + def __init__( + self, + path: str, + prompt_ids: torch.Tensor, + prompt: str, + start_sample_frame: int=0, + n_sample_frame: int = 8, + sampling_rate: int = 1, + stride: int = 1, + image_mode: str = "RGB", + image_size: int = 512, + crop: str = "center", + + class_data_root: str = None, + class_prompt_ids: torch.Tensor = None, + + offset: dict = { + "left": 0, + "right": 0, + "top": 0, + "bottom": 0 + } + ): + self.path = path + self.images = self.get_image_list(path) + self.n_images = len(self.images) + self.offset = offset + + if n_sample_frame < 0: + n_sample_frame = len(self.images) + self.start_sample_frame = start_sample_frame + + self.n_sample_frame = n_sample_frame + self.sampling_rate = sampling_rate + + self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1 + if self.n_images < self.sequence_length: + raise ValueError("self.n_images < self.sequence_length") + self.stride = stride + + self.image_mode = image_mode + self.image_size = image_size + crop_methods = { + "center": center_crop, + "random": random_crop, + } + if crop not in crop_methods: + raise ValueError + self.crop = crop_methods[crop] + + self.prompt = prompt + self.prompt_ids = prompt_ids + self.overfit_length = (self.n_images - self.sequence_length) // self.stride + 1 + # Negative prompt for regularization + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_images_path = sorted(list(self.class_data_root.iterdir())) + self.num_class_images = len(self.class_images_path) + self.class_prompt_ids = class_prompt_ids + + self.video_len = (self.n_images - self.sequence_length) // self.stride + 1 + + def __len__(self): + max_len = (self.n_images - self.sequence_length) // self.stride + 1 + + if hasattr(self, 'num_class_images'): + max_len = max(max_len, self.num_class_images) + # return (self.n_images - self.sequence_length) // self.stride + 1 + return max_len + + def __getitem__(self, index): + return_batch = {} + frame_indices = self.get_frame_indices(index%self.video_len) + frames = [self.load_frame(i) for i in frame_indices] + frames = self.transform(frames) + + return_batch.update( + { + "images": frames, + "prompt_ids": self.prompt_ids, + } + ) + + if hasattr(self, 'class_data_root'): + class_index = index % (self.num_class_images - self.n_sample_frame) + class_indices = self.get_class_indices(class_index) + frames = [self.load_class_frame(i) for i in class_indices] + return_batch["class_images"] = self.tensorize_frames(frames) + return_batch["class_prompt_ids"] = self.class_prompt_ids + return return_batch + + def get_all(self, val_length=None): + if val_length is None: + val_length = len(self.images) + frame_indices = (i for i in range(val_length)) + frames = [self.load_frame(i) for i in frame_indices] + frames = self.transform(frames) + + return { + "images": frames, + "prompt_ids": self.prompt_ids, + } + + def transform(self, frames): + frames = self.tensorize_frames(frames) + frames = offset_crop(frames, **self.offset) + frames = short_size_scale(frames, size=self.image_size) + frames = self.crop(frames, height=self.image_size, width=self.image_size) + return frames + + @staticmethod + def tensorize_frames(frames): + frames = rearrange(np.stack(frames), "f h w c -> c f h w") + return torch.from_numpy(frames).div(255) * 2 - 1 + + def load_frame(self, index): + image_path = os.path.join(self.path, self.images[index]) + return Image.open(image_path).convert(self.image_mode) + + def load_class_frame(self, index): + image_path = self.class_images_path[index] + return Image.open(image_path).convert(self.image_mode) + + def get_frame_indices(self, index): + if self.start_sample_frame is not None: + frame_start = self.start_sample_frame + self.stride * index + else: + frame_start = self.stride * index + return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) + + def get_class_indices(self, index): + frame_start = index + return (frame_start + i for i in range(self.n_sample_frame)) + + @staticmethod + def get_image_list(path): + images = [] + for file in sorted(os.listdir(path)): + if file.endswith(IMAGE_EXTENSION): + images.append(file) + return images diff --git a/FateZero/video_diffusion/data/transform.py b/FateZero/video_diffusion/data/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..043097b4b1eb0108ba1c5430ddd11702dfe9a9b6 --- /dev/null +++ b/FateZero/video_diffusion/data/transform.py @@ -0,0 +1,48 @@ +import random + +import torch + + +def short_size_scale(images, size): + h, w = images.shape[-2:] + short, long = (h, w) if h < w else (w, h) + + scale = size / short + long_target = int(scale * long) + + target_size = (size, long_target) if h < w else (long_target, size) + + return torch.nn.functional.interpolate( + input=images, size=target_size, mode="bilinear", antialias=True + ) + + +def random_short_side_scale(images, size_min, size_max): + size = random.randint(size_min, size_max) + return short_size_scale(images, size) + + +def random_crop(images, height, width): + image_h, image_w = images.shape[-2:] + h_start = random.randint(0, image_h - height) + w_start = random.randint(0, image_w - width) + return images[:, :, h_start : h_start + height, w_start : w_start + width] + + +def center_crop(images, height, width): + # offset_crop(images, 0,0, 200, 0) + image_h, image_w = images.shape[-2:] + h_start = (image_h - height) // 2 + w_start = (image_w - width) // 2 + return images[:, :, h_start : h_start + height, w_start : w_start + width] + +def offset_crop(image, left=0, right=0, top=200, bottom=0): + + n, c, h, w = image.shape + left = min(left, w-1) + right = min(right, w - left - 1) + top = min(top, h - 1) + bottom = min(bottom, h - top - 1) + image = image[:, :, top:h-bottom, left:w-right] + + return image \ No newline at end of file diff --git a/FateZero/video_diffusion/models/attention.py b/FateZero/video_diffusion/models/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..d38d90764accd6729441dc715f90845da1d8b513 --- /dev/null +++ b/FateZero/video_diffusion/models/attention.py @@ -0,0 +1,482 @@ +# code mostly taken from https://github.com/huggingface/diffusers +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.models.attention import FeedForward, CrossAttention, AdaLayerNorm +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange + + +@dataclass +class SpatioTemporalTransformerModelOutput(BaseOutput): + """torch.FloatTensor of shape [batch x channel x frames x height x width]""" + + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class SpatioTemporalTransformerModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + model_config: dict = {}, + **transformer_kwargs, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + SpatioTemporalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + model_config=model_config, + **transformer_kwargs, + ) + for d in range(num_layers) + ] + ) + + # Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward( + self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True + ): + # 1. Input + clip_length = None + is_video = hidden_states.ndim == 5 + if is_video: + clip_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = encoder_hidden_states.repeat_interleave(clip_length, 0) + else: + # To adapt to classifier-free guidance where encoder_hidden_states=2 + batch_size = hidden_states.shape[0]//encoder_hidden_states.shape[0] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(batch_size, 0) + *_, h, w = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") # (bf) (hw) c + else: + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, # [16, 4096, 320] + encoder_hidden_states=encoder_hidden_states, # ([1, 77, 768] + timestep=timestep, + clip_length=clip_length, + ) + + # 3. Output + if not self.use_linear_projection: + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=h, w=w).contiguous() + + output = hidden_states + residual + if is_video: + output = rearrange(output, "(b f) c h w -> b c f h w", f=clip_length) + + if not return_dict: + return (output,) + + return SpatioTemporalTransformerModelOutput(sample=output) + +import copy +class SpatioTemporalTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_sparse_causal_attention: bool = True, + temporal_attention_position: str = "after_feedforward", + model_config: dict = {} + ): + super().__init__() + + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_sparse_causal_attention = use_sparse_causal_attention + # For safety, freeze the model_config + self.model_config = copy.deepcopy(model_config) + if 'least_sc_channel' in model_config: + if dim< model_config['least_sc_channel']: + self.model_config['SparseCausalAttention_index'] = [] + + self.temporal_attention_position = temporal_attention_position + temporal_attention_positions = ["after_spatial", "after_cross", "after_feedforward"] + if temporal_attention_position not in temporal_attention_positions: + raise ValueError( + f"`temporal_attention_position` must be one of {temporal_attention_positions}" + ) + + # 1. Spatial-Attn + spatial_attention = SparseCausalAttention if use_sparse_causal_attention else CrossAttention + self.attn1 = spatial_attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) # is a self-attention + self.norm1 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + else: + self.attn2 = None + self.norm2 = None + + # 3. Temporal-Attn + self.attn_temporal = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) # initialize as an identity function + self.norm_temporal = ( + AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + ) + # efficient_attention_backward_cutlass is not implemented for large channels + self.use_xformers = (dim <= 320) or "3090" not in torch.cuda.get_device_name(0) + + # 4. Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + if use_memory_efficient_attention_xformers is True: + + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + else: + + pass + except Exception as e: + raise e + # self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers and self.use_xformers + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers and self.use_xformers + # self.attn_temporal._use_memory_efficient_attention_xformers = ( + # use_memory_efficient_attention_xformers + # ), # FIXME: enabling this raises CUDA ERROR. Gotta dig in. + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + clip_length=None, + ): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + kwargs = dict( + hidden_states=norm_hidden_states, + attention_mask=attention_mask, + ) + if self.only_cross_attention: + kwargs.update(encoder_hidden_states=encoder_hidden_states) + if self.use_sparse_causal_attention: + kwargs.update(clip_length=clip_length) + if 'SparseCausalAttention_index' in self.model_config.keys(): + kwargs.update(SparseCausalAttention_index = self.model_config['SparseCausalAttention_index']) + + hidden_states = hidden_states + self.attn1(**kwargs) + + if clip_length is not None and self.temporal_attention_position == "after_spatial": + hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + + if self.attn2 is not None: + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, # [16, 4096, 320] + encoder_hidden_states=encoder_hidden_states, # [1, 77, 768] + attention_mask=attention_mask, + ) + + hidden_states + ) + + if clip_length is not None and self.temporal_attention_position == "after_cross": + hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + if clip_length is not None and self.temporal_attention_position == "after_feedforward": + hidden_states = self.apply_temporal_attention(hidden_states, timestep, clip_length) + + return hidden_states + + def apply_temporal_attention(self, hidden_states, timestep, clip_length): + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=clip_length) + norm_hidden_states = ( + self.norm_temporal(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temporal(hidden_states) + ) + hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + clip_length: int = None, + SparseCausalAttention_index: list = [-1, 'first'] + ): + if ( + self.added_kv_proj_dim is not None + or encoder_hidden_states is not None + or attention_mask is not None + ): + raise NotImplementedError + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + if clip_length is not None: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) + + + # ***********************Start of SparseCausalAttention_index********** + frame_index_list = [] + # print(f'SparseCausalAttention_index {str(SparseCausalAttention_index)}') + if len(SparseCausalAttention_index) > 0: + for index in SparseCausalAttention_index: + if isinstance(index, str): + if index == 'first': + frame_index = [0] * clip_length + if index == 'last': + frame_index = [clip_length-1] * clip_length + if (index == 'mid') or (index == 'middle'): + frame_index = [int(clip_length-1)//2] * clip_length + else: + assert isinstance(index, int), 'relative index must be int' + frame_index = torch.arange(clip_length) + index + frame_index = frame_index.clip(0, clip_length-1) + + frame_index_list.append(frame_index) + + key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list + ], dim=2) + value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list + ], dim=2) + + + # ***********************End of SparseCausalAttention_index********** + key = rearrange(key, "b f d c -> (b f) d c", f=clip_length) + value = rearrange(value, "b f d c -> (b f) d c", f=clip_length) + + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention( + query, key, value, hidden_states.shape[1], dim, attention_mask + ) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + +# FIXME +class SparseCausalAttention_fixme(CrossAttention): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + clip_length: int = None, + ): + if ( + self.added_kv_proj_dim is not None + or encoder_hidden_states is not None + or attention_mask is not None + ): + raise NotImplementedError + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + prev_frame_index = torch.arange(clip_length) - 1 + prev_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) + key = torch.cat([key[:, [0] * clip_length], key[:, prev_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c", f=clip_length) + + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) + value = torch.cat([value[:, [0] * clip_length], value[:, prev_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c", f=clip_length) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention( + query, key, value, hidden_states.shape[1], dim, attention_mask + ) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states diff --git a/FateZero/video_diffusion/models/lora.py b/FateZero/video_diffusion/models/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..2692dd91c314586c8b3c7d8fc51ca5541ee2669d --- /dev/null +++ b/FateZero/video_diffusion/models/lora.py @@ -0,0 +1,131 @@ +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +# from diffusers.utils +from diffusers.utils import deprecate, logging +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, CrossAttention, AdaLayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, stride=1): + super().__init__() + + if rank > min(in_features, out_features): + Warning(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}, reset to {min(in_features, out_features)//2}") + rank = min(in_features, out_features)//2 + + + self.down = nn.Conv1d(in_features, rank, bias=False, + kernel_size=3, + stride = stride, + padding=1,) + self.up = nn.Conv1d(rank, out_features, bias=False, + kernel_size=3, + padding=1,) + + nn.init.normal_(self.down.weight, std=1 / rank) + # nn.init.zeros_(self.down.bias.data) + + nn.init.zeros_(self.up.weight) + # nn.init.zeros_(self.up.bias.data) + if stride > 1: + self.skip = nn.AvgPool1d(kernel_size=3, stride=2, padding=1) + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + if hasattr(self, 'skip'): + hidden_states=self.skip(hidden_states) + return up_hidden_states.to(orig_dtype)+hidden_states + + +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + + + +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states diff --git a/FateZero/video_diffusion/models/resnet.py b/FateZero/video_diffusion/models/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..162b46a9d66fa6e329661f7145cedf07f773cf8d --- /dev/null +++ b/FateZero/video_diffusion/models/resnet.py @@ -0,0 +1,518 @@ +# code mostly taken from https://github.com/huggingface/diffusers +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy + +from einops import rearrange +from .lora import LoRALinearLayer, LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor + +class PseudoConv3d(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, temporal_kernel_size=None, model_config: dict={}, temporal_downsample=False, **kwargs): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + **kwargs, + ) + if temporal_kernel_size is None: + temporal_kernel_size = kernel_size + + if temporal_downsample is True: + temporal_stride = 2 + else: + temporal_stride = 1 + + + if 'lora' in model_config.keys() : + self.conv_temporal = ( + LoRALinearLayer( + out_channels, + out_channels, + rank=model_config['lora'], + stride=temporal_stride + + ) + if kernel_size > 1 + else None + ) + else: + self.conv_temporal = ( + nn.Conv1d( + out_channels, + out_channels, + kernel_size=temporal_kernel_size, + padding=temporal_kernel_size // 2, + ) + if kernel_size > 1 + else None + ) + + if self.conv_temporal is not None: + nn.init.dirac_(self.conv_temporal.weight.data) # initialized to be identity + nn.init.zeros_(self.conv_temporal.bias.data) + + def forward(self, x): + b = x.shape[0] + + is_video = x.ndim == 5 + if is_video: + x = rearrange(x, "b c f h w -> (b f) c h w") + + x = super().forward(x) + + if is_video: + x = rearrange(x, "(b f) c h w -> b c f h w", b=b) + + if self.conv_temporal is None or not is_video: + return x + + *_, h, w = x.shape + + x = rearrange(x, "b c f h w -> (b h w) c f") + + x = self.conv_temporal(x) + + x = rearrange(x, "(b h w) c f -> b c f h w", h=h, w=w) + + return x + + +class UpsamplePseudo3D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv", model_config: dict={}, **kwargs + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.model_config = copy.deepcopy(model_config) + + conv = None + if use_conv_transpose: + raise NotImplementedError + conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + # Do NOT downsample in upsample block + td = False + + conv = PseudoConv3d(self.channels, self.out_channels, 3, padding=1, + model_config=model_config, temporal_downsample=td) + # conv = PseudoConv3d(self.channels, self.out_channels, 3, kwargs['lora'], padding=1) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + b = hidden_states.shape[0] + is_video = hidden_states.ndim == 5 + if is_video: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + if is_video: + td = ('temporal_downsample' in self.model_config and self.model_config['temporal_downsample'] is True) + + + if td: + hidden_states = rearrange(hidden_states, " (b f) c h w -> b c h w f ", b=b) + t_b, t_c, t_h, t_w, t_f = hidden_states.shape + hidden_states = rearrange(hidden_states, " b c h w f -> (b c) (h w) f ", b=b) + + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="linear") + hidden_states = rearrange(hidden_states, " (b c) (h w) f -> (b f) c h w ", b=t_b, h=t_h) + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class DownsamplePseudo3D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, model_config: dict={}, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + self.model_config = copy.deepcopy(model_config) + # self.model_config = copy.deepcopy(model_config) + + if use_conv: + td = ('temporal_downsample' in model_config and model_config['temporal_downsample'] is True) + + conv = PseudoConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding, + model_config=model_config, temporal_downsample=td) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + if self.use_conv: + hidden_states = self.conv(hidden_states) + else: + b = hidden_states.shape[0] + is_video = hidden_states.ndim == 5 + if is_video: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = self.conv(hidden_states) + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + return hidden_states + + +class ResnetBlockPseudo3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + model_config: dict={}, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = PseudoConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, model_config=model_config) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = PseudoConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, model_config=model_config) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = UpsamplePseudo3D(in_channels, use_conv=False, model_config=model_config) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = DownsamplePseudo3D(in_channels, use_conv=False, padding=1, name="op", model_config=model_config) + + self.use_in_shortcut = ( + self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = PseudoConv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, model_config=model_config + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + is_video = hidden_states.ndim == 5 + if is_video: + b, c, f, h, w = hidden_states.shape + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + temb = temb.repeat_interleave(f, 0) + + hidden_states = hidden_states + temb + + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + is_video = hidden_states.ndim == 5 + if is_video: + b, c, f, h, w = hidden_states.shape + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + temb = temb.repeat_interleave(f, 0) + + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + if is_video: + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", b=b) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) + + +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): + r"""Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): + r"""Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + output: Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + + kernel = torch.tensor(kernel, dtype=torch.float32) + if kernel.ndim == 1: + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + kernel.to(device=hidden_states.device), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output + + +def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(tensor.device) # Move back to mps if necessary + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/FateZero/video_diffusion/models/unet_3d_blocks.py b/FateZero/video_diffusion/models/unet_3d_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..9e6285cf1416c7e1be444cc0be4b4575c7eedb0b --- /dev/null +++ b/FateZero/video_diffusion/models/unet_3d_blocks.py @@ -0,0 +1,631 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import torch +from torch import nn + +from .attention import SpatioTemporalTransformerModel +from .resnet import DownsamplePseudo3D, ResnetBlockPseudo3D, UpsamplePseudo3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + model_config: dict={} +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockPseudo3D": + return DownBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + elif down_block_type == "CrossAttnDownBlockPseudo3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockPseudo3D") + return CrossAttnDownBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + model_config: dict={} +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockPseudo3D": + return UpBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + elif up_block_type == "CrossAttnUpBlockPseudo3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockPseudo3D") + return CrossAttnUpBlockPseudo3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=model_config + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockPseudo3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + model_config: dict={} + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + model_config=model_config + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + model_config=model_config + ) + ) + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + model_config=model_config + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention_mask is currently not used. Implement once used + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + model_config: dict={} + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + model_config=model_config + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + model_config=model_config + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsamplePseudo3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + model_config=model_config + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + model_config: dict={} + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockPseudo3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + model_config=model_config + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + DownsamplePseudo3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + model_config=model_config + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + model_config: dict={}, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + self.model_config = model_config + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockPseudo3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + model_config=model_config + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + SpatioTemporalTransformerModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + model_config=model_config + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels, model_config=model_config)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlockPseudo3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + model_config: dict={}, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockPseudo3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + model_config=model_config + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [UpsamplePseudo3D(out_channels, use_conv=True, out_channels=out_channels, model_config=model_config)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/FateZero/video_diffusion/models/unet_3d_condition.py b/FateZero/video_diffusion/models/unet_3d_condition.py new file mode 100755 index 0000000000000000000000000000000000000000..cb4f510f44a1fb1a386fe802a51420ead321c29d --- /dev/null +++ b/FateZero/video_diffusion/models/unet_3d_condition.py @@ -0,0 +1,501 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import os +import glob +import json +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import copy + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_3d_blocks import ( + CrossAttnDownBlockPseudo3D, + CrossAttnUpBlockPseudo3D, + DownBlockPseudo3D, + UNetMidBlockPseudo3DCrossAttn, + UpBlockPseudo3D, + get_down_block, + get_up_block, +) +from .resnet import PseudoConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetPseudo3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNetPseudo3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "CrossAttnDownBlockPseudo3D", + "DownBlockPseudo3D", + ), + mid_block_type: str = "UNetMidBlockPseudo3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlockPseudo3D", + "CrossAttnUpBlockPseudo3D", + "CrossAttnUpBlockPseudo3D", + "CrossAttnUpBlockPseudo3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + **kwargs + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + if 'temporal_downsample' in kwargs and kwargs['temporal_downsample'] is True: + kwargs['temporal_downsample_time'] = 3 + self.temporal_downsample_time = kwargs.get('temporal_downsample_time', 0) + + # input + self.conv_in = PseudoConv3d(in_channels, block_out_channels[0], + kernel_size=3, padding=(1, 1), model_config=kwargs) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + kwargs_copy=copy.deepcopy(kwargs) + temporal_downsample_i = ((i >= (len(down_block_types)-self.temporal_downsample_time)) + and (not is_final_block)) + kwargs_copy.update({'temporal_downsample': temporal_downsample_i} ) + + # kwargs_copy.update({'SparseCausalAttention_index': temporal_downsample_i} ) + if temporal_downsample_i: + print(f'Initialize model temporal downsample at layer {i}') + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=kwargs_copy + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlockPseudo3DCrossAttn": + self.mid_block = UNetMidBlockPseudo3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + model_config=kwargs + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + kwargs_copy=copy.deepcopy(kwargs) + kwargs_copy.update({'temporal_downsample': + i < (self.temporal_downsample_time-1)}) + if i < (self.temporal_downsample_time-1): + print(f'Initialize model temporal updample at layer {i}') + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + model_config=kwargs_copy + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + self.conv_out = PseudoConv3d(block_out_channels[0], out_channels, + kernel_size=3, padding=1, model_config=kwargs) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = ( + num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, + (CrossAttnDownBlockPseudo3D, DownBlockPseudo3D, CrossAttnUpBlockPseudo3D, UpBlockPseudo3D), + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, # None + attention_mask: Optional[torch.Tensor] = None, # None + return_dict: bool = True, + ) -> Union[UNetPseudo3DConditionOutput, Tuple]: + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: # None + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: # False + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + # for i in down_block_res_samples: print(i.shape) + # torch.Size([1, 320, 16, 64, 64]) + # torch.Size([1, 320, 16, 64, 64]) + # torch.Size([1, 320, 16, 64, 64]) + # torch.Size([1, 320, 8, 32, 32]) + # torch.Size([1, 640, 8, 32, 32]) + # torch.Size([1, 640, 8, 32, 32]) + # torch.Size([1, 640, 4, 16, 16]) + # torch.Size([1, 1280, 4, 16, 16]) + # torch.Size([1, 1280, 4, 16, 16]) + # torch.Size([1, 1280, 2, 8, 8]) + # torch.Size([1, 1280, 2, 8, 8]) + # torch.Size([1, 1280, 2, 8, 8]) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNetPseudo3DConditionOutput(sample=sample) + + @classmethod + def from_2d_model(cls, model_path, model_config): + config_path = os.path.join(model_path, "config.json") + if not os.path.isfile(config_path): + raise RuntimeError(f"{config_path} does not exist") + with open(config_path, "r") as f: + config = json.load(f) + + config.pop("_class_name") + config.pop("_diffusers_version") + + block_replacer = { + "CrossAttnDownBlock2D": "CrossAttnDownBlockPseudo3D", + "DownBlock2D": "DownBlockPseudo3D", + "UpBlock2D": "UpBlockPseudo3D", + "CrossAttnUpBlock2D": "CrossAttnUpBlockPseudo3D", + } + + def convert_2d_to_3d_block(block): + return block_replacer[block] if block in block_replacer else block + + config["down_block_types"] = [ + convert_2d_to_3d_block(block) for block in config["down_block_types"] + ] + config["up_block_types"] = [convert_2d_to_3d_block(block) for block in config["up_block_types"]] + if model_config is not None: + config.update(model_config) + + model = cls(**config) + + state_dict_path_condidates = glob.glob(os.path.join(model_path, "*.bin")) + if state_dict_path_condidates: + state_dict = torch.load(state_dict_path_condidates[0], map_location="cpu") + model.load_2d_state_dict(state_dict=state_dict) + + return model + + def load_2d_state_dict(self, state_dict, **kwargs): + state_dict_3d = self.state_dict() + + for k, v in state_dict.items(): + if k not in state_dict_3d: + raise KeyError(f"2d state_dict key {k} does not exist in 3d model") + elif v.shape != state_dict_3d[k].shape: + raise ValueError(f"state_dict shape mismatch, 2d {v.shape}, 3d {state_dict_3d[k].shape}") + + for k, v in state_dict_3d.items(): + if "_temporal" in k: + continue + if k not in state_dict: + raise KeyError(f"3d state_dict key {k} does not exist in 2d model") + + state_dict_3d.update(state_dict) + self.load_state_dict(state_dict_3d, **kwargs) diff --git a/FateZero/video_diffusion/pipelines/DDIMSpatioTemporalStableDiffusionPipeline.py b/FateZero/video_diffusion/pipelines/DDIMSpatioTemporalStableDiffusionPipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..5228dfa27bdf54081b9a075f3d4d7ea7a437d42f --- /dev/null +++ b/FateZero/video_diffusion/pipelines/DDIMSpatioTemporalStableDiffusionPipeline.py @@ -0,0 +1,300 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import inspect +from typing import Callable, List, Optional, Union +import PIL +import torch +import numpy as np +from einops import rearrange +from tqdm import trange, tqdm + +from diffusers.utils import deprecate, logging +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ..models.unet_3d_condition import UNetPseudo3DConditionModel +from .stable_diffusion import SpatioTemporalStableDiffusionPipeline + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DDIMSpatioTemporalStableDiffusionPipeline(SpatioTemporalStableDiffusionPipeline): + r""" + Pipeline for text-to-video generation using Spatio-Temporal Stable Diffusion. + """ + + def check_inputs(self, prompt, height, width, callback_steps, strength=None): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if strength is not None: + if strength <= 0 or strength > 1: + raise ValueError(f"The value of strength should in (0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + + + def prepare_latents_ddim_inverted(self, image, batch_size, num_images_per_prompt, + # dtype, device, + text_embeddings, + generator=None): + + # Not sure if image need to change device and type + # image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + # get latents + init_latents_bcfhw = rearrange(init_latents, "(b f) c h w -> b c f h w", b=batch_size) + ddim_latents_all_step = self.ddim_clean2noisy_loop(init_latents_bcfhw, text_embeddings) + return ddim_latents_all_step + + @torch.no_grad() + def ddim_clean2noisy_loop(self, latent, text_embeddings): + weight_dtype = latent.dtype + uncond_embeddings, cond_embeddings = text_embeddings.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + print(' Invert clean image to noise latents by DDIM and Unet') + for i in trange(len(self.scheduler.timesteps)): + t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1] + # noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) + noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)["sample"] # [1, 4, 8, 64, 64] -> [1, 4, 8, 64, 64]) + latent = self.next_clean2noise_step(noise_pred, t, latent) + all_latent.append(latent.to(dtype=weight_dtype)) + + return all_latent + + def next_clean2noise_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): + """ + Assume the eta in DDIM=0 + """ + timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = None, + num_inference_steps: int = 50, + clip_length: int = 8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **args + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Only used in DDIM or strength<1.0 + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps, strength) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + # if strength <1.0: + # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps = self.scheduler.timesteps + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + if latents is None: + ddim_latents_all_step = self.prepare_latents_ddim_inverted( + image, batch_size, num_images_per_prompt, + # text_embeddings.dtype, device, + text_embeddings, + generator, + ) + latents = ddim_latents_all_step[-1] + + latents_dtype = latents.dtype + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings + ).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + has_nsfw_concept = None + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + torch.cuda.empty_cache() + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/FateZero/video_diffusion/pipelines/p2pDDIMSpatioTemporalPipeline.py b/FateZero/video_diffusion/pipelines/p2pDDIMSpatioTemporalPipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..d33dd73c74b8236be71554829748541ab4de9725 --- /dev/null +++ b/FateZero/video_diffusion/pipelines/p2pDDIMSpatioTemporalPipeline.py @@ -0,0 +1,437 @@ +# code mostly taken from https://github.com/huggingface/diffusers + +from typing import Callable, List, Optional, Union +import os +import PIL +import torch +import numpy as np +from einops import rearrange +from tqdm import trange, tqdm + +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.utils import deprecate, logging +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.models import AutoencoderKL +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) + +from video_diffusion.prompt_attention.attention_util import make_controller +from ..models.unet_3d_condition import UNetPseudo3DConditionModel +from .stable_diffusion import SpatioTemporalStableDiffusionPipeline +from ..prompt_attention import attention_util +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class p2pDDIMSpatioTemporalPipeline(SpatioTemporalStableDiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNetPseudo3DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler,], + disk_store: bool=False + ): + super().__init__(vae, text_encoder, tokenizer, unet, scheduler) + self.store_controller = attention_util.AttentionStore(disk_store=disk_store) + self.empty_controller = attention_util.EmptyControl() + r""" + Pipeline for text-to-video generation using Spatio-Temporal Stable Diffusion. + """ + + def check_inputs(self, prompt, height, width, callback_steps, strength=None): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if strength is not None: + if strength <= 0 or strength > 1: + raise ValueError(f"The value of strength should in (0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + @torch.no_grad() + def prepare_latents_ddim_inverted(self, image, batch_size, num_images_per_prompt, + text_embeddings, + store_attention=False, prompt=None, + generator=None, + LOW_RESOURCE = True, + save_path = None + ): + self.prepare_before_train_loop() + if store_attention: + attention_util.register_attention_control(self, self.store_controller) + resource_default_value = self.store_controller.LOW_RESOURCE + self.store_controller.LOW_RESOURCE = LOW_RESOURCE # in inversion, no CFG, record all latents attention + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + # get latents + init_latents_bcfhw = rearrange(init_latents, "(b f) c h w -> b c f h w", b=batch_size) + ddim_latents_all_step = self.ddim_clean2noisy_loop(init_latents_bcfhw, text_embeddings, self.store_controller) + if store_attention and (save_path is not None) : + os.makedirs(save_path+'/cross_attention') + attention_output = attention_util.show_cross_attention(self.tokenizer, prompt, + self.store_controller, 16, ["up", "down"], + save_path = save_path+'/cross_attention') + + # Detach the controller for safety + attention_util.register_attention_control(self, self.empty_controller) + self.store_controller.LOW_RESOURCE = resource_default_value + + return ddim_latents_all_step + + @torch.no_grad() + def ddim_clean2noisy_loop(self, latent, text_embeddings, controller:attention_util.AttentionControl=None): + weight_dtype = latent.dtype + uncond_embeddings, cond_embeddings = text_embeddings.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + print(' Invert clean image to noise latents by DDIM and Unet') + for i in trange(len(self.scheduler.timesteps)): + t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1] + + # [1, 4, 8, 64, 64] -> [1, 4, 8, 64, 64]) + noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)["sample"] + + latent = self.next_clean2noise_step(noise_pred, t, latent) + if controller is not None: controller.step_callback(latent) + all_latent.append(latent.to(dtype=weight_dtype)) + + return all_latent + + def next_clean2noise_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): + """ + Assume the eta in DDIM=0 + """ + timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def p2preplace_edit(self, **kwargs): + # Edit controller during inference + # The controller must know the source prompt for replace mapping + + len_source = {len(kwargs['source_prompt'].split(' '))} + len_target = {len(kwargs['prompt'].split(' '))} + equal_length = (len_source == len_target) + print(f" len_source: {len_source}, len_target: {len_target}, equal_length: {equal_length}") + edit_controller = make_controller( + self.tokenizer, + [ kwargs['source_prompt'], kwargs['prompt']], + NUM_DDIM_STEPS = kwargs['num_inference_steps'], + is_replace_controller=kwargs.get('is_replace_controller', True) and equal_length, + cross_replace_steps=kwargs['cross_replace_steps'], + self_replace_steps=kwargs['self_replace_steps'], + blend_words=kwargs.get('blend_words', None), + equilizer_params=kwargs.get('eq_params', None), + additional_attention_store=self.store_controller, + use_inversion_attention = kwargs['use_inversion_attention'], + bend_th = kwargs.get('bend_th', (0.3, 0.3)), + masked_self_attention = kwargs.get('masked_self_attention', None), + masked_latents=kwargs.get('masked_latents', None), + save_path=kwargs.get('save_path', None), + save_self_attention = kwargs.get('save_self_attention', True), + disk_store = kwargs.get('disk_store', False) + ) + + attention_util.register_attention_control(self, edit_controller) + + + # In ddim inferece, no need source prompt + sdimage_output = self.sd_ddim_pipeline( + controller = edit_controller, + # target_prompt = kwargs['prompts'][1], + **kwargs) + if hasattr(edit_controller.local_blend, 'mask_list'): + mask_list = edit_controller.local_blend.mask_list + else: + mask_list = None + if len(edit_controller.attention_store.keys()) > 0: + attention_output = attention_util.show_cross_attention(self.tokenizer, kwargs['prompt'], + edit_controller, 16, ["up", "down"]) + else: + attention_output = None + dict_output = { + "sdimage_output" : sdimage_output, + "attention_output" : attention_output, + "mask_list" : mask_list, + } + attention_util.register_attention_control(self, self.empty_controller) + return dict_output + + + + + @torch.no_grad() + def __call__(self, **kwargs): + edit_type = kwargs['edit_type'] + assert edit_type in ['save', 'swap', None] + if edit_type is None: + return self.sd_ddim_pipeline(controller = None, **kwargs) + + if edit_type == 'save': + del self.store_controller + self.store_controller = attention_util.AttentionStore() + attention_util.register_attention_control(self, self.store_controller) + sdimage_output = self.sd_ddim_pipeline(controller = self.store_controller, **kwargs) + + mask_list = None + + attention_output = attention_util.show_cross_attention(self.tokenizer, kwargs['prompt'], self.store_controller, 16, ["up", "down"]) + + + dict_output = { + "sdimage_output" : sdimage_output, + "attention_output" : attention_output, + 'mask_list': mask_list + } + + # Detach the controller for safety + attention_util.register_attention_control(self, self.empty_controller) + return dict_output + + if edit_type == 'swap': + + return self.p2preplace_edit(**kwargs) + + + @torch.no_grad() + def sd_ddim_pipeline( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + controller: attention_util.AttentionControl = None, + **args + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Only used in DDIM or strength<1.0 + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.0): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps, strength) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if latents is None: + ddim_latents_all_step = self.prepare_latents_ddim_inverted( + image, batch_size, num_images_per_prompt, + text_embeddings, + store_attention=False, # avoid recording attention in first inversion + generator = generator, + ) + latents = ddim_latents_all_step[-1] + else: + ddim_latents_all_step=None + + latents_dtype = latents.dtype + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings + ).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # Edit the latents using attention map + if controller is not None: + latents_old = latents + dtype = latents.dtype + latents_new = controller.step_callback(latents) + latents = latents_new.to(dtype) + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + torch.cuda.empty_cache() + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + has_nsfw_concept = None + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + torch.cuda.empty_cache() + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/FateZero/video_diffusion/pipelines/p2pvalidation_loop.py b/FateZero/video_diffusion/pipelines/p2pvalidation_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..125f416bee3bf6cfaf64567142e3a366b596eaad --- /dev/null +++ b/FateZero/video_diffusion/pipelines/p2pvalidation_loop.py @@ -0,0 +1,174 @@ +import os +import numpy as np +from typing import List, Union +import PIL +import copy +from einops import rearrange + +import torch +import torch.utils.data +import torch.utils.checkpoint + +from diffusers.pipeline_utils import DiffusionPipeline +from tqdm.auto import tqdm +from video_diffusion.common.image_util import make_grid, annotate_image +from video_diffusion.common.image_util import save_gif_mp4_folder_type + + +class p2pSampleLogger: + def __init__( + self, + prompts: List[str], + clip_length: int, + logdir: str, + subdir: str = "sample", + num_samples_per_prompt: int = 1, + sample_seeds: List[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 7, + strength: float = None, + annotate: bool = True, + annotate_size: int = 15, + use_make_grid: bool = True, + grid_column_size: int = 2, + prompt2prompt_edit: bool=False, + p2p_config: dict = None, + use_inversion_attention: bool = True, + source_prompt: str = None, + traverse_p2p_config: bool = False, + **args + ) -> None: + self.prompts = prompts + self.clip_length = clip_length + self.guidance_scale = guidance_scale + self.num_inference_steps = num_inference_steps + self.strength = strength + + if sample_seeds is None: + max_num_samples_per_prompt = int(1e5) + if num_samples_per_prompt > max_num_samples_per_prompt: + raise ValueError + sample_seeds = torch.randint(0, max_num_samples_per_prompt, (num_samples_per_prompt,)) + sample_seeds = sorted(sample_seeds.numpy().tolist()) + self.sample_seeds = sample_seeds + + self.logdir = os.path.join(logdir, subdir) + os.makedirs(self.logdir) + + self.annotate = annotate + self.annotate_size = annotate_size + self.make_grid = use_make_grid + self.grid_column_size = grid_column_size + self.prompt2prompt_edit = prompt2prompt_edit + self.p2p_config = p2p_config + self.use_inversion_attention = use_inversion_attention + self.source_prompt = source_prompt + self.traverse_p2p_config =traverse_p2p_config + + def log_sample_images( + self, pipeline: DiffusionPipeline, + device: torch.device, step: int, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + latents: torch.FloatTensor = None, + uncond_embeddings_list: List[torch.FloatTensor] = None, + save_dir = None, + ): + torch.cuda.empty_cache() + samples_all = [] + attention_all = [] + # handle input image + if image is not None: + input_pil_images = pipeline.numpy_to_pil(tensor_to_numpy(image))[0] + samples_all.append([ + annotate_image(image, "input sequence", font_size=self.annotate_size) for image in input_pil_images + ]) + for idx, prompt in enumerate(tqdm(self.prompts, desc="Generating sample images")): + if self.prompt2prompt_edit: + if self.traverse_p2p_config: + p2p_config_now = copy.deepcopy(self.p2p_config[idx]) + else: + p2p_config_now = copy.deepcopy(self.p2p_config[idx]) + + if idx == 0 and not self.use_inversion_attention: + edit_type = 'save' + p2p_config_now.update({'save_self_attention': True}) + print('Reflash the attention map in pipeline') + + else: + edit_type = 'swap' + p2p_config_now.update({'save_self_attention': False}) + + p2p_config_now.update({'use_inversion_attention': self.use_inversion_attention}) + else: + edit_type = None + + input_prompt = prompt + for seed in self.sample_seeds: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + sequence_return = pipeline( + prompt=input_prompt, + source_prompt = self.prompts[0] if self.source_prompt is None else self.source_prompt, + edit_type = edit_type, + image=image, # torch.Size([8, 3, 512, 512]) + strength=self.strength, + generator=generator, + num_inference_steps=self.num_inference_steps, + clip_length=self.clip_length, + guidance_scale=self.guidance_scale, + num_images_per_prompt=1, + # used in null inversion + latents = latents, + uncond_embeddings_list = uncond_embeddings_list, + save_path = save_dir, + **p2p_config_now, + ) + if self.prompt2prompt_edit: + sequence = sequence_return['sdimage_output'].images[0] + attention_output = sequence_return['attention_output'] + mask_list = sequence_return.get('mask_list', None) + else: + sequence = sequence_return.images[0] + torch.cuda.empty_cache() + + if self.annotate: + images = [ + annotate_image(image, prompt, font_size=self.annotate_size) for image in sequence + ] + + if self.make_grid: + samples_all.append(images) + if self.prompt2prompt_edit: + if attention_output is not None: + attention_all.append(attention_output) + + save_path = os.path.join(self.logdir, f"step_{step}_{idx}_{seed}.gif") + save_gif_mp4_folder_type(images, save_path) + + if self.prompt2prompt_edit: + if mask_list is not None and len(mask_list) > 0: + save_gif_mp4_folder_type(mask_list, save_path.replace('.gif', 'mask.gif')) + if attention_output is not None: + save_gif_mp4_folder_type(attention_output, save_path.replace('.gif', 'atten.gif')) + + if self.make_grid: + samples_all = [make_grid(images, cols=int(np.ceil(np.sqrt(len(samples_all))))) for images in zip(*samples_all)] + save_path = os.path.join(self.logdir, f"step_{step}.gif") + save_gif_mp4_folder_type(samples_all, save_path) + if self.prompt2prompt_edit: + if len(attention_all) > 0 : + attention_all = [make_grid(images, cols=1) for images in zip(*attention_all)] + if len(attention_all) > 0: + save_gif_mp4_folder_type(attention_all, save_path.replace('.gif', 'atten.gif')) + return samples_all, save_path + + + + +def tensor_to_numpy(image, b=1): + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + + image = image.cpu().float().numpy() + image = rearrange(image, "(b f) c h w -> b f h w c", b=b) + return image diff --git a/FateZero/video_diffusion/pipelines/stable_diffusion.py b/FateZero/video_diffusion/pipelines/stable_diffusion.py new file mode 100755 index 0000000000000000000000000000000000000000..c594aa9b63ae3fb4d635bc582103a187c2598cdd --- /dev/null +++ b/FateZero/video_diffusion/pipelines/stable_diffusion.py @@ -0,0 +1,610 @@ +# code mostly taken from https://github.com/huggingface/diffusers +import inspect +from typing import Callable, List, Optional, Union +import os, sys + +import torch +from einops import rearrange + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ..models.unet_3d_condition import UNetPseudo3DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SpatioTemporalStableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Spatio-Temporal Stable Diffusion. + """ + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNetPseudo3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + def prepare_before_train_loop(self, params_to_optimize=None): + # Set xformers in train.py + + # self.disable_xformers_memory_efficient_attention() + + self.vae.requires_grad_(False) + self.unet.requires_grad_(False) + self.text_encoder.requires_grad_(False) + + self.vae.eval() + self.unet.eval() + self.text_encoder.eval() + + if params_to_optimize is not None: + params_to_optimize.requires_grad = True + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + is_video = (latents.dim() == 5) + b = latents.shape[0] + latents = 1 / 0.18215 * latents + + if is_video: + latents = rearrange(latents, "b c f h w -> (b f) c h w") # torch.Size([70, 4, 64, 64]) + + latents_split = torch.split(latents, 16, dim=0) + image = torch.cat([self.vae.decode(l).sample for l in latents_split], dim=0) + + # image_full = self.vae.decode(latents).sample + # RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements + # Pytorch upsample alogrithm not work for batch size 32 -> 64 + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + + image = image.cpu().float().numpy() + if is_video: + image = rearrange(image, "(b f) c h w -> b f h w c", b=b) + else: + image = rearrange(image, "b c h w -> b h w c", b=b) + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + clip_length, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + clip_length, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to( + device + ) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + clip_length: int = 8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + # [1, 4, 8, 64, 64] + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + clip_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + # [2, 4, 8, 64, 64] + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings + ).sample.to(dtype=latents_dtype) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 [1, 4, 8, 64, 64] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + has_nsfw_concept = None + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + torch.cuda.empty_cache() + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @staticmethod + def numpy_to_pil(images): + # (1, 16, 512, 512, 3) + pil_images = [] + is_video = (len(images.shape)==5) + if is_video: + for sequence in images: + pil_images.append(DiffusionPipeline.numpy_to_pil(sequence)) + else: + pil_images.append(DiffusionPipeline.numpy_to_pil(images)) + return pil_images + + def print_pipeline(self, logger): + print('Overview function of pipeline: ') + print(self.__class__) + + print(self) + + expected_modules, optional_parameters = self._get_signature_keys(self) + components_details = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } + import json + logger.info(str(components_details)) + # logger.info(str(json.dumps(components_details, indent = 4))) + # print(str(components_details)) + # print(self._optional_components) + + print(f"python version {sys.version}") + print(f"torch version {torch.__version__}") + print(f"validate gpu status:") + print( torch.tensor(1.0).cuda()*2) + os.system("nvcc --version") + + import diffusers + print(diffusers.__version__) + print(diffusers.__file__) + + try: + import bitsandbytes + print(bitsandbytes.__file__) + except: + print("fail to import bitsandbytes") + # os.system("accelerate env") + # os.system("python -m xformers.info") diff --git a/FateZero/video_diffusion/pipelines/validation_loop.py b/FateZero/video_diffusion/pipelines/validation_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..f9373d17f154752a8767302389a2caa33b9922cd --- /dev/null +++ b/FateZero/video_diffusion/pipelines/validation_loop.py @@ -0,0 +1,154 @@ +import os +import numpy as np +from typing import Callable, List, Optional, Union +import PIL + + +import torch +import torch.utils.data +import torch.utils.checkpoint + +from diffusers.pipeline_utils import DiffusionPipeline +from tqdm.auto import tqdm +from video_diffusion.common.image_util import make_grid, annotate_image +from video_diffusion.common.image_util import save_gif_mp4_folder_type + + +class SampleLogger: + def __init__( + self, + prompts: List[str], + clip_length: int, + logdir: str, + subdir: str = "sample", + num_samples_per_prompt: int = 1, + sample_seeds: List[int] = None, + num_inference_steps: int = 20, + guidance_scale: float = 7, + strength: float = None, + annotate: bool = True, + annotate_size: int = 15, + make_grid: bool = True, + grid_column_size: int = 2, + prompt2prompt_edit: bool=False, + **args + + ) -> None: + self.prompts = prompts + self.clip_length = clip_length + self.guidance_scale = guidance_scale + self.num_inference_steps = num_inference_steps + self.strength = strength + + if sample_seeds is None: + max_num_samples_per_prompt = int(1e5) + if num_samples_per_prompt > max_num_samples_per_prompt: + raise ValueError + sample_seeds = torch.randint(0, max_num_samples_per_prompt, (num_samples_per_prompt,)) + sample_seeds = sorted(sample_seeds.numpy().tolist()) + self.sample_seeds = sample_seeds + + self.logdir = os.path.join(logdir, subdir) + os.makedirs(self.logdir) + + self.annotate = annotate + self.annotate_size = annotate_size + self.make_grid = make_grid + self.grid_column_size = grid_column_size + self.prompt2prompt_edit = prompt2prompt_edit + + def log_sample_images( + self, pipeline: DiffusionPipeline, + device: torch.device, step: int, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + latents: torch.FloatTensor = None, + uncond_embeddings_list: List[torch.FloatTensor] = None, + ): + torch.cuda.empty_cache() + samples_all = [] + attention_all = [] + # handle input image + if image is not None: + input_pil_images = pipeline.numpy_to_pil(tensor_to_numpy(image))[0] + samples_all.append([ + annotate_image(image, "input sequence", font_size=self.annotate_size) for image in input_pil_images + ]) + for idx, prompt in enumerate(tqdm(self.prompts, desc="Generating sample images")): + if self.prompt2prompt_edit: + if idx == 0: + edit_type = 'save' + else: + edit_type = 'swap' + else: + edit_type = None + for seed in self.sample_seeds: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + # if isinstance(pipeline, SDeditSpatioTemporalStableDiffusionPipeline): + sequence_return = pipeline( + prompt=prompt, + edit_type = edit_type, + image=image, # torch.Size([8, 3, 512, 512]) + strength=self.strength, + generator=generator, + num_inference_steps=self.num_inference_steps, + clip_length=self.clip_length, + guidance_scale=self.guidance_scale, + num_images_per_prompt=1, + # used in null inversion + latents = latents, + uncond_embeddings_list = uncond_embeddings_list, + # Put the source prompt at the first one, when using p2p + # edit_type = edit_type + ) + if self.prompt2prompt_edit: + sequence = sequence_return['sdimage_output'].images[0] + attention_output = sequence_return['attention_output'] + if ddim_latents_all_step in sequence_return: + ddim_latents_all_step = sequence_return['ddim_latents_all_step'] + else: + sequence = sequence_return.images[0] + torch.cuda.empty_cache() + + if self.annotate: + images = [ + annotate_image(image, prompt, font_size=self.annotate_size) for image in sequence + ] + + if self.make_grid: + samples_all.append(images) + if self.prompt2prompt_edit: + attention_all.append(attention_output) + # else: + save_path = os.path.join(self.logdir, f"step_{step}_{idx}_{seed}.gif") + # save_path_mp4 = save_path.replace('gif', 'mp4') + # save_path_folder = save_path.replace('.gif', '') + # save_images_as_gif(images, save_path) + # save_images_as_mp4(images, save_path_mp4) + # save_images_as_folder(images, save_path_folder) + save_gif_mp4_folder_type(images, save_path) + if self.prompt2prompt_edit: + save_gif_mp4_folder_type(attention_output, save_path.replace('.gif', 'atten.gif')) + + if self.make_grid: + samples_all = [make_grid(images, cols=int(np.ceil(np.sqrt(len(samples_all))))) for images in zip(*samples_all)] + save_path = os.path.join(self.logdir, f"step_{step}.gif") + # save_images_as_gif(samples_all, save_path) + save_gif_mp4_folder_type(samples_all, save_path) + if self.prompt2prompt_edit: + attention_all = [make_grid(images, cols=1) for images in zip(*attention_all)] + # save_path = os.path.join(self.logdir, f"step_{step}.gif") + # save_images_as_gif(samples_all, save_path) + save_gif_mp4_folder_type(attention_all, save_path.replace('.gif', 'atten.gif')) + return samples_all + + +from einops import rearrange + +def tensor_to_numpy(image, b=1): + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + + image = image.cpu().float().numpy() + image = rearrange(image, "(b f) c h w -> b f h w c", b=b) + return image \ No newline at end of file diff --git a/FateZero/video_diffusion/prompt_attention/attention_util.py b/FateZero/video_diffusion/prompt_attention/attention_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5ff77716dddeafa6c04d63da89d97ffd3826ef --- /dev/null +++ b/FateZero/video_diffusion/prompt_attention/attention_util.py @@ -0,0 +1,1077 @@ +""" +Code for prompt2prompt local editing and attention visualization + +""" + +from typing import Optional, Union, Tuple, List, Dict +import abc +import os +import datetime +import numpy as np +from PIL import Image +import copy +import torchvision.utils as tvu +from einops import rearrange + +import torch +import torch.nn.functional as F + +from video_diffusion.common.util import get_time_string +import video_diffusion.prompt_attention.ptp_utils as ptp_utils +import video_diffusion.prompt_attention.seq_aligner as seq_aligner +from video_diffusion.common.image_util import save_gif_mp4_folder_type +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +class LocalBlend: + """Called in make_controller + self.alpha_layers.shape = torch.Size([2, 1, 1, 1, 1, 77]), 1 denotes the world to be replaced + """ + def get_mask(self, maps, alpha, use_pool, x_t, step_in_store: int=None, prompt_choose='source'): + k = 1 + # ([2, 40, 4, 16, 16, 77]) * ([2, 1, 1, 1, 1, 77]) -> [2, 1, 16, 16] + if maps.dim() == 5: alpha = alpha[:, None, ...] + maps = (maps * alpha).sum(-1).mean(1) + if use_pool: + maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) + mask = F.interpolate(maps, size=(x_t.shape[-2:])) + mask = mask / mask.max(-2, keepdims=True)[0].max(-1, keepdims=True)[0] + mask = mask.gt(self.th[1-int(use_pool)]) + mask = mask[:1] + mask + if self.save_path is not None: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + save_path = f'{self.save_path}/{prompt_choose}/' + if step_in_store is not None: + save_path += f'step_in_store_{step_in_store:04d}' + # f'{self.save_path}/step_in_store_{step_in_store:04d}/mask_{now}_{self.count:02d}.png' + save_path +=f'/mask_{now}_{self.count:02d}.png' + os.makedirs(os.path.dirname(save_path), exist_ok=True) + tvu.save_image(rearrange(mask[1:].float(), "c p h w -> p c h w"), save_path,normalize=True) + self.count +=1 + return mask + + def __call__(self, x_t, attention_store): + """_summary_ + + Args: + x_t (_type_): [1,4,8,64,64] # (prompt, channel, clip_length, res, res) + attention_store (_type_): _description_ + + Returns: + _type_: _description_ + """ + self.counter += 1 + if (self.counter > self.start_blend) and (self.counter < self.end_blend): + + maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] + if maps[0].dim() == 4: + (ph, c, r, w)= maps[0].shape + assert r == 16*16 + # a list of len(5), elements has shape [16, 256, 77] + maps = [rearrange(item, "(p h) c (res_h res_w) w -> p h c res_h res_w w ", + p=self.alpha_layers.shape[0], res_h=16, res_w=16) for item in maps] + maps = torch.cat(maps, dim=1) + mask = self.get_mask(maps, self.alpha_layers, True, x_t) + if self.substruct_layers is not None: + maps_sub = ~self.get_mask(maps, self.substruct_layers, False) + mask = mask * maps_sub + mask = mask.float() + # only for debug + # mask = torch.zeros_like(mask) + # "mask is one: use geenerated information" + # "mask is zero: use geenerated information" + self.mask_list.append(mask[0][:, None, :, :].float().cpu().detach()) + if x_t.dim()==5: + mask = mask[:, None, ...] + # x_t [2,4,2,64,64] + x_t = x_t[:1] + mask * (x_t - x_t[:1]) + else: + (ph, r, w)= maps[0].shape + # a list of len(5), elements has shape [16, 256, 77] + + maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.MAX_NUM_WORDS) for item in maps] + maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.MAX_NUM_WORDS) for item in maps] + maps = torch.cat(maps, dim=1) + mask = self.get_mask(maps, self.alpha_layers, True, x_t) + if self.substruct_layers is not None: + maps_sub = ~self.get_mask(maps, self.substruct_layers, False) + mask = mask * maps_sub + mask = mask.float() + x_t = x_t[:1] + mask * (x_t - x_t[:1]) + + return x_t + + def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, + start_blend=0.2, end_blend=0.8, + th=(0.9, 0.9), tokenizer=None, NUM_DDIM_STEPS =None, + save_path =None): + self.count = 0 + self.MAX_NUM_WORDS = 77 + self.NUM_DDIM_STEPS = NUM_DDIM_STEPS + if save_path is not None: + self.save_path = save_path+'/latents_mask' + os.makedirs(self.save_path, exist_ok='True') + else: + self.save_path = None + alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.MAX_NUM_WORDS) + for i, (prompt, words_) in enumerate(zip(prompts, words)): + if type(words_) is str: + words_ = [words_] + for word in words_: + # debug me + ind = ptp_utils.get_word_inds(prompt, word, tokenizer) + alpha_layers[i, :, :, :, :, ind] = 1 + + if substruct_words is not None: + substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.MAX_NUM_WORDS) + for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)): + if type(words_) is str: + words_ = [words_] + for word in words_: + ind = ptp_utils.get_word_inds(prompt, word, tokenizer) + substruct_layers[i, :, :, :, :, ind] = 1 + self.substruct_layers = substruct_layers.to(device) + else: + self.substruct_layers = None + + self.alpha_layers = alpha_layers.to(device) + self.start_blend = int(start_blend * self.NUM_DDIM_STEPS) + self.end_blend = int(end_blend * self.NUM_DDIM_STEPS) + self.counter = 0 + self.th=th + self.mask_list = [] + + + +class MaskBlend: + """ + First, we consider only source prompt + Called in make_controller + self.alpha_layers.shape = torch.Size([2, 1, 1, 1, 1, 77]), 1 denotes the world to be replaced + """ + def get_mask(self, maps, alpha, use_pool, h=None, w=None, step_in_store: int=None, prompt_choose='source'): + """ + # ([1, 40, 2, 16, 16, 77]) * ([1, 1, 1, 1, 1, 77]) -> [2, 1, 16, 16] + mask have dimension of [clip_length, dim, res, res] + """ + k = 1 + + if maps.dim() == 5: alpha = alpha[:, None, ...] + maps = (maps * alpha).sum(-1).mean(1) + if use_pool: + maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) + mask = F.interpolate(maps, size=(h, w)) + mask = mask / mask.max(-2, keepdims=True)[0].max(-1, keepdims=True)[0] + mask = mask.gt(self.th[1-int(use_pool)]) + if self.save_path is not None: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + + save_path = f'{self.save_path}/{prompt_choose}/' + if step_in_store is not None: + save_path += f'step_in_store_{step_in_store:04d}' + save_path +=f'/mask_{now}_{self.count:02d}.png' + os.makedirs(os.path.dirname(save_path), exist_ok=True) + tvu.save_image(rearrange(mask.float(), "c p h w -> p c h w"), save_path,normalize=True) + self.count +=1 + return mask + + def __call__(self, target_h, target_w, attention_store, step_in_store: int=None): + """ + input has shape (heads) clip res words + one meens using target self-attention, zero is using source + Previous implementation us all zeros + mask should be repeat. + + Args: + x_t (_type_): [1,4,8,64,64] # (prompt, channel, clip_length, res, res) + attention_store (_type_): _description_ + + Returns: + _type_: _description_ + """ + + maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] + + # maps = attention_store # [2,8,1024, 77] = [frames, head, (res, res), word_embedding] + assert maps[0].dim() == 4, "only support temporal data" + ( c, heads, r, w)= maps[0].shape + res_h = int(np.sqrt(r)) + assert r == res_h* res_h + # a list of len(5), elements has shape [16, 256, 77] + target_device = self.alpha_layers.device + target_dtype = self.alpha_layers.dtype + maps = [rearrange(item, " c h (res_h res_w) w -> h c res_h res_w w ", + h=heads, res_h=res_h, res_w=res_h)[None, ...].to(target_device, dtype=target_dtype) + for item in maps] + + + maps = torch.cat(maps, dim=1) + # We only support self-attention blending using source prompt + masked_alpah_layers = self.alpha_layers[0:1] + mask = self.get_mask(maps, masked_alpah_layers, True, target_h, target_w, step_in_store=step_in_store, prompt_choose='source') + + if self.substruct_layers is not None: + maps_sub = ~self.get_mask(maps, self.substruct_layers, False) + mask = mask * maps_sub + mask = mask.float() + + # "mask is one: use geenerated information" + # "mask is zero: use geenerated information" + self.mask_list.append(mask[0][:, None, :, :].float().cpu().detach()) + + return mask + + def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, + start_blend=0.2, end_blend=0.8, + th=(0.9, 0.9), tokenizer=None, NUM_DDIM_STEPS =None, + save_path = None): + self.count = 0 + # self.config_dict = copy.deepcopy(config_dict) + self.MAX_NUM_WORDS = 77 + self.NUM_DDIM_STEPS = NUM_DDIM_STEPS + if save_path is not None: + self.save_path = save_path+'/blend_mask' + os.makedirs(self.save_path, exist_ok='True') + else: + self.save_path = None + alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.MAX_NUM_WORDS) + for i, (prompt, words_) in enumerate(zip(prompts, words)): + if type(words_) is str: + words_ = [words_] + for word in words_: + # debug me + ind = ptp_utils.get_word_inds(prompt, word, tokenizer) + alpha_layers[i, :, :, :, :, ind] = 1 + + if substruct_words is not None: + substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.MAX_NUM_WORDS) + for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)): + if type(words_) is str: + words_ = [words_] + for word in words_: + ind = ptp_utils.get_word_inds(prompt, word, tokenizer) + substruct_layers[i, :, :, :, :, ind] = 1 + self.substruct_layers = substruct_layers.to(device) + else: + self.substruct_layers = None + + self.alpha_layers = alpha_layers.to(device) + print('the index mask of edited word in the prompt') + print(self.alpha_layers[0][..., 0:(len(prompts[0].split(" "))+2)]) + print(self.alpha_layers[1][..., 0:(len(prompts[1].split(" "))+2)]) + + self.start_blend = int(start_blend * self.NUM_DDIM_STEPS) + self.end_blend = int(end_blend * self.NUM_DDIM_STEPS) + self.counter = 0 + self.th=th + self.mask_list = [] + + + + +class EmptyControl: + + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + return attn + + +class AttentionControl(abc.ABC): + + def step_callback(self, x_t): + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + """I guess the diffusion of google has some unconditional attention layer + No unconditional attention layer in Stable diffusion + + Returns: + _type_: _description_ + """ + # return self.num_att_layers if config_dict['LOW_RESOURCE'] else 0 + return 0 + + @abc.abstractmethod + def forward (self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + if self.LOW_RESOURCE: + # For inversion without null text file + attn = self.forward(attn, is_cross, place_in_unet) + else: + # For classifier-free guidance scale!=1 + h = attn.shape[0] + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self, + ): + self.LOW_RESOURCE = False # assume the edit have cfg + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + +class SpatialReplace(EmptyControl): + + def step_callback(self, x_t): + if self.cur_step < self.stop_inject: + b = x_t.shape[0] + x_t = x_t[:1].expand(b, *x_t.shape[1:]) + return x_t + + def __init__(self, stop_inject: float, NUM_DDIM_STEPS=None): + super(SpatialReplace, self).__init__() + self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS) + + +class AttentionStore(AttentionControl): + def step_callback(self, x_t): + + + x_t = super().step_callback(x_t) + self.latents_store.append(x_t.cpu().detach()) + return x_t + + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + @staticmethod + def get_empty_cross_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + } + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[-2] <= 32 ** 2: # avoid memory overhead + # print(f"Store attention map {key} of shape {attn.shape}") + if is_cross or self.save_self_attention: + if attn.shape[-2] == 32**2: + append_tensor = attn.cpu().detach() + else: + append_tensor = attn + self.step_store[key].append(copy.deepcopy(append_tensor)) + return attn + + def between_steps(self): + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + + if self.disk_store: + path = self.store_dir + f'/{self.cur_step:03d}.pt' + torch.save(copy.deepcopy(self.step_store), path) + self.attention_store_all_step.append(path) + else: + self.attention_store_all_step.append(copy.deepcopy(self.step_store)) + self.step_store = self.get_empty_store() + + def get_average_attention(self): + "divide the attention map value in attention store by denoising steps" + average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} + return average_attention + + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store_all_step = [] + self.attention_store = {} + + def __init__(self, save_self_attention:bool=True, disk_store=False): + super(AttentionStore, self).__init__() + self.disk_store = disk_store + if self.disk_store: + time_string = get_time_string() + path = f'./trash/attention_cache_{time_string}' + os.makedirs(path, exist_ok=True) + self.store_dir = path + else: + self.store_dir =None + self.step_store = self.get_empty_store() + self.attention_store = {} + self.save_self_attention = save_self_attention + self.latents_store = [] + self.attention_store_all_step = [] + + +class AttentionControlEdit(AttentionStore, abc.ABC): + """Decide self or cross-attention. Call the reweighting cross attention module + + Args: + AttentionStore (_type_): ([1, 4, 8, 64, 64]) + abc (_type_): [8, 8, 1024, 77] + """ + + def step_callback(self, x_t): + x_t = super().step_callback(x_t) + x_t_device = x_t.device + x_t_dtype = x_t.dtype + if self.local_blend is not None: + if self.use_inversion_attention: + step_in_store = len(self.additional_attention_store.latents_store) - self.cur_step + else: + step_in_store = self.cur_step + + inverted_latents = self.additional_attention_store.latents_store[step_in_store] + inverted_latents = inverted_latents.to(device =x_t_device, dtype=x_t_dtype) + # [prompt, channel, clip, res, res] = [1, 4, 2, 64, 64] + + blend_dict = self.get_empty_cross_store() + # each element in blend_dict have (prompt head) clip_length (res res) words, + # to better align with (b c f h w) + + attention_store_step = self.additional_attention_store.attention_store_all_step[step_in_store] + if isinstance(place_in_unet_cross_atten_list, str): attention_store_step = torch.load(attention_store_step) + + for key in blend_dict.keys(): + place_in_unet_cross_atten_list = attention_store_step[key] + for i, attention in enumerate(place_in_unet_cross_atten_list): + + concate_attention = torch.cat([attention[None, ...], self.attention_store[key][i][None, ...]], dim=0) + blend_dict[key].append(copy.deepcopy(rearrange(concate_attention, ' p c h res words -> (p h) c res words'))) + x_t = self.local_blend(copy.deepcopy(torch.cat([inverted_latents, x_t], dim=0)), copy.deepcopy(blend_dict)) + return x_t[1:, ...] + else: + return x_t + + def replace_self_attention(self, attn_base, att_replace, reshaped_mask=None): + if att_replace.shape[-2] <= 32 ** 2: + target_device = att_replace.device + target_dtype = att_replace.dtype + attn_base = attn_base.to(target_device, dtype=target_dtype) + attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) + if reshaped_mask is not None: + return_attention = reshaped_mask*att_replace + (1-reshaped_mask)*attn_base + return return_attention + else: + return attn_base + else: + return att_replace + + @abc.abstractmethod + def replace_cross_attention(self, attn_base, att_replace): + raise NotImplementedError + + def update_attention_position_dict(self, current_attention_key): + self.attention_position_counter_dict[current_attention_key] +=1 + + + def forward(self, attn, is_cross: bool, place_in_unet: str): + super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) + if attn.shape[-2] <= 32 ** 2: + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + current_pos = self.attention_position_counter_dict[key] + + if self.use_inversion_attention: + step_in_store = len(self.additional_attention_store.attention_store_all_step) - self.cur_step -1 + else: + step_in_store = self.cur_step + + place_in_unet_cross_atten_list = self.additional_attention_store.attention_store_all_step[step_in_store] + if isinstance(place_in_unet_cross_atten_list, str): place_in_unet_cross_atten_list = torch.load(place_in_unet_cross_atten_list) + # breakpoint() + # Note that attn is append to step_store, + # if attn is get through clean -> noisy, we should inverse it + attn_base = place_in_unet_cross_atten_list[key][current_pos] + + self.update_attention_position_dict(key) + # save in format of [temporal, head, resolution, text_embedding] + if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + clip_length = attn.shape[0] // (self.batch_size) + attn = attn.reshape(self.batch_size, clip_length, *attn.shape[1:]) + # Replace att_replace with attn_base + attn_base, attn_repalce = attn_base, attn[0:] + if is_cross: + alpha_words = self.cross_replace_alpha[self.cur_step] + attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce + attn[0:] = attn_repalce_new # b t h p n = [1, 1, 8, 1024, 77] + else: + + # start of masked self-attention + if self.MB is not None and attn_repalce.shape[-2] <= 32 ** 2: + # ca_this_step = place_in_unet_cross_atten_list + # query 1024, key 2048 + h = int(np.sqrt(attn_repalce.shape[-2])) + w = h + mask = self.MB(target_h = h, target_w =w, attention_store= place_in_unet_cross_atten_list, step_in_store=step_in_store) + # reshape from ([ 1, 2, 32, 32]) -> [2, 1, 1024, 1] + reshaped_mask = rearrange(mask, "d c h w -> c d (h w)")[..., None] + + # input has shape (h) c res words + # one meens using target self-attention, zero is using source + # Previous implementation us all zeros + # mask should be repeat. + else: + reshaped_mask = None + attn[0:] = self.replace_self_attention(attn_base, attn_repalce, reshaped_mask) + + + + attn = attn.reshape(self.batch_size * clip_length, *attn.shape[2:]) + # save in format of [temporal, head, resolution, text_embedding] + + return attn + def between_steps(self): + + super().between_steps() + self.step_store = self.get_empty_store() + + self.attention_position_counter_dict = { + 'down_cross': 0, + 'mid_cross': 0, + 'up_cross': 0, + 'down_self': 0, + 'mid_self': 0, + 'up_self': 0, + } + return + def __init__(self, prompts, num_steps: int, + cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], + self_replace_steps: Union[float, Tuple[float, float]], + local_blend: Optional[LocalBlend], tokenizer=None, + additional_attention_store: AttentionStore =None, + use_inversion_attention: bool=False, + MB: MaskBlend= None, + save_self_attention: bool=True, + disk_store=False + ): + super(AttentionControlEdit, self).__init__( + save_self_attention=save_self_attention, + disk_store=disk_store) + self.additional_attention_store = additional_attention_store + self.batch_size = len(prompts) + self.MB = MB + if self.additional_attention_store is not None: + # the attention_store is provided outside, only pass in one promp + self.batch_size = len(prompts) //2 + assert self.batch_size==1, 'Only support single video editing with additional attention_store' + + self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device) + if type(self_replace_steps) is float: + self_replace_steps = 0, self_replace_steps + self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.local_blend = local_blend + # We need to know the current position in attention + self.prev_attention_key_name = 0 + self.use_inversion_attention = use_inversion_attention + self.attention_position_counter_dict = { + 'down_cross': 0, + 'mid_cross': 0, + 'up_cross': 0, + 'down_self': 0, + 'mid_self': 0, + 'up_self': 0, + } + +class AttentionReplace(AttentionControlEdit): + + def replace_cross_attention(self, attn_base, att_replace): + # torch.Size([8, 4096, 77]), torch.Size([1, 77, 77]) -> [1, 8, 4096, 77] + # Can be extend to temporal, use temporal as batch size + target_device = att_replace.device + target_dtype = att_replace.dtype + attn_base = attn_base.to(target_device, dtype=target_dtype) + + if attn_base.dim()==3: + return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) + elif attn_base.dim()==4: + return torch.einsum('thpw,bwn->bthpn', attn_base, self.mapper) + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, tokenizer=None, + additional_attention_store=None, + use_inversion_attention = False, + MB: MaskBlend=None, + save_self_attention: bool = True, + disk_store=False): + super(AttentionReplace, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer=tokenizer, + additional_attention_store=additional_attention_store, use_inversion_attention = use_inversion_attention, + MB=MB, + save_self_attention = save_self_attention, + disk_store=disk_store + ) + self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) + +class AttentionRefine(AttentionControlEdit): + + def replace_cross_attention(self, attn_base, att_replace): + + target_device = att_replace.device + target_dtype = att_replace.dtype + attn_base = attn_base.to(target_device, dtype=target_dtype) + if attn_base.dim()==3: + attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) + elif attn_base.dim()==4: + attn_base_replace = attn_base[:, :, :, self.mapper].permute(3, 0, 1, 2, 4) + attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) + return attn_replace + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, tokenizer=None, + additional_attention_store=None, + use_inversion_attention = False, + MB: MaskBlend=None, + save_self_attention : bool=True, + disk_store = False + ): + super(AttentionRefine, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer=tokenizer, + additional_attention_store=additional_attention_store, use_inversion_attention = use_inversion_attention, + MB=MB, + save_self_attention = save_self_attention, + disk_store = disk_store + ) + self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) + self.mapper, alphas = self.mapper.to(device), alphas.to(device) + self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) + + +class AttentionReweight(AttentionControlEdit): + """First replace the weight, than increase the attention at a area + + Args: + AttentionControlEdit (_type_): _description_ + """ + + def replace_cross_attention(self, attn_base, att_replace): + if self.prev_controller is not None: + attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) + attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] + return attn_replace + + def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, + local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None, tokenizer=None, + additional_attention_store=None, + use_inversion_attention = False, + MB: MaskBlend=None, + save_self_attention:bool = True, + disk_store = False + ): + super(AttentionReweight, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer=tokenizer, + additional_attention_store=additional_attention_store, + use_inversion_attention = use_inversion_attention, + MB=MB, + save_self_attention=save_self_attention, + disk_store = disk_store + ) + self.equalizer = equalizer.to(device) + self.prev_controller = controller + +def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], + Tuple[float, ...]], tokenizer=None): + if type(word_select) is int or type(word_select) is str: + word_select = (word_select,) + equalizer = torch.ones(1, 77) + + for word, val in zip(word_select, values): + inds = ptp_utils.get_word_inds(text, word, tokenizer) + equalizer[:, inds] = val + return equalizer + +def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): + out = [] + attention_maps = attention_store.get_average_attention() + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.dim() == 3: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + elif item.dim() == 4: + t, h, res_sq, token = item.shape + if item.shape[2] == num_pixels: + cross_maps = item.reshape(len(prompts), t, -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + + out = torch.cat(out, dim=-4) + out = out.sum(-4) / out.shape[-4] + return out.cpu() + + +def make_controller(tokenizer, prompts: List[str], is_replace_controller: bool, + cross_replace_steps: Dict[str, float], self_replace_steps: float=0.0, + blend_words=None, equilizer_params=None, + additional_attention_store=None, use_inversion_attention = False, bend_th: float=(0.3, 0.3), + NUM_DDIM_STEPS=None, + masked_latents = False, + masked_self_attention=False, + save_path = None, + save_self_attention = True, + disk_store = False + ) -> AttentionControlEdit: + if (blend_words is None) or (blend_words == 'None'): + lb = None + MB =None + else: + if masked_latents: + lb = LocalBlend( prompts, blend_words, tokenizer=tokenizer, th=bend_th, NUM_DDIM_STEPS=NUM_DDIM_STEPS, + save_path=save_path) + else: + lb = None + if masked_self_attention: + MB = MaskBlend( prompts, blend_words, tokenizer=tokenizer, th=bend_th, NUM_DDIM_STEPS=NUM_DDIM_STEPS, + save_path=save_path) + print(f'Control self attention mask with threshold {bend_th}') + else: + MB = None + if is_replace_controller: + print('use replace controller') + controller = AttentionReplace(prompts, NUM_DDIM_STEPS, + cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, + local_blend=lb, tokenizer=tokenizer, + additional_attention_store=additional_attention_store, + use_inversion_attention = use_inversion_attention, + MB=MB, + save_self_attention = save_self_attention, + disk_store=disk_store + ) + else: + print('use refine controller') + controller = AttentionRefine(prompts, NUM_DDIM_STEPS, + cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, + local_blend=lb, tokenizer=tokenizer, + additional_attention_store=additional_attention_store, + use_inversion_attention = use_inversion_attention, + MB=MB, + save_self_attention = save_self_attention, + disk_store=disk_store + ) + if equilizer_params is not None: + eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"], tokenizer=tokenizer) + controller = AttentionReweight(prompts, NUM_DDIM_STEPS, + cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, + equalizer=eq, local_blend=lb, controller=controller, + tokenizer=tokenizer, + additional_attention_store=additional_attention_store, + use_inversion_attention = use_inversion_attention, + MB=MB, + save_self_attention = save_self_attention, + disk_store=disk_store + ) + return controller + + +def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, + res: int, from_where: List[str], select: int = 0, save_path = None): + """_summary_ + + tokenizer (_type_): _description_ + prompts (_type_): _description_ + attention_store (AttentionStore): _description_ + ["down", "mid", "up"] X ["self", "cross"] + 4, 1, 6 + head*res*text_token_len = 8*res*77 + res=1024 -> 64 -> 1024 + res (int): res + from_where (List[str]): "up", "down' + select (int, optional): _description_. Defaults to 0. + """ + if isinstance(prompts, str): + prompts = [prompts,] + tokens = tokenizer.encode(prompts[select]) # list of length 9, [0-49 K] + decoder = tokenizer.decode + # 16, 16, 7, 7 + attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select) + os.makedirs('trash', exist_ok=True) + attention_list = [] + if attention_maps.dim()==3: attention_maps=attention_maps[None, ...] + for j in range(attention_maps.shape[0]): + images = [] + for i in range(len(tokens)): + image = attention_maps[j, :, :, i] + image = 255 * image / image.max() + image = image.unsqueeze(-1).expand(*image.shape, 3) + image = image.numpy().astype(np.uint8) + image = np.array(Image.fromarray(image).resize((256, 256))) + image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) + images.append(image) + ptp_utils.view_images(np.stack(images, axis=0), save_path=save_path) + atten_j = np.concatenate(images, axis=1) + attention_list.append(atten_j) + if save_path is not None: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + video_save_path = f'{save_path}/{now}.gif' + save_gif_mp4_folder_type(attention_list, video_save_path) + return attention_list + + +def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], + max_com=10, select: int = 0): + attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) + u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) + images = [] + for i in range(max_com): + image = vh[i].reshape(res, res) + image = image - image.min() + image = 255 * image / image.max() + image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) + image = Image.fromarray(image).resize((256, 256)) + image = np.array(image) + images.append(image) + ptp_utils.view_images(np.concatenate(images, axis=1)) + + +def register_attention_control(model, controller): + "Connect a model with a controller" + def ca_forward(self, place_in_unet, attention_type='cross'): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + + def _attention( query, key, value, is_cross, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # KEY FUNCTION: + # Record and edit the attention probs + attention_probs_th = reshape_batch_dim_to_temporal_heads(attention_probs) + attention_probs = controller(reshape_batch_dim_to_temporal_heads(attention_probs), + is_cross, place_in_unet) + attention_probs = reshape_temporal_heads_to_batch_dim(attention_probs_th) + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def reshape_temporal_heads_to_batch_dim( tensor): + head_size = self.heads + tensor = rearrange(tensor, " b h s t -> (b h) s t ", h = head_size) + return tensor + + def reshape_batch_dim_to_temporal_heads(tensor): + head_size = self.heads + tensor = rearrange(tensor, "(b h) s t -> b h s t", h = head_size) + return tensor + + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): + # hidden_states: torch.Size([16, 4096, 320]) + # encoder_hidden_states: torch.Size([16, 77, 768]) + is_cross = encoder_hidden_states is not None + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + + if self._use_memory_efficient_attention_xformers and query.shape[-2] > 32 ** 2: + # for large attention map of 64X64, use xformers to save memory + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + + hidden_states = _attention(query, key, value, is_cross=is_cross, attention_mask=attention_mask) + # else: + # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = to_out(hidden_states) + + # dropout + # hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def scforward( + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + clip_length: int = None, + SparseCausalAttention_index: list = [-1, 'first'] + ): + if ( + self.added_kv_proj_dim is not None + or encoder_hidden_states is not None + or attention_mask is not None + ): + raise NotImplementedError + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + query = self.reshape_heads_to_batch_dim(query) + + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + if clip_length is not None: + key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) + value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) + + + # ***********************Start of SparseCausalAttention_index********** + frame_index_list = [] + # print(f'SparseCausalAttention_index {str(SparseCausalAttention_index)}') + if len(SparseCausalAttention_index) > 0: + for index in SparseCausalAttention_index: + if isinstance(index, str): + if index == 'first': + frame_index = [0] * clip_length + if index == 'last': + frame_index = [clip_length-1] * clip_length + if (index == 'mid') or (index == 'middle'): + frame_index = [int((clip_length-1)//2)] * clip_length + else: + assert isinstance(index, int), 'relative index must be int' + frame_index = torch.arange(clip_length) + index + frame_index = frame_index.clip(0, clip_length-1) + + frame_index_list.append(frame_index) + key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list + ], dim=2) + value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list + ], dim=2) + + + # ***********************End of SparseCausalAttention_index********** + key = rearrange(key, "b f d c -> (b f) d c", f=clip_length) + value = rearrange(value, "b f d c -> (b f) d c", f=clip_length) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self._use_memory_efficient_attention_xformers and query.shape[-2] > 32 ** 2: + # for large attention map of 64X64, use xformers to save memory + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + # if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = _attention(query, key, value, attention_mask=attention_mask, is_cross=False) + # else: + # hidden_states = self._sliced_attention( + # query, key, value, hidden_states.shape[1], dim, attention_mask + # ) + + # linear proj + hidden_states = to_out(hidden_states) + + # dropout + # hidden_states = self.to_out[1](hidden_states) + return hidden_states + if attention_type == 'CrossAttention': + return forward + elif attention_type == "SparseCausalAttention": + return scforward + + class DummyController: + + def __call__(self, *args): + return args[0] + + def __init__(self): + self.num_att_layers = 0 + + if controller is None: + controller = DummyController() + + def register_recr(net_, count, place_in_unet): + if net_[1].__class__.__name__ == 'CrossAttention' \ + or net_[1].__class__.__name__ == 'SparseCausalAttention': + net_[1].forward = ca_forward(net_[1], place_in_unet, attention_type = net_[1].__class__.__name__) + return count + 1 + elif hasattr(net_[1], 'children'): + for net in net_[1].named_children(): + if net[0] !='attn_temporal': + + count = register_recr(net, count, place_in_unet) + + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net, 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net, 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net, 0, "mid") + print(f"Number of attention layer registered {cross_att_count}") + controller.num_att_layers = cross_att_count diff --git a/FateZero/video_diffusion/prompt_attention/ptp_utils.py b/FateZero/video_diffusion/prompt_attention/ptp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f18cefdfc2e952f2349656655cfcbbf8ac5d7da6 --- /dev/null +++ b/FateZero/video_diffusion/prompt_attention/ptp_utils.py @@ -0,0 +1,253 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +import cv2 +from typing import Optional, Union, Tuple, List, Callable, Dict +# from IPython.display import display +from tqdm.notebook import tqdm +import datetime +import random + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) + return img + + +def view_images(images, num_rows=1, offset_ratio=0.02, save_path=None): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + if save_path is not None: + pil_img = Image.fromarray(image_) + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + pil_img.save(f'{save_path}/{now}.png') + # display(pil_img) + +def load_512(image_path, left=0, right=0, top=0, bottom=0): + if type(image_path) is str: + image = np.array(Image.open(image_path))[:, :, :3] + else: + image = image_path + h, w, c = image.shape + left = min(left, w-1) + right = min(right, w - left - 1) + top = min(top, h - left - 1) + bottom = min(bottom, h - top - 1) + image = image[top:h-bottom, left:w-right] + h, w, c = image.shape + if h < w: + offset = (w - h) // 2 + image = image[:, offset:offset + h] + elif w < h: + offset = (h - w) // 2 + image = image[offset:offset + w] + image = np.array(Image.fromarray(image).resize((512, 512))) + return image + +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + + + +def latent2image(vae, latents): + latents = 1 / 0.18215 * latents + image = vae.decode(latents)['sample'] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image * 255).astype(np.uint8) + return image + + +def init_latent(latent, model, height, width, generator, batch_size): + # Expand latent with given shape, or randonly initialize it + if latent is None: + latent = torch.randn( + (1, model.unet.in_channels, height // 8, width // 8), + generator=generator, + ) + latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) + return latent, latents + + + +def register_attention_control(model, controller): + "Connect a model with a controller" + def ca_forward(self, place_in_unet): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + + # def forward(x, encoder_hidden_states=None, attention_mask=None): + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = self.to_q(hidden_states) + query = self.head_to_batch_dim(query) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + key = self.head_to_batch_dim(key) + value = self.head_to_batch_dim(value) + + attention_probs = self.get_attention_scores(query, key, attention_mask) # [16, 4096, 4096] + attention_probs = controller(attention_probs, is_cross, place_in_unet) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = self.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + return forward + + class DummyController: + + def __call__(self, *args): + return args[0] + + def __init__(self): + self.num_att_layers = 0 + + if controller is None: + controller = DummyController() + + def register_recr(net_, count, place_in_unet): + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + elif hasattr(net_, 'children'): + for net__ in net_.children(): + count = register_recr(net__, count, place_in_unet) + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net[1], 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net[1], 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net[1], 0, "mid") + + controller.num_att_layers = cross_att_count + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, + word_inds: Optional[torch.Tensor]=None): + # Edit the alpha map during attention map editing + if type(bounds) is float: + bounds = 0, bounds + start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) + if word_inds is None: + word_inds = torch.arange(alpha.shape[2]) + alpha[: start, prompt_ind, word_inds] = 0 + alpha[start: end, prompt_ind, word_inds] = 1 + alpha[end:, prompt_ind, word_inds] = 0 + return alpha + +import omegaconf +def get_time_words_attention_alpha(prompts, num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, max_num_words=77): + # Not understand + if (type(cross_replace_steps) is not dict) and \ + (type(cross_replace_steps) is not omegaconf.dictconfig.DictConfig): + cross_replace_steps = {"default_": cross_replace_steps} + if "default_" not in cross_replace_steps: + cross_replace_steps["default_"] = (0., 1.) + alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) + for i in range(len(prompts) - 1): + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], + i) + for key, item in cross_replace_steps.items(): + if key != "default_": + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + for i, ind in enumerate(inds): + if len(ind) > 0: + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + return alpha_time_words diff --git a/FateZero/video_diffusion/prompt_attention/seq_aligner.py b/FateZero/video_diffusion/prompt_attention/seq_aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..684036b77b137bbbe1be3d15a56e8a56b62fca9a --- /dev/null +++ b/FateZero/video_diffusion/prompt_attention/seq_aligner.py @@ -0,0 +1,196 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np + + +class ScoreParams: + + def __init__(self, gap, match, mismatch): + self.gap = gap + self.match = match + self.mismatch = mismatch + + def mis_match_char(self, x, y): + if x != y: + return self.mismatch + else: + return self.match + + +def get_matrix(size_x, size_y, gap): + matrix = [] + for i in range(len(size_x) + 1): + sub_matrix = [] + for j in range(len(size_y) + 1): + sub_matrix.append(0) + matrix.append(sub_matrix) + for j in range(1, len(size_y) + 1): + matrix[0][j] = j*gap + for i in range(1, len(size_x) + 1): + matrix[i][0] = i*gap + return matrix + + +def get_matrix(size_x, size_y, gap): + matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) + matrix[0, 1:] = (np.arange(size_y) + 1) * gap + matrix[1:, 0] = (np.arange(size_x) + 1) * gap + return matrix + + +def get_traceback_matrix(size_x, size_y): + matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) + matrix[0, 1:] = 1 + matrix[1:, 0] = 2 + matrix[0, 0] = 4 + return matrix + + +def global_align(x, y, score): + matrix = get_matrix(len(x), len(y), score.gap) + trace_back = get_traceback_matrix(len(x), len(y)) + for i in range(1, len(x) + 1): + for j in range(1, len(y) + 1): + left = matrix[i, j - 1] + score.gap + up = matrix[i - 1, j] + score.gap + diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) + matrix[i, j] = max(left, up, diag) + if matrix[i, j] == left: + trace_back[i, j] = 1 + elif matrix[i, j] == up: + trace_back[i, j] = 2 + else: + trace_back[i, j] = 3 + return matrix, trace_back + + +def get_aligned_sequences(x, y, trace_back): + x_seq = [] + y_seq = [] + i = len(x) + j = len(y) + mapper_y_to_x = [] + while i > 0 or j > 0: + if trace_back[i, j] == 3: + x_seq.append(x[i-1]) + y_seq.append(y[j-1]) + i = i-1 + j = j-1 + mapper_y_to_x.append((j, i)) + elif trace_back[i][j] == 1: + x_seq.append('-') + y_seq.append(y[j-1]) + j = j-1 + mapper_y_to_x.append((j, -1)) + elif trace_back[i][j] == 2: + x_seq.append(x[i-1]) + y_seq.append('-') + i = i-1 + elif trace_back[i][j] == 4: + break + mapper_y_to_x.reverse() + return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) + + +def get_mapper(x: str, y: str, tokenizer, max_len=77): + x_seq = tokenizer.encode(x) + y_seq = tokenizer.encode(y) + score = ScoreParams(0, 1, -1) + matrix, trace_back = global_align(x_seq, y_seq, score) + mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] + alphas = torch.ones(max_len) + alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() + mapper = torch.zeros(max_len, dtype=torch.int64) + mapper[:mapper_base.shape[0]] = mapper_base[:, 1] + mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) + return mapper, alphas + + +def get_refinement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers, alphas = [], [] + for i in range(1, len(prompts)): + mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + alphas.append(alpha) + return torch.stack(mappers), torch.stack(alphas) + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): + words_x = x.split(' ') + words_y = y.split(' ') + if len(words_x) != len(words_y): + raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" + f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") + inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] + inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] + inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] + mapper = np.zeros((max_len, max_len)) + i = j = 0 + cur_inds = 0 + while i < max_len and j < max_len: + if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: + inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] + if len(inds_source_) == len(inds_target_): + mapper[inds_source_, inds_target_] = 1 + else: + ratio = 1 / len(inds_target_) + for i_t in inds_target_: + mapper[inds_source_, i_t] = ratio + cur_inds += 1 + i += len(inds_source_) + j += len(inds_target_) + elif cur_inds < len(inds_source): + mapper[i, j] = 1 + i += 1 + j += 1 + else: + mapper[j, j] = 1 + i += 1 + j += 1 + + return torch.from_numpy(mapper).float() + + + +def get_replacement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers = [] + for i in range(1, len(prompts)): + mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + return torch.stack(mappers) + diff --git a/FateZero/video_diffusion/trainer/ddpm_trainer.py b/FateZero/video_diffusion/trainer/ddpm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f498edddbf9c611a51360f3fc2b1d309b0243d2 --- /dev/null +++ b/FateZero/video_diffusion/trainer/ddpm_trainer.py @@ -0,0 +1,184 @@ +import inspect +from typing import Callable, List, Optional, Union + +import torch +import torch.nn.functional as F +from einops import rearrange + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from ..models.unet_3d_condition import UNetPseudo3DConditionModel +from video_diffusion.pipelines.stable_diffusion import SpatioTemporalStableDiffusionPipeline + +class DDPMTrainer(SpatioTemporalStableDiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNetPseudo3DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + **kwargs + ): + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + scheduler, + ) + for name, module in kwargs.items(): + setattr(self, name, module) + + def step(self, + batch: dict = dict()): + if 'class_images' in batch: + self.step2d(batch["class_images"], batch["class_prompt_ids"]) + self.vae.eval() + self.text_encoder.eval() + self.unet.train() + if self.prior_preservation is not None: + print('Use prior_preservation loss') + self.unet2d.eval() + + # with accelerator.accumulate(unet): + # Convert images to latent space + images = batch["images"].to(dtype=self.weight_dtype) + b = images.shape[0] + images = rearrange(images, "b c f h w -> (b f) c h w") + latents = self.vae.encode(images).latent_dist.sample() # shape=torch.Size([8, 3, 512, 512]), min=-1.00, max=0.98, var=0.21, -0.96875 + latents = rearrange(latents, "(b f) c h w -> b c f h w", b=b) + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, self.scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = self.text_encoder(batch["prompt_ids"])[0] + + # Predict the noise residual + model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if self.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.scheduler.config.prediction_type == "v_prediction": + target = self.scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + if self.prior_preservation is not None: + model_pred_2d = self.unet2d(noisy_latents[:, :, 0], timesteps, encoder_hidden_states).sample + loss = ( + loss + + F.mse_loss(model_pred[:, :, 0].float(), model_pred_2d.float(), reduction="mean") + * self.prior_preservation + ) + + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.unet.parameters(), self.max_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + return loss + + def step2d(self, class_images, prompt_ids + # batch: dict = dict() + ): + + self.vae.eval() + self.text_encoder.eval() + self.unet.train() + if self.prior_preservation is not None: + self.unet2d.eval() + + # with accelerator.accumulate(unet): + # Convert images to latent space + images = class_images.to(dtype=self.weight_dtype) + b = images.shape[0] + images = rearrange(images, "b c f h w -> (b f) c h w") + latents = self.vae.encode(images).latent_dist.sample() # shape=torch.Size([8, 3, 512, 512]), min=-1.00, max=0.98, var=0.21, -0.96875 + + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, self.scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = self.text_encoder(prompt_ids)[0] + + # Predict the noise residual + model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if self.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.scheduler.config.prediction_type == "v_prediction": + target = self.scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + if self.prior_preservation is not None: + model_pred_2d = self.unet2d(noisy_latents[:, :, 0], timesteps, encoder_hidden_states).sample + loss = ( + loss + + F.mse_loss(model_pred[:, :, 0].float(), model_pred_2d.float(), reduction="mean") + * self.prior_preservation + ) + + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.unet.parameters(), self.max_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + return loss \ No newline at end of file