smarques commited on
Commit
fd0db3a
·
1 Parent(s): 6022933

checkout INstantDrag

Browse files
InstDrag/.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ demo/results
2
+ demo/checkpoints
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
InstDrag/README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InstantDrag
2
+
3
+ <p align="center">
4
+ <img src="assets/demo.gif" alt="Demo video">
5
+ </p>
6
+
7
+ <br/>
8
+
9
+ Official implementation of the paper **"InstantDrag: Improving Interactivity in Drag-based Image Editing"** (SIGGRAPH Asia 2024).
10
+
11
+ <p align="center">
12
+ <a href="https://arxiv.org/abs/2409.08857"><img src="https://img.shields.io/badge/arxiv-2409.08857-b31b1b"></a>
13
+ <a href="https://joonghyuk.com/instantdrag-web/"><img src="https://img.shields.io/badge/Project%20Page-InstantDrag-blue"></a>
14
+ <a href="https://huggingface.co/alex4727/InstantDrag"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-forestgreen"></a>
15
+ </p>
16
+
17
+ ---
18
+
19
+ ## Setup
20
+
21
+ 1. Create and activate a conda environment:
22
+ ```bash
23
+ conda create -n instantdrag python=3.10 -y
24
+ conda activate instantdrag
25
+ ```
26
+
27
+ 2. Install PyTorch:
28
+ ```bash
29
+ pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
30
+ ```
31
+
32
+ 3. Install other dependencies:
33
+ ```bash
34
+ pip install transformers==4.44.2 diffusers==0.30.1 accelerate==0.33.0 gradio==4.44.0 opencv-python
35
+ ```
36
+ **Note:** Exact version matching may not be necessary for all dependencies.
37
+
38
+ ## Demo
39
+
40
+ To run the demo:
41
+ ```bash
42
+ cd demo/
43
+ CUDA_VISIBLE_DEVICES=0 python run_demo.py
44
+ ```
45
+ ### Disclaimer
46
+
47
+ - Our **base** models are **solely** trained on real-world talking head (facial) videos, with a focus on achieving **fast fine-grained facial editing w/o metadata**. The preliminary signs of generalizability in other types of scenes, without fine-tuning, should be considered more of an experimental byproduct and may not perform well in many cases. Please check the Appendix A of our paper for more information.
48
+ - This is a research project, **NOT** a commercial product. Use at your own risk.
49
+
50
+ ### Usage Instructions & Tips
51
+
52
+ - Upload and preprocess image using Gradio's interface.
53
+ - Click to define source and target point pairs on the image.
54
+ - Adjust settings in the "Configs" tab.
55
+ - We provide two checkpoints for FlowGen: config-2 (default, used for most figures in the paper) and config-3 (used for benchmark table in the paper). Generally, we recommend config-2 for most cases including few keypoints-based draggings. For extremely fine-grained editing with many drags (i.e. 68 keypoint drags as used in the benchmark), config-3 could be better suited as it produces more local movements.
56
+ - If image moves too much or too little, try modifying the image or flow guidance scales (usually 1 ~ 2 are recommended, but flow guidance can be larger).
57
+ - If you observe loss of identity or noisy artifacts, increasing image guidance or sampling steps could be helpful ([1.75, 1.5] scale is also a good choice for facial images).
58
+ - Click `Run` to perform the editing.
59
+ - We recommend first viewing the example videos (in project page or .gif) and paper figures to understand the model's capabilities. Then, begin with facial images using fine-grained keypoint drags before progressing to more complex motions.
60
+ - As noted in the paper, our model may struggle with large motions that exceed the capabilities of the optical flow estimation networks used for training data extraction.
61
+ - Notes on FlowGen Output Scale
62
+ - In many cases, especially for unseen domains, FlowGen's output doesn't precisely span the -1 to 1 range expected by FlowDiffusion's fixed-size normalization process. For all figures and benchmarks in our paper, we applied a static multiplier of 2 based on observations to adjust FlowGen's output to match the expected range. However, we found that forcefully rescaling the output to -1 to 1 also works well, so we set this as the default behavior (when value is -1). While not recommended, you can manually modify this value to scale the output of FlowGen before feeding it to FlowDiffusion for larger or smaller motions.
63
+
64
+ **Note:** The initial run may take longer as models are loaded to GPU.
65
+
66
+ ## BibTeX
67
+ If you find this work useful, please cite them as below!
68
+ ```
69
+ @inproceedings{shin2024instantdrag,
70
+ title = {{InstantDrag: Improving Interactivity in Drag-based Image Editing}},
71
+ author = {Shin, Joonghyuk and Choi, Daehyeon and Park, Jaesik},
72
+ booktitle = {ACM SIGGRAPH Asia 2024 Conference Proceedings},
73
+ year = {2024},
74
+ pages = {1--10},
75
+ }
76
+ ```
InstDrag/demo/demo_utils.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("../")
3
+
4
+ import os
5
+ import re
6
+ import time
7
+ import datetime
8
+ from copy import deepcopy
9
+
10
+ import numpy as np
11
+ import cv2
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import gradio as gr
15
+ from PIL import Image
16
+ from PIL.ImageOps import exif_transpose
17
+ from safetensors.torch import load_file
18
+
19
+ from utils.flow_utils import flow_to_image, resize_flow
20
+ from flowgen.models import UnetGenerator
21
+ from flowdiffusion.pipeline import FlowDiffusionPipeline
22
+
23
+ LENGTH = 512
24
+ FLOWGAN_RESOLUTION = [256, 256] # HxW
25
+ FLOWDIFFUSION_RESOLUTION = [512, 512] # HxW
26
+
27
+ def process_img(image):
28
+ if image["composite"] is not None and not np.all(image["composite"] == 0):
29
+ original_image = Image.fromarray(image["composite"]).resize((LENGTH, LENGTH), Image.BICUBIC)
30
+ original_image = np.array(exif_transpose(original_image))
31
+ return original_image, [], gr.Image(value=deepcopy(original_image), interactive=False)
32
+ else:
33
+ return (
34
+ gr.Image(value=None, interactive=False),
35
+ [],
36
+ gr.Image(value=None, interactive=False),
37
+ )
38
+
39
+ def get_points(img, sel_pix, evt: gr.SelectData):
40
+ sel_pix.append(evt.index)
41
+ print(sel_pix)
42
+ points = []
43
+ for idx, point in enumerate(sel_pix):
44
+ if idx % 2 == 0:
45
+ cv2.circle(img, tuple(point), 4, (255, 0, 0), -1)
46
+ else:
47
+ cv2.circle(img, tuple(point), 4, (0, 0, 255), -1)
48
+ points.append(tuple(point))
49
+ if len(points) == 2:
50
+ cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5)
51
+ points = []
52
+ img = img if isinstance(img, np.ndarray) else np.array(img)
53
+ return img
54
+
55
+ def display_points(img, predefined_points, save_results):
56
+ if predefined_points != "":
57
+ predefined_points = predefined_points.split()
58
+ predefined_points = [int(re.sub(r'[^0-9]', '', point)) for point in predefined_points]
59
+ processed_points = []
60
+ for i, point in enumerate(predefined_points):
61
+ if i % 2 == 0:
62
+ processed_points.append([point, predefined_points[i+1]])
63
+ selected_points = processed_points
64
+
65
+ print(selected_points)
66
+ points = []
67
+ for idx, point in enumerate(selected_points):
68
+ if idx % 2 == 0:
69
+ cv2.circle(img, tuple(point), 4, (255, 0, 0), -1)
70
+ else:
71
+ cv2.circle(img, tuple(point), 4, (0, 0, 255), -1)
72
+ points.append(tuple(point))
73
+ if len(points) == 2:
74
+ cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5)
75
+ points = []
76
+ img = img if isinstance(img, np.ndarray) else np.array(img)
77
+
78
+ if save_results:
79
+ if not os.path.isdir("results/drag_inst_viz"):
80
+ os.makedirs("results/drag_inst_viz")
81
+ save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
82
+ to_save_img = Image.fromarray(img)
83
+ to_save_img.save(f"results/drag_inst_viz/{save_prefix}.png")
84
+
85
+ return img
86
+
87
+ def undo_points_image(original_image):
88
+ if original_image is not None:
89
+ return original_image, []
90
+ else:
91
+ return gr.Image(value=None, interactive=False), []
92
+
93
+ def clear_all():
94
+ return (
95
+ gr.Image(value=None, interactive=True),
96
+ gr.Image(value=None, interactive=False),
97
+ gr.Image(value=None, interactive=False),
98
+ [],
99
+ None
100
+ )
101
+
102
+ class InstantDragPipeline:
103
+ def __init__(self, seed=9999, device="cuda", dtype=torch.float16):
104
+ self.seed = seed
105
+ self.device = device
106
+ self.dtype = dtype
107
+ self.generator = torch.Generator(device=device).manual_seed(seed)
108
+ self.flowgen_ckpt, self.flowdiffusion_ckpt = None, None
109
+ self.model_config = dict()
110
+
111
+ def build_model(self):
112
+ print("Building model...")
113
+ if self.flowgen_ckpt != self.model_config["flowgen_ckpt"]:
114
+ self.flowgen = UnetGenerator(input_nc=5, output_nc=2)
115
+ self.flowgen.load_state_dict(
116
+ load_file(os.path.join("checkpoints/", self.model_config["flowgen_ckpt"]), device="cpu")
117
+ )
118
+ self.flowgen.to(self.device)
119
+ self.flowgen.eval()
120
+ self.flowgen_ckpt = self.model_config["flowgen_ckpt"]
121
+
122
+ if self.flowdiffusion_ckpt != self.model_config["flowdiffusion_ckpt"]:
123
+ self.flowdiffusion = FlowDiffusionPipeline.from_pretrained(
124
+ os.path.join("checkpoints/", self.model_config["flowdiffusion_ckpt"]),
125
+ torch_dtype=self.dtype,
126
+ safety_checker=None
127
+ )
128
+ self.flowdiffusion.to(self.device)
129
+ self.flowdiffusion_ckpt = self.model_config["flowdiffusion_ckpt"]
130
+
131
+ def drag(self, original_image, selected_points, save_results):
132
+ scale = self.model_config["flowgen_output_scale"]
133
+ original_image = torch.tensor(original_image).permute(2, 0, 1).unsqueeze(0).float() # 1, 3, 512, 512
134
+ original_image = 2 * (original_image / 255.) - 1 # Normalize to [-1, 1]
135
+ original_image = original_image.to(self.device)
136
+
137
+ source_points = []
138
+ target_points = []
139
+ for idx, point in enumerate(selected_points):
140
+ cur_point = torch.tensor([point[0], point[1]]) # x, y
141
+ if idx % 2 == 0:
142
+ source_points.append(cur_point)
143
+ else:
144
+ target_points.append(cur_point)
145
+
146
+ torch.cuda.synchronize()
147
+ start_time = time.time()
148
+
149
+ # Generate sparse flow vectors
150
+ point_vector_map = torch.zeros((1, 2, LENGTH, LENGTH))
151
+ for source_point, target_point in zip(source_points, target_points):
152
+ cur_x, cur_y = source_point[0], source_point[1]
153
+ target_x, target_y = target_point[0], target_point[1]
154
+ vec_x = target_x - cur_x
155
+ vec_y = target_y - cur_y
156
+ point_vector_map[0, 0, int(cur_y), int(cur_x)] = vec_x
157
+ point_vector_map[0, 1, int(cur_y), int(cur_x)] = vec_y
158
+ point_vector_map = point_vector_map.to(self.device)
159
+
160
+ # Sample-wise normalize the flow vectors
161
+ factor_x = torch.amax(torch.abs(point_vector_map[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device)
162
+ factor_y = torch.amax(torch.abs(point_vector_map[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device)
163
+ if factor_x >= 1e-8: # Avoid division by zero
164
+ point_vector_map[:, 0, :, :] /= factor_x
165
+ if factor_y >= 1e-8: # Avoid division by zero
166
+ point_vector_map[:, 1, :, :] /= factor_y
167
+
168
+ with torch.inference_mode():
169
+ gan_input_image = F.interpolate(original_image, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256
170
+ point_vector_map = F.interpolate(point_vector_map, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256
171
+ gan_input = torch.cat([gan_input_image, point_vector_map], dim=1)
172
+ flow = self.flowgen(gan_input) # -1 ~ 1
173
+
174
+ if scale == -1.0:
175
+ flow[:, 0, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1]
176
+ flow[:, 1, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1]
177
+ else:
178
+ flow[:, 0, :, :] *= scale # manually adjust the scale
179
+ flow[:, 1, :, :] *= scale # manually adjust the scale
180
+
181
+ if factor_x >= 1e-8:
182
+ flow[:, 0, :, :] *= factor_x * (FLOWGAN_RESOLUTION[1] / original_image.shape[3]) # width
183
+ else:
184
+ flow[:, 0, :, :] *= 0
185
+ if factor_y >= 1e-8:
186
+ flow[:, 1, :, :] *= factor_y * (FLOWGAN_RESOLUTION[0] / original_image.shape[2]) # height
187
+ else:
188
+ flow[:, 1, :, :] *= 0
189
+
190
+ resized_flow = resize_flow(flow, (FLOWDIFFUSION_RESOLUTION[0]//8, FLOWDIFFUSION_RESOLUTION[1]//8), scale_type="normalize_fixed")
191
+
192
+ kwargs = {
193
+ "image": original_image.to(self.dtype),
194
+ "flow": resized_flow.to(self.dtype),
195
+ "num_inference_steps": self.model_config['n_inference_step'],
196
+ "image_guidance_scale": self.model_config['image_guidance'],
197
+ "flow_guidance_scale": self.model_config['flow_guidance'],
198
+ "generator": self.generator,
199
+ }
200
+ edited_image = self.flowdiffusion(**kwargs).images[0]
201
+
202
+ end_time = time.time()
203
+ inference_time = end_time - start_time
204
+ print(f"Inference Time: {inference_time} seconds")
205
+
206
+ if save_results:
207
+ save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S")
208
+ if not os.path.isdir("results/flows"):
209
+ os.makedirs("results/flows")
210
+ np.save(f"results/flows/{save_prefix}.npy", flow[0].detach().cpu().numpy())
211
+ if not os.path.isdir("results/flow_visualized"):
212
+ os.makedirs("results/flow_visualized")
213
+ flow_to_image(flow[0].detach()).save(f"results/flow_visualized/{save_prefix}.png")
214
+ if not os.path.isdir("results/edited_images"):
215
+ os.makedirs("results/edited_images")
216
+ edited_image.save(f"results/edited_images/{save_prefix}.png")
217
+ if not os.path.isdir("results/drag_instructions"):
218
+ os.makedirs("results/drag_instructions")
219
+ with open(f"results/drag_instructions/{save_prefix}.txt", "w") as f:
220
+ f.write(str(selected_points))
221
+
222
+ edited_image = np.array(edited_image)
223
+ return edited_image
224
+
225
+ def run(self, original_image, selected_points,
226
+ flowgen_ckpt, flowdiffusion_ckpt, image_guidance, flow_guidance, flowgen_output_scale,
227
+ num_steps, save_results):
228
+
229
+ self.model_config = {
230
+ "flowgen_ckpt": flowgen_ckpt,
231
+ "flowdiffusion_ckpt": flowdiffusion_ckpt,
232
+ "image_guidance": image_guidance,
233
+ "flow_guidance": flow_guidance,
234
+ "flowgen_output_scale": flowgen_output_scale,
235
+ "n_inference_step": num_steps
236
+ }
237
+
238
+ self.build_model()
239
+
240
+ edited_image = self.drag(original_image, selected_points, save_results)
241
+
242
+ return edited_image
InstDrag/demo/run_demo.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from huggingface_hub import snapshot_download
5
+ os.makedirs("checkpoints", exist_ok=True)
6
+ snapshot_download("alex4727/InstantDrag", local_dir="./checkpoints")
7
+
8
+ from demo_utils import (
9
+ process_img,
10
+ get_points,
11
+ undo_points_image,
12
+ clear_all,
13
+ InstantDragPipeline,
14
+ )
15
+
16
+ LENGTH = 480 # Length of the square area displaying/editing images
17
+
18
+ with gr.Blocks() as demo:
19
+ pipeline = InstantDragPipeline(seed=42, device="cuda", dtype=torch.float16)
20
+
21
+ # Layout definition
22
+ with gr.Row():
23
+ gr.Markdown(
24
+ """
25
+ # InstantDrag: Improving Interactivity in Drag-based Image Editing
26
+ """
27
+ )
28
+
29
+ with gr.Tab(label="InstantDrag Demo"):
30
+ selected_points = gr.State([]) # Store points
31
+ original_image = gr.State(value=None) # Store original input image
32
+
33
+ with gr.Row():
34
+ # Upload & Preprocess Image Column
35
+ with gr.Column():
36
+ gr.Markdown(
37
+ """<p style="text-align: center; font-size: 20px">Upload & Preprocess Image</p>"""
38
+ )
39
+ canvas = gr.ImageEditor(
40
+ height=LENGTH,
41
+ width=LENGTH,
42
+ type="numpy",
43
+ image_mode="RGB",
44
+ label="Preprocess Image",
45
+ show_label=True,
46
+ interactive=True,
47
+ )
48
+ with gr.Row():
49
+ save_results = gr.Checkbox(
50
+ value=False,
51
+ label="Save Results",
52
+ scale=1,
53
+ )
54
+ undo_button = gr.Button("Undo Clicked Points", scale=3)
55
+
56
+ # Click Points Column
57
+ with gr.Column():
58
+ gr.Markdown(
59
+ """<p style="text-align: center; font-size: 20px">Click Points</p>"""
60
+ )
61
+ input_image = gr.Image(
62
+ type="numpy",
63
+ label="Click Points",
64
+ show_label=True,
65
+ height=LENGTH,
66
+ width=LENGTH,
67
+ interactive=False,
68
+ show_fullscreen_button=False,
69
+ )
70
+ with gr.Row():
71
+ run_button = gr.Button("Run")
72
+
73
+ # Editing Results Column
74
+ with gr.Column():
75
+ gr.Markdown(
76
+ """<p style="text-align: center; font-size: 20px">Editing Results</p>"""
77
+ )
78
+ edited_image = gr.Image(
79
+ type="numpy",
80
+ label="Editing Results",
81
+ show_label=True,
82
+ height=LENGTH,
83
+ width=LENGTH,
84
+ interactive=False,
85
+ show_fullscreen_button=False,
86
+ )
87
+ with gr.Row():
88
+ clear_all_button = gr.Button("Clear All")
89
+
90
+ with gr.Tab("Configs - make sure to check README for details"):
91
+ with gr.Row():
92
+ with gr.Column():
93
+ with gr.Row():
94
+ flowgen_choices = sorted(
95
+ [model for model in os.listdir("checkpoints/") if "flowgen" in model]
96
+ )
97
+ flowgen_ckpt = gr.Dropdown(
98
+ value=flowgen_choices[0],
99
+ label="Select FlowGen to use",
100
+ choices=flowgen_choices,
101
+ info="config2 for most cases, config3 for more fine-grained dragging",
102
+ scale=2,
103
+ )
104
+ flowdiffusion_choices = sorted(
105
+ [model for model in os.listdir("checkpoints/") if "flowdiffusion" in model]
106
+ )
107
+ flowdiffusion_ckpt = gr.Dropdown(
108
+ value=flowdiffusion_choices[0],
109
+ label="Select FlowDiffusion to use",
110
+ choices=flowdiffusion_choices,
111
+ info="single model for all cases",
112
+ scale=1,
113
+ )
114
+ image_guidance = gr.Number(
115
+ value=1.5,
116
+ label="Image Guidance Scale",
117
+ precision=2,
118
+ step=0.1,
119
+ scale=1,
120
+ info="typically between 1.0-2.0.",
121
+ )
122
+ flow_guidance = gr.Number(
123
+ value=1.5,
124
+ label="Flow Guidance Scale",
125
+ precision=2,
126
+ step=0.1,
127
+ scale=1,
128
+ info="typically between 1.0-5.0",
129
+ )
130
+ num_steps = gr.Number(
131
+ value=20,
132
+ label="Inference Steps",
133
+ precision=0,
134
+ step=1,
135
+ scale=1,
136
+ info="typically between 20-50, 20 is usually enough",
137
+ )
138
+ flowgen_output_scale = gr.Number(
139
+ value=-1.0,
140
+ label="FlowGen Output Scale",
141
+ precision=1,
142
+ step=0.1,
143
+ scale=2,
144
+ info="-1.0, by default, forces flowgen's output to [-1, 1], could be adjusted to [0, ∞] for stronger/weaker effects",
145
+ )
146
+
147
+ gr.Markdown(
148
+ """
149
+ <p style="text-align: center; font-size: 18px;">Examples</p>
150
+ """
151
+ )
152
+ with gr.Row():
153
+ gr.Examples(
154
+ examples=[
155
+ "samples/airplane.jpg",
156
+ "samples/anime.jpg",
157
+ "samples/caligraphy.jpg",
158
+ "samples/crocodile.jpg",
159
+ "samples/elephant.jpg",
160
+ "samples/meteor.jpg",
161
+ "samples/monalisa.jpg",
162
+ "samples/portrait.jpg",
163
+ "samples/sketch.jpg",
164
+ "samples/surreal.jpg",
165
+ ],
166
+ inputs=[canvas],
167
+ outputs=[original_image, selected_points, input_image],
168
+ fn=process_img,
169
+ cache_examples=False,
170
+ examples_per_page=10,
171
+ )
172
+ gr.Markdown(
173
+ """
174
+ <p style="text-align: center; font-size: 9">[Important] Our base models are solely trained on real-world talking head (facial) videos, with a focus on achieving fine-grained facial editing. <br>
175
+ Their application to other types of scenes, without fine-tuning, should be considered more of an experimental byproduct and may not perform well in many cases (we currently support only square images).</p>
176
+ """
177
+ )
178
+
179
+ # Event Handlers
180
+ canvas.change(
181
+ process_img,
182
+ [canvas],
183
+ [original_image, selected_points, input_image],
184
+ )
185
+
186
+ input_image.select(
187
+ get_points,
188
+ [input_image, selected_points],
189
+ [input_image],
190
+ )
191
+
192
+ undo_button.click(
193
+ undo_points_image,
194
+ [original_image],
195
+ [input_image, selected_points],
196
+ )
197
+
198
+ run_button.click(
199
+ pipeline.run,
200
+ [
201
+ original_image,
202
+ selected_points,
203
+ flowgen_ckpt,
204
+ flowdiffusion_ckpt,
205
+ image_guidance,
206
+ flow_guidance,
207
+ flowgen_output_scale,
208
+ num_steps,
209
+ save_results,
210
+ ],
211
+ [edited_image],
212
+ )
213
+
214
+ clear_all_button.click(
215
+ clear_all,
216
+ [],
217
+ [
218
+ canvas,
219
+ input_image,
220
+ edited_image,
221
+ selected_points,
222
+ original_image,
223
+ ],
224
+ )
225
+
226
+ demo.queue().launch(share=False, debug=True)
InstDrag/demo/samples/airplane.jpg ADDED
InstDrag/demo/samples/anime.jpg ADDED
InstDrag/demo/samples/caligraphy.jpg ADDED
InstDrag/demo/samples/crocodile.jpg ADDED
InstDrag/demo/samples/elephant.jpg ADDED
InstDrag/demo/samples/meteor.jpg ADDED
InstDrag/demo/samples/monalisa.jpg ADDED
InstDrag/demo/samples/portrait.jpg ADDED
InstDrag/demo/samples/sketch.jpg ADDED
InstDrag/demo/samples/surreal.jpg ADDED
InstDrag/flowdiffusion/pipeline.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is partially based on the diffusers library, which licensed the code under the following license:
2
+
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+ import os
20
+ from pathlib import Path
21
+
22
+ import PIL.Image
23
+ import torch
24
+ from transformers import CLIPImageProcessor
25
+
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
+ from diffusers.loaders import StableDiffusionLoraLoaderMixin
29
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
30
+ from diffusers.schedulers import KarrasDiffusionSchedulers
31
+ from diffusers.utils import deprecate, logging
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
34
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
35
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
41
+ def retrieve_latents(
42
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
43
+ ):
44
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
45
+ return encoder_output.latent_dist.sample(generator)
46
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
47
+ return encoder_output.latent_dist.mode()
48
+ elif hasattr(encoder_output, "latents"):
49
+ return encoder_output.latents
50
+ else:
51
+ raise AttributeError("Could not access latents of provided encoder_output")
52
+
53
+
54
+ class FlowDiffusionPipeline(
55
+ DiffusionPipeline,
56
+ StableDiffusionMixin,
57
+ StableDiffusionLoraLoaderMixin,
58
+ ):
59
+ r"""
60
+ Pipeline for pixel-level image editing given optical flow as condition.
61
+
62
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
63
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
64
+
65
+ The pipeline also inherits the following loading methods:
66
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
67
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
68
+
69
+ Args:
70
+ vae ([`AutoencoderKL`]):
71
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
72
+ unet ([`UNet2DConditionModel`]):
73
+ A `UNet2DConditionModel` to denoise the encoded image latents.
74
+ scheduler ([`SchedulerMixin`]):
75
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
76
+ safety_checker ([`StableDiffusionSafetyChecker`]):
77
+ Classification module that estimates whether generated images could be considered offensive or harmful.
78
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
79
+ about a model's potential harms.
80
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
81
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
82
+ """
83
+
84
+ model_cpu_offload_seq = "unet->vae"
85
+ _optional_components = ["safety_checker", "feature_extractor"]
86
+ _exclude_from_cpu_offload = ["safety_checker"]
87
+ _callback_tensor_inputs = ["latents", "image_latents"]
88
+
89
+ def __init__(
90
+ self,
91
+ vae: AutoencoderKL,
92
+ unet: UNet2DConditionModel,
93
+ scheduler: KarrasDiffusionSchedulers,
94
+ safety_checker: StableDiffusionSafetyChecker,
95
+ feature_extractor: CLIPImageProcessor,
96
+ requires_safety_checker: bool = False,
97
+ null_prompt: str = "../utils/null_prompt.pt"
98
+ ):
99
+ super().__init__()
100
+
101
+ if safety_checker is None and requires_safety_checker:
102
+ logger.warning(
103
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
104
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
105
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
106
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
107
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
108
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
109
+ )
110
+
111
+ if safety_checker is not None and feature_extractor is None:
112
+ raise ValueError(
113
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
114
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
115
+ )
116
+
117
+ self.register_modules(
118
+ vae=vae,
119
+ unet=unet,
120
+ scheduler=scheduler,
121
+ safety_checker=safety_checker,
122
+ feature_extractor=feature_extractor,
123
+ )
124
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
125
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
126
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
127
+ self.null_prompt_embeds = torch.load(os.path.join(Path(__file__).parent.absolute(), null_prompt), map_location="cpu")
128
+
129
+ @torch.no_grad()
130
+ def __call__(
131
+ self,
132
+ image: PipelineImageInput = None,
133
+ flow: torch.Tensor = None,
134
+ num_inference_steps: int = 20,
135
+ image_guidance_scale: float = 1.5,
136
+ flow_guidance_scale: float = 1.5,
137
+ eta: float = 0.0,
138
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
139
+ latents: Optional[torch.Tensor] = None,
140
+ output_type: Optional[str] = "pil",
141
+ return_dict: bool = True,
142
+ callback_on_step_end: Optional[
143
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
144
+ ] = None,
145
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
146
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
147
+ **kwargs,
148
+ ):
149
+ r"""
150
+ The call function to the pipeline for generation.
151
+
152
+ Args:
153
+ image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
154
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
155
+ image latents as `image`, but if passing latents directly it is not encoded again. We only support batch size of 1 for now.
156
+ flow: torch.Tensor = None,
157
+ Optical flow tensor to be used as a condition for the image generation. We only support batch size of 1 for now.
158
+ num_inference_steps (`int`, *optional*, defaults to 20):
159
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
160
+ expense of slower inference.
161
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
162
+ Push the generated image towards the initial `image`. Image guidance scale is enabled by setting
163
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
164
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
165
+ value of at least `1`.
166
+ flow_guidance_scale (`float`, *optional*, defaults to 1.5):
167
+ Apply the flow guidance to the image generation. Higher values of `flow_guidance_scale` encourage
168
+ the model to follow the flow stronger.
169
+ eta (`float`, *optional*, defaults to 0.0):
170
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
171
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
172
+ generator (`torch.Generator`, *optional*):
173
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
174
+ generation deterministic.
175
+ latents (`torch.Tensor`, *optional*):
176
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
177
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
178
+ tensor is generated by sampling using the supplied random `generator`.
179
+ output_type (`str`, *optional*, defaults to `"pil"`):
180
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
181
+ return_dict (`bool`, *optional*, defaults to `True`):
182
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
183
+ plain tuple.
184
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
185
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
186
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
187
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
188
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
189
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
190
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
191
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
192
+ `._callback_tensor_inputs` attribute of your pipeline class.
193
+ cross_attention_kwargs (`dict`, *optional*):
194
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
195
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
196
+ """
197
+
198
+ callback = kwargs.pop("callback", None)
199
+ callback_steps = kwargs.pop("callback_steps", None)
200
+
201
+ if callback is not None:
202
+ deprecate(
203
+ "callback",
204
+ "1.0.0",
205
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
206
+ )
207
+ if callback_steps is not None:
208
+ deprecate(
209
+ "callback_steps",
210
+ "1.0.0",
211
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
212
+ )
213
+
214
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
215
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
216
+
217
+ # 0. Check inputs
218
+ self.check_inputs(
219
+ callback_steps,
220
+ callback_on_step_end_tensor_inputs,
221
+ )
222
+ self._image_guidance_scale = image_guidance_scale
223
+ self._flow_guidance_scale = flow_guidance_scale
224
+
225
+ device = self._execution_device
226
+
227
+ if image is None or flow is None:
228
+ raise ValueError("`image` or `flow` input cannot be undefined.")
229
+
230
+ # 1. Define call parameters
231
+
232
+ # 2. Encode input prompt
233
+ prompt_embeds = self._encode_prompt(
234
+ device,
235
+ self.do_classifier_free_guidance,
236
+ )
237
+
238
+ # 3. Preprocess image
239
+ image = self.image_processor.preprocess(image)
240
+ assert image.shape[0] == 1 and flow.shape[0] == 1, "Batch size must be 1 for now."
241
+
242
+ # 4. set timesteps
243
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
244
+ timesteps = self.scheduler.timesteps
245
+
246
+ # 5. Prepare Image latents
247
+ image_latents = self.prepare_image_latents(
248
+ image,
249
+ flow,
250
+ prompt_embeds.dtype,
251
+ device,
252
+ self.do_classifier_free_guidance,
253
+ )
254
+
255
+ height, width = image_latents.shape[-2:]
256
+ height = height * self.vae_scale_factor
257
+ width = width * self.vae_scale_factor
258
+
259
+ # 6. Prepare latent variables
260
+ num_channels_latents = self.vae.config.latent_channels
261
+ latents = self.prepare_latents(
262
+ num_channels_latents,
263
+ height,
264
+ width,
265
+ prompt_embeds.dtype,
266
+ device,
267
+ generator,
268
+ latents,
269
+ )
270
+
271
+ # 7. Check that shapes of latents and image match the UNet channels
272
+ num_channels_image = image_latents.shape[1]
273
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
274
+ raise ValueError(
275
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
276
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
277
+ f" `num_channels_image`: {num_channels_image} "
278
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
279
+ " `pipeline.unet` or your `image` input."
280
+ )
281
+
282
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
283
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
284
+
285
+ # 9. Denoising loop
286
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
287
+ self._num_timesteps = len(timesteps)
288
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
289
+ for i, t in enumerate(timesteps):
290
+ # Expand the latents if we are doing classifier free guidance.
291
+ # The latents are expanded 3 times because for image / flow guidance
292
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
293
+
294
+ # concat latents, image_latents in the channel dimension
295
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
296
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
297
+
298
+ # predict the noise residual
299
+ noise_pred = self.unet(
300
+ scaled_latent_model_input,
301
+ t,
302
+ encoder_hidden_states=prompt_embeds,
303
+ added_cond_kwargs=None,
304
+ cross_attention_kwargs=cross_attention_kwargs,
305
+ return_dict=False,
306
+ )[0]
307
+
308
+ # perform guidance
309
+ if self.do_classifier_free_guidance:
310
+ noise_pred_flow, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
311
+ noise_pred = (
312
+ noise_pred_uncond
313
+ + self._image_guidance_scale * (noise_pred_image - noise_pred_uncond)
314
+ + self._flow_guidance_scale * (noise_pred_flow - noise_pred_image)
315
+ )
316
+
317
+ # compute the previous noisy sample x_t -> x_t-1
318
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
319
+
320
+ if callback_on_step_end is not None:
321
+ callback_kwargs = {}
322
+ for k in callback_on_step_end_tensor_inputs:
323
+ callback_kwargs[k] = locals()[k]
324
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
325
+
326
+ latents = callback_outputs.pop("latents", latents)
327
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
328
+ image_latents = callback_outputs.pop("image_latents", image_latents)
329
+
330
+ # call the callback, if provided
331
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
332
+ progress_bar.update()
333
+ if callback is not None and i % callback_steps == 0:
334
+ step_idx = i // getattr(self.scheduler, "order", 1)
335
+ callback(step_idx, t, latents)
336
+
337
+ if not output_type == "latent":
338
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
339
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
340
+ else:
341
+ image = latents
342
+ has_nsfw_concept = None
343
+
344
+ if has_nsfw_concept is None:
345
+ do_denormalize = [True] * image.shape[0]
346
+ else:
347
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
348
+
349
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
350
+
351
+ # Offload all models
352
+ self.maybe_free_model_hooks()
353
+
354
+ if not return_dict:
355
+ return (image, has_nsfw_concept)
356
+
357
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
358
+
359
+ def _encode_prompt(
360
+ self,
361
+ device,
362
+ do_classifier_free_guidance,
363
+ ):
364
+ prompt_embeds = self.null_prompt_embeds.to(dtype=torch.float16, device=device) # 1 77 512
365
+
366
+ if do_classifier_free_guidance: # We are only doing cfg for image and flow
367
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds, prompt_embeds]) # 3 77 512
368
+
369
+ return prompt_embeds
370
+
371
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
372
+ def run_safety_checker(self, image, device, dtype):
373
+ if self.safety_checker is None:
374
+ has_nsfw_concept = None
375
+ else:
376
+ if torch.is_tensor(image):
377
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
378
+ else:
379
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
380
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
381
+ image, has_nsfw_concept = self.safety_checker(
382
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
383
+ )
384
+ return image, has_nsfw_concept
385
+
386
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
387
+ def prepare_extra_step_kwargs(self, generator, eta):
388
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
389
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
390
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
391
+ # and should be between [0, 1]
392
+
393
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
394
+ extra_step_kwargs = {}
395
+ if accepts_eta:
396
+ extra_step_kwargs["eta"] = eta
397
+
398
+ # check if the scheduler accepts generator
399
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
400
+ if accepts_generator:
401
+ extra_step_kwargs["generator"] = generator
402
+ return extra_step_kwargs
403
+
404
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
405
+ def decode_latents(self, latents):
406
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
407
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
408
+
409
+ latents = 1 / self.vae.config.scaling_factor * latents
410
+ image = self.vae.decode(latents, return_dict=False)[0]
411
+ image = (image / 2 + 0.5).clamp(0, 1)
412
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
413
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
414
+ return image
415
+
416
+ def check_inputs(
417
+ self,
418
+ callback_steps,
419
+ callback_on_step_end_tensor_inputs=None,
420
+ ):
421
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
422
+ raise ValueError(
423
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
424
+ f" {type(callback_steps)}."
425
+ )
426
+
427
+ if callback_on_step_end_tensor_inputs is not None and not all(
428
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
429
+ ):
430
+ raise ValueError(
431
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
432
+ )
433
+
434
+ def prepare_latents(self, num_channels_latents, height, width, dtype, device, generator, latents=None):
435
+ shape = (
436
+ 1,
437
+ num_channels_latents,
438
+ int(height) // self.vae_scale_factor,
439
+ int(width) // self.vae_scale_factor,
440
+ )
441
+ if isinstance(generator, list) and len(generator) != 1:
442
+ raise ValueError(
443
+ f"You have passed a list of generators of length {len(generator)}, but we only support a single batch for now."
444
+ )
445
+
446
+ if latents is None:
447
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
448
+ else:
449
+ latents = latents.to(device)
450
+
451
+ # scale the initial noise by the standard deviation required by the scheduler
452
+ latents = latents * self.scheduler.init_noise_sigma
453
+ return latents
454
+
455
+ def prepare_image_latents(
456
+ self, image, flow, dtype, device, do_classifier_free_guidance, generator=None
457
+ ):
458
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
459
+ raise ValueError(
460
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
461
+ )
462
+
463
+ image = image.to(device=device, dtype=dtype)
464
+
465
+ if image.shape[1] == 4:
466
+ image_latents = image
467
+ else:
468
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
469
+
470
+ image_latents_flow_cond = torch.cat([image_latents, flow.to(device)], dim=1)
471
+
472
+ if do_classifier_free_guidance:
473
+ image_latents_flow_uncond = torch.cat([image_latents, torch.zeros_like(flow).to(device)], dim=1)
474
+ image_latents_uncond = torch.zeros_like(image_latents_flow_cond)
475
+ image_latents_final = torch.cat([image_latents_flow_cond, image_latents_flow_uncond, image_latents_uncond], dim=0)
476
+ else:
477
+ image_latents_final = image_latents_flow_cond
478
+
479
+ return image_latents_final
480
+
481
+ @property
482
+ def image_guidance_scale(self):
483
+ return self._image_guidance_scale
484
+
485
+ @property
486
+ def flow_guidance_scale(self):
487
+ return self._flow_guidance_scale
488
+
489
+ @property
490
+ def num_timesteps(self):
491
+ return self._num_timesteps
492
+
493
+ @property
494
+ def do_classifier_free_guidance(self):
495
+ return self._image_guidance_scale > 1 or self._flow_guidance_scale > 1
InstDrag/flowgen/models.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import functools
6
+
7
+ class UnetSkipConnectionBlock(nn.Module):
8
+ """Defines the Unet submodule with skip connection.
9
+ X -------------------identity----------------------
10
+ |-- downsampling -- |submodule| -- upsampling --|
11
+ """
12
+
13
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
14
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
15
+ """Construct a Unet submodule with skip connections.
16
+
17
+ Parameters:
18
+ outer_nc (int) -- the number of filters in the outer conv layer
19
+ inner_nc (int) -- the number of filters in the inner conv layer
20
+ input_nc (int) -- the number of channels in input images/features
21
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
22
+ outermost (bool) -- if this module is the outermost module
23
+ innermost (bool) -- if this module is the innermost module
24
+ norm_layer -- normalization layer
25
+ use_dropout (bool) -- if use dropout layers.
26
+ """
27
+ super(UnetSkipConnectionBlock, self).__init__()
28
+ self.outermost = outermost
29
+ if type(norm_layer) == functools.partial:
30
+ use_bias = norm_layer.func != nn.BatchNorm2d
31
+ else:
32
+ use_bias = norm_layer != nn.BatchNorm2d
33
+ if input_nc is None:
34
+ input_nc = outer_nc
35
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
36
+ stride=2, padding=1, bias=use_bias)
37
+ downrelu = nn.LeakyReLU(0.2, True)
38
+
39
+ if norm_layer == nn.GroupNorm:
40
+ downnorm = norm_layer(32, inner_nc)
41
+ else: downnorm = norm_layer(inner_nc)
42
+ uprelu = nn.ReLU(True)
43
+ if norm_layer == nn.GroupNorm:
44
+ if outer_nc % 32 != 0:
45
+ upnorm = norm_layer(outer_nc, outer_nc) # Layer Norm
46
+ else:
47
+ upnorm = norm_layer(32, outer_nc)
48
+ else:
49
+ upnorm = norm_layer(outer_nc)
50
+
51
+ if outermost:
52
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
53
+ kernel_size=4, stride=2,
54
+ padding=1)
55
+ down = [downconv]
56
+ up = [uprelu, upconv, nn.Tanh()]
57
+ model = down + [submodule] + up
58
+ elif innermost:
59
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
60
+ kernel_size=4, stride=2,
61
+ padding=1, bias=use_bias)
62
+ down = [downrelu, downconv]
63
+ up = [uprelu, upconv, upnorm]
64
+ model = down + up
65
+ else:
66
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
67
+ kernel_size=4, stride=2,
68
+ padding=1, bias=use_bias)
69
+ down = [downrelu, downconv, downnorm]
70
+ up = [uprelu, upconv, upnorm]
71
+
72
+ if use_dropout:
73
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
74
+ else:
75
+ model = down + [submodule] + up
76
+
77
+ self.model = nn.Sequential(*model)
78
+
79
+ def forward(self, x):
80
+ if self.outermost:
81
+ return self.model(x)
82
+ else: # add skip connections
83
+ return torch.cat([x, self.model(x)], 1)
84
+
85
+ class UnetGenerator(nn.Module):
86
+ """Create a Unet-based generator"""
87
+
88
+ def __init__(self, input_nc, output_nc=2, num_downs=8, ngf=64, norm_layer=nn.GroupNorm, use_dropout=True):
89
+ """Construct a Unet generator
90
+ Parameters:
91
+ input_nc (int) -- the number of channels in input images
92
+ output_nc (int) -- the number of channels in output images
93
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
94
+ image of size 128x128 will become of size 1x1 # at the bottleneck
95
+ ngf (int) -- the number of filters in the last conv layer
96
+ norm_layer -- normalization layer
97
+
98
+ We construct the U-Net from the innermost layer to the outermost layer.
99
+ It is a recursive process.
100
+ """
101
+ super(UnetGenerator, self).__init__()
102
+ # construct unet structure
103
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
104
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
105
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
106
+ # gradually reduce the number of filters from ngf * 8 to ngf
107
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
108
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
109
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
110
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
111
+
112
+ def forward(self, input):
113
+ """Standard forward"""
114
+ return self.model(input)
115
+
116
+ class NLayerDiscriminator(nn.Module):
117
+ """Defines a PatchGAN discriminator"""
118
+
119
+ def __init__(self, input_nc, ndf=64, n_layers=6, norm_layer=nn.GroupNorm):
120
+ """Construct a PatchGAN discriminator
121
+
122
+ Parameters:
123
+ input_nc (int) -- the number of channels in input images
124
+ ndf (int) -- the number of filters in the last conv layer
125
+ n_layers (int) -- the number of conv layers in the discriminator
126
+ norm_layer -- normalization layer
127
+ """
128
+ super(NLayerDiscriminator, self).__init__()
129
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
130
+ use_bias = norm_layer.func != nn.BatchNorm2d
131
+ else:
132
+ use_bias = norm_layer != nn.BatchNorm2d
133
+
134
+ kw = 4
135
+ padw = 1
136
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
137
+ nf_mult = 1
138
+ nf_mult_prev = 1
139
+ for n in range(1, n_layers): # gradually increase the number of filters
140
+ nf_mult_prev = nf_mult
141
+ nf_mult = min(2 ** n, 8)
142
+ sequence += [
143
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
144
+ norm_layer(32, ndf * nf_mult) if norm_layer == nn.GroupNorm else norm_layer(ndf * nf_mult),
145
+ nn.LeakyReLU(0.2, True)
146
+ ]
147
+
148
+ nf_mult_prev = nf_mult
149
+ nf_mult = min(2 ** n_layers, 8)
150
+ sequence += [
151
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
152
+ norm_layer(32, ndf * nf_mult) if norm_layer == nn.GroupNorm else norm_layer(ndf * nf_mult),
153
+ nn.LeakyReLU(0.2, True)
154
+ ]
155
+
156
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
157
+ self.model = nn.Sequential(*sequence)
158
+
159
+ def forward(self, input):
160
+ """Standard forward."""
161
+ return self.model(input)
InstDrag/utils/flow_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ def make_colorwheel():
7
+ """
8
+ Generates a color wheel for optical flow visualization as presented in:
9
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
10
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
11
+
12
+ Code follows the original C++ source code of Daniel Scharstein.
13
+ Code follows the the Matlab source code of Deqing Sun.
14
+
15
+ Returns:
16
+ np.ndarray: Color wheel
17
+ """
18
+
19
+ RY = 15
20
+ YG = 6
21
+ GC = 4
22
+ CB = 11
23
+ BM = 13
24
+ MR = 6
25
+
26
+ ncols = RY + YG + GC + CB + BM + MR
27
+ colorwheel = np.zeros((ncols, 3))
28
+ col = 0
29
+
30
+ # RY
31
+ colorwheel[0:RY, 0] = 255
32
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
33
+ col = col+RY
34
+ # YG
35
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
36
+ colorwheel[col:col+YG, 1] = 255
37
+ col = col+YG
38
+ # GC
39
+ colorwheel[col:col+GC, 1] = 255
40
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
41
+ col = col+GC
42
+ # CB
43
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
44
+ colorwheel[col:col+CB, 2] = 255
45
+ col = col+CB
46
+ # BM
47
+ colorwheel[col:col+BM, 2] = 255
48
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
49
+ col = col+BM
50
+ # MR
51
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
52
+ colorwheel[col:col+MR, 0] = 255
53
+ return colorwheel
54
+
55
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
56
+ """
57
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
58
+
59
+ According to the C++ source code of Daniel Scharstein
60
+ According to the Matlab source code of Deqing Sun
61
+
62
+ Args:
63
+ u (np.ndarray): Input horizontal flow of shape [H,W]
64
+ v (np.ndarray): Input vertical flow of shape [H,W]
65
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
66
+
67
+ Returns:
68
+ np.ndarray: Flow visualization image of shape [H,W,3]
69
+ """
70
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
71
+ colorwheel = make_colorwheel() # shape [55x3]
72
+ ncols = colorwheel.shape[0]
73
+ rad = np.sqrt(np.square(u) + np.square(v))
74
+ a = np.arctan2(-v, -u)/np.pi
75
+ fk = (a+1) / 2*(ncols-1)
76
+ k0 = np.floor(fk).astype(np.int32)
77
+ k1 = k0 + 1
78
+ k1[k1 == ncols] = 0
79
+ f = fk - k0
80
+ for i in range(colorwheel.shape[1]):
81
+ tmp = colorwheel[:,i]
82
+ col0 = tmp[k0] / 255.0
83
+ col1 = tmp[k1] / 255.0
84
+ col = (1-f)*col0 + f*col1
85
+ idx = (rad <= 1)
86
+ col[idx] = 1 - rad[idx] * (1-col[idx])
87
+ col[~idx] = col[~idx] * 0.75 # out of range
88
+ # Note the 2-i => BGR instead of RGB
89
+ ch_idx = 2-i if convert_to_bgr else i
90
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
91
+ return flow_image
92
+
93
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None):
94
+ """
95
+ Expects a two dimensional flow image of shape.
96
+
97
+ Args:
98
+ flow_uv (torch.Tensor): Flow UV image of shape [2,H,W]
99
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
100
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
101
+
102
+ Returns:
103
+ PIL Image: Flow visualization image
104
+ """
105
+ flow_uv = flow_uv.permute(1, 2, 0).cpu().numpy() # change to [H,W,2] and convert to numpy
106
+
107
+ if clip_flow is not None:
108
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
109
+ u = flow_uv[:,:,0]
110
+ v = flow_uv[:,:,1]
111
+ if max_flow is None:
112
+ rad = np.sqrt(np.square(u) + np.square(v))
113
+ rad_max = np.max(rad)
114
+ else:
115
+ rad_max = max_flow
116
+ epsilon = 1e-5
117
+ u = u / (rad_max + epsilon)
118
+ v = v / (rad_max + epsilon)
119
+ flow_image = flow_uv_to_colors(u, v, convert_to_bgr)
120
+
121
+ return Image.fromarray(flow_image)
122
+
123
+ def resize_flow(flow, size, scale_type="none", mode="bicubic"):
124
+ """
125
+ Resize the flow tensor (Bx2xHxW) to the given size (HxW).
126
+ flow tensor is in range of [-ori_w, ori_w] and [-ori_h, ori_h]
127
+ Size should be a tuple (H, W).
128
+ """
129
+ ori_h, ori_w = flow.shape[2:]
130
+ flow = F.interpolate(flow, size=size, mode=mode, align_corners=False)
131
+
132
+ if scale_type == "scale" and (ori_h != size[0] or ori_w != size[1]):
133
+ flow[:,0,:,:] *= size[1] / ori_w
134
+ flow[:,1,:,:] *= size[0] / ori_h
135
+ elif scale_type == "normalize_fixed": # normalize to -1 ~ 1
136
+ flow[:,0,:,:] /= ori_w
137
+ flow[:,1,:,:] /= ori_h
138
+ elif scale_type == "normalize_max":
139
+ max_flow_x = torch.amax(torch.abs(flow[:, 0, :, :]), dim=(1, 2))
140
+ max_flow_y = torch.amax(torch.abs(flow[:, 1, :, :]), dim=(1, 2))
141
+ flow[:, 0, :, :] /= max_flow_x.view(-1, 1, 1)
142
+ flow[:, 1, :, :] /= max_flow_y.view(-1, 1, 1)
143
+ return flow
InstDrag/utils/null_prompt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7eb3e5fc1308277b9288aa665562eb688e4aa36e6bcbc422083b707468e84d2a
3
+ size 237655