Spaces:
Sleeping
Sleeping
Tony Lian
commited on
Commit
•
1f39cf9
1
Parent(s):
58524a7
Add stage 2
Browse files- .gitignore +160 -0
- README.md +4 -2
- app.py +105 -17
- baseline.py +41 -0
- generation.py +179 -0
- models/__init__.py +1 -0
- models/attention.py +392 -0
- models/attention_processor.py +508 -0
- models/modeling_utils.py +874 -0
- models/models.py +97 -0
- models/pipelines.py +246 -0
- models/sam.py +179 -0
- models/transformer_2d.py +367 -0
- models/unet_2d_blocks.py +793 -0
- models/unet_2d_condition.py +980 -0
- requirements.txt +11 -0
- shared.py +11 -0
- utils/__init__.py +1 -0
- utils/latents.py +151 -0
- utils/parse.py +284 -0
- utils/utils.py +165 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
README.md
CHANGED
@@ -4,10 +4,12 @@ emoji: 😊
|
|
4 |
colorFrom: red
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
tags: [llm, diffusion, grounding, grounded, llm-grounded, text-to-image, language, large language models, layout, generation, generative, customization, personalization, prompting, chatgpt, gpt-3.5, gpt-4]
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
4 |
colorFrom: red
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.34.0
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
tags: [llm, diffusion, grounding, grounded, llm-grounded, text-to-image, language, large language models, layout, generation, generative, customization, personalization, prompting, chatgpt, gpt-3.5, gpt-4]
|
11 |
---
|
12 |
|
13 |
+
Credits:
|
14 |
+
|
15 |
+
This space uses code from [diffusers](https://huggingface.co/docs/diffusers/index), [GLIGEN](https://github.com/gligen/GLIGEN), and [layout-guidance](https://github.com/silent-chen/layout-guidance). Using their code means adhering to their license.
|
app.py
CHANGED
@@ -4,13 +4,21 @@ import ast
|
|
4 |
from matplotlib.patches import Polygon
|
5 |
from matplotlib.collections import PatchCollection
|
6 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
box_scale = (512, 512)
|
9 |
size = box_scale
|
10 |
|
11 |
bg_prompt_text = "Background prompt: "
|
12 |
|
13 |
-
simplified_prompt = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512, and the bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object. Do not put objects that are already provided in the bounding boxes into the background prompt. If needed, you can make reasonable guesses. Please refer to the example below for the desired format.
|
14 |
|
15 |
Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
|
16 |
Objects: [('a green car', [21, 181, 211, 159]), ('a blue truck', [269, 181, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
|
@@ -43,12 +51,20 @@ Background prompt: An oil painting of a living room scene
|
|
43 |
Caption: {prompt}
|
44 |
Objects: """
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def get_lmd_prompt(prompt):
|
47 |
if prompt == "":
|
48 |
-
prompt =
|
49 |
return simplified_prompt.format(prompt=prompt)
|
50 |
|
51 |
def get_layout_image(response):
|
|
|
|
|
52 |
gen_boxes, bg_prompt = parse_input(response)
|
53 |
fig = plt.figure(figsize=(8, 8))
|
54 |
# https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
|
@@ -63,6 +79,35 @@ def get_layout_image(response):
|
|
63 |
plt.clf()
|
64 |
return data
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
def parse_input(text=None):
|
67 |
try:
|
68 |
if "Objects: " in text:
|
@@ -130,30 +175,73 @@ def show_boxes(gen_boxes, bg_prompt=None):
|
|
130 |
|
131 |
draw_boxes(anns)
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
|
|
137 |
<p><b>Tips:</b><p>
|
138 |
<p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
|
139 |
<p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
|
140 |
-
<p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.<p>
|
141 |
-
|
|
|
|
|
|
|
142 |
with gr.Row():
|
143 |
with gr.Column(scale=1):
|
144 |
-
prompt = gr.Textbox(lines=2, label="Prompt for Layout Generation", placeholder=
|
145 |
-
|
146 |
with gr.Column(scale=1):
|
147 |
-
output = gr.Textbox(label="Paste this into ChatGPT (GPT-4
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
-
with gr.Tab("
|
151 |
with gr.Row():
|
152 |
with gr.Column(scale=1):
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
with gr.Column(scale=1):
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
g.launch()
|
|
|
4 |
from matplotlib.patches import Polygon
|
5 |
from matplotlib.collections import PatchCollection
|
6 |
import matplotlib.pyplot as plt
|
7 |
+
from utils.parse import filter_boxes
|
8 |
+
from generation import run as run_ours
|
9 |
+
from baseline import run as run_baseline
|
10 |
+
import torch
|
11 |
+
|
12 |
+
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
13 |
+
if torch.cuda.is_available():
|
14 |
+
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
15 |
|
16 |
box_scale = (512, 512)
|
17 |
size = box_scale
|
18 |
|
19 |
bg_prompt_text = "Background prompt: "
|
20 |
|
21 |
+
simplified_prompt = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512, and the bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object. Do not put objects that are already provided in the bounding boxes into the background prompt. If needed, you can make reasonable guesses. Generate the object descriptions and background prompts in English even if the caption might not be in English. Please refer to the example below for the desired format.
|
22 |
|
23 |
Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
|
24 |
Objects: [('a green car', [21, 181, 211, 159]), ('a blue truck', [269, 181, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
|
|
|
51 |
Caption: {prompt}
|
52 |
Objects: """
|
53 |
|
54 |
+
prompt_placeholder = "A realistic photo of a gray cat and an orange dog on the grass."
|
55 |
+
|
56 |
+
layout_placeholder = """Caption: A realistic photo of a gray cat and an orange dog on the grass.
|
57 |
+
Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
|
58 |
+
Background prompt: A realistic photo of a grassy area."""
|
59 |
+
|
60 |
def get_lmd_prompt(prompt):
|
61 |
if prompt == "":
|
62 |
+
prompt = prompt_placeholder
|
63 |
return simplified_prompt.format(prompt=prompt)
|
64 |
|
65 |
def get_layout_image(response):
|
66 |
+
if response == "":
|
67 |
+
response = layout_placeholder
|
68 |
gen_boxes, bg_prompt = parse_input(response)
|
69 |
fig = plt.figure(figsize=(8, 8))
|
70 |
# https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
|
|
|
79 |
plt.clf()
|
80 |
return data
|
81 |
|
82 |
+
def get_layout_image_gallery(response):
|
83 |
+
return [get_layout_image(response)]
|
84 |
+
|
85 |
+
def get_ours_image(response, seed, fg_seed_start, fg_blending_ratio=0.1, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta=0.3, show_so_imgs=False, scale_boxes=False, gallery=None):
|
86 |
+
if response == "":
|
87 |
+
response = layout_placeholder
|
88 |
+
gen_boxes, bg_prompt = parse_input(response)
|
89 |
+
gen_boxes = filter_boxes(gen_boxes, scale_boxes=scale_boxes)
|
90 |
+
spec = {
|
91 |
+
# prompt is unused
|
92 |
+
'prompt': '',
|
93 |
+
'gen_boxes': gen_boxes,
|
94 |
+
'bg_prompt': bg_prompt
|
95 |
+
}
|
96 |
+
image_np, so_img_list = run_ours(
|
97 |
+
spec, bg_seed=seed, fg_seed_start=fg_seed_start,
|
98 |
+
fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio,
|
99 |
+
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta)
|
100 |
+
images = [image_np]
|
101 |
+
if show_so_imgs:
|
102 |
+
images.extend([np.asarray(so_img) for so_img in so_img_list])
|
103 |
+
return images
|
104 |
+
|
105 |
+
def get_baseline_image(prompt, seed):
|
106 |
+
if prompt == "":
|
107 |
+
prompt = prompt_placeholder
|
108 |
+
image_np = run_baseline(prompt, bg_seed=seed)
|
109 |
+
return [image_np]
|
110 |
+
|
111 |
def parse_input(text=None):
|
112 |
try:
|
113 |
if "Objects: " in text:
|
|
|
175 |
|
176 |
draw_boxes(anns)
|
177 |
|
178 |
+
duplicate_html = '<a style="display:inline-block" href="https://huggingface.co/spaces/longlian/llm-grounded-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>'
|
179 |
+
|
180 |
+
with gr.Blocks(
|
181 |
+
title="LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models"
|
182 |
+
) as g:
|
183 |
+
gr.HTML(f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
|
184 |
+
<h2>LLM + Stable Diffusion => better prompt understanding in text2image generation 🤩</h2>
|
185 |
+
<h2><a href='https://llm-grounded-diffusion.github.io/'>Project Page</a> | <a href='https://bair.berkeley.edu/blog/2023/05/23/lmd/'>5-minute Blog Post</a> | <a href='https://arxiv.org/pdf/2305.13655.pdf'>ArXiv Paper</a> | <a href='https://github.com/TonyLianLong/LLM-groundedDiffusion'>Github</a> | <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</h2>
|
186 |
<p><b>Tips:</b><p>
|
187 |
<p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
|
188 |
<p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
|
189 |
+
<p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.<p>
|
190 |
+
<p>4. Duplicate this space and add GPU to skip the queue and run our model faster. {duplicate_html}</p>
|
191 |
+
<br/>
|
192 |
+
<p>Implementation note: In this demo, we replace the attention manipulation in our layout-guided Stable Diffusion described in our paper with GLIGEN due to much faster inference speed (<b>FlashAttention supported, no backprop needed</b> during inference). Compared to vanilla GLIGEN, we have better coherence. Other parts of text-to-image pipeline, including single object generation and SAM, remain the same. The settings and examples in the prompt are simplified in this demo.</p>""")
|
193 |
+
with gr.Tab("Stage 1. Image Prompt to ChatGPT"):
|
194 |
with gr.Row():
|
195 |
with gr.Column(scale=1):
|
196 |
+
prompt = gr.Textbox(lines=2, label="Prompt for Layout Generation", placeholder=prompt_placeholder)
|
197 |
+
generate_btn = gr.Button("Generate Prompt")
|
198 |
with gr.Column(scale=1):
|
199 |
+
output = gr.Textbox(label="Paste this into ChatGPT (GPT-4 preferred; on Mac, click text and press Command+A and Command+C to copy all)")
|
200 |
+
generate_btn.click(fn=get_lmd_prompt, inputs=prompt, outputs=output, api_name="get_lmd_prompt")
|
201 |
+
|
202 |
+
# with gr.Tab("(Optional) Visualize ChatGPT-generated Layout"):
|
203 |
+
# with gr.Row():
|
204 |
+
# with gr.Column(scale=1):
|
205 |
+
# response = gr.Textbox(lines=5, label="Paste ChatGPT response here", placeholder=layout_placeholder)
|
206 |
+
# visualize_btn = gr.Button("Visualize Layout")
|
207 |
+
# with gr.Column(scale=1):
|
208 |
+
# output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img", css="img {width: 300px}")
|
209 |
+
# visualize_btn.click(fn=get_layout_image, inputs=response, outputs=output, api_name="visualize-layout")
|
210 |
|
211 |
+
with gr.Tab("Stage 2 (New). Layout to Image generation"):
|
212 |
with gr.Row():
|
213 |
with gr.Column(scale=1):
|
214 |
+
response = gr.Textbox(lines=5, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
|
215 |
+
visualize_btn = gr.Button("Visualize Layout")
|
216 |
+
generate_btn = gr.Button("Generate Image from Layout", variant='primary')
|
217 |
+
with gr.Accordion("Advanced options", open=False):
|
218 |
+
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
219 |
+
fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
|
220 |
+
fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
|
221 |
+
frozen_step_ratio = gr.Slider(0, 1, value=0.4, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
|
222 |
+
gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
|
223 |
+
show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False)
|
224 |
with gr.Column(scale=1):
|
225 |
+
gallery = gr.Gallery(
|
226 |
+
label="Generated image", show_label=False, elem_id="gallery"
|
227 |
+
).style(columns=[1], rows=[1], object_fit="contain", preview=True)
|
228 |
+
visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
|
229 |
+
generate_btn.click(fn=get_ours_image, inputs=[response, seed, fg_seed_start, fg_blending_ratio, frozen_step_ratio, gligen_scheduled_sampling_beta, show_so_imgs], outputs=gallery, api_name="layout-to-image")
|
230 |
+
|
231 |
+
with gr.Tab("Baseline: Stable Diffusion"):
|
232 |
+
with gr.Row():
|
233 |
+
with gr.Column(scale=1):
|
234 |
+
sd_prompt = gr.Textbox(lines=2, label="Prompt for baseline SD", placeholder=prompt_placeholder)
|
235 |
+
generate_btn = gr.Button("Generate")
|
236 |
+
with gr.Accordion("Advanced options", open=False):
|
237 |
+
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
238 |
+
# with gr.Column(scale=1):
|
239 |
+
# output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img")
|
240 |
+
with gr.Column(scale=1):
|
241 |
+
gallery = gr.Gallery(
|
242 |
+
label="Generated image", show_label=False, elem_id="gallery2"
|
243 |
+
).style(columns=[1], rows=[1], object_fit="contain", preview=True)
|
244 |
+
generate_btn.click(fn=get_baseline_image, inputs=[sd_prompt, seed], outputs=gallery, api_name="baseline")
|
245 |
+
|
246 |
|
247 |
g.launch()
|
baseline.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original Stable Diffusion (1.4)
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import models
|
5 |
+
from models import pipelines
|
6 |
+
from shared import model_dict
|
7 |
+
|
8 |
+
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
|
9 |
+
|
10 |
+
torch.set_grad_enabled(False)
|
11 |
+
|
12 |
+
height = 512 # default height of Stable Diffusion
|
13 |
+
width = 512 # default width of Stable Diffusion
|
14 |
+
num_inference_steps = 20 # Number of denoising steps
|
15 |
+
guidance_scale = 7.5 # Scale for classifier-free guidance
|
16 |
+
batch_size = 1
|
17 |
+
|
18 |
+
# h, w
|
19 |
+
image_scale = (512, 512)
|
20 |
+
|
21 |
+
bg_negative = 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate'
|
22 |
+
|
23 |
+
def run(prompt, bg_seed=1):
|
24 |
+
print(f"prompt: {prompt}")
|
25 |
+
generator = torch.Generator(models.torch_device).manual_seed(bg_seed)
|
26 |
+
|
27 |
+
prompts = [prompt]
|
28 |
+
input_embeddings = models.encode_prompts(prompts=prompts, tokenizer=tokenizer, text_encoder=text_encoder, negative_prompt=bg_negative)
|
29 |
+
|
30 |
+
generator = torch.manual_seed(1) # Seed generator to create the inital latent noise
|
31 |
+
latents = models.get_unscaled_latents(batch_size, unet.config.in_channels, height, width, generator, dtype)
|
32 |
+
|
33 |
+
latents = latents * scheduler.init_noise_sigma
|
34 |
+
|
35 |
+
pipelines.gligen_enable_fuser(model_dict['unet'], enabled=False)
|
36 |
+
_, images = pipelines.generate(
|
37 |
+
model_dict, latents, input_embeddings, num_inference_steps,
|
38 |
+
guidance_scale=guidance_scale
|
39 |
+
)
|
40 |
+
|
41 |
+
return images[0]
|
generation.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version = "v3.0"
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import models
|
6 |
+
from models import load_sd
|
7 |
+
import utils
|
8 |
+
from models import pipelines, sam
|
9 |
+
from utils import parse, latents
|
10 |
+
from shared import model_dict, sam_model_dict
|
11 |
+
|
12 |
+
verbose = False
|
13 |
+
|
14 |
+
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
|
15 |
+
|
16 |
+
model_dict.update(sam_model_dict)
|
17 |
+
|
18 |
+
|
19 |
+
# Hyperparams
|
20 |
+
height = 512 # default height of Stable Diffusion
|
21 |
+
width = 512 # default width of Stable Diffusion
|
22 |
+
H, W = height // 8, width // 8 # size of the latent
|
23 |
+
num_inference_steps = 20 # Number of denoising steps
|
24 |
+
guidance_scale = 7.5 # Scale for classifier-free guidance
|
25 |
+
|
26 |
+
# batch size that is not 1 is not supported
|
27 |
+
so_batch_size = 1
|
28 |
+
overall_batch_size = 1
|
29 |
+
|
30 |
+
# discourage masks with confidence below
|
31 |
+
discourage_mask_below_confidence = 0.85
|
32 |
+
|
33 |
+
# discourage masks with iou (with coarse binarized attention mask) below
|
34 |
+
discourage_mask_below_coarse_iou = 0.25
|
35 |
+
|
36 |
+
run_ind = None
|
37 |
+
|
38 |
+
|
39 |
+
def generate_single_object_with_box(prompt, box, phrase, word, input_latents, input_embeddings,
|
40 |
+
sam_refine_kwargs, gligen_scheduled_sampling_beta=0.3,
|
41 |
+
verbose=False, visualize=True):
|
42 |
+
|
43 |
+
bboxes, phrases, words = [box], [phrase], [word]
|
44 |
+
|
45 |
+
latents, single_object_images, single_object_pil_images_box_ann, latents_all = pipelines.generate_gligen(
|
46 |
+
model_dict, input_latents, input_embeddings, num_inference_steps, bboxes, phrases, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
47 |
+
guidance_scale=guidance_scale, return_saved_cross_attn=False,
|
48 |
+
return_box_vis=True, save_all_latents=True
|
49 |
+
)
|
50 |
+
|
51 |
+
mask_selected, conf_score_selected = sam.sam_refine_box(sam_input_image=single_object_images[0], box=box, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
|
52 |
+
|
53 |
+
mask_selected_tensor = torch.tensor(mask_selected)
|
54 |
+
|
55 |
+
return latents_all, mask_selected_tensor, single_object_pil_images_box_ann[0]
|
56 |
+
|
57 |
+
def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_list, so_input_embeddings, verbose=False, **kwargs):
|
58 |
+
latents_all_list, mask_tensor_list, so_img_list = [], [], []
|
59 |
+
|
60 |
+
if not so_prompt_phrase_word_box_list:
|
61 |
+
return latents_all_list, mask_tensor_list
|
62 |
+
|
63 |
+
so_uncond_embeddings, so_cond_embeddings = so_input_embeddings
|
64 |
+
|
65 |
+
for idx, ((prompt, phrase, word, box), input_latents) in enumerate(zip(so_prompt_phrase_word_box_list, input_latents_list)):
|
66 |
+
so_current_cond_embeddings = so_cond_embeddings[idx:idx+1]
|
67 |
+
so_current_text_embeddings = torch.cat([so_uncond_embeddings, so_current_cond_embeddings], dim=0)
|
68 |
+
so_current_input_embeddings = so_current_text_embeddings, so_uncond_embeddings, so_current_cond_embeddings
|
69 |
+
|
70 |
+
latents_all, mask_tensor, so_img = generate_single_object_with_box(prompt, box, phrase, word, input_latents, input_embeddings=so_current_input_embeddings, verbose=verbose, **kwargs)
|
71 |
+
latents_all_list.append(latents_all)
|
72 |
+
mask_tensor_list.append(mask_tensor)
|
73 |
+
so_img_list.append(so_img)
|
74 |
+
|
75 |
+
return latents_all_list, mask_tensor_list, so_img_list
|
76 |
+
|
77 |
+
|
78 |
+
# Note: need to keep the supervision, especially the box corrdinates, corresponds to each other in single object and overall.
|
79 |
+
|
80 |
+
def run(
|
81 |
+
spec, bg_seed = 1, fg_seed_start = 20, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta = 0.3,
|
82 |
+
so_center_box = False, fg_blending_ratio = 0.1, so_horizontal_center_only = True,
|
83 |
+
align_with_overall_bboxes = False, horizontal_shift_only = True
|
84 |
+
):
|
85 |
+
"""
|
86 |
+
so_center_box: using centered box in single object generation
|
87 |
+
so_horizontal_center_only: move to the center horizontally only
|
88 |
+
|
89 |
+
align_with_overall_bboxes: Align the center of the mask, latents, and cross-attention with the center of the box in overall bboxes
|
90 |
+
horizontal_shift_only: only shift horizontally for the alignment of mask, latents, and cross-attention
|
91 |
+
"""
|
92 |
+
|
93 |
+
print("generation:", spec, bg_seed, fg_seed_start, frozen_step_ratio, gligen_scheduled_sampling_beta)
|
94 |
+
|
95 |
+
frozen_step_ratio = min(max(frozen_step_ratio, 0.), 1.)
|
96 |
+
frozen_steps = int(num_inference_steps * frozen_step_ratio)
|
97 |
+
|
98 |
+
if True:
|
99 |
+
so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes = parse.convert_spec(spec, height, width, verbose=verbose)
|
100 |
+
|
101 |
+
overall_phrases, overall_words, overall_bboxes = [item[0] for item in overall_phrases_words_bboxes], [item[1] for item in overall_phrases_words_bboxes], [item[2] for item in overall_phrases_words_bboxes]
|
102 |
+
|
103 |
+
# The so box is centered but the overall boxes are not (since we need to place to the right place).
|
104 |
+
if so_center_box:
|
105 |
+
so_prompt_phrase_word_box_list = [(prompt, phrase, word, utils.get_centered_box(bbox, horizontal_center_only=so_horizontal_center_only)) for prompt, phrase, word, bbox in so_prompt_phrase_word_box_list]
|
106 |
+
if verbose:
|
107 |
+
print(f"centered so_prompt_phrase_word_box_list: {so_prompt_phrase_word_box_list}")
|
108 |
+
so_boxes = [item[-1] for item in so_prompt_phrase_word_box_list]
|
109 |
+
|
110 |
+
if True:
|
111 |
+
so_negative_prompt = "artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate, two, many, group, occlusion, occluded, side, border, collate"
|
112 |
+
overall_negative_prompt = "artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate"
|
113 |
+
else:
|
114 |
+
so_negative_prompt = ""
|
115 |
+
overall_negative_prompt = ""
|
116 |
+
|
117 |
+
sam_refine_kwargs = dict(
|
118 |
+
discourage_mask_below_confidence=discourage_mask_below_confidence, discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
|
119 |
+
height=height, width=width, H=H, W=W
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
# Note that so and overall use different negative prompts
|
124 |
+
|
125 |
+
so_prompts = [item[0] for item in so_prompt_phrase_word_box_list]
|
126 |
+
if so_prompts:
|
127 |
+
so_input_embeddings = models.encode_prompts(prompts=so_prompts, tokenizer=tokenizer, text_encoder=text_encoder, negative_prompt=so_negative_prompt, one_uncond_input_only=True)
|
128 |
+
else:
|
129 |
+
so_input_embeddings = []
|
130 |
+
|
131 |
+
overall_input_embeddings = models.encode_prompts(prompts=[overall_prompt], tokenizer=tokenizer, negative_prompt=overall_negative_prompt, text_encoder=text_encoder)
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
input_latents_list, latents_bg = latents.get_input_latents_list(
|
137 |
+
model_dict, bg_seed=bg_seed, fg_seed_start=fg_seed_start,
|
138 |
+
so_boxes=so_boxes, fg_blending_ratio=fg_blending_ratio, height=height, width=width, verbose=False
|
139 |
+
)
|
140 |
+
latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
|
141 |
+
so_prompt_phrase_word_box_list, input_latents_list,
|
142 |
+
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
143 |
+
sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, verbose=verbose
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
composed_latents, foreground_indices, offset_list = latents.compose_latents_with_alignment(
|
149 |
+
model_dict, latents_all_list, mask_tensor_list, num_inference_steps,
|
150 |
+
overall_batch_size, height, width, latents_bg=latents_bg,
|
151 |
+
align_with_overall_bboxes=align_with_overall_bboxes, overall_bboxes=overall_bboxes,
|
152 |
+
horizontal_shift_only=horizontal_shift_only
|
153 |
+
)
|
154 |
+
|
155 |
+
overall_bboxes_flattened, overall_phrases_flattened = [], []
|
156 |
+
for overall_bboxes_item, overall_phrase in zip(overall_bboxes, overall_phrases):
|
157 |
+
for overall_bbox in overall_bboxes_item:
|
158 |
+
overall_bboxes_flattened.append(overall_bbox)
|
159 |
+
overall_phrases_flattened.append(overall_phrase)
|
160 |
+
|
161 |
+
# Generate with composed latents
|
162 |
+
|
163 |
+
# Foreground should be frozen
|
164 |
+
frozen_mask = foreground_indices != 0
|
165 |
+
|
166 |
+
regen_latents, images = pipelines.generate_gligen(
|
167 |
+
model_dict, composed_latents, overall_input_embeddings, num_inference_steps,
|
168 |
+
overall_bboxes_flattened, overall_phrases_flattened, guidance_scale=guidance_scale,
|
169 |
+
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
170 |
+
frozen_steps=frozen_steps, frozen_mask=frozen_mask
|
171 |
+
)
|
172 |
+
|
173 |
+
print(f"Generation with spatial guidance from input latents and first {frozen_steps} steps frozen (directly from the composed latents input)")
|
174 |
+
print("Generation from composed latents (with semantic guidance)")
|
175 |
+
|
176 |
+
# display(Image.fromarray(images[0]), "img", run_ind)
|
177 |
+
|
178 |
+
return images[0], so_img_list
|
179 |
+
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .models import *
|
models/attention.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import maybe_allow_in_graph
|
21 |
+
from .attention_processor import Attention
|
22 |
+
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
|
23 |
+
|
24 |
+
# https://github.com/gligen/diffusers/blob/23a9a0fab1b48752c7b9bcc98f6fe3b1d8fa7990/src/diffusers/models/attention.py
|
25 |
+
class GatedSelfAttentionDense(nn.Module):
|
26 |
+
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
30 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
31 |
+
|
32 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
33 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
34 |
+
|
35 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
36 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
37 |
+
|
38 |
+
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
39 |
+
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
40 |
+
|
41 |
+
self.enabled = True
|
42 |
+
|
43 |
+
def forward(self, x, objs, fuser_attn_kwargs={}):
|
44 |
+
if not self.enabled:
|
45 |
+
return x
|
46 |
+
|
47 |
+
n_visual = x.shape[1]
|
48 |
+
objs = self.linear(objs)
|
49 |
+
|
50 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)), **fuser_attn_kwargs)[:, :n_visual, :]
|
51 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
52 |
+
|
53 |
+
return x
|
54 |
+
|
55 |
+
@maybe_allow_in_graph
|
56 |
+
class BasicTransformerBlock(nn.Module):
|
57 |
+
r"""
|
58 |
+
A basic Transformer block.
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
dim (`int`): The number of channels in the input and output.
|
62 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
63 |
+
attention_head_dim (`int`): The number of channels in each head.
|
64 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
65 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
66 |
+
only_cross_attention (`bool`, *optional*):
|
67 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
68 |
+
double_self_attention (`bool`, *optional*):
|
69 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
70 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
71 |
+
num_embeds_ada_norm (:
|
72 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
73 |
+
attention_bias (:
|
74 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
dim: int,
|
80 |
+
num_attention_heads: int,
|
81 |
+
attention_head_dim: int,
|
82 |
+
dropout=0.0,
|
83 |
+
cross_attention_dim: Optional[int] = None,
|
84 |
+
activation_fn: str = "geglu",
|
85 |
+
num_embeds_ada_norm: Optional[int] = None,
|
86 |
+
attention_bias: bool = False,
|
87 |
+
only_cross_attention: bool = False,
|
88 |
+
double_self_attention: bool = False,
|
89 |
+
upcast_attention: bool = False,
|
90 |
+
norm_elementwise_affine: bool = True,
|
91 |
+
norm_type: str = "layer_norm",
|
92 |
+
final_dropout: bool = False,
|
93 |
+
use_gated_attention: bool = False,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
self.only_cross_attention = only_cross_attention
|
97 |
+
|
98 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
99 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
100 |
+
|
101 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
102 |
+
raise ValueError(
|
103 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
104 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
105 |
+
)
|
106 |
+
|
107 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
108 |
+
# 1. Self-Attn
|
109 |
+
if self.use_ada_layer_norm:
|
110 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
111 |
+
elif self.use_ada_layer_norm_zero:
|
112 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
113 |
+
else:
|
114 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
115 |
+
self.attn1 = Attention(
|
116 |
+
query_dim=dim,
|
117 |
+
heads=num_attention_heads,
|
118 |
+
dim_head=attention_head_dim,
|
119 |
+
dropout=dropout,
|
120 |
+
bias=attention_bias,
|
121 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
122 |
+
upcast_attention=upcast_attention,
|
123 |
+
)
|
124 |
+
|
125 |
+
# 2. Cross-Attn
|
126 |
+
if cross_attention_dim is not None or double_self_attention:
|
127 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
128 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
129 |
+
# the second cross attention block.
|
130 |
+
self.norm2 = (
|
131 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
132 |
+
if self.use_ada_layer_norm
|
133 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
134 |
+
)
|
135 |
+
self.attn2 = Attention(
|
136 |
+
query_dim=dim,
|
137 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
138 |
+
heads=num_attention_heads,
|
139 |
+
dim_head=attention_head_dim,
|
140 |
+
dropout=dropout,
|
141 |
+
bias=attention_bias,
|
142 |
+
upcast_attention=upcast_attention,
|
143 |
+
) # is self-attn if encoder_hidden_states is none
|
144 |
+
else:
|
145 |
+
self.norm2 = None
|
146 |
+
self.attn2 = None
|
147 |
+
|
148 |
+
# 3. Feed-forward
|
149 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
150 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
151 |
+
|
152 |
+
# 4. Fuser
|
153 |
+
if use_gated_attention:
|
154 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
155 |
+
|
156 |
+
def forward(
|
157 |
+
self,
|
158 |
+
hidden_states: torch.FloatTensor,
|
159 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
160 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
161 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
162 |
+
timestep: Optional[torch.LongTensor] = None,
|
163 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
164 |
+
class_labels: Optional[torch.LongTensor] = None,
|
165 |
+
return_cross_attention_probs: bool = None,
|
166 |
+
):
|
167 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
168 |
+
|
169 |
+
# 0. Prepare GLIGEN inputs
|
170 |
+
if 'gligen' in cross_attention_kwargs:
|
171 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
172 |
+
gligen_kwargs = cross_attention_kwargs.pop('gligen', None)
|
173 |
+
else:
|
174 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
175 |
+
gligen_kwargs = None
|
176 |
+
|
177 |
+
# 1. Self-Attention
|
178 |
+
if self.use_ada_layer_norm:
|
179 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
180 |
+
elif self.use_ada_layer_norm_zero:
|
181 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
182 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
norm_hidden_states = self.norm1(hidden_states)
|
186 |
+
|
187 |
+
attn_output = self.attn1(
|
188 |
+
norm_hidden_states,
|
189 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
190 |
+
attention_mask=attention_mask,
|
191 |
+
**cross_attention_kwargs,
|
192 |
+
)
|
193 |
+
if self.use_ada_layer_norm_zero:
|
194 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
195 |
+
hidden_states = attn_output + hidden_states
|
196 |
+
|
197 |
+
# 1.5 GLIGEN Control
|
198 |
+
if gligen_kwargs is not None:
|
199 |
+
# print(gligen_kwargs)
|
200 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs['objs'], fuser_attn_kwargs=gligen_kwargs.get("fuser_attn_kwargs", {}))
|
201 |
+
# 1.5 ends
|
202 |
+
|
203 |
+
# 2. Cross-Attention
|
204 |
+
if self.attn2 is not None:
|
205 |
+
norm_hidden_states = (
|
206 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
207 |
+
)
|
208 |
+
|
209 |
+
attn_output = self.attn2(
|
210 |
+
norm_hidden_states,
|
211 |
+
encoder_hidden_states=encoder_hidden_states,
|
212 |
+
attention_mask=encoder_attention_mask,
|
213 |
+
return_attntion_probs=return_cross_attention_probs,
|
214 |
+
**cross_attention_kwargs,
|
215 |
+
)
|
216 |
+
|
217 |
+
if return_cross_attention_probs:
|
218 |
+
attn_output, cross_attention_probs = attn_output
|
219 |
+
|
220 |
+
hidden_states = attn_output + hidden_states
|
221 |
+
|
222 |
+
# 3. Feed-forward
|
223 |
+
norm_hidden_states = self.norm3(hidden_states)
|
224 |
+
|
225 |
+
if self.use_ada_layer_norm_zero:
|
226 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
227 |
+
|
228 |
+
ff_output = self.ff(norm_hidden_states)
|
229 |
+
|
230 |
+
if self.use_ada_layer_norm_zero:
|
231 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
232 |
+
|
233 |
+
hidden_states = ff_output + hidden_states
|
234 |
+
|
235 |
+
if return_cross_attention_probs and self.attn2 is not None:
|
236 |
+
return hidden_states, cross_attention_probs
|
237 |
+
return hidden_states
|
238 |
+
|
239 |
+
|
240 |
+
class FeedForward(nn.Module):
|
241 |
+
r"""
|
242 |
+
A feed-forward layer.
|
243 |
+
|
244 |
+
Parameters:
|
245 |
+
dim (`int`): The number of channels in the input.
|
246 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
247 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
248 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
249 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
250 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
251 |
+
"""
|
252 |
+
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
dim: int,
|
256 |
+
dim_out: Optional[int] = None,
|
257 |
+
mult: int = 4,
|
258 |
+
dropout: float = 0.0,
|
259 |
+
activation_fn: str = "geglu",
|
260 |
+
final_dropout: bool = False,
|
261 |
+
):
|
262 |
+
super().__init__()
|
263 |
+
inner_dim = int(dim * mult)
|
264 |
+
dim_out = dim_out if dim_out is not None else dim
|
265 |
+
|
266 |
+
if activation_fn == "gelu":
|
267 |
+
act_fn = GELU(dim, inner_dim)
|
268 |
+
if activation_fn == "gelu-approximate":
|
269 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
270 |
+
elif activation_fn == "geglu":
|
271 |
+
act_fn = GEGLU(dim, inner_dim)
|
272 |
+
elif activation_fn == "geglu-approximate":
|
273 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
274 |
+
|
275 |
+
self.net = nn.ModuleList([])
|
276 |
+
# project in
|
277 |
+
self.net.append(act_fn)
|
278 |
+
# project dropout
|
279 |
+
self.net.append(nn.Dropout(dropout))
|
280 |
+
# project out
|
281 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
282 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
283 |
+
if final_dropout:
|
284 |
+
self.net.append(nn.Dropout(dropout))
|
285 |
+
|
286 |
+
def forward(self, hidden_states):
|
287 |
+
for module in self.net:
|
288 |
+
hidden_states = module(hidden_states)
|
289 |
+
return hidden_states
|
290 |
+
|
291 |
+
|
292 |
+
class GELU(nn.Module):
|
293 |
+
r"""
|
294 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
298 |
+
super().__init__()
|
299 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
300 |
+
self.approximate = approximate
|
301 |
+
|
302 |
+
def gelu(self, gate):
|
303 |
+
if gate.device.type != "mps":
|
304 |
+
return F.gelu(gate, approximate=self.approximate)
|
305 |
+
# mps: gelu is not implemented for float16
|
306 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
307 |
+
|
308 |
+
def forward(self, hidden_states):
|
309 |
+
hidden_states = self.proj(hidden_states)
|
310 |
+
hidden_states = self.gelu(hidden_states)
|
311 |
+
return hidden_states
|
312 |
+
|
313 |
+
|
314 |
+
class GEGLU(nn.Module):
|
315 |
+
r"""
|
316 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
317 |
+
|
318 |
+
Parameters:
|
319 |
+
dim_in (`int`): The number of channels in the input.
|
320 |
+
dim_out (`int`): The number of channels in the output.
|
321 |
+
"""
|
322 |
+
|
323 |
+
def __init__(self, dim_in: int, dim_out: int):
|
324 |
+
super().__init__()
|
325 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
326 |
+
|
327 |
+
def gelu(self, gate):
|
328 |
+
if gate.device.type != "mps":
|
329 |
+
return F.gelu(gate)
|
330 |
+
# mps: gelu is not implemented for float16
|
331 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
332 |
+
|
333 |
+
def forward(self, hidden_states):
|
334 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
335 |
+
return hidden_states * self.gelu(gate)
|
336 |
+
|
337 |
+
|
338 |
+
class ApproximateGELU(nn.Module):
|
339 |
+
"""
|
340 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
341 |
+
|
342 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
343 |
+
"""
|
344 |
+
|
345 |
+
def __init__(self, dim_in: int, dim_out: int):
|
346 |
+
super().__init__()
|
347 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
348 |
+
|
349 |
+
def forward(self, x):
|
350 |
+
x = self.proj(x)
|
351 |
+
return x * torch.sigmoid(1.702 * x)
|
352 |
+
|
353 |
+
|
354 |
+
class AdaLayerNorm(nn.Module):
|
355 |
+
"""
|
356 |
+
Norm layer modified to incorporate timestep embeddings.
|
357 |
+
"""
|
358 |
+
|
359 |
+
def __init__(self, embedding_dim, num_embeddings):
|
360 |
+
super().__init__()
|
361 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
362 |
+
self.silu = nn.SiLU()
|
363 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
364 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
365 |
+
|
366 |
+
def forward(self, x, timestep):
|
367 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
368 |
+
scale, shift = torch.chunk(emb, 2)
|
369 |
+
x = self.norm(x) * (1 + scale) + shift
|
370 |
+
return x
|
371 |
+
|
372 |
+
|
373 |
+
class AdaLayerNormZero(nn.Module):
|
374 |
+
"""
|
375 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
376 |
+
"""
|
377 |
+
|
378 |
+
def __init__(self, embedding_dim, num_embeddings):
|
379 |
+
super().__init__()
|
380 |
+
|
381 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
382 |
+
|
383 |
+
self.silu = nn.SiLU()
|
384 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
385 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
386 |
+
|
387 |
+
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
388 |
+
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
389 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
390 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
391 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
392 |
+
|
models/attention_processor.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import warnings
|
15 |
+
from typing import Callable, Optional, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.utils import deprecate, logging, maybe_allow_in_graph
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
24 |
+
|
25 |
+
@maybe_allow_in_graph
|
26 |
+
class Attention(nn.Module):
|
27 |
+
r"""
|
28 |
+
A cross attention layer.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
query_dim (`int`): The number of channels in the query.
|
32 |
+
cross_attention_dim (`int`, *optional*):
|
33 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
34 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
35 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
36 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
37 |
+
bias (`bool`, *optional*, defaults to False):
|
38 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
query_dim: int,
|
44 |
+
cross_attention_dim: Optional[int] = None,
|
45 |
+
heads: int = 8,
|
46 |
+
dim_head: int = 64,
|
47 |
+
dropout: float = 0.0,
|
48 |
+
bias=False,
|
49 |
+
upcast_attention: bool = False,
|
50 |
+
upcast_softmax: bool = False,
|
51 |
+
cross_attention_norm: Optional[str] = None,
|
52 |
+
cross_attention_norm_num_groups: int = 32,
|
53 |
+
added_kv_proj_dim: Optional[int] = None,
|
54 |
+
norm_num_groups: Optional[int] = None,
|
55 |
+
spatial_norm_dim: Optional[int] = None,
|
56 |
+
out_bias: bool = True,
|
57 |
+
scale_qk: bool = True,
|
58 |
+
only_cross_attention: bool = False,
|
59 |
+
eps: float = 1e-5,
|
60 |
+
rescale_output_factor: float = 1.0,
|
61 |
+
residual_connection: bool = False,
|
62 |
+
_from_deprecated_attn_block=False,
|
63 |
+
processor: Optional["AttnProcessor"] = None,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
inner_dim = dim_head * heads
|
67 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
68 |
+
self.upcast_attention = upcast_attention
|
69 |
+
self.upcast_softmax = upcast_softmax
|
70 |
+
self.rescale_output_factor = rescale_output_factor
|
71 |
+
self.residual_connection = residual_connection
|
72 |
+
|
73 |
+
# we make use of this private variable to know whether this class is loaded
|
74 |
+
# with an deprecated state dict so that we can convert it on the fly
|
75 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
76 |
+
|
77 |
+
self.scale_qk = scale_qk
|
78 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
79 |
+
|
80 |
+
self.heads = heads
|
81 |
+
# for slice_size > 0 the attention score computation
|
82 |
+
# is split across the batch axis to save memory
|
83 |
+
# You can set slice_size with `set_attention_slice`
|
84 |
+
self.sliceable_head_dim = heads
|
85 |
+
|
86 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
87 |
+
self.only_cross_attention = only_cross_attention
|
88 |
+
|
89 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
90 |
+
raise ValueError(
|
91 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
92 |
+
)
|
93 |
+
|
94 |
+
if norm_num_groups is not None:
|
95 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
96 |
+
else:
|
97 |
+
self.group_norm = None
|
98 |
+
|
99 |
+
if spatial_norm_dim is not None:
|
100 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
101 |
+
else:
|
102 |
+
self.spatial_norm = None
|
103 |
+
|
104 |
+
if cross_attention_norm is None:
|
105 |
+
self.norm_cross = None
|
106 |
+
elif cross_attention_norm == "layer_norm":
|
107 |
+
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
108 |
+
elif cross_attention_norm == "group_norm":
|
109 |
+
if self.added_kv_proj_dim is not None:
|
110 |
+
# The given `encoder_hidden_states` are initially of shape
|
111 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
112 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
113 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
114 |
+
# the number of channels for the group norm.
|
115 |
+
norm_cross_num_channels = added_kv_proj_dim
|
116 |
+
else:
|
117 |
+
norm_cross_num_channels = cross_attention_dim
|
118 |
+
|
119 |
+
self.norm_cross = nn.GroupNorm(
|
120 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
raise ValueError(
|
124 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
125 |
+
)
|
126 |
+
|
127 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
128 |
+
|
129 |
+
if not self.only_cross_attention:
|
130 |
+
# only relevant for the `AddedKVProcessor` classes
|
131 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
132 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
133 |
+
else:
|
134 |
+
self.to_k = None
|
135 |
+
self.to_v = None
|
136 |
+
|
137 |
+
if self.added_kv_proj_dim is not None:
|
138 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
139 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
140 |
+
|
141 |
+
self.to_out = nn.ModuleList([])
|
142 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
143 |
+
self.to_out.append(nn.Dropout(dropout))
|
144 |
+
|
145 |
+
# set attention processor
|
146 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
147 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
148 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
149 |
+
if processor is None:
|
150 |
+
# processor = (
|
151 |
+
# AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
152 |
+
# )
|
153 |
+
# Note: efficient attention is not used. We can use efficient attention to speed up.
|
154 |
+
processor = AttnProcessor()
|
155 |
+
self.set_processor(processor)
|
156 |
+
|
157 |
+
def set_processor(self, processor: "AttnProcessor"):
|
158 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
159 |
+
# pop `processor` from `self._modules`
|
160 |
+
if (
|
161 |
+
hasattr(self, "processor")
|
162 |
+
and isinstance(self.processor, torch.nn.Module)
|
163 |
+
and not isinstance(processor, torch.nn.Module)
|
164 |
+
):
|
165 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
166 |
+
self._modules.pop("processor")
|
167 |
+
|
168 |
+
self.processor = processor
|
169 |
+
|
170 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, return_attntion_probs=False, **cross_attention_kwargs):
|
171 |
+
# The `Attention` class can call different attention processors / attention functions
|
172 |
+
# here we simply pass along all tensors to the selected processor class
|
173 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
174 |
+
return self.processor(
|
175 |
+
self,
|
176 |
+
hidden_states,
|
177 |
+
encoder_hidden_states=encoder_hidden_states,
|
178 |
+
attention_mask=attention_mask,
|
179 |
+
return_attntion_probs=return_attntion_probs,
|
180 |
+
**cross_attention_kwargs,
|
181 |
+
)
|
182 |
+
|
183 |
+
def batch_to_head_dim(self, tensor):
|
184 |
+
head_size = self.heads
|
185 |
+
batch_size, seq_len, dim = tensor.shape
|
186 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
187 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
188 |
+
return tensor
|
189 |
+
|
190 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
191 |
+
head_size = self.heads
|
192 |
+
batch_size, seq_len, dim = tensor.shape
|
193 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
194 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
195 |
+
|
196 |
+
if out_dim == 3:
|
197 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
198 |
+
|
199 |
+
return tensor
|
200 |
+
|
201 |
+
def get_attention_scores(self, query, key, attention_mask=None):
|
202 |
+
dtype = query.dtype
|
203 |
+
if self.upcast_attention:
|
204 |
+
query = query.float()
|
205 |
+
key = key.float()
|
206 |
+
|
207 |
+
if attention_mask is None:
|
208 |
+
baddbmm_input = torch.empty(
|
209 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
210 |
+
)
|
211 |
+
beta = 0
|
212 |
+
else:
|
213 |
+
baddbmm_input = attention_mask
|
214 |
+
beta = 1
|
215 |
+
|
216 |
+
attention_scores = torch.baddbmm(
|
217 |
+
baddbmm_input,
|
218 |
+
query,
|
219 |
+
key.transpose(-1, -2),
|
220 |
+
beta=beta,
|
221 |
+
alpha=self.scale,
|
222 |
+
)
|
223 |
+
del baddbmm_input
|
224 |
+
|
225 |
+
if self.upcast_softmax:
|
226 |
+
attention_scores = attention_scores.float()
|
227 |
+
|
228 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
229 |
+
del attention_scores
|
230 |
+
|
231 |
+
attention_probs = attention_probs.to(dtype)
|
232 |
+
|
233 |
+
return attention_probs
|
234 |
+
|
235 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
236 |
+
if batch_size is None:
|
237 |
+
deprecate(
|
238 |
+
"batch_size=None",
|
239 |
+
"0.0.15",
|
240 |
+
(
|
241 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
242 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
243 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
244 |
+
),
|
245 |
+
)
|
246 |
+
batch_size = 1
|
247 |
+
|
248 |
+
head_size = self.heads
|
249 |
+
if attention_mask is None:
|
250 |
+
return attention_mask
|
251 |
+
|
252 |
+
current_length: int = attention_mask.shape[-1]
|
253 |
+
if current_length != target_length:
|
254 |
+
if attention_mask.device.type == "mps":
|
255 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
256 |
+
# Instead, we can manually construct the padding tensor.
|
257 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
258 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
259 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
260 |
+
else:
|
261 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
262 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
263 |
+
# remaining_length: int = target_length - current_length
|
264 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
265 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
266 |
+
|
267 |
+
if out_dim == 3:
|
268 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
269 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
270 |
+
elif out_dim == 4:
|
271 |
+
attention_mask = attention_mask.unsqueeze(1)
|
272 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
273 |
+
|
274 |
+
return attention_mask
|
275 |
+
|
276 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
277 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
278 |
+
|
279 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
280 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
281 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
282 |
+
# Group norm norms along the channels dimension and expects
|
283 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
284 |
+
# to norm along the hidden dimension, so we need to move
|
285 |
+
# (batch_size, sequence_length, hidden_size) ->
|
286 |
+
# (batch_size, hidden_size, sequence_length)
|
287 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
288 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
289 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
290 |
+
else:
|
291 |
+
assert False
|
292 |
+
|
293 |
+
return encoder_hidden_states
|
294 |
+
|
295 |
+
|
296 |
+
class AttnProcessor:
|
297 |
+
r"""
|
298 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
299 |
+
"""
|
300 |
+
|
301 |
+
def __init__(self):
|
302 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
303 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
304 |
+
|
305 |
+
def __call_fast__(
|
306 |
+
self,
|
307 |
+
attn: Attention,
|
308 |
+
hidden_states,
|
309 |
+
encoder_hidden_states=None,
|
310 |
+
attention_mask=None,
|
311 |
+
temb=None,
|
312 |
+
):
|
313 |
+
residual = hidden_states
|
314 |
+
|
315 |
+
if attn.spatial_norm is not None:
|
316 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
317 |
+
|
318 |
+
input_ndim = hidden_states.ndim
|
319 |
+
|
320 |
+
if input_ndim == 4:
|
321 |
+
batch_size, channel, height, width = hidden_states.shape
|
322 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
323 |
+
|
324 |
+
batch_size, sequence_length, _ = (
|
325 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
326 |
+
)
|
327 |
+
inner_dim = hidden_states.shape[-1]
|
328 |
+
|
329 |
+
if attention_mask is not None:
|
330 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
331 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
332 |
+
# (batch, heads, source_length, target_length)
|
333 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
334 |
+
|
335 |
+
if attn.group_norm is not None:
|
336 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
337 |
+
|
338 |
+
query = attn.to_q(hidden_states)
|
339 |
+
|
340 |
+
if encoder_hidden_states is None:
|
341 |
+
encoder_hidden_states = hidden_states
|
342 |
+
elif attn.norm_cross:
|
343 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
344 |
+
|
345 |
+
key = attn.to_k(encoder_hidden_states)
|
346 |
+
value = attn.to_v(encoder_hidden_states)
|
347 |
+
|
348 |
+
head_dim = inner_dim // attn.heads
|
349 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
350 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
351 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
352 |
+
|
353 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
354 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
355 |
+
hidden_states = F.scaled_dot_product_attention(
|
356 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
357 |
+
)
|
358 |
+
|
359 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
360 |
+
hidden_states = hidden_states.to(query.dtype)
|
361 |
+
|
362 |
+
# linear proj
|
363 |
+
hidden_states = attn.to_out[0](hidden_states)
|
364 |
+
# dropout
|
365 |
+
hidden_states = attn.to_out[1](hidden_states)
|
366 |
+
|
367 |
+
if input_ndim == 4:
|
368 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
369 |
+
|
370 |
+
if attn.residual_connection:
|
371 |
+
hidden_states = hidden_states + residual
|
372 |
+
|
373 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
374 |
+
|
375 |
+
return hidden_states
|
376 |
+
|
377 |
+
def __call__(
|
378 |
+
self,
|
379 |
+
attn: Attention,
|
380 |
+
hidden_states,
|
381 |
+
encoder_hidden_states=None,
|
382 |
+
attention_mask=None,
|
383 |
+
temb=None,
|
384 |
+
return_attntion_probs=False,
|
385 |
+
attn_key=None,
|
386 |
+
attn_process_fn=None,
|
387 |
+
return_cond_ca_only=False,
|
388 |
+
return_token_ca_only=None,
|
389 |
+
offload_cross_attn_to_cpu=False,
|
390 |
+
save_attn_to_dict=None,
|
391 |
+
save_keys=None,
|
392 |
+
enable_flash_attn=True,
|
393 |
+
):
|
394 |
+
"""
|
395 |
+
attn_key: current key (a tuple of hierarchy index (up/mid/down, stage id, block id, sub-block id), sub block id should always be 0 in SD UNet)
|
396 |
+
save_attn_to_dict: pass in a dict to save to dict
|
397 |
+
"""
|
398 |
+
cross_attn = encoder_hidden_states is not None
|
399 |
+
|
400 |
+
if (not cross_attn) or (
|
401 |
+
(attn_process_fn is None)
|
402 |
+
and not (save_attn_to_dict is not None and (save_keys is None or (tuple(attn_key) in save_keys)))
|
403 |
+
and not return_attntion_probs):
|
404 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=enable_flash_attn, enable_math=True, enable_mem_efficient=enable_flash_attn):
|
405 |
+
return self.__call_fast__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
406 |
+
|
407 |
+
residual = hidden_states
|
408 |
+
|
409 |
+
if attn.spatial_norm is not None:
|
410 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
411 |
+
|
412 |
+
input_ndim = hidden_states.ndim
|
413 |
+
|
414 |
+
if input_ndim == 4:
|
415 |
+
batch_size, channel, height, width = hidden_states.shape
|
416 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
417 |
+
|
418 |
+
batch_size, sequence_length, _ = (
|
419 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
420 |
+
)
|
421 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
422 |
+
|
423 |
+
if attn.group_norm is not None:
|
424 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
425 |
+
|
426 |
+
query = attn.to_q(hidden_states)
|
427 |
+
|
428 |
+
if encoder_hidden_states is None:
|
429 |
+
encoder_hidden_states = hidden_states
|
430 |
+
elif attn.norm_cross:
|
431 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
432 |
+
|
433 |
+
key = attn.to_k(encoder_hidden_states)
|
434 |
+
value = attn.to_v(encoder_hidden_states)
|
435 |
+
|
436 |
+
query = attn.head_to_batch_dim(query)
|
437 |
+
key = attn.head_to_batch_dim(key)
|
438 |
+
value = attn.head_to_batch_dim(value)
|
439 |
+
|
440 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
441 |
+
# Currently only process cross-attention
|
442 |
+
if attn_process_fn is not None and cross_attn:
|
443 |
+
attention_probs_before_process = attention_probs.clone()
|
444 |
+
attention_probs = attn_process_fn(attention_probs, query, key, value, attn_key=attn_key, cross_attn=cross_attn, batch_size=batch_size, heads=attn.heads)
|
445 |
+
else:
|
446 |
+
attention_probs_before_process = attention_probs
|
447 |
+
hidden_states = torch.bmm(attention_probs, value)
|
448 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
449 |
+
|
450 |
+
# linear proj
|
451 |
+
hidden_states = attn.to_out[0](hidden_states)
|
452 |
+
# dropout
|
453 |
+
hidden_states = attn.to_out[1](hidden_states)
|
454 |
+
|
455 |
+
if input_ndim == 4:
|
456 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
457 |
+
|
458 |
+
if attn.residual_connection:
|
459 |
+
hidden_states = hidden_states + residual
|
460 |
+
|
461 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
462 |
+
|
463 |
+
if return_attntion_probs or save_attn_to_dict is not None:
|
464 |
+
# Recover batch dimension: (batch_size, heads, flattened_2d, text_tokens)
|
465 |
+
attention_probs_unflattened = attention_probs_before_process.unflatten(dim=0, sizes=(batch_size, attn.heads))
|
466 |
+
if return_token_ca_only is not None:
|
467 |
+
# (batch size, n heads, 2d dimension, num text tokens)
|
468 |
+
if isinstance(return_token_ca_only, int):
|
469 |
+
# return_token_ca_only: an integer
|
470 |
+
attention_probs_unflattened = attention_probs_unflattened[:, :, :, return_token_ca_only:return_token_ca_only+1]
|
471 |
+
else:
|
472 |
+
# return_token_ca_only: A 1d index tensor
|
473 |
+
attention_probs_unflattened = attention_probs_unflattened[:, :, :, return_token_ca_only]
|
474 |
+
if return_cond_ca_only:
|
475 |
+
assert batch_size % 2 == 0, f"Samples are not in pairs: {batch_size} samples"
|
476 |
+
attention_probs_unflattened = attention_probs_unflattened[batch_size // 2:]
|
477 |
+
if offload_cross_attn_to_cpu:
|
478 |
+
attention_probs_unflattened = attention_probs_unflattened.cpu()
|
479 |
+
if save_attn_to_dict is not None and (save_keys is None or (tuple(attn_key) in save_keys)):
|
480 |
+
save_attn_to_dict[tuple(attn_key)] = attention_probs_unflattened
|
481 |
+
if return_attntion_probs:
|
482 |
+
return hidden_states, attention_probs_unflattened
|
483 |
+
return hidden_states
|
484 |
+
|
485 |
+
# For typing
|
486 |
+
AttentionProcessor = AttnProcessor
|
487 |
+
|
488 |
+
class SpatialNorm(nn.Module):
|
489 |
+
"""
|
490 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
491 |
+
"""
|
492 |
+
|
493 |
+
def __init__(
|
494 |
+
self,
|
495 |
+
f_channels,
|
496 |
+
zq_channels,
|
497 |
+
):
|
498 |
+
super().__init__()
|
499 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
500 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
501 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
502 |
+
|
503 |
+
def forward(self, f, zq):
|
504 |
+
f_size = f.shape[-2:]
|
505 |
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
506 |
+
norm_f = self.norm_layer(f)
|
507 |
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
508 |
+
return new_f
|
models/modeling_utils.py
ADDED
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
import itertools
|
19 |
+
import os
|
20 |
+
from functools import partial
|
21 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch import Tensor, device
|
25 |
+
|
26 |
+
from diffusers import __version__
|
27 |
+
from diffusers.utils import (
|
28 |
+
CONFIG_NAME,
|
29 |
+
DIFFUSERS_CACHE,
|
30 |
+
FLAX_WEIGHTS_NAME,
|
31 |
+
HF_HUB_OFFLINE,
|
32 |
+
SAFETENSORS_WEIGHTS_NAME,
|
33 |
+
WEIGHTS_NAME,
|
34 |
+
_add_variant,
|
35 |
+
_get_model_file,
|
36 |
+
deprecate,
|
37 |
+
is_accelerate_available,
|
38 |
+
is_safetensors_available,
|
39 |
+
is_torch_version,
|
40 |
+
logging,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__)
|
45 |
+
|
46 |
+
|
47 |
+
if is_torch_version(">=", "1.9.0"):
|
48 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
49 |
+
else:
|
50 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
51 |
+
|
52 |
+
|
53 |
+
if is_accelerate_available():
|
54 |
+
import accelerate
|
55 |
+
from accelerate.utils import set_module_tensor_to_device
|
56 |
+
from accelerate.utils.versions import is_torch_version
|
57 |
+
|
58 |
+
if is_safetensors_available():
|
59 |
+
import safetensors
|
60 |
+
|
61 |
+
|
62 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
63 |
+
try:
|
64 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
65 |
+
return next(parameters_and_buffers).device
|
66 |
+
except StopIteration:
|
67 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
68 |
+
|
69 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
70 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
71 |
+
return tuples
|
72 |
+
|
73 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
74 |
+
first_tuple = next(gen)
|
75 |
+
return first_tuple[1].device
|
76 |
+
|
77 |
+
|
78 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
79 |
+
try:
|
80 |
+
params = tuple(parameter.parameters())
|
81 |
+
if len(params) > 0:
|
82 |
+
return params[0].dtype
|
83 |
+
|
84 |
+
buffers = tuple(parameter.buffers())
|
85 |
+
if len(buffers) > 0:
|
86 |
+
return buffers[0].dtype
|
87 |
+
|
88 |
+
except StopIteration:
|
89 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
90 |
+
|
91 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
92 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
93 |
+
return tuples
|
94 |
+
|
95 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
96 |
+
first_tuple = next(gen)
|
97 |
+
return first_tuple[1].dtype
|
98 |
+
|
99 |
+
|
100 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
101 |
+
"""
|
102 |
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
103 |
+
"""
|
104 |
+
try:
|
105 |
+
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
106 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
107 |
+
else:
|
108 |
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
109 |
+
except Exception as e:
|
110 |
+
try:
|
111 |
+
with open(checkpoint_file) as f:
|
112 |
+
if f.read().startswith("version"):
|
113 |
+
raise OSError(
|
114 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
115 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
116 |
+
"you cloned."
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
raise ValueError(
|
120 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
121 |
+
"model. Make sure you have saved the model properly."
|
122 |
+
) from e
|
123 |
+
except (UnicodeDecodeError, ValueError):
|
124 |
+
raise OSError(
|
125 |
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
126 |
+
f"at '{checkpoint_file}'. "
|
127 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
132 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
133 |
+
# copy state_dict so _load_from_state_dict can modify it
|
134 |
+
state_dict = state_dict.copy()
|
135 |
+
error_msgs = []
|
136 |
+
|
137 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
138 |
+
# so we need to apply the function recursively.
|
139 |
+
def load(module: torch.nn.Module, prefix=""):
|
140 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
141 |
+
module._load_from_state_dict(*args)
|
142 |
+
|
143 |
+
for name, child in module._modules.items():
|
144 |
+
if child is not None:
|
145 |
+
load(child, prefix + name + ".")
|
146 |
+
|
147 |
+
load(model_to_load)
|
148 |
+
|
149 |
+
return error_msgs
|
150 |
+
|
151 |
+
|
152 |
+
class ModelMixin(torch.nn.Module):
|
153 |
+
r"""
|
154 |
+
Base class for all models.
|
155 |
+
|
156 |
+
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
157 |
+
and saving models.
|
158 |
+
|
159 |
+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
160 |
+
[`~models.ModelMixin.save_pretrained`].
|
161 |
+
"""
|
162 |
+
config_name = CONFIG_NAME
|
163 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
164 |
+
_supports_gradient_checkpointing = False
|
165 |
+
|
166 |
+
def __init__(self):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
def __getattr__(self, name: str) -> Any:
|
170 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
171 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
172 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
173 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
174 |
+
"""
|
175 |
+
|
176 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
177 |
+
is_attribute = name in self.__dict__
|
178 |
+
|
179 |
+
if is_in_config and not is_attribute:
|
180 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
181 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
182 |
+
return self._internal_dict[name]
|
183 |
+
|
184 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
185 |
+
return super().__getattr__(name)
|
186 |
+
|
187 |
+
@property
|
188 |
+
def is_gradient_checkpointing(self) -> bool:
|
189 |
+
"""
|
190 |
+
Whether gradient checkpointing is activated for this model or not.
|
191 |
+
|
192 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
193 |
+
activations".
|
194 |
+
"""
|
195 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
196 |
+
|
197 |
+
def enable_gradient_checkpointing(self):
|
198 |
+
"""
|
199 |
+
Activates gradient checkpointing for the current model.
|
200 |
+
|
201 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
202 |
+
activations".
|
203 |
+
"""
|
204 |
+
if not self._supports_gradient_checkpointing:
|
205 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
206 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
207 |
+
|
208 |
+
def disable_gradient_checkpointing(self):
|
209 |
+
"""
|
210 |
+
Deactivates gradient checkpointing for the current model.
|
211 |
+
|
212 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
213 |
+
activations".
|
214 |
+
"""
|
215 |
+
if self._supports_gradient_checkpointing:
|
216 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
217 |
+
|
218 |
+
def set_use_memory_efficient_attention_xformers(
|
219 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
220 |
+
) -> None:
|
221 |
+
# Recursively walk through all the children.
|
222 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
223 |
+
# gets the message
|
224 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
225 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
226 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
227 |
+
|
228 |
+
for child in module.children():
|
229 |
+
fn_recursive_set_mem_eff(child)
|
230 |
+
|
231 |
+
for module in self.children():
|
232 |
+
if isinstance(module, torch.nn.Module):
|
233 |
+
fn_recursive_set_mem_eff(module)
|
234 |
+
|
235 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
236 |
+
r"""
|
237 |
+
Enable memory efficient attention as implemented in xformers.
|
238 |
+
|
239 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
240 |
+
time. Speed up at training time is not guaranteed.
|
241 |
+
|
242 |
+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
243 |
+
is used.
|
244 |
+
|
245 |
+
Parameters:
|
246 |
+
attention_op (`Callable`, *optional*):
|
247 |
+
Override the default `None` operator for use as `op` argument to the
|
248 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
249 |
+
function of xFormers.
|
250 |
+
|
251 |
+
Examples:
|
252 |
+
|
253 |
+
```py
|
254 |
+
>>> import torch
|
255 |
+
>>> from diffusers import UNet2DConditionModel
|
256 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
257 |
+
|
258 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
259 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
260 |
+
... )
|
261 |
+
>>> model = model.to("cuda")
|
262 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
263 |
+
```
|
264 |
+
"""
|
265 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
266 |
+
|
267 |
+
def disable_xformers_memory_efficient_attention(self):
|
268 |
+
r"""
|
269 |
+
Disable memory efficient attention as implemented in xformers.
|
270 |
+
"""
|
271 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
272 |
+
|
273 |
+
def save_pretrained(
|
274 |
+
self,
|
275 |
+
save_directory: Union[str, os.PathLike],
|
276 |
+
is_main_process: bool = True,
|
277 |
+
save_function: Callable = None,
|
278 |
+
safe_serialization: bool = False,
|
279 |
+
variant: Optional[str] = None,
|
280 |
+
):
|
281 |
+
"""
|
282 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
283 |
+
`[`~models.ModelMixin.from_pretrained`]` class method.
|
284 |
+
|
285 |
+
Arguments:
|
286 |
+
save_directory (`str` or `os.PathLike`):
|
287 |
+
Directory to which to save. Will be created if it doesn't exist.
|
288 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
289 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
290 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
291 |
+
the main process to avoid race conditions.
|
292 |
+
save_function (`Callable`):
|
293 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
294 |
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
295 |
+
`DIFFUSERS_SAVE_MODE`.
|
296 |
+
safe_serialization (`bool`, *optional*, defaults to `False`):
|
297 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
298 |
+
variant (`str`, *optional*):
|
299 |
+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
300 |
+
"""
|
301 |
+
if safe_serialization and not is_safetensors_available():
|
302 |
+
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
303 |
+
|
304 |
+
if os.path.isfile(save_directory):
|
305 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
306 |
+
return
|
307 |
+
|
308 |
+
os.makedirs(save_directory, exist_ok=True)
|
309 |
+
|
310 |
+
model_to_save = self
|
311 |
+
|
312 |
+
# Attach architecture to the config
|
313 |
+
# Save the config
|
314 |
+
if is_main_process:
|
315 |
+
model_to_save.save_config(save_directory)
|
316 |
+
|
317 |
+
# Save the model
|
318 |
+
state_dict = model_to_save.state_dict()
|
319 |
+
|
320 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
321 |
+
weights_name = _add_variant(weights_name, variant)
|
322 |
+
|
323 |
+
# Save the model
|
324 |
+
if safe_serialization:
|
325 |
+
safetensors.torch.save_file(
|
326 |
+
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
330 |
+
|
331 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
332 |
+
|
333 |
+
@classmethod
|
334 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
335 |
+
r"""
|
336 |
+
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
337 |
+
|
338 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
339 |
+
the model, you should first set it back in training mode with `model.train()`.
|
340 |
+
|
341 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
342 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
343 |
+
task.
|
344 |
+
|
345 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
346 |
+
weights are discarded.
|
347 |
+
|
348 |
+
Parameters:
|
349 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
350 |
+
Can be either:
|
351 |
+
|
352 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
353 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
354 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
355 |
+
`./my_model_directory/`.
|
356 |
+
|
357 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
358 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
359 |
+
standard cache should not be used.
|
360 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
361 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
362 |
+
will be automatically derived from the model's weights.
|
363 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
364 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
365 |
+
cached versions if they exist.
|
366 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
367 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
368 |
+
file exists.
|
369 |
+
proxies (`Dict[str, str]`, *optional*):
|
370 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
371 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
372 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
373 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
374 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
375 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
376 |
+
use_auth_token (`str` or *bool*, *optional*):
|
377 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
378 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
379 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
380 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
381 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
382 |
+
identifier allowed by git.
|
383 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
384 |
+
Load the model weights from a Flax checkpoint save file.
|
385 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
386 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
387 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
388 |
+
|
389 |
+
mirror (`str`, *optional*):
|
390 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
391 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
392 |
+
Please refer to the mirror site for more information.
|
393 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
394 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
395 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
396 |
+
same device.
|
397 |
+
|
398 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
399 |
+
more information about each option see [designing a device
|
400 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
401 |
+
max_memory (`Dict`, *optional*):
|
402 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
403 |
+
GPU and the available CPU RAM if unset.
|
404 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
405 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
406 |
+
offload_state_dict (`bool`, *optional*):
|
407 |
+
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
|
408 |
+
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
|
409 |
+
`True` when there is some disk offload.
|
410 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
411 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
412 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
413 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
414 |
+
setting this argument to `True` will raise an error.
|
415 |
+
variant (`str`, *optional*):
|
416 |
+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
417 |
+
ignored when using `from_flax`.
|
418 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
419 |
+
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
420 |
+
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
421 |
+
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
422 |
+
|
423 |
+
<Tip>
|
424 |
+
|
425 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
426 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
427 |
+
|
428 |
+
</Tip>
|
429 |
+
|
430 |
+
<Tip>
|
431 |
+
|
432 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
433 |
+
this method in a firewalled environment.
|
434 |
+
|
435 |
+
</Tip>
|
436 |
+
|
437 |
+
"""
|
438 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
439 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
440 |
+
force_download = kwargs.pop("force_download", False)
|
441 |
+
from_flax = kwargs.pop("from_flax", False)
|
442 |
+
resume_download = kwargs.pop("resume_download", False)
|
443 |
+
proxies = kwargs.pop("proxies", None)
|
444 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
445 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
446 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
447 |
+
revision = kwargs.pop("revision", None)
|
448 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
449 |
+
subfolder = kwargs.pop("subfolder", None)
|
450 |
+
device_map = kwargs.pop("device_map", None)
|
451 |
+
max_memory = kwargs.pop("max_memory", None)
|
452 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
453 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
454 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
455 |
+
variant = kwargs.pop("variant", None)
|
456 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
457 |
+
|
458 |
+
if use_safetensors and not is_safetensors_available():
|
459 |
+
raise ValueError(
|
460 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
461 |
+
)
|
462 |
+
|
463 |
+
allow_pickle = False
|
464 |
+
if use_safetensors is None:
|
465 |
+
use_safetensors = is_safetensors_available()
|
466 |
+
allow_pickle = True
|
467 |
+
|
468 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
469 |
+
low_cpu_mem_usage = False
|
470 |
+
logger.warning(
|
471 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
472 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
473 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
474 |
+
" install accelerate\n```\n."
|
475 |
+
)
|
476 |
+
|
477 |
+
if device_map is not None and not is_accelerate_available():
|
478 |
+
raise NotImplementedError(
|
479 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
480 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
481 |
+
)
|
482 |
+
|
483 |
+
# Check if we can handle device_map and dispatching the weights
|
484 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
485 |
+
raise NotImplementedError(
|
486 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
487 |
+
" `device_map=None`."
|
488 |
+
)
|
489 |
+
|
490 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
491 |
+
raise NotImplementedError(
|
492 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
493 |
+
" `low_cpu_mem_usage=False`."
|
494 |
+
)
|
495 |
+
|
496 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
497 |
+
raise ValueError(
|
498 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
499 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
500 |
+
)
|
501 |
+
|
502 |
+
# Load config if we don't provide a configuration
|
503 |
+
config_path = pretrained_model_name_or_path
|
504 |
+
|
505 |
+
user_agent = {
|
506 |
+
"diffusers": __version__,
|
507 |
+
"file_type": "model",
|
508 |
+
"framework": "pytorch",
|
509 |
+
}
|
510 |
+
|
511 |
+
# load config
|
512 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
513 |
+
config_path,
|
514 |
+
cache_dir=cache_dir,
|
515 |
+
return_unused_kwargs=True,
|
516 |
+
return_commit_hash=True,
|
517 |
+
force_download=force_download,
|
518 |
+
resume_download=resume_download,
|
519 |
+
proxies=proxies,
|
520 |
+
local_files_only=local_files_only,
|
521 |
+
use_auth_token=use_auth_token,
|
522 |
+
revision=revision,
|
523 |
+
subfolder=subfolder,
|
524 |
+
device_map=device_map,
|
525 |
+
max_memory=max_memory,
|
526 |
+
offload_folder=offload_folder,
|
527 |
+
offload_state_dict=offload_state_dict,
|
528 |
+
user_agent=user_agent,
|
529 |
+
**kwargs,
|
530 |
+
)
|
531 |
+
|
532 |
+
# load model
|
533 |
+
model_file = None
|
534 |
+
if from_flax:
|
535 |
+
model_file = _get_model_file(
|
536 |
+
pretrained_model_name_or_path,
|
537 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
538 |
+
cache_dir=cache_dir,
|
539 |
+
force_download=force_download,
|
540 |
+
resume_download=resume_download,
|
541 |
+
proxies=proxies,
|
542 |
+
local_files_only=local_files_only,
|
543 |
+
use_auth_token=use_auth_token,
|
544 |
+
revision=revision,
|
545 |
+
subfolder=subfolder,
|
546 |
+
user_agent=user_agent,
|
547 |
+
commit_hash=commit_hash,
|
548 |
+
)
|
549 |
+
model = cls.from_config(config, **unused_kwargs)
|
550 |
+
|
551 |
+
# Convert the weights
|
552 |
+
from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
553 |
+
|
554 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
555 |
+
else:
|
556 |
+
if use_safetensors:
|
557 |
+
try:
|
558 |
+
model_file = _get_model_file(
|
559 |
+
pretrained_model_name_or_path,
|
560 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
561 |
+
cache_dir=cache_dir,
|
562 |
+
force_download=force_download,
|
563 |
+
resume_download=resume_download,
|
564 |
+
proxies=proxies,
|
565 |
+
local_files_only=local_files_only,
|
566 |
+
use_auth_token=use_auth_token,
|
567 |
+
revision=revision,
|
568 |
+
subfolder=subfolder,
|
569 |
+
user_agent=user_agent,
|
570 |
+
commit_hash=commit_hash,
|
571 |
+
)
|
572 |
+
except IOError as e:
|
573 |
+
if not allow_pickle:
|
574 |
+
raise e
|
575 |
+
pass
|
576 |
+
if model_file is None:
|
577 |
+
model_file = _get_model_file(
|
578 |
+
pretrained_model_name_or_path,
|
579 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
580 |
+
cache_dir=cache_dir,
|
581 |
+
force_download=force_download,
|
582 |
+
resume_download=resume_download,
|
583 |
+
proxies=proxies,
|
584 |
+
local_files_only=local_files_only,
|
585 |
+
use_auth_token=use_auth_token,
|
586 |
+
revision=revision,
|
587 |
+
subfolder=subfolder,
|
588 |
+
user_agent=user_agent,
|
589 |
+
commit_hash=commit_hash,
|
590 |
+
)
|
591 |
+
|
592 |
+
if low_cpu_mem_usage:
|
593 |
+
# Instantiate model with empty weights
|
594 |
+
with accelerate.init_empty_weights():
|
595 |
+
model = cls.from_config(config, **unused_kwargs)
|
596 |
+
|
597 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
598 |
+
if device_map is None:
|
599 |
+
param_device = "cpu"
|
600 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
601 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
602 |
+
# move the params from meta device to cpu
|
603 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
604 |
+
if len(missing_keys) > 0:
|
605 |
+
raise ValueError(
|
606 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
607 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
608 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
609 |
+
" those weights or else make sure your checkpoint file is correct."
|
610 |
+
)
|
611 |
+
|
612 |
+
empty_state_dict = model.state_dict()
|
613 |
+
for param_name, param in state_dict.items():
|
614 |
+
accepts_dtype = "dtype" in set(
|
615 |
+
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
616 |
+
)
|
617 |
+
|
618 |
+
if empty_state_dict[param_name].shape != param.shape:
|
619 |
+
raise ValueError(
|
620 |
+
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
621 |
+
)
|
622 |
+
|
623 |
+
if accepts_dtype:
|
624 |
+
set_module_tensor_to_device(
|
625 |
+
model, param_name, param_device, value=param, dtype=torch_dtype
|
626 |
+
)
|
627 |
+
else:
|
628 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
629 |
+
else: # else let accelerate handle loading and dispatching.
|
630 |
+
# Load weights and dispatch according to the device_map
|
631 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
632 |
+
accelerate.load_checkpoint_and_dispatch(
|
633 |
+
model,
|
634 |
+
model_file,
|
635 |
+
device_map,
|
636 |
+
max_memory=max_memory,
|
637 |
+
offload_folder=offload_folder,
|
638 |
+
offload_state_dict=offload_state_dict,
|
639 |
+
dtype=torch_dtype,
|
640 |
+
)
|
641 |
+
|
642 |
+
loading_info = {
|
643 |
+
"missing_keys": [],
|
644 |
+
"unexpected_keys": [],
|
645 |
+
"mismatched_keys": [],
|
646 |
+
"error_msgs": [],
|
647 |
+
}
|
648 |
+
else:
|
649 |
+
model = cls.from_config(config, **unused_kwargs)
|
650 |
+
|
651 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
652 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
653 |
+
|
654 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
655 |
+
model,
|
656 |
+
state_dict,
|
657 |
+
model_file,
|
658 |
+
pretrained_model_name_or_path,
|
659 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
660 |
+
)
|
661 |
+
|
662 |
+
loading_info = {
|
663 |
+
"missing_keys": missing_keys,
|
664 |
+
"unexpected_keys": unexpected_keys,
|
665 |
+
"mismatched_keys": mismatched_keys,
|
666 |
+
"error_msgs": error_msgs,
|
667 |
+
}
|
668 |
+
|
669 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
670 |
+
raise ValueError(
|
671 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
672 |
+
)
|
673 |
+
elif torch_dtype is not None:
|
674 |
+
model = model.to(torch_dtype)
|
675 |
+
|
676 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
677 |
+
|
678 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
679 |
+
model.eval()
|
680 |
+
if output_loading_info:
|
681 |
+
return model, loading_info
|
682 |
+
|
683 |
+
return model
|
684 |
+
|
685 |
+
@classmethod
|
686 |
+
def _load_pretrained_model(
|
687 |
+
cls,
|
688 |
+
model,
|
689 |
+
state_dict,
|
690 |
+
resolved_archive_file,
|
691 |
+
pretrained_model_name_or_path,
|
692 |
+
ignore_mismatched_sizes=False,
|
693 |
+
):
|
694 |
+
# Retrieve missing & unexpected_keys
|
695 |
+
model_state_dict = model.state_dict()
|
696 |
+
loaded_keys = list(state_dict.keys())
|
697 |
+
|
698 |
+
expected_keys = list(model_state_dict.keys())
|
699 |
+
|
700 |
+
original_loaded_keys = loaded_keys
|
701 |
+
|
702 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
703 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
704 |
+
|
705 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
706 |
+
model_to_load = model
|
707 |
+
|
708 |
+
def _find_mismatched_keys(
|
709 |
+
state_dict,
|
710 |
+
model_state_dict,
|
711 |
+
loaded_keys,
|
712 |
+
ignore_mismatched_sizes,
|
713 |
+
):
|
714 |
+
mismatched_keys = []
|
715 |
+
if ignore_mismatched_sizes:
|
716 |
+
for checkpoint_key in loaded_keys:
|
717 |
+
model_key = checkpoint_key
|
718 |
+
|
719 |
+
if (
|
720 |
+
model_key in model_state_dict
|
721 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
722 |
+
):
|
723 |
+
mismatched_keys.append(
|
724 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
725 |
+
)
|
726 |
+
del state_dict[checkpoint_key]
|
727 |
+
return mismatched_keys
|
728 |
+
|
729 |
+
if state_dict is not None:
|
730 |
+
# Whole checkpoint
|
731 |
+
mismatched_keys = _find_mismatched_keys(
|
732 |
+
state_dict,
|
733 |
+
model_state_dict,
|
734 |
+
original_loaded_keys,
|
735 |
+
ignore_mismatched_sizes,
|
736 |
+
)
|
737 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
738 |
+
|
739 |
+
if len(error_msgs) > 0:
|
740 |
+
error_msg = "\n\t".join(error_msgs)
|
741 |
+
if "size mismatch" in error_msg:
|
742 |
+
error_msg += (
|
743 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
744 |
+
)
|
745 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
746 |
+
|
747 |
+
if len(unexpected_keys) > 0:
|
748 |
+
logger.warning(
|
749 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
750 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
751 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
752 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
753 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
754 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
755 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
756 |
+
" BertForSequenceClassification model)."
|
757 |
+
)
|
758 |
+
else:
|
759 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
760 |
+
if len(missing_keys) > 0:
|
761 |
+
logger.warning(
|
762 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
763 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
764 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
765 |
+
)
|
766 |
+
elif len(mismatched_keys) == 0:
|
767 |
+
logger.info(
|
768 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
769 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
770 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
771 |
+
" without further training."
|
772 |
+
)
|
773 |
+
if len(mismatched_keys) > 0:
|
774 |
+
mismatched_warning = "\n".join(
|
775 |
+
[
|
776 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
777 |
+
for key, shape1, shape2 in mismatched_keys
|
778 |
+
]
|
779 |
+
)
|
780 |
+
logger.warning(
|
781 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
782 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
783 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
784 |
+
" able to use it for predictions and inference."
|
785 |
+
)
|
786 |
+
|
787 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
788 |
+
|
789 |
+
@property
|
790 |
+
def device(self) -> device:
|
791 |
+
"""
|
792 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
793 |
+
device).
|
794 |
+
"""
|
795 |
+
return get_parameter_device(self)
|
796 |
+
|
797 |
+
@property
|
798 |
+
def dtype(self) -> torch.dtype:
|
799 |
+
"""
|
800 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
801 |
+
"""
|
802 |
+
return get_parameter_dtype(self)
|
803 |
+
|
804 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
805 |
+
"""
|
806 |
+
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
807 |
+
|
808 |
+
Args:
|
809 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
810 |
+
Whether or not to return only the number of trainable parameters
|
811 |
+
|
812 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
813 |
+
Whether or not to return only the number of non-embeddings parameters
|
814 |
+
|
815 |
+
Returns:
|
816 |
+
`int`: The number of parameters.
|
817 |
+
"""
|
818 |
+
|
819 |
+
if exclude_embeddings:
|
820 |
+
embedding_param_names = [
|
821 |
+
f"{name}.weight"
|
822 |
+
for name, module_type in self.named_modules()
|
823 |
+
if isinstance(module_type, torch.nn.Embedding)
|
824 |
+
]
|
825 |
+
non_embedding_parameters = [
|
826 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
827 |
+
]
|
828 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
829 |
+
else:
|
830 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
831 |
+
|
832 |
+
def _convert_deprecated_attention_blocks(self, state_dict):
|
833 |
+
deprecated_attention_block_paths = []
|
834 |
+
|
835 |
+
def recursive_find_attn_block(name, module):
|
836 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
837 |
+
deprecated_attention_block_paths.append(name)
|
838 |
+
|
839 |
+
for sub_name, sub_module in module.named_children():
|
840 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
841 |
+
recursive_find_attn_block(sub_name, sub_module)
|
842 |
+
|
843 |
+
recursive_find_attn_block("", self)
|
844 |
+
|
845 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
846 |
+
# because it is possible we are loading from a state dict that was already
|
847 |
+
# converted
|
848 |
+
|
849 |
+
for path in deprecated_attention_block_paths:
|
850 |
+
# group_norm path stays the same
|
851 |
+
|
852 |
+
# query -> to_q
|
853 |
+
if f"{path}.query.weight" in state_dict:
|
854 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
855 |
+
if f"{path}.query.bias" in state_dict:
|
856 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
857 |
+
|
858 |
+
# key -> to_k
|
859 |
+
if f"{path}.key.weight" in state_dict:
|
860 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
861 |
+
if f"{path}.key.bias" in state_dict:
|
862 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
863 |
+
|
864 |
+
# value -> to_v
|
865 |
+
if f"{path}.value.weight" in state_dict:
|
866 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
867 |
+
if f"{path}.value.bias" in state_dict:
|
868 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
869 |
+
|
870 |
+
# proj_attn -> to_out.0
|
871 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
872 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
873 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
874 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
models/models.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
3 |
+
from diffusers import AutoencoderKL, DDIMScheduler, DDIMInverseScheduler, DPMSolverMultistepScheduler
|
4 |
+
from .unet_2d_condition import UNet2DConditionModel
|
5 |
+
from easydict import EasyDict
|
6 |
+
import numpy as np
|
7 |
+
# For compatibility
|
8 |
+
from utils.latents import get_unscaled_latents, get_scaled_latents, blend_latents
|
9 |
+
from utils import torch_device
|
10 |
+
|
11 |
+
def load_sd(key="runwayml/stable-diffusion-v1-5", use_fp16=False, load_inverse_scheduler=True, use_dpm_multistep_scheduler=False):
|
12 |
+
"""
|
13 |
+
Keys:
|
14 |
+
key = "CompVis/stable-diffusion-v1-4"
|
15 |
+
key = "runwayml/stable-diffusion-v1-5"
|
16 |
+
key = "stabilityai/stable-diffusion-2-1-base"
|
17 |
+
|
18 |
+
Unpack with:
|
19 |
+
```
|
20 |
+
model_dict = load_sd(key=key, use_fp16=use_fp16)
|
21 |
+
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
|
22 |
+
```
|
23 |
+
|
24 |
+
use_fp16: fp16 might have degraded performance
|
25 |
+
use_dpm_multistep_scheduler: DPMSolverMultistepScheduler
|
26 |
+
"""
|
27 |
+
|
28 |
+
# run final results in fp32
|
29 |
+
if use_fp16:
|
30 |
+
dtype = torch.float16
|
31 |
+
revision = "fp16"
|
32 |
+
else:
|
33 |
+
dtype = torch.float
|
34 |
+
revision = "main"
|
35 |
+
|
36 |
+
vae = AutoencoderKL.from_pretrained(key, subfolder="vae", revision=revision, torch_dtype=dtype).to(torch_device)
|
37 |
+
tokenizer = CLIPTokenizer.from_pretrained(key, subfolder="tokenizer", revision=revision, torch_dtype=dtype)
|
38 |
+
text_encoder = CLIPTextModel.from_pretrained(key, subfolder="text_encoder", revision=revision, torch_dtype=dtype).to(torch_device)
|
39 |
+
unet = UNet2DConditionModel.from_pretrained(key, subfolder="unet", revision=revision, torch_dtype=dtype).to(torch_device)
|
40 |
+
if use_dpm_multistep_scheduler:
|
41 |
+
scheduler = DPMSolverMultistepScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)
|
42 |
+
else:
|
43 |
+
scheduler = DDIMScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)
|
44 |
+
|
45 |
+
model_dict = EasyDict(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, dtype=dtype)
|
46 |
+
|
47 |
+
if load_inverse_scheduler:
|
48 |
+
inverse_scheduler = DDIMInverseScheduler.from_config(scheduler.config)
|
49 |
+
model_dict.inverse_scheduler = inverse_scheduler
|
50 |
+
|
51 |
+
return model_dict
|
52 |
+
|
53 |
+
def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_full_only=False, one_uncond_input_only=False):
|
54 |
+
if negative_prompt == "":
|
55 |
+
print("Note that negative_prompt is an empty string")
|
56 |
+
|
57 |
+
text_input = tokenizer(
|
58 |
+
prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
59 |
+
)
|
60 |
+
|
61 |
+
max_length = text_input.input_ids.shape[-1]
|
62 |
+
if one_uncond_input_only:
|
63 |
+
num_uncond_input = 1
|
64 |
+
else:
|
65 |
+
num_uncond_input = len(prompts)
|
66 |
+
uncond_input = tokenizer([negative_prompt] * num_uncond_input, padding="max_length", max_length=max_length, return_tensors="pt")
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
70 |
+
cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
|
71 |
+
|
72 |
+
if one_uncond_input_only:
|
73 |
+
return uncond_embeddings, cond_embeddings
|
74 |
+
|
75 |
+
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
|
76 |
+
|
77 |
+
if return_full_only:
|
78 |
+
return text_embeddings
|
79 |
+
return text_embeddings, uncond_embeddings, cond_embeddings
|
80 |
+
|
81 |
+
def attn_list_to_tensor(cross_attention_probs):
|
82 |
+
# timestep, CrossAttnBlock, Transformer2DModel, 1xBasicTransformerBlock
|
83 |
+
|
84 |
+
num_cross_attn_block = len(cross_attention_probs[0])
|
85 |
+
cross_attention_probs_all = []
|
86 |
+
|
87 |
+
for i in range(num_cross_attn_block):
|
88 |
+
# cross_attention_probs_timestep[i]: Transformer2DModel
|
89 |
+
# 1xBasicTransformerBlock is skipped
|
90 |
+
cross_attention_probs_current = []
|
91 |
+
for cross_attention_probs_timestep in cross_attention_probs:
|
92 |
+
cross_attention_probs_current.append(torch.stack([item for item in cross_attention_probs_timestep[i]], dim=0))
|
93 |
+
|
94 |
+
cross_attention_probs_current = torch.stack(cross_attention_probs_current, dim=0)
|
95 |
+
cross_attention_probs_all.append(cross_attention_probs_current)
|
96 |
+
|
97 |
+
return cross_attention_probs_all
|
models/pipelines.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import utils
|
4 |
+
from PIL import Image
|
5 |
+
import gc
|
6 |
+
import numpy as np
|
7 |
+
from .attention import GatedSelfAttentionDense
|
8 |
+
from .models import torch_device
|
9 |
+
|
10 |
+
@torch.no_grad()
|
11 |
+
def encode(model_dict, image, generator):
|
12 |
+
"""
|
13 |
+
image should be a PIL object or numpy array with range 0 to 255
|
14 |
+
"""
|
15 |
+
|
16 |
+
vae, dtype = model_dict.vae, model_dict.dtype
|
17 |
+
|
18 |
+
if isinstance(image, Image.Image):
|
19 |
+
w, h = image.size
|
20 |
+
assert w % 8 == 0 and h % 8 == 0, f"h ({h}) and w ({w}) should be a multiple of 8"
|
21 |
+
# w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
22 |
+
# image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :]
|
23 |
+
image = np.array(image)
|
24 |
+
|
25 |
+
if isinstance(image, np.ndarray):
|
26 |
+
assert image.dtype == np.uint8, f"Should have dtype uint8 (dtype: {image.dtype})"
|
27 |
+
image = image.astype(np.float32) / 255.0
|
28 |
+
image = image[None, ...]
|
29 |
+
image = image.transpose(0, 3, 1, 2)
|
30 |
+
image = 2.0 * image - 1.0
|
31 |
+
image = torch.from_numpy(image)
|
32 |
+
|
33 |
+
assert isinstance(image, torch.Tensor), f"type of image: {type(image)}"
|
34 |
+
|
35 |
+
image = image.to(device=torch_device, dtype=dtype)
|
36 |
+
latents = vae.encode(image).latent_dist.sample(generator)
|
37 |
+
|
38 |
+
latents = vae.config.scaling_factor * latents
|
39 |
+
|
40 |
+
return latents
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def decode(vae, latents):
|
44 |
+
# scale and decode the image latents with vae
|
45 |
+
scaled_latents = 1 / 0.18215 * latents
|
46 |
+
with torch.no_grad():
|
47 |
+
image = vae.decode(scaled_latents).sample
|
48 |
+
|
49 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
50 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
51 |
+
images = (image * 255).round().astype("uint8")
|
52 |
+
|
53 |
+
return images
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False):
|
57 |
+
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
|
58 |
+
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
|
59 |
+
|
60 |
+
if not no_set_timesteps:
|
61 |
+
scheduler.set_timesteps(num_inference_steps)
|
62 |
+
|
63 |
+
for t in tqdm(scheduler.timesteps):
|
64 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
65 |
+
latent_model_input = torch.cat([latents] * 2)
|
66 |
+
|
67 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
|
68 |
+
|
69 |
+
# predict the noise residual
|
70 |
+
with torch.no_grad():
|
71 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
72 |
+
|
73 |
+
# perform guidance
|
74 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
75 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
76 |
+
|
77 |
+
# compute the previous noisy sample x_t -> x_t-1
|
78 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
79 |
+
|
80 |
+
images = decode(vae, latents)
|
81 |
+
|
82 |
+
ret = [latents, images]
|
83 |
+
|
84 |
+
return tuple(ret)
|
85 |
+
|
86 |
+
def gligen_enable_fuser(unet, enabled=True):
|
87 |
+
for module in unet.modules():
|
88 |
+
if isinstance(module, GatedSelfAttentionDense):
|
89 |
+
module.enabled = enabled
|
90 |
+
|
91 |
+
@torch.no_grad()
|
92 |
+
def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
|
93 |
+
frozen_steps=20, frozen_mask=None,
|
94 |
+
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
|
95 |
+
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
|
96 |
+
semantic_guidance=False, semantic_guidance_bboxes=None, semantic_guidance_object_positions=None, semantic_guidance_kwargs=None,
|
97 |
+
return_box_vis=False, show_progress=True, save_all_latents=False):
|
98 |
+
"""
|
99 |
+
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
|
100 |
+
"""
|
101 |
+
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
|
102 |
+
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
|
103 |
+
|
104 |
+
if latents.dim() == 5:
|
105 |
+
# latents_all from the input side, different from the latents_all to be saved
|
106 |
+
latents_all_input = latents
|
107 |
+
latents = latents[0]
|
108 |
+
else:
|
109 |
+
latents_all_input = None
|
110 |
+
|
111 |
+
# Just in case that we have in-place ops
|
112 |
+
latents = latents.clone()
|
113 |
+
|
114 |
+
if save_all_latents:
|
115 |
+
# offload to cpu to save space
|
116 |
+
if offload_latents_to_cpu:
|
117 |
+
latents_all = [latents.cpu()]
|
118 |
+
else:
|
119 |
+
latents_all = [latents]
|
120 |
+
|
121 |
+
scheduler.set_timesteps(num_inference_steps)
|
122 |
+
|
123 |
+
if frozen_mask is not None:
|
124 |
+
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
125 |
+
|
126 |
+
batch_size = 1
|
127 |
+
|
128 |
+
# 5.1 Prepare GLIGEN variables
|
129 |
+
assert len(phrases) == len(bboxes)
|
130 |
+
# assert batch_size == 1
|
131 |
+
max_objs = 30
|
132 |
+
_boxes = bboxes
|
133 |
+
|
134 |
+
n_objs = min(len(_boxes), max_objs)
|
135 |
+
boxes = torch.zeros(max_objs, 4, device=torch_device, dtype=dtype)
|
136 |
+
phrase_embeddings = torch.zeros(max_objs, 768, device=torch_device, dtype=dtype)
|
137 |
+
masks = torch.zeros(max_objs, device=torch_device, dtype=dtype)
|
138 |
+
|
139 |
+
if n_objs > 0:
|
140 |
+
boxes[:n_objs] = torch.tensor(_boxes[:n_objs])
|
141 |
+
tokenizer_inputs = tokenizer(phrases, padding=True, return_tensors="pt").to(torch_device)
|
142 |
+
_phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
|
143 |
+
phrase_embeddings[:n_objs] = _phrase_embeddings[:n_objs]
|
144 |
+
masks[:n_objs] = 1
|
145 |
+
|
146 |
+
# Classifier-free guidance
|
147 |
+
repeat_batch = batch_size * num_images_per_prompt * 2
|
148 |
+
|
149 |
+
boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
|
150 |
+
phrase_embeddings = phrase_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
|
151 |
+
masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
|
152 |
+
masks[:repeat_batch // 2] = 0
|
153 |
+
|
154 |
+
if semantic_guidance_bboxes and semantic_guidance:
|
155 |
+
loss = torch.tensor(10000.)
|
156 |
+
# TODO: we can also save necessary tokens only to save memory.
|
157 |
+
# offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
|
158 |
+
guidance_cross_attention_kwargs = {
|
159 |
+
'offload_cross_attn_to_cpu': False,
|
160 |
+
'enable_flash_attn': False,
|
161 |
+
'gligen': {
|
162 |
+
'boxes': boxes[:repeat_batch // 2],
|
163 |
+
'positive_embeddings': phrase_embeddings[:repeat_batch // 2],
|
164 |
+
'masks': masks[:repeat_batch // 2],
|
165 |
+
'fuser_attn_kwargs': {
|
166 |
+
'enable_flash_attn': False,
|
167 |
+
}
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
if return_saved_cross_attn:
|
172 |
+
saved_attns = []
|
173 |
+
|
174 |
+
main_cross_attention_kwargs = {
|
175 |
+
'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
|
176 |
+
'return_cond_ca_only': return_cond_ca_only,
|
177 |
+
'return_token_ca_only': return_token_ca_only,
|
178 |
+
'save_keys': saved_cross_attn_keys,
|
179 |
+
'gligen': {
|
180 |
+
'boxes': boxes,
|
181 |
+
'positive_embeddings': phrase_embeddings,
|
182 |
+
'masks': masks
|
183 |
+
}
|
184 |
+
}
|
185 |
+
|
186 |
+
timesteps = scheduler.timesteps
|
187 |
+
|
188 |
+
num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
|
189 |
+
gligen_enable_fuser(unet, True)
|
190 |
+
|
191 |
+
for index, t in enumerate(tqdm(timesteps, disable=not show_progress)):
|
192 |
+
# Scheduled sampling
|
193 |
+
if index == num_grounding_steps:
|
194 |
+
gligen_enable_fuser(unet, False)
|
195 |
+
|
196 |
+
if semantic_guidance_bboxes and semantic_guidance:
|
197 |
+
with torch.enable_grad():
|
198 |
+
latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, semantic_guidance_bboxes, semantic_guidance_object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
|
199 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
200 |
+
latent_model_input = torch.cat([latents] * 2)
|
201 |
+
|
202 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
|
203 |
+
|
204 |
+
main_cross_attention_kwargs['save_attn_to_dict'] = {}
|
205 |
+
|
206 |
+
# predict the noise residual
|
207 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
|
208 |
+
cross_attention_kwargs=main_cross_attention_kwargs).sample
|
209 |
+
|
210 |
+
if return_saved_cross_attn:
|
211 |
+
saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
|
212 |
+
|
213 |
+
del main_cross_attention_kwargs['save_attn_to_dict']
|
214 |
+
|
215 |
+
# perform guidance
|
216 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
217 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
218 |
+
|
219 |
+
# compute the previous noisy sample x_t -> x_t-1
|
220 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
221 |
+
|
222 |
+
if frozen_mask is not None and index < frozen_steps:
|
223 |
+
latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
|
224 |
+
|
225 |
+
if save_all_latents:
|
226 |
+
if offload_latents_to_cpu:
|
227 |
+
latents_all.append(latents.cpu())
|
228 |
+
else:
|
229 |
+
latents_all.append(latents)
|
230 |
+
|
231 |
+
# Turn off fuser for typical SD
|
232 |
+
gligen_enable_fuser(unet, False)
|
233 |
+
images = decode(vae, latents)
|
234 |
+
|
235 |
+
ret = [latents, images]
|
236 |
+
if return_saved_cross_attn:
|
237 |
+
ret.append(saved_attns)
|
238 |
+
if return_box_vis:
|
239 |
+
pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
|
240 |
+
ret.append(pil_images)
|
241 |
+
if save_all_latents:
|
242 |
+
latents_all = torch.stack(latents_all, dim=0)
|
243 |
+
ret.append(latents_all)
|
244 |
+
|
245 |
+
return tuple(ret)
|
246 |
+
|
models/sam.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from models import torch_device
|
6 |
+
from transformers import SamModel, SamProcessor
|
7 |
+
import utils
|
8 |
+
import cv2
|
9 |
+
from scipy import ndimage
|
10 |
+
|
11 |
+
def load_sam():
|
12 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(torch_device)
|
13 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
14 |
+
|
15 |
+
sam_model_dict = dict(
|
16 |
+
sam_model = sam_model, sam_processor = sam_processor
|
17 |
+
)
|
18 |
+
|
19 |
+
return sam_model_dict
|
20 |
+
|
21 |
+
# Not fully backward compatible with the previous implementation
|
22 |
+
# Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb
|
23 |
+
def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_shape=None):
|
24 |
+
"""target_mask_shape: (h, w)"""
|
25 |
+
sam_model, sam_processor = sam_model_dict['sam_model'], sam_model_dict['sam_processor']
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
with torch.autocast(torch_device):
|
29 |
+
inputs = sam_processor(image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
|
30 |
+
outputs = sam_model(**inputs)
|
31 |
+
masks = sam_processor.image_processor.post_process_masks(
|
32 |
+
outputs.pred_masks.cpu().float(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
33 |
+
)
|
34 |
+
conf_scores = outputs.iou_scores.to(device="cpu", dtype=torch.float32).numpy()[0,0]
|
35 |
+
del inputs, outputs
|
36 |
+
|
37 |
+
gc.collect()
|
38 |
+
if torch_device == "cuda":
|
39 |
+
torch.cuda.empty_cache()
|
40 |
+
|
41 |
+
masks = masks[0][0].numpy()
|
42 |
+
|
43 |
+
if target_mask_shape is not None:
|
44 |
+
masks = np.array([cv2.resize(mask.astype(np.uint8) * 255, target_mask_shape[::-1], cv2.INTER_LINEAR).astype(bool) for mask in masks])
|
45 |
+
|
46 |
+
return masks, conf_scores
|
47 |
+
|
48 |
+
def sam_point_input(sam_model_dict, image, input_points, **kwargs):
|
49 |
+
return sam(sam_model_dict, image, input_points=input_points, **kwargs)
|
50 |
+
|
51 |
+
def sam_box_input(sam_model_dict, image, input_boxes, **kwargs):
|
52 |
+
return sam(sam_model_dict, image, input_boxes=input_boxes, **kwargs)
|
53 |
+
|
54 |
+
def get_iou_with_resize(mask, masks, masks_shape):
|
55 |
+
masks = np.array([cv2.resize(mask.astype(np.uint8) * 255, masks_shape[::-1], cv2.INTER_LINEAR).astype(bool) for mask in masks])
|
56 |
+
return utils.iou(mask, masks)
|
57 |
+
|
58 |
+
def select_mask(masks, conf_scores, coarse_ious=None, rule="largest_over_conf", discourage_mask_below_confidence=0.85, discourage_mask_below_coarse_iou=0.2, verbose=False):
|
59 |
+
"""masks: numpy bool array"""
|
60 |
+
mask_sizes = masks.sum(axis=(1, 2))
|
61 |
+
|
62 |
+
# Another possible rule: iou with the attention mask
|
63 |
+
if rule == "largest_over_conf":
|
64 |
+
# Use the largest segmentation
|
65 |
+
# Discourage selecting masks with conf too low or coarse iou is too low
|
66 |
+
max_mask_size = np.max(mask_sizes)
|
67 |
+
if coarse_ious is not None:
|
68 |
+
scores = mask_sizes - (conf_scores < discourage_mask_below_confidence) * max_mask_size - (coarse_ious < discourage_mask_below_coarse_iou) * max_mask_size
|
69 |
+
else:
|
70 |
+
scores = mask_sizes - (conf_scores < discourage_mask_below_confidence) * max_mask_size
|
71 |
+
if verbose:
|
72 |
+
print(f"mask_sizes: {mask_sizes}, scores: {scores}")
|
73 |
+
else:
|
74 |
+
raise ValueError(f"Unknown rule: {rule}")
|
75 |
+
|
76 |
+
mask_id = np.argmax(scores)
|
77 |
+
mask = masks[mask_id]
|
78 |
+
|
79 |
+
selection_conf = conf_scores[mask_id]
|
80 |
+
|
81 |
+
if coarse_ious is not None:
|
82 |
+
selection_coarse_iou = coarse_ious[mask_id]
|
83 |
+
else:
|
84 |
+
selection_coarse_iou = None
|
85 |
+
|
86 |
+
if verbose:
|
87 |
+
# print(f"Confidences: {conf_scores}")
|
88 |
+
print(f"Selected a mask with confidence: {selection_conf}, coarse_iou: {selection_coarse_iou}")
|
89 |
+
|
90 |
+
if verbose:
|
91 |
+
plt.figure(figsize=(10, 8))
|
92 |
+
# plt.suptitle("After SAM")
|
93 |
+
for ind in range(3):
|
94 |
+
plt.subplot(1, 3, ind+1)
|
95 |
+
# This is obtained before resize.
|
96 |
+
plt.title(f"Mask {ind}, score {scores[ind]}, conf {conf_scores[ind]:.2f}, iou {coarse_ious[ind] if coarse_ious is not None else None:.2f}")
|
97 |
+
plt.imshow(masks[ind])
|
98 |
+
plt.tight_layout()
|
99 |
+
plt.show()
|
100 |
+
|
101 |
+
return mask, selection_conf
|
102 |
+
|
103 |
+
def preprocess_mask(token_attn_np_smooth, mask_th, n_erode_dilate_mask=0):
|
104 |
+
token_attn_np_smooth_normalized = token_attn_np_smooth - token_attn_np_smooth.min()
|
105 |
+
token_attn_np_smooth_normalized /= token_attn_np_smooth_normalized.max()
|
106 |
+
mask_thresholded = token_attn_np_smooth_normalized > mask_th
|
107 |
+
|
108 |
+
if n_erode_dilate_mask:
|
109 |
+
mask_thresholded = ndimage.binary_erosion(mask_thresholded, iterations=n_erode_dilate_mask)
|
110 |
+
mask_thresholded = ndimage.binary_dilation(mask_thresholded, iterations=n_erode_dilate_mask)
|
111 |
+
|
112 |
+
return mask_thresholded
|
113 |
+
|
114 |
+
# The overall pipeline to refine the attention mask
|
115 |
+
def sam_refine_attn(sam_input_image, token_attn_np, model_dict, height, width, H, W, use_box_input, gaussian_sigma, mask_th_for_box, n_erode_dilate_mask_for_box, mask_th_for_point, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
|
116 |
+
|
117 |
+
# token_attn_np is for visualizations
|
118 |
+
token_attn_np_smooth = ndimage.gaussian_filter(token_attn_np, sigma=gaussian_sigma)
|
119 |
+
|
120 |
+
# (w, h)
|
121 |
+
mask_size_scale = height // token_attn_np_smooth.shape[1], width // token_attn_np_smooth.shape[0]
|
122 |
+
|
123 |
+
if use_box_input:
|
124 |
+
# box input
|
125 |
+
mask_binary = preprocess_mask(token_attn_np_smooth, mask_th_for_box, n_erode_dilate_mask=n_erode_dilate_mask_for_box)
|
126 |
+
|
127 |
+
input_boxes = utils.binary_mask_to_box(mask_binary, w_scale=mask_size_scale[0], h_scale=mask_size_scale[1])
|
128 |
+
input_boxes = [input_boxes]
|
129 |
+
|
130 |
+
masks, conf_scores = sam_box_input(model_dict, image=sam_input_image, input_boxes=input_boxes, target_mask_shape=(H, W))
|
131 |
+
else:
|
132 |
+
# point input
|
133 |
+
mask_binary = preprocess_mask(token_attn_np_smooth, mask_th_for_point, n_erode_dilate_mask=0)
|
134 |
+
|
135 |
+
# Uses the max coordinate only
|
136 |
+
max_coord = np.unravel_index(token_attn_np_smooth.argmax(), token_attn_np_smooth.shape)
|
137 |
+
# print("max_coord:", max_coord)
|
138 |
+
input_points = [[[max_coord[1] * mask_size_scale[1], max_coord[0] * mask_size_scale[0]]]]
|
139 |
+
|
140 |
+
masks, conf_scores = sam_point_input(model_dict, image=sam_input_image, input_points=input_points, target_mask_shape=(H, W))
|
141 |
+
|
142 |
+
if verbose:
|
143 |
+
plt.title("Coarse binary mask (for box for box input and for iou)")
|
144 |
+
plt.imshow(mask_binary)
|
145 |
+
plt.show()
|
146 |
+
|
147 |
+
coarse_ious = get_iou_with_resize(mask_binary, masks, masks_shape=mask_binary.shape)
|
148 |
+
|
149 |
+
mask_selected, conf_score_selected = select_mask(masks, conf_scores, coarse_ious=coarse_ious,
|
150 |
+
rule="largest_over_conf",
|
151 |
+
discourage_mask_below_confidence=discourage_mask_below_confidence,
|
152 |
+
discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
|
153 |
+
verbose=True)
|
154 |
+
|
155 |
+
return mask_selected, conf_score_selected
|
156 |
+
|
157 |
+
def sam_refine_box(sam_input_image, box, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
|
158 |
+
# (w, h)
|
159 |
+
input_boxes = utils.scale_proportion(box, H=height, W=width)
|
160 |
+
input_boxes = [input_boxes]
|
161 |
+
|
162 |
+
masks, conf_scores = sam_box_input(model_dict, image=sam_input_image, input_boxes=input_boxes, target_mask_shape=(H, W))
|
163 |
+
|
164 |
+
mask_binary = utils.proportion_to_mask(box, H, W, return_np=True)
|
165 |
+
if verbose:
|
166 |
+
# Also the box is the input for SAM
|
167 |
+
plt.title("Binary mask from input box (for iou)")
|
168 |
+
plt.imshow(mask_binary)
|
169 |
+
plt.show()
|
170 |
+
|
171 |
+
coarse_ious = get_iou_with_resize(mask_binary, masks, masks_shape=mask_binary.shape)
|
172 |
+
|
173 |
+
mask_selected, conf_score_selected = select_mask(masks, conf_scores, coarse_ious=coarse_ious,
|
174 |
+
rule="largest_over_conf",
|
175 |
+
discourage_mask_below_confidence=discourage_mask_below_confidence,
|
176 |
+
discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
|
177 |
+
verbose=True)
|
178 |
+
|
179 |
+
return mask_selected, conf_score_selected
|
models/transformer_2d.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
23 |
+
from diffusers.utils import BaseOutput, deprecate
|
24 |
+
from .attention import BasicTransformerBlock
|
25 |
+
from diffusers.models.embeddings import PatchEmbed
|
26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Transformer2DModelOutput(BaseOutput):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
34 |
+
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
35 |
+
for the unnoised latent pixels.
|
36 |
+
"""
|
37 |
+
|
38 |
+
sample: torch.FloatTensor
|
39 |
+
|
40 |
+
|
41 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
42 |
+
"""
|
43 |
+
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
44 |
+
embeddings) inputs.
|
45 |
+
|
46 |
+
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
47 |
+
transformer action. Finally, reshape to image.
|
48 |
+
|
49 |
+
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
50 |
+
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
51 |
+
classes of unnoised image.
|
52 |
+
|
53 |
+
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
54 |
+
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
55 |
+
|
56 |
+
Parameters:
|
57 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
58 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
59 |
+
in_channels (`int`, *optional*):
|
60 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
61 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
62 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
63 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
64 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
65 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
66 |
+
`ImagePositionalEmbeddings`.
|
67 |
+
num_vector_embeds (`int`, *optional*):
|
68 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
69 |
+
Includes the class for the masked latent pixel.
|
70 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
71 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
72 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
73 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
74 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
75 |
+
attention_bias (`bool`, *optional*):
|
76 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
77 |
+
"""
|
78 |
+
|
79 |
+
@register_to_config
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
num_attention_heads: int = 16,
|
83 |
+
attention_head_dim: int = 88,
|
84 |
+
in_channels: Optional[int] = None,
|
85 |
+
out_channels: Optional[int] = None,
|
86 |
+
num_layers: int = 1,
|
87 |
+
dropout: float = 0.0,
|
88 |
+
norm_num_groups: int = 32,
|
89 |
+
cross_attention_dim: Optional[int] = None,
|
90 |
+
attention_bias: bool = False,
|
91 |
+
sample_size: Optional[int] = None,
|
92 |
+
num_vector_embeds: Optional[int] = None,
|
93 |
+
patch_size: Optional[int] = None,
|
94 |
+
activation_fn: str = "geglu",
|
95 |
+
num_embeds_ada_norm: Optional[int] = None,
|
96 |
+
use_linear_projection: bool = False,
|
97 |
+
only_cross_attention: bool = False,
|
98 |
+
upcast_attention: bool = False,
|
99 |
+
norm_type: str = "layer_norm",
|
100 |
+
norm_elementwise_affine: bool = True,
|
101 |
+
use_gated_attention: bool = False,
|
102 |
+
):
|
103 |
+
super().__init__()
|
104 |
+
self.use_linear_projection = use_linear_projection
|
105 |
+
self.num_attention_heads = num_attention_heads
|
106 |
+
self.attention_head_dim = attention_head_dim
|
107 |
+
inner_dim = num_attention_heads * attention_head_dim
|
108 |
+
|
109 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
110 |
+
# Define whether input is continuous or discrete depending on configuration
|
111 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
112 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
113 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
114 |
+
|
115 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
116 |
+
deprecation_message = (
|
117 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
118 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
119 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
120 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
121 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
122 |
+
)
|
123 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
124 |
+
norm_type = "ada_norm"
|
125 |
+
|
126 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
127 |
+
raise ValueError(
|
128 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
129 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
130 |
+
)
|
131 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
132 |
+
raise ValueError(
|
133 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
134 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
135 |
+
)
|
136 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
137 |
+
raise ValueError(
|
138 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
139 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
140 |
+
)
|
141 |
+
|
142 |
+
# 2. Define input layers
|
143 |
+
if self.is_input_continuous:
|
144 |
+
self.in_channels = in_channels
|
145 |
+
|
146 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
147 |
+
if use_linear_projection:
|
148 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
149 |
+
else:
|
150 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
151 |
+
elif self.is_input_vectorized:
|
152 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
153 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
154 |
+
|
155 |
+
self.height = sample_size
|
156 |
+
self.width = sample_size
|
157 |
+
self.num_vector_embeds = num_vector_embeds
|
158 |
+
self.num_latent_pixels = self.height * self.width
|
159 |
+
|
160 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
161 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
162 |
+
)
|
163 |
+
elif self.is_input_patches:
|
164 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
165 |
+
|
166 |
+
self.height = sample_size
|
167 |
+
self.width = sample_size
|
168 |
+
|
169 |
+
self.patch_size = patch_size
|
170 |
+
self.pos_embed = PatchEmbed(
|
171 |
+
height=sample_size,
|
172 |
+
width=sample_size,
|
173 |
+
patch_size=patch_size,
|
174 |
+
in_channels=in_channels,
|
175 |
+
embed_dim=inner_dim,
|
176 |
+
)
|
177 |
+
|
178 |
+
# 3. Define transformers blocks
|
179 |
+
self.transformer_blocks = nn.ModuleList(
|
180 |
+
[
|
181 |
+
BasicTransformerBlock(
|
182 |
+
inner_dim,
|
183 |
+
num_attention_heads,
|
184 |
+
attention_head_dim,
|
185 |
+
dropout=dropout,
|
186 |
+
cross_attention_dim=cross_attention_dim,
|
187 |
+
activation_fn=activation_fn,
|
188 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
189 |
+
attention_bias=attention_bias,
|
190 |
+
only_cross_attention=only_cross_attention,
|
191 |
+
upcast_attention=upcast_attention,
|
192 |
+
norm_type=norm_type,
|
193 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
194 |
+
use_gated_attention=use_gated_attention,
|
195 |
+
)
|
196 |
+
for d in range(num_layers)
|
197 |
+
]
|
198 |
+
)
|
199 |
+
|
200 |
+
# 4. Define output layers
|
201 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
202 |
+
if self.is_input_continuous:
|
203 |
+
# TODO: should use out_channels for continuous projections
|
204 |
+
if use_linear_projection:
|
205 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
206 |
+
else:
|
207 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
208 |
+
elif self.is_input_vectorized:
|
209 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
210 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
211 |
+
elif self.is_input_patches:
|
212 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
213 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
214 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
215 |
+
|
216 |
+
def forward(
|
217 |
+
self,
|
218 |
+
hidden_states: torch.Tensor,
|
219 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
220 |
+
timestep: Optional[torch.LongTensor] = None,
|
221 |
+
class_labels: Optional[torch.LongTensor] = None,
|
222 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
223 |
+
attention_mask: Optional[torch.Tensor] = None,
|
224 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
225 |
+
return_dict: bool = True,
|
226 |
+
return_cross_attention_probs: bool = False,
|
227 |
+
):
|
228 |
+
"""
|
229 |
+
Args:
|
230 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
231 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
232 |
+
hidden_states
|
233 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
234 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
235 |
+
self-attention.
|
236 |
+
timestep ( `torch.LongTensor`, *optional*):
|
237 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
238 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
239 |
+
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
240 |
+
conditioning.
|
241 |
+
encoder_attention_mask ( `torch.Tensor`, *optional* ).
|
242 |
+
Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
|
243 |
+
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
|
244 |
+
= keep, -10000 = discard.
|
245 |
+
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
|
246 |
+
above. This bias will be added to the cross-attention scores.
|
247 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
248 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
252 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
253 |
+
returning a tuple, the first element is the sample tensor.
|
254 |
+
"""
|
255 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
256 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
257 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
258 |
+
# expects mask of shape:
|
259 |
+
# [batch, key_tokens]
|
260 |
+
# adds singleton query_tokens dimension:
|
261 |
+
# [batch, 1, key_tokens]
|
262 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
263 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
264 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
265 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
266 |
+
# assume that mask is expressed as:
|
267 |
+
# (1 = keep, 0 = discard)
|
268 |
+
# convert mask into a bias that can be added to attention scores:
|
269 |
+
# (keep = +0, discard = -10000.0)
|
270 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
271 |
+
attention_mask = attention_mask.unsqueeze(1)
|
272 |
+
|
273 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
274 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
275 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
276 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
277 |
+
|
278 |
+
# 1. Input
|
279 |
+
if self.is_input_continuous:
|
280 |
+
batch, _, height, width = hidden_states.shape
|
281 |
+
residual = hidden_states
|
282 |
+
|
283 |
+
hidden_states = self.norm(hidden_states)
|
284 |
+
if not self.use_linear_projection:
|
285 |
+
hidden_states = self.proj_in(hidden_states)
|
286 |
+
inner_dim = hidden_states.shape[1]
|
287 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
288 |
+
else:
|
289 |
+
inner_dim = hidden_states.shape[1]
|
290 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
291 |
+
hidden_states = self.proj_in(hidden_states)
|
292 |
+
elif self.is_input_vectorized:
|
293 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
294 |
+
elif self.is_input_patches:
|
295 |
+
hidden_states = self.pos_embed(hidden_states)
|
296 |
+
|
297 |
+
base_attn_key = cross_attention_kwargs["attn_key"]
|
298 |
+
|
299 |
+
# 2. Blocks
|
300 |
+
cross_attention_probs_all = []
|
301 |
+
for block_ind, block in enumerate(self.transformer_blocks):
|
302 |
+
cross_attention_kwargs["attn_key"] = base_attn_key + [block_ind]
|
303 |
+
|
304 |
+
hidden_states = block(
|
305 |
+
hidden_states,
|
306 |
+
attention_mask=attention_mask,
|
307 |
+
encoder_hidden_states=encoder_hidden_states,
|
308 |
+
encoder_attention_mask=encoder_attention_mask,
|
309 |
+
timestep=timestep,
|
310 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
311 |
+
class_labels=class_labels,
|
312 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
313 |
+
)
|
314 |
+
if return_cross_attention_probs:
|
315 |
+
hidden_states, cross_attention_probs = hidden_states
|
316 |
+
cross_attention_probs_all.append(cross_attention_probs)
|
317 |
+
|
318 |
+
# 3. Output
|
319 |
+
if self.is_input_continuous:
|
320 |
+
if not self.use_linear_projection:
|
321 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
322 |
+
hidden_states = self.proj_out(hidden_states)
|
323 |
+
else:
|
324 |
+
hidden_states = self.proj_out(hidden_states)
|
325 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
326 |
+
|
327 |
+
output = hidden_states + residual
|
328 |
+
elif self.is_input_vectorized:
|
329 |
+
hidden_states = self.norm_out(hidden_states)
|
330 |
+
logits = self.out(hidden_states)
|
331 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
332 |
+
logits = logits.permute(0, 2, 1)
|
333 |
+
|
334 |
+
# log(p(x_0))
|
335 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
336 |
+
elif self.is_input_patches:
|
337 |
+
# TODO: cleanup!
|
338 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
339 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
340 |
+
)
|
341 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
342 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
343 |
+
hidden_states = self.proj_out_2(hidden_states)
|
344 |
+
|
345 |
+
# unpatchify
|
346 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
347 |
+
hidden_states = hidden_states.reshape(
|
348 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
349 |
+
)
|
350 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
351 |
+
output = hidden_states.reshape(
|
352 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
353 |
+
)
|
354 |
+
|
355 |
+
if len(cross_attention_probs_all) == 1:
|
356 |
+
# If we only have one transformer block in a Transformer2DModel, we do not create another nested level.
|
357 |
+
cross_attention_probs_all = cross_attention_probs_all[0]
|
358 |
+
|
359 |
+
if not return_dict:
|
360 |
+
if return_cross_attention_probs:
|
361 |
+
return (output, cross_attention_probs_all)
|
362 |
+
return (output,)
|
363 |
+
|
364 |
+
output = Transformer2DModelOutput(sample=output)
|
365 |
+
if return_cross_attention_probs:
|
366 |
+
return output, cross_attention_probs_all
|
367 |
+
return output
|
models/unet_2d_blocks.py
ADDED
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional, Tuple
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from diffusers.utils import is_torch_version
|
22 |
+
from diffusers.models.dual_transformer_2d import DualTransformer2DModel
|
23 |
+
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
24 |
+
from .transformer_2d import Transformer2DModel
|
25 |
+
|
26 |
+
|
27 |
+
def get_down_block(
|
28 |
+
down_block_type,
|
29 |
+
num_layers,
|
30 |
+
in_channels,
|
31 |
+
out_channels,
|
32 |
+
temb_channels,
|
33 |
+
add_downsample,
|
34 |
+
resnet_eps,
|
35 |
+
resnet_act_fn,
|
36 |
+
attn_num_head_channels,
|
37 |
+
resnet_groups=None,
|
38 |
+
cross_attention_dim=None,
|
39 |
+
downsample_padding=None,
|
40 |
+
dual_cross_attention=False,
|
41 |
+
use_linear_projection=False,
|
42 |
+
only_cross_attention=False,
|
43 |
+
upcast_attention=False,
|
44 |
+
resnet_time_scale_shift="default",
|
45 |
+
resnet_skip_time_act=False,
|
46 |
+
resnet_out_scale_factor=1.0,
|
47 |
+
cross_attention_norm=None,
|
48 |
+
use_gated_attention=False,
|
49 |
+
):
|
50 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith(
|
51 |
+
"UNetRes") else down_block_type
|
52 |
+
if down_block_type == "DownBlock2D":
|
53 |
+
return DownBlock2D(
|
54 |
+
num_layers=num_layers,
|
55 |
+
in_channels=in_channels,
|
56 |
+
out_channels=out_channels,
|
57 |
+
temb_channels=temb_channels,
|
58 |
+
add_downsample=add_downsample,
|
59 |
+
resnet_eps=resnet_eps,
|
60 |
+
resnet_act_fn=resnet_act_fn,
|
61 |
+
resnet_groups=resnet_groups,
|
62 |
+
downsample_padding=downsample_padding,
|
63 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
64 |
+
)
|
65 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
66 |
+
if cross_attention_dim is None:
|
67 |
+
raise ValueError(
|
68 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
69 |
+
return CrossAttnDownBlock2D(
|
70 |
+
num_layers=num_layers,
|
71 |
+
in_channels=in_channels,
|
72 |
+
out_channels=out_channels,
|
73 |
+
temb_channels=temb_channels,
|
74 |
+
add_downsample=add_downsample,
|
75 |
+
resnet_eps=resnet_eps,
|
76 |
+
resnet_act_fn=resnet_act_fn,
|
77 |
+
resnet_groups=resnet_groups,
|
78 |
+
downsample_padding=downsample_padding,
|
79 |
+
cross_attention_dim=cross_attention_dim,
|
80 |
+
attn_num_head_channels=attn_num_head_channels,
|
81 |
+
dual_cross_attention=dual_cross_attention,
|
82 |
+
use_linear_projection=use_linear_projection,
|
83 |
+
only_cross_attention=only_cross_attention,
|
84 |
+
upcast_attention=upcast_attention,
|
85 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
86 |
+
use_gated_attention=use_gated_attention,
|
87 |
+
)
|
88 |
+
|
89 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
90 |
+
|
91 |
+
|
92 |
+
def get_up_block(
|
93 |
+
up_block_type,
|
94 |
+
num_layers,
|
95 |
+
in_channels,
|
96 |
+
out_channels,
|
97 |
+
prev_output_channel,
|
98 |
+
temb_channels,
|
99 |
+
add_upsample,
|
100 |
+
resnet_eps,
|
101 |
+
resnet_act_fn,
|
102 |
+
attn_num_head_channels,
|
103 |
+
resnet_groups=None,
|
104 |
+
cross_attention_dim=None,
|
105 |
+
dual_cross_attention=False,
|
106 |
+
use_linear_projection=False,
|
107 |
+
only_cross_attention=False,
|
108 |
+
upcast_attention=False,
|
109 |
+
resnet_time_scale_shift="default",
|
110 |
+
resnet_skip_time_act=False,
|
111 |
+
resnet_out_scale_factor=1.0,
|
112 |
+
cross_attention_norm=None,
|
113 |
+
use_gated_attention=False,
|
114 |
+
):
|
115 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith(
|
116 |
+
"UNetRes") else up_block_type
|
117 |
+
if up_block_type == "UpBlock2D":
|
118 |
+
return UpBlock2D(
|
119 |
+
num_layers=num_layers,
|
120 |
+
in_channels=in_channels,
|
121 |
+
out_channels=out_channels,
|
122 |
+
prev_output_channel=prev_output_channel,
|
123 |
+
temb_channels=temb_channels,
|
124 |
+
add_upsample=add_upsample,
|
125 |
+
resnet_eps=resnet_eps,
|
126 |
+
resnet_act_fn=resnet_act_fn,
|
127 |
+
resnet_groups=resnet_groups,
|
128 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
129 |
+
)
|
130 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
131 |
+
if cross_attention_dim is None:
|
132 |
+
raise ValueError(
|
133 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
134 |
+
return CrossAttnUpBlock2D(
|
135 |
+
num_layers=num_layers,
|
136 |
+
in_channels=in_channels,
|
137 |
+
out_channels=out_channels,
|
138 |
+
prev_output_channel=prev_output_channel,
|
139 |
+
temb_channels=temb_channels,
|
140 |
+
add_upsample=add_upsample,
|
141 |
+
resnet_eps=resnet_eps,
|
142 |
+
resnet_act_fn=resnet_act_fn,
|
143 |
+
resnet_groups=resnet_groups,
|
144 |
+
cross_attention_dim=cross_attention_dim,
|
145 |
+
attn_num_head_channels=attn_num_head_channels,
|
146 |
+
dual_cross_attention=dual_cross_attention,
|
147 |
+
use_linear_projection=use_linear_projection,
|
148 |
+
only_cross_attention=only_cross_attention,
|
149 |
+
upcast_attention=upcast_attention,
|
150 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
151 |
+
use_gated_attention=use_gated_attention,
|
152 |
+
)
|
153 |
+
|
154 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
155 |
+
|
156 |
+
|
157 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
in_channels: int,
|
161 |
+
temb_channels: int,
|
162 |
+
dropout: float = 0.0,
|
163 |
+
num_layers: int = 1,
|
164 |
+
resnet_eps: float = 1e-6,
|
165 |
+
resnet_time_scale_shift: str = "default",
|
166 |
+
resnet_act_fn: str = "swish",
|
167 |
+
resnet_groups: int = 32,
|
168 |
+
resnet_pre_norm: bool = True,
|
169 |
+
attn_num_head_channels=1,
|
170 |
+
output_scale_factor=1.0,
|
171 |
+
cross_attention_dim=1280,
|
172 |
+
dual_cross_attention=False,
|
173 |
+
use_linear_projection=False,
|
174 |
+
upcast_attention=False,
|
175 |
+
use_gated_attention=False,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
|
179 |
+
self.has_cross_attention = True
|
180 |
+
self.attn_num_head_channels = attn_num_head_channels
|
181 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(
|
182 |
+
in_channels // 4, 32)
|
183 |
+
|
184 |
+
# there is always at least one resnet
|
185 |
+
resnets = [
|
186 |
+
ResnetBlock2D(
|
187 |
+
in_channels=in_channels,
|
188 |
+
out_channels=in_channels,
|
189 |
+
temb_channels=temb_channels,
|
190 |
+
eps=resnet_eps,
|
191 |
+
groups=resnet_groups,
|
192 |
+
dropout=dropout,
|
193 |
+
time_embedding_norm=resnet_time_scale_shift,
|
194 |
+
non_linearity=resnet_act_fn,
|
195 |
+
output_scale_factor=output_scale_factor,
|
196 |
+
pre_norm=resnet_pre_norm,
|
197 |
+
)
|
198 |
+
]
|
199 |
+
attentions = []
|
200 |
+
|
201 |
+
for _ in range(num_layers):
|
202 |
+
if not dual_cross_attention:
|
203 |
+
attentions.append(
|
204 |
+
Transformer2DModel(
|
205 |
+
attn_num_head_channels,
|
206 |
+
in_channels // attn_num_head_channels,
|
207 |
+
in_channels=in_channels,
|
208 |
+
num_layers=1,
|
209 |
+
cross_attention_dim=cross_attention_dim,
|
210 |
+
norm_num_groups=resnet_groups,
|
211 |
+
use_linear_projection=use_linear_projection,
|
212 |
+
upcast_attention=upcast_attention,
|
213 |
+
use_gated_attention=use_gated_attention,
|
214 |
+
)
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
attentions.append(
|
218 |
+
DualTransformer2DModel(
|
219 |
+
attn_num_head_channels,
|
220 |
+
in_channels // attn_num_head_channels,
|
221 |
+
in_channels=in_channels,
|
222 |
+
num_layers=1,
|
223 |
+
cross_attention_dim=cross_attention_dim,
|
224 |
+
norm_num_groups=resnet_groups,
|
225 |
+
)
|
226 |
+
)
|
227 |
+
resnets.append(
|
228 |
+
ResnetBlock2D(
|
229 |
+
in_channels=in_channels,
|
230 |
+
out_channels=in_channels,
|
231 |
+
temb_channels=temb_channels,
|
232 |
+
eps=resnet_eps,
|
233 |
+
groups=resnet_groups,
|
234 |
+
dropout=dropout,
|
235 |
+
time_embedding_norm=resnet_time_scale_shift,
|
236 |
+
non_linearity=resnet_act_fn,
|
237 |
+
output_scale_factor=output_scale_factor,
|
238 |
+
pre_norm=resnet_pre_norm,
|
239 |
+
)
|
240 |
+
)
|
241 |
+
|
242 |
+
self.attentions = nn.ModuleList(attentions)
|
243 |
+
self.resnets = nn.ModuleList(resnets)
|
244 |
+
|
245 |
+
def forward(
|
246 |
+
self,
|
247 |
+
hidden_states: torch.FloatTensor,
|
248 |
+
temb: Optional[torch.FloatTensor] = None,
|
249 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
250 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
251 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
252 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
253 |
+
return_cross_attention_probs: bool = False,
|
254 |
+
) -> torch.FloatTensor:
|
255 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
256 |
+
cross_attention_probs_all = []
|
257 |
+
base_attn_key = cross_attention_kwargs["attn_key"]
|
258 |
+
for attn_key, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
|
259 |
+
cross_attention_kwargs["attn_key"] = base_attn_key + [attn_key]
|
260 |
+
hidden_states = attn(
|
261 |
+
hidden_states,
|
262 |
+
encoder_hidden_states=encoder_hidden_states,
|
263 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
264 |
+
attention_mask=attention_mask,
|
265 |
+
encoder_attention_mask=encoder_attention_mask,
|
266 |
+
return_dict=False,
|
267 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
268 |
+
)
|
269 |
+
if return_cross_attention_probs:
|
270 |
+
hidden_states, cross_attention_probs = hidden_states
|
271 |
+
cross_attention_probs_all.append(cross_attention_probs)
|
272 |
+
else:
|
273 |
+
hidden_states = hidden_states[0]
|
274 |
+
hidden_states = resnet(hidden_states, temb)
|
275 |
+
|
276 |
+
if return_cross_attention_probs:
|
277 |
+
return hidden_states, cross_attention_probs_all
|
278 |
+
return hidden_states
|
279 |
+
|
280 |
+
|
281 |
+
class CrossAttnDownBlock2D(nn.Module):
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
in_channels: int,
|
285 |
+
out_channels: int,
|
286 |
+
temb_channels: int,
|
287 |
+
dropout: float = 0.0,
|
288 |
+
num_layers: int = 1,
|
289 |
+
resnet_eps: float = 1e-6,
|
290 |
+
resnet_time_scale_shift: str = "default",
|
291 |
+
resnet_act_fn: str = "swish",
|
292 |
+
resnet_groups: int = 32,
|
293 |
+
resnet_pre_norm: bool = True,
|
294 |
+
attn_num_head_channels=1,
|
295 |
+
cross_attention_dim=1280,
|
296 |
+
output_scale_factor=1.0,
|
297 |
+
downsample_padding=1,
|
298 |
+
add_downsample=True,
|
299 |
+
dual_cross_attention=False,
|
300 |
+
use_linear_projection=False,
|
301 |
+
only_cross_attention=False,
|
302 |
+
upcast_attention=False,
|
303 |
+
use_gated_attention=False,
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
resnets = []
|
307 |
+
attentions = []
|
308 |
+
|
309 |
+
self.has_cross_attention = True
|
310 |
+
self.attn_num_head_channels = attn_num_head_channels
|
311 |
+
|
312 |
+
for i in range(num_layers):
|
313 |
+
in_channels = in_channels if i == 0 else out_channels
|
314 |
+
resnets.append(
|
315 |
+
ResnetBlock2D(
|
316 |
+
in_channels=in_channels,
|
317 |
+
out_channels=out_channels,
|
318 |
+
temb_channels=temb_channels,
|
319 |
+
eps=resnet_eps,
|
320 |
+
groups=resnet_groups,
|
321 |
+
dropout=dropout,
|
322 |
+
time_embedding_norm=resnet_time_scale_shift,
|
323 |
+
non_linearity=resnet_act_fn,
|
324 |
+
output_scale_factor=output_scale_factor,
|
325 |
+
pre_norm=resnet_pre_norm,
|
326 |
+
)
|
327 |
+
)
|
328 |
+
if not dual_cross_attention:
|
329 |
+
attentions.append(
|
330 |
+
Transformer2DModel(
|
331 |
+
attn_num_head_channels,
|
332 |
+
out_channels // attn_num_head_channels,
|
333 |
+
in_channels=out_channels,
|
334 |
+
num_layers=1,
|
335 |
+
cross_attention_dim=cross_attention_dim,
|
336 |
+
norm_num_groups=resnet_groups,
|
337 |
+
use_linear_projection=use_linear_projection,
|
338 |
+
only_cross_attention=only_cross_attention,
|
339 |
+
upcast_attention=upcast_attention,
|
340 |
+
use_gated_attention=use_gated_attention
|
341 |
+
)
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
attentions.append(
|
345 |
+
DualTransformer2DModel(
|
346 |
+
attn_num_head_channels,
|
347 |
+
out_channels // attn_num_head_channels,
|
348 |
+
in_channels=out_channels,
|
349 |
+
num_layers=1,
|
350 |
+
cross_attention_dim=cross_attention_dim,
|
351 |
+
norm_num_groups=resnet_groups,
|
352 |
+
)
|
353 |
+
)
|
354 |
+
self.attentions = nn.ModuleList(attentions)
|
355 |
+
self.resnets = nn.ModuleList(resnets)
|
356 |
+
|
357 |
+
if add_downsample:
|
358 |
+
self.downsamplers = nn.ModuleList(
|
359 |
+
[
|
360 |
+
Downsample2D(
|
361 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
362 |
+
)
|
363 |
+
]
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
self.downsamplers = None
|
367 |
+
|
368 |
+
self.gradient_checkpointing = False
|
369 |
+
|
370 |
+
def forward(
|
371 |
+
self,
|
372 |
+
hidden_states: torch.FloatTensor,
|
373 |
+
temb: Optional[torch.FloatTensor] = None,
|
374 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
375 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
376 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
377 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
378 |
+
return_cross_attention_probs: bool = False,
|
379 |
+
):
|
380 |
+
output_states = ()
|
381 |
+
cross_attention_probs_all = []
|
382 |
+
base_attn_key = cross_attention_kwargs["attn_key"]
|
383 |
+
|
384 |
+
for attn_key, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
385 |
+
|
386 |
+
cross_attention_kwargs["attn_key"] = base_attn_key + [attn_key]
|
387 |
+
|
388 |
+
if self.training and self.gradient_checkpointing:
|
389 |
+
|
390 |
+
def create_custom_forward(module, return_dict=None):
|
391 |
+
def custom_forward(*inputs):
|
392 |
+
if return_dict is not None:
|
393 |
+
return module(*inputs, return_dict=return_dict)
|
394 |
+
else:
|
395 |
+
return module(*inputs)
|
396 |
+
|
397 |
+
return custom_forward
|
398 |
+
|
399 |
+
ckpt_kwargs: Dict[str, Any] = {
|
400 |
+
"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
401 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
402 |
+
create_custom_forward(resnet),
|
403 |
+
hidden_states,
|
404 |
+
temb,
|
405 |
+
**ckpt_kwargs,
|
406 |
+
)
|
407 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
408 |
+
create_custom_forward(attn, return_dict=False),
|
409 |
+
hidden_states,
|
410 |
+
encoder_hidden_states,
|
411 |
+
None, # timestep
|
412 |
+
None, # class_labels
|
413 |
+
cross_attention_kwargs,
|
414 |
+
attention_mask,
|
415 |
+
encoder_attention_mask,
|
416 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
417 |
+
**ckpt_kwargs,
|
418 |
+
)
|
419 |
+
if return_cross_attention_probs:
|
420 |
+
hidden_states, cross_attention_probs = hidden_states
|
421 |
+
cross_attention_probs_all.append(cross_attention_probs)
|
422 |
+
else:
|
423 |
+
hidden_states = hidden_states[0]
|
424 |
+
else:
|
425 |
+
hidden_states = resnet(hidden_states, temb)
|
426 |
+
hidden_states = attn(
|
427 |
+
hidden_states,
|
428 |
+
encoder_hidden_states=encoder_hidden_states,
|
429 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
430 |
+
attention_mask=attention_mask,
|
431 |
+
encoder_attention_mask=encoder_attention_mask,
|
432 |
+
return_dict=False,
|
433 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
434 |
+
)
|
435 |
+
if return_cross_attention_probs:
|
436 |
+
hidden_states, cross_attention_probs = hidden_states
|
437 |
+
cross_attention_probs_all.append(cross_attention_probs)
|
438 |
+
else:
|
439 |
+
hidden_states = hidden_states[0]
|
440 |
+
|
441 |
+
output_states = output_states + (hidden_states,)
|
442 |
+
|
443 |
+
if self.downsamplers is not None:
|
444 |
+
for downsampler in self.downsamplers:
|
445 |
+
hidden_states = downsampler(hidden_states)
|
446 |
+
|
447 |
+
output_states = output_states + (hidden_states,)
|
448 |
+
|
449 |
+
if return_cross_attention_probs:
|
450 |
+
return hidden_states, output_states, cross_attention_probs_all
|
451 |
+
return hidden_states, output_states
|
452 |
+
|
453 |
+
|
454 |
+
class DownBlock2D(nn.Module):
|
455 |
+
def __init__(
|
456 |
+
self,
|
457 |
+
in_channels: int,
|
458 |
+
out_channels: int,
|
459 |
+
temb_channels: int,
|
460 |
+
dropout: float = 0.0,
|
461 |
+
num_layers: int = 1,
|
462 |
+
resnet_eps: float = 1e-6,
|
463 |
+
resnet_time_scale_shift: str = "default",
|
464 |
+
resnet_act_fn: str = "swish",
|
465 |
+
resnet_groups: int = 32,
|
466 |
+
resnet_pre_norm: bool = True,
|
467 |
+
output_scale_factor=1.0,
|
468 |
+
add_downsample=True,
|
469 |
+
downsample_padding=1,
|
470 |
+
):
|
471 |
+
super().__init__()
|
472 |
+
resnets = []
|
473 |
+
|
474 |
+
for i in range(num_layers):
|
475 |
+
in_channels = in_channels if i == 0 else out_channels
|
476 |
+
resnets.append(
|
477 |
+
ResnetBlock2D(
|
478 |
+
in_channels=in_channels,
|
479 |
+
out_channels=out_channels,
|
480 |
+
temb_channels=temb_channels,
|
481 |
+
eps=resnet_eps,
|
482 |
+
groups=resnet_groups,
|
483 |
+
dropout=dropout,
|
484 |
+
time_embedding_norm=resnet_time_scale_shift,
|
485 |
+
non_linearity=resnet_act_fn,
|
486 |
+
output_scale_factor=output_scale_factor,
|
487 |
+
pre_norm=resnet_pre_norm,
|
488 |
+
)
|
489 |
+
)
|
490 |
+
|
491 |
+
self.resnets = nn.ModuleList(resnets)
|
492 |
+
|
493 |
+
if add_downsample:
|
494 |
+
self.downsamplers = nn.ModuleList(
|
495 |
+
[
|
496 |
+
Downsample2D(
|
497 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
498 |
+
)
|
499 |
+
]
|
500 |
+
)
|
501 |
+
else:
|
502 |
+
self.downsamplers = None
|
503 |
+
|
504 |
+
self.gradient_checkpointing = False
|
505 |
+
|
506 |
+
def forward(self, hidden_states, temb=None):
|
507 |
+
output_states = ()
|
508 |
+
|
509 |
+
for resnet in self.resnets:
|
510 |
+
if self.training and self.gradient_checkpointing:
|
511 |
+
|
512 |
+
def create_custom_forward(module):
|
513 |
+
def custom_forward(*inputs):
|
514 |
+
return module(*inputs)
|
515 |
+
|
516 |
+
return custom_forward
|
517 |
+
|
518 |
+
if is_torch_version(">=", "1.11.0"):
|
519 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
520 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
521 |
+
)
|
522 |
+
else:
|
523 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
524 |
+
create_custom_forward(resnet), hidden_states, temb
|
525 |
+
)
|
526 |
+
else:
|
527 |
+
hidden_states = resnet(hidden_states, temb)
|
528 |
+
|
529 |
+
output_states = output_states + (hidden_states,)
|
530 |
+
|
531 |
+
if self.downsamplers is not None:
|
532 |
+
for downsampler in self.downsamplers:
|
533 |
+
hidden_states = downsampler(hidden_states)
|
534 |
+
|
535 |
+
output_states = output_states + (hidden_states,)
|
536 |
+
|
537 |
+
return hidden_states, output_states
|
538 |
+
|
539 |
+
|
540 |
+
class CrossAttnUpBlock2D(nn.Module):
|
541 |
+
def __init__(
|
542 |
+
self,
|
543 |
+
in_channels: int,
|
544 |
+
out_channels: int,
|
545 |
+
prev_output_channel: int,
|
546 |
+
temb_channels: int,
|
547 |
+
dropout: float = 0.0,
|
548 |
+
num_layers: int = 1,
|
549 |
+
resnet_eps: float = 1e-6,
|
550 |
+
resnet_time_scale_shift: str = "default",
|
551 |
+
resnet_act_fn: str = "swish",
|
552 |
+
resnet_groups: int = 32,
|
553 |
+
resnet_pre_norm: bool = True,
|
554 |
+
attn_num_head_channels=1,
|
555 |
+
cross_attention_dim=1280,
|
556 |
+
output_scale_factor=1.0,
|
557 |
+
add_upsample=True,
|
558 |
+
dual_cross_attention=False,
|
559 |
+
use_linear_projection=False,
|
560 |
+
only_cross_attention=False,
|
561 |
+
upcast_attention=False,
|
562 |
+
use_gated_attention=False,
|
563 |
+
):
|
564 |
+
super().__init__()
|
565 |
+
resnets = []
|
566 |
+
attentions = []
|
567 |
+
|
568 |
+
self.has_cross_attention = True
|
569 |
+
self.attn_num_head_channels = attn_num_head_channels
|
570 |
+
|
571 |
+
for i in range(num_layers):
|
572 |
+
res_skip_channels = in_channels if (
|
573 |
+
i == num_layers - 1) else out_channels
|
574 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
575 |
+
|
576 |
+
resnets.append(
|
577 |
+
ResnetBlock2D(
|
578 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
579 |
+
out_channels=out_channels,
|
580 |
+
temb_channels=temb_channels,
|
581 |
+
eps=resnet_eps,
|
582 |
+
groups=resnet_groups,
|
583 |
+
dropout=dropout,
|
584 |
+
time_embedding_norm=resnet_time_scale_shift,
|
585 |
+
non_linearity=resnet_act_fn,
|
586 |
+
output_scale_factor=output_scale_factor,
|
587 |
+
pre_norm=resnet_pre_norm,
|
588 |
+
)
|
589 |
+
)
|
590 |
+
if not dual_cross_attention:
|
591 |
+
attentions.append(
|
592 |
+
Transformer2DModel(
|
593 |
+
attn_num_head_channels,
|
594 |
+
out_channels // attn_num_head_channels,
|
595 |
+
in_channels=out_channels,
|
596 |
+
num_layers=1,
|
597 |
+
cross_attention_dim=cross_attention_dim,
|
598 |
+
norm_num_groups=resnet_groups,
|
599 |
+
use_linear_projection=use_linear_projection,
|
600 |
+
only_cross_attention=only_cross_attention,
|
601 |
+
upcast_attention=upcast_attention,
|
602 |
+
use_gated_attention=use_gated_attention,
|
603 |
+
)
|
604 |
+
)
|
605 |
+
else:
|
606 |
+
attentions.append(
|
607 |
+
DualTransformer2DModel(
|
608 |
+
attn_num_head_channels,
|
609 |
+
out_channels // attn_num_head_channels,
|
610 |
+
in_channels=out_channels,
|
611 |
+
num_layers=1,
|
612 |
+
cross_attention_dim=cross_attention_dim,
|
613 |
+
norm_num_groups=resnet_groups,
|
614 |
+
)
|
615 |
+
)
|
616 |
+
self.attentions = nn.ModuleList(attentions)
|
617 |
+
self.resnets = nn.ModuleList(resnets)
|
618 |
+
|
619 |
+
if add_upsample:
|
620 |
+
self.upsamplers = nn.ModuleList(
|
621 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
622 |
+
else:
|
623 |
+
self.upsamplers = None
|
624 |
+
|
625 |
+
self.gradient_checkpointing = False
|
626 |
+
|
627 |
+
def forward(
|
628 |
+
self,
|
629 |
+
hidden_states: torch.FloatTensor,
|
630 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
631 |
+
temb: Optional[torch.FloatTensor] = None,
|
632 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
633 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
634 |
+
upsample_size: Optional[int] = None,
|
635 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
636 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
637 |
+
return_cross_attention_probs: bool = False,
|
638 |
+
):
|
639 |
+
cross_attention_probs_all = []
|
640 |
+
base_attn_key = cross_attention_kwargs["attn_key"]
|
641 |
+
|
642 |
+
for attn_key, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
643 |
+
cross_attention_kwargs["attn_key"] = base_attn_key + [attn_key]
|
644 |
+
|
645 |
+
# pop res hidden states
|
646 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
647 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
648 |
+
hidden_states = torch.cat(
|
649 |
+
[hidden_states, res_hidden_states], dim=1)
|
650 |
+
|
651 |
+
if self.training and self.gradient_checkpointing:
|
652 |
+
|
653 |
+
def create_custom_forward(module, return_dict=None):
|
654 |
+
def custom_forward(*inputs):
|
655 |
+
if return_dict is not None:
|
656 |
+
return module(*inputs, return_dict=return_dict)
|
657 |
+
else:
|
658 |
+
return module(*inputs)
|
659 |
+
|
660 |
+
return custom_forward
|
661 |
+
|
662 |
+
ckpt_kwargs: Dict[str, Any] = {
|
663 |
+
"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
664 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
665 |
+
create_custom_forward(resnet),
|
666 |
+
hidden_states,
|
667 |
+
temb,
|
668 |
+
**ckpt_kwargs,
|
669 |
+
)
|
670 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
671 |
+
create_custom_forward(attn, return_dict=False),
|
672 |
+
hidden_states,
|
673 |
+
encoder_hidden_states,
|
674 |
+
None, # timestep
|
675 |
+
None, # class_labels
|
676 |
+
cross_attention_kwargs,
|
677 |
+
attention_mask,
|
678 |
+
encoder_attention_mask,
|
679 |
+
**ckpt_kwargs,
|
680 |
+
)
|
681 |
+
if return_cross_attention_probs:
|
682 |
+
hidden_states, cross_attention_probs = hidden_states
|
683 |
+
cross_attention_probs_all.append(cross_attention_probs)
|
684 |
+
else:
|
685 |
+
hidden_states = hidden_states[0]
|
686 |
+
else:
|
687 |
+
hidden_states = resnet(hidden_states, temb)
|
688 |
+
hidden_states = attn(
|
689 |
+
hidden_states,
|
690 |
+
encoder_hidden_states=encoder_hidden_states,
|
691 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
692 |
+
attention_mask=attention_mask,
|
693 |
+
encoder_attention_mask=encoder_attention_mask,
|
694 |
+
return_dict=False,
|
695 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
696 |
+
)
|
697 |
+
if return_cross_attention_probs:
|
698 |
+
hidden_states, cross_attention_probs = hidden_states
|
699 |
+
cross_attention_probs_all.append(cross_attention_probs)
|
700 |
+
else:
|
701 |
+
hidden_states = hidden_states[0]
|
702 |
+
|
703 |
+
if self.upsamplers is not None:
|
704 |
+
for upsampler in self.upsamplers:
|
705 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
706 |
+
|
707 |
+
if return_cross_attention_probs:
|
708 |
+
return hidden_states, cross_attention_probs_all
|
709 |
+
return hidden_states
|
710 |
+
|
711 |
+
|
712 |
+
class UpBlock2D(nn.Module):
|
713 |
+
def __init__(
|
714 |
+
self,
|
715 |
+
in_channels: int,
|
716 |
+
prev_output_channel: int,
|
717 |
+
out_channels: int,
|
718 |
+
temb_channels: int,
|
719 |
+
dropout: float = 0.0,
|
720 |
+
num_layers: int = 1,
|
721 |
+
resnet_eps: float = 1e-6,
|
722 |
+
resnet_time_scale_shift: str = "default",
|
723 |
+
resnet_act_fn: str = "swish",
|
724 |
+
resnet_groups: int = 32,
|
725 |
+
resnet_pre_norm: bool = True,
|
726 |
+
output_scale_factor=1.0,
|
727 |
+
add_upsample=True,
|
728 |
+
):
|
729 |
+
super().__init__()
|
730 |
+
resnets = []
|
731 |
+
|
732 |
+
for i in range(num_layers):
|
733 |
+
res_skip_channels = in_channels if (
|
734 |
+
i == num_layers - 1) else out_channels
|
735 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
736 |
+
|
737 |
+
resnets.append(
|
738 |
+
ResnetBlock2D(
|
739 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
740 |
+
out_channels=out_channels,
|
741 |
+
temb_channels=temb_channels,
|
742 |
+
eps=resnet_eps,
|
743 |
+
groups=resnet_groups,
|
744 |
+
dropout=dropout,
|
745 |
+
time_embedding_norm=resnet_time_scale_shift,
|
746 |
+
non_linearity=resnet_act_fn,
|
747 |
+
output_scale_factor=output_scale_factor,
|
748 |
+
pre_norm=resnet_pre_norm,
|
749 |
+
)
|
750 |
+
)
|
751 |
+
|
752 |
+
self.resnets = nn.ModuleList(resnets)
|
753 |
+
|
754 |
+
if add_upsample:
|
755 |
+
self.upsamplers = nn.ModuleList(
|
756 |
+
[Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
757 |
+
else:
|
758 |
+
self.upsamplers = None
|
759 |
+
|
760 |
+
self.gradient_checkpointing = False
|
761 |
+
|
762 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
763 |
+
for resnet in self.resnets:
|
764 |
+
# pop res hidden states
|
765 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
766 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
767 |
+
hidden_states = torch.cat(
|
768 |
+
[hidden_states, res_hidden_states], dim=1)
|
769 |
+
|
770 |
+
if self.training and self.gradient_checkpointing:
|
771 |
+
|
772 |
+
def create_custom_forward(module):
|
773 |
+
def custom_forward(*inputs):
|
774 |
+
return module(*inputs)
|
775 |
+
|
776 |
+
return custom_forward
|
777 |
+
|
778 |
+
if is_torch_version(">=", "1.11.0"):
|
779 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
780 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
781 |
+
)
|
782 |
+
else:
|
783 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
784 |
+
create_custom_forward(resnet), hidden_states, temb
|
785 |
+
)
|
786 |
+
else:
|
787 |
+
hidden_states = resnet(hidden_states, temb)
|
788 |
+
|
789 |
+
if self.upsamplers is not None:
|
790 |
+
for upsampler in self.upsamplers:
|
791 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
792 |
+
|
793 |
+
return hidden_states
|
models/unet_2d_condition.py
ADDED
@@ -0,0 +1,980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
24 |
+
from diffusers.utils import BaseOutput, logging
|
25 |
+
from diffusers.models.embeddings import (
|
26 |
+
GaussianFourierProjection,
|
27 |
+
TextImageProjection,
|
28 |
+
TextImageTimeEmbedding,
|
29 |
+
TextTimeEmbedding,
|
30 |
+
TimestepEmbedding,
|
31 |
+
Timesteps,
|
32 |
+
)
|
33 |
+
from diffusers.models.modeling_utils import ModelMixin
|
34 |
+
from .unet_2d_blocks import (
|
35 |
+
CrossAttnDownBlock2D,
|
36 |
+
CrossAttnUpBlock2D,
|
37 |
+
DownBlock2D,
|
38 |
+
UNetMidBlock2DCrossAttn,
|
39 |
+
UpBlock2D,
|
40 |
+
get_down_block,
|
41 |
+
get_up_block,
|
42 |
+
)
|
43 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
44 |
+
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class UNet2DConditionOutput(BaseOutput):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
54 |
+
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
55 |
+
"""
|
56 |
+
|
57 |
+
sample: torch.FloatTensor
|
58 |
+
cross_attention_probs_down: List[Any]
|
59 |
+
cross_attention_probs_mid: List[Any]
|
60 |
+
cross_attention_probs_up: List[Any]
|
61 |
+
|
62 |
+
|
63 |
+
class FourierEmbedder(nn.Module):
|
64 |
+
def __init__(self, num_freqs=64, temperature=100):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self.num_freqs = num_freqs
|
68 |
+
self.temperature = temperature
|
69 |
+
|
70 |
+
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
71 |
+
freq_bands = freq_bands[None, None, None]
|
72 |
+
self.register_buffer('freq_bands', freq_bands, persistent=False)
|
73 |
+
|
74 |
+
def __call__(self, x):
|
75 |
+
x = self.freq_bands * x.unsqueeze(-1)
|
76 |
+
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
77 |
+
|
78 |
+
|
79 |
+
class PositionNet(nn.Module):
|
80 |
+
def __init__(self, positive_len, out_dim, fourier_freqs=8):
|
81 |
+
super().__init__()
|
82 |
+
self.positive_len = positive_len
|
83 |
+
self.out_dim = out_dim
|
84 |
+
|
85 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
86 |
+
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
87 |
+
|
88 |
+
self.linears = nn.Sequential(
|
89 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
90 |
+
nn.SiLU(),
|
91 |
+
nn.Linear(512, 512),
|
92 |
+
nn.SiLU(),
|
93 |
+
nn.Linear(512, out_dim),
|
94 |
+
)
|
95 |
+
|
96 |
+
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
97 |
+
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
|
98 |
+
|
99 |
+
def forward(self, boxes, masks, positive_embeddings):
|
100 |
+
masks = masks.unsqueeze(-1)
|
101 |
+
|
102 |
+
# embedding position (it may includes padding as placeholder)
|
103 |
+
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
|
104 |
+
|
105 |
+
# learnable null embedding
|
106 |
+
positive_null = self.null_positive_feature.view(1, 1, -1)
|
107 |
+
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
108 |
+
|
109 |
+
# replace padding with learnable null embedding
|
110 |
+
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
|
111 |
+
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
112 |
+
|
113 |
+
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
114 |
+
return objs
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
119 |
+
r"""
|
120 |
+
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
121 |
+
and returns sample shaped output.
|
122 |
+
|
123 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
124 |
+
implements for all the models (such as downloading or saving, etc.)
|
125 |
+
|
126 |
+
Parameters:
|
127 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
128 |
+
Height and width of input/output sample.
|
129 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
130 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
131 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
132 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
133 |
+
Whether to flip the sin to cos in the time embedding.
|
134 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
135 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
136 |
+
The tuple of downsample blocks to use.
|
137 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
138 |
+
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
|
139 |
+
mid block layer if `None`.
|
140 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
141 |
+
The tuple of upsample blocks to use.
|
142 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
143 |
+
Whether to include self-attention in the basic transformer blocks, see
|
144 |
+
[`~models.attention.BasicTransformerBlock`].
|
145 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
146 |
+
The tuple of output channels for each block.
|
147 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
148 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
149 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
150 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
151 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
152 |
+
If `None`, it will skip the normalization and activation layers in post-processing
|
153 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
154 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
155 |
+
The dimension of the cross attention features.
|
156 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
157 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
158 |
+
dimension to `cross_attention_dim`.
|
159 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to None):
|
160 |
+
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
|
161 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
162 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
163 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
164 |
+
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
165 |
+
class_embed_type (`str`, *optional*, defaults to None):
|
166 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
167 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
168 |
+
addition_embed_type (`str`, *optional*, defaults to None):
|
169 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
170 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
171 |
+
num_class_embeds (`int`, *optional*, defaults to None):
|
172 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
173 |
+
class conditioning with `class_embed_type` equal to `None`.
|
174 |
+
time_embedding_type (`str`, *optional*, default to `positional`):
|
175 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
176 |
+
time_embedding_dim (`int`, *optional*, default to `None`):
|
177 |
+
An optional override for the dimension of the projected time embedding.
|
178 |
+
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
179 |
+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
180 |
+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
181 |
+
timestep_post_act (`str, *optional*, default to `None`):
|
182 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
183 |
+
time_cond_proj_dim (`int`, *optional*, default to `None`):
|
184 |
+
The dimension of `cond_proj` layer in timestep embedding.
|
185 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
186 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
187 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
188 |
+
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
|
189 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
190 |
+
embeddings with the class embeddings.
|
191 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
192 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
193 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
|
194 |
+
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
|
195 |
+
default to `False`.
|
196 |
+
"""
|
197 |
+
|
198 |
+
_supports_gradient_checkpointing = True
|
199 |
+
|
200 |
+
@register_to_config
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
sample_size: Optional[int] = None,
|
204 |
+
in_channels: int = 4,
|
205 |
+
out_channels: int = 4,
|
206 |
+
center_input_sample: bool = False,
|
207 |
+
flip_sin_to_cos: bool = True,
|
208 |
+
freq_shift: int = 0,
|
209 |
+
down_block_types: Tuple[str] = (
|
210 |
+
"CrossAttnDownBlock2D",
|
211 |
+
"CrossAttnDownBlock2D",
|
212 |
+
"CrossAttnDownBlock2D",
|
213 |
+
"DownBlock2D",
|
214 |
+
),
|
215 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
216 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
217 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
218 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
219 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
220 |
+
downsample_padding: int = 1,
|
221 |
+
mid_block_scale_factor: float = 1,
|
222 |
+
act_fn: str = "silu",
|
223 |
+
norm_num_groups: Optional[int] = 32,
|
224 |
+
norm_eps: float = 1e-5,
|
225 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
226 |
+
encoder_hid_dim: Optional[int] = None,
|
227 |
+
encoder_hid_dim_type: Optional[str] = None,
|
228 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
229 |
+
dual_cross_attention: bool = False,
|
230 |
+
use_linear_projection: bool = False,
|
231 |
+
class_embed_type: Optional[str] = None,
|
232 |
+
addition_embed_type: Optional[str] = None,
|
233 |
+
num_class_embeds: Optional[int] = None,
|
234 |
+
upcast_attention: bool = False,
|
235 |
+
resnet_time_scale_shift: str = "default",
|
236 |
+
resnet_skip_time_act: bool = False,
|
237 |
+
resnet_out_scale_factor: int = 1.0,
|
238 |
+
time_embedding_type: str = "positional",
|
239 |
+
time_embedding_dim: Optional[int] = None,
|
240 |
+
time_embedding_act_fn: Optional[str] = None,
|
241 |
+
timestep_post_act: Optional[str] = None,
|
242 |
+
time_cond_proj_dim: Optional[int] = None,
|
243 |
+
conv_in_kernel: int = 3,
|
244 |
+
conv_out_kernel: int = 3,
|
245 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
246 |
+
class_embeddings_concat: bool = False,
|
247 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
248 |
+
cross_attention_norm: Optional[str] = None,
|
249 |
+
addition_embed_type_num_heads=64,
|
250 |
+
use_gated_attention: bool = False,
|
251 |
+
):
|
252 |
+
super().__init__()
|
253 |
+
|
254 |
+
self.sample_size = sample_size
|
255 |
+
|
256 |
+
# Check inputs
|
257 |
+
if len(down_block_types) != len(up_block_types):
|
258 |
+
raise ValueError(
|
259 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
260 |
+
)
|
261 |
+
|
262 |
+
if len(block_out_channels) != len(down_block_types):
|
263 |
+
raise ValueError(
|
264 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
265 |
+
)
|
266 |
+
|
267 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
268 |
+
raise ValueError(
|
269 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
270 |
+
)
|
271 |
+
|
272 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
273 |
+
raise ValueError(
|
274 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
275 |
+
)
|
276 |
+
|
277 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
278 |
+
raise ValueError(
|
279 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
280 |
+
)
|
281 |
+
|
282 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
283 |
+
raise ValueError(
|
284 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
285 |
+
)
|
286 |
+
|
287 |
+
# input
|
288 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
289 |
+
self.conv_in = nn.Conv2d(
|
290 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
291 |
+
)
|
292 |
+
|
293 |
+
# time
|
294 |
+
if time_embedding_type == "fourier":
|
295 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
296 |
+
if time_embed_dim % 2 != 0:
|
297 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
298 |
+
self.time_proj = GaussianFourierProjection(
|
299 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
300 |
+
)
|
301 |
+
timestep_input_dim = time_embed_dim
|
302 |
+
elif time_embedding_type == "positional":
|
303 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
304 |
+
|
305 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
306 |
+
timestep_input_dim = block_out_channels[0]
|
307 |
+
else:
|
308 |
+
raise ValueError(
|
309 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
310 |
+
)
|
311 |
+
|
312 |
+
self.time_embedding = TimestepEmbedding(
|
313 |
+
timestep_input_dim,
|
314 |
+
time_embed_dim,
|
315 |
+
act_fn=act_fn,
|
316 |
+
post_act_fn=timestep_post_act,
|
317 |
+
cond_proj_dim=time_cond_proj_dim,
|
318 |
+
)
|
319 |
+
|
320 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
321 |
+
encoder_hid_dim_type = "text_proj"
|
322 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
323 |
+
|
324 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
325 |
+
raise ValueError(
|
326 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
327 |
+
)
|
328 |
+
|
329 |
+
if encoder_hid_dim_type == "text_proj":
|
330 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
331 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
332 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
333 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
334 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
335 |
+
self.encoder_hid_proj = TextImageProjection(
|
336 |
+
text_embed_dim=encoder_hid_dim,
|
337 |
+
image_embed_dim=cross_attention_dim,
|
338 |
+
cross_attention_dim=cross_attention_dim,
|
339 |
+
)
|
340 |
+
|
341 |
+
elif encoder_hid_dim_type is not None:
|
342 |
+
raise ValueError(
|
343 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
344 |
+
)
|
345 |
+
else:
|
346 |
+
self.encoder_hid_proj = None
|
347 |
+
|
348 |
+
# class embedding
|
349 |
+
if class_embed_type is None and num_class_embeds is not None:
|
350 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
351 |
+
elif class_embed_type == "timestep":
|
352 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
353 |
+
elif class_embed_type == "identity":
|
354 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
355 |
+
elif class_embed_type == "projection":
|
356 |
+
if projection_class_embeddings_input_dim is None:
|
357 |
+
raise ValueError(
|
358 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
359 |
+
)
|
360 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
361 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
362 |
+
# 2. it projects from an arbitrary input dimension.
|
363 |
+
#
|
364 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
365 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
366 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
367 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
368 |
+
elif class_embed_type == "simple_projection":
|
369 |
+
if projection_class_embeddings_input_dim is None:
|
370 |
+
raise ValueError(
|
371 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
372 |
+
)
|
373 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
374 |
+
else:
|
375 |
+
self.class_embedding = None
|
376 |
+
|
377 |
+
if addition_embed_type == "text":
|
378 |
+
if encoder_hid_dim is not None:
|
379 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
380 |
+
else:
|
381 |
+
text_time_embedding_from_dim = cross_attention_dim
|
382 |
+
|
383 |
+
self.add_embedding = TextTimeEmbedding(
|
384 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
385 |
+
)
|
386 |
+
elif addition_embed_type == "text_image":
|
387 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
388 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
389 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
390 |
+
self.add_embedding = TextImageTimeEmbedding(
|
391 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
392 |
+
)
|
393 |
+
elif addition_embed_type is not None:
|
394 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
395 |
+
|
396 |
+
if time_embedding_act_fn is None:
|
397 |
+
self.time_embed_act = None
|
398 |
+
elif time_embedding_act_fn == "swish":
|
399 |
+
self.time_embed_act = lambda x: F.silu(x)
|
400 |
+
elif time_embedding_act_fn == "mish":
|
401 |
+
self.time_embed_act = nn.Mish()
|
402 |
+
elif time_embedding_act_fn == "silu":
|
403 |
+
self.time_embed_act = nn.SiLU()
|
404 |
+
elif time_embedding_act_fn == "gelu":
|
405 |
+
self.time_embed_act = nn.GELU()
|
406 |
+
else:
|
407 |
+
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
|
408 |
+
|
409 |
+
self.down_blocks = nn.ModuleList([])
|
410 |
+
self.up_blocks = nn.ModuleList([])
|
411 |
+
|
412 |
+
if isinstance(only_cross_attention, bool):
|
413 |
+
if mid_block_only_cross_attention is None:
|
414 |
+
mid_block_only_cross_attention = only_cross_attention
|
415 |
+
|
416 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
417 |
+
|
418 |
+
if mid_block_only_cross_attention is None:
|
419 |
+
mid_block_only_cross_attention = False
|
420 |
+
|
421 |
+
if isinstance(attention_head_dim, int):
|
422 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
423 |
+
|
424 |
+
if isinstance(cross_attention_dim, int):
|
425 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
426 |
+
else:
|
427 |
+
assert not use_gated_attention, f"use_gated_attention is not supported with varying cross_attention_dim: {cross_attention_dim}"
|
428 |
+
|
429 |
+
if isinstance(layers_per_block, int):
|
430 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
431 |
+
|
432 |
+
if class_embeddings_concat:
|
433 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
434 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
435 |
+
# regular time embeddings
|
436 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
437 |
+
else:
|
438 |
+
blocks_time_embed_dim = time_embed_dim
|
439 |
+
|
440 |
+
# down
|
441 |
+
output_channel = block_out_channels[0]
|
442 |
+
for i, down_block_type in enumerate(down_block_types):
|
443 |
+
input_channel = output_channel
|
444 |
+
output_channel = block_out_channels[i]
|
445 |
+
is_final_block = i == len(block_out_channels) - 1
|
446 |
+
|
447 |
+
down_block = get_down_block(
|
448 |
+
down_block_type,
|
449 |
+
num_layers=layers_per_block[i],
|
450 |
+
in_channels=input_channel,
|
451 |
+
out_channels=output_channel,
|
452 |
+
temb_channels=blocks_time_embed_dim,
|
453 |
+
add_downsample=not is_final_block,
|
454 |
+
resnet_eps=norm_eps,
|
455 |
+
resnet_act_fn=act_fn,
|
456 |
+
resnet_groups=norm_num_groups,
|
457 |
+
cross_attention_dim=cross_attention_dim[i],
|
458 |
+
attn_num_head_channels=attention_head_dim[i],
|
459 |
+
downsample_padding=downsample_padding,
|
460 |
+
dual_cross_attention=dual_cross_attention,
|
461 |
+
use_linear_projection=use_linear_projection,
|
462 |
+
only_cross_attention=only_cross_attention[i],
|
463 |
+
upcast_attention=upcast_attention,
|
464 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
465 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
466 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
467 |
+
cross_attention_norm=cross_attention_norm,
|
468 |
+
use_gated_attention=use_gated_attention,
|
469 |
+
)
|
470 |
+
self.down_blocks.append(down_block)
|
471 |
+
|
472 |
+
# mid
|
473 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
474 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
475 |
+
in_channels=block_out_channels[-1],
|
476 |
+
temb_channels=blocks_time_embed_dim,
|
477 |
+
resnet_eps=norm_eps,
|
478 |
+
resnet_act_fn=act_fn,
|
479 |
+
output_scale_factor=mid_block_scale_factor,
|
480 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
481 |
+
cross_attention_dim=cross_attention_dim[-1],
|
482 |
+
attn_num_head_channels=attention_head_dim[-1],
|
483 |
+
resnet_groups=norm_num_groups,
|
484 |
+
dual_cross_attention=dual_cross_attention,
|
485 |
+
use_linear_projection=use_linear_projection,
|
486 |
+
upcast_attention=upcast_attention,
|
487 |
+
use_gated_attention=use_gated_attention,
|
488 |
+
)
|
489 |
+
elif mid_block_type is None:
|
490 |
+
self.mid_block = None
|
491 |
+
else:
|
492 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
493 |
+
|
494 |
+
# count how many layers upsample the images
|
495 |
+
self.num_upsamplers = 0
|
496 |
+
|
497 |
+
# up
|
498 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
499 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
500 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
501 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
502 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
503 |
+
|
504 |
+
output_channel = reversed_block_out_channels[0]
|
505 |
+
for i, up_block_type in enumerate(up_block_types):
|
506 |
+
is_final_block = i == len(block_out_channels) - 1
|
507 |
+
|
508 |
+
prev_output_channel = output_channel
|
509 |
+
output_channel = reversed_block_out_channels[i]
|
510 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
511 |
+
|
512 |
+
# add upsample block for all BUT final layer
|
513 |
+
if not is_final_block:
|
514 |
+
add_upsample = True
|
515 |
+
self.num_upsamplers += 1
|
516 |
+
else:
|
517 |
+
add_upsample = False
|
518 |
+
|
519 |
+
up_block = get_up_block(
|
520 |
+
up_block_type,
|
521 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
522 |
+
in_channels=input_channel,
|
523 |
+
out_channels=output_channel,
|
524 |
+
prev_output_channel=prev_output_channel,
|
525 |
+
temb_channels=blocks_time_embed_dim,
|
526 |
+
add_upsample=add_upsample,
|
527 |
+
resnet_eps=norm_eps,
|
528 |
+
resnet_act_fn=act_fn,
|
529 |
+
resnet_groups=norm_num_groups,
|
530 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
531 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
532 |
+
dual_cross_attention=dual_cross_attention,
|
533 |
+
use_linear_projection=use_linear_projection,
|
534 |
+
only_cross_attention=only_cross_attention[i],
|
535 |
+
upcast_attention=upcast_attention,
|
536 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
537 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
538 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
539 |
+
cross_attention_norm=cross_attention_norm,
|
540 |
+
use_gated_attention=use_gated_attention,
|
541 |
+
)
|
542 |
+
self.up_blocks.append(up_block)
|
543 |
+
prev_output_channel = output_channel
|
544 |
+
|
545 |
+
# out
|
546 |
+
if norm_num_groups is not None:
|
547 |
+
self.conv_norm_out = nn.GroupNorm(
|
548 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
549 |
+
)
|
550 |
+
|
551 |
+
if act_fn == "swish":
|
552 |
+
self.conv_act = lambda x: F.silu(x)
|
553 |
+
elif act_fn == "mish":
|
554 |
+
self.conv_act = nn.Mish()
|
555 |
+
elif act_fn == "silu":
|
556 |
+
self.conv_act = nn.SiLU()
|
557 |
+
elif act_fn == "gelu":
|
558 |
+
self.conv_act = nn.GELU()
|
559 |
+
else:
|
560 |
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
561 |
+
|
562 |
+
else:
|
563 |
+
self.conv_norm_out = None
|
564 |
+
self.conv_act = None
|
565 |
+
|
566 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
567 |
+
self.conv_out = nn.Conv2d(
|
568 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
569 |
+
)
|
570 |
+
|
571 |
+
if use_gated_attention:
|
572 |
+
self.position_net = PositionNet(positive_len=768, out_dim=cross_attention_dim[-1])
|
573 |
+
|
574 |
+
|
575 |
+
@property
|
576 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
577 |
+
r"""
|
578 |
+
Returns:
|
579 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
580 |
+
indexed by its weight name.
|
581 |
+
"""
|
582 |
+
# set recursively
|
583 |
+
processors = {}
|
584 |
+
|
585 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
586 |
+
if hasattr(module, "set_processor"):
|
587 |
+
processors[f"{name}.processor"] = module.processor
|
588 |
+
|
589 |
+
for sub_name, child in module.named_children():
|
590 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
591 |
+
|
592 |
+
return processors
|
593 |
+
|
594 |
+
for name, module in self.named_children():
|
595 |
+
fn_recursive_add_processors(name, module, processors)
|
596 |
+
|
597 |
+
return processors
|
598 |
+
|
599 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
600 |
+
r"""
|
601 |
+
Parameters:
|
602 |
+
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
603 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
604 |
+
of **all** `Attention` layers.
|
605 |
+
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
|
606 |
+
|
607 |
+
"""
|
608 |
+
count = len(self.attn_processors.keys())
|
609 |
+
|
610 |
+
if isinstance(processor, dict) and len(processor) != count:
|
611 |
+
raise ValueError(
|
612 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
613 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
614 |
+
)
|
615 |
+
|
616 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
617 |
+
if hasattr(module, "set_processor"):
|
618 |
+
if not isinstance(processor, dict):
|
619 |
+
module.set_processor(processor)
|
620 |
+
else:
|
621 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
622 |
+
|
623 |
+
for sub_name, child in module.named_children():
|
624 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
625 |
+
|
626 |
+
for name, module in self.named_children():
|
627 |
+
fn_recursive_attn_processor(name, module, processor)
|
628 |
+
|
629 |
+
def set_default_attn_processor(self):
|
630 |
+
"""
|
631 |
+
Disables custom attention processors and sets the default attention implementation.
|
632 |
+
"""
|
633 |
+
self.set_attn_processor(AttnProcessor())
|
634 |
+
|
635 |
+
def set_attention_slice(self, slice_size):
|
636 |
+
r"""
|
637 |
+
Enable sliced attention computation.
|
638 |
+
|
639 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
640 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
641 |
+
|
642 |
+
Args:
|
643 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
644 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
645 |
+
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
|
646 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
647 |
+
must be a multiple of `slice_size`.
|
648 |
+
"""
|
649 |
+
sliceable_head_dims = []
|
650 |
+
|
651 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
652 |
+
if hasattr(module, "set_attention_slice"):
|
653 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
654 |
+
|
655 |
+
for child in module.children():
|
656 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
657 |
+
|
658 |
+
# retrieve number of attention layers
|
659 |
+
for module in self.children():
|
660 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
661 |
+
|
662 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
663 |
+
|
664 |
+
if slice_size == "auto":
|
665 |
+
# half the attention head size is usually a good trade-off between
|
666 |
+
# speed and memory
|
667 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
668 |
+
elif slice_size == "max":
|
669 |
+
# make smallest slice possible
|
670 |
+
slice_size = num_sliceable_layers * [1]
|
671 |
+
|
672 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
673 |
+
|
674 |
+
if len(slice_size) != len(sliceable_head_dims):
|
675 |
+
raise ValueError(
|
676 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
677 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
678 |
+
)
|
679 |
+
|
680 |
+
for i in range(len(slice_size)):
|
681 |
+
size = slice_size[i]
|
682 |
+
dim = sliceable_head_dims[i]
|
683 |
+
if size is not None and size > dim:
|
684 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
685 |
+
|
686 |
+
# Recursively walk through all the children.
|
687 |
+
# Any children which exposes the set_attention_slice method
|
688 |
+
# gets the message
|
689 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
690 |
+
if hasattr(module, "set_attention_slice"):
|
691 |
+
module.set_attention_slice(slice_size.pop())
|
692 |
+
|
693 |
+
for child in module.children():
|
694 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
695 |
+
|
696 |
+
reversed_slice_size = list(reversed(slice_size))
|
697 |
+
for module in self.children():
|
698 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
699 |
+
|
700 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
701 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
702 |
+
module.gradient_checkpointing = value
|
703 |
+
|
704 |
+
def forward(
|
705 |
+
self,
|
706 |
+
sample: torch.FloatTensor,
|
707 |
+
timestep: Union[torch.Tensor, float, int],
|
708 |
+
encoder_hidden_states: torch.Tensor,
|
709 |
+
class_labels: Optional[torch.Tensor] = None,
|
710 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
711 |
+
attention_mask: Optional[torch.Tensor] = None,
|
712 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
713 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
714 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
715 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
716 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
717 |
+
return_dict: bool = True,
|
718 |
+
return_cross_attention_probs: bool = False
|
719 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
720 |
+
r"""
|
721 |
+
Args:
|
722 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
723 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
724 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
725 |
+
encoder_attention_mask (`torch.Tensor`):
|
726 |
+
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
|
727 |
+
discard. Mask will be converted into a bias, which adds large negative values to attention scores
|
728 |
+
corresponding to "discard" tokens.
|
729 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
730 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
731 |
+
cross_attention_kwargs (`dict`, *optional*):
|
732 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
733 |
+
`self.processor` in
|
734 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
735 |
+
added_cond_kwargs (`dict`, *optional*):
|
736 |
+
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
|
737 |
+
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
|
738 |
+
`addition_embed_type` for more information.
|
739 |
+
|
740 |
+
Returns:
|
741 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
742 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
743 |
+
returning a tuple, the first element is the sample tensor.
|
744 |
+
"""
|
745 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
746 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
747 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
748 |
+
# on the fly if necessary.
|
749 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
750 |
+
|
751 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
752 |
+
forward_upsample_size = False
|
753 |
+
upsample_size = None
|
754 |
+
|
755 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
756 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
757 |
+
forward_upsample_size = True
|
758 |
+
|
759 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
760 |
+
# expects mask of shape:
|
761 |
+
# [batch, key_tokens]
|
762 |
+
# adds singleton query_tokens dimension:
|
763 |
+
# [batch, 1, key_tokens]
|
764 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
765 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
766 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
767 |
+
if attention_mask is not None:
|
768 |
+
# assume that mask is expressed as:
|
769 |
+
# (1 = keep, 0 = discard)
|
770 |
+
# convert mask into a bias that can be added to attention scores:
|
771 |
+
# (keep = +0, discard = -10000.0)
|
772 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
773 |
+
attention_mask = attention_mask.unsqueeze(1)
|
774 |
+
|
775 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
776 |
+
if encoder_attention_mask is not None:
|
777 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
778 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
779 |
+
|
780 |
+
# 0. center input if necessary
|
781 |
+
if self.config.center_input_sample:
|
782 |
+
sample = 2 * sample - 1.0
|
783 |
+
|
784 |
+
# 1. time
|
785 |
+
timesteps = timestep
|
786 |
+
if not torch.is_tensor(timesteps):
|
787 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
788 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
789 |
+
is_mps = sample.device.type == "mps"
|
790 |
+
if isinstance(timestep, float):
|
791 |
+
dtype = torch.float32 if is_mps else torch.float64
|
792 |
+
else:
|
793 |
+
dtype = torch.int32 if is_mps else torch.int64
|
794 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
795 |
+
elif len(timesteps.shape) == 0:
|
796 |
+
timesteps = timesteps[None].to(sample.device)
|
797 |
+
|
798 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
799 |
+
timesteps = timesteps.expand(sample.shape[0])
|
800 |
+
|
801 |
+
t_emb = self.time_proj(timesteps)
|
802 |
+
|
803 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
804 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
805 |
+
# there might be better ways to encapsulate this.
|
806 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
807 |
+
|
808 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
809 |
+
|
810 |
+
if self.class_embedding is not None:
|
811 |
+
if class_labels is None:
|
812 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
813 |
+
|
814 |
+
if self.config.class_embed_type == "timestep":
|
815 |
+
class_labels = self.time_proj(class_labels)
|
816 |
+
|
817 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
818 |
+
# there might be better ways to encapsulate this.
|
819 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
820 |
+
|
821 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
822 |
+
|
823 |
+
if self.config.class_embeddings_concat:
|
824 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
825 |
+
else:
|
826 |
+
emb = emb + class_emb
|
827 |
+
|
828 |
+
if self.config.addition_embed_type == "text":
|
829 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
830 |
+
emb = emb + aug_emb
|
831 |
+
elif self.config.addition_embed_type == "text_image":
|
832 |
+
# Kadinsky 2.1 - style
|
833 |
+
if "image_embeds" not in added_cond_kwargs:
|
834 |
+
raise ValueError(
|
835 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
836 |
+
)
|
837 |
+
|
838 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
839 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
840 |
+
|
841 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
842 |
+
emb = emb + aug_emb
|
843 |
+
|
844 |
+
if self.time_embed_act is not None:
|
845 |
+
emb = self.time_embed_act(emb)
|
846 |
+
|
847 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
848 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
849 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
850 |
+
# Kadinsky 2.1 - style
|
851 |
+
if "image_embeds" not in added_cond_kwargs:
|
852 |
+
raise ValueError(
|
853 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
854 |
+
)
|
855 |
+
|
856 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
857 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
858 |
+
|
859 |
+
# 2. pre-process
|
860 |
+
sample = self.conv_in(sample)
|
861 |
+
|
862 |
+
# 2.5 GLIGEN position net
|
863 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get('gligen', None) is not None:
|
864 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
865 |
+
cross_attention_kwargs['gligen'] = {
|
866 |
+
'objs': self.position_net(
|
867 |
+
boxes=cross_attention_kwargs['gligen']['boxes'],
|
868 |
+
masks=cross_attention_kwargs['gligen']['masks'],
|
869 |
+
positive_embeddings=cross_attention_kwargs['gligen']['positive_embeddings']
|
870 |
+
),
|
871 |
+
'fuser_attn_kwargs': cross_attention_kwargs['gligen'].get('fuser_attn_kwargs', {})
|
872 |
+
}
|
873 |
+
|
874 |
+
# 3. down
|
875 |
+
down_block_res_samples = (sample,)
|
876 |
+
cross_attention_probs_down = []
|
877 |
+
if cross_attention_kwargs is None:
|
878 |
+
cross_attention_kwargs = {}
|
879 |
+
|
880 |
+
for i, downsample_block in enumerate(self.down_blocks):
|
881 |
+
cross_attention_kwargs["attn_key"] = ["down", i]
|
882 |
+
|
883 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
884 |
+
downsample_block_output = downsample_block(
|
885 |
+
hidden_states=sample,
|
886 |
+
temb=emb,
|
887 |
+
encoder_hidden_states=encoder_hidden_states,
|
888 |
+
attention_mask=attention_mask,
|
889 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
890 |
+
encoder_attention_mask=encoder_attention_mask,
|
891 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
892 |
+
)
|
893 |
+
if return_cross_attention_probs:
|
894 |
+
sample, res_samples, cross_attention_probs = downsample_block_output
|
895 |
+
cross_attention_probs_down.append(cross_attention_probs)
|
896 |
+
else:
|
897 |
+
sample, res_samples = downsample_block_output
|
898 |
+
else:
|
899 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
900 |
+
|
901 |
+
down_block_res_samples += res_samples
|
902 |
+
|
903 |
+
if down_block_additional_residuals is not None:
|
904 |
+
new_down_block_res_samples = ()
|
905 |
+
|
906 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
907 |
+
down_block_res_samples, down_block_additional_residuals
|
908 |
+
):
|
909 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
910 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
911 |
+
|
912 |
+
down_block_res_samples = new_down_block_res_samples
|
913 |
+
|
914 |
+
# 4. mid
|
915 |
+
cross_attention_probs_mid = []
|
916 |
+
if self.mid_block is not None:
|
917 |
+
cross_attention_kwargs["attn_key"] = ["mid", 0]
|
918 |
+
|
919 |
+
sample = self.mid_block(
|
920 |
+
sample,
|
921 |
+
emb,
|
922 |
+
encoder_hidden_states=encoder_hidden_states,
|
923 |
+
attention_mask=attention_mask,
|
924 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
925 |
+
encoder_attention_mask=encoder_attention_mask,
|
926 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
927 |
+
)
|
928 |
+
if return_cross_attention_probs:
|
929 |
+
sample, cross_attention_probs = sample
|
930 |
+
cross_attention_probs_mid.append(cross_attention_probs)
|
931 |
+
|
932 |
+
|
933 |
+
if mid_block_additional_residual is not None:
|
934 |
+
sample = sample + mid_block_additional_residual
|
935 |
+
|
936 |
+
cross_attention_probs_up = []
|
937 |
+
# 5. up
|
938 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
939 |
+
cross_attention_kwargs["attn_key"] = ["up", i]
|
940 |
+
|
941 |
+
is_final_block = i == len(self.up_blocks) - 1
|
942 |
+
|
943 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
944 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
945 |
+
|
946 |
+
# if we have not reached the final block and need to forward the
|
947 |
+
# upsample size, we do it here
|
948 |
+
if not is_final_block and forward_upsample_size:
|
949 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
950 |
+
|
951 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
952 |
+
sample = upsample_block(
|
953 |
+
hidden_states=sample,
|
954 |
+
temb=emb,
|
955 |
+
res_hidden_states_tuple=res_samples,
|
956 |
+
encoder_hidden_states=encoder_hidden_states,
|
957 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
958 |
+
upsample_size=upsample_size,
|
959 |
+
attention_mask=attention_mask,
|
960 |
+
encoder_attention_mask=encoder_attention_mask,
|
961 |
+
return_cross_attention_probs=return_cross_attention_probs,
|
962 |
+
)
|
963 |
+
if return_cross_attention_probs:
|
964 |
+
sample, cross_attention_probs = sample
|
965 |
+
cross_attention_probs_up.append(cross_attention_probs)
|
966 |
+
else:
|
967 |
+
sample = upsample_block(
|
968 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
969 |
+
)
|
970 |
+
|
971 |
+
# 6. post-process
|
972 |
+
if self.conv_norm_out:
|
973 |
+
sample = self.conv_norm_out(sample)
|
974 |
+
sample = self.conv_act(sample)
|
975 |
+
sample = self.conv_out(sample)
|
976 |
+
|
977 |
+
if not return_dict:
|
978 |
+
return (sample,)
|
979 |
+
|
980 |
+
return UNet2DConditionOutput(sample=sample, cross_attention_probs_down=cross_attention_probs_down, cross_attention_probs_mid=cross_attention_probs_mid, cross_attention_probs_up=cross_attention_probs_up)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
numpy
|
3 |
+
scipy
|
4 |
+
torch==2.0.0
|
5 |
+
diffusers==0.17.0
|
6 |
+
transformers==4.29.2
|
7 |
+
opencv-python==4.7.0.72
|
8 |
+
opencv-contrib-python==4.7.0.72
|
9 |
+
inflect==6.0.4
|
10 |
+
easydict
|
11 |
+
accelerate==0.18.0
|
shared.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import load_sd, sam
|
2 |
+
|
3 |
+
use_fp16 = False
|
4 |
+
use_dpm = True
|
5 |
+
|
6 |
+
sd_key = "gligen/diffusers-generation-text-box"
|
7 |
+
|
8 |
+
print(f"Using SD: {sd_key}")
|
9 |
+
model_dict = load_sd(key=sd_key, use_fp16=use_fp16, use_dpm_multistep_scheduler=use_dpm, load_inverse_scheduler=False)
|
10 |
+
|
11 |
+
sam_model_dict = sam.load_sam()
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .utils import *
|
utils/latents.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from . import utils
|
4 |
+
from utils import torch_device
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
def get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype):
|
8 |
+
"""
|
9 |
+
in_channels: often obtained with `unet.config.in_channels`
|
10 |
+
"""
|
11 |
+
# Obtain with torch.float32 and cast to float16 if needed
|
12 |
+
# Directly obtaining latents in float16 will lead to different latents
|
13 |
+
latents_base = torch.randn(
|
14 |
+
(batch_size, in_channels, height // 8, width // 8),
|
15 |
+
generator=generator, dtype=dtype
|
16 |
+
).to(torch_device, dtype=dtype)
|
17 |
+
|
18 |
+
return latents_base
|
19 |
+
|
20 |
+
def get_scaled_latents(batch_size, in_channels, height, width, generator, dtype, scheduler):
|
21 |
+
latents_base = get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype)
|
22 |
+
latents_base = latents_base * scheduler.init_noise_sigma
|
23 |
+
return latents_base
|
24 |
+
|
25 |
+
def blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=0.01):
|
26 |
+
"""
|
27 |
+
in_channels: often obtained with `unet.config.in_channels`
|
28 |
+
"""
|
29 |
+
assert not torch.allclose(latents_bg, latents_fg), "latents_bg should be independent with latents_fg"
|
30 |
+
|
31 |
+
dtype = latents_bg.dtype
|
32 |
+
latents = latents_bg * (1. - fg_mask) + (latents_bg * np.sqrt(1. - fg_blending_ratio) + latents_fg * np.sqrt(fg_blending_ratio)) * fg_mask
|
33 |
+
latents = latents.to(dtype=dtype)
|
34 |
+
|
35 |
+
return latents
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, latents_bg=None, bg_seed=None, compose_box_to_bg=True):
|
39 |
+
unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
|
40 |
+
|
41 |
+
if latents_bg is None:
|
42 |
+
generator = torch.manual_seed(bg_seed) # Seed generator to create the inital latent noise
|
43 |
+
latents_bg = get_scaled_latents(overall_batch_size, unet.config.in_channels, height, width, generator, dtype, scheduler)
|
44 |
+
|
45 |
+
# Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
|
46 |
+
composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
|
47 |
+
composed_latents[0] = latents_bg
|
48 |
+
|
49 |
+
foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
|
50 |
+
|
51 |
+
mask_size = np.array([mask_tensor.sum().item() for mask_tensor in mask_tensor_list])
|
52 |
+
# Compose the largest mask first
|
53 |
+
mask_order = np.argsort(-mask_size)
|
54 |
+
|
55 |
+
if compose_box_to_bg:
|
56 |
+
# This has two functionalities:
|
57 |
+
# 1. copies the right initial latents from the right place (for centered so generation), 2. copies the right initial latents (since we have foreground blending) for centered/original so generation.
|
58 |
+
for mask_idx in mask_order:
|
59 |
+
latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
|
60 |
+
|
61 |
+
# Note: need to be careful to not copy from zeros due to shifting.
|
62 |
+
mask_tensor = utils.binary_mask_to_box_mask(mask_tensor, to_device=False)
|
63 |
+
|
64 |
+
mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
|
65 |
+
composed_latents[0] = composed_latents[0] * (1. - mask_tensor_expanded) + latents_all[0] * mask_tensor_expanded
|
66 |
+
|
67 |
+
# This is still needed with `compose_box_to_bg` to ensure the foreground latent is still visible and to compute foreground indices.
|
68 |
+
for mask_idx in mask_order:
|
69 |
+
latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
|
70 |
+
foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
|
71 |
+
mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
|
72 |
+
composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all * mask_tensor_expanded
|
73 |
+
|
74 |
+
composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
|
75 |
+
return composed_latents, foreground_indices
|
76 |
+
|
77 |
+
def align_with_bboxes(latents_all_list, mask_tensor_list, bboxes, horizontal_shift_only=False):
|
78 |
+
"""
|
79 |
+
Each offset in `offset_list` is `(x_offset, y_offset)` (normalized).
|
80 |
+
"""
|
81 |
+
new_latents_all_list, new_mask_tensor_list, offset_list = [], [], []
|
82 |
+
for latents_all, mask_tensor, bbox in zip(latents_all_list, mask_tensor_list, bboxes):
|
83 |
+
x_src_center, y_src_center = utils.binary_mask_to_center(mask_tensor, normalize=True)
|
84 |
+
x_min_dest, y_min_dest, x_max_dest, y_max_dest = bbox
|
85 |
+
x_dest_center, y_dest_center = (x_min_dest + x_max_dest) / 2, (y_min_dest + y_max_dest) / 2
|
86 |
+
# print("src (x,y):", x_src_center, y_src_center, "dest (x,y):", x_dest_center, y_dest_center)
|
87 |
+
x_offset, y_offset = x_dest_center - x_src_center, y_dest_center - y_src_center
|
88 |
+
if horizontal_shift_only:
|
89 |
+
y_offset = 0.
|
90 |
+
offset = x_offset, y_offset
|
91 |
+
latents_all = utils.shift_tensor(latents_all, x_offset, y_offset, offset_normalized=True)
|
92 |
+
mask_tensor = utils.shift_tensor(mask_tensor, x_offset, y_offset, offset_normalized=True)
|
93 |
+
new_latents_all_list.append(latents_all)
|
94 |
+
new_mask_tensor_list.append(mask_tensor)
|
95 |
+
offset_list.append(offset)
|
96 |
+
|
97 |
+
return new_latents_all_list, new_mask_tensor_list, offset_list
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def compose_latents_with_alignment(
|
101 |
+
model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width,
|
102 |
+
align_with_overall_bboxes=True, overall_bboxes=None, horizontal_shift_only=False, **kwargs
|
103 |
+
):
|
104 |
+
if align_with_overall_bboxes and len(latents_all_list):
|
105 |
+
expanded_overall_bboxes = utils.expand_overall_bboxes(overall_bboxes)
|
106 |
+
latents_all_list, mask_tensor_list, offset_list = align_with_bboxes(latents_all_list, mask_tensor_list, bboxes=expanded_overall_bboxes, horizontal_shift_only=horizontal_shift_only)
|
107 |
+
else:
|
108 |
+
offset_list = [(0., 0.) for _ in range(len(latents_all_list))]
|
109 |
+
composed_latents, foreground_indices = compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, **kwargs)
|
110 |
+
return composed_latents, foreground_indices, offset_list
|
111 |
+
|
112 |
+
def get_input_latents_list(model_dict, bg_seed, fg_seed_start, fg_blending_ratio, height, width, so_prompt_phrase_box_list=None, so_boxes=None, verbose=False):
|
113 |
+
"""
|
114 |
+
Note: the returned input latents are scaled by `scheduler.init_noise_sigma`
|
115 |
+
"""
|
116 |
+
unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
|
117 |
+
|
118 |
+
generator_bg = torch.manual_seed(bg_seed) # Seed generator to create the inital latent noise
|
119 |
+
latents_bg = get_unscaled_latents(batch_size=1, in_channels=unet.config.in_channels, height=height, width=width, generator=generator_bg, dtype=dtype)
|
120 |
+
|
121 |
+
input_latents_list = []
|
122 |
+
|
123 |
+
if so_boxes is None:
|
124 |
+
# For compatibility
|
125 |
+
so_boxes = [item[-1] for item in so_prompt_phrase_box_list]
|
126 |
+
|
127 |
+
# change this changes the foreground initial noise
|
128 |
+
for idx, obj_box in enumerate(so_boxes):
|
129 |
+
H, W = height // 8, width // 8
|
130 |
+
fg_mask = utils.proportion_to_mask(obj_box, H, W)
|
131 |
+
|
132 |
+
if verbose:
|
133 |
+
plt.imshow(fg_mask.cpu().numpy())
|
134 |
+
plt.show()
|
135 |
+
|
136 |
+
fg_seed = fg_seed_start + idx
|
137 |
+
if fg_seed == bg_seed:
|
138 |
+
# We should have different seeds for foreground and background
|
139 |
+
fg_seed += 12345
|
140 |
+
|
141 |
+
generator_fg = torch.manual_seed(fg_seed)
|
142 |
+
latents_fg = get_unscaled_latents(batch_size=1, in_channels=unet.config.in_channels, height=height, width=width, generator=generator_fg, dtype=dtype)
|
143 |
+
|
144 |
+
input_latents = blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=fg_blending_ratio)
|
145 |
+
|
146 |
+
input_latents = input_latents * scheduler.init_noise_sigma
|
147 |
+
|
148 |
+
input_latents_list.append(input_latents)
|
149 |
+
|
150 |
+
return input_latents_list, latents_bg
|
151 |
+
|
utils/parse.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from matplotlib.patches import Polygon
|
5 |
+
from matplotlib.collections import PatchCollection
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import inflect
|
10 |
+
|
11 |
+
p = inflect.engine()
|
12 |
+
|
13 |
+
img_dir = "imgs"
|
14 |
+
bg_prompt_text = "Background prompt: "
|
15 |
+
# h, w
|
16 |
+
box_scale = (512, 512)
|
17 |
+
size = box_scale
|
18 |
+
size_h, size_w = size
|
19 |
+
print(f"Using box scale: {box_scale}")
|
20 |
+
|
21 |
+
def parse_input(text=None, no_input=False):
|
22 |
+
if not text:
|
23 |
+
if no_input:
|
24 |
+
return
|
25 |
+
|
26 |
+
text = input("Enter the response: ")
|
27 |
+
if "Objects: " in text:
|
28 |
+
text = text.split("Objects: ")[1]
|
29 |
+
|
30 |
+
text_split = text.split(bg_prompt_text)
|
31 |
+
if len(text_split) == 2:
|
32 |
+
gen_boxes, bg_prompt = text_split
|
33 |
+
elif len(text_split) == 1:
|
34 |
+
if no_input:
|
35 |
+
return
|
36 |
+
gen_boxes = text
|
37 |
+
bg_prompt = ""
|
38 |
+
while not bg_prompt:
|
39 |
+
# Ignore the empty lines in the response
|
40 |
+
bg_prompt = input("Enter the background prompt: ").strip()
|
41 |
+
if bg_prompt_text in bg_prompt:
|
42 |
+
bg_prompt = bg_prompt.split(bg_prompt_text)[1]
|
43 |
+
else:
|
44 |
+
raise ValueError(f"text: {text}")
|
45 |
+
try:
|
46 |
+
gen_boxes = ast.literal_eval(gen_boxes)
|
47 |
+
except SyntaxError as e:
|
48 |
+
# Sometimes the response is in plain text
|
49 |
+
if "No objects" in gen_boxes:
|
50 |
+
gen_boxes = []
|
51 |
+
else:
|
52 |
+
raise e
|
53 |
+
bg_prompt = bg_prompt.strip()
|
54 |
+
|
55 |
+
return gen_boxes, bg_prompt
|
56 |
+
|
57 |
+
def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3):
|
58 |
+
if len(gen_boxes) == 0:
|
59 |
+
return []
|
60 |
+
|
61 |
+
box_dict_format = False
|
62 |
+
gen_boxes_new = []
|
63 |
+
for gen_box in gen_boxes:
|
64 |
+
if isinstance(gen_box, dict):
|
65 |
+
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
|
66 |
+
box_dict_format = True
|
67 |
+
else:
|
68 |
+
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
|
69 |
+
if bbox_w <= 0 or bbox_h <= 0:
|
70 |
+
# Empty boxes
|
71 |
+
continue
|
72 |
+
if ignore_background:
|
73 |
+
if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]:
|
74 |
+
# Ignore the background boxes
|
75 |
+
continue
|
76 |
+
gen_boxes_new.append(gen_box)
|
77 |
+
|
78 |
+
gen_boxes = gen_boxes_new
|
79 |
+
|
80 |
+
if len(gen_boxes) == 0:
|
81 |
+
return []
|
82 |
+
|
83 |
+
filtered_gen_boxes = []
|
84 |
+
if box_dict_format:
|
85 |
+
# For compatibility
|
86 |
+
bbox_left_x_min = min([gen_box['bounding_box'][0] for gen_box in gen_boxes])
|
87 |
+
bbox_right_x_max = max([gen_box['bounding_box'][0] + gen_box['bounding_box'][2] for gen_box in gen_boxes])
|
88 |
+
bbox_top_y_min = min([gen_box['bounding_box'][1] for gen_box in gen_boxes])
|
89 |
+
bbox_bottom_y_max = max([gen_box['bounding_box'][1] + gen_box['bounding_box'][3] for gen_box in gen_boxes])
|
90 |
+
else:
|
91 |
+
bbox_left_x_min = min([gen_box[1][0] for gen_box in gen_boxes])
|
92 |
+
bbox_right_x_max = max([gen_box[1][0] + gen_box[1][2] for gen_box in gen_boxes])
|
93 |
+
bbox_top_y_min = min([gen_box[1][1] for gen_box in gen_boxes])
|
94 |
+
bbox_bottom_y_max = max([gen_box[1][1] + gen_box[1][3] for gen_box in gen_boxes])
|
95 |
+
|
96 |
+
# All boxes are empty
|
97 |
+
if (bbox_right_x_max - bbox_left_x_min) == 0:
|
98 |
+
return []
|
99 |
+
|
100 |
+
# Used if scale_boxes is True
|
101 |
+
shift = -bbox_left_x_min
|
102 |
+
scale = size_w / (bbox_right_x_max - bbox_left_x_min)
|
103 |
+
|
104 |
+
scale = min(scale, max_scale)
|
105 |
+
|
106 |
+
for gen_box in gen_boxes:
|
107 |
+
if box_dict_format:
|
108 |
+
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
|
109 |
+
else:
|
110 |
+
name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
|
111 |
+
|
112 |
+
if scale_boxes:
|
113 |
+
# Vertical: move the boxes if out of bound
|
114 |
+
# Horizontal: move and scale the boxes so it spans the horizontal line
|
115 |
+
|
116 |
+
bbox_x = (bbox_x + shift) * scale
|
117 |
+
bbox_y = bbox_y * scale
|
118 |
+
bbox_w, bbox_h = bbox_w * scale, bbox_h * scale
|
119 |
+
# TODO: verify this makes the y center not moving
|
120 |
+
bbox_y_offset = 0
|
121 |
+
if bbox_top_y_min * scale + bbox_y_offset < 0:
|
122 |
+
bbox_y_offset -= bbox_top_y_min * scale
|
123 |
+
if bbox_bottom_y_max * scale + bbox_y_offset >= size_h:
|
124 |
+
bbox_y_offset -= bbox_bottom_y_max * scale - size_h
|
125 |
+
bbox_y += bbox_y_offset
|
126 |
+
|
127 |
+
if bbox_y < 0:
|
128 |
+
bbox_y, bbox_h = 0, bbox_h - bbox_y
|
129 |
+
|
130 |
+
name = name.rstrip(".")
|
131 |
+
bounding_box = (int(np.round(bbox_x)), int(np.round(bbox_y)), int(np.round(bbox_w)), int(np.round(bbox_h)))
|
132 |
+
if box_dict_format:
|
133 |
+
gen_box = {
|
134 |
+
'name': name,
|
135 |
+
'bounding_box': bounding_box
|
136 |
+
}
|
137 |
+
else:
|
138 |
+
gen_box = (name, bounding_box)
|
139 |
+
|
140 |
+
filtered_gen_boxes.append(gen_box)
|
141 |
+
|
142 |
+
return filtered_gen_boxes
|
143 |
+
|
144 |
+
def draw_boxes(anns):
|
145 |
+
ax = plt.gca()
|
146 |
+
ax.set_autoscale_on(False)
|
147 |
+
polygons = []
|
148 |
+
color = []
|
149 |
+
for ann in anns:
|
150 |
+
c = (np.random.random((1, 3))*0.6+0.4)
|
151 |
+
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
|
152 |
+
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
|
153 |
+
[bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
154 |
+
np_poly = np.array(poly).reshape((4, 2))
|
155 |
+
polygons.append(Polygon(np_poly))
|
156 |
+
color.append(c)
|
157 |
+
|
158 |
+
# print(ann)
|
159 |
+
name = ann['name'] if 'name' in ann else str(ann['category_id'])
|
160 |
+
ax.text(bbox_x, bbox_y, name, style='italic',
|
161 |
+
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
|
162 |
+
|
163 |
+
p = PatchCollection(polygons, facecolor='none',
|
164 |
+
edgecolors=color, linewidths=2)
|
165 |
+
ax.add_collection(p)
|
166 |
+
|
167 |
+
|
168 |
+
def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
|
169 |
+
if len(gen_boxes) == 0:
|
170 |
+
return
|
171 |
+
|
172 |
+
if isinstance(gen_boxes[0], dict):
|
173 |
+
anns = [{'name': gen_box['name'], 'bbox': gen_box['bounding_box']}
|
174 |
+
for gen_box in gen_boxes]
|
175 |
+
else:
|
176 |
+
anns = [{'name': gen_box[0], 'bbox': gen_box[1]} for gen_box in gen_boxes]
|
177 |
+
|
178 |
+
# White background (to allow line to show on the edge)
|
179 |
+
I = np.ones((size[0]+4, size[1]+4, 3), dtype=np.uint8) * 255
|
180 |
+
|
181 |
+
plt.imshow(I)
|
182 |
+
plt.axis('off')
|
183 |
+
|
184 |
+
if bg_prompt is not None:
|
185 |
+
ax = plt.gca()
|
186 |
+
ax.text(0, 0, bg_prompt, style='italic',
|
187 |
+
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
|
188 |
+
|
189 |
+
c = (np.zeros((1, 3)))
|
190 |
+
[bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0])
|
191 |
+
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
|
192 |
+
[bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
|
193 |
+
np_poly = np.array(poly).reshape((4, 2))
|
194 |
+
polygons = [Polygon(np_poly)]
|
195 |
+
color = [c]
|
196 |
+
p = PatchCollection(polygons, facecolor='none',
|
197 |
+
edgecolors=color, linewidths=2)
|
198 |
+
ax.add_collection(p)
|
199 |
+
|
200 |
+
draw_boxes(anns)
|
201 |
+
if show:
|
202 |
+
plt.show()
|
203 |
+
else:
|
204 |
+
print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}")
|
205 |
+
if ind is not None:
|
206 |
+
plt.savefig(f"{img_dir}/boxes_{ind}.png")
|
207 |
+
plt.savefig(f"{img_dir}/boxes.png")
|
208 |
+
|
209 |
+
|
210 |
+
def show_masks(masks):
|
211 |
+
masks_to_show = np.zeros((*size, 3), dtype=np.float32)
|
212 |
+
for mask in masks:
|
213 |
+
c = (np.random.random((3,))*0.6+0.4)
|
214 |
+
|
215 |
+
masks_to_show += mask[..., None] * c[None, None, :]
|
216 |
+
plt.imshow(masks_to_show)
|
217 |
+
plt.savefig(f"{img_dir}/masks.png")
|
218 |
+
plt.show()
|
219 |
+
plt.clf()
|
220 |
+
|
221 |
+
def convert_box(box, height, width):
|
222 |
+
# box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max
|
223 |
+
x_min, y_min = box[0] / width, box[1] / height
|
224 |
+
w_box, h_box = box[2] / width, box[3] / height
|
225 |
+
|
226 |
+
x_max, y_max = x_min + w_box, y_min + h_box
|
227 |
+
|
228 |
+
return x_min, y_min, x_max, y_max
|
229 |
+
|
230 |
+
def convert_spec(spec, height, width, include_counts=True, verbose=False):
|
231 |
+
# Infer from spec
|
232 |
+
prompt, gen_boxes, bg_prompt = spec['prompt'], spec['gen_boxes'], spec['bg_prompt']
|
233 |
+
|
234 |
+
# This ensures the same objects appear together because flattened `overall_phrases_bboxes` should EXACTLY correspond to `so_prompt_phrase_box_list`.
|
235 |
+
gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0])
|
236 |
+
|
237 |
+
gen_boxes = [(name, convert_box(box, height=height, width=width)) for name, box in gen_boxes]
|
238 |
+
|
239 |
+
# NOTE: so phrase should include all the words associated to the object (otherwise "an orange dog" may be recognized as "an orange" by the model generating the background).
|
240 |
+
# so word should have one token that includes the word to transfer cross attention (the object name).
|
241 |
+
# Currently using the last word of the object name as word.
|
242 |
+
if bg_prompt:
|
243 |
+
so_prompt_phrase_word_box_list = [(f"{bg_prompt} with {name}", name, name.split(" ")[-1], box) for name, box in gen_boxes]
|
244 |
+
else:
|
245 |
+
so_prompt_phrase_word_box_list = [(f"{name}", name, name.split(" ")[-1], box) for name, box in gen_boxes]
|
246 |
+
|
247 |
+
objects = [gen_box[0] for gen_box in gen_boxes]
|
248 |
+
|
249 |
+
objects_unique, objects_count = np.unique(objects, return_counts=True)
|
250 |
+
|
251 |
+
num_total_matched_boxes = 0
|
252 |
+
overall_phrases_words_bboxes = []
|
253 |
+
for ind, object_name in enumerate(objects_unique):
|
254 |
+
bboxes = [box for name, box in gen_boxes if name == object_name]
|
255 |
+
|
256 |
+
if objects_count[ind] > 1:
|
257 |
+
phrase = p.plural_noun(object_name.replace("an ", "").replace("a ", ""))
|
258 |
+
if include_counts:
|
259 |
+
phrase = p.number_to_words(objects_count[ind]) + " " + phrase
|
260 |
+
else:
|
261 |
+
phrase = object_name
|
262 |
+
# Currently using the last word of the phrase as word.
|
263 |
+
word = phrase.split(' ')[-1]
|
264 |
+
|
265 |
+
num_total_matched_boxes += len(bboxes)
|
266 |
+
overall_phrases_words_bboxes.append((phrase, word, bboxes))
|
267 |
+
|
268 |
+
assert num_total_matched_boxes == len(gen_boxes), f"{num_total_matched_boxes} != {len(gen_boxes)}"
|
269 |
+
|
270 |
+
objects_str = ", ".join([phrase for phrase, _, _ in overall_phrases_words_bboxes])
|
271 |
+
if objects_str:
|
272 |
+
if bg_prompt:
|
273 |
+
overall_prompt = f"{bg_prompt} with {objects_str}"
|
274 |
+
else:
|
275 |
+
overall_prompt = objects_str
|
276 |
+
else:
|
277 |
+
overall_prompt = bg_prompt
|
278 |
+
|
279 |
+
if verbose:
|
280 |
+
print("so_prompt_phrase_word_box_list:", so_prompt_phrase_word_box_list)
|
281 |
+
print("overall_prompt:", overall_prompt)
|
282 |
+
print("overall_phrases_words_bboxes:", overall_phrases_words_bboxes)
|
283 |
+
|
284 |
+
return so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes
|
utils/utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import ImageDraw
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import gc
|
6 |
+
|
7 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
+
|
9 |
+
def draw_box(pil_img, bboxes, phrases):
|
10 |
+
draw = ImageDraw.Draw(pil_img)
|
11 |
+
# font = ImageFont.truetype('./FreeMono.ttf', 25)
|
12 |
+
|
13 |
+
for obj_bbox, phrase in zip(bboxes, phrases):
|
14 |
+
x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3]
|
15 |
+
draw.rectangle([int(x_0 * 512), int(y_0 * 512), int(x_1 * 512), int(y_1 * 512)], outline='red', width=5)
|
16 |
+
draw.text((int(x_0 * 512) + 5, int(y_0 * 512) + 5), phrase, font=None, fill=(255, 0, 0))
|
17 |
+
|
18 |
+
return pil_img
|
19 |
+
|
20 |
+
def get_centered_box(box, horizontal_center_only=True):
|
21 |
+
x_min, y_min, x_max, y_max = box
|
22 |
+
w = x_max - x_min
|
23 |
+
|
24 |
+
if horizontal_center_only:
|
25 |
+
return [0.5 - w/2, y_min, 0.5 + w/2, y_max]
|
26 |
+
|
27 |
+
h = y_max - y_min
|
28 |
+
|
29 |
+
return [0.5 - w/2, 0.5 - h/2, 0.5 + w/2, 0.5 + h/2]
|
30 |
+
|
31 |
+
# NOTE: this changes the behavior of the function
|
32 |
+
def proportion_to_mask(obj_box, H, W, use_legacy=False, return_np=False):
|
33 |
+
x_min, y_min, x_max, y_max = scale_proportion(obj_box, H, W, use_legacy)
|
34 |
+
if return_np:
|
35 |
+
mask = np.zeros((H, W))
|
36 |
+
else:
|
37 |
+
mask = torch.zeros(H, W).to(torch_device)
|
38 |
+
mask[y_min: y_max, x_min: x_max] = 1.
|
39 |
+
|
40 |
+
return mask
|
41 |
+
|
42 |
+
def scale_proportion(obj_box, H, W, use_legacy=False):
|
43 |
+
if use_legacy:
|
44 |
+
# Bias towards the top-left corner
|
45 |
+
x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
|
46 |
+
else:
|
47 |
+
# Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5".
|
48 |
+
x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H)
|
49 |
+
box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round((obj_box[3] - obj_box[1]) * H)
|
50 |
+
x_max, y_max = x_min + box_w, y_min + box_h
|
51 |
+
|
52 |
+
x_min, y_min = max(x_min, 0), max(y_min, 0)
|
53 |
+
x_max, y_max = min(x_max, W), min(y_max, H)
|
54 |
+
|
55 |
+
return x_min, y_min, x_max, y_max
|
56 |
+
|
57 |
+
def binary_mask_to_box(mask, enlarge_box_by_one=True, w_scale=1, h_scale=1):
|
58 |
+
if isinstance(mask, torch.Tensor):
|
59 |
+
mask_loc = torch.where(mask)
|
60 |
+
else:
|
61 |
+
mask_loc = np.where(mask)
|
62 |
+
height, width = mask.shape
|
63 |
+
if len(mask_loc) == 0:
|
64 |
+
raise ValueError('The mask is empty')
|
65 |
+
if enlarge_box_by_one:
|
66 |
+
ymin, ymax = max(min(mask_loc[0]) - 1, 0), min(max(mask_loc[0]) + 1, height)
|
67 |
+
xmin, xmax = max(min(mask_loc[1]) - 1, 0), min(max(mask_loc[1]) + 1, width)
|
68 |
+
else:
|
69 |
+
ymin, ymax = min(mask_loc[0]), max(mask_loc[0])
|
70 |
+
xmin, xmax = min(mask_loc[1]), max(mask_loc[1])
|
71 |
+
box = [xmin * w_scale, ymin * h_scale, xmax * w_scale, ymax * h_scale]
|
72 |
+
|
73 |
+
return box
|
74 |
+
|
75 |
+
def binary_mask_to_box_mask(mask, to_device=True):
|
76 |
+
box = binary_mask_to_box(mask)
|
77 |
+
x_min, y_min, x_max, y_max = box
|
78 |
+
|
79 |
+
H, W = mask.shape
|
80 |
+
mask = torch.zeros(H, W)
|
81 |
+
if to_device:
|
82 |
+
mask = mask.to(torch_device)
|
83 |
+
mask[y_min: y_max+1, x_min: x_max+1] = 1.
|
84 |
+
|
85 |
+
return mask
|
86 |
+
|
87 |
+
def binary_mask_to_center(mask, normalize=False):
|
88 |
+
"""
|
89 |
+
This computes the mass center of the mask.
|
90 |
+
normalize: the coords range from 0 to 1
|
91 |
+
|
92 |
+
Reference: https://stackoverflow.com/a/66184125
|
93 |
+
"""
|
94 |
+
h, w = mask.shape
|
95 |
+
|
96 |
+
total = mask.sum()
|
97 |
+
if isinstance(mask, torch.Tensor):
|
98 |
+
x_coord = ((mask.sum(dim=0) @ torch.arange(w)) / total).item()
|
99 |
+
y_coord = ((mask.sum(dim=1) @ torch.arange(h)) / total).item()
|
100 |
+
else:
|
101 |
+
x_coord = (mask.sum(axis=0) @ np.arange(w)) / total
|
102 |
+
y_coord = (mask.sum(axis=1) @ np.arange(h)) / total
|
103 |
+
|
104 |
+
if normalize:
|
105 |
+
x_coord, y_coord = x_coord / w, y_coord / h
|
106 |
+
return x_coord, y_coord
|
107 |
+
|
108 |
+
|
109 |
+
def iou(mask, masks, eps=1e-6):
|
110 |
+
# mask: [h, w], masks: [n, h, w]
|
111 |
+
mask = mask[None].astype(bool)
|
112 |
+
masks = masks.astype(bool)
|
113 |
+
i = (mask & masks).sum(axis=(1,2))
|
114 |
+
u = (mask | masks).sum(axis=(1,2))
|
115 |
+
|
116 |
+
return i / (u + eps)
|
117 |
+
|
118 |
+
def free_memory():
|
119 |
+
gc.collect()
|
120 |
+
torch.cuda.empty_cache()
|
121 |
+
|
122 |
+
def expand_overall_bboxes(overall_bboxes):
|
123 |
+
"""
|
124 |
+
Expand overall bboxes from a 3d list to 2d list:
|
125 |
+
Input: [[box 1 for phrase 1, box 2 for phrase 1], ...]
|
126 |
+
Output: [box 1, box 2, ...]
|
127 |
+
"""
|
128 |
+
return sum(overall_bboxes, start=[])
|
129 |
+
|
130 |
+
def shift_tensor(tensor, x_offset, y_offset, base_w=8, base_h=8, offset_normalized=False, ignore_last_dim=False):
|
131 |
+
"""base_w and base_h: make sure the shift is aligned in the latent and multiple levels of cross attention"""
|
132 |
+
if ignore_last_dim:
|
133 |
+
tensor_h, tensor_w = tensor.shape[-3:-1]
|
134 |
+
else:
|
135 |
+
tensor_h, tensor_w = tensor.shape[-2:]
|
136 |
+
if offset_normalized:
|
137 |
+
assert tensor_h % base_h == 0 and tensor_w % base_w == 0, f"{tensor_h, tensor_w} is not a multiple of {base_h, base_w}"
|
138 |
+
scale_from_base_h, scale_from_base_w = tensor_h // base_h, tensor_w // base_w
|
139 |
+
x_offset, y_offset = round(x_offset * base_w) * scale_from_base_w, round(y_offset * base_h) * scale_from_base_h
|
140 |
+
new_tensor = torch.zeros_like(tensor)
|
141 |
+
|
142 |
+
overlap_w = tensor_w - abs(x_offset)
|
143 |
+
overlap_h = tensor_h - abs(y_offset)
|
144 |
+
|
145 |
+
if y_offset >= 0:
|
146 |
+
y_src_start = 0
|
147 |
+
y_dest_start = y_offset
|
148 |
+
else:
|
149 |
+
y_src_start = -y_offset
|
150 |
+
y_dest_start = 0
|
151 |
+
|
152 |
+
if x_offset >= 0:
|
153 |
+
x_src_start = 0
|
154 |
+
x_dest_start = x_offset
|
155 |
+
else:
|
156 |
+
x_src_start = -x_offset
|
157 |
+
x_dest_start = 0
|
158 |
+
|
159 |
+
if ignore_last_dim:
|
160 |
+
# For cross attention maps, the third to last and the second to last are the 2D dimensions after unflatten.
|
161 |
+
new_tensor[..., y_dest_start:y_dest_start+overlap_h, x_dest_start:x_dest_start+overlap_w, :] = tensor[..., y_src_start:y_src_start+overlap_h, x_src_start:x_src_start+overlap_w, :]
|
162 |
+
else:
|
163 |
+
new_tensor[..., y_dest_start:y_dest_start+overlap_h, x_dest_start:x_dest_start+overlap_w] = tensor[..., y_src_start:y_src_start+overlap_h, x_src_start:x_src_start+overlap_w]
|
164 |
+
|
165 |
+
return new_tensor
|