Spaces:
Sleeping
Sleeping
checkout INstantDrag
Browse files- InstDrag/.gitignore +163 -0
- InstDrag/README.md +76 -0
- InstDrag/demo/demo_utils.py +242 -0
- InstDrag/demo/run_demo.py +226 -0
- InstDrag/demo/samples/airplane.jpg +0 -0
- InstDrag/demo/samples/anime.jpg +0 -0
- InstDrag/demo/samples/caligraphy.jpg +0 -0
- InstDrag/demo/samples/crocodile.jpg +0 -0
- InstDrag/demo/samples/elephant.jpg +0 -0
- InstDrag/demo/samples/meteor.jpg +0 -0
- InstDrag/demo/samples/monalisa.jpg +0 -0
- InstDrag/demo/samples/portrait.jpg +0 -0
- InstDrag/demo/samples/sketch.jpg +0 -0
- InstDrag/demo/samples/surreal.jpg +0 -0
- InstDrag/flowdiffusion/pipeline.py +495 -0
- InstDrag/flowgen/models.py +161 -0
- InstDrag/utils/flow_utils.py +143 -0
- InstDrag/utils/null_prompt.pt +3 -0
InstDrag/.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
demo/results
|
2 |
+
demo/checkpoints
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/#use-with-ide
|
113 |
+
.pdm.toml
|
114 |
+
|
115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
116 |
+
__pypackages__/
|
117 |
+
|
118 |
+
# Celery stuff
|
119 |
+
celerybeat-schedule
|
120 |
+
celerybeat.pid
|
121 |
+
|
122 |
+
# SageMath parsed files
|
123 |
+
*.sage.py
|
124 |
+
|
125 |
+
# Environments
|
126 |
+
.env
|
127 |
+
.venv
|
128 |
+
env/
|
129 |
+
venv/
|
130 |
+
ENV/
|
131 |
+
env.bak/
|
132 |
+
venv.bak/
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
InstDrag/README.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# InstantDrag
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<img src="assets/demo.gif" alt="Demo video">
|
5 |
+
</p>
|
6 |
+
|
7 |
+
<br/>
|
8 |
+
|
9 |
+
Official implementation of the paper **"InstantDrag: Improving Interactivity in Drag-based Image Editing"** (SIGGRAPH Asia 2024).
|
10 |
+
|
11 |
+
<p align="center">
|
12 |
+
<a href="https://arxiv.org/abs/2409.08857"><img src="https://img.shields.io/badge/arxiv-2409.08857-b31b1b"></a>
|
13 |
+
<a href="https://joonghyuk.com/instantdrag-web/"><img src="https://img.shields.io/badge/Project%20Page-InstantDrag-blue"></a>
|
14 |
+
<a href="https://huggingface.co/alex4727/InstantDrag"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-forestgreen"></a>
|
15 |
+
</p>
|
16 |
+
|
17 |
+
---
|
18 |
+
|
19 |
+
## Setup
|
20 |
+
|
21 |
+
1. Create and activate a conda environment:
|
22 |
+
```bash
|
23 |
+
conda create -n instantdrag python=3.10 -y
|
24 |
+
conda activate instantdrag
|
25 |
+
```
|
26 |
+
|
27 |
+
2. Install PyTorch:
|
28 |
+
```bash
|
29 |
+
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
|
30 |
+
```
|
31 |
+
|
32 |
+
3. Install other dependencies:
|
33 |
+
```bash
|
34 |
+
pip install transformers==4.44.2 diffusers==0.30.1 accelerate==0.33.0 gradio==4.44.0 opencv-python
|
35 |
+
```
|
36 |
+
**Note:** Exact version matching may not be necessary for all dependencies.
|
37 |
+
|
38 |
+
## Demo
|
39 |
+
|
40 |
+
To run the demo:
|
41 |
+
```bash
|
42 |
+
cd demo/
|
43 |
+
CUDA_VISIBLE_DEVICES=0 python run_demo.py
|
44 |
+
```
|
45 |
+
### Disclaimer
|
46 |
+
|
47 |
+
- Our **base** models are **solely** trained on real-world talking head (facial) videos, with a focus on achieving **fast fine-grained facial editing w/o metadata**. The preliminary signs of generalizability in other types of scenes, without fine-tuning, should be considered more of an experimental byproduct and may not perform well in many cases. Please check the Appendix A of our paper for more information.
|
48 |
+
- This is a research project, **NOT** a commercial product. Use at your own risk.
|
49 |
+
|
50 |
+
### Usage Instructions & Tips
|
51 |
+
|
52 |
+
- Upload and preprocess image using Gradio's interface.
|
53 |
+
- Click to define source and target point pairs on the image.
|
54 |
+
- Adjust settings in the "Configs" tab.
|
55 |
+
- We provide two checkpoints for FlowGen: config-2 (default, used for most figures in the paper) and config-3 (used for benchmark table in the paper). Generally, we recommend config-2 for most cases including few keypoints-based draggings. For extremely fine-grained editing with many drags (i.e. 68 keypoint drags as used in the benchmark), config-3 could be better suited as it produces more local movements.
|
56 |
+
- If image moves too much or too little, try modifying the image or flow guidance scales (usually 1 ~ 2 are recommended, but flow guidance can be larger).
|
57 |
+
- If you observe loss of identity or noisy artifacts, increasing image guidance or sampling steps could be helpful ([1.75, 1.5] scale is also a good choice for facial images).
|
58 |
+
- Click `Run` to perform the editing.
|
59 |
+
- We recommend first viewing the example videos (in project page or .gif) and paper figures to understand the model's capabilities. Then, begin with facial images using fine-grained keypoint drags before progressing to more complex motions.
|
60 |
+
- As noted in the paper, our model may struggle with large motions that exceed the capabilities of the optical flow estimation networks used for training data extraction.
|
61 |
+
- Notes on FlowGen Output Scale
|
62 |
+
- In many cases, especially for unseen domains, FlowGen's output doesn't precisely span the -1 to 1 range expected by FlowDiffusion's fixed-size normalization process. For all figures and benchmarks in our paper, we applied a static multiplier of 2 based on observations to adjust FlowGen's output to match the expected range. However, we found that forcefully rescaling the output to -1 to 1 also works well, so we set this as the default behavior (when value is -1). While not recommended, you can manually modify this value to scale the output of FlowGen before feeding it to FlowDiffusion for larger or smaller motions.
|
63 |
+
|
64 |
+
**Note:** The initial run may take longer as models are loaded to GPU.
|
65 |
+
|
66 |
+
## BibTeX
|
67 |
+
If you find this work useful, please cite them as below!
|
68 |
+
```
|
69 |
+
@inproceedings{shin2024instantdrag,
|
70 |
+
title = {{InstantDrag: Improving Interactivity in Drag-based Image Editing}},
|
71 |
+
author = {Shin, Joonghyuk and Choi, Daehyeon and Park, Jaesik},
|
72 |
+
booktitle = {ACM SIGGRAPH Asia 2024 Conference Proceedings},
|
73 |
+
year = {2024},
|
74 |
+
pages = {1--10},
|
75 |
+
}
|
76 |
+
```
|
InstDrag/demo/demo_utils.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("../")
|
3 |
+
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import time
|
7 |
+
import datetime
|
8 |
+
from copy import deepcopy
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import gradio as gr
|
15 |
+
from PIL import Image
|
16 |
+
from PIL.ImageOps import exif_transpose
|
17 |
+
from safetensors.torch import load_file
|
18 |
+
|
19 |
+
from utils.flow_utils import flow_to_image, resize_flow
|
20 |
+
from flowgen.models import UnetGenerator
|
21 |
+
from flowdiffusion.pipeline import FlowDiffusionPipeline
|
22 |
+
|
23 |
+
LENGTH = 512
|
24 |
+
FLOWGAN_RESOLUTION = [256, 256] # HxW
|
25 |
+
FLOWDIFFUSION_RESOLUTION = [512, 512] # HxW
|
26 |
+
|
27 |
+
def process_img(image):
|
28 |
+
if image["composite"] is not None and not np.all(image["composite"] == 0):
|
29 |
+
original_image = Image.fromarray(image["composite"]).resize((LENGTH, LENGTH), Image.BICUBIC)
|
30 |
+
original_image = np.array(exif_transpose(original_image))
|
31 |
+
return original_image, [], gr.Image(value=deepcopy(original_image), interactive=False)
|
32 |
+
else:
|
33 |
+
return (
|
34 |
+
gr.Image(value=None, interactive=False),
|
35 |
+
[],
|
36 |
+
gr.Image(value=None, interactive=False),
|
37 |
+
)
|
38 |
+
|
39 |
+
def get_points(img, sel_pix, evt: gr.SelectData):
|
40 |
+
sel_pix.append(evt.index)
|
41 |
+
print(sel_pix)
|
42 |
+
points = []
|
43 |
+
for idx, point in enumerate(sel_pix):
|
44 |
+
if idx % 2 == 0:
|
45 |
+
cv2.circle(img, tuple(point), 4, (255, 0, 0), -1)
|
46 |
+
else:
|
47 |
+
cv2.circle(img, tuple(point), 4, (0, 0, 255), -1)
|
48 |
+
points.append(tuple(point))
|
49 |
+
if len(points) == 2:
|
50 |
+
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5)
|
51 |
+
points = []
|
52 |
+
img = img if isinstance(img, np.ndarray) else np.array(img)
|
53 |
+
return img
|
54 |
+
|
55 |
+
def display_points(img, predefined_points, save_results):
|
56 |
+
if predefined_points != "":
|
57 |
+
predefined_points = predefined_points.split()
|
58 |
+
predefined_points = [int(re.sub(r'[^0-9]', '', point)) for point in predefined_points]
|
59 |
+
processed_points = []
|
60 |
+
for i, point in enumerate(predefined_points):
|
61 |
+
if i % 2 == 0:
|
62 |
+
processed_points.append([point, predefined_points[i+1]])
|
63 |
+
selected_points = processed_points
|
64 |
+
|
65 |
+
print(selected_points)
|
66 |
+
points = []
|
67 |
+
for idx, point in enumerate(selected_points):
|
68 |
+
if idx % 2 == 0:
|
69 |
+
cv2.circle(img, tuple(point), 4, (255, 0, 0), -1)
|
70 |
+
else:
|
71 |
+
cv2.circle(img, tuple(point), 4, (0, 0, 255), -1)
|
72 |
+
points.append(tuple(point))
|
73 |
+
if len(points) == 2:
|
74 |
+
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5)
|
75 |
+
points = []
|
76 |
+
img = img if isinstance(img, np.ndarray) else np.array(img)
|
77 |
+
|
78 |
+
if save_results:
|
79 |
+
if not os.path.isdir("results/drag_inst_viz"):
|
80 |
+
os.makedirs("results/drag_inst_viz")
|
81 |
+
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
|
82 |
+
to_save_img = Image.fromarray(img)
|
83 |
+
to_save_img.save(f"results/drag_inst_viz/{save_prefix}.png")
|
84 |
+
|
85 |
+
return img
|
86 |
+
|
87 |
+
def undo_points_image(original_image):
|
88 |
+
if original_image is not None:
|
89 |
+
return original_image, []
|
90 |
+
else:
|
91 |
+
return gr.Image(value=None, interactive=False), []
|
92 |
+
|
93 |
+
def clear_all():
|
94 |
+
return (
|
95 |
+
gr.Image(value=None, interactive=True),
|
96 |
+
gr.Image(value=None, interactive=False),
|
97 |
+
gr.Image(value=None, interactive=False),
|
98 |
+
[],
|
99 |
+
None
|
100 |
+
)
|
101 |
+
|
102 |
+
class InstantDragPipeline:
|
103 |
+
def __init__(self, seed=9999, device="cuda", dtype=torch.float16):
|
104 |
+
self.seed = seed
|
105 |
+
self.device = device
|
106 |
+
self.dtype = dtype
|
107 |
+
self.generator = torch.Generator(device=device).manual_seed(seed)
|
108 |
+
self.flowgen_ckpt, self.flowdiffusion_ckpt = None, None
|
109 |
+
self.model_config = dict()
|
110 |
+
|
111 |
+
def build_model(self):
|
112 |
+
print("Building model...")
|
113 |
+
if self.flowgen_ckpt != self.model_config["flowgen_ckpt"]:
|
114 |
+
self.flowgen = UnetGenerator(input_nc=5, output_nc=2)
|
115 |
+
self.flowgen.load_state_dict(
|
116 |
+
load_file(os.path.join("checkpoints/", self.model_config["flowgen_ckpt"]), device="cpu")
|
117 |
+
)
|
118 |
+
self.flowgen.to(self.device)
|
119 |
+
self.flowgen.eval()
|
120 |
+
self.flowgen_ckpt = self.model_config["flowgen_ckpt"]
|
121 |
+
|
122 |
+
if self.flowdiffusion_ckpt != self.model_config["flowdiffusion_ckpt"]:
|
123 |
+
self.flowdiffusion = FlowDiffusionPipeline.from_pretrained(
|
124 |
+
os.path.join("checkpoints/", self.model_config["flowdiffusion_ckpt"]),
|
125 |
+
torch_dtype=self.dtype,
|
126 |
+
safety_checker=None
|
127 |
+
)
|
128 |
+
self.flowdiffusion.to(self.device)
|
129 |
+
self.flowdiffusion_ckpt = self.model_config["flowdiffusion_ckpt"]
|
130 |
+
|
131 |
+
def drag(self, original_image, selected_points, save_results):
|
132 |
+
scale = self.model_config["flowgen_output_scale"]
|
133 |
+
original_image = torch.tensor(original_image).permute(2, 0, 1).unsqueeze(0).float() # 1, 3, 512, 512
|
134 |
+
original_image = 2 * (original_image / 255.) - 1 # Normalize to [-1, 1]
|
135 |
+
original_image = original_image.to(self.device)
|
136 |
+
|
137 |
+
source_points = []
|
138 |
+
target_points = []
|
139 |
+
for idx, point in enumerate(selected_points):
|
140 |
+
cur_point = torch.tensor([point[0], point[1]]) # x, y
|
141 |
+
if idx % 2 == 0:
|
142 |
+
source_points.append(cur_point)
|
143 |
+
else:
|
144 |
+
target_points.append(cur_point)
|
145 |
+
|
146 |
+
torch.cuda.synchronize()
|
147 |
+
start_time = time.time()
|
148 |
+
|
149 |
+
# Generate sparse flow vectors
|
150 |
+
point_vector_map = torch.zeros((1, 2, LENGTH, LENGTH))
|
151 |
+
for source_point, target_point in zip(source_points, target_points):
|
152 |
+
cur_x, cur_y = source_point[0], source_point[1]
|
153 |
+
target_x, target_y = target_point[0], target_point[1]
|
154 |
+
vec_x = target_x - cur_x
|
155 |
+
vec_y = target_y - cur_y
|
156 |
+
point_vector_map[0, 0, int(cur_y), int(cur_x)] = vec_x
|
157 |
+
point_vector_map[0, 1, int(cur_y), int(cur_x)] = vec_y
|
158 |
+
point_vector_map = point_vector_map.to(self.device)
|
159 |
+
|
160 |
+
# Sample-wise normalize the flow vectors
|
161 |
+
factor_x = torch.amax(torch.abs(point_vector_map[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device)
|
162 |
+
factor_y = torch.amax(torch.abs(point_vector_map[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device)
|
163 |
+
if factor_x >= 1e-8: # Avoid division by zero
|
164 |
+
point_vector_map[:, 0, :, :] /= factor_x
|
165 |
+
if factor_y >= 1e-8: # Avoid division by zero
|
166 |
+
point_vector_map[:, 1, :, :] /= factor_y
|
167 |
+
|
168 |
+
with torch.inference_mode():
|
169 |
+
gan_input_image = F.interpolate(original_image, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256
|
170 |
+
point_vector_map = F.interpolate(point_vector_map, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256
|
171 |
+
gan_input = torch.cat([gan_input_image, point_vector_map], dim=1)
|
172 |
+
flow = self.flowgen(gan_input) # -1 ~ 1
|
173 |
+
|
174 |
+
if scale == -1.0:
|
175 |
+
flow[:, 0, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1]
|
176 |
+
flow[:, 1, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1]
|
177 |
+
else:
|
178 |
+
flow[:, 0, :, :] *= scale # manually adjust the scale
|
179 |
+
flow[:, 1, :, :] *= scale # manually adjust the scale
|
180 |
+
|
181 |
+
if factor_x >= 1e-8:
|
182 |
+
flow[:, 0, :, :] *= factor_x * (FLOWGAN_RESOLUTION[1] / original_image.shape[3]) # width
|
183 |
+
else:
|
184 |
+
flow[:, 0, :, :] *= 0
|
185 |
+
if factor_y >= 1e-8:
|
186 |
+
flow[:, 1, :, :] *= factor_y * (FLOWGAN_RESOLUTION[0] / original_image.shape[2]) # height
|
187 |
+
else:
|
188 |
+
flow[:, 1, :, :] *= 0
|
189 |
+
|
190 |
+
resized_flow = resize_flow(flow, (FLOWDIFFUSION_RESOLUTION[0]//8, FLOWDIFFUSION_RESOLUTION[1]//8), scale_type="normalize_fixed")
|
191 |
+
|
192 |
+
kwargs = {
|
193 |
+
"image": original_image.to(self.dtype),
|
194 |
+
"flow": resized_flow.to(self.dtype),
|
195 |
+
"num_inference_steps": self.model_config['n_inference_step'],
|
196 |
+
"image_guidance_scale": self.model_config['image_guidance'],
|
197 |
+
"flow_guidance_scale": self.model_config['flow_guidance'],
|
198 |
+
"generator": self.generator,
|
199 |
+
}
|
200 |
+
edited_image = self.flowdiffusion(**kwargs).images[0]
|
201 |
+
|
202 |
+
end_time = time.time()
|
203 |
+
inference_time = end_time - start_time
|
204 |
+
print(f"Inference Time: {inference_time} seconds")
|
205 |
+
|
206 |
+
if save_results:
|
207 |
+
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
|
208 |
+
if not os.path.isdir("results/flows"):
|
209 |
+
os.makedirs("results/flows")
|
210 |
+
np.save(f"results/flows/{save_prefix}.npy", flow[0].detach().cpu().numpy())
|
211 |
+
if not os.path.isdir("results/flow_visualized"):
|
212 |
+
os.makedirs("results/flow_visualized")
|
213 |
+
flow_to_image(flow[0].detach()).save(f"results/flow_visualized/{save_prefix}.png")
|
214 |
+
if not os.path.isdir("results/edited_images"):
|
215 |
+
os.makedirs("results/edited_images")
|
216 |
+
edited_image.save(f"results/edited_images/{save_prefix}.png")
|
217 |
+
if not os.path.isdir("results/drag_instructions"):
|
218 |
+
os.makedirs("results/drag_instructions")
|
219 |
+
with open(f"results/drag_instructions/{save_prefix}.txt", "w") as f:
|
220 |
+
f.write(str(selected_points))
|
221 |
+
|
222 |
+
edited_image = np.array(edited_image)
|
223 |
+
return edited_image
|
224 |
+
|
225 |
+
def run(self, original_image, selected_points,
|
226 |
+
flowgen_ckpt, flowdiffusion_ckpt, image_guidance, flow_guidance, flowgen_output_scale,
|
227 |
+
num_steps, save_results):
|
228 |
+
|
229 |
+
self.model_config = {
|
230 |
+
"flowgen_ckpt": flowgen_ckpt,
|
231 |
+
"flowdiffusion_ckpt": flowdiffusion_ckpt,
|
232 |
+
"image_guidance": image_guidance,
|
233 |
+
"flow_guidance": flow_guidance,
|
234 |
+
"flowgen_output_scale": flowgen_output_scale,
|
235 |
+
"n_inference_step": num_steps
|
236 |
+
}
|
237 |
+
|
238 |
+
self.build_model()
|
239 |
+
|
240 |
+
edited_image = self.drag(original_image, selected_points, save_results)
|
241 |
+
|
242 |
+
return edited_image
|
InstDrag/demo/run_demo.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
+
os.makedirs("checkpoints", exist_ok=True)
|
6 |
+
snapshot_download("alex4727/InstantDrag", local_dir="./checkpoints")
|
7 |
+
|
8 |
+
from demo_utils import (
|
9 |
+
process_img,
|
10 |
+
get_points,
|
11 |
+
undo_points_image,
|
12 |
+
clear_all,
|
13 |
+
InstantDragPipeline,
|
14 |
+
)
|
15 |
+
|
16 |
+
LENGTH = 480 # Length of the square area displaying/editing images
|
17 |
+
|
18 |
+
with gr.Blocks() as demo:
|
19 |
+
pipeline = InstantDragPipeline(seed=42, device="cuda", dtype=torch.float16)
|
20 |
+
|
21 |
+
# Layout definition
|
22 |
+
with gr.Row():
|
23 |
+
gr.Markdown(
|
24 |
+
"""
|
25 |
+
# InstantDrag: Improving Interactivity in Drag-based Image Editing
|
26 |
+
"""
|
27 |
+
)
|
28 |
+
|
29 |
+
with gr.Tab(label="InstantDrag Demo"):
|
30 |
+
selected_points = gr.State([]) # Store points
|
31 |
+
original_image = gr.State(value=None) # Store original input image
|
32 |
+
|
33 |
+
with gr.Row():
|
34 |
+
# Upload & Preprocess Image Column
|
35 |
+
with gr.Column():
|
36 |
+
gr.Markdown(
|
37 |
+
"""<p style="text-align: center; font-size: 20px">Upload & Preprocess Image</p>"""
|
38 |
+
)
|
39 |
+
canvas = gr.ImageEditor(
|
40 |
+
height=LENGTH,
|
41 |
+
width=LENGTH,
|
42 |
+
type="numpy",
|
43 |
+
image_mode="RGB",
|
44 |
+
label="Preprocess Image",
|
45 |
+
show_label=True,
|
46 |
+
interactive=True,
|
47 |
+
)
|
48 |
+
with gr.Row():
|
49 |
+
save_results = gr.Checkbox(
|
50 |
+
value=False,
|
51 |
+
label="Save Results",
|
52 |
+
scale=1,
|
53 |
+
)
|
54 |
+
undo_button = gr.Button("Undo Clicked Points", scale=3)
|
55 |
+
|
56 |
+
# Click Points Column
|
57 |
+
with gr.Column():
|
58 |
+
gr.Markdown(
|
59 |
+
"""<p style="text-align: center; font-size: 20px">Click Points</p>"""
|
60 |
+
)
|
61 |
+
input_image = gr.Image(
|
62 |
+
type="numpy",
|
63 |
+
label="Click Points",
|
64 |
+
show_label=True,
|
65 |
+
height=LENGTH,
|
66 |
+
width=LENGTH,
|
67 |
+
interactive=False,
|
68 |
+
show_fullscreen_button=False,
|
69 |
+
)
|
70 |
+
with gr.Row():
|
71 |
+
run_button = gr.Button("Run")
|
72 |
+
|
73 |
+
# Editing Results Column
|
74 |
+
with gr.Column():
|
75 |
+
gr.Markdown(
|
76 |
+
"""<p style="text-align: center; font-size: 20px">Editing Results</p>"""
|
77 |
+
)
|
78 |
+
edited_image = gr.Image(
|
79 |
+
type="numpy",
|
80 |
+
label="Editing Results",
|
81 |
+
show_label=True,
|
82 |
+
height=LENGTH,
|
83 |
+
width=LENGTH,
|
84 |
+
interactive=False,
|
85 |
+
show_fullscreen_button=False,
|
86 |
+
)
|
87 |
+
with gr.Row():
|
88 |
+
clear_all_button = gr.Button("Clear All")
|
89 |
+
|
90 |
+
with gr.Tab("Configs - make sure to check README for details"):
|
91 |
+
with gr.Row():
|
92 |
+
with gr.Column():
|
93 |
+
with gr.Row():
|
94 |
+
flowgen_choices = sorted(
|
95 |
+
[model for model in os.listdir("checkpoints/") if "flowgen" in model]
|
96 |
+
)
|
97 |
+
flowgen_ckpt = gr.Dropdown(
|
98 |
+
value=flowgen_choices[0],
|
99 |
+
label="Select FlowGen to use",
|
100 |
+
choices=flowgen_choices,
|
101 |
+
info="config2 for most cases, config3 for more fine-grained dragging",
|
102 |
+
scale=2,
|
103 |
+
)
|
104 |
+
flowdiffusion_choices = sorted(
|
105 |
+
[model for model in os.listdir("checkpoints/") if "flowdiffusion" in model]
|
106 |
+
)
|
107 |
+
flowdiffusion_ckpt = gr.Dropdown(
|
108 |
+
value=flowdiffusion_choices[0],
|
109 |
+
label="Select FlowDiffusion to use",
|
110 |
+
choices=flowdiffusion_choices,
|
111 |
+
info="single model for all cases",
|
112 |
+
scale=1,
|
113 |
+
)
|
114 |
+
image_guidance = gr.Number(
|
115 |
+
value=1.5,
|
116 |
+
label="Image Guidance Scale",
|
117 |
+
precision=2,
|
118 |
+
step=0.1,
|
119 |
+
scale=1,
|
120 |
+
info="typically between 1.0-2.0.",
|
121 |
+
)
|
122 |
+
flow_guidance = gr.Number(
|
123 |
+
value=1.5,
|
124 |
+
label="Flow Guidance Scale",
|
125 |
+
precision=2,
|
126 |
+
step=0.1,
|
127 |
+
scale=1,
|
128 |
+
info="typically between 1.0-5.0",
|
129 |
+
)
|
130 |
+
num_steps = gr.Number(
|
131 |
+
value=20,
|
132 |
+
label="Inference Steps",
|
133 |
+
precision=0,
|
134 |
+
step=1,
|
135 |
+
scale=1,
|
136 |
+
info="typically between 20-50, 20 is usually enough",
|
137 |
+
)
|
138 |
+
flowgen_output_scale = gr.Number(
|
139 |
+
value=-1.0,
|
140 |
+
label="FlowGen Output Scale",
|
141 |
+
precision=1,
|
142 |
+
step=0.1,
|
143 |
+
scale=2,
|
144 |
+
info="-1.0, by default, forces flowgen's output to [-1, 1], could be adjusted to [0, ∞] for stronger/weaker effects",
|
145 |
+
)
|
146 |
+
|
147 |
+
gr.Markdown(
|
148 |
+
"""
|
149 |
+
<p style="text-align: center; font-size: 18px;">Examples</p>
|
150 |
+
"""
|
151 |
+
)
|
152 |
+
with gr.Row():
|
153 |
+
gr.Examples(
|
154 |
+
examples=[
|
155 |
+
"samples/airplane.jpg",
|
156 |
+
"samples/anime.jpg",
|
157 |
+
"samples/caligraphy.jpg",
|
158 |
+
"samples/crocodile.jpg",
|
159 |
+
"samples/elephant.jpg",
|
160 |
+
"samples/meteor.jpg",
|
161 |
+
"samples/monalisa.jpg",
|
162 |
+
"samples/portrait.jpg",
|
163 |
+
"samples/sketch.jpg",
|
164 |
+
"samples/surreal.jpg",
|
165 |
+
],
|
166 |
+
inputs=[canvas],
|
167 |
+
outputs=[original_image, selected_points, input_image],
|
168 |
+
fn=process_img,
|
169 |
+
cache_examples=False,
|
170 |
+
examples_per_page=10,
|
171 |
+
)
|
172 |
+
gr.Markdown(
|
173 |
+
"""
|
174 |
+
<p style="text-align: center; font-size: 9">[Important] Our base models are solely trained on real-world talking head (facial) videos, with a focus on achieving fine-grained facial editing. <br>
|
175 |
+
Their application to other types of scenes, without fine-tuning, should be considered more of an experimental byproduct and may not perform well in many cases (we currently support only square images).</p>
|
176 |
+
"""
|
177 |
+
)
|
178 |
+
|
179 |
+
# Event Handlers
|
180 |
+
canvas.change(
|
181 |
+
process_img,
|
182 |
+
[canvas],
|
183 |
+
[original_image, selected_points, input_image],
|
184 |
+
)
|
185 |
+
|
186 |
+
input_image.select(
|
187 |
+
get_points,
|
188 |
+
[input_image, selected_points],
|
189 |
+
[input_image],
|
190 |
+
)
|
191 |
+
|
192 |
+
undo_button.click(
|
193 |
+
undo_points_image,
|
194 |
+
[original_image],
|
195 |
+
[input_image, selected_points],
|
196 |
+
)
|
197 |
+
|
198 |
+
run_button.click(
|
199 |
+
pipeline.run,
|
200 |
+
[
|
201 |
+
original_image,
|
202 |
+
selected_points,
|
203 |
+
flowgen_ckpt,
|
204 |
+
flowdiffusion_ckpt,
|
205 |
+
image_guidance,
|
206 |
+
flow_guidance,
|
207 |
+
flowgen_output_scale,
|
208 |
+
num_steps,
|
209 |
+
save_results,
|
210 |
+
],
|
211 |
+
[edited_image],
|
212 |
+
)
|
213 |
+
|
214 |
+
clear_all_button.click(
|
215 |
+
clear_all,
|
216 |
+
[],
|
217 |
+
[
|
218 |
+
canvas,
|
219 |
+
input_image,
|
220 |
+
edited_image,
|
221 |
+
selected_points,
|
222 |
+
original_image,
|
223 |
+
],
|
224 |
+
)
|
225 |
+
|
226 |
+
demo.queue().launch(share=False, debug=True)
|
InstDrag/demo/samples/airplane.jpg
ADDED
![]() |
InstDrag/demo/samples/anime.jpg
ADDED
![]() |
InstDrag/demo/samples/caligraphy.jpg
ADDED
![]() |
InstDrag/demo/samples/crocodile.jpg
ADDED
![]() |
InstDrag/demo/samples/elephant.jpg
ADDED
![]() |
InstDrag/demo/samples/meteor.jpg
ADDED
![]() |
InstDrag/demo/samples/monalisa.jpg
ADDED
![]() |
InstDrag/demo/samples/portrait.jpg
ADDED
![]() |
InstDrag/demo/samples/sketch.jpg
ADDED
![]() |
InstDrag/demo/samples/surreal.jpg
ADDED
![]() |
InstDrag/flowdiffusion/pipeline.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is partially based on the diffusers library, which licensed the code under the following license:
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
19 |
+
import os
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
import PIL.Image
|
23 |
+
import torch
|
24 |
+
from transformers import CLIPImageProcessor
|
25 |
+
|
26 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
27 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
28 |
+
from diffusers.loaders import StableDiffusionLoraLoaderMixin
|
29 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
30 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
31 |
+
from diffusers.utils import deprecate, logging
|
32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
33 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
34 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
35 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
41 |
+
def retrieve_latents(
|
42 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
43 |
+
):
|
44 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
45 |
+
return encoder_output.latent_dist.sample(generator)
|
46 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
47 |
+
return encoder_output.latent_dist.mode()
|
48 |
+
elif hasattr(encoder_output, "latents"):
|
49 |
+
return encoder_output.latents
|
50 |
+
else:
|
51 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
52 |
+
|
53 |
+
|
54 |
+
class FlowDiffusionPipeline(
|
55 |
+
DiffusionPipeline,
|
56 |
+
StableDiffusionMixin,
|
57 |
+
StableDiffusionLoraLoaderMixin,
|
58 |
+
):
|
59 |
+
r"""
|
60 |
+
Pipeline for pixel-level image editing given optical flow as condition.
|
61 |
+
|
62 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
63 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
64 |
+
|
65 |
+
The pipeline also inherits the following loading methods:
|
66 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
67 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
68 |
+
|
69 |
+
Args:
|
70 |
+
vae ([`AutoencoderKL`]):
|
71 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
72 |
+
unet ([`UNet2DConditionModel`]):
|
73 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
74 |
+
scheduler ([`SchedulerMixin`]):
|
75 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
76 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
77 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
78 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
79 |
+
about a model's potential harms.
|
80 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
81 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
82 |
+
"""
|
83 |
+
|
84 |
+
model_cpu_offload_seq = "unet->vae"
|
85 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
86 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
87 |
+
_callback_tensor_inputs = ["latents", "image_latents"]
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
vae: AutoencoderKL,
|
92 |
+
unet: UNet2DConditionModel,
|
93 |
+
scheduler: KarrasDiffusionSchedulers,
|
94 |
+
safety_checker: StableDiffusionSafetyChecker,
|
95 |
+
feature_extractor: CLIPImageProcessor,
|
96 |
+
requires_safety_checker: bool = False,
|
97 |
+
null_prompt: str = "../utils/null_prompt.pt"
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
if safety_checker is None and requires_safety_checker:
|
102 |
+
logger.warning(
|
103 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
104 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
105 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
106 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
107 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
108 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
109 |
+
)
|
110 |
+
|
111 |
+
if safety_checker is not None and feature_extractor is None:
|
112 |
+
raise ValueError(
|
113 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
114 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
115 |
+
)
|
116 |
+
|
117 |
+
self.register_modules(
|
118 |
+
vae=vae,
|
119 |
+
unet=unet,
|
120 |
+
scheduler=scheduler,
|
121 |
+
safety_checker=safety_checker,
|
122 |
+
feature_extractor=feature_extractor,
|
123 |
+
)
|
124 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
125 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
126 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
127 |
+
self.null_prompt_embeds = torch.load(os.path.join(Path(__file__).parent.absolute(), null_prompt), map_location="cpu")
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def __call__(
|
131 |
+
self,
|
132 |
+
image: PipelineImageInput = None,
|
133 |
+
flow: torch.Tensor = None,
|
134 |
+
num_inference_steps: int = 20,
|
135 |
+
image_guidance_scale: float = 1.5,
|
136 |
+
flow_guidance_scale: float = 1.5,
|
137 |
+
eta: float = 0.0,
|
138 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
139 |
+
latents: Optional[torch.Tensor] = None,
|
140 |
+
output_type: Optional[str] = "pil",
|
141 |
+
return_dict: bool = True,
|
142 |
+
callback_on_step_end: Optional[
|
143 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
144 |
+
] = None,
|
145 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
146 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
147 |
+
**kwargs,
|
148 |
+
):
|
149 |
+
r"""
|
150 |
+
The call function to the pipeline for generation.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
154 |
+
`Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
|
155 |
+
image latents as `image`, but if passing latents directly it is not encoded again. We only support batch size of 1 for now.
|
156 |
+
flow: torch.Tensor = None,
|
157 |
+
Optical flow tensor to be used as a condition for the image generation. We only support batch size of 1 for now.
|
158 |
+
num_inference_steps (`int`, *optional*, defaults to 20):
|
159 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
160 |
+
expense of slower inference.
|
161 |
+
image_guidance_scale (`float`, *optional*, defaults to 1.5):
|
162 |
+
Push the generated image towards the initial `image`. Image guidance scale is enabled by setting
|
163 |
+
`image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
|
164 |
+
linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
|
165 |
+
value of at least `1`.
|
166 |
+
flow_guidance_scale (`float`, *optional*, defaults to 1.5):
|
167 |
+
Apply the flow guidance to the image generation. Higher values of `flow_guidance_scale` encourage
|
168 |
+
the model to follow the flow stronger.
|
169 |
+
eta (`float`, *optional*, defaults to 0.0):
|
170 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
171 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
172 |
+
generator (`torch.Generator`, *optional*):
|
173 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
174 |
+
generation deterministic.
|
175 |
+
latents (`torch.Tensor`, *optional*):
|
176 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
177 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
178 |
+
tensor is generated by sampling using the supplied random `generator`.
|
179 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
180 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
181 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
182 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
183 |
+
plain tuple.
|
184 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
185 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
186 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
187 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
188 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
189 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
190 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
191 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
192 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
193 |
+
cross_attention_kwargs (`dict`, *optional*):
|
194 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
195 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
196 |
+
"""
|
197 |
+
|
198 |
+
callback = kwargs.pop("callback", None)
|
199 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
200 |
+
|
201 |
+
if callback is not None:
|
202 |
+
deprecate(
|
203 |
+
"callback",
|
204 |
+
"1.0.0",
|
205 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
206 |
+
)
|
207 |
+
if callback_steps is not None:
|
208 |
+
deprecate(
|
209 |
+
"callback_steps",
|
210 |
+
"1.0.0",
|
211 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
212 |
+
)
|
213 |
+
|
214 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
215 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
216 |
+
|
217 |
+
# 0. Check inputs
|
218 |
+
self.check_inputs(
|
219 |
+
callback_steps,
|
220 |
+
callback_on_step_end_tensor_inputs,
|
221 |
+
)
|
222 |
+
self._image_guidance_scale = image_guidance_scale
|
223 |
+
self._flow_guidance_scale = flow_guidance_scale
|
224 |
+
|
225 |
+
device = self._execution_device
|
226 |
+
|
227 |
+
if image is None or flow is None:
|
228 |
+
raise ValueError("`image` or `flow` input cannot be undefined.")
|
229 |
+
|
230 |
+
# 1. Define call parameters
|
231 |
+
|
232 |
+
# 2. Encode input prompt
|
233 |
+
prompt_embeds = self._encode_prompt(
|
234 |
+
device,
|
235 |
+
self.do_classifier_free_guidance,
|
236 |
+
)
|
237 |
+
|
238 |
+
# 3. Preprocess image
|
239 |
+
image = self.image_processor.preprocess(image)
|
240 |
+
assert image.shape[0] == 1 and flow.shape[0] == 1, "Batch size must be 1 for now."
|
241 |
+
|
242 |
+
# 4. set timesteps
|
243 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
244 |
+
timesteps = self.scheduler.timesteps
|
245 |
+
|
246 |
+
# 5. Prepare Image latents
|
247 |
+
image_latents = self.prepare_image_latents(
|
248 |
+
image,
|
249 |
+
flow,
|
250 |
+
prompt_embeds.dtype,
|
251 |
+
device,
|
252 |
+
self.do_classifier_free_guidance,
|
253 |
+
)
|
254 |
+
|
255 |
+
height, width = image_latents.shape[-2:]
|
256 |
+
height = height * self.vae_scale_factor
|
257 |
+
width = width * self.vae_scale_factor
|
258 |
+
|
259 |
+
# 6. Prepare latent variables
|
260 |
+
num_channels_latents = self.vae.config.latent_channels
|
261 |
+
latents = self.prepare_latents(
|
262 |
+
num_channels_latents,
|
263 |
+
height,
|
264 |
+
width,
|
265 |
+
prompt_embeds.dtype,
|
266 |
+
device,
|
267 |
+
generator,
|
268 |
+
latents,
|
269 |
+
)
|
270 |
+
|
271 |
+
# 7. Check that shapes of latents and image match the UNet channels
|
272 |
+
num_channels_image = image_latents.shape[1]
|
273 |
+
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
274 |
+
raise ValueError(
|
275 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
276 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
277 |
+
f" `num_channels_image`: {num_channels_image} "
|
278 |
+
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
279 |
+
" `pipeline.unet` or your `image` input."
|
280 |
+
)
|
281 |
+
|
282 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
283 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
284 |
+
|
285 |
+
# 9. Denoising loop
|
286 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
287 |
+
self._num_timesteps = len(timesteps)
|
288 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
289 |
+
for i, t in enumerate(timesteps):
|
290 |
+
# Expand the latents if we are doing classifier free guidance.
|
291 |
+
# The latents are expanded 3 times because for image / flow guidance
|
292 |
+
latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
|
293 |
+
|
294 |
+
# concat latents, image_latents in the channel dimension
|
295 |
+
scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
296 |
+
scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
|
297 |
+
|
298 |
+
# predict the noise residual
|
299 |
+
noise_pred = self.unet(
|
300 |
+
scaled_latent_model_input,
|
301 |
+
t,
|
302 |
+
encoder_hidden_states=prompt_embeds,
|
303 |
+
added_cond_kwargs=None,
|
304 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
305 |
+
return_dict=False,
|
306 |
+
)[0]
|
307 |
+
|
308 |
+
# perform guidance
|
309 |
+
if self.do_classifier_free_guidance:
|
310 |
+
noise_pred_flow, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
|
311 |
+
noise_pred = (
|
312 |
+
noise_pred_uncond
|
313 |
+
+ self._image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
314 |
+
+ self._flow_guidance_scale * (noise_pred_flow - noise_pred_image)
|
315 |
+
)
|
316 |
+
|
317 |
+
# compute the previous noisy sample x_t -> x_t-1
|
318 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
319 |
+
|
320 |
+
if callback_on_step_end is not None:
|
321 |
+
callback_kwargs = {}
|
322 |
+
for k in callback_on_step_end_tensor_inputs:
|
323 |
+
callback_kwargs[k] = locals()[k]
|
324 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
325 |
+
|
326 |
+
latents = callback_outputs.pop("latents", latents)
|
327 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
328 |
+
image_latents = callback_outputs.pop("image_latents", image_latents)
|
329 |
+
|
330 |
+
# call the callback, if provided
|
331 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
332 |
+
progress_bar.update()
|
333 |
+
if callback is not None and i % callback_steps == 0:
|
334 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
335 |
+
callback(step_idx, t, latents)
|
336 |
+
|
337 |
+
if not output_type == "latent":
|
338 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
339 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
340 |
+
else:
|
341 |
+
image = latents
|
342 |
+
has_nsfw_concept = None
|
343 |
+
|
344 |
+
if has_nsfw_concept is None:
|
345 |
+
do_denormalize = [True] * image.shape[0]
|
346 |
+
else:
|
347 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
348 |
+
|
349 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
350 |
+
|
351 |
+
# Offload all models
|
352 |
+
self.maybe_free_model_hooks()
|
353 |
+
|
354 |
+
if not return_dict:
|
355 |
+
return (image, has_nsfw_concept)
|
356 |
+
|
357 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
358 |
+
|
359 |
+
def _encode_prompt(
|
360 |
+
self,
|
361 |
+
device,
|
362 |
+
do_classifier_free_guidance,
|
363 |
+
):
|
364 |
+
prompt_embeds = self.null_prompt_embeds.to(dtype=torch.float16, device=device) # 1 77 512
|
365 |
+
|
366 |
+
if do_classifier_free_guidance: # We are only doing cfg for image and flow
|
367 |
+
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds, prompt_embeds]) # 3 77 512
|
368 |
+
|
369 |
+
return prompt_embeds
|
370 |
+
|
371 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
372 |
+
def run_safety_checker(self, image, device, dtype):
|
373 |
+
if self.safety_checker is None:
|
374 |
+
has_nsfw_concept = None
|
375 |
+
else:
|
376 |
+
if torch.is_tensor(image):
|
377 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
378 |
+
else:
|
379 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
380 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
381 |
+
image, has_nsfw_concept = self.safety_checker(
|
382 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
383 |
+
)
|
384 |
+
return image, has_nsfw_concept
|
385 |
+
|
386 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
387 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
388 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
389 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
390 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
391 |
+
# and should be between [0, 1]
|
392 |
+
|
393 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
394 |
+
extra_step_kwargs = {}
|
395 |
+
if accepts_eta:
|
396 |
+
extra_step_kwargs["eta"] = eta
|
397 |
+
|
398 |
+
# check if the scheduler accepts generator
|
399 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
400 |
+
if accepts_generator:
|
401 |
+
extra_step_kwargs["generator"] = generator
|
402 |
+
return extra_step_kwargs
|
403 |
+
|
404 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
405 |
+
def decode_latents(self, latents):
|
406 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
407 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
408 |
+
|
409 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
410 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
411 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
412 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
413 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
414 |
+
return image
|
415 |
+
|
416 |
+
def check_inputs(
|
417 |
+
self,
|
418 |
+
callback_steps,
|
419 |
+
callback_on_step_end_tensor_inputs=None,
|
420 |
+
):
|
421 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
422 |
+
raise ValueError(
|
423 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
424 |
+
f" {type(callback_steps)}."
|
425 |
+
)
|
426 |
+
|
427 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
428 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
429 |
+
):
|
430 |
+
raise ValueError(
|
431 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
432 |
+
)
|
433 |
+
|
434 |
+
def prepare_latents(self, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
435 |
+
shape = (
|
436 |
+
1,
|
437 |
+
num_channels_latents,
|
438 |
+
int(height) // self.vae_scale_factor,
|
439 |
+
int(width) // self.vae_scale_factor,
|
440 |
+
)
|
441 |
+
if isinstance(generator, list) and len(generator) != 1:
|
442 |
+
raise ValueError(
|
443 |
+
f"You have passed a list of generators of length {len(generator)}, but we only support a single batch for now."
|
444 |
+
)
|
445 |
+
|
446 |
+
if latents is None:
|
447 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
448 |
+
else:
|
449 |
+
latents = latents.to(device)
|
450 |
+
|
451 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
452 |
+
latents = latents * self.scheduler.init_noise_sigma
|
453 |
+
return latents
|
454 |
+
|
455 |
+
def prepare_image_latents(
|
456 |
+
self, image, flow, dtype, device, do_classifier_free_guidance, generator=None
|
457 |
+
):
|
458 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
459 |
+
raise ValueError(
|
460 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
461 |
+
)
|
462 |
+
|
463 |
+
image = image.to(device=device, dtype=dtype)
|
464 |
+
|
465 |
+
if image.shape[1] == 4:
|
466 |
+
image_latents = image
|
467 |
+
else:
|
468 |
+
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
469 |
+
|
470 |
+
image_latents_flow_cond = torch.cat([image_latents, flow.to(device)], dim=1)
|
471 |
+
|
472 |
+
if do_classifier_free_guidance:
|
473 |
+
image_latents_flow_uncond = torch.cat([image_latents, torch.zeros_like(flow).to(device)], dim=1)
|
474 |
+
image_latents_uncond = torch.zeros_like(image_latents_flow_cond)
|
475 |
+
image_latents_final = torch.cat([image_latents_flow_cond, image_latents_flow_uncond, image_latents_uncond], dim=0)
|
476 |
+
else:
|
477 |
+
image_latents_final = image_latents_flow_cond
|
478 |
+
|
479 |
+
return image_latents_final
|
480 |
+
|
481 |
+
@property
|
482 |
+
def image_guidance_scale(self):
|
483 |
+
return self._image_guidance_scale
|
484 |
+
|
485 |
+
@property
|
486 |
+
def flow_guidance_scale(self):
|
487 |
+
return self._flow_guidance_scale
|
488 |
+
|
489 |
+
@property
|
490 |
+
def num_timesteps(self):
|
491 |
+
return self._num_timesteps
|
492 |
+
|
493 |
+
@property
|
494 |
+
def do_classifier_free_guidance(self):
|
495 |
+
return self._image_guidance_scale > 1 or self._flow_guidance_scale > 1
|
InstDrag/flowgen/models.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import functools
|
6 |
+
|
7 |
+
class UnetSkipConnectionBlock(nn.Module):
|
8 |
+
"""Defines the Unet submodule with skip connection.
|
9 |
+
X -------------------identity----------------------
|
10 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
14 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
15 |
+
"""Construct a Unet submodule with skip connections.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
19 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
20 |
+
input_nc (int) -- the number of channels in input images/features
|
21 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
22 |
+
outermost (bool) -- if this module is the outermost module
|
23 |
+
innermost (bool) -- if this module is the innermost module
|
24 |
+
norm_layer -- normalization layer
|
25 |
+
use_dropout (bool) -- if use dropout layers.
|
26 |
+
"""
|
27 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
28 |
+
self.outermost = outermost
|
29 |
+
if type(norm_layer) == functools.partial:
|
30 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
31 |
+
else:
|
32 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
33 |
+
if input_nc is None:
|
34 |
+
input_nc = outer_nc
|
35 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
36 |
+
stride=2, padding=1, bias=use_bias)
|
37 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
38 |
+
|
39 |
+
if norm_layer == nn.GroupNorm:
|
40 |
+
downnorm = norm_layer(32, inner_nc)
|
41 |
+
else: downnorm = norm_layer(inner_nc)
|
42 |
+
uprelu = nn.ReLU(True)
|
43 |
+
if norm_layer == nn.GroupNorm:
|
44 |
+
if outer_nc % 32 != 0:
|
45 |
+
upnorm = norm_layer(outer_nc, outer_nc) # Layer Norm
|
46 |
+
else:
|
47 |
+
upnorm = norm_layer(32, outer_nc)
|
48 |
+
else:
|
49 |
+
upnorm = norm_layer(outer_nc)
|
50 |
+
|
51 |
+
if outermost:
|
52 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
53 |
+
kernel_size=4, stride=2,
|
54 |
+
padding=1)
|
55 |
+
down = [downconv]
|
56 |
+
up = [uprelu, upconv, nn.Tanh()]
|
57 |
+
model = down + [submodule] + up
|
58 |
+
elif innermost:
|
59 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
60 |
+
kernel_size=4, stride=2,
|
61 |
+
padding=1, bias=use_bias)
|
62 |
+
down = [downrelu, downconv]
|
63 |
+
up = [uprelu, upconv, upnorm]
|
64 |
+
model = down + up
|
65 |
+
else:
|
66 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
67 |
+
kernel_size=4, stride=2,
|
68 |
+
padding=1, bias=use_bias)
|
69 |
+
down = [downrelu, downconv, downnorm]
|
70 |
+
up = [uprelu, upconv, upnorm]
|
71 |
+
|
72 |
+
if use_dropout:
|
73 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
74 |
+
else:
|
75 |
+
model = down + [submodule] + up
|
76 |
+
|
77 |
+
self.model = nn.Sequential(*model)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
if self.outermost:
|
81 |
+
return self.model(x)
|
82 |
+
else: # add skip connections
|
83 |
+
return torch.cat([x, self.model(x)], 1)
|
84 |
+
|
85 |
+
class UnetGenerator(nn.Module):
|
86 |
+
"""Create a Unet-based generator"""
|
87 |
+
|
88 |
+
def __init__(self, input_nc, output_nc=2, num_downs=8, ngf=64, norm_layer=nn.GroupNorm, use_dropout=True):
|
89 |
+
"""Construct a Unet generator
|
90 |
+
Parameters:
|
91 |
+
input_nc (int) -- the number of channels in input images
|
92 |
+
output_nc (int) -- the number of channels in output images
|
93 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
94 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
95 |
+
ngf (int) -- the number of filters in the last conv layer
|
96 |
+
norm_layer -- normalization layer
|
97 |
+
|
98 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
99 |
+
It is a recursive process.
|
100 |
+
"""
|
101 |
+
super(UnetGenerator, self).__init__()
|
102 |
+
# construct unet structure
|
103 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
104 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
105 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
106 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
107 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
108 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
109 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
110 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
111 |
+
|
112 |
+
def forward(self, input):
|
113 |
+
"""Standard forward"""
|
114 |
+
return self.model(input)
|
115 |
+
|
116 |
+
class NLayerDiscriminator(nn.Module):
|
117 |
+
"""Defines a PatchGAN discriminator"""
|
118 |
+
|
119 |
+
def __init__(self, input_nc, ndf=64, n_layers=6, norm_layer=nn.GroupNorm):
|
120 |
+
"""Construct a PatchGAN discriminator
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
input_nc (int) -- the number of channels in input images
|
124 |
+
ndf (int) -- the number of filters in the last conv layer
|
125 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
126 |
+
norm_layer -- normalization layer
|
127 |
+
"""
|
128 |
+
super(NLayerDiscriminator, self).__init__()
|
129 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
130 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
131 |
+
else:
|
132 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
133 |
+
|
134 |
+
kw = 4
|
135 |
+
padw = 1
|
136 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
137 |
+
nf_mult = 1
|
138 |
+
nf_mult_prev = 1
|
139 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
140 |
+
nf_mult_prev = nf_mult
|
141 |
+
nf_mult = min(2 ** n, 8)
|
142 |
+
sequence += [
|
143 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
144 |
+
norm_layer(32, ndf * nf_mult) if norm_layer == nn.GroupNorm else norm_layer(ndf * nf_mult),
|
145 |
+
nn.LeakyReLU(0.2, True)
|
146 |
+
]
|
147 |
+
|
148 |
+
nf_mult_prev = nf_mult
|
149 |
+
nf_mult = min(2 ** n_layers, 8)
|
150 |
+
sequence += [
|
151 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
152 |
+
norm_layer(32, ndf * nf_mult) if norm_layer == nn.GroupNorm else norm_layer(ndf * nf_mult),
|
153 |
+
nn.LeakyReLU(0.2, True)
|
154 |
+
]
|
155 |
+
|
156 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
157 |
+
self.model = nn.Sequential(*sequence)
|
158 |
+
|
159 |
+
def forward(self, input):
|
160 |
+
"""Standard forward."""
|
161 |
+
return self.model(input)
|
InstDrag/utils/flow_utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
def make_colorwheel():
|
7 |
+
"""
|
8 |
+
Generates a color wheel for optical flow visualization as presented in:
|
9 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
10 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
11 |
+
|
12 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
13 |
+
Code follows the the Matlab source code of Deqing Sun.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
np.ndarray: Color wheel
|
17 |
+
"""
|
18 |
+
|
19 |
+
RY = 15
|
20 |
+
YG = 6
|
21 |
+
GC = 4
|
22 |
+
CB = 11
|
23 |
+
BM = 13
|
24 |
+
MR = 6
|
25 |
+
|
26 |
+
ncols = RY + YG + GC + CB + BM + MR
|
27 |
+
colorwheel = np.zeros((ncols, 3))
|
28 |
+
col = 0
|
29 |
+
|
30 |
+
# RY
|
31 |
+
colorwheel[0:RY, 0] = 255
|
32 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
33 |
+
col = col+RY
|
34 |
+
# YG
|
35 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
36 |
+
colorwheel[col:col+YG, 1] = 255
|
37 |
+
col = col+YG
|
38 |
+
# GC
|
39 |
+
colorwheel[col:col+GC, 1] = 255
|
40 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
41 |
+
col = col+GC
|
42 |
+
# CB
|
43 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
44 |
+
colorwheel[col:col+CB, 2] = 255
|
45 |
+
col = col+CB
|
46 |
+
# BM
|
47 |
+
colorwheel[col:col+BM, 2] = 255
|
48 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
49 |
+
col = col+BM
|
50 |
+
# MR
|
51 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
52 |
+
colorwheel[col:col+MR, 0] = 255
|
53 |
+
return colorwheel
|
54 |
+
|
55 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
56 |
+
"""
|
57 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
58 |
+
|
59 |
+
According to the C++ source code of Daniel Scharstein
|
60 |
+
According to the Matlab source code of Deqing Sun
|
61 |
+
|
62 |
+
Args:
|
63 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
64 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
65 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
69 |
+
"""
|
70 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
71 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
72 |
+
ncols = colorwheel.shape[0]
|
73 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
74 |
+
a = np.arctan2(-v, -u)/np.pi
|
75 |
+
fk = (a+1) / 2*(ncols-1)
|
76 |
+
k0 = np.floor(fk).astype(np.int32)
|
77 |
+
k1 = k0 + 1
|
78 |
+
k1[k1 == ncols] = 0
|
79 |
+
f = fk - k0
|
80 |
+
for i in range(colorwheel.shape[1]):
|
81 |
+
tmp = colorwheel[:,i]
|
82 |
+
col0 = tmp[k0] / 255.0
|
83 |
+
col1 = tmp[k1] / 255.0
|
84 |
+
col = (1-f)*col0 + f*col1
|
85 |
+
idx = (rad <= 1)
|
86 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
87 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
88 |
+
# Note the 2-i => BGR instead of RGB
|
89 |
+
ch_idx = 2-i if convert_to_bgr else i
|
90 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
91 |
+
return flow_image
|
92 |
+
|
93 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None):
|
94 |
+
"""
|
95 |
+
Expects a two dimensional flow image of shape.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
flow_uv (torch.Tensor): Flow UV image of shape [2,H,W]
|
99 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
100 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
PIL Image: Flow visualization image
|
104 |
+
"""
|
105 |
+
flow_uv = flow_uv.permute(1, 2, 0).cpu().numpy() # change to [H,W,2] and convert to numpy
|
106 |
+
|
107 |
+
if clip_flow is not None:
|
108 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
109 |
+
u = flow_uv[:,:,0]
|
110 |
+
v = flow_uv[:,:,1]
|
111 |
+
if max_flow is None:
|
112 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
113 |
+
rad_max = np.max(rad)
|
114 |
+
else:
|
115 |
+
rad_max = max_flow
|
116 |
+
epsilon = 1e-5
|
117 |
+
u = u / (rad_max + epsilon)
|
118 |
+
v = v / (rad_max + epsilon)
|
119 |
+
flow_image = flow_uv_to_colors(u, v, convert_to_bgr)
|
120 |
+
|
121 |
+
return Image.fromarray(flow_image)
|
122 |
+
|
123 |
+
def resize_flow(flow, size, scale_type="none", mode="bicubic"):
|
124 |
+
"""
|
125 |
+
Resize the flow tensor (Bx2xHxW) to the given size (HxW).
|
126 |
+
flow tensor is in range of [-ori_w, ori_w] and [-ori_h, ori_h]
|
127 |
+
Size should be a tuple (H, W).
|
128 |
+
"""
|
129 |
+
ori_h, ori_w = flow.shape[2:]
|
130 |
+
flow = F.interpolate(flow, size=size, mode=mode, align_corners=False)
|
131 |
+
|
132 |
+
if scale_type == "scale" and (ori_h != size[0] or ori_w != size[1]):
|
133 |
+
flow[:,0,:,:] *= size[1] / ori_w
|
134 |
+
flow[:,1,:,:] *= size[0] / ori_h
|
135 |
+
elif scale_type == "normalize_fixed": # normalize to -1 ~ 1
|
136 |
+
flow[:,0,:,:] /= ori_w
|
137 |
+
flow[:,1,:,:] /= ori_h
|
138 |
+
elif scale_type == "normalize_max":
|
139 |
+
max_flow_x = torch.amax(torch.abs(flow[:, 0, :, :]), dim=(1, 2))
|
140 |
+
max_flow_y = torch.amax(torch.abs(flow[:, 1, :, :]), dim=(1, 2))
|
141 |
+
flow[:, 0, :, :] /= max_flow_x.view(-1, 1, 1)
|
142 |
+
flow[:, 1, :, :] /= max_flow_y.view(-1, 1, 1)
|
143 |
+
return flow
|
InstDrag/utils/null_prompt.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7eb3e5fc1308277b9288aa665562eb688e4aa36e6bcbc422083b707468e84d2a
|
3 |
+
size 237655
|