FANG DAI
commited on
Upload 126 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Tiger Model/Coarse-Training.py +947 -0
- Tiger Model/Fine-Training.py +1246 -0
- Tiger Model/GP.py +266 -0
- Tiger Model/IS.py +109 -0
- Tiger Model/diffusiers-Tiger/CLIPTextModel.py +1326 -0
- Tiger Model/diffusiers-Tiger/__init__.py +293 -0
- Tiger Model/diffusiers-Tiger/__pycache__/__init__.cpython-38.pyc +0 -0
- Tiger Model/diffusiers-Tiger/__pycache__/configuration_utils.cpython-38.pyc +0 -0
- Tiger Model/diffusiers-Tiger/__pycache__/fuse.cpython-38.pyc +0 -0
- Tiger Model/diffusiers-Tiger/__pycache__/image_processor.cpython-38.pyc +0 -0
- Tiger Model/diffusiers-Tiger/__pycache__/loaders.cpython-38.pyc +0 -0
- Tiger Model/diffusiers-Tiger/__pycache__/optimization.cpython-38.pyc +0 -0
- Tiger Model/diffusiers-Tiger/__pycache__/training_utils.cpython-38.pyc +0 -0
- Tiger Model/diffusiers-Tiger/commands/__init__.py +27 -0
- Tiger Model/diffusiers-Tiger/commands/diffusers_cli.py +43 -0
- Tiger Model/diffusiers-Tiger/commands/env.py +84 -0
- Tiger Model/diffusiers-Tiger/commands/fp16_safetensors.py +133 -0
- Tiger Model/diffusiers-Tiger/configuration_utils.py +686 -0
- Tiger Model/diffusiers-Tiger/dependency_versions_check.py +47 -0
- Tiger Model/diffusiers-Tiger/dependency_versions_table.py +44 -0
- Tiger Model/diffusiers-Tiger/fuse.py +175 -0
- Tiger Model/diffusiers-Tiger/getWeight.py +88 -0
- Tiger Model/diffusiers-Tiger/image_processor.py +366 -0
- Tiger Model/diffusiers-Tiger/loaders.py +0 -0
- Tiger Model/diffusiers-Tiger/models/README.md +3 -0
- Tiger Model/diffusiers-Tiger/models/__init__.py +39 -0
- Tiger Model/diffusiers-Tiger/models/activations.py +14 -0
- Tiger Model/diffusiers-Tiger/models/adapter.py +291 -0
- Tiger Model/diffusiers-Tiger/models/attention.py +437 -0
- Tiger Model/diffusiers-Tiger/models/attention_processor.py +1716 -0
- Tiger Model/diffusiers-Tiger/models/autoencoder_asym_kl.py +180 -0
- Tiger Model/diffusiers-Tiger/models/autoencoder_kl.py +417 -0
- Tiger Model/diffusiers-Tiger/models/autoencoder_tiny.py +342 -0
- Tiger Model/diffusiers-Tiger/models/controlnet.py +762 -0
- Tiger Model/diffusiers-Tiger/models/dual_transformer_2d.py +151 -0
- Tiger Model/diffusiers-Tiger/models/embeddings.py +602 -0
- Tiger Model/diffusiers-Tiger/models/lora.py +117 -0
- Tiger Model/diffusiers-Tiger/models/modeling_utils.py +997 -0
- Tiger Model/diffusiers-Tiger/models/prior_transformer.py +364 -0
- Tiger Model/diffusiers-Tiger/models/resnet.py +878 -0
- Tiger Model/diffusiers-Tiger/models/t5_film_transformer.py +321 -0
- Tiger Model/diffusiers-Tiger/models/transformer_2d.py +359 -0
- Tiger Model/diffusiers-Tiger/models/transformer_temporal.py +179 -0
- Tiger Model/diffusiers-Tiger/models/unet_1d.py +255 -0
- Tiger Model/diffusiers-Tiger/models/unet_1d_blocks.py +656 -0
- Tiger Model/diffusiers-Tiger/models/unet_2d.py +329 -0
- Tiger Model/diffusiers-Tiger/models/unet_2d_blocks.py +0 -0
- Tiger Model/diffusiers-Tiger/models/unet_2d_condition.py +1009 -0
- Tiger Model/diffusiers-Tiger/models/unet_3d_blocks.py +679 -0
- Tiger Model/diffusiers-Tiger/models/unet_3d_condition.py +627 -0
Tiger Model/Coarse-Training.py
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
import math
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import shutil
|
23 |
+
from pathlib import Path
|
24 |
+
|
25 |
+
import datasets
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
import transformers
|
31 |
+
from accelerate import Accelerator
|
32 |
+
from accelerate.logging import get_logger
|
33 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
34 |
+
from datasets import load_dataset
|
35 |
+
from packaging import version
|
36 |
+
from torchvision import transforms
|
37 |
+
from tqdm.auto import tqdm
|
38 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
39 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
40 |
+
import diffusers
|
41 |
+
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler
|
42 |
+
from diffusers.loaders import AttnProcsLayers
|
43 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor
|
44 |
+
from diffusers.optimization import get_scheduler
|
45 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
46 |
+
from diffusers.utils.import_utils import is_xformers_available
|
47 |
+
import warnings
|
48 |
+
warnings.filterwarnings('ignore')
|
49 |
+
|
50 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
51 |
+
|
52 |
+
|
53 |
+
logger = get_logger(__name__, log_level="INFO")
|
54 |
+
|
55 |
+
|
56 |
+
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
57 |
+
img_str = ""
|
58 |
+
for i, image in enumerate(images):
|
59 |
+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
60 |
+
img_str += f"\n"
|
61 |
+
|
62 |
+
yaml = f"""
|
63 |
+
---
|
64 |
+
license: creativeml-openrail-m
|
65 |
+
base_model: {base_model}
|
66 |
+
tags:
|
67 |
+
- stable-diffusion
|
68 |
+
- stable-diffusion-diffusers
|
69 |
+
- text-to-image
|
70 |
+
- diffusers
|
71 |
+
- lora
|
72 |
+
inference: true
|
73 |
+
---
|
74 |
+
"""
|
75 |
+
model_card = f"""
|
76 |
+
# LoRA text2image fine-tuning - {repo_id}
|
77 |
+
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
78 |
+
{img_str}
|
79 |
+
"""
|
80 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
81 |
+
f.write(yaml + model_card)
|
82 |
+
|
83 |
+
|
84 |
+
def parse_args():
|
85 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
86 |
+
parser.add_argument(
|
87 |
+
"--pretrained_model_name_or_path",
|
88 |
+
type=str,
|
89 |
+
default=None,
|
90 |
+
required=True,
|
91 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--revision",
|
95 |
+
type=str,
|
96 |
+
default=None,
|
97 |
+
required=False,
|
98 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--dataset_name",
|
102 |
+
type=str,
|
103 |
+
default=None,
|
104 |
+
help=(
|
105 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
106 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
107 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
108 |
+
),
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--dataset_config_name",
|
112 |
+
type=str,
|
113 |
+
default=None,
|
114 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--train_data_dir",
|
118 |
+
type=str,
|
119 |
+
default=None,
|
120 |
+
help=(
|
121 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
122 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
123 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
124 |
+
),
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--caption_column",
|
131 |
+
type=str,
|
132 |
+
default="text",
|
133 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--num_validation_images",
|
140 |
+
type=int,
|
141 |
+
default=4,
|
142 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--validation_epochs",
|
146 |
+
type=int,
|
147 |
+
default=1,
|
148 |
+
help=(
|
149 |
+
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
|
150 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
151 |
+
),
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--max_train_samples",
|
155 |
+
type=int,
|
156 |
+
default=None,
|
157 |
+
help=(
|
158 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
159 |
+
"value if set."
|
160 |
+
),
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--output_dir",
|
164 |
+
type=str,
|
165 |
+
default="sd-model-finetuned-lora",
|
166 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--cache_dir",
|
170 |
+
type=str,
|
171 |
+
default=None,
|
172 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
173 |
+
)
|
174 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
175 |
+
parser.add_argument(
|
176 |
+
"--resolution",
|
177 |
+
type=int,
|
178 |
+
default=512,
|
179 |
+
help=(
|
180 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
181 |
+
" resolution"
|
182 |
+
),
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--center_crop",
|
186 |
+
default=False,
|
187 |
+
action="store_true",
|
188 |
+
help=(
|
189 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
190 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
191 |
+
),
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--random_flip",
|
195 |
+
action="store_true",
|
196 |
+
help="whether to randomly flip images horizontally",
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
200 |
+
)
|
201 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
202 |
+
parser.add_argument(
|
203 |
+
"--max_train_steps",
|
204 |
+
type=int,
|
205 |
+
default=None,
|
206 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--gradient_accumulation_steps",
|
210 |
+
type=int,
|
211 |
+
default=1,
|
212 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--gradient_checkpointing",
|
216 |
+
action="store_true",
|
217 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--learning_rate",
|
221 |
+
type=float,
|
222 |
+
default=1e-4,
|
223 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--scale_lr",
|
227 |
+
action="store_true",
|
228 |
+
default=False,
|
229 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--lr_scheduler",
|
233 |
+
type=str,
|
234 |
+
default="constant",
|
235 |
+
help=(
|
236 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
237 |
+
' "constant", "constant_with_warmup"]'
|
238 |
+
),
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--snr_gamma",
|
245 |
+
type=float,
|
246 |
+
default=None,
|
247 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
248 |
+
"More details here: https://arxiv.org/abs/2303.09556.",
|
249 |
+
)
|
250 |
+
parser.add_argument(
|
251 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--allow_tf32",
|
255 |
+
action="store_true",
|
256 |
+
help=(
|
257 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
258 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
259 |
+
),
|
260 |
+
)
|
261 |
+
parser.add_argument(
|
262 |
+
"--dataloader_num_workers",
|
263 |
+
type=int,
|
264 |
+
default=0,
|
265 |
+
help=(
|
266 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
267 |
+
),
|
268 |
+
)
|
269 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
270 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
271 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
272 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
273 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
274 |
+
|
275 |
+
parser.add_argument(
|
276 |
+
"--prediction_type",
|
277 |
+
type=str,
|
278 |
+
default=None,
|
279 |
+
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
|
280 |
+
)
|
281 |
+
parser.add_argument(
|
282 |
+
"--hub_model_id",
|
283 |
+
type=str,
|
284 |
+
default=None,
|
285 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--logging_dir",
|
289 |
+
type=str,
|
290 |
+
default="logs",
|
291 |
+
help=(
|
292 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
293 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
294 |
+
),
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--mixed_precision",
|
298 |
+
type=str,
|
299 |
+
default="no",
|
300 |
+
choices=["no", "fp16", "bf16"],
|
301 |
+
help=(
|
302 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
303 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
304 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
305 |
+
),
|
306 |
+
)
|
307 |
+
parser.add_argument(
|
308 |
+
"--report_to",
|
309 |
+
type=str,
|
310 |
+
default="tensorboard",
|
311 |
+
help=(
|
312 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
313 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
314 |
+
),
|
315 |
+
)
|
316 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
317 |
+
parser.add_argument(
|
318 |
+
"--checkpointing_steps",
|
319 |
+
type=int,
|
320 |
+
default=500,
|
321 |
+
help=(
|
322 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
323 |
+
" training using `--resume_from_checkpoint`."
|
324 |
+
),
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--checkpoints_total_limit",
|
328 |
+
type=int,
|
329 |
+
default=None,
|
330 |
+
help=("Max number of checkpoints to store."),
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--resume_from_checkpoint",
|
334 |
+
type=str,
|
335 |
+
default=None,
|
336 |
+
help=(
|
337 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
338 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
339 |
+
),
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
343 |
+
)
|
344 |
+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
345 |
+
parser.add_argument(
|
346 |
+
"--rank",
|
347 |
+
type=int,
|
348 |
+
default=4,
|
349 |
+
help=("The dimension of the LoRA update matrices."),
|
350 |
+
)
|
351 |
+
|
352 |
+
args = parser.parse_args()
|
353 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
354 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
355 |
+
args.local_rank = env_local_rank
|
356 |
+
|
357 |
+
# Sanity checks
|
358 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
359 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
360 |
+
|
361 |
+
return args
|
362 |
+
|
363 |
+
|
364 |
+
DATASET_NAME_MAPPING = {
|
365 |
+
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
366 |
+
}
|
367 |
+
|
368 |
+
|
369 |
+
def main():
|
370 |
+
args = parse_args()
|
371 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
372 |
+
|
373 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
374 |
+
|
375 |
+
accelerator = Accelerator(
|
376 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
377 |
+
mixed_precision=args.mixed_precision,
|
378 |
+
log_with=args.report_to,
|
379 |
+
project_config=accelerator_project_config,
|
380 |
+
)
|
381 |
+
if args.report_to == "wandb":
|
382 |
+
if not is_wandb_available():
|
383 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
384 |
+
import wandb
|
385 |
+
|
386 |
+
# Make one log on every process with the configuration for debugging.
|
387 |
+
logging.basicConfig(
|
388 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
389 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
390 |
+
level=logging.INFO,
|
391 |
+
)
|
392 |
+
logger.info(accelerator.state, main_process_only=False)
|
393 |
+
if accelerator.is_local_main_process:
|
394 |
+
datasets.utils.logging.set_verbosity_warning()
|
395 |
+
transformers.utils.logging.set_verbosity_warning()
|
396 |
+
diffusers.utils.logging.set_verbosity_info()
|
397 |
+
else:
|
398 |
+
datasets.utils.logging.set_verbosity_error()
|
399 |
+
transformers.utils.logging.set_verbosity_error()
|
400 |
+
diffusers.utils.logging.set_verbosity_error()
|
401 |
+
|
402 |
+
# If passed along, set the training seed now.
|
403 |
+
if args.seed is not None:
|
404 |
+
set_seed(args.seed)
|
405 |
+
|
406 |
+
# Handle the repository creation
|
407 |
+
if accelerator.is_main_process:
|
408 |
+
if args.output_dir is not None:
|
409 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
410 |
+
|
411 |
+
# Load scheduler, tokenizer and models.
|
412 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
413 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
414 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
415 |
+
)
|
416 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
417 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
418 |
+
)
|
419 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
420 |
+
unet = UNet2DConditionModel.from_pretrained(
|
421 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
422 |
+
)
|
423 |
+
# freeze parameters of models to save more memory
|
424 |
+
unet.requires_grad_(False)
|
425 |
+
vae.requires_grad_(False)
|
426 |
+
|
427 |
+
text_encoder.requires_grad_(False)
|
428 |
+
|
429 |
+
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
430 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
431 |
+
weight_dtype = torch.float32
|
432 |
+
if accelerator.mixed_precision == "fp16":
|
433 |
+
weight_dtype = torch.float16
|
434 |
+
elif accelerator.mixed_precision == "bf16":
|
435 |
+
weight_dtype = torch.bfloat16
|
436 |
+
|
437 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
438 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
439 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
440 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
441 |
+
|
442 |
+
# now we will add new LoRA weights to the attention layers
|
443 |
+
# It's important to realize here how many attention weights will be added and of which sizes
|
444 |
+
# The sizes of the attention layers consist only of two different variables:
|
445 |
+
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
|
446 |
+
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
|
447 |
+
|
448 |
+
# Let's first see how many attention processors we will have to set.
|
449 |
+
# For Stable Diffusion, it should be equal to:
|
450 |
+
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
|
451 |
+
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
|
452 |
+
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
|
453 |
+
# => 32 layers
|
454 |
+
|
455 |
+
# Set correct lora layers
|
456 |
+
lora_attn_procs = {}
|
457 |
+
for name in unet.attn_processors.keys():
|
458 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
459 |
+
if name.startswith("mid_block"):
|
460 |
+
hidden_size = unet.config.block_out_channels[-1]
|
461 |
+
elif name.startswith("up_blocks"):
|
462 |
+
block_id = int(name[len("up_blocks.")])
|
463 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
464 |
+
elif name.startswith("down_blocks"):
|
465 |
+
block_id = int(name[len("down_blocks.")])
|
466 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
467 |
+
|
468 |
+
lora_attn_procs[name] = LoRAAttnProcessor(
|
469 |
+
hidden_size=hidden_size,
|
470 |
+
cross_attention_dim=cross_attention_dim,
|
471 |
+
rank=args.rank,
|
472 |
+
)
|
473 |
+
|
474 |
+
unet.set_attn_processor(lora_attn_procs)
|
475 |
+
|
476 |
+
|
477 |
+
def compute_snr(timesteps):
|
478 |
+
"""
|
479 |
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
480 |
+
"""
|
481 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
482 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
483 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
484 |
+
|
485 |
+
# Expand the tensors.
|
486 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
487 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
488 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
489 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
490 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
491 |
+
|
492 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
493 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
494 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
495 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
496 |
+
|
497 |
+
# Compute SNR.
|
498 |
+
snr = (alpha / sigma) ** 2
|
499 |
+
return snr
|
500 |
+
|
501 |
+
lora_layers = AttnProcsLayers(unet.attn_processors)
|
502 |
+
|
503 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
504 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
505 |
+
if args.allow_tf32:
|
506 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
507 |
+
|
508 |
+
if args.scale_lr:
|
509 |
+
args.learning_rate = (
|
510 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
511 |
+
)
|
512 |
+
|
513 |
+
# Initialize the optimizer
|
514 |
+
if args.use_8bit_adam:
|
515 |
+
try:
|
516 |
+
import bitsandbytes as bnb
|
517 |
+
except ImportError:
|
518 |
+
raise ImportError(
|
519 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
520 |
+
)
|
521 |
+
|
522 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
523 |
+
else:
|
524 |
+
optimizer_cls = torch.optim.AdamW
|
525 |
+
|
526 |
+
optimizer = optimizer_cls(
|
527 |
+
lora_layers.parameters(),
|
528 |
+
lr=args.learning_rate,
|
529 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
530 |
+
weight_decay=args.adam_weight_decay,
|
531 |
+
eps=args.adam_epsilon,
|
532 |
+
)
|
533 |
+
|
534 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
535 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
536 |
+
|
537 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
538 |
+
# download the dataset.
|
539 |
+
if args.dataset_name is not None:
|
540 |
+
# Downloading and loading a dataset from the hub.
|
541 |
+
dataset = load_dataset(
|
542 |
+
args.dataset_name,
|
543 |
+
args.dataset_config_name,
|
544 |
+
cache_dir=args.cache_dir,
|
545 |
+
data_dir=args.train_data_dir,
|
546 |
+
)
|
547 |
+
else:
|
548 |
+
data_files = {}
|
549 |
+
if args.train_data_dir is not None:
|
550 |
+
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
551 |
+
dataset = load_dataset(
|
552 |
+
"imagefolder",
|
553 |
+
data_files=data_files,
|
554 |
+
cache_dir=args.cache_dir,
|
555 |
+
)
|
556 |
+
# See more about loading custom images at
|
557 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
558 |
+
|
559 |
+
# Preprocessing the datasets.
|
560 |
+
# We need to tokenize inputs and targets.
|
561 |
+
column_names = dataset["train"].column_names
|
562 |
+
|
563 |
+
# 6. Get the column names for input/target.
|
564 |
+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
565 |
+
if args.image_column is None:
|
566 |
+
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
567 |
+
else:
|
568 |
+
image_column = args.image_column
|
569 |
+
if image_column not in column_names:
|
570 |
+
raise ValueError(
|
571 |
+
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
572 |
+
)
|
573 |
+
if args.caption_column is None:
|
574 |
+
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
575 |
+
else:
|
576 |
+
caption_column = args.caption_column
|
577 |
+
if caption_column not in column_names:
|
578 |
+
raise ValueError(
|
579 |
+
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
580 |
+
)
|
581 |
+
|
582 |
+
# Preprocessing the datasets.
|
583 |
+
# We need to tokenize input captions and transform the images.
|
584 |
+
def tokenize_captions(examples, is_train=True):
|
585 |
+
captions = []
|
586 |
+
for caption in examples[caption_column]:
|
587 |
+
if isinstance(caption, str):
|
588 |
+
captions.append(caption)
|
589 |
+
elif isinstance(caption, (list, np.ndarray)):
|
590 |
+
# take a random caption if there are multiple
|
591 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
592 |
+
else:
|
593 |
+
raise ValueError(
|
594 |
+
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
595 |
+
)
|
596 |
+
inputs = tokenizer(
|
597 |
+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
598 |
+
)
|
599 |
+
return inputs.input_ids
|
600 |
+
|
601 |
+
# Preprocessing the datasets.
|
602 |
+
train_transforms = transforms.Compose(
|
603 |
+
[
|
604 |
+
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
605 |
+
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
606 |
+
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
607 |
+
transforms.ToTensor(),
|
608 |
+
transforms.Normalize([0.5], [0.5]),
|
609 |
+
]
|
610 |
+
)
|
611 |
+
|
612 |
+
def preprocess_train(examples):
|
613 |
+
images = [image.convert("RGB") for image in examples[image_column]]
|
614 |
+
examples["pixel_values"] = [train_transforms(image) for image in images]
|
615 |
+
examples["input_ids"] = tokenize_captions(examples)
|
616 |
+
return examples
|
617 |
+
|
618 |
+
with accelerator.main_process_first():
|
619 |
+
if args.max_train_samples is not None:
|
620 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
621 |
+
# Set the training transforms
|
622 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
623 |
+
|
624 |
+
def collate_fn(examples):
|
625 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
626 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
627 |
+
input_ids = torch.stack([example["input_ids"] for example in examples])
|
628 |
+
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
629 |
+
|
630 |
+
# DataLoaders creation:
|
631 |
+
train_dataloader = torch.utils.data.DataLoader(
|
632 |
+
train_dataset,
|
633 |
+
shuffle=True,
|
634 |
+
collate_fn=collate_fn,
|
635 |
+
batch_size=args.train_batch_size,
|
636 |
+
num_workers=args.dataloader_num_workers,
|
637 |
+
)
|
638 |
+
|
639 |
+
# Scheduler and math around the number of training steps.
|
640 |
+
overrode_max_train_steps = False
|
641 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
642 |
+
if args.max_train_steps is None:
|
643 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
644 |
+
overrode_max_train_steps = True
|
645 |
+
|
646 |
+
lr_scheduler = get_scheduler(
|
647 |
+
args.lr_scheduler,
|
648 |
+
optimizer=optimizer,
|
649 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
650 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
651 |
+
)
|
652 |
+
|
653 |
+
# Prepare everything with our `accelerator`.
|
654 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
655 |
+
lora_layers, optimizer, train_dataloader, lr_scheduler
|
656 |
+
)
|
657 |
+
|
658 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
659 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
660 |
+
if overrode_max_train_steps:
|
661 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
662 |
+
# Afterwards we recalculate our number of training epochs
|
663 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
664 |
+
|
665 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
666 |
+
# The trackers initializes automatically on the main process.
|
667 |
+
if accelerator.is_main_process:
|
668 |
+
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
669 |
+
|
670 |
+
# Train!
|
671 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
672 |
+
|
673 |
+
logger.info("***** Running training *****")
|
674 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
675 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
676 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
677 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
678 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
679 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
680 |
+
global_step = 0
|
681 |
+
first_epoch = 0
|
682 |
+
|
683 |
+
# Potentially load in the weights and states from a previous save
|
684 |
+
if args.resume_from_checkpoint:
|
685 |
+
if args.resume_from_checkpoint != "latest":
|
686 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
687 |
+
else:
|
688 |
+
# Get the most recent checkpoint
|
689 |
+
dirs = os.listdir(args.output_dir)
|
690 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
691 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
692 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
693 |
+
|
694 |
+
if path is None:
|
695 |
+
accelerator.print(
|
696 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
697 |
+
)
|
698 |
+
args.resume_from_checkpoint = None
|
699 |
+
else:
|
700 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
701 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
702 |
+
global_step = int(path.split("-")[1])
|
703 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
704 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
705 |
+
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
706 |
+
|
707 |
+
# Only show the progress bar once on each machine.
|
708 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
709 |
+
progress_bar.set_description("Steps")
|
710 |
+
|
711 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
712 |
+
unet.train()
|
713 |
+
train_loss = 0.0
|
714 |
+
for step, batch in enumerate(train_dataloader):
|
715 |
+
# Skip steps until we reach the resumed step
|
716 |
+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
717 |
+
if step % args.gradient_accumulation_steps == 0:
|
718 |
+
progress_bar.update(1)
|
719 |
+
continue
|
720 |
+
|
721 |
+
with accelerator.accumulate(unet):
|
722 |
+
# Convert images to latent space
|
723 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
724 |
+
latents = latents * vae.config.scaling_factor
|
725 |
+
|
726 |
+
# Sample noise that we'll add to the latents
|
727 |
+
noise = torch.randn_like(latents)
|
728 |
+
if args.noise_offset:
|
729 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
730 |
+
noise += args.noise_offset * torch.randn(
|
731 |
+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
732 |
+
)
|
733 |
+
|
734 |
+
bsz = latents.shape[0]
|
735 |
+
# Sample a random timestep for each image
|
736 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
737 |
+
timesteps = timesteps.long()
|
738 |
+
|
739 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
740 |
+
# (this is the forward diffusion process)
|
741 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
742 |
+
|
743 |
+
# Get the text embedding for conditioning
|
744 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
745 |
+
|
746 |
+
# Get the target for loss depending on the prediction type
|
747 |
+
if args.prediction_type is not None:
|
748 |
+
# set prediction_type of scheduler if defined
|
749 |
+
noise_scheduler.register_to_config(prediction_type=args.prediction_type)
|
750 |
+
|
751 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
752 |
+
target = noise
|
753 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
754 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
755 |
+
else:
|
756 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
757 |
+
|
758 |
+
# Predict the noise residual and compute loss
|
759 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
760 |
+
|
761 |
+
if args.snr_gamma is None:
|
762 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
763 |
+
else:
|
764 |
+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
765 |
+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
766 |
+
# This is discussed in Section 4.2 of the same paper.
|
767 |
+
snr = compute_snr(timesteps)
|
768 |
+
mse_loss_weights = (
|
769 |
+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
770 |
+
)
|
771 |
+
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
772 |
+
# rebalance the sample-wise losses with their respective loss weights.
|
773 |
+
# Finally, we take the mean of the rebalanced loss.
|
774 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
775 |
+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
776 |
+
loss = loss.mean()
|
777 |
+
|
778 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
779 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
780 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
781 |
+
|
782 |
+
# Backpropagate
|
783 |
+
accelerator.backward(loss)
|
784 |
+
if accelerator.sync_gradients:
|
785 |
+
params_to_clip = lora_layers.parameters()
|
786 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
787 |
+
optimizer.step()
|
788 |
+
lr_scheduler.step()
|
789 |
+
optimizer.zero_grad()
|
790 |
+
|
791 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
792 |
+
if accelerator.sync_gradients:
|
793 |
+
progress_bar.update(1)
|
794 |
+
global_step += 1
|
795 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
796 |
+
train_loss = 0.0
|
797 |
+
|
798 |
+
if global_step % args.checkpointing_steps == 0:
|
799 |
+
if accelerator.is_main_process:
|
800 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
801 |
+
if args.checkpoints_total_limit is not None:
|
802 |
+
checkpoints = os.listdir(args.output_dir)
|
803 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
804 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
805 |
+
|
806 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
807 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
808 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
809 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
810 |
+
|
811 |
+
logger.info(
|
812 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
813 |
+
)
|
814 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
815 |
+
|
816 |
+
for removing_checkpoint in removing_checkpoints:
|
817 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
818 |
+
shutil.rmtree(removing_checkpoint)
|
819 |
+
|
820 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
821 |
+
logger.info(f"Saved state to {save_path}")
|
822 |
+
|
823 |
+
unet = unet.to(torch.float32)
|
824 |
+
unet.save_attn_procs(save_path)
|
825 |
+
|
826 |
+
# create pipeline
|
827 |
+
# pipeline = DiffusionPipeline.from_pretrained(
|
828 |
+
# args.pretrained_model_name_or_path,
|
829 |
+
# unet=accelerator.unwrap_model(unet),
|
830 |
+
# revision=args.revision,
|
831 |
+
# torch_dtype=weight_dtype,
|
832 |
+
# )
|
833 |
+
# pipeline = pipeline.to(accelerator.device)
|
834 |
+
# pipeline.set_progress_bar_config(disable=True)
|
835 |
+
|
836 |
+
# # run inference
|
837 |
+
# generator = torch.Generator(device=accelerator.device)
|
838 |
+
|
839 |
+
# images = []
|
840 |
+
# for i in range(args.num_validation_images):
|
841 |
+
# if args.seed is not None:
|
842 |
+
# generator = generator.manual_seed(args.seed + i + args.checkpointing_steps)
|
843 |
+
# images.append(
|
844 |
+
# pipeline(args.validation_prompt, num_inference_steps=30, generator=generator, guidance_scale=7).images[0]
|
845 |
+
# )
|
846 |
+
|
847 |
+
if args.validation_prompt is not None:
|
848 |
+
logger.info(
|
849 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
850 |
+
f" {args.validation_prompt}."
|
851 |
+
)
|
852 |
+
print()
|
853 |
+
# create pipeline
|
854 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
855 |
+
args.pretrained_model_name_or_path,
|
856 |
+
unet=accelerator.unwrap_model(unet),
|
857 |
+
revision=args.revision,
|
858 |
+
torch_dtype=weight_dtype,
|
859 |
+
)
|
860 |
+
pipeline = pipeline.to(accelerator.device)
|
861 |
+
pipeline.set_progress_bar_config(disable=True)
|
862 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
863 |
+
# run inference
|
864 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
865 |
+
images = []
|
866 |
+
for _ in range(args.num_validation_images):
|
867 |
+
images.append(
|
868 |
+
pipeline(
|
869 |
+
args.validation_prompt,
|
870 |
+
height=args.resolution,
|
871 |
+
width=args.resolution,
|
872 |
+
num_inference_steps=45,
|
873 |
+
generator=generator
|
874 |
+
).images[0]
|
875 |
+
)
|
876 |
+
image2 = pipeline(
|
877 |
+
'High quality photo of an astronaut riding a horse in space',
|
878 |
+
guidance_scale=7,
|
879 |
+
height=args.resolution,
|
880 |
+
width=args.resolution,
|
881 |
+
num_inference_steps=45
|
882 |
+
).images[0]
|
883 |
+
images.append(image2)
|
884 |
+
|
885 |
+
for tracker in accelerator.trackers:
|
886 |
+
if tracker.name == "tensorboard":
|
887 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
888 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
889 |
+
if tracker.name == "wandb":
|
890 |
+
tracker.log(
|
891 |
+
{
|
892 |
+
"validation": [
|
893 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
894 |
+
for i, image in enumerate(images)
|
895 |
+
]
|
896 |
+
}
|
897 |
+
)
|
898 |
+
|
899 |
+
del pipeline
|
900 |
+
torch.cuda.empty_cache()
|
901 |
+
|
902 |
+
# Save the lora layers
|
903 |
+
accelerator.wait_for_everyone()
|
904 |
+
if accelerator.is_main_process:
|
905 |
+
unet = unet.to(torch.float32)
|
906 |
+
unet.save_attn_procs(args.output_dir)
|
907 |
+
|
908 |
+
|
909 |
+
# Final inference
|
910 |
+
# Load previous pipeline
|
911 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
912 |
+
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
913 |
+
)
|
914 |
+
pipeline = pipeline.to(accelerator.device)
|
915 |
+
|
916 |
+
# load attention processors
|
917 |
+
pipeline.unet.load_attn_procs(args.output_dir)
|
918 |
+
|
919 |
+
# run inference
|
920 |
+
generator = torch.Generator(device=accelerator.device)
|
921 |
+
if args.seed is not None:
|
922 |
+
generator = generator.manual_seed(args.seed)
|
923 |
+
images = []
|
924 |
+
for _ in range(args.num_validation_images):
|
925 |
+
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
|
926 |
+
|
927 |
+
if accelerator.is_main_process:
|
928 |
+
for tracker in accelerator.trackers:
|
929 |
+
if len(images) != 0:
|
930 |
+
if tracker.name == "tensorboard":
|
931 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
932 |
+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
933 |
+
if tracker.name == "wandb":
|
934 |
+
tracker.log(
|
935 |
+
{
|
936 |
+
"test": [
|
937 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
938 |
+
for i, image in enumerate(images)
|
939 |
+
]
|
940 |
+
}
|
941 |
+
)
|
942 |
+
|
943 |
+
accelerator.end_training()
|
944 |
+
|
945 |
+
|
946 |
+
if __name__ == "__main__":
|
947 |
+
main()
|
Tiger Model/Fine-Training.py
ADDED
@@ -0,0 +1,1246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
import math
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import shutil
|
23 |
+
from pathlib import Path
|
24 |
+
from pynvml import *
|
25 |
+
import accelerate
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
import transformers123
|
31 |
+
from accelerate import Accelerator
|
32 |
+
from accelerate.logging import get_logger
|
33 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
34 |
+
from datasets import load_dataset
|
35 |
+
from huggingface_hub import create_repo, upload_folder
|
36 |
+
from packaging import version
|
37 |
+
from PIL import Image
|
38 |
+
from torchvision import transforms
|
39 |
+
from tqdm.auto import tqdm
|
40 |
+
import transformers
|
41 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
42 |
+
import tensorflow as tf
|
43 |
+
tf.get_logger().setLevel('ERROR')
|
44 |
+
from collections import Counter
|
45 |
+
import diffusers_Tiger
|
46 |
+
from diffusers_Tiger import (
|
47 |
+
AutoencoderKL,
|
48 |
+
ControlNetModel,
|
49 |
+
DDPMScheduler,
|
50 |
+
StableDiffusionControlNetPipeline,
|
51 |
+
StableDiffusionControlNetInpaintPipeline,
|
52 |
+
UNet2DConditionModel,
|
53 |
+
UniPCMultistepScheduler,
|
54 |
+
DDIMScheduler
|
55 |
+
)
|
56 |
+
from diffusers_Tiger.optimization import get_scheduler
|
57 |
+
from diffusers_Tiger.utils import check_min_version, is_wandb_available
|
58 |
+
from diffusers_Tiger.utils.import_utils import is_xformers_available
|
59 |
+
from diffusers_Tiger import fuse
|
60 |
+
|
61 |
+
if is_wandb_available():
|
62 |
+
import wandb
|
63 |
+
import warnings
|
64 |
+
warnings.filterwarnings('ignore')
|
65 |
+
|
66 |
+
# Will error if the minimal version of diffusers123 is not installed. Remove at your own risks.
|
67 |
+
check_min_version("0.19.0.dev0")
|
68 |
+
|
69 |
+
logger = get_logger(__name__)
|
70 |
+
|
71 |
+
|
72 |
+
def image_grid(imgs, rows, cols):
|
73 |
+
assert len(imgs) == rows * cols
|
74 |
+
|
75 |
+
w, h = imgs[0].sizeelerator
|
76 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
77 |
+
|
78 |
+
for i, img in enumerate(imgs):
|
79 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
80 |
+
return grid
|
81 |
+
|
82 |
+
def make_inpaint_condition(image, image_mask):
|
83 |
+
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
|
84 |
+
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
|
85 |
+
|
86 |
+
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
|
87 |
+
image[image_mask > 0.5] = -1.0 # set as masked pixel
|
88 |
+
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
89 |
+
image = torch.from_numpy(image)
|
90 |
+
return image
|
91 |
+
|
92 |
+
def log_validation(vae, text_encoder, tokenizer, unet, controlnet_nd, controlnet_bg, args, accelerator, weight_dtype, step):
|
93 |
+
logger.info("Running validation... ")
|
94 |
+
|
95 |
+
controlnet_nd = accelerator.unwrap_model(controlnet)
|
96 |
+
|
97 |
+
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
98 |
+
args.pretrained_model_name_or_path,
|
99 |
+
vae=vae,
|
100 |
+
text_encoder=text_encoder,
|
101 |
+
tokenizer=tokenizer,
|
102 |
+
unet=unet,
|
103 |
+
controlnet=controlnet,
|
104 |
+
safety_checker=None,
|
105 |
+
revision=args.revision,
|
106 |
+
torch_dtype=weight_dtype,
|
107 |
+
)
|
108 |
+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
109 |
+
pipeline = pipeline.to(accelerator.device)
|
110 |
+
pipeline.set_progress_bar_config(disable=True)
|
111 |
+
|
112 |
+
if args.enable_xformers_memory_efficient_attention:
|
113 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
114 |
+
|
115 |
+
if args.seed is None:
|
116 |
+
generator = None
|
117 |
+
else:
|
118 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
119 |
+
|
120 |
+
if len(args.validation_image) == len(args.validation_prompt):
|
121 |
+
validation_images = args.validation_image
|
122 |
+
validation_prompts = args.validation_prompt
|
123 |
+
elif len(args.validation_image) == 1:
|
124 |
+
validation_images = args.validation_image * len(args.validation_prompt)
|
125 |
+
validation_prompts = args.validation_prompt
|
126 |
+
elif len(args.validation_prompt) == 1:
|
127 |
+
validation_images = args.validation_image
|
128 |
+
validation_prompts = args.validation_prompt * len(args.validation_image)
|
129 |
+
else:
|
130 |
+
raise ValueError(
|
131 |
+
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
|
132 |
+
)
|
133 |
+
|
134 |
+
image_logs = []
|
135 |
+
images = []
|
136 |
+
for validation_prompt, validation_image1 in zip(validation_prompts, validation_images):
|
137 |
+
validation_image = Image.open(validation_image1).convert("RGB").resize((512, 512))
|
138 |
+
mask_image = Image.open(validation_image1).convert("RGB").resize((512, 512))
|
139 |
+
|
140 |
+
control_image = make_inpaint_condition(validation_image, mask_image)
|
141 |
+
|
142 |
+
for _ in range(args.num_validation_images):
|
143 |
+
with torch.autocast("cuda"):
|
144 |
+
seed = random.randint(1,1000000)
|
145 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
146 |
+
image = pipeline(
|
147 |
+
validation_prompt,
|
148 |
+
num_inference_steps=50,
|
149 |
+
generator=generator,
|
150 |
+
eta=1.0,
|
151 |
+
image=validation_image,
|
152 |
+
mask_image=mask_image,
|
153 |
+
control_image=control_image,
|
154 |
+
guidance_scale = 7
|
155 |
+
).images[0]
|
156 |
+
|
157 |
+
images.append(image)
|
158 |
+
|
159 |
+
|
160 |
+
image_logs.append(
|
161 |
+
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
|
162 |
+
)
|
163 |
+
|
164 |
+
for tracker in accelerator.trackers:
|
165 |
+
if tracker.name == "tensorboard":
|
166 |
+
for log in image_logs:
|
167 |
+
images = log["images"]
|
168 |
+
validation_prompt = log["validation_prompt"]
|
169 |
+
validation_image = log["validation_image"]
|
170 |
+
|
171 |
+
formatted_images = []
|
172 |
+
|
173 |
+
formatted_images.append(np.asarray(validation_image))
|
174 |
+
|
175 |
+
for image in images:
|
176 |
+
formatted_images.append(np.asarray(image))
|
177 |
+
|
178 |
+
formatted_images = np.stack(formatted_images)
|
179 |
+
|
180 |
+
tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
|
181 |
+
elif tracker.name == "wandb":
|
182 |
+
formatted_images = []
|
183 |
+
|
184 |
+
for log in image_logs:
|
185 |
+
images = log["images"]
|
186 |
+
validation_prompt = log["validation_prompt"]
|
187 |
+
validation_image = log["validation_image"]
|
188 |
+
|
189 |
+
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
190 |
+
|
191 |
+
for image in images:
|
192 |
+
image = wandb.Image(image, caption=validation_prompt)
|
193 |
+
formatted_images.append(image)
|
194 |
+
|
195 |
+
tracker.log({"validation": formatted_images})
|
196 |
+
else:
|
197 |
+
logger.warn(f"image logging not implemented for {tracker.name}")
|
198 |
+
|
199 |
+
return image_logs
|
200 |
+
|
201 |
+
|
202 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
203 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
204 |
+
pretrained_model_name_or_path,
|
205 |
+
subfolder="text_encoder",
|
206 |
+
revision=revision,
|
207 |
+
)
|
208 |
+
model_class = text_encoder_config.architectures[0]
|
209 |
+
|
210 |
+
if model_class == "CLIPTextModel":
|
211 |
+
from transformers123 import CLIPTextModel
|
212 |
+
|
213 |
+
return CLIPTextModel
|
214 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
215 |
+
from diffusers123.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
216 |
+
|
217 |
+
return RobertaSeriesModelWithTransformation
|
218 |
+
else:
|
219 |
+
raise ValueError(f"{model_class} is not supported.")
|
220 |
+
|
221 |
+
|
222 |
+
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
223 |
+
img_str = ""
|
224 |
+
if image_logs is not None:
|
225 |
+
img_str = "You can find some example images below.\n"
|
226 |
+
for i, log in enumerate(image_logs):
|
227 |
+
images = log["images"]
|
228 |
+
validation_prompt = log["validation_prompt"]
|
229 |
+
validation_image = log["validation_image"]
|
230 |
+
validation_image.save(os.path.join(repo_folder, "image_control.png"))
|
231 |
+
img_str += f"prompt: {validation_prompt}\n"
|
232 |
+
images = [validation_image] + images
|
233 |
+
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
234 |
+
img_str += f"\n"
|
235 |
+
|
236 |
+
yaml = f"""
|
237 |
+
---
|
238 |
+
license: creativeml-openrail-m
|
239 |
+
base_model: {base_model}
|
240 |
+
tags:
|
241 |
+
- stable-diffusion
|
242 |
+
- stable-diffusion-diffusers
|
243 |
+
- text-to-image
|
244 |
+
- diffusers
|
245 |
+
- controlnet
|
246 |
+
inference: true
|
247 |
+
---
|
248 |
+
"""
|
249 |
+
model_card = f"""
|
250 |
+
# controlnet-{repo_id}
|
251 |
+
|
252 |
+
These are controlnet weights trained on {base_model} with new type of conditioning.
|
253 |
+
{img_str}
|
254 |
+
"""
|
255 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
256 |
+
f.write(yaml + model_card)
|
257 |
+
|
258 |
+
|
259 |
+
def parse_args(input_args=None):
|
260 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
|
261 |
+
parser.add_argument(
|
262 |
+
"--pretrained_model_name_or_path",
|
263 |
+
type=str,
|
264 |
+
default=None,
|
265 |
+
required=True,
|
266 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"--controlnet_model_name_or_path",
|
270 |
+
type=str,
|
271 |
+
default=None,
|
272 |
+
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
|
273 |
+
" If not specified controlnet weights are initialized from unet.",
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--revision",
|
277 |
+
type=str,
|
278 |
+
default=None,
|
279 |
+
required=False,
|
280 |
+
help=(
|
281 |
+
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
282 |
+
" float32 precision."
|
283 |
+
),
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--tokenizer_name",
|
287 |
+
type=str,
|
288 |
+
default=None,
|
289 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--output_dir",
|
293 |
+
type=str,
|
294 |
+
default="controlnet-model",
|
295 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--cache_dir",
|
299 |
+
type=str,
|
300 |
+
default="/export/home/daifang/Diffusion/own_code/dataset",
|
301 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
302 |
+
)
|
303 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
304 |
+
parser.add_argument(
|
305 |
+
"--resolution",
|
306 |
+
type=int,
|
307 |
+
default=512,
|
308 |
+
help=(
|
309 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
310 |
+
" resolution"
|
311 |
+
),
|
312 |
+
)
|
313 |
+
parser.add_argument(
|
314 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
315 |
+
)
|
316 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
317 |
+
parser.add_argument(
|
318 |
+
"--max_train_steps",
|
319 |
+
type=int,
|
320 |
+
default=None,
|
321 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--checkpointing_steps",
|
325 |
+
type=int,
|
326 |
+
default=500,
|
327 |
+
help=(
|
328 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
329 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
330 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
331 |
+
"See https://huggingface.co/docs/diffusers123/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
332 |
+
"instructions."
|
333 |
+
),
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--checkpoints_total_limit",
|
337 |
+
type=int,
|
338 |
+
default=None,
|
339 |
+
help=("Max number of checkpoints to store."),
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--resume_from_checkpoint",
|
343 |
+
type=str,
|
344 |
+
default=None,
|
345 |
+
help=(
|
346 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
347 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
348 |
+
),
|
349 |
+
)
|
350 |
+
parser.add_argument(
|
351 |
+
"--gradient_accumulation_steps",
|
352 |
+
type=int,
|
353 |
+
default=1,
|
354 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
355 |
+
)
|
356 |
+
parser.add_argument(
|
357 |
+
"--gradient_checkpointing",
|
358 |
+
action="store_true",
|
359 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--learning_rate",
|
363 |
+
type=float,
|
364 |
+
default=5e-6,
|
365 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
366 |
+
)
|
367 |
+
parser.add_argument(
|
368 |
+
"--scale_lr",
|
369 |
+
action="store_true",
|
370 |
+
default=False,
|
371 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
372 |
+
)
|
373 |
+
parser.add_argument(
|
374 |
+
"--lr_scheduler",
|
375 |
+
type=str,
|
376 |
+
default="constant",
|
377 |
+
help=(
|
378 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
379 |
+
' "constant", "constant_with_warmup"]'
|
380 |
+
),
|
381 |
+
)
|
382 |
+
parser.add_argument(
|
383 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
384 |
+
)
|
385 |
+
parser.add_argument(
|
386 |
+
"--lr_num_cycles",
|
387 |
+
type=int,
|
388 |
+
default=1,
|
389 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
390 |
+
)
|
391 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
392 |
+
parser.add_argument(
|
393 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
394 |
+
)
|
395 |
+
parser.add_argument(
|
396 |
+
"--dataloader_num_workers",
|
397 |
+
type=int,
|
398 |
+
default=0,
|
399 |
+
help=(
|
400 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
401 |
+
),
|
402 |
+
)
|
403 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
404 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
405 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
406 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
407 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
408 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
409 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
410 |
+
parser.add_argument(
|
411 |
+
"--hub_model_id",
|
412 |
+
type=str,
|
413 |
+
default=None,
|
414 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
415 |
+
)
|
416 |
+
parser.add_argument(
|
417 |
+
"--logging_dir",
|
418 |
+
type=str,
|
419 |
+
default="logs",
|
420 |
+
help=(
|
421 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
422 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
423 |
+
),
|
424 |
+
)
|
425 |
+
parser.add_argument(
|
426 |
+
"--allow_tf32",
|
427 |
+
action="store_true",
|
428 |
+
help=(
|
429 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
430 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
431 |
+
),
|
432 |
+
)
|
433 |
+
parser.add_argument(
|
434 |
+
"--report_to",
|
435 |
+
type=str,
|
436 |
+
default="tensorboard",
|
437 |
+
help=(
|
438 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
439 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
440 |
+
),
|
441 |
+
)
|
442 |
+
parser.add_argument(
|
443 |
+
"--mixed_precision",
|
444 |
+
type=str,
|
445 |
+
default="no",
|
446 |
+
choices=["no", "fp16", "bf16"],
|
447 |
+
help=(
|
448 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
449 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
450 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
451 |
+
),
|
452 |
+
)
|
453 |
+
parser.add_argument(
|
454 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
455 |
+
)
|
456 |
+
parser.add_argument(
|
457 |
+
"--set_grads_to_none",
|
458 |
+
action="store_true",
|
459 |
+
help=(
|
460 |
+
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
461 |
+
" behaviors, so disable this argument if it causes any problems. More info:"
|
462 |
+
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
463 |
+
),
|
464 |
+
)
|
465 |
+
parser.add_argument(
|
466 |
+
"--dataset_name",
|
467 |
+
type=str,
|
468 |
+
default=None,
|
469 |
+
help=(
|
470 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
471 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
472 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
473 |
+
),
|
474 |
+
)
|
475 |
+
parser.add_argument(
|
476 |
+
"--dataset_config_name",
|
477 |
+
type=str,
|
478 |
+
default=None,
|
479 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
480 |
+
)
|
481 |
+
parser.add_argument(
|
482 |
+
"--train_data_dir",
|
483 |
+
type=str,
|
484 |
+
default=None,
|
485 |
+
help=(
|
486 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
487 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
488 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
489 |
+
),
|
490 |
+
)
|
491 |
+
##############################################################################################################
|
492 |
+
parser.add_argument(
|
493 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing the target image."
|
494 |
+
)
|
495 |
+
parser.add_argument(
|
496 |
+
"--conditioning_nd_column",
|
497 |
+
type=str,
|
498 |
+
default="condition_nd",
|
499 |
+
help="The column of the dataset containing the controlnet conditioning image.",
|
500 |
+
)
|
501 |
+
parser.add_argument(
|
502 |
+
"--conditioning_bg_column",
|
503 |
+
type=str,
|
504 |
+
default="condition_bg",
|
505 |
+
help="The column of the dataset containing the controlnet conditioning image.",
|
506 |
+
)
|
507 |
+
parser.add_argument(
|
508 |
+
"--caption_column_nd",
|
509 |
+
type=str,
|
510 |
+
default="text_nd",
|
511 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
512 |
+
)
|
513 |
+
parser.add_argument(
|
514 |
+
"--caption_column_bg",
|
515 |
+
type=str,
|
516 |
+
default="text_nd",
|
517 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
518 |
+
)
|
519 |
+
##############################################################################################################
|
520 |
+
parser.add_argument(
|
521 |
+
"--max_train_samples",
|
522 |
+
type=int,
|
523 |
+
default=None,
|
524 |
+
help=(
|
525 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
526 |
+
"value if set."
|
527 |
+
),
|
528 |
+
)
|
529 |
+
parser.add_argument(
|
530 |
+
"--proportion_empty_prompts",
|
531 |
+
type=float,
|
532 |
+
default=0,
|
533 |
+
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
|
534 |
+
)
|
535 |
+
parser.add_argument(
|
536 |
+
"--validation_prompt",
|
537 |
+
type=str,
|
538 |
+
default=None,
|
539 |
+
nargs="+",
|
540 |
+
help=(
|
541 |
+
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
|
542 |
+
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
|
543 |
+
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
|
544 |
+
),
|
545 |
+
)
|
546 |
+
parser.add_argument(
|
547 |
+
"--validation_image",
|
548 |
+
type=str,
|
549 |
+
default=None,
|
550 |
+
nargs="+",
|
551 |
+
help=(
|
552 |
+
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
553 |
+
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
554 |
+
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
555 |
+
" `--validation_image` that will be used with all `--validation_prompt`s."
|
556 |
+
),
|
557 |
+
)
|
558 |
+
parser.add_argument(
|
559 |
+
"--num_validation_images",
|
560 |
+
type=int,
|
561 |
+
default=4,
|
562 |
+
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
|
563 |
+
)
|
564 |
+
parser.add_argument(
|
565 |
+
"--validation_steps",
|
566 |
+
type=int,
|
567 |
+
default=100,
|
568 |
+
help=(
|
569 |
+
"Run validation every X steps. Validation consists of running the prompt"
|
570 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
571 |
+
" and logging the images."
|
572 |
+
),
|
573 |
+
)
|
574 |
+
parser.add_argument(
|
575 |
+
"--tracker_project_name",
|
576 |
+
type=str,
|
577 |
+
default="train_controlnet",
|
578 |
+
help=(
|
579 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
580 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
581 |
+
),
|
582 |
+
)
|
583 |
+
|
584 |
+
if input_args is not None:
|
585 |
+
args = parser.parse_args(input_args)
|
586 |
+
else:
|
587 |
+
args = parser.parse_args()
|
588 |
+
|
589 |
+
if args.dataset_name is None and args.train_data_dir is None:
|
590 |
+
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
|
591 |
+
|
592 |
+
if args.dataset_name is not None and args.train_data_dir is not None:
|
593 |
+
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
|
594 |
+
|
595 |
+
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
|
596 |
+
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
|
597 |
+
|
598 |
+
if args.validation_prompt is not None and args.validation_image is None:
|
599 |
+
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
|
600 |
+
|
601 |
+
if args.validation_prompt is None and args.validation_image is not None:
|
602 |
+
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
|
603 |
+
|
604 |
+
if (
|
605 |
+
args.validation_image is not None
|
606 |
+
and args.validation_prompt is not None
|
607 |
+
and len(args.validation_image) != 1
|
608 |
+
and len(args.validation_prompt) != 1
|
609 |
+
and len(args.validation_image) != len(args.validation_prompt)
|
610 |
+
):
|
611 |
+
raise ValueError(
|
612 |
+
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
|
613 |
+
" or the same number of `--validation_prompt`s and `--validation_image`s"
|
614 |
+
)
|
615 |
+
|
616 |
+
if args.resolution % 8 != 0:
|
617 |
+
raise ValueError(
|
618 |
+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
619 |
+
)
|
620 |
+
|
621 |
+
return args
|
622 |
+
|
623 |
+
|
624 |
+
def make_train_dataset(args, tokenizer, accelerator):
|
625 |
+
if args.dataset_name is not None:
|
626 |
+
dataset = load_dataset(
|
627 |
+
args.dataset_name,
|
628 |
+
args.dataset_config_name,
|
629 |
+
cache_dir=args.cache_dir,
|
630 |
+
)
|
631 |
+
else:
|
632 |
+
if args.train_data_dir is not None:
|
633 |
+
dataset = load_dataset(
|
634 |
+
args.train_data_dir,
|
635 |
+
cache_dir=args.cache_dir,
|
636 |
+
)
|
637 |
+
column_names = dataset["train"].column_names
|
638 |
+
##########################################################################################################################################################################
|
639 |
+
# Get the column names for input/target.
|
640 |
+
# target image
|
641 |
+
if args.image_column is None:
|
642 |
+
image_column = column_names[0]
|
643 |
+
logger.info(f"image column defaulting to {image_column}")
|
644 |
+
else:
|
645 |
+
image_column = args.image_column
|
646 |
+
if image_column not in column_names:
|
647 |
+
raise ValueError(
|
648 |
+
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
649 |
+
)
|
650 |
+
# condition nodule image
|
651 |
+
if args.conditioning_nd_column is None:
|
652 |
+
|
653 |
+
conditioning_nd_column = column_names[1]
|
654 |
+
logger.info(f"conditioning image column defaulting to {conditioning_nd_column}")
|
655 |
+
else:
|
656 |
+
conditioning_nd_column = args.conditioning_nd_column
|
657 |
+
if conditioning_nd_column not in column_names:
|
658 |
+
raise ValueError(
|
659 |
+
f"`--conditioning_nd_column` value '{args.conditioning_nd_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
660 |
+
)
|
661 |
+
# condition background image
|
662 |
+
if args.conditioning_bg_column is None:
|
663 |
+
conditioning_bg_column = column_names[2]
|
664 |
+
logger.info(f"conditioning bg column defaulting to {conditioning_bg_column}")
|
665 |
+
else:
|
666 |
+
conditioning_bg_column = args.conditioning_bg_column
|
667 |
+
if conditioning_bg_column not in column_names:
|
668 |
+
raise ValueError(
|
669 |
+
f"`--conditioning_bg_column` value '{args.conditioning_bg_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
670 |
+
)
|
671 |
+
# condition nodule text
|
672 |
+
|
673 |
+
if args.caption_column_nd is None:
|
674 |
+
caption_column_nd = column_names[3]
|
675 |
+
logger.info(f"caption column defaulting to {caption_column_nd}")
|
676 |
+
else:
|
677 |
+
caption_column_nd = args.caption_column_nd
|
678 |
+
if caption_column_nd not in column_names:
|
679 |
+
raise ValueError(
|
680 |
+
f"`--caption_column` value '{args.caption_column_nd}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
681 |
+
)
|
682 |
+
# condition backgrorund text
|
683 |
+
if args.caption_column_bg is None:
|
684 |
+
caption_column_bg = column_names[4]
|
685 |
+
logger.info(f"caption column defaulting to {caption_column_bg}")
|
686 |
+
else:
|
687 |
+
caption_column_bg = args.caption_column_bg
|
688 |
+
if caption_column_bg not in column_names:
|
689 |
+
raise ValueError(
|
690 |
+
f"`--caption_column` value '{args.caption_column_bg}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
691 |
+
)
|
692 |
+
##########################################################################################################################################################################
|
693 |
+
|
694 |
+
def tokenize_captions(examples, caption_column, names, is_train=True):
|
695 |
+
captions = []
|
696 |
+
for caption in examples[caption_column]:
|
697 |
+
if random.random() < args.proportion_empty_prompts:
|
698 |
+
captions.append("")
|
699 |
+
elif isinstance(caption, str):
|
700 |
+
captions.append(caption)
|
701 |
+
elif isinstance(caption, (list, np.ndarray)):
|
702 |
+
|
703 |
+
# take a random caption if there are multiple
|
704 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
705 |
+
else:
|
706 |
+
raise ValueError(
|
707 |
+
f"Caption column `{caption_column_nd}` should contain either strings or lists of strings."
|
708 |
+
)
|
709 |
+
|
710 |
+
inputs = tokenizer(
|
711 |
+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
712 |
+
)
|
713 |
+
|
714 |
+
def calculate_word_frequencies(phrases):
|
715 |
+
total_counts = Counter()
|
716 |
+
total_words = 0
|
717 |
+
for phrase in phrases:
|
718 |
+
words = phrase.replace(',', '').split()
|
719 |
+
total_counts.update(words)
|
720 |
+
total_words += len(words)
|
721 |
+
frequencies = {word: count / total_words for word, count in total_counts.items()}
|
722 |
+
return frequencies, total_words
|
723 |
+
|
724 |
+
def calculate_average_frequencies(phrases, word_frequencies):
|
725 |
+
average_frequencies = []
|
726 |
+
for phrase in phrases:
|
727 |
+
words = phrase.replace(',', '').split()
|
728 |
+
total_freq = sum(word_frequencies[word] for word in words)
|
729 |
+
avg_freq = total_freq / len(words) if words else 0
|
730 |
+
average_frequencies.append((phrase, avg_freq))
|
731 |
+
return average_frequencies
|
732 |
+
if names == 'nd':
|
733 |
+
word_frequencies, total_word_count = calculate_word_frequencies(captions)
|
734 |
+
weight_matrix = calculate_average_frequencies(captions, word_frequencies)
|
735 |
+
# Extract the values to replace
|
736 |
+
values = [desc[1] for desc in weight_matrix]
|
737 |
+
# Replace the first zero in each row with the corresponding value
|
738 |
+
for i in range(inputs.input_ids.shape[0]):
|
739 |
+
weight = int(values[i]*10**5)
|
740 |
+
inputs.input_ids[i][0] = weight
|
741 |
+
assert not torch.isnan(inputs.input_ids).any(), "inputs.input_ids contains NaN values"
|
742 |
+
|
743 |
+
return inputs.input_ids
|
744 |
+
|
745 |
+
image_transforms = transforms.Compose(
|
746 |
+
[
|
747 |
+
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
748 |
+
transforms.CenterCrop(args.resolution),
|
749 |
+
transforms.ToTensor(),
|
750 |
+
transforms.Normalize([0.5], [0.5]),
|
751 |
+
]
|
752 |
+
)
|
753 |
+
|
754 |
+
conditioning_image_transforms = transforms.Compose(
|
755 |
+
[
|
756 |
+
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
757 |
+
transforms.CenterCrop(args.resolution),
|
758 |
+
transforms.ToTensor(),
|
759 |
+
]
|
760 |
+
)
|
761 |
+
|
762 |
+
def preprocess_train(examples):
|
763 |
+
images = [image.convert("RGB") for image in examples[image_column]]
|
764 |
+
images = [image_transforms(image) for image in images]
|
765 |
+
conditioning_nd = [Image.open(image).convert("RGB") for image in examples[conditioning_nd_column]]
|
766 |
+
conditioning_nd = [conditioning_image_transforms(image) for image in conditioning_nd]
|
767 |
+
|
768 |
+
conditioning_bg = [Image.open(image).convert("RGB") for image in examples[conditioning_bg_column]]
|
769 |
+
conditioning_bg = [conditioning_image_transforms(image) for image in conditioning_bg]
|
770 |
+
|
771 |
+
examples["pixel_values"] = images
|
772 |
+
examples["conditioning_pixel_values_nd"] = conditioning_nd
|
773 |
+
examples["conditioning_pixel_values_bg"] = conditioning_bg
|
774 |
+
examples["input_ids_nd"] = tokenize_captions(examples, caption_column = caption_column_nd, names = 'nd')
|
775 |
+
examples["input_ids_bg"] = tokenize_captions(examples, caption_column = caption_column_bg, names = 'bg')
|
776 |
+
|
777 |
+
return examples
|
778 |
+
|
779 |
+
with accelerator.main_process_first():
|
780 |
+
if args.max_train_samples is not None:
|
781 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
782 |
+
# Set the training transforms
|
783 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
784 |
+
|
785 |
+
return train_dataset
|
786 |
+
|
787 |
+
|
788 |
+
def collate_fn(examples):
|
789 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
790 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
791 |
+
|
792 |
+
conditioning_pixel_values_nd = torch.stack([example["conditioning_pixel_values_nd"] for example in examples])
|
793 |
+
conditioning_pixel_values_nd = conditioning_pixel_values_nd.to(memory_format=torch.contiguous_format).float()
|
794 |
+
|
795 |
+
conditioning_pixel_values_bg = torch.stack([example["conditioning_pixel_values_bg"] for example in examples])
|
796 |
+
conditioning_pixel_values_bg = conditioning_pixel_values_bg.to(memory_format=torch.contiguous_format).float()
|
797 |
+
|
798 |
+
input_ids_nd = torch.stack([example["input_ids_nd"] for example in examples])
|
799 |
+
input_ids_bg = torch.stack([example["input_ids_bg"] for example in examples])
|
800 |
+
|
801 |
+
return {
|
802 |
+
"pixel_values": pixel_values,
|
803 |
+
"conditioning_pixel_values_nd": conditioning_pixel_values_nd,
|
804 |
+
"conditioning_pixel_values_bg": conditioning_pixel_values_bg,
|
805 |
+
"input_ids_nd": input_ids_nd,
|
806 |
+
"input_ids_bg": input_ids_bg,
|
807 |
+
}
|
808 |
+
|
809 |
+
|
810 |
+
def main(args):
|
811 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
812 |
+
|
813 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
814 |
+
|
815 |
+
accelerator = Accelerator(
|
816 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
817 |
+
mixed_precision=args.mixed_precision,
|
818 |
+
log_with=args.report_to,
|
819 |
+
project_config=accelerator_project_config,
|
820 |
+
)
|
821 |
+
|
822 |
+
# Make one log on every process with the configuration for debugging.
|
823 |
+
logging.basicConfig(
|
824 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
825 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
826 |
+
level=logging.INFO,
|
827 |
+
)
|
828 |
+
logger.info(accelerator.state, main_process_only=False)
|
829 |
+
if accelerator.is_local_main_process:
|
830 |
+
transformers.utils.logging.set_verbosity_warning()
|
831 |
+
diffusers_Tiger.utils.logging.set_verbosity_info()
|
832 |
+
else:
|
833 |
+
transformers.utils.logging.set_verbosity_error()
|
834 |
+
diffusers_Tiger.utils.logging.set_verbosity_error()
|
835 |
+
|
836 |
+
# If passed along, set the training seed now.
|
837 |
+
if args.seed is not None:
|
838 |
+
set_seed(args.seed)
|
839 |
+
|
840 |
+
# Handle the repository creation
|
841 |
+
if accelerator.is_main_process:
|
842 |
+
if args.output_dir is not None:
|
843 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
844 |
+
|
845 |
+
# Load the tokenizer
|
846 |
+
if args.tokenizer_name:
|
847 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
848 |
+
elif args.pretrained_model_name_or_path:
|
849 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
850 |
+
args.pretrained_model_name_or_path,
|
851 |
+
subfolder="tokenizer",
|
852 |
+
revision=args.revision,
|
853 |
+
use_fast=False,
|
854 |
+
)
|
855 |
+
|
856 |
+
# import correct text encoder class
|
857 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
858 |
+
|
859 |
+
# Load scheduler and models
|
860 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
861 |
+
text_encoder = text_encoder_cls.from_pretrained(
|
862 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
863 |
+
)
|
864 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
865 |
+
unet = UNet2DConditionModel.from_pretrained(
|
866 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
867 |
+
)
|
868 |
+
|
869 |
+
if args.controlnet_model_name_or_path:
|
870 |
+
logger.info("Loading existing controlnet weights")
|
871 |
+
controlnet_nd = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
|
872 |
+
controlnet_bg = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
|
873 |
+
else:
|
874 |
+
logger.info("Initializing controlnet weights from unet")
|
875 |
+
controlnet_nd = ControlNetModel.from_unet(unet)
|
876 |
+
controlnet_bg = ControlNetModel.from_unet(unet)
|
877 |
+
|
878 |
+
|
879 |
+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
880 |
+
def save_model_hook(models, weights, output_dir):
|
881 |
+
weights.pop()
|
882 |
+
model1 = models[0]
|
883 |
+
sub_dir = "controlnet_nd"
|
884 |
+
model1.save_pretrained(os.path.join(output_dir, sub_dir))
|
885 |
+
|
886 |
+
|
887 |
+
def load_model_hook(models, input_dir):
|
888 |
+
while len(models) > 0:
|
889 |
+
# pop models so that they are not loaded again
|
890 |
+
model = models.pop()
|
891 |
+
|
892 |
+
# load diffusers123 style into model
|
893 |
+
load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
|
894 |
+
model.register_to_config(**load_model.config)
|
895 |
+
|
896 |
+
model.load_state_dict(load_model.state_dict())
|
897 |
+
del load_model
|
898 |
+
|
899 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
900 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
901 |
+
|
902 |
+
vae.requires_grad_(False)
|
903 |
+
unet.requires_grad_(False)
|
904 |
+
text_encoder.requires_grad_(False)
|
905 |
+
controlnet_nd.requires_grad_(True).train()
|
906 |
+
controlnet_bg.requires_grad_(True).train()
|
907 |
+
|
908 |
+
if args.gradient_checkpointing:
|
909 |
+
controlnet_nd.enable_gradient_checkpointing()
|
910 |
+
controlnet_bg.enable_gradient_checkpointing()
|
911 |
+
|
912 |
+
# Check that all trainable models are in full precision
|
913 |
+
low_precision_error_string = (
|
914 |
+
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
|
915 |
+
" doing mixed precision training, copy of the weights should still be float32."
|
916 |
+
)
|
917 |
+
|
918 |
+
if accelerator.unwrap_model(controlnet_nd).dtype != torch.float32:
|
919 |
+
raise ValueError(
|
920 |
+
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_nd).dtype}. {low_precision_error_string}"
|
921 |
+
)
|
922 |
+
if accelerator.unwrap_model(controlnet_bg).dtype != torch.float32:
|
923 |
+
raise ValueError(
|
924 |
+
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_bg).dtype}. {low_precision_error_string}"
|
925 |
+
)
|
926 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
927 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
928 |
+
if args.allow_tf32:
|
929 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
930 |
+
|
931 |
+
if args.scale_lr:
|
932 |
+
args.learning_rate = (
|
933 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
934 |
+
)
|
935 |
+
|
936 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
937 |
+
if args.use_8bit_adam:
|
938 |
+
try:
|
939 |
+
import bitsandbytes as bnb
|
940 |
+
except ImportError:
|
941 |
+
raise ImportError(
|
942 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
943 |
+
)
|
944 |
+
|
945 |
+
optimizer_class = bnb.optim.AdamW8bit
|
946 |
+
else:
|
947 |
+
optimizer_class = torch.optim.AdamW
|
948 |
+
|
949 |
+
# Optimizer creation
|
950 |
+
params_to_optimize_nd = controlnet_nd.parameters()
|
951 |
+
params_to_optimize_bg = controlnet_bg.parameters()
|
952 |
+
|
953 |
+
optimizer_nd = optimizer_class(
|
954 |
+
params_to_optimize_nd,
|
955 |
+
lr=args.learning_rate,
|
956 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
957 |
+
weight_decay=args.adam_weight_decay,
|
958 |
+
eps=args.adam_epsilon,
|
959 |
+
)
|
960 |
+
optimizer_bg = optimizer_class(
|
961 |
+
params_to_optimize_bg,
|
962 |
+
lr=args.learning_rate,
|
963 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
964 |
+
weight_decay=args.adam_weight_decay,
|
965 |
+
eps=args.adam_epsilon,
|
966 |
+
)
|
967 |
+
train_dataset = make_train_dataset(args, tokenizer, accelerator)
|
968 |
+
|
969 |
+
train_dataloader = torch.utils.data.DataLoader(
|
970 |
+
train_dataset,
|
971 |
+
shuffle=True,
|
972 |
+
collate_fn=collate_fn,
|
973 |
+
batch_size=args.train_batch_size,
|
974 |
+
num_workers=args.dataloader_num_workers,
|
975 |
+
)
|
976 |
+
|
977 |
+
# Scheduler and math around the number of training steps.
|
978 |
+
overrode_max_train_steps = False
|
979 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
980 |
+
if args.max_train_steps is None:
|
981 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
982 |
+
overrode_max_train_steps = True
|
983 |
+
|
984 |
+
lr_scheduler = get_scheduler(
|
985 |
+
args.lr_scheduler,
|
986 |
+
optimizer=optimizer_nd,
|
987 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
988 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
989 |
+
num_cycles=args.lr_num_cycles,
|
990 |
+
power=args.lr_power)
|
991 |
+
|
992 |
+
# Prepare everything with our `accelerator`.
|
993 |
+
controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler = accelerator.prepare(
|
994 |
+
controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler
|
995 |
+
)
|
996 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
997 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
998 |
+
weight_dtype = torch.float32
|
999 |
+
if accelerator.mixed_precision == "fp16":
|
1000 |
+
weight_dtype = torch.float16
|
1001 |
+
elif accelerator.mixed_precision == "bf16":
|
1002 |
+
weight_dtype = torch.bfloat16
|
1003 |
+
|
1004 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
1005 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
1006 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
1007 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
1008 |
+
|
1009 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
1010 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1011 |
+
if overrode_max_train_steps:
|
1012 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
1013 |
+
# Afterwards we recalculate our number of training epochs
|
1014 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
1015 |
+
|
1016 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
1017 |
+
# The trackers initializes automatically on the main process.
|
1018 |
+
if accelerator.is_main_process:
|
1019 |
+
tracker_config = dict(vars(args))
|
1020 |
+
|
1021 |
+
# tensorboard cannot handle list types for config
|
1022 |
+
tracker_config.pop("validation_prompt")
|
1023 |
+
tracker_config.pop("validation_image")
|
1024 |
+
|
1025 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
1026 |
+
|
1027 |
+
# Train!
|
1028 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
1029 |
+
|
1030 |
+
logger.info("***** Running training *****")
|
1031 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
1032 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
1033 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
1034 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
1035 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
1036 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
1037 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
1038 |
+
global_step = 0
|
1039 |
+
first_epoch = 0
|
1040 |
+
|
1041 |
+
# Potentially load in the weights and states from a previous save
|
1042 |
+
if args.resume_from_checkpoint:
|
1043 |
+
if args.resume_from_checkpoint != "latest":
|
1044 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
1045 |
+
else:
|
1046 |
+
# Get the most recent checkpoint
|
1047 |
+
dirs = os.listdir(args.output_dir)
|
1048 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
1049 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
1050 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
1051 |
+
|
1052 |
+
if path is None:
|
1053 |
+
accelerator.print(
|
1054 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
1055 |
+
)
|
1056 |
+
args.resume_from_checkpoint = None
|
1057 |
+
initial_global_step = 0
|
1058 |
+
else:
|
1059 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
1060 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
1061 |
+
global_step = int(path.split("-")[1])
|
1062 |
+
|
1063 |
+
initial_global_step = global_step
|
1064 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
1065 |
+
else:
|
1066 |
+
initial_global_step = 0
|
1067 |
+
|
1068 |
+
progress_bar = tqdm(
|
1069 |
+
range(0, args.max_train_steps),
|
1070 |
+
initial=initial_global_step,
|
1071 |
+
desc="Steps",
|
1072 |
+
# Only show the progress bar once on each machine.
|
1073 |
+
disable=not accelerator.is_local_main_process,
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
image_logs = None
|
1077 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
1078 |
+
for step, batch in enumerate(train_dataloader):
|
1079 |
+
# with accelerator.accumulate(controlnet_nd):
|
1080 |
+
# Convert images to latent space
|
1081 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
1082 |
+
latents = latents * vae.config.scaling_factor
|
1083 |
+
# Sample noise that we'll add to the latents
|
1084 |
+
noise = torch.randn_like(latents)
|
1085 |
+
bsz = latents.shape[0]
|
1086 |
+
# Sample a random timestep for each image
|
1087 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
1088 |
+
timesteps = timesteps.long()
|
1089 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
1090 |
+
# (this is the forward diffusion process)
|
1091 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
1092 |
+
# Get the text embedding for conditioning
|
1093 |
+
|
1094 |
+
weight_nd = batch["input_ids_nd"][:, 0]
|
1095 |
+
weight_nd = weight_nd / 10**5
|
1096 |
+
batch["input_ids_nd"][:, 0] = 49406
|
1097 |
+
encoder_hidden_states_nd = text_encoder(batch["input_ids_nd"])[0]
|
1098 |
+
encoder_hidden_states_bg = text_encoder(batch["input_ids_bg"])[0]
|
1099 |
+
controlnet_image_nd = batch["conditioning_pixel_values_nd"].to(dtype=weight_dtype)
|
1100 |
+
controlnet_image_bg = batch["conditioning_pixel_values_bg"].to(dtype=weight_dtype)
|
1101 |
+
# print(weight_nd)
|
1102 |
+
down_block_res_samples_nd, mid_block_res_sample_nd = controlnet_nd(
|
1103 |
+
noisy_latents,
|
1104 |
+
timesteps,
|
1105 |
+
encoder_hidden_states=encoder_hidden_states_nd, # text
|
1106 |
+
controlnet_cond=controlnet_image_nd,
|
1107 |
+
return_dict=False,
|
1108 |
+
weight=weight_nd)
|
1109 |
+
|
1110 |
+
|
1111 |
+
|
1112 |
+
down_block_res_samples_bg, mid_block_res_sample_bg = controlnet_bg(
|
1113 |
+
noisy_latents,
|
1114 |
+
timesteps,
|
1115 |
+
encoder_hidden_states=encoder_hidden_states_bg, # text
|
1116 |
+
controlnet_cond=controlnet_image_bg,
|
1117 |
+
return_dict=False)
|
1118 |
+
# Predict the noise residual
|
1119 |
+
samples_nd_list, samples_bg_list = [], []
|
1120 |
+
for number in range(len(down_block_res_samples_nd)):
|
1121 |
+
if number > 1 :
|
1122 |
+
sample = down_block_res_samples_nd[number]
|
1123 |
+
samples_nd = torch.stack((down_block_res_samples_nd[number][0].to('cpu'), \
|
1124 |
+
down_block_res_samples_nd[number][0].to('cpu')))
|
1125 |
+
samples_bg = torch.stack((down_block_res_samples_bg[number][0].to('cpu'), \
|
1126 |
+
down_block_res_samples_bg[number][0].to('cpu')))
|
1127 |
+
channels = sample.shape[1]
|
1128 |
+
model_fuse_down = fuse.AFF(channels=channels).to(device='cpu')
|
1129 |
+
output = model_fuse_down(samples_nd, samples_bg)[0].unsqueeze(0)
|
1130 |
+
|
1131 |
+
samples_nd_list.append(output)
|
1132 |
+
samples_bg_list.append(output)
|
1133 |
+
else:
|
1134 |
+
samples_nd_list.append(down_block_res_samples_nd[number])
|
1135 |
+
samples_bg_list.append(down_block_res_samples_bg[number])
|
1136 |
+
mid_block_res_sample = mid_block_res_sample_bg + mid_block_res_sample_nd
|
1137 |
+
model_pred_nd = unet(
|
1138 |
+
noisy_latents,
|
1139 |
+
timesteps,
|
1140 |
+
encoder_hidden_states=encoder_hidden_states_nd.to('cuda'),
|
1141 |
+
down_block_additional_residuals=[
|
1142 |
+
sample.to(dtype=weight_dtype).to('cuda') for sample in samples_nd_list],
|
1143 |
+
mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype),
|
1144 |
+
).sample
|
1145 |
+
model_pred_bg = unet(
|
1146 |
+
noisy_latents,
|
1147 |
+
timesteps,
|
1148 |
+
encoder_hidden_states=encoder_hidden_states_bg.to('cuda'),
|
1149 |
+
down_block_additional_residuals=[
|
1150 |
+
sample.to(dtype=weight_dtype).to('cuda') for sample in samples_bg_list],
|
1151 |
+
mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype),
|
1152 |
+
).sample
|
1153 |
+
# Get the target for loss depending on the prediction type
|
1154 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
1155 |
+
target = noise
|
1156 |
+
elif noise_scheduler.config.prediction_type == "v_prediction": # use
|
1157 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
1158 |
+
else:
|
1159 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
1160 |
+
loss_nd = F.mse_loss(model_pred_nd.to('cuda').float(), target.float(), reduction="mean")
|
1161 |
+
loss_bg = F.mse_loss(model_pred_bg.to('cuda').float(), target.float(), reduction="mean")
|
1162 |
+
optimizer_nd.zero_grad(set_to_none=args.set_grads_to_none)
|
1163 |
+
optimizer_bg.zero_grad(set_to_none=args.set_grads_to_none)
|
1164 |
+
# h0, h1 = nvmlDeviceGetHandleByIndex(0), nvmlDeviceGetHandleByIndex(1)
|
1165 |
+
# info0, info1 = nvmlDeviceGetMemoryInfo(h0), nvmlDeviceGetMemoryInfo(h1)
|
1166 |
+
# print(f'0free : {info0.free} 1free : {info1.free}')
|
1167 |
+
loss = loss_nd + loss_bg
|
1168 |
+
accelerator.backward(loss)
|
1169 |
+
# loss_nd.backward()
|
1170 |
+
# loss_bg.backward()
|
1171 |
+
# if accelerator.sync_gradients:
|
1172 |
+
# params_to_clip_nd = controlnet_nd.parameters()
|
1173 |
+
# accelerator.clip_grad_norm_(params_to_clip_nd, args.max_grad_norm)
|
1174 |
+
# params_to_clip_bg = controlnet_bg.parameters()
|
1175 |
+
# accelerator.clip_grad_norm_(params_to_clip_bg, args.max_grad_norm)
|
1176 |
+
optimizer_nd.step()
|
1177 |
+
|
1178 |
+
optimizer_bg.step()
|
1179 |
+
|
1180 |
+
lr_scheduler.step()
|
1181 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1182 |
+
if accelerator.sync_gradients:
|
1183 |
+
progress_bar.update(1)
|
1184 |
+
global_step += 1
|
1185 |
+
|
1186 |
+
if accelerator.is_main_process:
|
1187 |
+
if global_step % args.checkpointing_steps == 0:
|
1188 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
1189 |
+
if args.checkpoints_total_limit is not None:
|
1190 |
+
checkpoints = os.listdir(args.output_dir)
|
1191 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
1192 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
1193 |
+
|
1194 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
1195 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
1196 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
1197 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
1198 |
+
|
1199 |
+
logger.info(
|
1200 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
1201 |
+
)
|
1202 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
1203 |
+
|
1204 |
+
for removing_checkpoint in removing_checkpoints:
|
1205 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
1206 |
+
shutil.rmtree(removing_checkpoint)
|
1207 |
+
|
1208 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
1209 |
+
accelerator.save_state(save_path)
|
1210 |
+
logger.info(f"Saved state to {save_path}")
|
1211 |
+
|
1212 |
+
# if args.validation_prompt is not None :
|
1213 |
+
# image_logs = log_validation(
|
1214 |
+
# vae,
|
1215 |
+
# text_encoder,
|
1216 |
+
# tokenizer,
|
1217 |
+
# unet,
|
1218 |
+
# controlnet_nd,
|
1219 |
+
# controlnet_bg,
|
1220 |
+
# args,
|
1221 |
+
# accelerator,
|
1222 |
+
# weight_dtype,
|
1223 |
+
# global_step,
|
1224 |
+
# )
|
1225 |
+
|
1226 |
+
logs = {"loss": loss.detach().item()}
|
1227 |
+
progress_bar.set_postfix(**logs)
|
1228 |
+
accelerator.log(logs, step=global_step)
|
1229 |
+
|
1230 |
+
if global_step >= args.max_train_steps:
|
1231 |
+
break
|
1232 |
+
|
1233 |
+
# Create the pipeline using using the trained modules and save it.
|
1234 |
+
# accelerator.wait_for_everyone()
|
1235 |
+
if accelerator.is_main_process:
|
1236 |
+
controlnet_nd = accelerator.unwrap_model(controlnet_nd)
|
1237 |
+
controlnet_nd.save_pretrained(args.output_dir)
|
1238 |
+
controlnet_bg = accelerator.unwrap_model(controlnet_bg)
|
1239 |
+
controlnet_bg.save_pretrained(args.output_dir)
|
1240 |
+
|
1241 |
+
accelerator.end_training()
|
1242 |
+
|
1243 |
+
|
1244 |
+
if __name__ == "__main__":
|
1245 |
+
args = parse_args()
|
1246 |
+
main(args)
|
Tiger Model/GP.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
# import os
|
17 |
+
# import torch
|
18 |
+
# import numpy as np
|
19 |
+
# import torchvision.transforms as transforms
|
20 |
+
# from torch.utils.data import DataLoader, Dataset
|
21 |
+
# from PIL import Image
|
22 |
+
# from gtda.images import Binarizer, HeightFiltration
|
23 |
+
# from gtda.homology import CubicalPersistence
|
24 |
+
# from gtda.diagrams import Amplitude
|
25 |
+
# from sklearn.metrics import pairwise_distances
|
26 |
+
|
27 |
+
|
28 |
+
# transform = transforms.Compose([
|
29 |
+
# transforms.Resize((256, 256)),
|
30 |
+
# transforms.ToTensor(),
|
31 |
+
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
32 |
+
# ])
|
33 |
+
|
34 |
+
|
35 |
+
# class ImageFolderDataset(Dataset):
|
36 |
+
# def __init__(self, folder_path, transform=None):
|
37 |
+
# self.file_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.png')]
|
38 |
+
# self.transform = transform
|
39 |
+
|
40 |
+
# def __len__(self):
|
41 |
+
# return len(self.file_paths)
|
42 |
+
|
43 |
+
# def __getitem__(self, idx):
|
44 |
+
# img_path = self.file_paths[idx]
|
45 |
+
# image = Image.open(img_path).convert('RGB')
|
46 |
+
# if self.transform:
|
47 |
+
# image = self.transform(image)
|
48 |
+
# return image
|
49 |
+
|
50 |
+
# def load_data(folder_path):
|
51 |
+
# dataset = ImageFolderDataset(folder_path, transform=transform)
|
52 |
+
# loader = DataLoader(dataset, batch_size=10, shuffle=False)
|
53 |
+
# return loader
|
54 |
+
|
55 |
+
# # 计算Diversity Score
|
56 |
+
# def calculate_diversity_score(features):
|
57 |
+
# distances = pairwise_distances(features, metric='euclidean')
|
58 |
+
# diversity_score = np.mean(distances)
|
59 |
+
# return diversity_score
|
60 |
+
|
61 |
+
# # 计算Geometry Score
|
62 |
+
# def calculate_geometry_score(images):
|
63 |
+
# binarizer = Binarizer(threshold=0.5)
|
64 |
+
# height_filtration = HeightFiltration(direction=np.array([1, 1, 1]))
|
65 |
+
# cubical_persistence = CubicalPersistence(homology_dimensions=[0, 1], coeff=2)
|
66 |
+
# amplitude = Amplitude(metric='wasserstein', metric_params={'p': 2})
|
67 |
+
|
68 |
+
# # Preprocess images
|
69 |
+
# images = np.array([img.numpy() if isinstance(img, torch.Tensor) else img for img in images])
|
70 |
+
# images_binarized = binarizer.fit_transform(images)
|
71 |
+
# images_filtered = height_filtration.fit_transform(images_binarized)
|
72 |
+
# diagrams = cubical_persistence.fit_transform(images_filtered)
|
73 |
+
# gs_score = amplitude.fit_transform(diagrams)
|
74 |
+
# return gs_score.mean()
|
75 |
+
|
76 |
+
|
77 |
+
# generated_images_loader = load_data('../figure/1')
|
78 |
+
# real_images_loader = load_data('../figure/2')
|
79 |
+
|
80 |
+
|
81 |
+
# generated_features = []
|
82 |
+
# real_features = []
|
83 |
+
|
84 |
+
# for img_batch in generated_images_loader:
|
85 |
+
# generated_features.extend(img_batch.numpy())
|
86 |
+
|
87 |
+
# for img_batch in real_images_loader:
|
88 |
+
# real_features.extend(img_batch.numpy())
|
89 |
+
|
90 |
+
# generated_features = np.array(generated_features)
|
91 |
+
# real_features = np.array(real_features)
|
92 |
+
|
93 |
+
# # 计算Diversity Score
|
94 |
+
# generated_div_score = calculate_diversity_score(generated_features.reshape(len(generated_features), -1))
|
95 |
+
# real_div_score = calculate_diversity_score(real_features.reshape(len(real_features), -1))
|
96 |
+
|
97 |
+
# # 计算Geometry Score
|
98 |
+
# generated_gs_score = calculate_geometry_score(generated_features)
|
99 |
+
# real_gs_score = calculate_geometry_score(real_features)
|
100 |
+
|
101 |
+
# print(f"Generated Images Diversity Score: {generated_div_score}")
|
102 |
+
# print(f"Real Images Diversity Score: {real_div_score}")
|
103 |
+
# print(f"Generated Images Geometry Score: {generated_gs_score}")
|
104 |
+
# print(f"Real Images Geometry Score: {real_gs_score}")
|
105 |
+
|
106 |
+
|
107 |
+
# import torch
|
108 |
+
# import torch.nn.functional as F
|
109 |
+
# from torchvision import transforms
|
110 |
+
# from PIL import Image
|
111 |
+
# import numpy as np
|
112 |
+
# import os
|
113 |
+
|
114 |
+
# # Function to load and preprocess images
|
115 |
+
# def load_and_preprocess_image(img_path):
|
116 |
+
# img = Image.open(img_path).convert('RGB')
|
117 |
+
# preprocess = transforms.Compose([
|
118 |
+
# transforms.ToTensor(),
|
119 |
+
# ])
|
120 |
+
# img = preprocess(img).unsqueeze(0) # Add batch dimension
|
121 |
+
# return img
|
122 |
+
|
123 |
+
# # Function to compute image gradients
|
124 |
+
# def compute_gradients(img):
|
125 |
+
# grad_x = img[:, :, 1:, :] - img[:, :, :-1, :]
|
126 |
+
# grad_y = img[:, :, :, 1:] - img[:, :, :, :-1]
|
127 |
+
# return grad_x, grad_y
|
128 |
+
|
129 |
+
# # Function to calculate Gradient Similarity (GS)
|
130 |
+
# def gradient_similarity(real_img, gen_img):
|
131 |
+
# real_grad_x, real_grad_y = compute_gradients(real_img)
|
132 |
+
# gen_grad_x, gen_grad_y = compute_gradients(gen_img)
|
133 |
+
|
134 |
+
# grad_sim_x = F.cosine_similarity(real_grad_x, gen_grad_x, dim=1).mean()
|
135 |
+
# grad_sim_y = F.cosine_similarity(real_grad_y, gen_grad_y, dim=1).mean()
|
136 |
+
|
137 |
+
# gs = (grad_sim_x + grad_sim_y) / 2.0
|
138 |
+
# return gs.item()
|
139 |
+
|
140 |
+
# # Example usage
|
141 |
+
# real_img_dir = '../GS/real' # Replace with your real image directory
|
142 |
+
# gen_img_dir = '../GS/fake' # Replace with your generated image directory
|
143 |
+
|
144 |
+
# real_img_paths = [os.path.join(real_img_dir, fname) for fname in os.listdir(real_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
|
145 |
+
# gen_img_paths = [os.path.join(gen_img_dir, fname) for fname in os.listdir(gen_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
|
146 |
+
|
147 |
+
# # Ensure both directories have the same number of images
|
148 |
+
# assert len(real_img_paths) == len(gen_img_paths), "The number of images in both directories must be the same"
|
149 |
+
|
150 |
+
# gs_scores = []
|
151 |
+
|
152 |
+
# for real_img_path, gen_img_path in zip(real_img_paths, gen_img_paths):
|
153 |
+
# real_img = load_and_preprocess_image(real_img_path)
|
154 |
+
# gen_img = load_and_preprocess_image(gen_img_path)
|
155 |
+
|
156 |
+
# gs = gradient_similarity(real_img, gen_img)
|
157 |
+
# gs_scores.append(gs)
|
158 |
+
|
159 |
+
# print(f'Processed {real_img_path} and {gen_img_path}: GS = {gs:.3e}')
|
160 |
+
|
161 |
+
# mean_gs = np.mean(gs_scores)
|
162 |
+
# print(f'Mean Gradient Similarity (GS) score: {mean_gs:.3e}')
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
import torch
|
167 |
+
import torch.nn as nn
|
168 |
+
from torchvision import models, transforms
|
169 |
+
from PIL import Image
|
170 |
+
import numpy as np
|
171 |
+
import os
|
172 |
+
from prdc import compute_prdc
|
173 |
+
|
174 |
+
# Function to load and preprocess images
|
175 |
+
def load_and_preprocess_image(img_path):
|
176 |
+
img = Image.open(img_path).convert('RGB')
|
177 |
+
preprocess = transforms.Compose([
|
178 |
+
transforms.Resize((299, 299)),
|
179 |
+
transforms.ToTensor(),
|
180 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
181 |
+
])
|
182 |
+
img = preprocess(img).unsqueeze(0) # Add batch dimension
|
183 |
+
return img
|
184 |
+
|
185 |
+
# Function to extract features using InceptionV3
|
186 |
+
def extract_features(img_paths, model):
|
187 |
+
features = []
|
188 |
+
with torch.no_grad():
|
189 |
+
for img_path in img_paths:
|
190 |
+
img = load_and_preprocess_image(img_path)
|
191 |
+
feature = model(img).numpy().squeeze()
|
192 |
+
features.append(feature)
|
193 |
+
features = np.array(features)
|
194 |
+
return features
|
195 |
+
|
196 |
+
# Load the InceptionV3 model
|
197 |
+
model = models.resnet18(pretrained=False)
|
198 |
+
model.load_state_dict(torch.load('../modelsaved/Pretrained_InceptionV3.pth', map_location=lambda storage, loc: storage),strict=False)
|
199 |
+
model.fc = nn.Identity() # Remove the final classification layer
|
200 |
+
model.eval()
|
201 |
+
|
202 |
+
# Example usage
|
203 |
+
real_img_dir = '../dataset/1' # Replace with your real image directory
|
204 |
+
gen_img_dir = '../dataset/2' # Replace with your generated image directory
|
205 |
+
|
206 |
+
real_img_paths = [os.path.join(real_img_dir, fname) for fname in os.listdir(real_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
|
207 |
+
gen_img_paths = [os.path.join(gen_img_dir, fname) for fname in os.listdir(gen_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
|
208 |
+
|
209 |
+
# Extract features for real and generated images
|
210 |
+
real_features = extract_features(real_img_paths, model)
|
211 |
+
gen_features = extract_features(gen_img_paths, model)
|
212 |
+
|
213 |
+
# Calculate PRDC metrics
|
214 |
+
metrics = compute_prdc(real_features=real_features,
|
215 |
+
fake_features=gen_features,
|
216 |
+
nearest_k=2)
|
217 |
+
|
218 |
+
print(metrics)
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
# import torch
|
223 |
+
# from torch import nn
|
224 |
+
# from clip import clip
|
225 |
+
# import numpy as np
|
226 |
+
|
227 |
+
|
228 |
+
# clip_model, preprocess = clip.load("ViT-L/14@336px", device="cuda")
|
229 |
+
|
230 |
+
|
231 |
+
# def get_clip_embedding(images):
|
232 |
+
# with torch.no_grad():
|
233 |
+
# images = preprocess(images).unsqueeze(0).to("cuda")
|
234 |
+
# image_features = clip_model.encode_image(images)
|
235 |
+
# return image_features
|
236 |
+
|
237 |
+
|
238 |
+
# def compute_mmd(x, y, kernel):
|
239 |
+
|
240 |
+
# xx = kernel(x, x)
|
241 |
+
# yy = kernel(y, y)
|
242 |
+
# xy = kernel(x, y)
|
243 |
+
|
244 |
+
# mmd = torch.mean(xx) + torch.mean(yy) - 2 * torch.mean(xy)
|
245 |
+
# return mmd
|
246 |
+
|
247 |
+
# def gaussian_rbf_kernel(x, y, sigma=1.0):
|
248 |
+
|
249 |
+
# dist = torch.cdist(x, y, p=2.0)
|
250 |
+
|
251 |
+
# return torch.exp(-dist**2 / (2 * sigma**2))
|
252 |
+
|
253 |
+
|
254 |
+
# real_images = ...
|
255 |
+
# generated_images = ...
|
256 |
+
|
257 |
+
|
258 |
+
# real_features = get_clip_embedding(real_images)
|
259 |
+
# generated_features = get_clip_embedding(generated_images)
|
260 |
+
|
261 |
+
|
262 |
+
# sigma = 1.0
|
263 |
+
# mmd = compute_mmd(real_features, generated_features, lambda x, y: gaussian_rbf_kernel(x, y, sigma))
|
264 |
+
# cmmd = mmd * 1000
|
265 |
+
|
266 |
+
# print(f"CMMD: {cmmd.item()}")
|
Tiger Model/IS.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from datasets import *
|
16 |
+
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
from torch.utils.data import DataLoader
|
23 |
+
from torch.autograd import Variable
|
24 |
+
from torch.nn import functional as F
|
25 |
+
import torch.utils.data
|
26 |
+
from scipy.stats import entropy
|
27 |
+
from torchvision.models.inception import inception_v3
|
28 |
+
|
29 |
+
import os
|
30 |
+
import glob
|
31 |
+
import random
|
32 |
+
import os
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
from torch.utils.data import Dataset
|
36 |
+
from PIL import Image
|
37 |
+
import torchvision.transforms as transforms
|
38 |
+
|
39 |
+
class ISImageDataset(Dataset):
|
40 |
+
def __init__(self, root, transforms_=None):
|
41 |
+
self.transform = transforms.Compose(transforms_)
|
42 |
+
|
43 |
+
self.files = sorted(glob.glob(os.path.join(root) + "/*.png"))
|
44 |
+
|
45 |
+
def __getitem__(self, index):
|
46 |
+
img = Image.open(self.files[index % len(self.files)]).convert('RGB')
|
47 |
+
item_image = self.transform(img)
|
48 |
+
return item_image
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.files)
|
52 |
+
|
53 |
+
path = '.../Figure/'
|
54 |
+
count = 0
|
55 |
+
for root,dirs,files in os.walk(path):
|
56 |
+
for each in files:
|
57 |
+
count += 1
|
58 |
+
print(count)
|
59 |
+
batch_size = 64
|
60 |
+
transforms_ = [
|
61 |
+
transforms.Resize((256, 256)),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
64 |
+
]
|
65 |
+
|
66 |
+
val_dataloader = DataLoader(
|
67 |
+
ISImageDataset(path, transforms_=transforms_),
|
68 |
+
batch_size = batch_size,
|
69 |
+
)
|
70 |
+
|
71 |
+
cuda = False if torch.cuda.is_available() else False
|
72 |
+
print('cuda: ',cuda)
|
73 |
+
tensor = torch.cuda.FloatTensor
|
74 |
+
|
75 |
+
inception_model = inception_v3(pretrained=True, transform_input=False).cuda()
|
76 |
+
inception_model.eval()
|
77 |
+
up = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False).cuda()
|
78 |
+
|
79 |
+
def get_pred(x):
|
80 |
+
if True:
|
81 |
+
x = up(x)
|
82 |
+
x = inception_model(x)
|
83 |
+
return F.softmax(x, dim=1).data.cpu().numpy()
|
84 |
+
|
85 |
+
print('Computing predictions using inception v3 model')
|
86 |
+
preds = np.zeros((count, 1000))
|
87 |
+
|
88 |
+
for i, data in enumerate(val_dataloader):
|
89 |
+
data = data.type(tensor)
|
90 |
+
batch_size_i = data.size()[0]
|
91 |
+
preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(data)
|
92 |
+
|
93 |
+
print('Computing KL Divergence')
|
94 |
+
split_scores = []
|
95 |
+
splits=10
|
96 |
+
N = count
|
97 |
+
for k in range(splits):
|
98 |
+
part = preds[k * (N // splits): (k + 1) * (N // splits), :]
|
99 |
+
py = np.mean(part, axis=0)
|
100 |
+
scores = []
|
101 |
+
for i in range(part.shape[0]):
|
102 |
+
pyx = part[i, :]
|
103 |
+
scores.append(entropy(pyx, py))
|
104 |
+
split_scores.append(np.exp(np.mean(scores)))
|
105 |
+
|
106 |
+
|
107 |
+
mean, std = np.mean(split_scores), np.std(split_scores)
|
108 |
+
print('IS is %.4f' % mean)
|
109 |
+
print('The std is %.4f' % std)
|
Tiger Model/diffusiers-Tiger/CLIPTextModel.py
ADDED
@@ -0,0 +1,1326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
""" PyTorch CLIP model."""
|
18 |
+
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Any, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
from ...activations import ACT2FN
|
28 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
29 |
+
from ...modeling_utils import PreTrainedModel
|
30 |
+
from ...utils import (
|
31 |
+
ModelOutput,
|
32 |
+
add_start_docstrings,
|
33 |
+
add_start_docstrings_to_model_forward,
|
34 |
+
logging,
|
35 |
+
replace_return_docstrings,
|
36 |
+
)
|
37 |
+
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
43 |
+
|
44 |
+
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
45 |
+
"openai/clip-vit-base-patch32",
|
46 |
+
# See all CLIP models at https://huggingface.co/models?filter=clip
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
51 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
52 |
+
"""
|
53 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
54 |
+
"""
|
55 |
+
bsz, src_len = mask.size()
|
56 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
57 |
+
|
58 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
59 |
+
|
60 |
+
inverted_mask = 1.0 - expanded_mask
|
61 |
+
|
62 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
63 |
+
|
64 |
+
|
65 |
+
# contrastive loss function, adapted from
|
66 |
+
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
67 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
68 |
+
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
69 |
+
|
70 |
+
|
71 |
+
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
72 |
+
caption_loss = contrastive_loss(similarity)
|
73 |
+
image_loss = contrastive_loss(similarity.t())
|
74 |
+
return (caption_loss + image_loss) / 2.0
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class CLIPVisionModelOutput(ModelOutput):
|
79 |
+
"""
|
80 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
84 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
85 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
86 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
87 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
88 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
89 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
90 |
+
|
91 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
92 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
93 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
94 |
+
sequence_length)`.
|
95 |
+
|
96 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
97 |
+
heads.
|
98 |
+
"""
|
99 |
+
|
100 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
101 |
+
last_hidden_state: torch.FloatTensor = None
|
102 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
103 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class CLIPTextModelOutput(ModelOutput):
|
108 |
+
"""
|
109 |
+
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
113 |
+
The text embeddings obtained by applying the projection layer to the pooler_output.
|
114 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
115 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
116 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
117 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
118 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
119 |
+
|
120 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
121 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
122 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
123 |
+
sequence_length)`.
|
124 |
+
|
125 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
126 |
+
heads.
|
127 |
+
"""
|
128 |
+
|
129 |
+
text_embeds: Optional[torch.FloatTensor] = None
|
130 |
+
last_hidden_state: torch.FloatTensor = None
|
131 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
132 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
133 |
+
|
134 |
+
|
135 |
+
@dataclass
|
136 |
+
class CLIPOutput(ModelOutput):
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
140 |
+
Contrastive loss for image-text similarity.
|
141 |
+
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
142 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
143 |
+
similarity scores.
|
144 |
+
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
145 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
146 |
+
similarity scores.
|
147 |
+
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
148 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
|
149 |
+
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
150 |
+
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
151 |
+
text_model_output(`BaseModelOutputWithPooling`):
|
152 |
+
The output of the [`CLIPTextModel`].
|
153 |
+
vision_model_output(`BaseModelOutputWithPooling`):
|
154 |
+
The output of the [`CLIPVisionModel`].
|
155 |
+
"""
|
156 |
+
|
157 |
+
loss: Optional[torch.FloatTensor] = None
|
158 |
+
logits_per_image: torch.FloatTensor = None
|
159 |
+
logits_per_text: torch.FloatTensor = None
|
160 |
+
text_embeds: torch.FloatTensor = None
|
161 |
+
image_embeds: torch.FloatTensor = None
|
162 |
+
text_model_output: BaseModelOutputWithPooling = None
|
163 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
164 |
+
|
165 |
+
def to_tuple(self) -> Tuple[Any]:
|
166 |
+
return tuple(
|
167 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
168 |
+
for k in self.keys()
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
class CLIPVisionEmbeddings(nn.Module):
|
173 |
+
def __init__(self, config: CLIPVisionConfig):
|
174 |
+
super().__init__()
|
175 |
+
self.config = config
|
176 |
+
self.embed_dim = config.hidden_size
|
177 |
+
self.image_size = config.image_size
|
178 |
+
self.patch_size = config.patch_size
|
179 |
+
|
180 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
181 |
+
|
182 |
+
self.patch_embedding = nn.Conv2d(
|
183 |
+
in_channels=config.num_channels,
|
184 |
+
out_channels=self.embed_dim,
|
185 |
+
kernel_size=self.patch_size,
|
186 |
+
stride=self.patch_size,
|
187 |
+
bias=False,
|
188 |
+
)
|
189 |
+
|
190 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
191 |
+
self.num_positions = self.num_patches + 1
|
192 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
193 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
194 |
+
|
195 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
196 |
+
batch_size = pixel_values.shape[0]
|
197 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
198 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
199 |
+
|
200 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
201 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
202 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
203 |
+
return embeddings
|
204 |
+
|
205 |
+
|
206 |
+
class CLIPTextEmbeddings(nn.Module):
|
207 |
+
def __init__(self, config: CLIPTextConfig):
|
208 |
+
super().__init__()
|
209 |
+
embed_dim = config.hidden_size
|
210 |
+
|
211 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
212 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
213 |
+
|
214 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
215 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
input_ids: Optional[torch.LongTensor] = None,
|
220 |
+
position_ids: Optional[torch.LongTensor] = None,
|
221 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
222 |
+
) -> torch.Tensor:
|
223 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
224 |
+
|
225 |
+
if position_ids is None:
|
226 |
+
position_ids = self.position_ids[:, :seq_length]
|
227 |
+
|
228 |
+
if inputs_embeds is None:
|
229 |
+
inputs_embeds = self.token_embedding(input_ids)
|
230 |
+
|
231 |
+
position_embeddings = self.position_embedding(position_ids)
|
232 |
+
embeddings = inputs_embeds + position_embeddings
|
233 |
+
|
234 |
+
return embeddings
|
235 |
+
|
236 |
+
|
237 |
+
class CLIPAttention(nn.Module):
|
238 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
239 |
+
|
240 |
+
def __init__(self, config):
|
241 |
+
super().__init__()
|
242 |
+
self.config = config
|
243 |
+
self.embed_dim = config.hidden_size
|
244 |
+
self.num_heads = config.num_attention_heads
|
245 |
+
self.head_dim = self.embed_dim // self.num_heads
|
246 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
247 |
+
raise ValueError(
|
248 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
249 |
+
f" {self.num_heads})."
|
250 |
+
)
|
251 |
+
self.scale = self.head_dim**-0.5
|
252 |
+
self.dropout = config.attention_dropout
|
253 |
+
|
254 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
255 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
256 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
257 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
258 |
+
|
259 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
260 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
261 |
+
|
262 |
+
def forward(
|
263 |
+
self,
|
264 |
+
hidden_states: torch.Tensor,
|
265 |
+
attention_mask: Optional[torch.Tensor] = None,
|
266 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
267 |
+
output_attentions: Optional[bool] = False,
|
268 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
269 |
+
"""Input shape: Batch x Time x Channel"""
|
270 |
+
|
271 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
272 |
+
|
273 |
+
# get query proj
|
274 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
275 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
276 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
277 |
+
|
278 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
279 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
280 |
+
key_states = key_states.view(*proj_shape)
|
281 |
+
value_states = value_states.view(*proj_shape)
|
282 |
+
|
283 |
+
src_len = key_states.size(1)
|
284 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
285 |
+
|
286 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
287 |
+
raise ValueError(
|
288 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
289 |
+
f" {attn_weights.size()}"
|
290 |
+
)
|
291 |
+
|
292 |
+
# apply the causal_attention_mask first
|
293 |
+
if causal_attention_mask is not None:
|
294 |
+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
295 |
+
raise ValueError(
|
296 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
297 |
+
f" {causal_attention_mask.size()}"
|
298 |
+
)
|
299 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
300 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
301 |
+
|
302 |
+
if attention_mask is not None:
|
303 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
304 |
+
raise ValueError(
|
305 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
306 |
+
)
|
307 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
308 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
309 |
+
|
310 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
311 |
+
|
312 |
+
if output_attentions:
|
313 |
+
# this operation is a bit akward, but it's required to
|
314 |
+
# make sure that attn_weights keeps its gradient.
|
315 |
+
# In order to do so, attn_weights have to reshaped
|
316 |
+
# twice and have to be reused in the following
|
317 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
318 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
319 |
+
else:
|
320 |
+
attn_weights_reshaped = None
|
321 |
+
|
322 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
323 |
+
|
324 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
325 |
+
|
326 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
327 |
+
raise ValueError(
|
328 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
329 |
+
f" {attn_output.size()}"
|
330 |
+
)
|
331 |
+
|
332 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
333 |
+
attn_output = attn_output.transpose(1, 2)
|
334 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
335 |
+
|
336 |
+
attn_output = self.out_proj(attn_output)
|
337 |
+
|
338 |
+
return attn_output, attn_weights_reshaped
|
339 |
+
|
340 |
+
|
341 |
+
class CLIPMLP(nn.Module):
|
342 |
+
def __init__(self, config):
|
343 |
+
super().__init__()
|
344 |
+
self.config = config
|
345 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
346 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
347 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
348 |
+
|
349 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
350 |
+
hidden_states = self.fc1(hidden_states)
|
351 |
+
hidden_states = self.activation_fn(hidden_states)
|
352 |
+
hidden_states = self.fc2(hidden_states)
|
353 |
+
return hidden_states
|
354 |
+
|
355 |
+
|
356 |
+
class CLIPEncoderLayer(nn.Module):
|
357 |
+
def __init__(self, config: CLIPConfig):
|
358 |
+
super().__init__()
|
359 |
+
self.embed_dim = config.hidden_size
|
360 |
+
self.self_attn = CLIPAttention(config)
|
361 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
362 |
+
self.mlp = CLIPMLP(config)
|
363 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
364 |
+
|
365 |
+
def forward(
|
366 |
+
self,
|
367 |
+
hidden_states: torch.Tensor,
|
368 |
+
attention_mask: torch.Tensor,
|
369 |
+
causal_attention_mask: torch.Tensor,
|
370 |
+
output_attentions: Optional[bool] = False,
|
371 |
+
) -> Tuple[torch.FloatTensor]:
|
372 |
+
"""
|
373 |
+
Args:
|
374 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
375 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
376 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
377 |
+
`(config.encoder_attention_heads,)`.
|
378 |
+
output_attentions (`bool`, *optional*):
|
379 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
380 |
+
returned tensors for more detail.
|
381 |
+
"""
|
382 |
+
residual = hidden_states
|
383 |
+
|
384 |
+
hidden_states = self.layer_norm1(hidden_states)
|
385 |
+
hidden_states, attn_weights = self.self_attn(
|
386 |
+
hidden_states=hidden_states,
|
387 |
+
attention_mask=attention_mask,
|
388 |
+
causal_attention_mask=causal_attention_mask,
|
389 |
+
output_attentions=output_attentions,
|
390 |
+
)
|
391 |
+
hidden_states = residual + hidden_states
|
392 |
+
|
393 |
+
residual = hidden_states
|
394 |
+
hidden_states = self.layer_norm2(hidden_states)
|
395 |
+
hidden_states = self.mlp(hidden_states)
|
396 |
+
hidden_states = residual + hidden_states
|
397 |
+
|
398 |
+
outputs = (hidden_states,)
|
399 |
+
|
400 |
+
if output_attentions:
|
401 |
+
outputs += (attn_weights,)
|
402 |
+
|
403 |
+
return outputs
|
404 |
+
|
405 |
+
|
406 |
+
class CLIPPreTrainedModel(PreTrainedModel):
|
407 |
+
"""
|
408 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
409 |
+
models.
|
410 |
+
"""
|
411 |
+
|
412 |
+
config_class = CLIPConfig
|
413 |
+
base_model_prefix = "clip"
|
414 |
+
supports_gradient_checkpointing = True
|
415 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
416 |
+
|
417 |
+
def _init_weights(self, module):
|
418 |
+
"""Initialize the weights"""
|
419 |
+
factor = self.config.initializer_factor
|
420 |
+
if isinstance(module, CLIPTextEmbeddings):
|
421 |
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
422 |
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
423 |
+
elif isinstance(module, CLIPVisionEmbeddings):
|
424 |
+
factor = self.config.initializer_factor
|
425 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
426 |
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
427 |
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
428 |
+
elif isinstance(module, CLIPAttention):
|
429 |
+
factor = self.config.initializer_factor
|
430 |
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
431 |
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
432 |
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
433 |
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
434 |
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
435 |
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
436 |
+
elif isinstance(module, CLIPMLP):
|
437 |
+
factor = self.config.initializer_factor
|
438 |
+
in_proj_std = (
|
439 |
+
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
440 |
+
)
|
441 |
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
442 |
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
443 |
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
444 |
+
elif isinstance(module, CLIPModel):
|
445 |
+
nn.init.normal_(
|
446 |
+
module.text_projection.weight,
|
447 |
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
448 |
+
)
|
449 |
+
nn.init.normal_(
|
450 |
+
module.visual_projection.weight,
|
451 |
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
452 |
+
)
|
453 |
+
elif isinstance(module, CLIPVisionModelWithProjection):
|
454 |
+
nn.init.normal_(
|
455 |
+
module.visual_projection.weight,
|
456 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
457 |
+
)
|
458 |
+
elif isinstance(module, CLIPTextModelWithProjection):
|
459 |
+
nn.init.normal_(
|
460 |
+
module.text_projection.weight,
|
461 |
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
462 |
+
)
|
463 |
+
|
464 |
+
if isinstance(module, nn.LayerNorm):
|
465 |
+
module.bias.data.zero_()
|
466 |
+
module.weight.data.fill_(1.0)
|
467 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
468 |
+
module.bias.data.zero_()
|
469 |
+
|
470 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
471 |
+
if isinstance(module, CLIPEncoder):
|
472 |
+
module.gradient_checkpointing = value
|
473 |
+
|
474 |
+
|
475 |
+
CLIP_START_DOCSTRING = r"""
|
476 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
477 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
478 |
+
etc.)
|
479 |
+
|
480 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
481 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
482 |
+
and behavior.
|
483 |
+
|
484 |
+
Parameters:
|
485 |
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
486 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
487 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
488 |
+
"""
|
489 |
+
|
490 |
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
491 |
+
Args:
|
492 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
493 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
494 |
+
it.
|
495 |
+
|
496 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
497 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
498 |
+
|
499 |
+
[What are input IDs?](../glossary#input-ids)
|
500 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
501 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
502 |
+
|
503 |
+
- 1 for tokens that are **not masked**,
|
504 |
+
- 0 for tokens that are **masked**.
|
505 |
+
|
506 |
+
[What are attention masks?](../glossary#attention-mask)
|
507 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
508 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
509 |
+
config.max_position_embeddings - 1]`.
|
510 |
+
|
511 |
+
[What are position IDs?](../glossary#position-ids)
|
512 |
+
output_attentions (`bool`, *optional*):
|
513 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
514 |
+
tensors for more detail.
|
515 |
+
output_hidden_states (`bool`, *optional*):
|
516 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
517 |
+
more detail.
|
518 |
+
return_dict (`bool`, *optional*):
|
519 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
520 |
+
"""
|
521 |
+
|
522 |
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
523 |
+
Args:
|
524 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
525 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
526 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
527 |
+
output_attentions (`bool`, *optional*):
|
528 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
529 |
+
tensors for more detail.
|
530 |
+
output_hidden_states (`bool`, *optional*):
|
531 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
532 |
+
more detail.
|
533 |
+
return_dict (`bool`, *optional*):
|
534 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
535 |
+
"""
|
536 |
+
|
537 |
+
CLIP_INPUTS_DOCSTRING = r"""
|
538 |
+
Args:
|
539 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
540 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
541 |
+
it.
|
542 |
+
|
543 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
544 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
545 |
+
|
546 |
+
[What are input IDs?](../glossary#input-ids)
|
547 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
548 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
549 |
+
|
550 |
+
- 1 for tokens that are **not masked**,
|
551 |
+
- 0 for tokens that are **masked**.
|
552 |
+
|
553 |
+
[What are attention masks?](../glossary#attention-mask)
|
554 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
555 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
556 |
+
config.max_position_embeddings - 1]`.
|
557 |
+
|
558 |
+
[What are position IDs?](../glossary#position-ids)
|
559 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
560 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
561 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
562 |
+
return_loss (`bool`, *optional*):
|
563 |
+
Whether or not to return the contrastive loss.
|
564 |
+
output_attentions (`bool`, *optional*):
|
565 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
566 |
+
tensors for more detail.
|
567 |
+
output_hidden_states (`bool`, *optional*):
|
568 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
569 |
+
more detail.
|
570 |
+
return_dict (`bool`, *optional*):
|
571 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
572 |
+
"""
|
573 |
+
|
574 |
+
|
575 |
+
class CLIPEncoder(nn.Module):
|
576 |
+
"""
|
577 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
578 |
+
[`CLIPEncoderLayer`].
|
579 |
+
|
580 |
+
Args:
|
581 |
+
config: CLIPConfig
|
582 |
+
"""
|
583 |
+
|
584 |
+
def __init__(self, config: CLIPConfig):
|
585 |
+
super().__init__()
|
586 |
+
self.config = config
|
587 |
+
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
588 |
+
self.gradient_checkpointing = False
|
589 |
+
|
590 |
+
def forward(
|
591 |
+
self,
|
592 |
+
inputs_embeds,
|
593 |
+
attention_mask: Optional[torch.Tensor] = None,
|
594 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
595 |
+
output_attentions: Optional[bool] = None,
|
596 |
+
output_hidden_states: Optional[bool] = None,
|
597 |
+
return_dict: Optional[bool] = None,
|
598 |
+
) -> Union[Tuple, BaseModelOutput]:
|
599 |
+
r"""
|
600 |
+
Args:
|
601 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
602 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
603 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
604 |
+
than the model's internal embedding lookup matrix.
|
605 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
606 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
607 |
+
|
608 |
+
- 1 for tokens that are **not masked**,
|
609 |
+
- 0 for tokens that are **masked**.
|
610 |
+
|
611 |
+
[What are attention masks?](../glossary#attention-mask)
|
612 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
613 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
614 |
+
|
615 |
+
- 1 for tokens that are **not masked**,
|
616 |
+
- 0 for tokens that are **masked**.
|
617 |
+
|
618 |
+
[What are attention masks?](../glossary#attention-mask)
|
619 |
+
output_attentions (`bool`, *optional*):
|
620 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
621 |
+
returned tensors for more detail.
|
622 |
+
output_hidden_states (`bool`, *optional*):
|
623 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
624 |
+
for more detail.
|
625 |
+
return_dict (`bool`, *optional*):
|
626 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
627 |
+
"""
|
628 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
629 |
+
output_hidden_states = (
|
630 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
631 |
+
)
|
632 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
633 |
+
|
634 |
+
encoder_states = () if output_hidden_states else None
|
635 |
+
all_attentions = () if output_attentions else None
|
636 |
+
|
637 |
+
hidden_states = inputs_embeds
|
638 |
+
for idx, encoder_layer in enumerate(self.layers):
|
639 |
+
if output_hidden_states:
|
640 |
+
encoder_states = encoder_states + (hidden_states,)
|
641 |
+
if self.gradient_checkpointing and self.training:
|
642 |
+
|
643 |
+
def create_custom_forward(module):
|
644 |
+
def custom_forward(*inputs):
|
645 |
+
return module(*inputs, output_attentions)
|
646 |
+
|
647 |
+
return custom_forward
|
648 |
+
|
649 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
650 |
+
create_custom_forward(encoder_layer),
|
651 |
+
hidden_states,
|
652 |
+
attention_mask,
|
653 |
+
causal_attention_mask,
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
layer_outputs = encoder_layer(
|
657 |
+
hidden_states,
|
658 |
+
attention_mask,
|
659 |
+
causal_attention_mask,
|
660 |
+
output_attentions=output_attentions,
|
661 |
+
)
|
662 |
+
|
663 |
+
hidden_states = layer_outputs[0]
|
664 |
+
|
665 |
+
if output_attentions:
|
666 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
667 |
+
|
668 |
+
if output_hidden_states:
|
669 |
+
encoder_states = encoder_states + (hidden_states,)
|
670 |
+
|
671 |
+
if not return_dict:
|
672 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
673 |
+
return BaseModelOutput(
|
674 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
675 |
+
)
|
676 |
+
|
677 |
+
|
678 |
+
class CLIPTextTransformer(nn.Module):
|
679 |
+
def __init__(self, config: CLIPTextConfig):
|
680 |
+
super().__init__()
|
681 |
+
self.config = config
|
682 |
+
embed_dim = config.hidden_size
|
683 |
+
self.embeddings = CLIPTextEmbeddings(config)
|
684 |
+
self.encoder = CLIPEncoder(config)
|
685 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
686 |
+
|
687 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
688 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
689 |
+
def forward(
|
690 |
+
self,
|
691 |
+
input_ids: Optional[torch.Tensor] = None,
|
692 |
+
attention_mask: Optional[torch.Tensor] = None,
|
693 |
+
position_ids: Optional[torch.Tensor] = None,
|
694 |
+
output_attentions: Optional[bool] = None,
|
695 |
+
output_hidden_states: Optional[bool] = None,
|
696 |
+
return_dict: Optional[bool] = None,
|
697 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
698 |
+
r"""
|
699 |
+
Returns:
|
700 |
+
|
701 |
+
"""
|
702 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
703 |
+
output_hidden_states = (
|
704 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
705 |
+
)
|
706 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
707 |
+
|
708 |
+
if input_ids is None:
|
709 |
+
raise ValueError("You have to specify input_ids")
|
710 |
+
|
711 |
+
input_shape = input_ids.size()
|
712 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
713 |
+
|
714 |
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
715 |
+
|
716 |
+
bsz, seq_len = input_shape
|
717 |
+
# CLIP's text model uses causal mask, prepare it here.
|
718 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
719 |
+
causal_attention_mask = self._build_causal_attention_mask(
|
720 |
+
bsz, seq_len, hidden_states.dtype, device=hidden_states.device
|
721 |
+
)
|
722 |
+
# expand attention_mask
|
723 |
+
if attention_mask is not None:
|
724 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
725 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
726 |
+
|
727 |
+
encoder_outputs = self.encoder(
|
728 |
+
inputs_embeds=hidden_states,
|
729 |
+
attention_mask=attention_mask,
|
730 |
+
causal_attention_mask=causal_attention_mask,
|
731 |
+
output_attentions=output_attentions,
|
732 |
+
output_hidden_states=output_hidden_states,
|
733 |
+
return_dict=return_dict,
|
734 |
+
)
|
735 |
+
|
736 |
+
last_hidden_state = encoder_outputs[0]
|
737 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
738 |
+
|
739 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
740 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
741 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
742 |
+
pooled_output = last_hidden_state[
|
743 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
744 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
745 |
+
]
|
746 |
+
|
747 |
+
if not return_dict:
|
748 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
749 |
+
|
750 |
+
return BaseModelOutputWithPooling(
|
751 |
+
last_hidden_state=last_hidden_state,
|
752 |
+
pooler_output=pooled_output,
|
753 |
+
hidden_states=encoder_outputs.hidden_states,
|
754 |
+
attentions=encoder_outputs.attentions,
|
755 |
+
)
|
756 |
+
|
757 |
+
def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
|
758 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
759 |
+
# pytorch uses additive attention mask; fill with -inf
|
760 |
+
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
|
761 |
+
mask.fill_(torch.finfo(dtype).min)
|
762 |
+
mask.triu_(1) # zero out the lower diagonal
|
763 |
+
mask = mask.unsqueeze(1) # expand mask
|
764 |
+
return mask
|
765 |
+
|
766 |
+
|
767 |
+
@add_start_docstrings(
|
768 |
+
"""The text model from CLIP without any head or projection on top.""",
|
769 |
+
CLIP_START_DOCSTRING,
|
770 |
+
)
|
771 |
+
class CLIPTextModel(CLIPPreTrainedModel):
|
772 |
+
config_class = CLIPTextConfig
|
773 |
+
|
774 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
775 |
+
|
776 |
+
def __init__(self, config: CLIPTextConfig):
|
777 |
+
super().__init__(config)
|
778 |
+
self.text_model = CLIPTextTransformer(config)
|
779 |
+
# Initialize weights and apply final processing
|
780 |
+
self.post_init()
|
781 |
+
|
782 |
+
def get_input_embeddings(self) -> nn.Module:
|
783 |
+
return self.text_model.embeddings.token_embedding
|
784 |
+
|
785 |
+
def set_input_embeddings(self, value):
|
786 |
+
self.text_model.embeddings.token_embedding = value
|
787 |
+
|
788 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
789 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
790 |
+
def forward(
|
791 |
+
self,
|
792 |
+
input_ids: Optional[torch.Tensor] = None,
|
793 |
+
attention_mask: Optional[torch.Tensor] = None,
|
794 |
+
position_ids: Optional[torch.Tensor] = None,
|
795 |
+
output_attentions: Optional[bool] = None,
|
796 |
+
output_hidden_states: Optional[bool] = None,
|
797 |
+
return_dict: Optional[bool] = None,
|
798 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
799 |
+
r"""
|
800 |
+
Returns:
|
801 |
+
|
802 |
+
Examples:
|
803 |
+
|
804 |
+
```python
|
805 |
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
806 |
+
|
807 |
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
808 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
809 |
+
|
810 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
811 |
+
|
812 |
+
>>> outputs = model(**inputs)
|
813 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
814 |
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
815 |
+
```"""
|
816 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
817 |
+
|
818 |
+
return self.text_model(
|
819 |
+
input_ids=input_ids,
|
820 |
+
attention_mask=attention_mask,
|
821 |
+
position_ids=position_ids,
|
822 |
+
output_attentions=output_attentions,
|
823 |
+
output_hidden_states=output_hidden_states,
|
824 |
+
return_dict=return_dict,
|
825 |
+
)
|
826 |
+
|
827 |
+
|
828 |
+
class CLIPVisionTransformer(nn.Module):
|
829 |
+
def __init__(self, config: CLIPVisionConfig):
|
830 |
+
super().__init__()
|
831 |
+
self.config = config
|
832 |
+
embed_dim = config.hidden_size
|
833 |
+
|
834 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
835 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
836 |
+
self.encoder = CLIPEncoder(config)
|
837 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
838 |
+
|
839 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
840 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
841 |
+
def forward(
|
842 |
+
self,
|
843 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
844 |
+
output_attentions: Optional[bool] = None,
|
845 |
+
output_hidden_states: Optional[bool] = None,
|
846 |
+
return_dict: Optional[bool] = None,
|
847 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
848 |
+
r"""
|
849 |
+
Returns:
|
850 |
+
|
851 |
+
"""
|
852 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
853 |
+
output_hidden_states = (
|
854 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
855 |
+
)
|
856 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
857 |
+
|
858 |
+
if pixel_values is None:
|
859 |
+
raise ValueError("You have to specify pixel_values")
|
860 |
+
|
861 |
+
hidden_states = self.embeddings(pixel_values)
|
862 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
863 |
+
|
864 |
+
encoder_outputs = self.encoder(
|
865 |
+
inputs_embeds=hidden_states,
|
866 |
+
output_attentions=output_attentions,
|
867 |
+
output_hidden_states=output_hidden_states,
|
868 |
+
return_dict=return_dict,
|
869 |
+
)
|
870 |
+
|
871 |
+
last_hidden_state = encoder_outputs[0]
|
872 |
+
pooled_output = last_hidden_state[:, 0, :]
|
873 |
+
pooled_output = self.post_layernorm(pooled_output)
|
874 |
+
|
875 |
+
if not return_dict:
|
876 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
877 |
+
|
878 |
+
return BaseModelOutputWithPooling(
|
879 |
+
last_hidden_state=last_hidden_state,
|
880 |
+
pooler_output=pooled_output,
|
881 |
+
hidden_states=encoder_outputs.hidden_states,
|
882 |
+
attentions=encoder_outputs.attentions,
|
883 |
+
)
|
884 |
+
|
885 |
+
|
886 |
+
@add_start_docstrings(
|
887 |
+
"""The vision model from CLIP without any head or projection on top.""",
|
888 |
+
CLIP_START_DOCSTRING,
|
889 |
+
)
|
890 |
+
class CLIPVisionModel(CLIPPreTrainedModel):
|
891 |
+
config_class = CLIPVisionConfig
|
892 |
+
main_input_name = "pixel_values"
|
893 |
+
|
894 |
+
def __init__(self, config: CLIPVisionConfig):
|
895 |
+
super().__init__(config)
|
896 |
+
self.vision_model = CLIPVisionTransformer(config)
|
897 |
+
# Initialize weights and apply final processing
|
898 |
+
self.post_init()
|
899 |
+
|
900 |
+
def get_input_embeddings(self) -> nn.Module:
|
901 |
+
return self.vision_model.embeddings.patch_embedding
|
902 |
+
|
903 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
904 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
905 |
+
def forward(
|
906 |
+
self,
|
907 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
908 |
+
output_attentions: Optional[bool] = None,
|
909 |
+
output_hidden_states: Optional[bool] = None,
|
910 |
+
return_dict: Optional[bool] = None,
|
911 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
912 |
+
r"""
|
913 |
+
Returns:
|
914 |
+
|
915 |
+
Examples:
|
916 |
+
|
917 |
+
```python
|
918 |
+
>>> from PIL import Image
|
919 |
+
>>> import requests
|
920 |
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
921 |
+
|
922 |
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
923 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
924 |
+
|
925 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
926 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
927 |
+
|
928 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
929 |
+
|
930 |
+
>>> outputs = model(**inputs)
|
931 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
932 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
933 |
+
```"""
|
934 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
935 |
+
|
936 |
+
return self.vision_model(
|
937 |
+
pixel_values=pixel_values,
|
938 |
+
output_attentions=output_attentions,
|
939 |
+
output_hidden_states=output_hidden_states,
|
940 |
+
return_dict=return_dict,
|
941 |
+
)
|
942 |
+
|
943 |
+
|
944 |
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
945 |
+
class CLIPModel(CLIPPreTrainedModel):
|
946 |
+
config_class = CLIPConfig
|
947 |
+
|
948 |
+
def __init__(self, config: CLIPConfig):
|
949 |
+
super().__init__(config)
|
950 |
+
|
951 |
+
if not isinstance(config.text_config, CLIPTextConfig):
|
952 |
+
raise ValueError(
|
953 |
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
954 |
+
f" {type(config.text_config)}."
|
955 |
+
)
|
956 |
+
|
957 |
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
958 |
+
raise ValueError(
|
959 |
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
960 |
+
f" {type(config.vision_config)}."
|
961 |
+
)
|
962 |
+
|
963 |
+
text_config = config.text_config
|
964 |
+
vision_config = config.vision_config
|
965 |
+
|
966 |
+
self.projection_dim = config.projection_dim
|
967 |
+
self.text_embed_dim = text_config.hidden_size
|
968 |
+
self.vision_embed_dim = vision_config.hidden_size
|
969 |
+
|
970 |
+
self.text_model = CLIPTextTransformer(text_config)
|
971 |
+
self.vision_model = CLIPVisionTransformer(vision_config)
|
972 |
+
|
973 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
974 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
975 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
|
976 |
+
|
977 |
+
# Initialize weights and apply final processing
|
978 |
+
self.post_init()
|
979 |
+
|
980 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
981 |
+
def get_text_features(
|
982 |
+
self,
|
983 |
+
input_ids: Optional[torch.Tensor] = None,
|
984 |
+
attention_mask: Optional[torch.Tensor] = None,
|
985 |
+
position_ids: Optional[torch.Tensor] = None,
|
986 |
+
output_attentions: Optional[bool] = None,
|
987 |
+
output_hidden_states: Optional[bool] = None,
|
988 |
+
return_dict: Optional[bool] = None,
|
989 |
+
) -> torch.FloatTensor:
|
990 |
+
r"""
|
991 |
+
Returns:
|
992 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
993 |
+
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
994 |
+
|
995 |
+
Examples:
|
996 |
+
|
997 |
+
```python
|
998 |
+
>>> from transformers import AutoTokenizer, CLIPModel
|
999 |
+
|
1000 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
1001 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
1002 |
+
|
1003 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
1004 |
+
>>> text_features = model.get_text_features(**inputs)
|
1005 |
+
```"""
|
1006 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
1007 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1008 |
+
output_hidden_states = (
|
1009 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1010 |
+
)
|
1011 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1012 |
+
|
1013 |
+
text_outputs = self.text_model(
|
1014 |
+
input_ids=input_ids,
|
1015 |
+
attention_mask=attention_mask,
|
1016 |
+
position_ids=position_ids,
|
1017 |
+
output_attentions=output_attentions,
|
1018 |
+
output_hidden_states=output_hidden_states,
|
1019 |
+
return_dict=return_dict,
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
pooled_output = text_outputs[1]
|
1023 |
+
text_features = self.text_projection(pooled_output)
|
1024 |
+
|
1025 |
+
return text_features
|
1026 |
+
|
1027 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
1028 |
+
def get_image_features(
|
1029 |
+
self,
|
1030 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1031 |
+
output_attentions: Optional[bool] = None,
|
1032 |
+
output_hidden_states: Optional[bool] = None,
|
1033 |
+
return_dict: Optional[bool] = None,
|
1034 |
+
) -> torch.FloatTensor:
|
1035 |
+
r"""
|
1036 |
+
Returns:
|
1037 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
1038 |
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
1039 |
+
|
1040 |
+
Examples:
|
1041 |
+
|
1042 |
+
```python
|
1043 |
+
>>> from PIL import Image
|
1044 |
+
>>> import requests
|
1045 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
1046 |
+
|
1047 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
1048 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1049 |
+
|
1050 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1051 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1052 |
+
|
1053 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1054 |
+
|
1055 |
+
>>> image_features = model.get_image_features(**inputs)
|
1056 |
+
```"""
|
1057 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
1058 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1059 |
+
output_hidden_states = (
|
1060 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1061 |
+
)
|
1062 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1063 |
+
|
1064 |
+
vision_outputs = self.vision_model(
|
1065 |
+
pixel_values=pixel_values,
|
1066 |
+
output_attentions=output_attentions,
|
1067 |
+
output_hidden_states=output_hidden_states,
|
1068 |
+
return_dict=return_dict,
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
pooled_output = vision_outputs[1] # pooled_output
|
1072 |
+
image_features = self.visual_projection(pooled_output)
|
1073 |
+
|
1074 |
+
return image_features
|
1075 |
+
|
1076 |
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
1077 |
+
@replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
|
1078 |
+
def forward(
|
1079 |
+
self,
|
1080 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1081 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1082 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1083 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1084 |
+
return_loss: Optional[bool] = None,
|
1085 |
+
output_attentions: Optional[bool] = None,
|
1086 |
+
output_hidden_states: Optional[bool] = None,
|
1087 |
+
return_dict: Optional[bool] = None,
|
1088 |
+
) -> Union[Tuple, CLIPOutput]:
|
1089 |
+
r"""
|
1090 |
+
Returns:
|
1091 |
+
|
1092 |
+
Examples:
|
1093 |
+
|
1094 |
+
```python
|
1095 |
+
>>> from PIL import Image
|
1096 |
+
>>> import requests
|
1097 |
+
>>> from transformers import AutoProcessor, CLIPModel
|
1098 |
+
|
1099 |
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
1100 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1101 |
+
|
1102 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1103 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1104 |
+
|
1105 |
+
>>> inputs = processor(
|
1106 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
1107 |
+
... )
|
1108 |
+
|
1109 |
+
>>> outputs = model(**inputs)
|
1110 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
1111 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
1112 |
+
```"""
|
1113 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
1114 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1115 |
+
output_hidden_states = (
|
1116 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1117 |
+
)
|
1118 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1119 |
+
|
1120 |
+
vision_outputs = self.vision_model(
|
1121 |
+
pixel_values=pixel_values,
|
1122 |
+
output_attentions=output_attentions,
|
1123 |
+
output_hidden_states=output_hidden_states,
|
1124 |
+
return_dict=return_dict,
|
1125 |
+
)
|
1126 |
+
|
1127 |
+
text_outputs = self.text_model(
|
1128 |
+
input_ids=input_ids,
|
1129 |
+
attention_mask=attention_mask,
|
1130 |
+
position_ids=position_ids,
|
1131 |
+
output_attentions=output_attentions,
|
1132 |
+
output_hidden_states=output_hidden_states,
|
1133 |
+
return_dict=return_dict,
|
1134 |
+
)
|
1135 |
+
|
1136 |
+
image_embeds = vision_outputs[1]
|
1137 |
+
image_embeds = self.visual_projection(image_embeds)
|
1138 |
+
|
1139 |
+
text_embeds = text_outputs[1]
|
1140 |
+
text_embeds = self.text_projection(text_embeds)
|
1141 |
+
|
1142 |
+
# normalized features
|
1143 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
1144 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
1145 |
+
|
1146 |
+
# cosine similarity as logits
|
1147 |
+
logit_scale = self.logit_scale.exp()
|
1148 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
1149 |
+
logits_per_image = logits_per_text.t()
|
1150 |
+
|
1151 |
+
loss = None
|
1152 |
+
if return_loss:
|
1153 |
+
loss = clip_loss(logits_per_text)
|
1154 |
+
|
1155 |
+
if not return_dict:
|
1156 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
1157 |
+
return ((loss,) + output) if loss is not None else output
|
1158 |
+
|
1159 |
+
return CLIPOutput(
|
1160 |
+
loss=loss,
|
1161 |
+
logits_per_image=logits_per_image,
|
1162 |
+
logits_per_text=logits_per_text,
|
1163 |
+
text_embeds=text_embeds,
|
1164 |
+
image_embeds=image_embeds,
|
1165 |
+
text_model_output=text_outputs,
|
1166 |
+
vision_model_output=vision_outputs,
|
1167 |
+
)
|
1168 |
+
|
1169 |
+
|
1170 |
+
@add_start_docstrings(
|
1171 |
+
"""
|
1172 |
+
CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
|
1173 |
+
""",
|
1174 |
+
CLIP_START_DOCSTRING,
|
1175 |
+
)
|
1176 |
+
class CLIPTextModelWithProjection(CLIPPreTrainedModel):
|
1177 |
+
config_class = CLIPTextConfig
|
1178 |
+
|
1179 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
1180 |
+
|
1181 |
+
def __init__(self, config: CLIPTextConfig):
|
1182 |
+
super().__init__(config)
|
1183 |
+
|
1184 |
+
self.text_model = CLIPTextTransformer(config)
|
1185 |
+
|
1186 |
+
self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
1187 |
+
|
1188 |
+
# Initialize weights and apply final processing
|
1189 |
+
self.post_init()
|
1190 |
+
|
1191 |
+
def get_input_embeddings(self) -> nn.Module:
|
1192 |
+
return self.text_model.embeddings.token_embedding
|
1193 |
+
|
1194 |
+
def set_input_embeddings(self, value):
|
1195 |
+
self.text_model.embeddings.token_embedding = value
|
1196 |
+
|
1197 |
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
1198 |
+
@replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
|
1199 |
+
def forward(
|
1200 |
+
self,
|
1201 |
+
input_ids: Optional[torch.Tensor] = None,
|
1202 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1203 |
+
position_ids: Optional[torch.Tensor] = None,
|
1204 |
+
output_attentions: Optional[bool] = None,
|
1205 |
+
output_hidden_states: Optional[bool] = None,
|
1206 |
+
return_dict: Optional[bool] = None,
|
1207 |
+
) -> Union[Tuple, CLIPTextModelOutput]:
|
1208 |
+
r"""
|
1209 |
+
Returns:
|
1210 |
+
|
1211 |
+
Examples:
|
1212 |
+
|
1213 |
+
```python
|
1214 |
+
>>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
1215 |
+
|
1216 |
+
>>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
1217 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
1218 |
+
|
1219 |
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
1220 |
+
|
1221 |
+
>>> outputs = model(**inputs)
|
1222 |
+
>>> text_embeds = outputs.text_embeds
|
1223 |
+
```"""
|
1224 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1225 |
+
|
1226 |
+
text_outputs = self.text_model(
|
1227 |
+
input_ids=input_ids,
|
1228 |
+
attention_mask=attention_mask,
|
1229 |
+
position_ids=position_ids,
|
1230 |
+
output_attentions=output_attentions,
|
1231 |
+
output_hidden_states=output_hidden_states,
|
1232 |
+
return_dict=return_dict,
|
1233 |
+
)
|
1234 |
+
|
1235 |
+
pooled_output = text_outputs[1]
|
1236 |
+
|
1237 |
+
text_embeds = self.text_projection(pooled_output)
|
1238 |
+
|
1239 |
+
if not return_dict:
|
1240 |
+
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
|
1241 |
+
return tuple(output for output in outputs if output is not None)
|
1242 |
+
|
1243 |
+
return CLIPTextModelOutput(
|
1244 |
+
text_embeds=text_embeds,
|
1245 |
+
last_hidden_state=text_outputs.last_hidden_state,
|
1246 |
+
hidden_states=text_outputs.hidden_states,
|
1247 |
+
attentions=text_outputs.attentions,
|
1248 |
+
)
|
1249 |
+
|
1250 |
+
|
1251 |
+
@add_start_docstrings(
|
1252 |
+
"""
|
1253 |
+
CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
|
1254 |
+
""",
|
1255 |
+
CLIP_START_DOCSTRING,
|
1256 |
+
)
|
1257 |
+
class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
|
1258 |
+
config_class = CLIPVisionConfig
|
1259 |
+
main_input_name = "pixel_values"
|
1260 |
+
|
1261 |
+
def __init__(self, config: CLIPVisionConfig):
|
1262 |
+
super().__init__(config)
|
1263 |
+
|
1264 |
+
self.vision_model = CLIPVisionTransformer(config)
|
1265 |
+
|
1266 |
+
self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
|
1267 |
+
|
1268 |
+
# Initialize weights and apply final processing
|
1269 |
+
self.post_init()
|
1270 |
+
|
1271 |
+
def get_input_embeddings(self) -> nn.Module:
|
1272 |
+
return self.vision_model.embeddings.patch_embedding
|
1273 |
+
|
1274 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
1275 |
+
@replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
|
1276 |
+
def forward(
|
1277 |
+
self,
|
1278 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1279 |
+
output_attentions: Optional[bool] = None,
|
1280 |
+
output_hidden_states: Optional[bool] = None,
|
1281 |
+
return_dict: Optional[bool] = None,
|
1282 |
+
) -> Union[Tuple, CLIPVisionModelOutput]:
|
1283 |
+
r"""
|
1284 |
+
Returns:
|
1285 |
+
|
1286 |
+
Examples:
|
1287 |
+
|
1288 |
+
```python
|
1289 |
+
>>> from PIL import Image
|
1290 |
+
>>> import requests
|
1291 |
+
>>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
|
1292 |
+
|
1293 |
+
>>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
|
1294 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1295 |
+
|
1296 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1297 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1298 |
+
|
1299 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1300 |
+
|
1301 |
+
>>> outputs = model(**inputs)
|
1302 |
+
>>> image_embeds = outputs.image_embeds
|
1303 |
+
```"""
|
1304 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1305 |
+
|
1306 |
+
vision_outputs = self.vision_model(
|
1307 |
+
pixel_values=pixel_values,
|
1308 |
+
output_attentions=output_attentions,
|
1309 |
+
output_hidden_states=output_hidden_states,
|
1310 |
+
return_dict=return_dict,
|
1311 |
+
)
|
1312 |
+
|
1313 |
+
pooled_output = vision_outputs[1] # pooled_output
|
1314 |
+
|
1315 |
+
image_embeds = self.visual_projection(pooled_output)
|
1316 |
+
|
1317 |
+
if not return_dict:
|
1318 |
+
outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
1319 |
+
return tuple(output for output in outputs if output is not None)
|
1320 |
+
|
1321 |
+
return CLIPVisionModelOutput(
|
1322 |
+
image_embeds=image_embeds,
|
1323 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1324 |
+
hidden_states=vision_outputs.hidden_states,
|
1325 |
+
attentions=vision_outputs.attentions,
|
1326 |
+
)
|
Tiger Model/diffusiers-Tiger/__init__.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.21.0.dev0"
|
2 |
+
|
3 |
+
from .configuration_utils import ConfigMixin
|
4 |
+
from .utils import (
|
5 |
+
OptionalDependencyNotAvailable,
|
6 |
+
is_flax_available,
|
7 |
+
is_inflect_available,
|
8 |
+
is_invisible_watermark_available,
|
9 |
+
is_k_diffusion_available,
|
10 |
+
is_k_diffusion_version,
|
11 |
+
is_librosa_available,
|
12 |
+
is_note_seq_available,
|
13 |
+
is_onnx_available,
|
14 |
+
is_scipy_available,
|
15 |
+
is_torch_available,
|
16 |
+
is_torchsde_available,
|
17 |
+
is_transformers_available,
|
18 |
+
is_transformers_version,
|
19 |
+
is_unidecode_available,
|
20 |
+
logging,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
if not is_onnx_available():
|
26 |
+
raise OptionalDependencyNotAvailable()
|
27 |
+
except OptionalDependencyNotAvailable:
|
28 |
+
from .utils.dummy_onnx_objects import * # noqa F403
|
29 |
+
else:
|
30 |
+
from .pipelines import OnnxRuntimeModel
|
31 |
+
|
32 |
+
try:
|
33 |
+
if not is_torch_available():
|
34 |
+
raise OptionalDependencyNotAvailable()
|
35 |
+
except OptionalDependencyNotAvailable:
|
36 |
+
from .utils.dummy_pt_objects import * # noqa F403
|
37 |
+
else:
|
38 |
+
from .models import (
|
39 |
+
AsymmetricAutoencoderKL,
|
40 |
+
AutoencoderKL,
|
41 |
+
AutoencoderTiny,
|
42 |
+
ControlNetModel,
|
43 |
+
ModelMixin,
|
44 |
+
MultiAdapter,
|
45 |
+
PriorTransformer,
|
46 |
+
T2IAdapter,
|
47 |
+
T5FilmDecoder,
|
48 |
+
Transformer2DModel,
|
49 |
+
UNet1DModel,
|
50 |
+
UNet2DConditionModel,
|
51 |
+
UNet2DModel,
|
52 |
+
UNet3DConditionModel,
|
53 |
+
VQModel,
|
54 |
+
)
|
55 |
+
from .optimization import (
|
56 |
+
get_constant_schedule,
|
57 |
+
get_constant_schedule_with_warmup,
|
58 |
+
get_cosine_schedule_with_warmup,
|
59 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
60 |
+
get_linear_schedule_with_warmup,
|
61 |
+
get_polynomial_decay_schedule_with_warmup,
|
62 |
+
get_scheduler,
|
63 |
+
)
|
64 |
+
from .pipelines import (
|
65 |
+
AudioPipelineOutput,
|
66 |
+
AutoPipelineForImage2Image,
|
67 |
+
AutoPipelineForInpainting,
|
68 |
+
AutoPipelineForText2Image,
|
69 |
+
ConsistencyModelPipeline,
|
70 |
+
DanceDiffusionPipeline,
|
71 |
+
DDIMPipeline,
|
72 |
+
DDPMPipeline,
|
73 |
+
DiffusionPipeline,
|
74 |
+
DiTPipeline,
|
75 |
+
ImagePipelineOutput,
|
76 |
+
KarrasVePipeline,
|
77 |
+
LDMPipeline,
|
78 |
+
LDMSuperResolutionPipeline,
|
79 |
+
PNDMPipeline,
|
80 |
+
RePaintPipeline,
|
81 |
+
ScoreSdeVePipeline,
|
82 |
+
)
|
83 |
+
from .schedulers import (
|
84 |
+
CMStochasticIterativeScheduler,
|
85 |
+
DDIMInverseScheduler,
|
86 |
+
DDIMParallelScheduler,
|
87 |
+
DDIMScheduler,
|
88 |
+
DDPMParallelScheduler,
|
89 |
+
DDPMScheduler,
|
90 |
+
DEISMultistepScheduler,
|
91 |
+
DPMSolverMultistepInverseScheduler,
|
92 |
+
DPMSolverMultistepScheduler,
|
93 |
+
DPMSolverSinglestepScheduler,
|
94 |
+
EulerAncestralDiscreteScheduler,
|
95 |
+
EulerDiscreteScheduler,
|
96 |
+
HeunDiscreteScheduler,
|
97 |
+
IPNDMScheduler,
|
98 |
+
KarrasVeScheduler,
|
99 |
+
KDPM2AncestralDiscreteScheduler,
|
100 |
+
KDPM2DiscreteScheduler,
|
101 |
+
PNDMScheduler,
|
102 |
+
RePaintScheduler,
|
103 |
+
SchedulerMixin,
|
104 |
+
ScoreSdeVeScheduler,
|
105 |
+
UnCLIPScheduler,
|
106 |
+
UniPCMultistepScheduler,
|
107 |
+
VQDiffusionScheduler,
|
108 |
+
)
|
109 |
+
from .training_utils import EMAModel
|
110 |
+
|
111 |
+
try:
|
112 |
+
if not (is_torch_available() and is_scipy_available()):
|
113 |
+
raise OptionalDependencyNotAvailable()
|
114 |
+
except OptionalDependencyNotAvailable:
|
115 |
+
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
116 |
+
else:
|
117 |
+
from .schedulers import LMSDiscreteScheduler
|
118 |
+
|
119 |
+
try:
|
120 |
+
if not (is_torch_available() and is_torchsde_available()):
|
121 |
+
raise OptionalDependencyNotAvailable()
|
122 |
+
except OptionalDependencyNotAvailable:
|
123 |
+
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
124 |
+
else:
|
125 |
+
from .schedulers import DPMSolverSDEScheduler
|
126 |
+
|
127 |
+
try:
|
128 |
+
if not (is_torch_available() and is_transformers_available()):
|
129 |
+
raise OptionalDependencyNotAvailable()
|
130 |
+
except OptionalDependencyNotAvailable:
|
131 |
+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
132 |
+
else:
|
133 |
+
from .pipelines import (
|
134 |
+
AltDiffusionImg2ImgPipeline,
|
135 |
+
AltDiffusionPipeline,
|
136 |
+
AudioLDMPipeline,
|
137 |
+
CycleDiffusionPipeline,
|
138 |
+
IFImg2ImgPipeline,
|
139 |
+
IFImg2ImgSuperResolutionPipeline,
|
140 |
+
IFInpaintingPipeline,
|
141 |
+
IFInpaintingSuperResolutionPipeline,
|
142 |
+
IFPipeline,
|
143 |
+
IFSuperResolutionPipeline,
|
144 |
+
ImageTextPipelineOutput,
|
145 |
+
KandinskyCombinedPipeline,
|
146 |
+
KandinskyImg2ImgCombinedPipeline,
|
147 |
+
KandinskyImg2ImgPipeline,
|
148 |
+
KandinskyInpaintCombinedPipeline,
|
149 |
+
KandinskyInpaintPipeline,
|
150 |
+
KandinskyPipeline,
|
151 |
+
KandinskyPriorPipeline,
|
152 |
+
KandinskyV22CombinedPipeline,
|
153 |
+
KandinskyV22ControlnetImg2ImgPipeline,
|
154 |
+
KandinskyV22ControlnetPipeline,
|
155 |
+
KandinskyV22Img2ImgCombinedPipeline,
|
156 |
+
KandinskyV22Img2ImgPipeline,
|
157 |
+
KandinskyV22InpaintCombinedPipeline,
|
158 |
+
KandinskyV22InpaintPipeline,
|
159 |
+
KandinskyV22Pipeline,
|
160 |
+
KandinskyV22PriorEmb2EmbPipeline,
|
161 |
+
KandinskyV22PriorPipeline,
|
162 |
+
LDMTextToImagePipeline,
|
163 |
+
PaintByExamplePipeline,
|
164 |
+
SemanticStableDiffusionPipeline,
|
165 |
+
ShapEImg2ImgPipeline,
|
166 |
+
ShapEPipeline,
|
167 |
+
StableDiffusionAdapterPipeline,
|
168 |
+
StableDiffusionAttendAndExcitePipeline,
|
169 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
170 |
+
StableDiffusionControlNetInpaintPipeline,
|
171 |
+
StableDiffusionControlNetPipeline,
|
172 |
+
StableDiffusionDepth2ImgPipeline,
|
173 |
+
StableDiffusionDiffEditPipeline,
|
174 |
+
StableDiffusionGLIGENPipeline,
|
175 |
+
StableDiffusionImageVariationPipeline,
|
176 |
+
StableDiffusionImg2ImgPipeline,
|
177 |
+
StableDiffusionInpaintPipeline,
|
178 |
+
StableDiffusionInpaintPipelineLegacy,
|
179 |
+
StableDiffusionInstructPix2PixPipeline,
|
180 |
+
StableDiffusionLatentUpscalePipeline,
|
181 |
+
StableDiffusionLDM3DPipeline,
|
182 |
+
StableDiffusionModelEditingPipeline,
|
183 |
+
StableDiffusionPanoramaPipeline,
|
184 |
+
StableDiffusionParadigmsPipeline,
|
185 |
+
StableDiffusionPipeline,
|
186 |
+
StableDiffusionPipelineSafe,
|
187 |
+
StableDiffusionPix2PixZeroPipeline,
|
188 |
+
StableDiffusionSAGPipeline,
|
189 |
+
StableDiffusionUpscalePipeline,
|
190 |
+
StableDiffusionXLControlNetPipeline,
|
191 |
+
StableDiffusionXLImg2ImgPipeline,
|
192 |
+
StableDiffusionXLInpaintPipeline,
|
193 |
+
StableDiffusionXLInstructPix2PixPipeline,
|
194 |
+
StableDiffusionXLPipeline,
|
195 |
+
StableUnCLIPImg2ImgPipeline,
|
196 |
+
StableUnCLIPPipeline,
|
197 |
+
TextToVideoSDPipeline,
|
198 |
+
TextToVideoZeroPipeline,
|
199 |
+
UnCLIPImageVariationPipeline,
|
200 |
+
UnCLIPPipeline,
|
201 |
+
UniDiffuserModel,
|
202 |
+
UniDiffuserPipeline,
|
203 |
+
UniDiffuserTextDecoder,
|
204 |
+
VersatileDiffusionDualGuidedPipeline,
|
205 |
+
VersatileDiffusionImageVariationPipeline,
|
206 |
+
VersatileDiffusionPipeline,
|
207 |
+
VersatileDiffusionTextToImagePipeline,
|
208 |
+
VideoToVideoSDPipeline,
|
209 |
+
VQDiffusionPipeline,
|
210 |
+
)
|
211 |
+
|
212 |
+
try:
|
213 |
+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
214 |
+
raise OptionalDependencyNotAvailable()
|
215 |
+
except OptionalDependencyNotAvailable:
|
216 |
+
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
217 |
+
else:
|
218 |
+
from .pipelines import StableDiffusionKDiffusionPipeline
|
219 |
+
|
220 |
+
try:
|
221 |
+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
222 |
+
raise OptionalDependencyNotAvailable()
|
223 |
+
except OptionalDependencyNotAvailable:
|
224 |
+
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
225 |
+
else:
|
226 |
+
from .pipelines import (
|
227 |
+
OnnxStableDiffusionImg2ImgPipeline,
|
228 |
+
OnnxStableDiffusionInpaintPipeline,
|
229 |
+
OnnxStableDiffusionInpaintPipelineLegacy,
|
230 |
+
OnnxStableDiffusionPipeline,
|
231 |
+
OnnxStableDiffusionUpscalePipeline,
|
232 |
+
StableDiffusionOnnxPipeline,
|
233 |
+
)
|
234 |
+
|
235 |
+
try:
|
236 |
+
if not (is_torch_available() and is_librosa_available()):
|
237 |
+
raise OptionalDependencyNotAvailable()
|
238 |
+
except OptionalDependencyNotAvailable:
|
239 |
+
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
240 |
+
else:
|
241 |
+
from .pipelines import AudioDiffusionPipeline, Mel
|
242 |
+
|
243 |
+
try:
|
244 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
245 |
+
raise OptionalDependencyNotAvailable()
|
246 |
+
except OptionalDependencyNotAvailable:
|
247 |
+
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
248 |
+
else:
|
249 |
+
from .pipelines import SpectrogramDiffusionPipeline
|
250 |
+
|
251 |
+
try:
|
252 |
+
if not is_flax_available():
|
253 |
+
raise OptionalDependencyNotAvailable()
|
254 |
+
except OptionalDependencyNotAvailable:
|
255 |
+
from .utils.dummy_flax_objects import * # noqa F403
|
256 |
+
else:
|
257 |
+
from .models.controlnet_flax import FlaxControlNetModel
|
258 |
+
from .models.modeling_flax_utils import FlaxModelMixin
|
259 |
+
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
260 |
+
from .models.vae_flax import FlaxAutoencoderKL
|
261 |
+
from .pipelines import FlaxDiffusionPipeline
|
262 |
+
from .schedulers import (
|
263 |
+
FlaxDDIMScheduler,
|
264 |
+
FlaxDDPMScheduler,
|
265 |
+
FlaxDPMSolverMultistepScheduler,
|
266 |
+
FlaxKarrasVeScheduler,
|
267 |
+
FlaxLMSDiscreteScheduler,
|
268 |
+
FlaxPNDMScheduler,
|
269 |
+
FlaxSchedulerMixin,
|
270 |
+
FlaxScoreSdeVeScheduler,
|
271 |
+
)
|
272 |
+
|
273 |
+
|
274 |
+
try:
|
275 |
+
if not (is_flax_available() and is_transformers_available()):
|
276 |
+
raise OptionalDependencyNotAvailable()
|
277 |
+
except OptionalDependencyNotAvailable:
|
278 |
+
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
279 |
+
else:
|
280 |
+
from .pipelines import (
|
281 |
+
FlaxStableDiffusionControlNetPipeline,
|
282 |
+
FlaxStableDiffusionImg2ImgPipeline,
|
283 |
+
FlaxStableDiffusionInpaintPipeline,
|
284 |
+
FlaxStableDiffusionPipeline,
|
285 |
+
)
|
286 |
+
|
287 |
+
try:
|
288 |
+
if not (is_note_seq_available()):
|
289 |
+
raise OptionalDependencyNotAvailable()
|
290 |
+
except OptionalDependencyNotAvailable:
|
291 |
+
from .utils.dummy_note_seq_objects import * # noqa F403
|
292 |
+
else:
|
293 |
+
from .pipelines import MidiProcessor
|
Tiger Model/diffusiers-Tiger/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (8.47 kB). View file
|
|
Tiger Model/diffusiers-Tiger/__pycache__/configuration_utils.cpython-38.pyc
ADDED
Binary file (24 kB). View file
|
|
Tiger Model/diffusiers-Tiger/__pycache__/fuse.cpython-38.pyc
ADDED
Binary file (3.83 kB). View file
|
|
Tiger Model/diffusiers-Tiger/__pycache__/image_processor.cpython-38.pyc
ADDED
Binary file (12.7 kB). View file
|
|
Tiger Model/diffusiers-Tiger/__pycache__/loaders.cpython-38.pyc
ADDED
Binary file (78.3 kB). View file
|
|
Tiger Model/diffusiers-Tiger/__pycache__/optimization.cpython-38.pyc
ADDED
Binary file (12.8 kB). View file
|
|
Tiger Model/diffusiers-Tiger/__pycache__/training_utils.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
Tiger Model/diffusiers-Tiger/commands/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
|
19 |
+
class BaseDiffusersCLICommand(ABC):
|
20 |
+
@staticmethod
|
21 |
+
@abstractmethod
|
22 |
+
def register_subcommand(parser: ArgumentParser):
|
23 |
+
raise NotImplementedError()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def run(self):
|
27 |
+
raise NotImplementedError()
|
Tiger Model/diffusiers-Tiger/commands/diffusers_cli.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
from .env import EnvironmentCommand
|
19 |
+
from .fp16_safetensors import FP16SafetensorsCommand
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
|
24 |
+
commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
|
25 |
+
|
26 |
+
# Register commands
|
27 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
28 |
+
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
29 |
+
|
30 |
+
# Let's go
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
if not hasattr(args, "func"):
|
34 |
+
parser.print_help()
|
35 |
+
exit(1)
|
36 |
+
|
37 |
+
# Run
|
38 |
+
service = args.func(args)
|
39 |
+
service.run()
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
main()
|
Tiger Model/diffusiers-Tiger/commands/env.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import platform
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
import huggingface_hub
|
19 |
+
|
20 |
+
from .. import __version__ as version
|
21 |
+
from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
|
22 |
+
from . import BaseDiffusersCLICommand
|
23 |
+
|
24 |
+
|
25 |
+
def info_command_factory(_):
|
26 |
+
return EnvironmentCommand()
|
27 |
+
|
28 |
+
|
29 |
+
class EnvironmentCommand(BaseDiffusersCLICommand):
|
30 |
+
@staticmethod
|
31 |
+
def register_subcommand(parser: ArgumentParser):
|
32 |
+
download_parser = parser.add_parser("env")
|
33 |
+
download_parser.set_defaults(func=info_command_factory)
|
34 |
+
|
35 |
+
def run(self):
|
36 |
+
hub_version = huggingface_hub.__version__
|
37 |
+
|
38 |
+
pt_version = "not installed"
|
39 |
+
pt_cuda_available = "NA"
|
40 |
+
if is_torch_available():
|
41 |
+
import torch
|
42 |
+
|
43 |
+
pt_version = torch.__version__
|
44 |
+
pt_cuda_available = torch.cuda.is_available()
|
45 |
+
|
46 |
+
transformers_version = "not installed"
|
47 |
+
if is_transformers_available():
|
48 |
+
import transformers
|
49 |
+
|
50 |
+
transformers_version = transformers.__version__
|
51 |
+
|
52 |
+
accelerate_version = "not installed"
|
53 |
+
if is_accelerate_available():
|
54 |
+
import accelerate
|
55 |
+
|
56 |
+
accelerate_version = accelerate.__version__
|
57 |
+
|
58 |
+
xformers_version = "not installed"
|
59 |
+
if is_xformers_available():
|
60 |
+
import xformers
|
61 |
+
|
62 |
+
xformers_version = xformers.__version__
|
63 |
+
|
64 |
+
info = {
|
65 |
+
"`diffusers` version": version,
|
66 |
+
"Platform": platform.platform(),
|
67 |
+
"Python version": platform.python_version(),
|
68 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
69 |
+
"Huggingface_hub version": hub_version,
|
70 |
+
"Transformers version": transformers_version,
|
71 |
+
"Accelerate version": accelerate_version,
|
72 |
+
"xFormers version": xformers_version,
|
73 |
+
"Using GPU in script?": "<fill in>",
|
74 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
75 |
+
}
|
76 |
+
|
77 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
78 |
+
print(self.format_dict(info))
|
79 |
+
|
80 |
+
return info
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def format_dict(d):
|
84 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
Tiger Model/diffusiers-Tiger/commands/fp16_safetensors.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
Usage example:
|
17 |
+
diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
|
18 |
+
"""
|
19 |
+
|
20 |
+
import glob
|
21 |
+
import json
|
22 |
+
from argparse import ArgumentParser, Namespace
|
23 |
+
from importlib import import_module
|
24 |
+
|
25 |
+
import huggingface_hub
|
26 |
+
import torch
|
27 |
+
from huggingface_hub import hf_hub_download
|
28 |
+
from packaging import version
|
29 |
+
|
30 |
+
from ..utils import logging
|
31 |
+
from . import BaseDiffusersCLICommand
|
32 |
+
|
33 |
+
|
34 |
+
def conversion_command_factory(args: Namespace):
|
35 |
+
return FP16SafetensorsCommand(
|
36 |
+
args.ckpt_id,
|
37 |
+
args.fp16,
|
38 |
+
args.use_safetensors,
|
39 |
+
args.use_auth_token,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
44 |
+
@staticmethod
|
45 |
+
def register_subcommand(parser: ArgumentParser):
|
46 |
+
conversion_parser = parser.add_parser("fp16_safetensors")
|
47 |
+
conversion_parser.add_argument(
|
48 |
+
"--ckpt_id",
|
49 |
+
type=str,
|
50 |
+
help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
|
51 |
+
)
|
52 |
+
conversion_parser.add_argument(
|
53 |
+
"--fp16", action="store_true", help="If serializing the variables in FP16 precision."
|
54 |
+
)
|
55 |
+
conversion_parser.add_argument(
|
56 |
+
"--use_safetensors", action="store_true", help="If serializing in the safetensors format."
|
57 |
+
)
|
58 |
+
conversion_parser.add_argument(
|
59 |
+
"--use_auth_token",
|
60 |
+
action="store_true",
|
61 |
+
help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
|
62 |
+
)
|
63 |
+
conversion_parser.set_defaults(func=conversion_command_factory)
|
64 |
+
|
65 |
+
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
|
66 |
+
self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
|
67 |
+
self.ckpt_id = ckpt_id
|
68 |
+
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
69 |
+
self.fp16 = fp16
|
70 |
+
|
71 |
+
self.use_safetensors = use_safetensors
|
72 |
+
|
73 |
+
if not self.use_safetensors and not self.fp16:
|
74 |
+
raise NotImplementedError(
|
75 |
+
"When `use_safetensors` and `fp16` both are False, then this command is of no use."
|
76 |
+
)
|
77 |
+
|
78 |
+
self.use_auth_token = use_auth_token
|
79 |
+
|
80 |
+
def run(self):
|
81 |
+
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
|
82 |
+
raise ImportError(
|
83 |
+
"The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
|
84 |
+
" installation."
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
from huggingface_hub import create_commit
|
88 |
+
from huggingface_hub._commit_api import CommitOperationAdd
|
89 |
+
|
90 |
+
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
|
91 |
+
with open(model_index, "r") as f:
|
92 |
+
pipeline_class_name = json.load(f)["_class_name"]
|
93 |
+
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
|
94 |
+
self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
|
95 |
+
|
96 |
+
# Load the appropriate pipeline. We could have use `DiffusionPipeline`
|
97 |
+
# here, but just to avoid any rough edge cases.
|
98 |
+
pipeline = pipeline_class.from_pretrained(
|
99 |
+
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
|
100 |
+
)
|
101 |
+
pipeline.save_pretrained(
|
102 |
+
self.local_ckpt_dir,
|
103 |
+
safe_serialization=True if self.use_safetensors else False,
|
104 |
+
variant="fp16" if self.fp16 else None,
|
105 |
+
)
|
106 |
+
self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
|
107 |
+
|
108 |
+
# Fetch all the paths.
|
109 |
+
if self.fp16:
|
110 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
|
111 |
+
elif self.use_safetensors:
|
112 |
+
modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
|
113 |
+
|
114 |
+
# Prepare for the PR.
|
115 |
+
commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
|
116 |
+
operations = []
|
117 |
+
for path in modified_paths:
|
118 |
+
operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
|
119 |
+
|
120 |
+
# Open the PR.
|
121 |
+
commit_description = (
|
122 |
+
"Variables converted by the [`diffusers`' `fp16_safetensors`"
|
123 |
+
" CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
|
124 |
+
)
|
125 |
+
hub_pr_url = create_commit(
|
126 |
+
repo_id=self.ckpt_id,
|
127 |
+
operations=operations,
|
128 |
+
commit_message=commit_message,
|
129 |
+
commit_description=commit_description,
|
130 |
+
repo_type="model",
|
131 |
+
create_pr=True,
|
132 |
+
).pr_url
|
133 |
+
self.logger.info(f"PR created here: {hub_pr_url}.")
|
Tiger Model/diffusiers-Tiger/configuration_utils.py
ADDED
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ConfigMixin base class and utilities."""
|
17 |
+
import dataclasses
|
18 |
+
import functools
|
19 |
+
import importlib
|
20 |
+
import inspect
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import re
|
24 |
+
from collections import OrderedDict
|
25 |
+
from pathlib import PosixPath
|
26 |
+
from typing import Any, Dict, Tuple, Union
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
from huggingface_hub import create_repo, hf_hub_download
|
30 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
31 |
+
from requests import HTTPError
|
32 |
+
|
33 |
+
from . import __version__
|
34 |
+
from .utils import (
|
35 |
+
DIFFUSERS_CACHE,
|
36 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
37 |
+
DummyObject,
|
38 |
+
deprecate,
|
39 |
+
extract_commit_hash,
|
40 |
+
http_user_agent,
|
41 |
+
logging,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
48 |
+
|
49 |
+
|
50 |
+
class FrozenDict(OrderedDict):
|
51 |
+
def __init__(self, *args, **kwargs):
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
for key, value in self.items():
|
55 |
+
setattr(self, key, value)
|
56 |
+
|
57 |
+
self.__frozen = True
|
58 |
+
|
59 |
+
def __delitem__(self, *args, **kwargs):
|
60 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
61 |
+
|
62 |
+
def setdefault(self, *args, **kwargs):
|
63 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
64 |
+
|
65 |
+
def pop(self, *args, **kwargs):
|
66 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
67 |
+
|
68 |
+
def update(self, *args, **kwargs):
|
69 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
70 |
+
|
71 |
+
def __setattr__(self, name, value):
|
72 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
73 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
74 |
+
super().__setattr__(name, value)
|
75 |
+
|
76 |
+
def __setitem__(self, name, value):
|
77 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
78 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
79 |
+
super().__setitem__(name, value)
|
80 |
+
|
81 |
+
|
82 |
+
class ConfigMixin:
|
83 |
+
r"""
|
84 |
+
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
|
85 |
+
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
|
86 |
+
saving classes that inherit from [`ConfigMixin`].
|
87 |
+
|
88 |
+
Class attributes:
|
89 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
90 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
91 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
92 |
+
overridden by subclass).
|
93 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
94 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
95 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
96 |
+
subclass).
|
97 |
+
"""
|
98 |
+
config_name = None
|
99 |
+
ignore_for_config = []
|
100 |
+
has_compatibles = False
|
101 |
+
|
102 |
+
_deprecated_kwargs = []
|
103 |
+
|
104 |
+
def register_to_config(self, **kwargs):
|
105 |
+
if self.config_name is None:
|
106 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
107 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
108 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
109 |
+
# or solve in a more general way.
|
110 |
+
kwargs.pop("kwargs", None)
|
111 |
+
|
112 |
+
if not hasattr(self, "_internal_dict"):
|
113 |
+
internal_dict = kwargs
|
114 |
+
else:
|
115 |
+
previous_dict = dict(self._internal_dict)
|
116 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
117 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
118 |
+
|
119 |
+
self._internal_dict = FrozenDict(internal_dict)
|
120 |
+
|
121 |
+
def __getattr__(self, name: str) -> Any:
|
122 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
123 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
124 |
+
|
125 |
+
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
126 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
127 |
+
"""
|
128 |
+
|
129 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
130 |
+
is_attribute = name in self.__dict__
|
131 |
+
|
132 |
+
if is_in_config and not is_attribute:
|
133 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
134 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
135 |
+
return self._internal_dict[name]
|
136 |
+
|
137 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
138 |
+
|
139 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
140 |
+
"""
|
141 |
+
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
142 |
+
[`~ConfigMixin.from_config`] class method.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
save_directory (`str` or `os.PathLike`):
|
146 |
+
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
147 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
148 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
149 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
150 |
+
namespace).
|
151 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
152 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
153 |
+
"""
|
154 |
+
if os.path.isfile(save_directory):
|
155 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
156 |
+
|
157 |
+
os.makedirs(save_directory, exist_ok=True)
|
158 |
+
|
159 |
+
# If we save using the predefined names, we can load using `from_config`
|
160 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
161 |
+
|
162 |
+
self.to_json_file(output_config_file)
|
163 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
164 |
+
|
165 |
+
if push_to_hub:
|
166 |
+
commit_message = kwargs.pop("commit_message", None)
|
167 |
+
private = kwargs.pop("private", False)
|
168 |
+
create_pr = kwargs.pop("create_pr", False)
|
169 |
+
token = kwargs.pop("token", None)
|
170 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
171 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
172 |
+
|
173 |
+
self._upload_folder(
|
174 |
+
save_directory,
|
175 |
+
repo_id,
|
176 |
+
token=token,
|
177 |
+
commit_message=commit_message,
|
178 |
+
create_pr=create_pr,
|
179 |
+
)
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
183 |
+
r"""
|
184 |
+
Instantiate a Python class from a config dictionary.
|
185 |
+
|
186 |
+
Parameters:
|
187 |
+
config (`Dict[str, Any]`):
|
188 |
+
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
189 |
+
files of compatible classes.
|
190 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
191 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
192 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
193 |
+
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
|
194 |
+
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
|
195 |
+
overwrite the same named arguments in `config`.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
[`ModelMixin`] or [`SchedulerMixin`]:
|
199 |
+
A model or scheduler object instantiated from a config dictionary.
|
200 |
+
|
201 |
+
Examples:
|
202 |
+
|
203 |
+
```python
|
204 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
205 |
+
|
206 |
+
>>> # Download scheduler from huggingface.co and cache.
|
207 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
208 |
+
|
209 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
210 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
211 |
+
|
212 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
213 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
214 |
+
```
|
215 |
+
"""
|
216 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
217 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
218 |
+
if "pretrained_model_name_or_path" in kwargs:
|
219 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
220 |
+
|
221 |
+
if config is None:
|
222 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
223 |
+
# ======>
|
224 |
+
|
225 |
+
if not isinstance(config, dict):
|
226 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
227 |
+
if "Scheduler" in cls.__name__:
|
228 |
+
deprecation_message += (
|
229 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
230 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
231 |
+
" be removed in v1.0.0."
|
232 |
+
)
|
233 |
+
elif "Model" in cls.__name__:
|
234 |
+
deprecation_message += (
|
235 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
236 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
237 |
+
" instead. This functionality will be removed in v1.0.0."
|
238 |
+
)
|
239 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
240 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
241 |
+
|
242 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
243 |
+
|
244 |
+
# Allow dtype to be specified on initialization
|
245 |
+
if "dtype" in unused_kwargs:
|
246 |
+
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
247 |
+
|
248 |
+
# add possible deprecated kwargs
|
249 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
250 |
+
if deprecated_kwarg in unused_kwargs:
|
251 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
252 |
+
|
253 |
+
# Return model and optionally state and/or unused_kwargs
|
254 |
+
model = cls(**init_dict)
|
255 |
+
|
256 |
+
# make sure to also save config parameters that might be used for compatible classes
|
257 |
+
model.register_to_config(**hidden_dict)
|
258 |
+
|
259 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
260 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
261 |
+
|
262 |
+
if return_unused_kwargs:
|
263 |
+
return (model, unused_kwargs)
|
264 |
+
else:
|
265 |
+
return model
|
266 |
+
|
267 |
+
@classmethod
|
268 |
+
def get_config_dict(cls, *args, **kwargs):
|
269 |
+
deprecation_message = (
|
270 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
271 |
+
" removed in version v1.0.0"
|
272 |
+
)
|
273 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
274 |
+
return cls.load_config(*args, **kwargs)
|
275 |
+
|
276 |
+
@classmethod
|
277 |
+
def load_config(
|
278 |
+
cls,
|
279 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
280 |
+
return_unused_kwargs=False,
|
281 |
+
return_commit_hash=False,
|
282 |
+
**kwargs,
|
283 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
284 |
+
r"""
|
285 |
+
Load a model or scheduler configuration.
|
286 |
+
|
287 |
+
Parameters:
|
288 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
289 |
+
Can be either:
|
290 |
+
|
291 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
292 |
+
the Hub.
|
293 |
+
- A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
|
294 |
+
[`~ConfigMixin.save_config`].
|
295 |
+
|
296 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
297 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
298 |
+
is not used.
|
299 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
300 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
301 |
+
cached versions if they exist.
|
302 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
303 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
304 |
+
incompletely downloaded files are deleted.
|
305 |
+
proxies (`Dict[str, str]`, *optional*):
|
306 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
307 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
308 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
309 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
310 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
311 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
312 |
+
won't be downloaded from the Hub.
|
313 |
+
use_auth_token (`str` or *bool*, *optional*):
|
314 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
315 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
316 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
317 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
318 |
+
allowed by Git.
|
319 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
320 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
321 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
322 |
+
Whether unused keyword arguments of the config are returned.
|
323 |
+
return_commit_hash (`bool`, *optional*, defaults to `False):
|
324 |
+
Whether the `commit_hash` of the loaded configuration are returned.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
`dict`:
|
328 |
+
A dictionary of all the parameters stored in a JSON configuration file.
|
329 |
+
|
330 |
+
"""
|
331 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
332 |
+
force_download = kwargs.pop("force_download", False)
|
333 |
+
resume_download = kwargs.pop("resume_download", False)
|
334 |
+
proxies = kwargs.pop("proxies", None)
|
335 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
336 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
337 |
+
revision = kwargs.pop("revision", None)
|
338 |
+
_ = kwargs.pop("mirror", None)
|
339 |
+
subfolder = kwargs.pop("subfolder", None)
|
340 |
+
user_agent = kwargs.pop("user_agent", {})
|
341 |
+
|
342 |
+
user_agent = {**user_agent, "file_type": "config"}
|
343 |
+
user_agent = http_user_agent(user_agent)
|
344 |
+
|
345 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
346 |
+
|
347 |
+
if cls.config_name is None:
|
348 |
+
raise ValueError(
|
349 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
350 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
351 |
+
)
|
352 |
+
|
353 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
354 |
+
config_file = pretrained_model_name_or_path
|
355 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
356 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
357 |
+
# Load from a PyTorch checkpoint
|
358 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
359 |
+
elif subfolder is not None and os.path.isfile(
|
360 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
361 |
+
):
|
362 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
363 |
+
else:
|
364 |
+
raise EnvironmentError(
|
365 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
366 |
+
)
|
367 |
+
else:
|
368 |
+
try:
|
369 |
+
# Load from URL or cache if already cached
|
370 |
+
config_file = hf_hub_download(
|
371 |
+
pretrained_model_name_or_path,
|
372 |
+
filename=cls.config_name,
|
373 |
+
cache_dir=cache_dir,
|
374 |
+
force_download=force_download,
|
375 |
+
proxies=proxies,
|
376 |
+
resume_download=resume_download,
|
377 |
+
local_files_only=local_files_only,
|
378 |
+
use_auth_token=use_auth_token,
|
379 |
+
user_agent=user_agent,
|
380 |
+
subfolder=subfolder,
|
381 |
+
revision=revision,
|
382 |
+
)
|
383 |
+
except RepositoryNotFoundError:
|
384 |
+
raise EnvironmentError(
|
385 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
386 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
387 |
+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
388 |
+
" login`."
|
389 |
+
)
|
390 |
+
except RevisionNotFoundError:
|
391 |
+
raise EnvironmentError(
|
392 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
393 |
+
" this model name. Check the model page at"
|
394 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
395 |
+
)
|
396 |
+
except EntryNotFoundError:
|
397 |
+
raise EnvironmentError(
|
398 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
399 |
+
)
|
400 |
+
except HTTPError as err:
|
401 |
+
raise EnvironmentError(
|
402 |
+
"There was a specific connection error when trying to load"
|
403 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
404 |
+
)
|
405 |
+
except ValueError:
|
406 |
+
raise EnvironmentError(
|
407 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
408 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
409 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
410 |
+
" run the library in offline mode at"
|
411 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
412 |
+
)
|
413 |
+
except EnvironmentError:
|
414 |
+
raise EnvironmentError(
|
415 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
416 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
417 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
418 |
+
f"containing a {cls.config_name} file"
|
419 |
+
)
|
420 |
+
|
421 |
+
try:
|
422 |
+
# Load config dict
|
423 |
+
config_dict = cls._dict_from_json_file(config_file)
|
424 |
+
|
425 |
+
commit_hash = extract_commit_hash(config_file)
|
426 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
427 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
428 |
+
|
429 |
+
if not (return_unused_kwargs or return_commit_hash):
|
430 |
+
return config_dict
|
431 |
+
|
432 |
+
outputs = (config_dict,)
|
433 |
+
|
434 |
+
if return_unused_kwargs:
|
435 |
+
outputs += (kwargs,)
|
436 |
+
|
437 |
+
if return_commit_hash:
|
438 |
+
outputs += (commit_hash,)
|
439 |
+
|
440 |
+
return outputs
|
441 |
+
|
442 |
+
@staticmethod
|
443 |
+
def _get_init_keys(cls):
|
444 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
445 |
+
|
446 |
+
@classmethod
|
447 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
448 |
+
# Skip keys that were not present in the original config, so default __init__ values were used
|
449 |
+
used_defaults = config_dict.get("_use_default_values", [])
|
450 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
|
451 |
+
|
452 |
+
# 0. Copy origin config dict
|
453 |
+
original_dict = dict(config_dict.items())
|
454 |
+
|
455 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
456 |
+
expected_keys = cls._get_init_keys(cls)
|
457 |
+
expected_keys.remove("self")
|
458 |
+
# remove general kwargs if present in dict
|
459 |
+
if "kwargs" in expected_keys:
|
460 |
+
expected_keys.remove("kwargs")
|
461 |
+
# remove flax internal keys
|
462 |
+
if hasattr(cls, "_flax_internal_args"):
|
463 |
+
for arg in cls._flax_internal_args:
|
464 |
+
expected_keys.remove(arg)
|
465 |
+
|
466 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
467 |
+
# remove keys to be ignored
|
468 |
+
if len(cls.ignore_for_config) > 0:
|
469 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
470 |
+
|
471 |
+
# load diffusers library to import compatible and original scheduler
|
472 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
473 |
+
|
474 |
+
if cls.has_compatibles:
|
475 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
476 |
+
else:
|
477 |
+
compatible_classes = []
|
478 |
+
|
479 |
+
expected_keys_comp_cls = set()
|
480 |
+
for c in compatible_classes:
|
481 |
+
expected_keys_c = cls._get_init_keys(c)
|
482 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
483 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
484 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
485 |
+
|
486 |
+
# remove attributes from orig class that cannot be expected
|
487 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
488 |
+
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
489 |
+
orig_cls = getattr(diffusers_library, orig_cls_name)
|
490 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
491 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
492 |
+
|
493 |
+
# remove private attributes
|
494 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
495 |
+
|
496 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
497 |
+
init_dict = {}
|
498 |
+
for key in expected_keys:
|
499 |
+
# if config param is passed to kwarg and is present in config dict
|
500 |
+
# it should overwrite existing config dict key
|
501 |
+
if key in kwargs and key in config_dict:
|
502 |
+
config_dict[key] = kwargs.pop(key)
|
503 |
+
|
504 |
+
if key in kwargs:
|
505 |
+
# overwrite key
|
506 |
+
init_dict[key] = kwargs.pop(key)
|
507 |
+
elif key in config_dict:
|
508 |
+
# use value from config dict
|
509 |
+
init_dict[key] = config_dict.pop(key)
|
510 |
+
|
511 |
+
# 4. Give nice warning if unexpected values have been passed
|
512 |
+
if len(config_dict) > 0:
|
513 |
+
logger.warning(
|
514 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
515 |
+
"but are not expected and will be ignored. Please verify your "
|
516 |
+
f"{cls.config_name} configuration file."
|
517 |
+
)
|
518 |
+
|
519 |
+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
520 |
+
passed_keys = set(init_dict.keys())
|
521 |
+
if len(expected_keys - passed_keys) > 0:
|
522 |
+
logger.info(
|
523 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
524 |
+
)
|
525 |
+
|
526 |
+
# 6. Define unused keyword arguments
|
527 |
+
unused_kwargs = {**config_dict, **kwargs}
|
528 |
+
|
529 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
530 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
531 |
+
|
532 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
533 |
+
|
534 |
+
@classmethod
|
535 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
536 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
537 |
+
text = reader.read()
|
538 |
+
return json.loads(text)
|
539 |
+
|
540 |
+
def __repr__(self):
|
541 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
542 |
+
|
543 |
+
@property
|
544 |
+
def config(self) -> Dict[str, Any]:
|
545 |
+
"""
|
546 |
+
Returns the config of the class as a frozen dictionary
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
`Dict[str, Any]`: Config of the class.
|
550 |
+
"""
|
551 |
+
return self._internal_dict
|
552 |
+
|
553 |
+
def to_json_string(self) -> str:
|
554 |
+
"""
|
555 |
+
Serializes the configuration instance to a JSON string.
|
556 |
+
|
557 |
+
Returns:
|
558 |
+
`str`:
|
559 |
+
String containing all the attributes that make up the configuration instance in JSON format.
|
560 |
+
"""
|
561 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
562 |
+
config_dict["_class_name"] = self.__class__.__name__
|
563 |
+
config_dict["_diffusers_version"] = __version__
|
564 |
+
|
565 |
+
def to_json_saveable(value):
|
566 |
+
if isinstance(value, np.ndarray):
|
567 |
+
value = value.tolist()
|
568 |
+
elif isinstance(value, PosixPath):
|
569 |
+
value = str(value)
|
570 |
+
return value
|
571 |
+
|
572 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
573 |
+
# Don't save "_ignore_files" or "_use_default_values"
|
574 |
+
config_dict.pop("_ignore_files", None)
|
575 |
+
config_dict.pop("_use_default_values", None)
|
576 |
+
|
577 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
578 |
+
|
579 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
580 |
+
"""
|
581 |
+
Save the configuration instance's parameters to a JSON file.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
json_file_path (`str` or `os.PathLike`):
|
585 |
+
Path to the JSON file to save a configuration instance's parameters.
|
586 |
+
"""
|
587 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
588 |
+
writer.write(self.to_json_string())
|
589 |
+
|
590 |
+
|
591 |
+
def register_to_config(init):
|
592 |
+
r"""
|
593 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
594 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
595 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
596 |
+
|
597 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
598 |
+
"""
|
599 |
+
|
600 |
+
@functools.wraps(init)
|
601 |
+
def inner_init(self, *args, **kwargs):
|
602 |
+
# Ignore private kwargs in the init.
|
603 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
604 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
605 |
+
if not isinstance(self, ConfigMixin):
|
606 |
+
raise RuntimeError(
|
607 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
608 |
+
"not inherit from `ConfigMixin`."
|
609 |
+
)
|
610 |
+
|
611 |
+
ignore = getattr(self, "ignore_for_config", [])
|
612 |
+
# Get positional arguments aligned with kwargs
|
613 |
+
new_kwargs = {}
|
614 |
+
signature = inspect.signature(init)
|
615 |
+
parameters = {
|
616 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
617 |
+
}
|
618 |
+
for arg, name in zip(args, parameters.keys()):
|
619 |
+
new_kwargs[name] = arg
|
620 |
+
|
621 |
+
# Then add all kwargs
|
622 |
+
new_kwargs.update(
|
623 |
+
{
|
624 |
+
k: init_kwargs.get(k, default)
|
625 |
+
for k, default in parameters.items()
|
626 |
+
if k not in ignore and k not in new_kwargs
|
627 |
+
}
|
628 |
+
)
|
629 |
+
|
630 |
+
# Take note of the parameters that were not present in the loaded config
|
631 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
632 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
633 |
+
|
634 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
635 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
636 |
+
init(self, *args, **init_kwargs)
|
637 |
+
|
638 |
+
return inner_init
|
639 |
+
|
640 |
+
|
641 |
+
def flax_register_to_config(cls):
|
642 |
+
original_init = cls.__init__
|
643 |
+
|
644 |
+
@functools.wraps(original_init)
|
645 |
+
def init(self, *args, **kwargs):
|
646 |
+
if not isinstance(self, ConfigMixin):
|
647 |
+
raise RuntimeError(
|
648 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
649 |
+
"not inherit from `ConfigMixin`."
|
650 |
+
)
|
651 |
+
|
652 |
+
# Ignore private kwargs in the init. Retrieve all passed attributes
|
653 |
+
init_kwargs = dict(kwargs.items())
|
654 |
+
|
655 |
+
# Retrieve default values
|
656 |
+
fields = dataclasses.fields(self)
|
657 |
+
default_kwargs = {}
|
658 |
+
for field in fields:
|
659 |
+
# ignore flax specific attributes
|
660 |
+
if field.name in self._flax_internal_args:
|
661 |
+
continue
|
662 |
+
if type(field.default) == dataclasses._MISSING_TYPE:
|
663 |
+
default_kwargs[field.name] = None
|
664 |
+
else:
|
665 |
+
default_kwargs[field.name] = getattr(self, field.name)
|
666 |
+
|
667 |
+
# Make sure init_kwargs override default kwargs
|
668 |
+
new_kwargs = {**default_kwargs, **init_kwargs}
|
669 |
+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
670 |
+
if "dtype" in new_kwargs:
|
671 |
+
new_kwargs.pop("dtype")
|
672 |
+
|
673 |
+
# Get positional arguments aligned with kwargs
|
674 |
+
for i, arg in enumerate(args):
|
675 |
+
name = fields[i].name
|
676 |
+
new_kwargs[name] = arg
|
677 |
+
|
678 |
+
# Take note of the parameters that were not present in the loaded config
|
679 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
680 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
681 |
+
|
682 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
683 |
+
original_init(self, *args, **kwargs)
|
684 |
+
|
685 |
+
cls.__init__ = init
|
686 |
+
return cls
|
Tiger Model/diffusiers-Tiger/dependency_versions_check.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import sys
|
15 |
+
|
16 |
+
from .dependency_versions_table import deps
|
17 |
+
from .utils.versions import require_version, require_version_core
|
18 |
+
|
19 |
+
|
20 |
+
# define which module versions we always want to check at run time
|
21 |
+
# (usually the ones defined in `install_requires` in setup.py)
|
22 |
+
#
|
23 |
+
# order specific notes:
|
24 |
+
# - tqdm must be checked before tokenizers
|
25 |
+
|
26 |
+
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
|
27 |
+
if sys.version_info < (3, 7):
|
28 |
+
pkgs_to_check_at_runtime.append("dataclasses")
|
29 |
+
if sys.version_info < (3, 8):
|
30 |
+
pkgs_to_check_at_runtime.append("importlib_metadata")
|
31 |
+
|
32 |
+
for pkg in pkgs_to_check_at_runtime:
|
33 |
+
if pkg in deps:
|
34 |
+
if pkg == "tokenizers":
|
35 |
+
# must be loaded here, or else tqdm check may fail
|
36 |
+
from .utils import is_tokenizers_available
|
37 |
+
|
38 |
+
if not is_tokenizers_available():
|
39 |
+
continue # not required, check version only if installed
|
40 |
+
|
41 |
+
require_version_core(deps[pkg])
|
42 |
+
else:
|
43 |
+
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
44 |
+
|
45 |
+
|
46 |
+
def dep_version_check(pkg, hint=None):
|
47 |
+
require_version(deps[pkg], hint)
|
Tiger Model/diffusiers-Tiger/dependency_versions_table.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
2 |
+
# 1. modify the `_deps` dict in setup.py
|
3 |
+
# 2. run `make deps_table_update``
|
4 |
+
deps = {
|
5 |
+
"Pillow": "Pillow",
|
6 |
+
"accelerate": "accelerate>=0.11.0",
|
7 |
+
"compel": "compel==0.1.8",
|
8 |
+
"black": "black~=23.1",
|
9 |
+
"datasets": "datasets",
|
10 |
+
"filelock": "filelock",
|
11 |
+
"flax": "flax>=0.4.1",
|
12 |
+
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
13 |
+
"huggingface-hub": "huggingface-hub>=0.13.2",
|
14 |
+
"requests-mock": "requests-mock==1.10.0",
|
15 |
+
"importlib_metadata": "importlib_metadata",
|
16 |
+
"invisible-watermark": "invisible-watermark>=0.2.0",
|
17 |
+
"isort": "isort>=5.5.4",
|
18 |
+
"jax": "jax>=0.2.8,!=0.3.2",
|
19 |
+
"jaxlib": "jaxlib>=0.1.65",
|
20 |
+
"Jinja2": "Jinja2",
|
21 |
+
"k-diffusion": "k-diffusion>=0.0.12",
|
22 |
+
"torchsde": "torchsde",
|
23 |
+
"note_seq": "note_seq",
|
24 |
+
"librosa": "librosa",
|
25 |
+
"numpy": "numpy",
|
26 |
+
"omegaconf": "omegaconf",
|
27 |
+
"parameterized": "parameterized",
|
28 |
+
"protobuf": "protobuf>=3.20.3,<4",
|
29 |
+
"pytest": "pytest",
|
30 |
+
"pytest-timeout": "pytest-timeout",
|
31 |
+
"pytest-xdist": "pytest-xdist",
|
32 |
+
"ruff": "ruff==0.0.280",
|
33 |
+
"safetensors": "safetensors>=0.3.1",
|
34 |
+
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
35 |
+
"scipy": "scipy",
|
36 |
+
"onnx": "onnx",
|
37 |
+
"regex": "regex!=2019.12.17",
|
38 |
+
"requests": "requests",
|
39 |
+
"tensorboard": "tensorboard",
|
40 |
+
"torch": "torch>=1.4",
|
41 |
+
"torchvision": "torchvision",
|
42 |
+
"transformers": "transformers>=4.25.1",
|
43 |
+
"urllib3": "urllib3<=2.0.0",
|
44 |
+
}
|
Tiger Model/diffusiers-Tiger/fuse.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class DAF(nn.Module):
|
6 |
+
'''
|
7 |
+
直接相加 DirectAddFuse
|
8 |
+
'''
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super(DAF, self).__init__()
|
12 |
+
|
13 |
+
def forward(self, x, residual):
|
14 |
+
return x + residual
|
15 |
+
|
16 |
+
|
17 |
+
class iAFF(nn.Module):
|
18 |
+
'''
|
19 |
+
多特征融合 iAFF
|
20 |
+
'''
|
21 |
+
|
22 |
+
def __init__(self, channels=64, r=4):
|
23 |
+
super(iAFF, self).__init__()
|
24 |
+
inter_channels = int(channels // r)
|
25 |
+
|
26 |
+
# 本地注意力
|
27 |
+
self.local_att = nn.Sequential(
|
28 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
29 |
+
nn.BatchNorm2d(inter_channels),
|
30 |
+
nn.ReLU(inplace=True),
|
31 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
32 |
+
nn.BatchNorm2d(channels),
|
33 |
+
)
|
34 |
+
|
35 |
+
# 全局注意力
|
36 |
+
self.global_att = nn.Sequential(
|
37 |
+
nn.AdaptiveAvgPool2d(1),
|
38 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
39 |
+
nn.BatchNorm2d(inter_channels),
|
40 |
+
nn.ReLU(inplace=True),
|
41 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
42 |
+
nn.BatchNorm2d(channels),
|
43 |
+
)
|
44 |
+
|
45 |
+
# 第二次本地注意力
|
46 |
+
self.local_att2 = nn.Sequential(
|
47 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
48 |
+
nn.BatchNorm2d(inter_channels),
|
49 |
+
nn.ReLU(inplace=True),
|
50 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
51 |
+
nn.BatchNorm2d(channels),
|
52 |
+
)
|
53 |
+
# 第二次全局注意力
|
54 |
+
self.global_att2 = nn.Sequential(
|
55 |
+
nn.AdaptiveAvgPool2d(1),
|
56 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
57 |
+
nn.BatchNorm2d(inter_channels),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
60 |
+
nn.BatchNorm2d(channels),
|
61 |
+
)
|
62 |
+
|
63 |
+
self.sigmoid = nn.Sigmoid()
|
64 |
+
|
65 |
+
def forward(self, x, residual):
|
66 |
+
xa = x + residual
|
67 |
+
xl = self.local_att(xa)
|
68 |
+
xg = self.global_att(xa)
|
69 |
+
xlg = xl + xg
|
70 |
+
wei = self.sigmoid(xlg)
|
71 |
+
xi = x * wei + residual * (1 - wei)
|
72 |
+
|
73 |
+
xl2 = self.local_att2(xi)
|
74 |
+
xg2 = self.global_att(xi)
|
75 |
+
xlg2 = xl2 + xg2
|
76 |
+
wei2 = self.sigmoid(xlg2)
|
77 |
+
xo = x * wei2 + residual * (1 - wei2)
|
78 |
+
return xo
|
79 |
+
|
80 |
+
|
81 |
+
class AFF(nn.Module):
|
82 |
+
'''
|
83 |
+
多特征融合 AFF
|
84 |
+
'''
|
85 |
+
|
86 |
+
def __init__(self, channels=64, r=4):
|
87 |
+
super(AFF, self).__init__()
|
88 |
+
inter_channels = int(channels // r)
|
89 |
+
|
90 |
+
self.local_att = nn.Sequential(
|
91 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
92 |
+
nn.BatchNorm2d(inter_channels),
|
93 |
+
nn.ReLU(inplace=True),
|
94 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
95 |
+
nn.BatchNorm2d(channels),
|
96 |
+
)
|
97 |
+
|
98 |
+
self.global_att = nn.Sequential(
|
99 |
+
nn.AdaptiveAvgPool2d(1),
|
100 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
101 |
+
nn.BatchNorm2d(inter_channels),
|
102 |
+
nn.ReLU(inplace=True),
|
103 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
104 |
+
nn.BatchNorm2d(channels),
|
105 |
+
)
|
106 |
+
|
107 |
+
self.sigmoid = nn.Sigmoid()
|
108 |
+
|
109 |
+
def forward(self, x, residual):
|
110 |
+
xa = x + residual
|
111 |
+
xl = self.local_att(xa)
|
112 |
+
xg = self.global_att(xa)
|
113 |
+
xlg = xl + xg
|
114 |
+
wei = self.sigmoid(xlg)
|
115 |
+
|
116 |
+
xo = 2 * x * wei + 2 * residual * (1 - wei)
|
117 |
+
return xo
|
118 |
+
|
119 |
+
|
120 |
+
class MS_CAM(nn.Module):
|
121 |
+
'''
|
122 |
+
单特征 进行通道加权,作用类似SE模块
|
123 |
+
'''
|
124 |
+
|
125 |
+
def __init__(self, channels=64, r=4):
|
126 |
+
super(MS_CAM, self).__init__()
|
127 |
+
inter_channels = int(channels // r)
|
128 |
+
|
129 |
+
self.local_att = nn.Sequential(
|
130 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
131 |
+
nn.BatchNorm2d(inter_channels),
|
132 |
+
nn.ReLU(inplace=True),
|
133 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
134 |
+
nn.BatchNorm2d(channels),
|
135 |
+
)
|
136 |
+
|
137 |
+
self.global_att = nn.Sequential(
|
138 |
+
nn.AdaptiveAvgPool2d(1),
|
139 |
+
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
|
140 |
+
nn.BatchNorm2d(inter_channels),
|
141 |
+
nn.ReLU(inplace=True),
|
142 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
143 |
+
nn.BatchNorm2d(channels),
|
144 |
+
)
|
145 |
+
|
146 |
+
self.sigmoid = nn.Sigmoid()
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
xl = self.local_att(x)
|
150 |
+
xg = self.global_att(x)
|
151 |
+
xlg = xl + xg
|
152 |
+
wei = self.sigmoid(xlg)
|
153 |
+
return x * wei
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
import os
|
159 |
+
device = torch.device("cpu")
|
160 |
+
x = torch.ones(1, 2, 2, 2).to(device)
|
161 |
+
print(x)
|
162 |
+
a = x[0]
|
163 |
+
print(a)
|
164 |
+
b = torch.ones(2, 2, 2)
|
165 |
+
c = torch.stack((a, b))
|
166 |
+
print(x.shape)
|
167 |
+
# x, residual= torch.ones(1, 2, 2, 2).to(device), torch.ones(1,64, 32, 32).to(device)
|
168 |
+
# x = torch.cat(x, dim=1)
|
169 |
+
# print(x.shape)
|
170 |
+
# channels=x.shape[1]
|
171 |
+
# print(channels)
|
172 |
+
# model=AFF(channels=channels)
|
173 |
+
# model=model.to(device).train()
|
174 |
+
# output = model(x, residual)
|
175 |
+
# print(output.shape)
|
Tiger Model/diffusiers-Tiger/getWeight.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
from pathlib import Path
|
6 |
+
from pynvml import *
|
7 |
+
import accelerate
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
13 |
+
|
14 |
+
tensor1 = torch.tensor([[49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
|
15 |
+
32955, 267, 38692, 267, 13989, 43204, 267, 1042, 13989, 49407,
|
16 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
17 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
18 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
19 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
20 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
21 |
+
0, 0, 0, 0, 0, 0, 0],
|
22 |
+
[49406, 1884, 33667, 267, 41122, 3633, 267, 21263, 268, 1126,
|
23 |
+
268, 7771, 267, 6148, 267, 32955, 267, 13989, 43204, 267,
|
24 |
+
1042, 13989, 267, 1579, 3396, 267, 2442, 1579, 3396, 49407,
|
25 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
26 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
27 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
28 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
29 |
+
0, 0, 0, 0, 0, 0, 0],
|
30 |
+
[49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
|
31 |
+
3143, 267, 6307, 267, 1070, 1042, 13989, 49407, 0, 0,
|
32 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
33 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
34 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
35 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
36 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
37 |
+
0, 0, 0, 0, 0, 0, 0],
|
38 |
+
[49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
|
39 |
+
46131, 267, 3143, 267, 6307, 49407, 0, 0, 0, 0,
|
40 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
41 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
42 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
43 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
44 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
45 |
+
0, 0, 0, 0, 0, 0, 0],
|
46 |
+
[49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
|
47 |
+
6148, 267, 32955, 267, 38692, 267, 13989, 43204, 267, 1042,
|
48 |
+
13989, 267, 1579, 3396, 267, 5094, 268, 789, 1579, 3396,
|
49 |
+
49407, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
50 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
51 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
52 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
53 |
+
0, 0, 0, 0, 0, 0, 0],
|
54 |
+
[49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
|
55 |
+
32955, 267, 38692, 6448, 49407, 0, 0, 0, 0, 0,
|
56 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
57 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
58 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
59 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
60 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
61 |
+
0, 0, 0, 0, 0, 0, 0]])
|
62 |
+
|
63 |
+
|
64 |
+
l = tensor1.tolist()
|
65 |
+
list_2 = sum(l, [])
|
66 |
+
def remove_item(n):
|
67 |
+
return n != 0 and n !=49407 and n!=49406 and n!=267
|
68 |
+
list_3 = list(filter(remove_item, list_2))
|
69 |
+
|
70 |
+
dict = {}
|
71 |
+
for key in list_3:
|
72 |
+
dict[key] = dict.get(key, 0) + 1
|
73 |
+
print(dict)
|
74 |
+
|
75 |
+
revision = None
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
77 |
+
"/export/home/daifang/Diffusion/diffusers/model/sd-8_28",
|
78 |
+
subfolder="tokenizer",
|
79 |
+
revision=revision,
|
80 |
+
use_fast=False,
|
81 |
+
)
|
82 |
+
# captions = ['papillary blood flow', 'malignant follicular, solid, unclear, irregular, hales, circular, enormous, white point', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo, white points, sand-like white points', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, extremely low echo, white points, sand-like white points', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo, white points, sand-like white points']
|
83 |
+
|
84 |
+
|
85 |
+
['papillary, taller-than-wide, solid, unclear, irregular, echo uneven, low echo', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo, white points, sand-like white points', 'papillary, wider-than-tall, unclear, irregular, echo uneven, low echo, white points, sand-like white points', 'papillary, taller-than-wide, solid, unclear, irregular, echo uneven, low echo', 'papillary, taller-than-wide, solid, unclear, irregular, echo uneven, low echo, white points, large white points', 'No focus']
|
86 |
+
inputs = tokenizer(
|
87 |
+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
88 |
+
print(inputs)
|
Tiger Model/diffusiers-Tiger/image_processor.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
from typing import List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import PIL
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from .configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
25 |
+
|
26 |
+
|
27 |
+
class VaeImageProcessor(ConfigMixin):
|
28 |
+
"""
|
29 |
+
Image processor for VAE.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
33 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
34 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
35 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
36 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
37 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
38 |
+
Resampling filter to use when resizing the image.
|
39 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
40 |
+
Whether to normalize the image to [-1,1].
|
41 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
42 |
+
Whether to convert the images to RGB format.
|
43 |
+
"""
|
44 |
+
|
45 |
+
config_name = CONFIG_NAME
|
46 |
+
|
47 |
+
@register_to_config
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
do_resize: bool = True,
|
51 |
+
vae_scale_factor: int = 8,
|
52 |
+
resample: str = "lanczos",
|
53 |
+
do_normalize: bool = True,
|
54 |
+
do_convert_rgb: bool = False,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
|
60 |
+
"""
|
61 |
+
Convert a numpy image or a batch of images to a PIL image.
|
62 |
+
"""
|
63 |
+
if images.ndim == 3:
|
64 |
+
images = images[None, ...]
|
65 |
+
images = (images * 255).round().astype("uint8")
|
66 |
+
if images.shape[-1] == 1:
|
67 |
+
# special case for grayscale (single channel) images
|
68 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
69 |
+
else:
|
70 |
+
pil_images = [Image.fromarray(image) for image in images]
|
71 |
+
|
72 |
+
return pil_images
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
76 |
+
"""
|
77 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
78 |
+
"""
|
79 |
+
if not isinstance(images, list):
|
80 |
+
images = [images]
|
81 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
82 |
+
images = np.stack(images, axis=0)
|
83 |
+
|
84 |
+
return images
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
88 |
+
"""
|
89 |
+
Convert a NumPy image to a PyTorch tensor.
|
90 |
+
"""
|
91 |
+
if images.ndim == 3:
|
92 |
+
images = images[..., None]
|
93 |
+
|
94 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
95 |
+
return images
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
99 |
+
"""
|
100 |
+
Convert a PyTorch tensor to a NumPy image.
|
101 |
+
"""
|
102 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
103 |
+
return images
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def normalize(images):
|
107 |
+
"""
|
108 |
+
Normalize an image array to [-1,1].
|
109 |
+
"""
|
110 |
+
return 2.0 * images - 1.0
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def denormalize(images):
|
114 |
+
"""
|
115 |
+
Denormalize an image array to [0,1].
|
116 |
+
"""
|
117 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
121 |
+
"""
|
122 |
+
Converts an image to RGB format.
|
123 |
+
"""
|
124 |
+
image = image.convert("RGB")
|
125 |
+
return image
|
126 |
+
|
127 |
+
def resize(
|
128 |
+
self,
|
129 |
+
image: PIL.Image.Image,
|
130 |
+
height: Optional[int] = None,
|
131 |
+
width: Optional[int] = None,
|
132 |
+
) -> PIL.Image.Image:
|
133 |
+
"""
|
134 |
+
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
|
135 |
+
"""
|
136 |
+
if height is None:
|
137 |
+
height = image.height
|
138 |
+
if width is None:
|
139 |
+
width = image.width
|
140 |
+
|
141 |
+
width, height = (
|
142 |
+
x - x % self.config.vae_scale_factor for x in (width, height)
|
143 |
+
) # resize to integer multiple of vae_scale_factor
|
144 |
+
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
145 |
+
return image
|
146 |
+
|
147 |
+
def preprocess(
|
148 |
+
self,
|
149 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
150 |
+
height: Optional[int] = None,
|
151 |
+
width: Optional[int] = None,
|
152 |
+
) -> torch.Tensor:
|
153 |
+
"""
|
154 |
+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
155 |
+
"""
|
156 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
157 |
+
if isinstance(image, supported_formats):
|
158 |
+
image = [image]
|
159 |
+
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
160 |
+
raise ValueError(
|
161 |
+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
162 |
+
)
|
163 |
+
|
164 |
+
if isinstance(image[0], PIL.Image.Image):
|
165 |
+
if self.config.do_convert_rgb:
|
166 |
+
image = [self.convert_to_rgb(i) for i in image]
|
167 |
+
if self.config.do_resize:
|
168 |
+
image = [self.resize(i, height, width) for i in image]
|
169 |
+
image = self.pil_to_numpy(image) # to np
|
170 |
+
image = self.numpy_to_pt(image) # to pt
|
171 |
+
|
172 |
+
elif isinstance(image[0], np.ndarray):
|
173 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
174 |
+
image = self.numpy_to_pt(image)
|
175 |
+
_, _, height, width = image.shape
|
176 |
+
if self.config.do_resize and (
|
177 |
+
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
178 |
+
):
|
179 |
+
raise ValueError(
|
180 |
+
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
|
181 |
+
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
182 |
+
)
|
183 |
+
|
184 |
+
elif isinstance(image[0], torch.Tensor):
|
185 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
186 |
+
_, channel, height, width = image.shape
|
187 |
+
|
188 |
+
# don't need any preprocess if the image is latents
|
189 |
+
if channel == 4:
|
190 |
+
return image
|
191 |
+
|
192 |
+
if self.config.do_resize and (
|
193 |
+
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
194 |
+
):
|
195 |
+
raise ValueError(
|
196 |
+
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
|
197 |
+
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
198 |
+
)
|
199 |
+
|
200 |
+
# expected range [0,1], normalize to [-1,1]
|
201 |
+
do_normalize = self.config.do_normalize
|
202 |
+
if image.min() < 0:
|
203 |
+
warnings.warn(
|
204 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
205 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
206 |
+
FutureWarning,
|
207 |
+
)
|
208 |
+
do_normalize = False
|
209 |
+
|
210 |
+
if do_normalize:
|
211 |
+
image = self.normalize(image)
|
212 |
+
|
213 |
+
return image
|
214 |
+
|
215 |
+
def postprocess(
|
216 |
+
self,
|
217 |
+
image: torch.FloatTensor,
|
218 |
+
output_type: str = "pil",
|
219 |
+
do_denormalize: Optional[List[bool]] = None,
|
220 |
+
):
|
221 |
+
if not isinstance(image, torch.Tensor):
|
222 |
+
raise ValueError(
|
223 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
224 |
+
)
|
225 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
226 |
+
deprecation_message = (
|
227 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
228 |
+
"`pil`, `np`, `pt`, `latent`"
|
229 |
+
)
|
230 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
231 |
+
output_type = "np"
|
232 |
+
|
233 |
+
if output_type == "latent":
|
234 |
+
return image
|
235 |
+
|
236 |
+
if do_denormalize is None:
|
237 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
238 |
+
|
239 |
+
image = torch.stack(
|
240 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
241 |
+
)
|
242 |
+
|
243 |
+
if output_type == "pt":
|
244 |
+
return image
|
245 |
+
|
246 |
+
image = self.pt_to_numpy(image)
|
247 |
+
|
248 |
+
if output_type == "np":
|
249 |
+
return image
|
250 |
+
|
251 |
+
if output_type == "pil":
|
252 |
+
return self.numpy_to_pil(image)
|
253 |
+
|
254 |
+
|
255 |
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
256 |
+
"""
|
257 |
+
Image processor for VAE LDM3D.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
261 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
262 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
263 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
264 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
265 |
+
Resampling filter to use when resizing the image.
|
266 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
267 |
+
Whether to normalize the image to [-1,1].
|
268 |
+
"""
|
269 |
+
|
270 |
+
config_name = CONFIG_NAME
|
271 |
+
|
272 |
+
@register_to_config
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
do_resize: bool = True,
|
276 |
+
vae_scale_factor: int = 8,
|
277 |
+
resample: str = "lanczos",
|
278 |
+
do_normalize: bool = True,
|
279 |
+
):
|
280 |
+
super().__init__()
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def numpy_to_pil(images):
|
284 |
+
"""
|
285 |
+
Convert a NumPy image or a batch of images to a PIL image.
|
286 |
+
"""
|
287 |
+
if images.ndim == 3:
|
288 |
+
images = images[None, ...]
|
289 |
+
images = (images * 255).round().astype("uint8")
|
290 |
+
if images.shape[-1] == 1:
|
291 |
+
# special case for grayscale (single channel) images
|
292 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
293 |
+
else:
|
294 |
+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
295 |
+
|
296 |
+
return pil_images
|
297 |
+
|
298 |
+
@staticmethod
|
299 |
+
def rgblike_to_depthmap(image):
|
300 |
+
"""
|
301 |
+
Args:
|
302 |
+
image: RGB-like depth image
|
303 |
+
|
304 |
+
Returns: depth map
|
305 |
+
|
306 |
+
"""
|
307 |
+
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
308 |
+
|
309 |
+
def numpy_to_depth(self, images):
|
310 |
+
"""
|
311 |
+
Convert a NumPy depth image or a batch of images to a PIL image.
|
312 |
+
"""
|
313 |
+
if images.ndim == 3:
|
314 |
+
images = images[None, ...]
|
315 |
+
images_depth = images[:, :, :, 3:]
|
316 |
+
if images.shape[-1] == 6:
|
317 |
+
images_depth = (images_depth * 255).round().astype("uint8")
|
318 |
+
pil_images = [
|
319 |
+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
320 |
+
]
|
321 |
+
elif images.shape[-1] == 4:
|
322 |
+
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
323 |
+
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
324 |
+
else:
|
325 |
+
raise Exception("Not supported")
|
326 |
+
|
327 |
+
return pil_images
|
328 |
+
|
329 |
+
def postprocess(
|
330 |
+
self,
|
331 |
+
image: torch.FloatTensor,
|
332 |
+
output_type: str = "pil",
|
333 |
+
do_denormalize: Optional[List[bool]] = None,
|
334 |
+
):
|
335 |
+
if not isinstance(image, torch.Tensor):
|
336 |
+
raise ValueError(
|
337 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
338 |
+
)
|
339 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
340 |
+
deprecation_message = (
|
341 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
342 |
+
"`pil`, `np`, `pt`, `latent`"
|
343 |
+
)
|
344 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
345 |
+
output_type = "np"
|
346 |
+
|
347 |
+
if do_denormalize is None:
|
348 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
349 |
+
|
350 |
+
image = torch.stack(
|
351 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
352 |
+
)
|
353 |
+
|
354 |
+
image = self.pt_to_numpy(image)
|
355 |
+
|
356 |
+
if output_type == "np":
|
357 |
+
if image.shape[-1] == 6:
|
358 |
+
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
359 |
+
else:
|
360 |
+
image_depth = image[:, :, :, 3:]
|
361 |
+
return image[:, :, :, :3], image_depth
|
362 |
+
|
363 |
+
if output_type == "pil":
|
364 |
+
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
365 |
+
else:
|
366 |
+
raise Exception(f"This type {output_type} is not supported")
|
Tiger Model/diffusiers-Tiger/loaders.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Tiger Model/diffusiers-Tiger/models/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Models
|
2 |
+
|
3 |
+
For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
|
Tiger Model/diffusiers-Tiger/models/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from ..utils import is_flax_available, is_torch_available
|
16 |
+
|
17 |
+
|
18 |
+
if is_torch_available():
|
19 |
+
from .adapter import MultiAdapter, T2IAdapter
|
20 |
+
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
21 |
+
from .autoencoder_kl import AutoencoderKL
|
22 |
+
from .autoencoder_tiny import AutoencoderTiny
|
23 |
+
from .controlnet import ControlNetModel
|
24 |
+
from .dual_transformer_2d import DualTransformer2DModel
|
25 |
+
from .modeling_utils import ModelMixin
|
26 |
+
from .prior_transformer import PriorTransformer
|
27 |
+
from .t5_film_transformer import T5FilmDecoder
|
28 |
+
from .transformer_2d import Transformer2DModel
|
29 |
+
from .unet_1d import UNet1DModel
|
30 |
+
from .unet_2d import UNet2DModel
|
31 |
+
from .unet_2d_condition import UNet2DConditionModel
|
32 |
+
from .modeling_utils import ModelMixin
|
33 |
+
from .unet_3d_condition import UNet3DConditionModel
|
34 |
+
from .vq_model import VQModel
|
35 |
+
|
36 |
+
if is_flax_available():
|
37 |
+
from .controlnet_flax import FlaxControlNetModel
|
38 |
+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
39 |
+
from .vae_flax import FlaxAutoencoderKL
|
Tiger Model/diffusiers-Tiger/models/activations.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation(act_fn):
|
5 |
+
if act_fn in ["swish", "silu"]:
|
6 |
+
return nn.SiLU()
|
7 |
+
elif act_fn == "mish":
|
8 |
+
return nn.Mish()
|
9 |
+
elif act_fn == "gelu":
|
10 |
+
return nn.GELU()
|
11 |
+
elif act_fn == "relu":
|
12 |
+
return nn.ReLU()
|
13 |
+
else:
|
14 |
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
Tiger Model/diffusiers-Tiger/models/adapter.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import List, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from .modeling_utils import ModelMixin
|
22 |
+
from .resnet import Downsample2D
|
23 |
+
|
24 |
+
|
25 |
+
class MultiAdapter(ModelMixin):
|
26 |
+
r"""
|
27 |
+
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
|
28 |
+
user-assigned weighting.
|
29 |
+
|
30 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
31 |
+
implements for all the model (such as downloading or saving, etc.)
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
|
35 |
+
A list of `T2IAdapter` model instances.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, adapters: List["T2IAdapter"]):
|
39 |
+
super(MultiAdapter, self).__init__()
|
40 |
+
|
41 |
+
self.num_adapter = len(adapters)
|
42 |
+
self.adapters = nn.ModuleList(adapters)
|
43 |
+
|
44 |
+
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
|
45 |
+
r"""
|
46 |
+
Args:
|
47 |
+
xs (`torch.Tensor`):
|
48 |
+
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
|
49 |
+
`channel` should equal to `num_adapter` * "number of channel of image".
|
50 |
+
adapter_weights (`List[float]`, *optional*, defaults to None):
|
51 |
+
List of floats representing the weight which will be multiply to each adapter's output before adding
|
52 |
+
them together.
|
53 |
+
"""
|
54 |
+
if adapter_weights is None:
|
55 |
+
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
|
56 |
+
else:
|
57 |
+
adapter_weights = torch.tensor(adapter_weights)
|
58 |
+
|
59 |
+
if xs.shape[1] % self.num_adapter != 0:
|
60 |
+
raise ValueError(
|
61 |
+
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
|
62 |
+
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
|
63 |
+
)
|
64 |
+
x_list = torch.chunk(xs, self.num_adapter, dim=1)
|
65 |
+
accume_state = None
|
66 |
+
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
|
67 |
+
features = adapter(x)
|
68 |
+
if accume_state is None:
|
69 |
+
accume_state = features
|
70 |
+
else:
|
71 |
+
for i in range(len(features)):
|
72 |
+
accume_state[i] += w * features[i]
|
73 |
+
return accume_state
|
74 |
+
|
75 |
+
|
76 |
+
class T2IAdapter(ModelMixin, ConfigMixin):
|
77 |
+
r"""
|
78 |
+
A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
|
79 |
+
generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
|
80 |
+
architecture follows the original implementation of
|
81 |
+
[Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
|
82 |
+
and
|
83 |
+
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
|
84 |
+
|
85 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
86 |
+
implements for all the model (such as downloading or saving, etc.)
|
87 |
+
|
88 |
+
Parameters:
|
89 |
+
in_channels (`int`, *optional*, defaults to 3):
|
90 |
+
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
|
91 |
+
image as *control image*.
|
92 |
+
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
93 |
+
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
|
94 |
+
also determine the number of downsample blocks in the Adapter.
|
95 |
+
num_res_blocks (`int`, *optional*, defaults to 2):
|
96 |
+
Number of ResNet blocks in each downsample block
|
97 |
+
"""
|
98 |
+
|
99 |
+
@register_to_config
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
in_channels: int = 3,
|
103 |
+
channels: List[int] = [320, 640, 1280, 1280],
|
104 |
+
num_res_blocks: int = 2,
|
105 |
+
downscale_factor: int = 8,
|
106 |
+
adapter_type: str = "full_adapter",
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
if adapter_type == "full_adapter":
|
111 |
+
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
112 |
+
elif adapter_type == "light_adapter":
|
113 |
+
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
114 |
+
else:
|
115 |
+
raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
|
116 |
+
|
117 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
118 |
+
return self.adapter(x)
|
119 |
+
|
120 |
+
@property
|
121 |
+
def total_downscale_factor(self):
|
122 |
+
return self.adapter.total_downscale_factor
|
123 |
+
|
124 |
+
|
125 |
+
# full adapter
|
126 |
+
|
127 |
+
|
128 |
+
class FullAdapter(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
in_channels: int = 3,
|
132 |
+
channels: List[int] = [320, 640, 1280, 1280],
|
133 |
+
num_res_blocks: int = 2,
|
134 |
+
downscale_factor: int = 8,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
in_channels = in_channels * downscale_factor**2
|
139 |
+
|
140 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
141 |
+
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
|
142 |
+
|
143 |
+
self.body = nn.ModuleList(
|
144 |
+
[
|
145 |
+
AdapterBlock(channels[0], channels[0], num_res_blocks),
|
146 |
+
*[
|
147 |
+
AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
|
148 |
+
for i in range(1, len(channels))
|
149 |
+
],
|
150 |
+
]
|
151 |
+
)
|
152 |
+
|
153 |
+
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
|
154 |
+
|
155 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
156 |
+
x = self.unshuffle(x)
|
157 |
+
x = self.conv_in(x)
|
158 |
+
|
159 |
+
features = []
|
160 |
+
|
161 |
+
for block in self.body:
|
162 |
+
x = block(x)
|
163 |
+
features.append(x)
|
164 |
+
|
165 |
+
return features
|
166 |
+
|
167 |
+
|
168 |
+
class AdapterBlock(nn.Module):
|
169 |
+
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
self.downsample = None
|
173 |
+
if down:
|
174 |
+
self.downsample = Downsample2D(in_channels)
|
175 |
+
|
176 |
+
self.in_conv = None
|
177 |
+
if in_channels != out_channels:
|
178 |
+
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
179 |
+
|
180 |
+
self.resnets = nn.Sequential(
|
181 |
+
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
|
182 |
+
)
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
if self.downsample is not None:
|
186 |
+
x = self.downsample(x)
|
187 |
+
|
188 |
+
if self.in_conv is not None:
|
189 |
+
x = self.in_conv(x)
|
190 |
+
|
191 |
+
x = self.resnets(x)
|
192 |
+
|
193 |
+
return x
|
194 |
+
|
195 |
+
|
196 |
+
class AdapterResnetBlock(nn.Module):
|
197 |
+
def __init__(self, channels):
|
198 |
+
super().__init__()
|
199 |
+
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
200 |
+
self.act = nn.ReLU()
|
201 |
+
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
h = x
|
205 |
+
h = self.block1(h)
|
206 |
+
h = self.act(h)
|
207 |
+
h = self.block2(h)
|
208 |
+
|
209 |
+
return h + x
|
210 |
+
|
211 |
+
|
212 |
+
# light adapter
|
213 |
+
|
214 |
+
|
215 |
+
class LightAdapter(nn.Module):
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
in_channels: int = 3,
|
219 |
+
channels: List[int] = [320, 640, 1280],
|
220 |
+
num_res_blocks: int = 4,
|
221 |
+
downscale_factor: int = 8,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
|
225 |
+
in_channels = in_channels * downscale_factor**2
|
226 |
+
|
227 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
228 |
+
|
229 |
+
self.body = nn.ModuleList(
|
230 |
+
[
|
231 |
+
LightAdapterBlock(in_channels, channels[0], num_res_blocks),
|
232 |
+
*[
|
233 |
+
LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
|
234 |
+
for i in range(len(channels) - 1)
|
235 |
+
],
|
236 |
+
LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
|
237 |
+
]
|
238 |
+
)
|
239 |
+
|
240 |
+
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
x = self.unshuffle(x)
|
244 |
+
|
245 |
+
features = []
|
246 |
+
|
247 |
+
for block in self.body:
|
248 |
+
x = block(x)
|
249 |
+
features.append(x)
|
250 |
+
|
251 |
+
return features
|
252 |
+
|
253 |
+
|
254 |
+
class LightAdapterBlock(nn.Module):
|
255 |
+
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
256 |
+
super().__init__()
|
257 |
+
mid_channels = out_channels // 4
|
258 |
+
|
259 |
+
self.downsample = None
|
260 |
+
if down:
|
261 |
+
self.downsample = Downsample2D(in_channels)
|
262 |
+
|
263 |
+
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
|
264 |
+
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
|
265 |
+
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
if self.downsample is not None:
|
269 |
+
x = self.downsample(x)
|
270 |
+
|
271 |
+
x = self.in_conv(x)
|
272 |
+
x = self.resnets(x)
|
273 |
+
x = self.out_conv(x)
|
274 |
+
|
275 |
+
return x
|
276 |
+
|
277 |
+
|
278 |
+
class LightAdapterResnetBlock(nn.Module):
|
279 |
+
def __init__(self, channels):
|
280 |
+
super().__init__()
|
281 |
+
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
282 |
+
self.act = nn.ReLU()
|
283 |
+
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
h = x
|
287 |
+
h = self.block1(h)
|
288 |
+
h = self.act(h)
|
289 |
+
h = self.block2(h)
|
290 |
+
|
291 |
+
return h + x
|
Tiger Model/diffusiers-Tiger/models/attention.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..utils import maybe_allow_in_graph
|
21 |
+
from .activations import get_activation
|
22 |
+
from .attention_processor import Attention
|
23 |
+
from .embeddings import CombinedTimestepLabelEmbeddings
|
24 |
+
from .lora import LoRACompatibleLinear
|
25 |
+
|
26 |
+
|
27 |
+
@maybe_allow_in_graph
|
28 |
+
class GatedSelfAttentionDense(nn.Module):
|
29 |
+
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
33 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
34 |
+
|
35 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
36 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
37 |
+
|
38 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
39 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
40 |
+
|
41 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
42 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
43 |
+
|
44 |
+
self.enabled = True
|
45 |
+
|
46 |
+
def forward(self, x, objs):
|
47 |
+
if not self.enabled:
|
48 |
+
return x
|
49 |
+
|
50 |
+
n_visual = x.shape[1]
|
51 |
+
objs = self.linear(objs)
|
52 |
+
|
53 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
54 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
55 |
+
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
@maybe_allow_in_graph
|
60 |
+
class BasicTransformerBlock(nn.Module):
|
61 |
+
r"""
|
62 |
+
A basic Transformer block.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
dim (`int`): The number of channels in the input and output.
|
66 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
67 |
+
attention_head_dim (`int`): The number of channels in each head.
|
68 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
69 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
70 |
+
only_cross_attention (`bool`, *optional*):
|
71 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
72 |
+
double_self_attention (`bool`, *optional*):
|
73 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
74 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
75 |
+
num_embeds_ada_norm (:
|
76 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
77 |
+
attention_bias (:
|
78 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
dim: int,
|
84 |
+
num_attention_heads: int,
|
85 |
+
attention_head_dim: int,
|
86 |
+
dropout=0.0,
|
87 |
+
cross_attention_dim: Optional[int] = None,
|
88 |
+
activation_fn: str = "geglu",
|
89 |
+
num_embeds_ada_norm: Optional[int] = None,
|
90 |
+
attention_bias: bool = False,
|
91 |
+
only_cross_attention: bool = False,
|
92 |
+
double_self_attention: bool = False,
|
93 |
+
upcast_attention: bool = False,
|
94 |
+
norm_elementwise_affine: bool = True,
|
95 |
+
norm_type: str = "layer_norm",
|
96 |
+
final_dropout: bool = False,
|
97 |
+
attention_type: str = "default",
|
98 |
+
weight: Optional[torch.LongTensor] = None,
|
99 |
+
):
|
100 |
+
super().__init__()
|
101 |
+
self.only_cross_attention = only_cross_attention
|
102 |
+
|
103 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
104 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
105 |
+
|
106 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
107 |
+
raise ValueError(
|
108 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
109 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
110 |
+
)
|
111 |
+
|
112 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
113 |
+
# 1. Self-Attn
|
114 |
+
if self.use_ada_layer_norm:
|
115 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
116 |
+
elif self.use_ada_layer_norm_zero:
|
117 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
118 |
+
else:
|
119 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
120 |
+
self.attn1 = Attention(
|
121 |
+
query_dim=dim,
|
122 |
+
heads=num_attention_heads,
|
123 |
+
dim_head=attention_head_dim,
|
124 |
+
dropout=dropout,
|
125 |
+
bias=attention_bias,
|
126 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
127 |
+
upcast_attention=upcast_attention,
|
128 |
+
)
|
129 |
+
|
130 |
+
# 2. Cross-Attn
|
131 |
+
if cross_attention_dim is not None or double_self_attention:
|
132 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
133 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
134 |
+
# the second cross attention block.
|
135 |
+
self.norm2 = (
|
136 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
137 |
+
if self.use_ada_layer_norm
|
138 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
139 |
+
)
|
140 |
+
self.attn2 = Attention(
|
141 |
+
query_dim=dim,
|
142 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
143 |
+
heads=num_attention_heads,
|
144 |
+
dim_head=attention_head_dim,
|
145 |
+
dropout=dropout,
|
146 |
+
bias=attention_bias,
|
147 |
+
upcast_attention=upcast_attention,
|
148 |
+
) # is self-attn if encoder_hidden_states is none
|
149 |
+
else:
|
150 |
+
self.norm2 = None
|
151 |
+
self.attn2 = None
|
152 |
+
|
153 |
+
# 3. Feed-forward
|
154 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
155 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
156 |
+
|
157 |
+
# 4. Fuser
|
158 |
+
if attention_type == "gated":
|
159 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
160 |
+
|
161 |
+
# let chunk size default to None
|
162 |
+
self._chunk_size = None
|
163 |
+
self._chunk_dim = 0
|
164 |
+
|
165 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
166 |
+
# Sets chunk feed-forward
|
167 |
+
self._chunk_size = chunk_size
|
168 |
+
self._chunk_dim = dim
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
hidden_states: torch.FloatTensor,
|
173 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
174 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
175 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
176 |
+
timestep: Optional[torch.LongTensor] = None,
|
177 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
178 |
+
class_labels: Optional[torch.LongTensor] = None,
|
179 |
+
weight : Optional[torch.LongTensor] = None,
|
180 |
+
):
|
181 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
182 |
+
# 1. Self-Attention
|
183 |
+
if self.use_ada_layer_norm:
|
184 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
185 |
+
elif self.use_ada_layer_norm_zero:
|
186 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
187 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
norm_hidden_states = self.norm1(hidden_states)
|
191 |
+
|
192 |
+
# 0. Prepare GLIGEN inputs
|
193 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
194 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
195 |
+
|
196 |
+
attn_output = self.attn1(
|
197 |
+
norm_hidden_states,
|
198 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
199 |
+
attention_mask=attention_mask,
|
200 |
+
weight = weight,
|
201 |
+
**cross_attention_kwargs,
|
202 |
+
)
|
203 |
+
if self.use_ada_layer_norm_zero:
|
204 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
205 |
+
hidden_states = attn_output + hidden_states
|
206 |
+
|
207 |
+
# 1.5 GLIGEN Control
|
208 |
+
if gligen_kwargs is not None:
|
209 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
210 |
+
# 1.5 ends
|
211 |
+
|
212 |
+
# 2. Cross-Attention
|
213 |
+
if self.attn2 is not None:
|
214 |
+
norm_hidden_states = (
|
215 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
216 |
+
)
|
217 |
+
|
218 |
+
attn_output = self.attn2(
|
219 |
+
norm_hidden_states,
|
220 |
+
encoder_hidden_states=encoder_hidden_states,
|
221 |
+
attention_mask=encoder_attention_mask,
|
222 |
+
**cross_attention_kwargs,
|
223 |
+
)
|
224 |
+
hidden_states = attn_output + hidden_states
|
225 |
+
|
226 |
+
# 3. Feed-forward
|
227 |
+
norm_hidden_states = self.norm3(hidden_states)
|
228 |
+
|
229 |
+
if self.use_ada_layer_norm_zero:
|
230 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
231 |
+
|
232 |
+
if self._chunk_size is not None:
|
233 |
+
# "feed_forward_chunk_size" can be used to save memory
|
234 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
235 |
+
raise ValueError(
|
236 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
237 |
+
)
|
238 |
+
|
239 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
240 |
+
ff_output = torch.cat(
|
241 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
242 |
+
dim=self._chunk_dim,
|
243 |
+
)
|
244 |
+
else:
|
245 |
+
ff_output = self.ff(norm_hidden_states)
|
246 |
+
|
247 |
+
if self.use_ada_layer_norm_zero:
|
248 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
249 |
+
|
250 |
+
hidden_states = ff_output + hidden_states
|
251 |
+
|
252 |
+
return hidden_states
|
253 |
+
|
254 |
+
|
255 |
+
class FeedForward(nn.Module):
|
256 |
+
r"""
|
257 |
+
A feed-forward layer.
|
258 |
+
|
259 |
+
Parameters:
|
260 |
+
dim (`int`): The number of channels in the input.
|
261 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
262 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
263 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
264 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
265 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
266 |
+
"""
|
267 |
+
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
dim: int,
|
271 |
+
dim_out: Optional[int] = None,
|
272 |
+
mult: int = 4,
|
273 |
+
dropout: float = 0.0,
|
274 |
+
activation_fn: str = "geglu",
|
275 |
+
final_dropout: bool = False,
|
276 |
+
):
|
277 |
+
super().__init__()
|
278 |
+
inner_dim = int(dim * mult)
|
279 |
+
dim_out = dim_out if dim_out is not None else dim
|
280 |
+
|
281 |
+
if activation_fn == "gelu":
|
282 |
+
act_fn = GELU(dim, inner_dim)
|
283 |
+
if activation_fn == "gelu-approximate":
|
284 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
285 |
+
elif activation_fn == "geglu":
|
286 |
+
act_fn = GEGLU(dim, inner_dim)
|
287 |
+
elif activation_fn == "geglu-approximate":
|
288 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
289 |
+
|
290 |
+
self.net = nn.ModuleList([])
|
291 |
+
# project in
|
292 |
+
self.net.append(act_fn)
|
293 |
+
# project dropout
|
294 |
+
self.net.append(nn.Dropout(dropout))
|
295 |
+
# project out
|
296 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
297 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
298 |
+
if final_dropout:
|
299 |
+
self.net.append(nn.Dropout(dropout))
|
300 |
+
|
301 |
+
def forward(self, hidden_states):
|
302 |
+
for module in self.net:
|
303 |
+
hidden_states = module(hidden_states)
|
304 |
+
return hidden_states
|
305 |
+
|
306 |
+
|
307 |
+
class GELU(nn.Module):
|
308 |
+
r"""
|
309 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
310 |
+
"""
|
311 |
+
|
312 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
313 |
+
super().__init__()
|
314 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
315 |
+
self.approximate = approximate
|
316 |
+
|
317 |
+
def gelu(self, gate):
|
318 |
+
if gate.device.type != "mps":
|
319 |
+
return F.gelu(gate, approximate=self.approximate)
|
320 |
+
# mps: gelu is not implemented for float16
|
321 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
322 |
+
|
323 |
+
def forward(self, hidden_states):
|
324 |
+
hidden_states = self.proj(hidden_states)
|
325 |
+
hidden_states = self.gelu(hidden_states)
|
326 |
+
return hidden_states
|
327 |
+
|
328 |
+
|
329 |
+
class GEGLU(nn.Module):
|
330 |
+
r"""
|
331 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
332 |
+
|
333 |
+
Parameters:
|
334 |
+
dim_in (`int`): The number of channels in the input.
|
335 |
+
dim_out (`int`): The number of channels in the output.
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(self, dim_in: int, dim_out: int):
|
339 |
+
super().__init__()
|
340 |
+
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
|
341 |
+
|
342 |
+
def gelu(self, gate):
|
343 |
+
if gate.device.type != "mps":
|
344 |
+
return F.gelu(gate)
|
345 |
+
# mps: gelu is not implemented for float16
|
346 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
347 |
+
|
348 |
+
def forward(self, hidden_states):
|
349 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
350 |
+
return hidden_states * self.gelu(gate)
|
351 |
+
|
352 |
+
|
353 |
+
class ApproximateGELU(nn.Module):
|
354 |
+
"""
|
355 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
356 |
+
|
357 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
358 |
+
"""
|
359 |
+
|
360 |
+
def __init__(self, dim_in: int, dim_out: int):
|
361 |
+
super().__init__()
|
362 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
363 |
+
|
364 |
+
def forward(self, x):
|
365 |
+
x = self.proj(x)
|
366 |
+
return x * torch.sigmoid(1.702 * x)
|
367 |
+
|
368 |
+
|
369 |
+
class AdaLayerNorm(nn.Module):
|
370 |
+
"""
|
371 |
+
Norm layer modified to incorporate timestep embeddings.
|
372 |
+
"""
|
373 |
+
|
374 |
+
def __init__(self, embedding_dim, num_embeddings):
|
375 |
+
super().__init__()
|
376 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
377 |
+
self.silu = nn.SiLU()
|
378 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
379 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
380 |
+
|
381 |
+
def forward(self, x, timestep):
|
382 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
383 |
+
scale, shift = torch.chunk(emb, 2)
|
384 |
+
x = self.norm(x) * (1 + scale) + shift
|
385 |
+
return x
|
386 |
+
|
387 |
+
|
388 |
+
class AdaLayerNormZero(nn.Module):
|
389 |
+
"""
|
390 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __init__(self, embedding_dim, num_embeddings):
|
394 |
+
super().__init__()
|
395 |
+
|
396 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
397 |
+
|
398 |
+
self.silu = nn.SiLU()
|
399 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
400 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
401 |
+
|
402 |
+
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
403 |
+
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
404 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
405 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
406 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
407 |
+
|
408 |
+
|
409 |
+
class AdaGroupNorm(nn.Module):
|
410 |
+
"""
|
411 |
+
GroupNorm layer modified to incorporate timestep embeddings.
|
412 |
+
"""
|
413 |
+
|
414 |
+
def __init__(
|
415 |
+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
416 |
+
):
|
417 |
+
super().__init__()
|
418 |
+
self.num_groups = num_groups
|
419 |
+
self.eps = eps
|
420 |
+
|
421 |
+
if act_fn is None:
|
422 |
+
self.act = None
|
423 |
+
else:
|
424 |
+
self.act = get_activation(act_fn)
|
425 |
+
|
426 |
+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
427 |
+
|
428 |
+
def forward(self, x, emb):
|
429 |
+
if self.act:
|
430 |
+
emb = self.act(emb)
|
431 |
+
emb = self.linear(emb)
|
432 |
+
emb = emb[:, :, None, None]
|
433 |
+
scale, shift = emb.chunk(2, dim=1)
|
434 |
+
|
435 |
+
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
436 |
+
x = x * (1 + scale) + shift
|
437 |
+
return x
|
Tiger Model/diffusiers-Tiger/models/attention_processor.py
ADDED
@@ -0,0 +1,1716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Callable, Optional, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..utils import deprecate, logging, maybe_allow_in_graph
|
21 |
+
from ..utils.import_utils import is_xformers_available
|
22 |
+
from .lora import LoRALinearLayer
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
26 |
+
|
27 |
+
|
28 |
+
if is_xformers_available():
|
29 |
+
import xformers
|
30 |
+
import xformers.ops
|
31 |
+
else:
|
32 |
+
xformers = None
|
33 |
+
|
34 |
+
|
35 |
+
@maybe_allow_in_graph
|
36 |
+
class Attention(nn.Module):
|
37 |
+
r"""
|
38 |
+
A cross attention layer.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
query_dim (`int`): The number of channels in the query.
|
42 |
+
cross_attention_dim (`int`, *optional*):
|
43 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
44 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
45 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
46 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
47 |
+
bias (`bool`, *optional*, defaults to False):
|
48 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
query_dim: int,
|
54 |
+
cross_attention_dim: Optional[int] = None,
|
55 |
+
heads: int = 8,
|
56 |
+
dim_head: int = 64,
|
57 |
+
dropout: float = 0.0,
|
58 |
+
bias=False,
|
59 |
+
upcast_attention: bool = False,
|
60 |
+
upcast_softmax: bool = False,
|
61 |
+
cross_attention_norm: Optional[str] = None,
|
62 |
+
cross_attention_norm_num_groups: int = 32,
|
63 |
+
added_kv_proj_dim: Optional[int] = None,
|
64 |
+
norm_num_groups: Optional[int] = None,
|
65 |
+
spatial_norm_dim: Optional[int] = None,
|
66 |
+
out_bias: bool = True,
|
67 |
+
scale_qk: bool = True,
|
68 |
+
only_cross_attention: bool = False,
|
69 |
+
eps: float = 1e-5,
|
70 |
+
rescale_output_factor: float = 1.0,
|
71 |
+
residual_connection: bool = False,
|
72 |
+
_from_deprecated_attn_block=False,
|
73 |
+
processor: Optional["AttnProcessor"] = None,
|
74 |
+
weight : Optional[torch.FloatTensor] = None,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
inner_dim = dim_head * heads
|
78 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
79 |
+
self.upcast_attention = upcast_attention
|
80 |
+
self.upcast_softmax = upcast_softmax
|
81 |
+
self.rescale_output_factor = rescale_output_factor
|
82 |
+
self.residual_connection = residual_connection
|
83 |
+
self.dropout = dropout
|
84 |
+
self.weight = weight
|
85 |
+
# we make use of this private variable to know whether this class is loaded
|
86 |
+
# with an deprecated state dict so that we can convert it on the fly
|
87 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
88 |
+
|
89 |
+
self.scale_qk = scale_qk
|
90 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
91 |
+
|
92 |
+
self.heads = heads
|
93 |
+
# for slice_size > 0 the attention score computation
|
94 |
+
# is split across the batch axis to save memory
|
95 |
+
# You can set slice_size with `set_attention_slice`
|
96 |
+
self.sliceable_head_dim = heads
|
97 |
+
|
98 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
99 |
+
self.only_cross_attention = only_cross_attention
|
100 |
+
|
101 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
102 |
+
raise ValueError(
|
103 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
104 |
+
)
|
105 |
+
|
106 |
+
if norm_num_groups is not None:
|
107 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
108 |
+
else:
|
109 |
+
self.group_norm = None
|
110 |
+
|
111 |
+
if spatial_norm_dim is not None:
|
112 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
113 |
+
else:
|
114 |
+
self.spatial_norm = None
|
115 |
+
|
116 |
+
if cross_attention_norm is None:
|
117 |
+
self.norm_cross = None
|
118 |
+
elif cross_attention_norm == "layer_norm":
|
119 |
+
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
120 |
+
elif cross_attention_norm == "group_norm":
|
121 |
+
if self.added_kv_proj_dim is not None:
|
122 |
+
# The given `encoder_hidden_states` are initially of shape
|
123 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
124 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
125 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
126 |
+
# the number of channels for the group norm.
|
127 |
+
norm_cross_num_channels = added_kv_proj_dim
|
128 |
+
else:
|
129 |
+
norm_cross_num_channels = cross_attention_dim
|
130 |
+
|
131 |
+
self.norm_cross = nn.GroupNorm(
|
132 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
raise ValueError(
|
136 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
137 |
+
)
|
138 |
+
|
139 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
140 |
+
|
141 |
+
if not self.only_cross_attention:
|
142 |
+
# only relevant for the `AddedKVProcessor` classes
|
143 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
144 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
145 |
+
else:
|
146 |
+
self.to_k = None
|
147 |
+
self.to_v = None
|
148 |
+
|
149 |
+
if self.added_kv_proj_dim is not None:
|
150 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
151 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
152 |
+
|
153 |
+
self.to_out = nn.ModuleList([])
|
154 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
155 |
+
self.to_out.append(nn.Dropout(dropout))
|
156 |
+
|
157 |
+
# set attention processor
|
158 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
159 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
160 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
161 |
+
if processor is None:
|
162 |
+
processor = (
|
163 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
164 |
+
)
|
165 |
+
self.set_processor(processor)
|
166 |
+
|
167 |
+
def set_use_memory_efficient_attention_xformers(
|
168 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
169 |
+
):
|
170 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
171 |
+
self.processor,
|
172 |
+
LORA_ATTENTION_PROCESSORS,
|
173 |
+
)
|
174 |
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
175 |
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
176 |
+
)
|
177 |
+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
178 |
+
self.processor,
|
179 |
+
(
|
180 |
+
AttnAddedKVProcessor,
|
181 |
+
AttnAddedKVProcessor2_0,
|
182 |
+
SlicedAttnAddedKVProcessor,
|
183 |
+
XFormersAttnAddedKVProcessor,
|
184 |
+
LoRAAttnAddedKVProcessor,
|
185 |
+
),
|
186 |
+
)
|
187 |
+
|
188 |
+
if use_memory_efficient_attention_xformers:
|
189 |
+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
190 |
+
raise NotImplementedError(
|
191 |
+
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
|
192 |
+
)
|
193 |
+
if not is_xformers_available():
|
194 |
+
raise ModuleNotFoundError(
|
195 |
+
(
|
196 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
197 |
+
" xformers"
|
198 |
+
),
|
199 |
+
name="xformers",
|
200 |
+
)
|
201 |
+
elif not torch.cuda.is_available():
|
202 |
+
raise ValueError(
|
203 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
204 |
+
" only available for GPU "
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
try:
|
208 |
+
# Make sure we can run the memory efficient attention
|
209 |
+
_ = xformers.ops.memory_efficient_attention(
|
210 |
+
torch.randn((1, 2, 40), device="cuda"),
|
211 |
+
torch.randn((1, 2, 40), device="cuda"),
|
212 |
+
torch.randn((1, 2, 40), device="cuda"),
|
213 |
+
)
|
214 |
+
except Exception as e:
|
215 |
+
raise e
|
216 |
+
|
217 |
+
if is_lora:
|
218 |
+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
219 |
+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
220 |
+
processor = LoRAXFormersAttnProcessor(
|
221 |
+
hidden_size=self.processor.hidden_size,
|
222 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
223 |
+
rank=self.processor.rank,
|
224 |
+
attention_op=attention_op,
|
225 |
+
)
|
226 |
+
processor.load_state_dict(self.processor.state_dict())
|
227 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
228 |
+
elif is_custom_diffusion:
|
229 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
230 |
+
train_kv=self.processor.train_kv,
|
231 |
+
train_q_out=self.processor.train_q_out,
|
232 |
+
hidden_size=self.processor.hidden_size,
|
233 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
234 |
+
attention_op=attention_op,
|
235 |
+
)
|
236 |
+
processor.load_state_dict(self.processor.state_dict())
|
237 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
238 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
239 |
+
elif is_added_kv_processor:
|
240 |
+
# throw warning
|
241 |
+
logger.info(
|
242 |
+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
243 |
+
)
|
244 |
+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
245 |
+
else:
|
246 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
247 |
+
else:
|
248 |
+
if is_lora:
|
249 |
+
attn_processor_class = (
|
250 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
251 |
+
)
|
252 |
+
processor = attn_processor_class(
|
253 |
+
hidden_size=self.processor.hidden_size,
|
254 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
255 |
+
rank=self.processor.rank,
|
256 |
+
)
|
257 |
+
processor.load_state_dict(self.processor.state_dict())
|
258 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
259 |
+
elif is_custom_diffusion:
|
260 |
+
processor = CustomDiffusionAttnProcessor(
|
261 |
+
train_kv=self.processor.train_kv,
|
262 |
+
train_q_out=self.processor.train_q_out,
|
263 |
+
hidden_size=self.processor.hidden_size,
|
264 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
265 |
+
)
|
266 |
+
processor.load_state_dict(self.processor.state_dict())
|
267 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
268 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
269 |
+
else:
|
270 |
+
processor = (
|
271 |
+
AttnProcessor2_0()
|
272 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
273 |
+
else AttnProcessor()
|
274 |
+
)
|
275 |
+
|
276 |
+
self.set_processor(processor)
|
277 |
+
|
278 |
+
def set_attention_slice(self, slice_size):
|
279 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
280 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
281 |
+
|
282 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
283 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
284 |
+
elif slice_size is not None:
|
285 |
+
processor = SlicedAttnProcessor(slice_size)
|
286 |
+
elif self.added_kv_proj_dim is not None:
|
287 |
+
processor = AttnAddedKVProcessor()
|
288 |
+
else:
|
289 |
+
processor = (
|
290 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
291 |
+
)
|
292 |
+
|
293 |
+
self.set_processor(processor)
|
294 |
+
|
295 |
+
def set_processor(self, processor: "AttnProcessor"):
|
296 |
+
if (
|
297 |
+
hasattr(self, "processor")
|
298 |
+
and isinstance(self.processor, torch.nn.Module)
|
299 |
+
and not isinstance(processor, torch.nn.Module)
|
300 |
+
):
|
301 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
302 |
+
self._modules.pop("processor")
|
303 |
+
|
304 |
+
self.processor = processor
|
305 |
+
|
306 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, weight=None, **cross_attention_kwargs):
|
307 |
+
return self.processor(
|
308 |
+
self,
|
309 |
+
hidden_states,
|
310 |
+
encoder_hidden_states=encoder_hidden_states,
|
311 |
+
attention_mask=attention_mask,
|
312 |
+
weight = weight,
|
313 |
+
**cross_attention_kwargs,
|
314 |
+
)
|
315 |
+
|
316 |
+
def batch_to_head_dim(self, tensor):
|
317 |
+
head_size = self.heads
|
318 |
+
batch_size, seq_len, dim = tensor.shape
|
319 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
320 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
321 |
+
return tensor
|
322 |
+
|
323 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
324 |
+
head_size = self.heads
|
325 |
+
batch_size, seq_len, dim = tensor.shape
|
326 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
327 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
328 |
+
|
329 |
+
if out_dim == 3:
|
330 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
331 |
+
|
332 |
+
return tensor
|
333 |
+
|
334 |
+
def get_attention_scores(self, query, key, weight, attention_mask=None,):
|
335 |
+
dtype = query.dtype
|
336 |
+
if self.upcast_attention:
|
337 |
+
query = query.float()
|
338 |
+
key = key.float()
|
339 |
+
if attention_mask is None:
|
340 |
+
baddbmm_input = torch.empty(
|
341 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
342 |
+
)
|
343 |
+
beta = 0
|
344 |
+
else:
|
345 |
+
baddbmm_input = attention_mask
|
346 |
+
beta = 1
|
347 |
+
|
348 |
+
attention_scores = torch.baddbmm(
|
349 |
+
baddbmm_input,
|
350 |
+
query,
|
351 |
+
key.transpose(-1, -2),
|
352 |
+
beta=beta,
|
353 |
+
alpha=self.scale,
|
354 |
+
)
|
355 |
+
|
356 |
+
del baddbmm_input
|
357 |
+
|
358 |
+
if self.upcast_softmax:
|
359 |
+
attention_scores = attention_scores.float()
|
360 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
361 |
+
del attention_scores
|
362 |
+
|
363 |
+
attention_probs = attention_probs.to(dtype)
|
364 |
+
|
365 |
+
return attention_probs
|
366 |
+
|
367 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
368 |
+
if batch_size is None:
|
369 |
+
deprecate(
|
370 |
+
"batch_size=None",
|
371 |
+
"0.0.15",
|
372 |
+
(
|
373 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
374 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
375 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
376 |
+
),
|
377 |
+
)
|
378 |
+
batch_size = 1
|
379 |
+
head_size = self.heads
|
380 |
+
if attention_mask is None:
|
381 |
+
return attention_mask
|
382 |
+
|
383 |
+
current_length: int = attention_mask.shape[-1]
|
384 |
+
if current_length != target_length:
|
385 |
+
if attention_mask.device.type == "mps":
|
386 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
387 |
+
# Instead, we can manually construct the padding tensor.
|
388 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
389 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
390 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
391 |
+
else:
|
392 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
393 |
+
|
394 |
+
if out_dim == 3:
|
395 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
396 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
397 |
+
elif out_dim == 4:
|
398 |
+
attention_mask = attention_mask.unsqueeze(1)
|
399 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
400 |
+
return attention_mask
|
401 |
+
|
402 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
403 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
404 |
+
|
405 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
406 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
407 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
408 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
409 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
410 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
411 |
+
else:
|
412 |
+
assert False
|
413 |
+
|
414 |
+
return encoder_hidden_states
|
415 |
+
|
416 |
+
class AttnProcessor:
|
417 |
+
r"""
|
418 |
+
Default processor for performing attention-related computations.
|
419 |
+
"""
|
420 |
+
|
421 |
+
def __call__(
|
422 |
+
self,
|
423 |
+
attn: Attention,
|
424 |
+
hidden_states,
|
425 |
+
encoder_hidden_states=None,
|
426 |
+
attention_mask=None,
|
427 |
+
temb=None,
|
428 |
+
weight = None,):
|
429 |
+
residual = hidden_states
|
430 |
+
|
431 |
+
if attn.spatial_norm is not None:
|
432 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
433 |
+
|
434 |
+
input_ndim = hidden_states.ndim
|
435 |
+
|
436 |
+
if input_ndim == 4:
|
437 |
+
batch_size, channel, height, width = hidden_states.shape
|
438 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
439 |
+
|
440 |
+
batch_size, sequence_length, _ = (
|
441 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
442 |
+
)
|
443 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
444 |
+
if attn.group_norm is not None:
|
445 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
446 |
+
|
447 |
+
if weight is not None:
|
448 |
+
multiplier = weight.unsqueeze(1).unsqueeze(2)
|
449 |
+
hidden_states = hidden_states * multiplier
|
450 |
+
query = attn.to_q(hidden_states)
|
451 |
+
|
452 |
+
|
453 |
+
if encoder_hidden_states is None:
|
454 |
+
encoder_hidden_states = hidden_states
|
455 |
+
elif attn.norm_cross:
|
456 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
457 |
+
key = attn.to_k(encoder_hidden_states)
|
458 |
+
value = attn.to_v(encoder_hidden_states)
|
459 |
+
|
460 |
+
query = attn.head_to_batch_dim(query)
|
461 |
+
key = attn.head_to_batch_dim(key)
|
462 |
+
value = attn.head_to_batch_dim(value)
|
463 |
+
|
464 |
+
attention_probs = attn.get_attention_scores(query, key, weight, attention_mask)
|
465 |
+
hidden_states = torch.bmm(attention_probs, value)
|
466 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
467 |
+
|
468 |
+
# linear proj
|
469 |
+
hidden_states = attn.to_out[0](hidden_states)
|
470 |
+
# dropout
|
471 |
+
hidden_states = attn.to_out[1](hidden_states)
|
472 |
+
|
473 |
+
if input_ndim == 4:
|
474 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
475 |
+
|
476 |
+
if attn.residual_connection:
|
477 |
+
hidden_states = hidden_states + residual
|
478 |
+
|
479 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
480 |
+
|
481 |
+
return hidden_states
|
482 |
+
|
483 |
+
class Guid_AttnProcessor:
|
484 |
+
r"""
|
485 |
+
Default processor for performing attention-related computations.
|
486 |
+
"""
|
487 |
+
|
488 |
+
def __call__(
|
489 |
+
self,
|
490 |
+
attn: Attention,
|
491 |
+
hidden_states,
|
492 |
+
encoder_hidden_states=None,
|
493 |
+
attention_mask=None,
|
494 |
+
temb=None,
|
495 |
+
):
|
496 |
+
residual = hidden_states
|
497 |
+
|
498 |
+
if attn.spatial_norm is not None:
|
499 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
500 |
+
|
501 |
+
input_ndim = hidden_states.ndim
|
502 |
+
|
503 |
+
if input_ndim == 4:
|
504 |
+
batch_size, channel, height, width = hidden_states.shape
|
505 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
506 |
+
|
507 |
+
batch_size, sequence_length, _ = (
|
508 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
509 |
+
)
|
510 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
511 |
+
|
512 |
+
if attn.group_norm is not None:
|
513 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
514 |
+
|
515 |
+
query = attn.to_q(hidden_states)
|
516 |
+
|
517 |
+
if encoder_hidden_states is None:
|
518 |
+
encoder_hidden_states = hidden_states
|
519 |
+
elif attn.norm_cross:
|
520 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
521 |
+
|
522 |
+
key = attn.to_k(encoder_hidden_states)
|
523 |
+
value = attn.to_v(encoder_hidden_states)
|
524 |
+
|
525 |
+
query = attn.head_to_batch_dim(query)
|
526 |
+
key = attn.head_to_batch_dim(key)
|
527 |
+
value = attn.head_to_batch_dim(value)
|
528 |
+
|
529 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
530 |
+
hidden_states = torch.bmm(attention_probs, value)
|
531 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
532 |
+
|
533 |
+
# linear proj
|
534 |
+
hidden_states = attn.to_out[0](hidden_states)
|
535 |
+
# dropout
|
536 |
+
hidden_states = attn.to_out[1](hidden_states)
|
537 |
+
|
538 |
+
if input_ndim == 4:
|
539 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
540 |
+
|
541 |
+
if attn.residual_connection:
|
542 |
+
hidden_states = hidden_states + residual
|
543 |
+
|
544 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
545 |
+
|
546 |
+
return hidden_states
|
547 |
+
|
548 |
+
class LoRAAttnProcessor(nn.Module):
|
549 |
+
r"""
|
550 |
+
Processor for implementing the LoRA attention mechanism.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
hidden_size (`int`, *optional*):
|
554 |
+
The hidden size of the attention layer.
|
555 |
+
cross_attention_dim (`int`, *optional*):
|
556 |
+
The number of channels in the `encoder_hidden_states`.
|
557 |
+
rank (`int`, defaults to 4):
|
558 |
+
The dimension of the LoRA update matrices.
|
559 |
+
network_alpha (`int`, *optional*):
|
560 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
561 |
+
"""
|
562 |
+
|
563 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
|
564 |
+
super().__init__()
|
565 |
+
|
566 |
+
self.hidden_size = hidden_size
|
567 |
+
self.cross_attention_dim = cross_attention_dim
|
568 |
+
self.rank = rank
|
569 |
+
|
570 |
+
q_rank = kwargs.pop("q_rank", None)
|
571 |
+
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
572 |
+
q_rank = q_rank if q_rank is not None else rank
|
573 |
+
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
574 |
+
|
575 |
+
v_rank = kwargs.pop("v_rank", None)
|
576 |
+
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
577 |
+
v_rank = v_rank if v_rank is not None else rank
|
578 |
+
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
579 |
+
|
580 |
+
out_rank = kwargs.pop("out_rank", None)
|
581 |
+
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
582 |
+
out_rank = out_rank if out_rank is not None else rank
|
583 |
+
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
584 |
+
|
585 |
+
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
586 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
587 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
588 |
+
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
589 |
+
|
590 |
+
def __call__(
|
591 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
592 |
+
):
|
593 |
+
residual = hidden_states
|
594 |
+
|
595 |
+
if attn.spatial_norm is not None:
|
596 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
597 |
+
|
598 |
+
input_ndim = hidden_states.ndim
|
599 |
+
|
600 |
+
if input_ndim == 4:
|
601 |
+
batch_size, channel, height, width = hidden_states.shape
|
602 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
603 |
+
|
604 |
+
batch_size, sequence_length, _ = (
|
605 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
606 |
+
)
|
607 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
608 |
+
|
609 |
+
if attn.group_norm is not None:
|
610 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
611 |
+
|
612 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
613 |
+
query = attn.head_to_batch_dim(query)
|
614 |
+
|
615 |
+
if encoder_hidden_states is None:
|
616 |
+
encoder_hidden_states = hidden_states
|
617 |
+
elif attn.norm_cross:
|
618 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
619 |
+
|
620 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
621 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
622 |
+
|
623 |
+
key = attn.head_to_batch_dim(key)
|
624 |
+
value = attn.head_to_batch_dim(value)
|
625 |
+
|
626 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
627 |
+
hidden_states = torch.bmm(attention_probs, value)
|
628 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
629 |
+
|
630 |
+
# linear proj
|
631 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
632 |
+
# dropout
|
633 |
+
hidden_states = attn.to_out[1](hidden_states)
|
634 |
+
|
635 |
+
if input_ndim == 4:
|
636 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
637 |
+
|
638 |
+
if attn.residual_connection:
|
639 |
+
hidden_states = hidden_states + residual
|
640 |
+
|
641 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
642 |
+
|
643 |
+
return hidden_states
|
644 |
+
|
645 |
+
class CustomDiffusionAttnProcessor(nn.Module):
|
646 |
+
r"""
|
647 |
+
Processor for implementing attention for the Custom Diffusion method.
|
648 |
+
|
649 |
+
Args:
|
650 |
+
train_kv (`bool`, defaults to `True`):
|
651 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
652 |
+
train_q_out (`bool`, defaults to `True`):
|
653 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
654 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
655 |
+
The hidden size of the attention layer.
|
656 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
657 |
+
The number of channels in the `encoder_hidden_states`.
|
658 |
+
out_bias (`bool`, defaults to `True`):
|
659 |
+
Whether to include the bias parameter in `train_q_out`.
|
660 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
661 |
+
The dropout probability to use.
|
662 |
+
"""
|
663 |
+
|
664 |
+
def __init__(
|
665 |
+
self,
|
666 |
+
train_kv=True,
|
667 |
+
train_q_out=True,
|
668 |
+
hidden_size=None,
|
669 |
+
cross_attention_dim=None,
|
670 |
+
out_bias=True,
|
671 |
+
dropout=0.0,
|
672 |
+
):
|
673 |
+
super().__init__()
|
674 |
+
self.train_kv = train_kv
|
675 |
+
self.train_q_out = train_q_out
|
676 |
+
|
677 |
+
self.hidden_size = hidden_size
|
678 |
+
self.cross_attention_dim = cross_attention_dim
|
679 |
+
|
680 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
681 |
+
if self.train_kv:
|
682 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
683 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
684 |
+
if self.train_q_out:
|
685 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
686 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
687 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
688 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
689 |
+
|
690 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
691 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
692 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
693 |
+
if self.train_q_out:
|
694 |
+
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
|
695 |
+
else:
|
696 |
+
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
|
697 |
+
|
698 |
+
if encoder_hidden_states is None:
|
699 |
+
crossattn = False
|
700 |
+
encoder_hidden_states = hidden_states
|
701 |
+
else:
|
702 |
+
crossattn = True
|
703 |
+
if attn.norm_cross:
|
704 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
705 |
+
|
706 |
+
if self.train_kv:
|
707 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
|
708 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
|
709 |
+
key = key.to(attn.to_q.weight.dtype)
|
710 |
+
value = value.to(attn.to_q.weight.dtype)
|
711 |
+
else:
|
712 |
+
key = attn.to_k(encoder_hidden_states)
|
713 |
+
value = attn.to_v(encoder_hidden_states)
|
714 |
+
|
715 |
+
if crossattn:
|
716 |
+
detach = torch.ones_like(key)
|
717 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
718 |
+
key = detach * key + (1 - detach) * key.detach()
|
719 |
+
value = detach * value + (1 - detach) * value.detach()
|
720 |
+
|
721 |
+
query = attn.head_to_batch_dim(query)
|
722 |
+
key = attn.head_to_batch_dim(key)
|
723 |
+
value = attn.head_to_batch_dim(value)
|
724 |
+
|
725 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
726 |
+
hidden_states = torch.bmm(attention_probs, value)
|
727 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
728 |
+
|
729 |
+
if self.train_q_out:
|
730 |
+
# linear proj
|
731 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
732 |
+
# dropout
|
733 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
734 |
+
else:
|
735 |
+
# linear proj
|
736 |
+
hidden_states = attn.to_out[0](hidden_states)
|
737 |
+
# dropout
|
738 |
+
hidden_states = attn.to_out[1](hidden_states)
|
739 |
+
|
740 |
+
return hidden_states
|
741 |
+
|
742 |
+
class AttnAddedKVProcessor:
|
743 |
+
r"""
|
744 |
+
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
745 |
+
encoder.
|
746 |
+
"""
|
747 |
+
|
748 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
749 |
+
residual = hidden_states
|
750 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
751 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
752 |
+
|
753 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
754 |
+
|
755 |
+
if encoder_hidden_states is None:
|
756 |
+
encoder_hidden_states = hidden_states
|
757 |
+
elif attn.norm_cross:
|
758 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
759 |
+
|
760 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
761 |
+
|
762 |
+
query = attn.to_q(hidden_states)
|
763 |
+
query = attn.head_to_batch_dim(query)
|
764 |
+
|
765 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
766 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
767 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
768 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
769 |
+
|
770 |
+
if not attn.only_cross_attention:
|
771 |
+
key = attn.to_k(hidden_states)
|
772 |
+
value = attn.to_v(hidden_states)
|
773 |
+
key = attn.head_to_batch_dim(key)
|
774 |
+
value = attn.head_to_batch_dim(value)
|
775 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
776 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
777 |
+
else:
|
778 |
+
key = encoder_hidden_states_key_proj
|
779 |
+
value = encoder_hidden_states_value_proj
|
780 |
+
|
781 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
782 |
+
hidden_states = torch.bmm(attention_probs, value)
|
783 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
784 |
+
|
785 |
+
# linear proj
|
786 |
+
hidden_states = attn.to_out[0](hidden_states)
|
787 |
+
# dropout
|
788 |
+
hidden_states = attn.to_out[1](hidden_states)
|
789 |
+
|
790 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
791 |
+
hidden_states = hidden_states + residual
|
792 |
+
|
793 |
+
return hidden_states
|
794 |
+
|
795 |
+
class AttnAddedKVProcessor2_0:
|
796 |
+
r"""
|
797 |
+
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
798 |
+
learnable key and value matrices for the text encoder.
|
799 |
+
"""
|
800 |
+
|
801 |
+
def __init__(self):
|
802 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
803 |
+
raise ImportError(
|
804 |
+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
805 |
+
)
|
806 |
+
|
807 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
808 |
+
residual = hidden_states
|
809 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
810 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
811 |
+
|
812 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
813 |
+
|
814 |
+
if encoder_hidden_states is None:
|
815 |
+
encoder_hidden_states = hidden_states
|
816 |
+
elif attn.norm_cross:
|
817 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
818 |
+
|
819 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
820 |
+
|
821 |
+
query = attn.to_q(hidden_states)
|
822 |
+
query = attn.head_to_batch_dim(query, out_dim=4)
|
823 |
+
|
824 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
825 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
826 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
827 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
828 |
+
|
829 |
+
if not attn.only_cross_attention:
|
830 |
+
key = attn.to_k(hidden_states)
|
831 |
+
value = attn.to_v(hidden_states)
|
832 |
+
key = attn.head_to_batch_dim(key, out_dim=4)
|
833 |
+
value = attn.head_to_batch_dim(value, out_dim=4)
|
834 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
835 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
836 |
+
else:
|
837 |
+
key = encoder_hidden_states_key_proj
|
838 |
+
value = encoder_hidden_states_value_proj
|
839 |
+
|
840 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
841 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
842 |
+
hidden_states = F.scaled_dot_product_attention(
|
843 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
844 |
+
)
|
845 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
846 |
+
|
847 |
+
# linear proj
|
848 |
+
hidden_states = attn.to_out[0](hidden_states)
|
849 |
+
# dropout
|
850 |
+
hidden_states = attn.to_out[1](hidden_states)
|
851 |
+
|
852 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
853 |
+
hidden_states = hidden_states + residual
|
854 |
+
|
855 |
+
return hidden_states
|
856 |
+
|
857 |
+
class LoRAAttnAddedKVProcessor(nn.Module):
|
858 |
+
r"""
|
859 |
+
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
860 |
+
encoder.
|
861 |
+
|
862 |
+
Args:
|
863 |
+
hidden_size (`int`, *optional*):
|
864 |
+
The hidden size of the attention layer.
|
865 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
866 |
+
The number of channels in the `encoder_hidden_states`.
|
867 |
+
rank (`int`, defaults to 4):
|
868 |
+
The dimension of the LoRA update matrices.
|
869 |
+
|
870 |
+
"""
|
871 |
+
|
872 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
873 |
+
super().__init__()
|
874 |
+
|
875 |
+
self.hidden_size = hidden_size
|
876 |
+
self.cross_attention_dim = cross_attention_dim
|
877 |
+
self.rank = rank
|
878 |
+
|
879 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
880 |
+
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
881 |
+
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
882 |
+
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
883 |
+
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
884 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
885 |
+
|
886 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
887 |
+
residual = hidden_states
|
888 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
889 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
890 |
+
|
891 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
892 |
+
|
893 |
+
if encoder_hidden_states is None:
|
894 |
+
encoder_hidden_states = hidden_states
|
895 |
+
elif attn.norm_cross:
|
896 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
897 |
+
|
898 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
899 |
+
|
900 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
901 |
+
query = attn.head_to_batch_dim(query)
|
902 |
+
|
903 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
904 |
+
encoder_hidden_states
|
905 |
+
)
|
906 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
907 |
+
encoder_hidden_states
|
908 |
+
)
|
909 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
910 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
911 |
+
|
912 |
+
if not attn.only_cross_attention:
|
913 |
+
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
914 |
+
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
915 |
+
key = attn.head_to_batch_dim(key)
|
916 |
+
value = attn.head_to_batch_dim(value)
|
917 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
918 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
919 |
+
else:
|
920 |
+
key = encoder_hidden_states_key_proj
|
921 |
+
value = encoder_hidden_states_value_proj
|
922 |
+
|
923 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
924 |
+
hidden_states = torch.bmm(attention_probs, value)
|
925 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
926 |
+
|
927 |
+
# linear proj
|
928 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
929 |
+
# dropout
|
930 |
+
hidden_states = attn.to_out[1](hidden_states)
|
931 |
+
|
932 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
933 |
+
hidden_states = hidden_states + residual
|
934 |
+
|
935 |
+
return hidden_states
|
936 |
+
|
937 |
+
|
938 |
+
class XFormersAttnAddedKVProcessor:
|
939 |
+
r"""
|
940 |
+
Processor for implementing memory efficient attention using xFormers.
|
941 |
+
|
942 |
+
Args:
|
943 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
944 |
+
The base
|
945 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
946 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
947 |
+
operator.
|
948 |
+
"""
|
949 |
+
|
950 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
951 |
+
self.attention_op = attention_op
|
952 |
+
|
953 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
954 |
+
residual = hidden_states
|
955 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
956 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
957 |
+
|
958 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
959 |
+
|
960 |
+
if encoder_hidden_states is None:
|
961 |
+
encoder_hidden_states = hidden_states
|
962 |
+
elif attn.norm_cross:
|
963 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
964 |
+
|
965 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
966 |
+
|
967 |
+
query = attn.to_q(hidden_states)
|
968 |
+
query = attn.head_to_batch_dim(query)
|
969 |
+
|
970 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
971 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
972 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
973 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
974 |
+
|
975 |
+
if not attn.only_cross_attention:
|
976 |
+
key = attn.to_k(hidden_states)
|
977 |
+
value = attn.to_v(hidden_states)
|
978 |
+
key = attn.head_to_batch_dim(key)
|
979 |
+
value = attn.head_to_batch_dim(value)
|
980 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
981 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
982 |
+
else:
|
983 |
+
key = encoder_hidden_states_key_proj
|
984 |
+
value = encoder_hidden_states_value_proj
|
985 |
+
|
986 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
987 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
988 |
+
)
|
989 |
+
hidden_states = hidden_states.to(query.dtype)
|
990 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
991 |
+
|
992 |
+
# linear proj
|
993 |
+
hidden_states = attn.to_out[0](hidden_states)
|
994 |
+
# dropout
|
995 |
+
hidden_states = attn.to_out[1](hidden_states)
|
996 |
+
|
997 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
998 |
+
hidden_states = hidden_states + residual
|
999 |
+
|
1000 |
+
return hidden_states
|
1001 |
+
|
1002 |
+
|
1003 |
+
class XFormersAttnProcessor:
|
1004 |
+
r"""
|
1005 |
+
Processor for implementing memory efficient attention using xFormers.
|
1006 |
+
|
1007 |
+
Args:
|
1008 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1009 |
+
The base
|
1010 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1011 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1012 |
+
operator.
|
1013 |
+
"""
|
1014 |
+
|
1015 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1016 |
+
self.attention_op = attention_op
|
1017 |
+
|
1018 |
+
def __call__(
|
1019 |
+
self,
|
1020 |
+
attn: Attention,
|
1021 |
+
hidden_states: torch.FloatTensor,
|
1022 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1023 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1024 |
+
temb: Optional[torch.FloatTensor] = None,
|
1025 |
+
):
|
1026 |
+
residual = hidden_states
|
1027 |
+
|
1028 |
+
if attn.spatial_norm is not None:
|
1029 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1030 |
+
|
1031 |
+
input_ndim = hidden_states.ndim
|
1032 |
+
|
1033 |
+
if input_ndim == 4:
|
1034 |
+
batch_size, channel, height, width = hidden_states.shape
|
1035 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1036 |
+
|
1037 |
+
batch_size, key_tokens, _ = (
|
1038 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
1042 |
+
if attention_mask is not None:
|
1043 |
+
# expand our mask's singleton query_tokens dimension:
|
1044 |
+
# [batch*heads, 1, key_tokens] ->
|
1045 |
+
# [batch*heads, query_tokens, key_tokens]
|
1046 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1047 |
+
# [batch*heads, query_tokens, key_tokens]
|
1048 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1049 |
+
_, query_tokens, _ = hidden_states.shape
|
1050 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1051 |
+
|
1052 |
+
if attn.group_norm is not None:
|
1053 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1054 |
+
|
1055 |
+
query = attn.to_q(hidden_states)
|
1056 |
+
|
1057 |
+
if encoder_hidden_states is None:
|
1058 |
+
encoder_hidden_states = hidden_states
|
1059 |
+
elif attn.norm_cross:
|
1060 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1061 |
+
|
1062 |
+
key = attn.to_k(encoder_hidden_states)
|
1063 |
+
value = attn.to_v(encoder_hidden_states)
|
1064 |
+
|
1065 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1066 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1067 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1068 |
+
|
1069 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1070 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1071 |
+
)
|
1072 |
+
hidden_states = hidden_states.to(query.dtype)
|
1073 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1074 |
+
|
1075 |
+
# linear proj
|
1076 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1077 |
+
# dropout
|
1078 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1079 |
+
|
1080 |
+
if input_ndim == 4:
|
1081 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1082 |
+
|
1083 |
+
if attn.residual_connection:
|
1084 |
+
hidden_states = hidden_states + residual
|
1085 |
+
|
1086 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1087 |
+
|
1088 |
+
return hidden_states
|
1089 |
+
|
1090 |
+
|
1091 |
+
class AttnProcessor2_0:
|
1092 |
+
r"""
|
1093 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1094 |
+
"""
|
1095 |
+
|
1096 |
+
def __init__(self):
|
1097 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1098 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1099 |
+
|
1100 |
+
def __call__(
|
1101 |
+
self,
|
1102 |
+
attn: Attention,
|
1103 |
+
hidden_states,
|
1104 |
+
encoder_hidden_states=None,
|
1105 |
+
attention_mask=None,
|
1106 |
+
temb=None,
|
1107 |
+
):
|
1108 |
+
residual = hidden_states
|
1109 |
+
|
1110 |
+
if attn.spatial_norm is not None:
|
1111 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1112 |
+
|
1113 |
+
input_ndim = hidden_states.ndim
|
1114 |
+
|
1115 |
+
if input_ndim == 4:
|
1116 |
+
batch_size, channel, height, width = hidden_states.shape
|
1117 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1118 |
+
|
1119 |
+
batch_size, sequence_length, _ = (
|
1120 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1121 |
+
)
|
1122 |
+
|
1123 |
+
if attention_mask is not None:
|
1124 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1125 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1126 |
+
# (batch, heads, source_length, target_length)
|
1127 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1128 |
+
|
1129 |
+
if attn.group_norm is not None:
|
1130 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1131 |
+
|
1132 |
+
query = attn.to_q(hidden_states)
|
1133 |
+
|
1134 |
+
if encoder_hidden_states is None:
|
1135 |
+
encoder_hidden_states = hidden_states
|
1136 |
+
elif attn.norm_cross:
|
1137 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1138 |
+
|
1139 |
+
key = attn.to_k(encoder_hidden_states)
|
1140 |
+
value = attn.to_v(encoder_hidden_states)
|
1141 |
+
|
1142 |
+
inner_dim = key.shape[-1]
|
1143 |
+
head_dim = inner_dim // attn.heads
|
1144 |
+
|
1145 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1146 |
+
|
1147 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1148 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1149 |
+
|
1150 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1151 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1152 |
+
hidden_states = F.scaled_dot_product_attention(
|
1153 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1154 |
+
)
|
1155 |
+
|
1156 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1157 |
+
hidden_states = hidden_states.to(query.dtype)
|
1158 |
+
|
1159 |
+
# linear proj
|
1160 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1161 |
+
# dropout
|
1162 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1163 |
+
|
1164 |
+
if input_ndim == 4:
|
1165 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1166 |
+
|
1167 |
+
if attn.residual_connection:
|
1168 |
+
hidden_states = hidden_states + residual
|
1169 |
+
|
1170 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1171 |
+
|
1172 |
+
return hidden_states
|
1173 |
+
|
1174 |
+
|
1175 |
+
class LoRAXFormersAttnProcessor(nn.Module):
|
1176 |
+
r"""
|
1177 |
+
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
1178 |
+
|
1179 |
+
Args:
|
1180 |
+
hidden_size (`int`, *optional*):
|
1181 |
+
The hidden size of the attention layer.
|
1182 |
+
cross_attention_dim (`int`, *optional*):
|
1183 |
+
The number of channels in the `encoder_hidden_states`.
|
1184 |
+
rank (`int`, defaults to 4):
|
1185 |
+
The dimension of the LoRA update matrices.
|
1186 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1187 |
+
The base
|
1188 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1189 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1190 |
+
operator.
|
1191 |
+
network_alpha (`int`, *optional*):
|
1192 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1193 |
+
|
1194 |
+
"""
|
1195 |
+
|
1196 |
+
def __init__(
|
1197 |
+
self,
|
1198 |
+
hidden_size,
|
1199 |
+
cross_attention_dim,
|
1200 |
+
rank=4,
|
1201 |
+
attention_op: Optional[Callable] = None,
|
1202 |
+
network_alpha=None,
|
1203 |
+
**kwargs,
|
1204 |
+
):
|
1205 |
+
super().__init__()
|
1206 |
+
|
1207 |
+
self.hidden_size = hidden_size
|
1208 |
+
self.cross_attention_dim = cross_attention_dim
|
1209 |
+
self.rank = rank
|
1210 |
+
self.attention_op = attention_op
|
1211 |
+
|
1212 |
+
q_rank = kwargs.pop("q_rank", None)
|
1213 |
+
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
1214 |
+
q_rank = q_rank if q_rank is not None else rank
|
1215 |
+
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
1216 |
+
|
1217 |
+
v_rank = kwargs.pop("v_rank", None)
|
1218 |
+
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
1219 |
+
v_rank = v_rank if v_rank is not None else rank
|
1220 |
+
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
1221 |
+
|
1222 |
+
out_rank = kwargs.pop("out_rank", None)
|
1223 |
+
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
1224 |
+
out_rank = out_rank if out_rank is not None else rank
|
1225 |
+
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
1226 |
+
|
1227 |
+
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
1228 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1229 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1230 |
+
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1231 |
+
|
1232 |
+
def __call__(
|
1233 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
1234 |
+
):
|
1235 |
+
residual = hidden_states
|
1236 |
+
|
1237 |
+
if attn.spatial_norm is not None:
|
1238 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1239 |
+
|
1240 |
+
input_ndim = hidden_states.ndim
|
1241 |
+
|
1242 |
+
if input_ndim == 4:
|
1243 |
+
batch_size, channel, height, width = hidden_states.shape
|
1244 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1245 |
+
|
1246 |
+
batch_size, sequence_length, _ = (
|
1247 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1248 |
+
)
|
1249 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1250 |
+
|
1251 |
+
if attn.group_norm is not None:
|
1252 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1253 |
+
|
1254 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1255 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1256 |
+
|
1257 |
+
if encoder_hidden_states is None:
|
1258 |
+
encoder_hidden_states = hidden_states
|
1259 |
+
elif attn.norm_cross:
|
1260 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1261 |
+
|
1262 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1263 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1264 |
+
|
1265 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1266 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1267 |
+
|
1268 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1269 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1270 |
+
)
|
1271 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1272 |
+
|
1273 |
+
# linear proj
|
1274 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1275 |
+
# dropout
|
1276 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1277 |
+
|
1278 |
+
if input_ndim == 4:
|
1279 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1280 |
+
|
1281 |
+
if attn.residual_connection:
|
1282 |
+
hidden_states = hidden_states + residual
|
1283 |
+
|
1284 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1285 |
+
|
1286 |
+
return hidden_states
|
1287 |
+
|
1288 |
+
|
1289 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
1290 |
+
r"""
|
1291 |
+
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1292 |
+
attention.
|
1293 |
+
|
1294 |
+
Args:
|
1295 |
+
hidden_size (`int`):
|
1296 |
+
The hidden size of the attention layer.
|
1297 |
+
cross_attention_dim (`int`, *optional*):
|
1298 |
+
The number of channels in the `encoder_hidden_states`.
|
1299 |
+
rank (`int`, defaults to 4):
|
1300 |
+
The dimension of the LoRA update matrices.
|
1301 |
+
network_alpha (`int`, *optional*):
|
1302 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1303 |
+
"""
|
1304 |
+
|
1305 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
|
1306 |
+
super().__init__()
|
1307 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1308 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1309 |
+
|
1310 |
+
self.hidden_size = hidden_size
|
1311 |
+
self.cross_attention_dim = cross_attention_dim
|
1312 |
+
self.rank = rank
|
1313 |
+
|
1314 |
+
q_rank = kwargs.pop("q_rank", None)
|
1315 |
+
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
1316 |
+
q_rank = q_rank if q_rank is not None else rank
|
1317 |
+
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
1318 |
+
|
1319 |
+
v_rank = kwargs.pop("v_rank", None)
|
1320 |
+
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
1321 |
+
v_rank = v_rank if v_rank is not None else rank
|
1322 |
+
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
1323 |
+
|
1324 |
+
out_rank = kwargs.pop("out_rank", None)
|
1325 |
+
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
1326 |
+
out_rank = out_rank if out_rank is not None else rank
|
1327 |
+
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
1328 |
+
|
1329 |
+
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
1330 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1331 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1332 |
+
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1333 |
+
|
1334 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
1335 |
+
residual = hidden_states
|
1336 |
+
|
1337 |
+
input_ndim = hidden_states.ndim
|
1338 |
+
|
1339 |
+
if input_ndim == 4:
|
1340 |
+
batch_size, channel, height, width = hidden_states.shape
|
1341 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1342 |
+
|
1343 |
+
batch_size, sequence_length, _ = (
|
1344 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1345 |
+
)
|
1346 |
+
inner_dim = hidden_states.shape[-1]
|
1347 |
+
|
1348 |
+
if attention_mask is not None:
|
1349 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1350 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1351 |
+
# (batch, heads, source_length, target_length)
|
1352 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1353 |
+
|
1354 |
+
if attn.group_norm is not None:
|
1355 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1356 |
+
|
1357 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1358 |
+
|
1359 |
+
if encoder_hidden_states is None:
|
1360 |
+
encoder_hidden_states = hidden_states
|
1361 |
+
elif attn.norm_cross:
|
1362 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1363 |
+
|
1364 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1365 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1366 |
+
|
1367 |
+
head_dim = inner_dim // attn.heads
|
1368 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1369 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1370 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1371 |
+
|
1372 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1373 |
+
hidden_states = F.scaled_dot_product_attention(
|
1374 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1375 |
+
)
|
1376 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1377 |
+
hidden_states = hidden_states.to(query.dtype)
|
1378 |
+
|
1379 |
+
# linear proj
|
1380 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1381 |
+
# dropout
|
1382 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1383 |
+
|
1384 |
+
if input_ndim == 4:
|
1385 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1386 |
+
|
1387 |
+
if attn.residual_connection:
|
1388 |
+
hidden_states = hidden_states + residual
|
1389 |
+
|
1390 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1391 |
+
|
1392 |
+
return hidden_states
|
1393 |
+
|
1394 |
+
|
1395 |
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1396 |
+
r"""
|
1397 |
+
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
1398 |
+
|
1399 |
+
Args:
|
1400 |
+
train_kv (`bool`, defaults to `True`):
|
1401 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
1402 |
+
train_q_out (`bool`, defaults to `True`):
|
1403 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
1404 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
1405 |
+
The hidden size of the attention layer.
|
1406 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1407 |
+
The number of channels in the `encoder_hidden_states`.
|
1408 |
+
out_bias (`bool`, defaults to `True`):
|
1409 |
+
Whether to include the bias parameter in `train_q_out`.
|
1410 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
1411 |
+
The dropout probability to use.
|
1412 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1413 |
+
The base
|
1414 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
1415 |
+
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
1416 |
+
"""
|
1417 |
+
|
1418 |
+
def __init__(
|
1419 |
+
self,
|
1420 |
+
train_kv=True,
|
1421 |
+
train_q_out=False,
|
1422 |
+
hidden_size=None,
|
1423 |
+
cross_attention_dim=None,
|
1424 |
+
out_bias=True,
|
1425 |
+
dropout=0.0,
|
1426 |
+
attention_op: Optional[Callable] = None,
|
1427 |
+
):
|
1428 |
+
super().__init__()
|
1429 |
+
self.train_kv = train_kv
|
1430 |
+
self.train_q_out = train_q_out
|
1431 |
+
|
1432 |
+
self.hidden_size = hidden_size
|
1433 |
+
self.cross_attention_dim = cross_attention_dim
|
1434 |
+
self.attention_op = attention_op
|
1435 |
+
|
1436 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
1437 |
+
if self.train_kv:
|
1438 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1439 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1440 |
+
if self.train_q_out:
|
1441 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1442 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
1443 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1444 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1445 |
+
|
1446 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1447 |
+
batch_size, sequence_length, _ = (
|
1448 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1449 |
+
)
|
1450 |
+
|
1451 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1452 |
+
|
1453 |
+
if self.train_q_out:
|
1454 |
+
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
|
1455 |
+
else:
|
1456 |
+
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
|
1457 |
+
|
1458 |
+
if encoder_hidden_states is None:
|
1459 |
+
crossattn = False
|
1460 |
+
encoder_hidden_states = hidden_states
|
1461 |
+
else:
|
1462 |
+
crossattn = True
|
1463 |
+
if attn.norm_cross:
|
1464 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1465 |
+
|
1466 |
+
if self.train_kv:
|
1467 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
|
1468 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
|
1469 |
+
key = key.to(attn.to_q.weight.dtype)
|
1470 |
+
value = value.to(attn.to_q.weight.dtype)
|
1471 |
+
else:
|
1472 |
+
key = attn.to_k(encoder_hidden_states)
|
1473 |
+
value = attn.to_v(encoder_hidden_states)
|
1474 |
+
|
1475 |
+
if crossattn:
|
1476 |
+
detach = torch.ones_like(key)
|
1477 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1478 |
+
key = detach * key + (1 - detach) * key.detach()
|
1479 |
+
value = detach * value + (1 - detach) * value.detach()
|
1480 |
+
|
1481 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1482 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1483 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1484 |
+
|
1485 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1486 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1487 |
+
)
|
1488 |
+
hidden_states = hidden_states.to(query.dtype)
|
1489 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1490 |
+
|
1491 |
+
if self.train_q_out:
|
1492 |
+
# linear proj
|
1493 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1494 |
+
# dropout
|
1495 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1496 |
+
else:
|
1497 |
+
# linear proj
|
1498 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1499 |
+
# dropout
|
1500 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1501 |
+
return hidden_states
|
1502 |
+
|
1503 |
+
|
1504 |
+
class SlicedAttnProcessor:
|
1505 |
+
r"""
|
1506 |
+
Processor for implementing sliced attention.
|
1507 |
+
|
1508 |
+
Args:
|
1509 |
+
slice_size (`int`, *optional*):
|
1510 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1511 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1512 |
+
"""
|
1513 |
+
|
1514 |
+
def __init__(self, slice_size):
|
1515 |
+
self.slice_size = slice_size
|
1516 |
+
|
1517 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1518 |
+
residual = hidden_states
|
1519 |
+
|
1520 |
+
input_ndim = hidden_states.ndim
|
1521 |
+
|
1522 |
+
if input_ndim == 4:
|
1523 |
+
batch_size, channel, height, width = hidden_states.shape
|
1524 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1525 |
+
|
1526 |
+
batch_size, sequence_length, _ = (
|
1527 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1528 |
+
)
|
1529 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1530 |
+
|
1531 |
+
if attn.group_norm is not None:
|
1532 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1533 |
+
|
1534 |
+
query = attn.to_q(hidden_states)
|
1535 |
+
dim = query.shape[-1]
|
1536 |
+
query = attn.head_to_batch_dim(query)
|
1537 |
+
|
1538 |
+
if encoder_hidden_states is None:
|
1539 |
+
encoder_hidden_states = hidden_states
|
1540 |
+
elif attn.norm_cross:
|
1541 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1542 |
+
|
1543 |
+
key = attn.to_k(encoder_hidden_states)
|
1544 |
+
value = attn.to_v(encoder_hidden_states)
|
1545 |
+
key = attn.head_to_batch_dim(key)
|
1546 |
+
value = attn.head_to_batch_dim(value)
|
1547 |
+
|
1548 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1549 |
+
hidden_states = torch.zeros(
|
1550 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1551 |
+
)
|
1552 |
+
|
1553 |
+
for i in range(batch_size_attention // self.slice_size):
|
1554 |
+
start_idx = i * self.slice_size
|
1555 |
+
end_idx = (i + 1) * self.slice_size
|
1556 |
+
|
1557 |
+
query_slice = query[start_idx:end_idx]
|
1558 |
+
key_slice = key[start_idx:end_idx]
|
1559 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1560 |
+
###########################################################################################################
|
1561 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, weight, attn_mask_slice)
|
1562 |
+
|
1563 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1564 |
+
|
1565 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1566 |
+
|
1567 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1568 |
+
|
1569 |
+
# linear proj
|
1570 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1571 |
+
# dropout
|
1572 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1573 |
+
|
1574 |
+
if input_ndim == 4:
|
1575 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1576 |
+
|
1577 |
+
if attn.residual_connection:
|
1578 |
+
hidden_states = hidden_states + residual
|
1579 |
+
|
1580 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1581 |
+
|
1582 |
+
return hidden_states
|
1583 |
+
|
1584 |
+
|
1585 |
+
class SlicedAttnAddedKVProcessor:
|
1586 |
+
r"""
|
1587 |
+
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
1588 |
+
|
1589 |
+
Args:
|
1590 |
+
slice_size (`int`, *optional*):
|
1591 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1592 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1593 |
+
"""
|
1594 |
+
|
1595 |
+
def __init__(self, slice_size):
|
1596 |
+
self.slice_size = slice_size
|
1597 |
+
|
1598 |
+
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
1599 |
+
residual = hidden_states
|
1600 |
+
|
1601 |
+
if attn.spatial_norm is not None:
|
1602 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1603 |
+
|
1604 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1605 |
+
|
1606 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1607 |
+
|
1608 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1609 |
+
|
1610 |
+
if encoder_hidden_states is None:
|
1611 |
+
encoder_hidden_states = hidden_states
|
1612 |
+
elif attn.norm_cross:
|
1613 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1614 |
+
|
1615 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1616 |
+
|
1617 |
+
query = attn.to_q(hidden_states)
|
1618 |
+
dim = query.shape[-1]
|
1619 |
+
query = attn.head_to_batch_dim(query)
|
1620 |
+
|
1621 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1622 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1623 |
+
|
1624 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1625 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1626 |
+
|
1627 |
+
if not attn.only_cross_attention:
|
1628 |
+
key = attn.to_k(hidden_states)
|
1629 |
+
value = attn.to_v(hidden_states)
|
1630 |
+
key = attn.head_to_batch_dim(key)
|
1631 |
+
value = attn.head_to_batch_dim(value)
|
1632 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1633 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1634 |
+
else:
|
1635 |
+
key = encoder_hidden_states_key_proj
|
1636 |
+
value = encoder_hidden_states_value_proj
|
1637 |
+
|
1638 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1639 |
+
hidden_states = torch.zeros(
|
1640 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1641 |
+
)
|
1642 |
+
|
1643 |
+
for i in range(batch_size_attention // self.slice_size):
|
1644 |
+
start_idx = i * self.slice_size
|
1645 |
+
end_idx = (i + 1) * self.slice_size
|
1646 |
+
|
1647 |
+
query_slice = query[start_idx:end_idx]
|
1648 |
+
key_slice = key[start_idx:end_idx]
|
1649 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1650 |
+
###########################################################################################################
|
1651 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, weight, attn_mask_slice)
|
1652 |
+
|
1653 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1654 |
+
|
1655 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1656 |
+
|
1657 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1658 |
+
|
1659 |
+
# linear proj
|
1660 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1661 |
+
# dropout
|
1662 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1663 |
+
|
1664 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1665 |
+
hidden_states = hidden_states + residual
|
1666 |
+
|
1667 |
+
return hidden_states
|
1668 |
+
|
1669 |
+
|
1670 |
+
AttentionProcessor = Union[
|
1671 |
+
AttnProcessor,
|
1672 |
+
Guid_AttnProcessor,
|
1673 |
+
AttnProcessor2_0,
|
1674 |
+
XFormersAttnProcessor,
|
1675 |
+
SlicedAttnProcessor,
|
1676 |
+
AttnAddedKVProcessor,
|
1677 |
+
SlicedAttnAddedKVProcessor,
|
1678 |
+
AttnAddedKVProcessor2_0,
|
1679 |
+
XFormersAttnAddedKVProcessor,
|
1680 |
+
LoRAAttnProcessor,
|
1681 |
+
LoRAXFormersAttnProcessor,
|
1682 |
+
LoRAAttnProcessor2_0,
|
1683 |
+
LoRAAttnAddedKVProcessor,
|
1684 |
+
CustomDiffusionAttnProcessor,
|
1685 |
+
CustomDiffusionXFormersAttnProcessor,
|
1686 |
+
]
|
1687 |
+
|
1688 |
+
LORA_ATTENTION_PROCESSORS = (
|
1689 |
+
LoRAAttnProcessor,
|
1690 |
+
LoRAAttnProcessor2_0,
|
1691 |
+
LoRAXFormersAttnProcessor,
|
1692 |
+
LoRAAttnAddedKVProcessor,
|
1693 |
+
)
|
1694 |
+
|
1695 |
+
|
1696 |
+
class SpatialNorm(nn.Module):
|
1697 |
+
"""
|
1698 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
1699 |
+
"""
|
1700 |
+
|
1701 |
+
def __init__(
|
1702 |
+
self,
|
1703 |
+
f_channels,
|
1704 |
+
zq_channels,
|
1705 |
+
):
|
1706 |
+
super().__init__()
|
1707 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
1708 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1709 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1710 |
+
|
1711 |
+
def forward(self, f, zq):
|
1712 |
+
f_size = f.shape[-2:]
|
1713 |
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1714 |
+
norm_f = self.norm_layer(f)
|
1715 |
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
1716 |
+
return new_f
|
Tiger Model/diffusiers-Tiger/models/autoencoder_asym_kl.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from ..utils import apply_forward_hook
|
21 |
+
from .autoencoder_kl import AutoencoderKLOutput
|
22 |
+
from .modeling_utils import ModelMixin
|
23 |
+
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
24 |
+
|
25 |
+
|
26 |
+
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
27 |
+
r"""
|
28 |
+
Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
|
29 |
+
for encoding images into latents and decoding latent representations into images.
|
30 |
+
|
31 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
32 |
+
for all models (such as downloading or saving).
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
36 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
37 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
38 |
+
Tuple of downsample block types.
|
39 |
+
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
40 |
+
Tuple of down block output channels.
|
41 |
+
layers_per_down_block (`int`, *optional*, defaults to `1`):
|
42 |
+
Number layers for down block.
|
43 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
44 |
+
Tuple of upsample block types.
|
45 |
+
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
46 |
+
Tuple of up block output channels.
|
47 |
+
layers_per_up_block (`int`, *optional*, defaults to `1`):
|
48 |
+
Number layers for up block.
|
49 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
50 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
51 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
52 |
+
norm_num_groups (`int`, *optional*, defaults to `32`):
|
53 |
+
Number of groups to use for the first normalization layer in ResNet blocks.
|
54 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
55 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
56 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
57 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
58 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
59 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
60 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
61 |
+
"""
|
62 |
+
|
63 |
+
@register_to_config
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
in_channels: int = 3,
|
67 |
+
out_channels: int = 3,
|
68 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
69 |
+
down_block_out_channels: Tuple[int] = (64,),
|
70 |
+
layers_per_down_block: int = 1,
|
71 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
72 |
+
up_block_out_channels: Tuple[int] = (64,),
|
73 |
+
layers_per_up_block: int = 1,
|
74 |
+
act_fn: str = "silu",
|
75 |
+
latent_channels: int = 4,
|
76 |
+
norm_num_groups: int = 32,
|
77 |
+
sample_size: int = 32,
|
78 |
+
scaling_factor: float = 0.18215,
|
79 |
+
) -> None:
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
# pass init params to Encoder
|
83 |
+
self.encoder = Encoder(
|
84 |
+
in_channels=in_channels,
|
85 |
+
out_channels=latent_channels,
|
86 |
+
down_block_types=down_block_types,
|
87 |
+
block_out_channels=down_block_out_channels,
|
88 |
+
layers_per_block=layers_per_down_block,
|
89 |
+
act_fn=act_fn,
|
90 |
+
norm_num_groups=norm_num_groups,
|
91 |
+
double_z=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
# pass init params to Decoder
|
95 |
+
self.decoder = MaskConditionDecoder(
|
96 |
+
in_channels=latent_channels,
|
97 |
+
out_channels=out_channels,
|
98 |
+
up_block_types=up_block_types,
|
99 |
+
block_out_channels=up_block_out_channels,
|
100 |
+
layers_per_block=layers_per_up_block,
|
101 |
+
act_fn=act_fn,
|
102 |
+
norm_num_groups=norm_num_groups,
|
103 |
+
)
|
104 |
+
|
105 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
106 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
107 |
+
|
108 |
+
self.use_slicing = False
|
109 |
+
self.use_tiling = False
|
110 |
+
|
111 |
+
@apply_forward_hook
|
112 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
113 |
+
h = self.encoder(x)
|
114 |
+
moments = self.quant_conv(h)
|
115 |
+
posterior = DiagonalGaussianDistribution(moments)
|
116 |
+
|
117 |
+
if not return_dict:
|
118 |
+
return (posterior,)
|
119 |
+
|
120 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
121 |
+
|
122 |
+
def _decode(
|
123 |
+
self,
|
124 |
+
z: torch.FloatTensor,
|
125 |
+
image: Optional[torch.FloatTensor] = None,
|
126 |
+
mask: Optional[torch.FloatTensor] = None,
|
127 |
+
return_dict: bool = True,
|
128 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
129 |
+
z = self.post_quant_conv(z)
|
130 |
+
dec = self.decoder(z, image, mask)
|
131 |
+
|
132 |
+
if not return_dict:
|
133 |
+
return (dec,)
|
134 |
+
|
135 |
+
return DecoderOutput(sample=dec)
|
136 |
+
|
137 |
+
@apply_forward_hook
|
138 |
+
def decode(
|
139 |
+
self,
|
140 |
+
z: torch.FloatTensor,
|
141 |
+
image: Optional[torch.FloatTensor] = None,
|
142 |
+
mask: Optional[torch.FloatTensor] = None,
|
143 |
+
return_dict: bool = True,
|
144 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
145 |
+
decoded = self._decode(z, image, mask).sample
|
146 |
+
|
147 |
+
if not return_dict:
|
148 |
+
return (decoded,)
|
149 |
+
|
150 |
+
return DecoderOutput(sample=decoded)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
sample: torch.FloatTensor,
|
155 |
+
mask: Optional[torch.FloatTensor] = None,
|
156 |
+
sample_posterior: bool = False,
|
157 |
+
return_dict: bool = True,
|
158 |
+
generator: Optional[torch.Generator] = None,
|
159 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
160 |
+
r"""
|
161 |
+
Args:
|
162 |
+
sample (`torch.FloatTensor`): Input sample.
|
163 |
+
mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
|
164 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
165 |
+
Whether to sample from the posterior.
|
166 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
167 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
168 |
+
"""
|
169 |
+
x = sample
|
170 |
+
posterior = self.encode(x).latent_dist
|
171 |
+
if sample_posterior:
|
172 |
+
z = posterior.sample(generator=generator)
|
173 |
+
else:
|
174 |
+
z = posterior.mode()
|
175 |
+
dec = self.decode(z, sample, mask).sample
|
176 |
+
|
177 |
+
if not return_dict:
|
178 |
+
return (dec,)
|
179 |
+
|
180 |
+
return DecoderOutput(sample=dec)
|
Tiger Model/diffusiers-Tiger/models/autoencoder_kl.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Dict, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..loaders import FromOriginalVAEMixin
|
22 |
+
from ..utils import BaseOutput, apply_forward_hook
|
23 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
24 |
+
from .modeling_utils import ModelMixin
|
25 |
+
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class AutoencoderKLOutput(BaseOutput):
|
30 |
+
"""
|
31 |
+
Output of AutoencoderKL encoding method.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
latent_dist (`DiagonalGaussianDistribution`):
|
35 |
+
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
36 |
+
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
37 |
+
"""
|
38 |
+
|
39 |
+
latent_dist: "DiagonalGaussianDistribution"
|
40 |
+
|
41 |
+
|
42 |
+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
43 |
+
r"""
|
44 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
45 |
+
|
46 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
47 |
+
for all models (such as downloading or saving).
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
51 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
52 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
53 |
+
Tuple of downsample block types.
|
54 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
55 |
+
Tuple of upsample block types.
|
56 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
57 |
+
Tuple of block output channels.
|
58 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
59 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
60 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
61 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
62 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
63 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
64 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
65 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
66 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
67 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
68 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
69 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
70 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
71 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
72 |
+
"""
|
73 |
+
|
74 |
+
_supports_gradient_checkpointing = True
|
75 |
+
|
76 |
+
@register_to_config
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
in_channels: int = 3,
|
80 |
+
out_channels: int = 3,
|
81 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
82 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
83 |
+
block_out_channels: Tuple[int] = (64,),
|
84 |
+
layers_per_block: int = 1,
|
85 |
+
act_fn: str = "silu",
|
86 |
+
latent_channels: int = 4,
|
87 |
+
norm_num_groups: int = 32,
|
88 |
+
sample_size: int = 32,
|
89 |
+
scaling_factor: float = 0.18215,
|
90 |
+
force_upcast: float = True,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
# pass init params to Encoder
|
95 |
+
self.encoder = Encoder(
|
96 |
+
in_channels=in_channels,
|
97 |
+
out_channels=latent_channels,
|
98 |
+
down_block_types=down_block_types,
|
99 |
+
block_out_channels=block_out_channels,
|
100 |
+
layers_per_block=layers_per_block,
|
101 |
+
act_fn=act_fn,
|
102 |
+
norm_num_groups=norm_num_groups,
|
103 |
+
double_z=True,
|
104 |
+
)
|
105 |
+
|
106 |
+
# pass init params to Decoder
|
107 |
+
self.decoder = Decoder(
|
108 |
+
in_channels=latent_channels,
|
109 |
+
out_channels=out_channels,
|
110 |
+
up_block_types=up_block_types,
|
111 |
+
block_out_channels=block_out_channels,
|
112 |
+
layers_per_block=layers_per_block,
|
113 |
+
norm_num_groups=norm_num_groups,
|
114 |
+
act_fn=act_fn,
|
115 |
+
)
|
116 |
+
|
117 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
118 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
119 |
+
|
120 |
+
self.use_slicing = False
|
121 |
+
self.use_tiling = False
|
122 |
+
|
123 |
+
# only relevant if vae tiling is enabled
|
124 |
+
self.tile_sample_min_size = self.config.sample_size
|
125 |
+
sample_size = (
|
126 |
+
self.config.sample_size[0]
|
127 |
+
if isinstance(self.config.sample_size, (list, tuple))
|
128 |
+
else self.config.sample_size
|
129 |
+
)
|
130 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
131 |
+
self.tile_overlap_factor = 0.25
|
132 |
+
|
133 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
134 |
+
if isinstance(module, (Encoder, Decoder)):
|
135 |
+
module.gradient_checkpointing = value
|
136 |
+
|
137 |
+
def enable_tiling(self, use_tiling: bool = True):
|
138 |
+
r"""
|
139 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
140 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
141 |
+
processing larger images.
|
142 |
+
"""
|
143 |
+
self.use_tiling = use_tiling
|
144 |
+
|
145 |
+
def disable_tiling(self):
|
146 |
+
r"""
|
147 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
148 |
+
decoding in one step.
|
149 |
+
"""
|
150 |
+
self.enable_tiling(False)
|
151 |
+
|
152 |
+
def enable_slicing(self):
|
153 |
+
r"""
|
154 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
155 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
156 |
+
"""
|
157 |
+
self.use_slicing = True
|
158 |
+
|
159 |
+
def disable_slicing(self):
|
160 |
+
r"""
|
161 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
162 |
+
decoding in one step.
|
163 |
+
"""
|
164 |
+
self.use_slicing = False
|
165 |
+
|
166 |
+
@property
|
167 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
168 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
169 |
+
r"""
|
170 |
+
Returns:
|
171 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
172 |
+
indexed by its weight name.
|
173 |
+
"""
|
174 |
+
# set recursively
|
175 |
+
processors = {}
|
176 |
+
|
177 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
178 |
+
if hasattr(module, "set_processor"):
|
179 |
+
processors[f"{name}.processor"] = module.processor
|
180 |
+
|
181 |
+
for sub_name, child in module.named_children():
|
182 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
183 |
+
|
184 |
+
return processors
|
185 |
+
|
186 |
+
for name, module in self.named_children():
|
187 |
+
fn_recursive_add_processors(name, module, processors)
|
188 |
+
|
189 |
+
return processors
|
190 |
+
|
191 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
192 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
193 |
+
r"""
|
194 |
+
Sets the attention processor to use to compute attention.
|
195 |
+
|
196 |
+
Parameters:
|
197 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
198 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
199 |
+
for **all** `Attention` layers.
|
200 |
+
|
201 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
202 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
203 |
+
|
204 |
+
"""
|
205 |
+
count = len(self.attn_processors.keys())
|
206 |
+
|
207 |
+
if isinstance(processor, dict) and len(processor) != count:
|
208 |
+
raise ValueError(
|
209 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
210 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
211 |
+
)
|
212 |
+
|
213 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
214 |
+
if hasattr(module, "set_processor"):
|
215 |
+
if not isinstance(processor, dict):
|
216 |
+
module.set_processor(processor)
|
217 |
+
else:
|
218 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
219 |
+
|
220 |
+
for sub_name, child in module.named_children():
|
221 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
222 |
+
|
223 |
+
for name, module in self.named_children():
|
224 |
+
fn_recursive_attn_processor(name, module, processor)
|
225 |
+
|
226 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
227 |
+
def set_default_attn_processor(self):
|
228 |
+
"""
|
229 |
+
Disables custom attention processors and sets the default attention implementation.
|
230 |
+
"""
|
231 |
+
self.set_attn_processor(AttnProcessor())
|
232 |
+
|
233 |
+
@apply_forward_hook
|
234 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
235 |
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
236 |
+
return self.tiled_encode(x, return_dict=return_dict)
|
237 |
+
|
238 |
+
if self.use_slicing and x.shape[0] > 1:
|
239 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
240 |
+
h = torch.cat(encoded_slices)
|
241 |
+
else:
|
242 |
+
h = self.encoder(x)
|
243 |
+
|
244 |
+
moments = self.quant_conv(h)
|
245 |
+
posterior = DiagonalGaussianDistribution(moments)
|
246 |
+
|
247 |
+
if not return_dict:
|
248 |
+
return (posterior,)
|
249 |
+
|
250 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
251 |
+
|
252 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
253 |
+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
254 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
255 |
+
|
256 |
+
z = self.post_quant_conv(z)
|
257 |
+
dec = self.decoder(z)
|
258 |
+
|
259 |
+
if not return_dict:
|
260 |
+
return (dec,)
|
261 |
+
|
262 |
+
return DecoderOutput(sample=dec)
|
263 |
+
|
264 |
+
@apply_forward_hook
|
265 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
266 |
+
if self.use_slicing and z.shape[0] > 1:
|
267 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
268 |
+
decoded = torch.cat(decoded_slices)
|
269 |
+
else:
|
270 |
+
decoded = self._decode(z).sample
|
271 |
+
|
272 |
+
if not return_dict:
|
273 |
+
return (decoded,)
|
274 |
+
|
275 |
+
return DecoderOutput(sample=decoded)
|
276 |
+
|
277 |
+
def blend_v(self, a, b, blend_extent):
|
278 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
279 |
+
for y in range(blend_extent):
|
280 |
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
281 |
+
return b
|
282 |
+
|
283 |
+
def blend_h(self, a, b, blend_extent):
|
284 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
285 |
+
for x in range(blend_extent):
|
286 |
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
287 |
+
return b
|
288 |
+
|
289 |
+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
290 |
+
r"""Encode a batch of images using a tiled encoder.
|
291 |
+
|
292 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
293 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
294 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
295 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
296 |
+
output, but they should be much less noticeable.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
x (`torch.FloatTensor`): Input batch of images.
|
300 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
301 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
305 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
306 |
+
`tuple` is returned.
|
307 |
+
"""
|
308 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
309 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
310 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
311 |
+
|
312 |
+
# Split the image into 512x512 tiles and encode them separately.
|
313 |
+
rows = []
|
314 |
+
for i in range(0, x.shape[2], overlap_size):
|
315 |
+
row = []
|
316 |
+
for j in range(0, x.shape[3], overlap_size):
|
317 |
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
318 |
+
tile = self.encoder(tile)
|
319 |
+
tile = self.quant_conv(tile)
|
320 |
+
row.append(tile)
|
321 |
+
rows.append(row)
|
322 |
+
result_rows = []
|
323 |
+
for i, row in enumerate(rows):
|
324 |
+
result_row = []
|
325 |
+
for j, tile in enumerate(row):
|
326 |
+
# blend the above tile and the left tile
|
327 |
+
# to the current tile and add the current tile to the result row
|
328 |
+
if i > 0:
|
329 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
330 |
+
if j > 0:
|
331 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
332 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
333 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
334 |
+
|
335 |
+
moments = torch.cat(result_rows, dim=2)
|
336 |
+
posterior = DiagonalGaussianDistribution(moments)
|
337 |
+
|
338 |
+
if not return_dict:
|
339 |
+
return (posterior,)
|
340 |
+
|
341 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
342 |
+
|
343 |
+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
344 |
+
r"""
|
345 |
+
Decode a batch of images using a tiled decoder.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
349 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
350 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
354 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
355 |
+
returned.
|
356 |
+
"""
|
357 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
358 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
359 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
360 |
+
|
361 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
362 |
+
# The tiles have an overlap to avoid seams between tiles.
|
363 |
+
rows = []
|
364 |
+
for i in range(0, z.shape[2], overlap_size):
|
365 |
+
row = []
|
366 |
+
for j in range(0, z.shape[3], overlap_size):
|
367 |
+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
368 |
+
tile = self.post_quant_conv(tile)
|
369 |
+
decoded = self.decoder(tile)
|
370 |
+
row.append(decoded)
|
371 |
+
rows.append(row)
|
372 |
+
result_rows = []
|
373 |
+
for i, row in enumerate(rows):
|
374 |
+
result_row = []
|
375 |
+
for j, tile in enumerate(row):
|
376 |
+
# blend the above tile and the left tile
|
377 |
+
# to the current tile and add the current tile to the result row
|
378 |
+
if i > 0:
|
379 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
380 |
+
if j > 0:
|
381 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
382 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
383 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
384 |
+
|
385 |
+
dec = torch.cat(result_rows, dim=2)
|
386 |
+
if not return_dict:
|
387 |
+
return (dec,)
|
388 |
+
|
389 |
+
return DecoderOutput(sample=dec)
|
390 |
+
|
391 |
+
def forward(
|
392 |
+
self,
|
393 |
+
sample: torch.FloatTensor,
|
394 |
+
sample_posterior: bool = False,
|
395 |
+
return_dict: bool = True,
|
396 |
+
generator: Optional[torch.Generator] = None,
|
397 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
398 |
+
r"""
|
399 |
+
Args:
|
400 |
+
sample (`torch.FloatTensor`): Input sample.
|
401 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
402 |
+
Whether to sample from the posterior.
|
403 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
404 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
405 |
+
"""
|
406 |
+
x = sample
|
407 |
+
posterior = self.encode(x).latent_dist
|
408 |
+
if sample_posterior:
|
409 |
+
z = posterior.sample(generator=generator)
|
410 |
+
else:
|
411 |
+
z = posterior.mode()
|
412 |
+
dec = self.decode(z).sample
|
413 |
+
|
414 |
+
if not return_dict:
|
415 |
+
return (dec,)
|
416 |
+
|
417 |
+
return DecoderOutput(sample=dec)
|
Tiger Model/diffusiers-Tiger/models/autoencoder_tiny.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..utils import BaseOutput, apply_forward_hook
|
23 |
+
from .modeling_utils import ModelMixin
|
24 |
+
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class AutoencoderTinyOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
Output of AutoencoderTiny encoding method.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
|
34 |
+
|
35 |
+
"""
|
36 |
+
|
37 |
+
latents: torch.Tensor
|
38 |
+
|
39 |
+
|
40 |
+
class AutoencoderTiny(ModelMixin, ConfigMixin):
|
41 |
+
r"""
|
42 |
+
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
|
43 |
+
|
44 |
+
[`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
|
45 |
+
|
46 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
47 |
+
all models (such as downloading or saving).
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
51 |
+
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
52 |
+
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
53 |
+
Tuple of integers representing the number of output channels for each encoder block. The length of the
|
54 |
+
tuple should be equal to the number of encoder blocks.
|
55 |
+
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
|
56 |
+
Tuple of integers representing the number of output channels for each decoder block. The length of the
|
57 |
+
tuple should be equal to the number of decoder blocks.
|
58 |
+
act_fn (`str`, *optional*, defaults to `"relu"`):
|
59 |
+
Activation function to be used throughout the model.
|
60 |
+
latent_channels (`int`, *optional*, defaults to 4):
|
61 |
+
Number of channels in the latent representation. The latent space acts as a compressed representation of
|
62 |
+
the input image.
|
63 |
+
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
|
64 |
+
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
|
65 |
+
upsampling process.
|
66 |
+
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
|
67 |
+
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
|
68 |
+
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
|
69 |
+
number of encoder blocks.
|
70 |
+
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
|
71 |
+
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
|
72 |
+
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
|
73 |
+
number of decoder blocks.
|
74 |
+
latent_magnitude (`float`, *optional*, defaults to 3.0):
|
75 |
+
Magnitude of the latent representation. This parameter scales the latent representation values to control
|
76 |
+
the extent of information preservation.
|
77 |
+
latent_shift (float, *optional*, defaults to 0.5):
|
78 |
+
Shift applied to the latent representation. This parameter controls the center of the latent space.
|
79 |
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
80 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
81 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
82 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
83 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
84 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
85 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
|
86 |
+
however, no such scaling factor was used, hence the value of 1.0 as the default.
|
87 |
+
force_upcast (`bool`, *optional*, default to `False`):
|
88 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
89 |
+
can be fine-tuned / trained to a lower range without losing too much precision, in which case
|
90 |
+
`force_upcast` can be set to `False` (see this fp16-friendly
|
91 |
+
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
92 |
+
"""
|
93 |
+
_supports_gradient_checkpointing = True
|
94 |
+
|
95 |
+
@register_to_config
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
in_channels=3,
|
99 |
+
out_channels=3,
|
100 |
+
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
101 |
+
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
|
102 |
+
act_fn: str = "relu",
|
103 |
+
latent_channels: int = 4,
|
104 |
+
upsampling_scaling_factor: int = 2,
|
105 |
+
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
|
106 |
+
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
|
107 |
+
latent_magnitude: int = 3,
|
108 |
+
latent_shift: float = 0.5,
|
109 |
+
force_upcast: float = False,
|
110 |
+
scaling_factor: float = 1.0,
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
if len(encoder_block_out_channels) != len(num_encoder_blocks):
|
115 |
+
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
|
116 |
+
if len(decoder_block_out_channels) != len(num_decoder_blocks):
|
117 |
+
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
|
118 |
+
|
119 |
+
self.encoder = EncoderTiny(
|
120 |
+
in_channels=in_channels,
|
121 |
+
out_channels=latent_channels,
|
122 |
+
num_blocks=num_encoder_blocks,
|
123 |
+
block_out_channels=encoder_block_out_channels,
|
124 |
+
act_fn=act_fn,
|
125 |
+
)
|
126 |
+
|
127 |
+
self.decoder = DecoderTiny(
|
128 |
+
in_channels=latent_channels,
|
129 |
+
out_channels=out_channels,
|
130 |
+
num_blocks=num_decoder_blocks,
|
131 |
+
block_out_channels=decoder_block_out_channels,
|
132 |
+
upsampling_scaling_factor=upsampling_scaling_factor,
|
133 |
+
act_fn=act_fn,
|
134 |
+
)
|
135 |
+
|
136 |
+
self.latent_magnitude = latent_magnitude
|
137 |
+
self.latent_shift = latent_shift
|
138 |
+
self.scaling_factor = scaling_factor
|
139 |
+
|
140 |
+
self.use_slicing = False
|
141 |
+
self.use_tiling = False
|
142 |
+
|
143 |
+
# only relevant if vae tiling is enabled
|
144 |
+
self.spatial_scale_factor = 2**out_channels
|
145 |
+
self.tile_overlap_factor = 0.125
|
146 |
+
self.tile_sample_min_size = 512
|
147 |
+
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
|
148 |
+
|
149 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
150 |
+
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
151 |
+
module.gradient_checkpointing = value
|
152 |
+
|
153 |
+
def scale_latents(self, x):
|
154 |
+
"""raw latents -> [0, 1]"""
|
155 |
+
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
|
156 |
+
|
157 |
+
def unscale_latents(self, x):
|
158 |
+
"""[0, 1] -> raw latents"""
|
159 |
+
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
160 |
+
|
161 |
+
def enable_slicing(self):
|
162 |
+
r"""
|
163 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
164 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
165 |
+
"""
|
166 |
+
self.use_slicing = True
|
167 |
+
|
168 |
+
def disable_slicing(self):
|
169 |
+
r"""
|
170 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
171 |
+
decoding in one step.
|
172 |
+
"""
|
173 |
+
self.use_slicing = False
|
174 |
+
|
175 |
+
def enable_tiling(self, use_tiling: bool = True):
|
176 |
+
r"""
|
177 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
178 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
179 |
+
processing larger images.
|
180 |
+
"""
|
181 |
+
self.use_tiling = use_tiling
|
182 |
+
|
183 |
+
def disable_tiling(self):
|
184 |
+
r"""
|
185 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
186 |
+
decoding in one step.
|
187 |
+
"""
|
188 |
+
self.enable_tiling(False)
|
189 |
+
|
190 |
+
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
191 |
+
r"""Encode a batch of images using a tiled encoder.
|
192 |
+
|
193 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
194 |
+
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
195 |
+
tiles overlap and are blended together to form a smooth output.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
x (`torch.FloatTensor`): Input batch of images.
|
199 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
200 |
+
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
|
204 |
+
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
|
205 |
+
plain `tuple` is returned.
|
206 |
+
"""
|
207 |
+
# scale of encoder output relative to input
|
208 |
+
sf = self.spatial_scale_factor
|
209 |
+
tile_size = self.tile_sample_min_size
|
210 |
+
|
211 |
+
# number of pixels to blend and to traverse between tile
|
212 |
+
blend_size = int(tile_size * self.tile_overlap_factor)
|
213 |
+
traverse_size = tile_size - blend_size
|
214 |
+
|
215 |
+
# tiles index (up/left)
|
216 |
+
ti = range(0, x.shape[-2], traverse_size)
|
217 |
+
tj = range(0, x.shape[-1], traverse_size)
|
218 |
+
|
219 |
+
# mask for blending
|
220 |
+
blend_masks = torch.stack(
|
221 |
+
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
|
222 |
+
)
|
223 |
+
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
224 |
+
|
225 |
+
# output array
|
226 |
+
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
|
227 |
+
for i in ti:
|
228 |
+
for j in tj:
|
229 |
+
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
230 |
+
# tile result
|
231 |
+
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
|
232 |
+
tile = self.encoder(tile_in)
|
233 |
+
h, w = tile.shape[-2], tile.shape[-1]
|
234 |
+
# blend tile result into output
|
235 |
+
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
236 |
+
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
237 |
+
blend_mask = blend_mask_i * blend_mask_j
|
238 |
+
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
|
239 |
+
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
240 |
+
return out
|
241 |
+
|
242 |
+
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
243 |
+
r"""Encode a batch of images using a tiled encoder.
|
244 |
+
|
245 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
246 |
+
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
247 |
+
tiles overlap and are blended together to form a smooth output.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
x (`torch.FloatTensor`): Input batch of images.
|
251 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
252 |
+
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
256 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
257 |
+
returned.
|
258 |
+
"""
|
259 |
+
# scale of decoder output relative to input
|
260 |
+
sf = self.spatial_scale_factor
|
261 |
+
tile_size = self.tile_latent_min_size
|
262 |
+
|
263 |
+
# number of pixels to blend and to traverse between tiles
|
264 |
+
blend_size = int(tile_size * self.tile_overlap_factor)
|
265 |
+
traverse_size = tile_size - blend_size
|
266 |
+
|
267 |
+
# tiles index (up/left)
|
268 |
+
ti = range(0, x.shape[-2], traverse_size)
|
269 |
+
tj = range(0, x.shape[-1], traverse_size)
|
270 |
+
|
271 |
+
# mask for blending
|
272 |
+
blend_masks = torch.stack(
|
273 |
+
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
|
274 |
+
)
|
275 |
+
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
276 |
+
|
277 |
+
# output array
|
278 |
+
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
|
279 |
+
for i in ti:
|
280 |
+
for j in tj:
|
281 |
+
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
282 |
+
# tile result
|
283 |
+
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
|
284 |
+
tile = self.decoder(tile_in)
|
285 |
+
h, w = tile.shape[-2], tile.shape[-1]
|
286 |
+
# blend tile result into output
|
287 |
+
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
288 |
+
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
289 |
+
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
|
290 |
+
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
291 |
+
return out
|
292 |
+
|
293 |
+
@apply_forward_hook
|
294 |
+
def encode(
|
295 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
296 |
+
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
297 |
+
if self.use_slicing and x.shape[0] > 1:
|
298 |
+
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
|
299 |
+
output = torch.cat(output)
|
300 |
+
else:
|
301 |
+
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
|
302 |
+
|
303 |
+
if not return_dict:
|
304 |
+
return (output,)
|
305 |
+
|
306 |
+
return AutoencoderTinyOutput(latents=output)
|
307 |
+
|
308 |
+
@apply_forward_hook
|
309 |
+
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
310 |
+
if self.use_slicing and x.shape[0] > 1:
|
311 |
+
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
312 |
+
output = torch.cat(output)
|
313 |
+
else:
|
314 |
+
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
315 |
+
# Refer to the following discussion to know why this is needed.
|
316 |
+
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
|
317 |
+
output = output.mul_(2).sub_(1)
|
318 |
+
|
319 |
+
if not return_dict:
|
320 |
+
return (output,)
|
321 |
+
|
322 |
+
return DecoderOutput(sample=output)
|
323 |
+
|
324 |
+
def forward(
|
325 |
+
self,
|
326 |
+
sample: torch.FloatTensor,
|
327 |
+
return_dict: bool = True,
|
328 |
+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
329 |
+
r"""
|
330 |
+
Args:
|
331 |
+
sample (`torch.FloatTensor`): Input sample.
|
332 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
333 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
334 |
+
"""
|
335 |
+
enc = self.encode(sample).latents
|
336 |
+
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
|
337 |
+
unscaled_enc = self.unscale_latents(scaled_enc)
|
338 |
+
dec = self.decode(unscaled_enc)
|
339 |
+
|
340 |
+
if not return_dict:
|
341 |
+
return (dec,)
|
342 |
+
return DecoderOutput(sample=dec)
|
Tiger Model/diffusiers-Tiger/models/controlnet.py
ADDED
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..loaders import FromOriginalControlnetMixin
|
23 |
+
from ..utils import BaseOutput, logging
|
24 |
+
from .attention_processor import AttentionProcessor, Guid_AttnProcessor
|
25 |
+
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
26 |
+
from .modeling_utils import ModelMixin
|
27 |
+
from .unet_2d_blocks import (
|
28 |
+
CrossAttnDownBlock2D,
|
29 |
+
DownBlock2D,
|
30 |
+
UNetMidBlock2DCrossAttn,
|
31 |
+
get_down_block,
|
32 |
+
)
|
33 |
+
from .unet_2d_condition import UNet2DConditionModel
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class ControlNetOutput(BaseOutput):
|
41 |
+
"""
|
42 |
+
The output of [`ControlNetModel`].
|
43 |
+
|
44 |
+
Args:
|
45 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
46 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
47 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
48 |
+
used to condition the original UNet's downsampling activations.
|
49 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
50 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
51 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
52 |
+
Output can be used to condition the original UNet's middle block activation.
|
53 |
+
"""
|
54 |
+
|
55 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
56 |
+
mid_block_res_sample: torch.Tensor
|
57 |
+
|
58 |
+
|
59 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
60 |
+
"""
|
61 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
62 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
63 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
64 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
65 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
66 |
+
model) to encode image-space conditions ... into feature maps ..."
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
conditioning_embedding_channels: int,
|
72 |
+
conditioning_channels: int = 3,
|
73 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
78 |
+
|
79 |
+
self.blocks = nn.ModuleList([])
|
80 |
+
|
81 |
+
for i in range(len(block_out_channels) - 1):
|
82 |
+
channel_in = block_out_channels[i]
|
83 |
+
channel_out = block_out_channels[i + 1]
|
84 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
85 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
86 |
+
|
87 |
+
self.conv_out = zero_module(
|
88 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, conditioning):
|
92 |
+
embedding = self.conv_in(conditioning)
|
93 |
+
embedding = F.silu(embedding)
|
94 |
+
|
95 |
+
for block in self.blocks:
|
96 |
+
embedding = block(embedding)
|
97 |
+
embedding = F.silu(embedding)
|
98 |
+
|
99 |
+
embedding = self.conv_out(embedding)
|
100 |
+
|
101 |
+
return embedding
|
102 |
+
|
103 |
+
|
104 |
+
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
105 |
+
_supports_gradient_checkpointing = True
|
106 |
+
|
107 |
+
@register_to_config
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
in_channels: int = 4,
|
111 |
+
conditioning_channels: int = 3,
|
112 |
+
flip_sin_to_cos: bool = True,
|
113 |
+
freq_shift: int = 0,
|
114 |
+
down_block_types: Tuple[str] = (
|
115 |
+
"CrossAttnDownBlock2D",
|
116 |
+
"CrossAttnDownBlock2D",
|
117 |
+
"CrossAttnDownBlock2D",
|
118 |
+
"DownBlock2D",
|
119 |
+
),
|
120 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
121 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
122 |
+
layers_per_block: int = 2,
|
123 |
+
downsample_padding: int = 1,
|
124 |
+
mid_block_scale_factor: float = 1,
|
125 |
+
act_fn: str = "silu",
|
126 |
+
norm_num_groups: Optional[int] = 32,
|
127 |
+
norm_eps: float = 1e-5,
|
128 |
+
cross_attention_dim: int = 1280,
|
129 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
130 |
+
encoder_hid_dim: Optional[int] = None,
|
131 |
+
encoder_hid_dim_type: Optional[str] = None,
|
132 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
133 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
134 |
+
use_linear_projection: bool = False,
|
135 |
+
class_embed_type: Optional[str] = None,
|
136 |
+
addition_embed_type: Optional[str] = None,
|
137 |
+
addition_time_embed_dim: Optional[int] = None,
|
138 |
+
num_class_embeds: Optional[int] = None,
|
139 |
+
upcast_attention: bool = False,
|
140 |
+
resnet_time_scale_shift: str = "default",
|
141 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
142 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
143 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
144 |
+
global_pool_conditions: bool = False,
|
145 |
+
addition_embed_type_num_heads=64,
|
146 |
+
weight : Optional[torch.Tensor] = None,
|
147 |
+
):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
151 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
152 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
153 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
154 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
155 |
+
# which is why we correct for the naming here.
|
156 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
157 |
+
|
158 |
+
# Check inputs
|
159 |
+
if len(block_out_channels) != len(down_block_types):
|
160 |
+
raise ValueError(
|
161 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
162 |
+
)
|
163 |
+
|
164 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
165 |
+
raise ValueError(
|
166 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
167 |
+
)
|
168 |
+
|
169 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
170 |
+
raise ValueError(
|
171 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
172 |
+
)
|
173 |
+
|
174 |
+
if isinstance(transformer_layers_per_block, int):
|
175 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
176 |
+
|
177 |
+
# input
|
178 |
+
conv_in_kernel = 3
|
179 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
180 |
+
self.conv_in = nn.Conv2d(
|
181 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
182 |
+
)
|
183 |
+
|
184 |
+
# time
|
185 |
+
time_embed_dim = block_out_channels[0] * 4
|
186 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
187 |
+
timestep_input_dim = block_out_channels[0]
|
188 |
+
self.time_embedding = TimestepEmbedding(
|
189 |
+
timestep_input_dim,
|
190 |
+
time_embed_dim,
|
191 |
+
act_fn=act_fn,
|
192 |
+
)
|
193 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
194 |
+
encoder_hid_dim_type = "text_proj"
|
195 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
196 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
197 |
+
|
198 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
199 |
+
raise ValueError(
|
200 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
201 |
+
)
|
202 |
+
|
203 |
+
if encoder_hid_dim_type == "text_proj":
|
204 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
205 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
206 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
207 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
208 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
209 |
+
self.encoder_hid_proj = TextImageProjection(
|
210 |
+
text_embed_dim=encoder_hid_dim,
|
211 |
+
image_embed_dim=cross_attention_dim,
|
212 |
+
cross_attention_dim=cross_attention_dim,
|
213 |
+
)
|
214 |
+
|
215 |
+
elif encoder_hid_dim_type is not None:
|
216 |
+
raise ValueError(
|
217 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
self.encoder_hid_proj = None
|
221 |
+
# class embedding
|
222 |
+
if class_embed_type is None and num_class_embeds is not None:
|
223 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
224 |
+
elif class_embed_type == "timestep":
|
225 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
226 |
+
elif class_embed_type == "identity":
|
227 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
228 |
+
elif class_embed_type == "projection":
|
229 |
+
if projection_class_embeddings_input_dim is None:
|
230 |
+
raise ValueError(
|
231 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
232 |
+
)
|
233 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
234 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
235 |
+
# 2. it projects from an arbitrary input dimension.
|
236 |
+
#
|
237 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
238 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
239 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
240 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
241 |
+
else:
|
242 |
+
self.class_embedding = None
|
243 |
+
|
244 |
+
if addition_embed_type == "text_nd":
|
245 |
+
if encoder_hid_dim is not None:
|
246 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
247 |
+
else:
|
248 |
+
text_time_embedding_from_dim = cross_attention_dim
|
249 |
+
|
250 |
+
self.add_embedding = TextTimeEmbedding(
|
251 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
252 |
+
)
|
253 |
+
elif addition_embed_type == "text_image":
|
254 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
255 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
256 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
257 |
+
self.add_embedding = TextImageTimeEmbedding(
|
258 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
259 |
+
)
|
260 |
+
elif addition_embed_type == "text_time":
|
261 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
262 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
263 |
+
|
264 |
+
elif addition_embed_type is not None:
|
265 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
266 |
+
|
267 |
+
# control net conditioning embedding
|
268 |
+
############################################################### ControlNetConditioningEmbedding #############################################################################
|
269 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
270 |
+
conditioning_embedding_channels=block_out_channels[0],
|
271 |
+
block_out_channels=conditioning_embedding_out_channels,
|
272 |
+
conditioning_channels=conditioning_channels,
|
273 |
+
)
|
274 |
+
self.down_blocks = nn.ModuleList([])
|
275 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
276 |
+
|
277 |
+
if isinstance(only_cross_attention, bool):
|
278 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
279 |
+
|
280 |
+
if isinstance(attention_head_dim, int):
|
281 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
282 |
+
|
283 |
+
if isinstance(num_attention_heads, int):
|
284 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
285 |
+
|
286 |
+
# down
|
287 |
+
output_channel = block_out_channels[0]
|
288 |
+
|
289 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
290 |
+
controlnet_block = zero_module(controlnet_block)
|
291 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
292 |
+
|
293 |
+
for i, down_block_type in enumerate(down_block_types):
|
294 |
+
input_channel = output_channel
|
295 |
+
output_channel = block_out_channels[i]
|
296 |
+
is_final_block = i == len(block_out_channels) - 1
|
297 |
+
|
298 |
+
down_block = get_down_block(
|
299 |
+
down_block_type,
|
300 |
+
num_layers=layers_per_block,
|
301 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
302 |
+
in_channels=input_channel,
|
303 |
+
out_channels=output_channel,
|
304 |
+
temb_channels=time_embed_dim,
|
305 |
+
add_downsample=not is_final_block,
|
306 |
+
resnet_eps=norm_eps,
|
307 |
+
resnet_act_fn=act_fn,
|
308 |
+
resnet_groups=norm_num_groups,
|
309 |
+
cross_attention_dim=cross_attention_dim,
|
310 |
+
num_attention_heads=num_attention_heads[i],
|
311 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
312 |
+
downsample_padding=downsample_padding,
|
313 |
+
use_linear_projection=use_linear_projection,
|
314 |
+
only_cross_attention=only_cross_attention[i],
|
315 |
+
upcast_attention=upcast_attention,
|
316 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
317 |
+
weight = weight,
|
318 |
+
)
|
319 |
+
|
320 |
+
self.down_blocks.append(down_block)
|
321 |
+
|
322 |
+
for _ in range(layers_per_block):
|
323 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
324 |
+
controlnet_block = zero_module(controlnet_block)
|
325 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
326 |
+
|
327 |
+
if not is_final_block:
|
328 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
329 |
+
controlnet_block = zero_module(controlnet_block)
|
330 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
331 |
+
|
332 |
+
# mid
|
333 |
+
mid_block_channel = block_out_channels[-1]
|
334 |
+
|
335 |
+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
336 |
+
controlnet_block = zero_module(controlnet_block)
|
337 |
+
self.controlnet_mid_block = controlnet_block
|
338 |
+
|
339 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
340 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
341 |
+
in_channels=mid_block_channel,
|
342 |
+
temb_channels=time_embed_dim,
|
343 |
+
resnet_eps=norm_eps,
|
344 |
+
resnet_act_fn=act_fn,
|
345 |
+
output_scale_factor=mid_block_scale_factor,
|
346 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
347 |
+
cross_attention_dim=cross_attention_dim,
|
348 |
+
num_attention_heads=num_attention_heads[-1],
|
349 |
+
resnet_groups=norm_num_groups,
|
350 |
+
use_linear_projection=use_linear_projection,
|
351 |
+
upcast_attention=upcast_attention,
|
352 |
+
)
|
353 |
+
|
354 |
+
@classmethod
|
355 |
+
def from_unet(
|
356 |
+
cls,
|
357 |
+
unet: UNet2DConditionModel,
|
358 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
359 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
360 |
+
load_weights_from_unet: bool = True,
|
361 |
+
weight : Optional[torch.Tensor] = None,
|
362 |
+
):
|
363 |
+
r"""
|
364 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
365 |
+
|
366 |
+
Parameters:
|
367 |
+
unet (`UNet2DConditionModel`):
|
368 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
369 |
+
where applicable.
|
370 |
+
"""
|
371 |
+
transformer_layers_per_block = (
|
372 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
373 |
+
)
|
374 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
375 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
376 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
377 |
+
addition_time_embed_dim = (
|
378 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
379 |
+
)
|
380 |
+
controlnet = cls(
|
381 |
+
encoder_hid_dim=encoder_hid_dim,
|
382 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
383 |
+
addition_embed_type=addition_embed_type,
|
384 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
385 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
386 |
+
in_channels=unet.config.in_channels,
|
387 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
388 |
+
freq_shift=unet.config.freq_shift,
|
389 |
+
down_block_types=unet.config.down_block_types,
|
390 |
+
only_cross_attention=unet.config.only_cross_attention,
|
391 |
+
block_out_channels=unet.config.block_out_channels,
|
392 |
+
layers_per_block=unet.config.layers_per_block,
|
393 |
+
downsample_padding=unet.config.downsample_padding,
|
394 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
395 |
+
act_fn=unet.config.act_fn,
|
396 |
+
norm_num_groups=unet.config.norm_num_groups,
|
397 |
+
norm_eps=unet.config.norm_eps,
|
398 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
399 |
+
attention_head_dim=unet.config.attention_head_dim,
|
400 |
+
num_attention_heads=unet.config.num_attention_heads,
|
401 |
+
use_linear_projection=unet.config.use_linear_projection,
|
402 |
+
class_embed_type=unet.config.class_embed_type,
|
403 |
+
num_class_embeds=unet.config.num_class_embeds,
|
404 |
+
upcast_attention=unet.config.upcast_attention,
|
405 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
406 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
407 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
408 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
409 |
+
)
|
410 |
+
|
411 |
+
if load_weights_from_unet:
|
412 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
413 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
414 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
415 |
+
|
416 |
+
if controlnet.class_embedding:
|
417 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
418 |
+
|
419 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
420 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
421 |
+
|
422 |
+
return controlnet
|
423 |
+
|
424 |
+
@property
|
425 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
426 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
427 |
+
r"""
|
428 |
+
Returns:
|
429 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
430 |
+
indexed by its weight name.
|
431 |
+
"""
|
432 |
+
# set recursively
|
433 |
+
processors = {}
|
434 |
+
|
435 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
436 |
+
if hasattr(module, "set_processor"):
|
437 |
+
processors[f"{name}.processor"] = module.processor
|
438 |
+
|
439 |
+
for sub_name, child in module.named_children():
|
440 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
441 |
+
|
442 |
+
return processors
|
443 |
+
|
444 |
+
for name, module in self.named_children():
|
445 |
+
fn_recursive_add_processors(name, module, processors)
|
446 |
+
|
447 |
+
return processors
|
448 |
+
|
449 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
450 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
451 |
+
r"""
|
452 |
+
Sets the attention processor to use to compute attention.
|
453 |
+
|
454 |
+
Parameters:
|
455 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
456 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
457 |
+
for **all** `Attention` layers.
|
458 |
+
|
459 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
460 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
461 |
+
|
462 |
+
"""
|
463 |
+
count = len(self.attn_processors.keys())
|
464 |
+
|
465 |
+
if isinstance(processor, dict) and len(processor) != count:
|
466 |
+
raise ValueError(
|
467 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
468 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
469 |
+
)
|
470 |
+
|
471 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
472 |
+
if hasattr(module, "set_processor"):
|
473 |
+
if not isinstance(processor, dict):
|
474 |
+
module.set_processor(processor)
|
475 |
+
else:
|
476 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
477 |
+
|
478 |
+
for sub_name, child in module.named_children():
|
479 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
480 |
+
for name, module in self.named_children():
|
481 |
+
fn_recursive_attn_processor(name, module, processor)
|
482 |
+
|
483 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
484 |
+
def set_default_attn_processor(self):
|
485 |
+
"""
|
486 |
+
Disables custom attention processors and sets the default attention implementation.
|
487 |
+
"""
|
488 |
+
self.set_attn_processor(AttnProcessor())
|
489 |
+
|
490 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
491 |
+
def set_attention_slice(self, slice_size):
|
492 |
+
r"""
|
493 |
+
Enable sliced attention computation.
|
494 |
+
|
495 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
496 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
497 |
+
|
498 |
+
Args:
|
499 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
500 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
501 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
502 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
503 |
+
must be a multiple of `slice_size`.
|
504 |
+
"""
|
505 |
+
sliceable_head_dims = []
|
506 |
+
|
507 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
508 |
+
if hasattr(module, "set_attention_slice"):
|
509 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
510 |
+
|
511 |
+
for child in module.children():
|
512 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
513 |
+
|
514 |
+
# retrieve number of attention layers
|
515 |
+
for module in self.children():
|
516 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
517 |
+
|
518 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
519 |
+
|
520 |
+
if slice_size == "auto":
|
521 |
+
# half the attention head size is usually a good trade-off between
|
522 |
+
# speed and memory
|
523 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
524 |
+
elif slice_size == "max":
|
525 |
+
# make smallest slice possible
|
526 |
+
slice_size = num_sliceable_layers * [1]
|
527 |
+
|
528 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
529 |
+
|
530 |
+
if len(slice_size) != len(sliceable_head_dims):
|
531 |
+
raise ValueError(
|
532 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
533 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
534 |
+
)
|
535 |
+
|
536 |
+
for i in range(len(slice_size)):
|
537 |
+
size = slice_size[i]
|
538 |
+
dim = sliceable_head_dims[i]
|
539 |
+
if size is not None and size > dim:
|
540 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
541 |
+
|
542 |
+
# Recursively walk through all the children.
|
543 |
+
# Any children which exposes the set_attention_slice method
|
544 |
+
# gets the message
|
545 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
546 |
+
if hasattr(module, "set_attention_slice"):
|
547 |
+
module.set_attention_slice(slice_size.pop())
|
548 |
+
|
549 |
+
for child in module.children():
|
550 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
551 |
+
|
552 |
+
reversed_slice_size = list(reversed(slice_size))
|
553 |
+
for module in self.children():
|
554 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
555 |
+
|
556 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
557 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
558 |
+
module.gradient_checkpointing = value
|
559 |
+
|
560 |
+
def forward(
|
561 |
+
self,
|
562 |
+
sample: torch.FloatTensor,
|
563 |
+
timestep: Union[torch.Tensor, float, int],
|
564 |
+
encoder_hidden_states: torch.Tensor,
|
565 |
+
controlnet_cond: torch.FloatTensor,
|
566 |
+
conditioning_scale: float = 1.0,
|
567 |
+
class_labels: Optional[torch.Tensor] = None,
|
568 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
569 |
+
attention_mask: Optional[torch.Tensor] = None,
|
570 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
571 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
572 |
+
guess_mode: bool = False,
|
573 |
+
return_dict: bool = True,
|
574 |
+
weight: Optional[torch.Tensor] = None,
|
575 |
+
) -> Union[ControlNetOutput, Tuple]:
|
576 |
+
"""
|
577 |
+
The [`ControlNetModel`] forward method.
|
578 |
+
|
579 |
+
Args:
|
580 |
+
sample (`torch.FloatTensor`):
|
581 |
+
The noisy input tensor.
|
582 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
583 |
+
The number of timesteps to denoise an input.
|
584 |
+
encoder_hidden_states (`torch.Tensor`):
|
585 |
+
The encoder hidden states.
|
586 |
+
controlnet_cond (`torch.FloatTensor`):
|
587 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
588 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
589 |
+
The scale factor for ControlNet outputs.
|
590 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
591 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
592 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
593 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
594 |
+
added_cond_kwargs (`dict`):
|
595 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
596 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
597 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
598 |
+
guess_mode (`bool`, defaults to `False`):
|
599 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
600 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
601 |
+
return_dict (`bool`, defaults to `True`):
|
602 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
603 |
+
|
604 |
+
Returns:
|
605 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
606 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
607 |
+
returned where the first element is the sample tensor.
|
608 |
+
"""
|
609 |
+
# check channel order
|
610 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
611 |
+
|
612 |
+
if channel_order == "rgb":
|
613 |
+
# in rgb order by default
|
614 |
+
...
|
615 |
+
elif channel_order == "bgr":
|
616 |
+
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
617 |
+
else:
|
618 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
619 |
+
|
620 |
+
# prepare attention_mask
|
621 |
+
|
622 |
+
if attention_mask is not None:
|
623 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
624 |
+
attention_mask = attention_mask.unsqueeze(1)
|
625 |
+
|
626 |
+
# 1. time
|
627 |
+
timesteps = timestep
|
628 |
+
if not torch.is_tensor(timesteps):
|
629 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
630 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
631 |
+
is_mps = sample.device.type == "mps"
|
632 |
+
if isinstance(timestep, float):
|
633 |
+
dtype = torch.float32 if is_mps else torch.float64
|
634 |
+
else:
|
635 |
+
dtype = torch.int32 if is_mps else torch.int64
|
636 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
637 |
+
elif len(timesteps.shape) == 0:
|
638 |
+
timesteps = timesteps[None].to(sample.device)
|
639 |
+
|
640 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
641 |
+
timesteps = timesteps.expand(sample.shape[0])
|
642 |
+
|
643 |
+
t_emb = self.time_proj(timesteps)
|
644 |
+
|
645 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
646 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
647 |
+
# there might be better ways to encapsulate this.
|
648 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
649 |
+
|
650 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
651 |
+
aug_emb = None
|
652 |
+
|
653 |
+
if self.class_embedding is not None:
|
654 |
+
if class_labels is None:
|
655 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
656 |
+
|
657 |
+
if self.config.class_embed_type == "timestep":
|
658 |
+
class_labels = self.time_proj(class_labels)
|
659 |
+
|
660 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
661 |
+
emb = emb + class_emb
|
662 |
+
|
663 |
+
if "addition_embed_type" in self.config:
|
664 |
+
if self.config.addition_embed_type == "text":
|
665 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
666 |
+
|
667 |
+
elif self.config.addition_embed_type == "text_time":
|
668 |
+
if "text_embeds" not in added_cond_kwargs:
|
669 |
+
raise ValueError(
|
670 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
671 |
+
)
|
672 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
673 |
+
if "time_ids" not in added_cond_kwargs:
|
674 |
+
raise ValueError(
|
675 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
676 |
+
)
|
677 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
678 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
679 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
680 |
+
|
681 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
682 |
+
add_embeds = add_embeds.to(emb.dtype)
|
683 |
+
aug_emb = self.add_embedding(add_embeds)
|
684 |
+
|
685 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
686 |
+
|
687 |
+
# 2. pre-process
|
688 |
+
sample = self.conv_in(sample)
|
689 |
+
|
690 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
691 |
+
sample = sample + controlnet_cond
|
692 |
+
|
693 |
+
# 3. down
|
694 |
+
down_block_res_samples = (sample,)
|
695 |
+
for downsample_block in self.down_blocks:
|
696 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
697 |
+
print('controlnet encoder_hidden_states_nd',encoder_hidden_states.shape)
|
698 |
+
sample, res_samples = downsample_block(
|
699 |
+
hidden_states=sample,
|
700 |
+
temb=emb,
|
701 |
+
encoder_hidden_states=encoder_hidden_states,
|
702 |
+
attention_mask=attention_mask,
|
703 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
704 |
+
weight = weight
|
705 |
+
)
|
706 |
+
else:
|
707 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
708 |
+
|
709 |
+
down_block_res_samples += res_samples
|
710 |
+
|
711 |
+
|
712 |
+
# 4. mid
|
713 |
+
if self.mid_block is not None:
|
714 |
+
sample = self.mid_block(
|
715 |
+
sample,
|
716 |
+
emb,
|
717 |
+
encoder_hidden_states=encoder_hidden_states,
|
718 |
+
attention_mask=attention_mask,
|
719 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
720 |
+
)
|
721 |
+
|
722 |
+
# 5. Control net blocks
|
723 |
+
|
724 |
+
controlnet_down_block_res_samples = ()
|
725 |
+
|
726 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
727 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
728 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
729 |
+
|
730 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
731 |
+
|
732 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
733 |
+
|
734 |
+
# 6. scaling
|
735 |
+
if guess_mode and not self.config.global_pool_conditions:
|
736 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
737 |
+
|
738 |
+
scales = scales * conditioning_scale
|
739 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
740 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
741 |
+
else:
|
742 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
743 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
744 |
+
|
745 |
+
if self.config.global_pool_conditions:
|
746 |
+
down_block_res_samples = [
|
747 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
748 |
+
]
|
749 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
750 |
+
|
751 |
+
if not return_dict:
|
752 |
+
return (down_block_res_samples, mid_block_res_sample)
|
753 |
+
|
754 |
+
return ControlNetOutput(
|
755 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
756 |
+
)
|
757 |
+
|
758 |
+
|
759 |
+
def zero_module(module):
|
760 |
+
for p in module.parameters():
|
761 |
+
nn.init.zeros_(p)
|
762 |
+
return module
|
Tiger Model/diffusiers-Tiger/models/dual_transformer_2d.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
19 |
+
|
20 |
+
|
21 |
+
class DualTransformer2DModel(nn.Module):
|
22 |
+
"""
|
23 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
27 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
28 |
+
in_channels (`int`, *optional*):
|
29 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
30 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
31 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
32 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
33 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
34 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
35 |
+
`ImagePositionalEmbeddings`.
|
36 |
+
num_vector_embeds (`int`, *optional*):
|
37 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
38 |
+
Includes the class for the masked latent pixel.
|
39 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
40 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
41 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
42 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
43 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
44 |
+
attention_bias (`bool`, *optional*):
|
45 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
num_attention_heads: int = 16,
|
51 |
+
attention_head_dim: int = 88,
|
52 |
+
in_channels: Optional[int] = None,
|
53 |
+
num_layers: int = 1,
|
54 |
+
dropout: float = 0.0,
|
55 |
+
norm_num_groups: int = 32,
|
56 |
+
cross_attention_dim: Optional[int] = None,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
sample_size: Optional[int] = None,
|
59 |
+
num_vector_embeds: Optional[int] = None,
|
60 |
+
activation_fn: str = "geglu",
|
61 |
+
num_embeds_ada_norm: Optional[int] = None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.transformers = nn.ModuleList(
|
65 |
+
[
|
66 |
+
Transformer2DModel(
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=attention_head_dim,
|
69 |
+
in_channels=in_channels,
|
70 |
+
num_layers=num_layers,
|
71 |
+
dropout=dropout,
|
72 |
+
norm_num_groups=norm_num_groups,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attention_bias=attention_bias,
|
75 |
+
sample_size=sample_size,
|
76 |
+
num_vector_embeds=num_vector_embeds,
|
77 |
+
activation_fn=activation_fn,
|
78 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
79 |
+
)
|
80 |
+
for _ in range(2)
|
81 |
+
]
|
82 |
+
)
|
83 |
+
|
84 |
+
# Variables that can be set by a pipeline:
|
85 |
+
|
86 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
87 |
+
self.mix_ratio = 0.5
|
88 |
+
|
89 |
+
# The shape of `encoder_hidden_states` is expected to be
|
90 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
91 |
+
self.condition_lengths = [77, 257]
|
92 |
+
|
93 |
+
# Which transformer to use to encode which condition.
|
94 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
95 |
+
self.transformer_index_for_condition = [1, 0]
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
hidden_states,
|
100 |
+
encoder_hidden_states,
|
101 |
+
timestep=None,
|
102 |
+
attention_mask=None,
|
103 |
+
cross_attention_kwargs=None,
|
104 |
+
return_dict: bool = True,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
109 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
110 |
+
hidden_states
|
111 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
112 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
113 |
+
self-attention.
|
114 |
+
timestep ( `torch.long`, *optional*):
|
115 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
116 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
117 |
+
Optional attention mask to be applied in Attention
|
118 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
119 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
123 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
124 |
+
returning a tuple, the first element is the sample tensor.
|
125 |
+
"""
|
126 |
+
input_states = hidden_states
|
127 |
+
|
128 |
+
encoded_states = []
|
129 |
+
tokens_start = 0
|
130 |
+
# attention_mask is not used yet
|
131 |
+
for i in range(2):
|
132 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
133 |
+
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
134 |
+
transformer_index = self.transformer_index_for_condition[i]
|
135 |
+
encoded_state = self.transformers[transformer_index](
|
136 |
+
input_states,
|
137 |
+
encoder_hidden_states=condition_state,
|
138 |
+
timestep=timestep,
|
139 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
140 |
+
return_dict=False,
|
141 |
+
)[0]
|
142 |
+
encoded_states.append(encoded_state - input_states)
|
143 |
+
tokens_start += self.condition_lengths[i]
|
144 |
+
|
145 |
+
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
146 |
+
output_states = output_states + input_states
|
147 |
+
|
148 |
+
if not return_dict:
|
149 |
+
return (output_states,)
|
150 |
+
|
151 |
+
return Transformer2DModelOutput(sample=output_states)
|
Tiger Model/diffusiers-Tiger/models/embeddings.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from .activations import get_activation
|
22 |
+
|
23 |
+
|
24 |
+
def get_timestep_embedding(
|
25 |
+
timesteps: torch.Tensor,
|
26 |
+
embedding_dim: int,
|
27 |
+
flip_sin_to_cos: bool = False,
|
28 |
+
downscale_freq_shift: float = 1,
|
29 |
+
scale: float = 1,
|
30 |
+
max_period: int = 10000,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
34 |
+
|
35 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
36 |
+
These may be fractional.
|
37 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
38 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
39 |
+
"""
|
40 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
41 |
+
|
42 |
+
half_dim = embedding_dim // 2
|
43 |
+
exponent = -math.log(max_period) * torch.arange(
|
44 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
45 |
+
)
|
46 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
47 |
+
|
48 |
+
emb = torch.exp(exponent)
|
49 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
50 |
+
|
51 |
+
# scale embeddings
|
52 |
+
emb = scale * emb
|
53 |
+
|
54 |
+
# concat sine and cosine embeddings
|
55 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
56 |
+
|
57 |
+
# flip sine and cosine embeddings
|
58 |
+
if flip_sin_to_cos:
|
59 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
60 |
+
|
61 |
+
# zero pad
|
62 |
+
if embedding_dim % 2 == 1:
|
63 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
64 |
+
return emb
|
65 |
+
|
66 |
+
|
67 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
68 |
+
"""
|
69 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
70 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
71 |
+
"""
|
72 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
73 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
74 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
75 |
+
grid = np.stack(grid, axis=0)
|
76 |
+
|
77 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
78 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
79 |
+
if cls_token and extra_tokens > 0:
|
80 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
81 |
+
return pos_embed
|
82 |
+
|
83 |
+
|
84 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
85 |
+
if embed_dim % 2 != 0:
|
86 |
+
raise ValueError("embed_dim must be divisible by 2")
|
87 |
+
|
88 |
+
# use half of dimensions to encode grid_h
|
89 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
90 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
91 |
+
|
92 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
93 |
+
return emb
|
94 |
+
|
95 |
+
|
96 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
97 |
+
"""
|
98 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
99 |
+
"""
|
100 |
+
if embed_dim % 2 != 0:
|
101 |
+
raise ValueError("embed_dim must be divisible by 2")
|
102 |
+
|
103 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
104 |
+
omega /= embed_dim / 2.0
|
105 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
106 |
+
|
107 |
+
pos = pos.reshape(-1) # (M,)
|
108 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
109 |
+
|
110 |
+
emb_sin = np.sin(out) # (M, D/2)
|
111 |
+
emb_cos = np.cos(out) # (M, D/2)
|
112 |
+
|
113 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
114 |
+
return emb
|
115 |
+
|
116 |
+
|
117 |
+
class PatchEmbed(nn.Module):
|
118 |
+
"""2D Image to Patch Embedding"""
|
119 |
+
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
height=224,
|
123 |
+
width=224,
|
124 |
+
patch_size=16,
|
125 |
+
in_channels=3,
|
126 |
+
embed_dim=768,
|
127 |
+
layer_norm=False,
|
128 |
+
flatten=True,
|
129 |
+
bias=True,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
134 |
+
self.flatten = flatten
|
135 |
+
self.layer_norm = layer_norm
|
136 |
+
|
137 |
+
self.proj = nn.Conv2d(
|
138 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
139 |
+
)
|
140 |
+
if layer_norm:
|
141 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
142 |
+
else:
|
143 |
+
self.norm = None
|
144 |
+
|
145 |
+
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
|
146 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
147 |
+
|
148 |
+
def forward(self, latent):
|
149 |
+
latent = self.proj(latent)
|
150 |
+
if self.flatten:
|
151 |
+
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
152 |
+
if self.layer_norm:
|
153 |
+
latent = self.norm(latent)
|
154 |
+
return latent + self.pos_embed
|
155 |
+
|
156 |
+
|
157 |
+
class TimestepEmbedding(nn.Module):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
in_channels: int,
|
161 |
+
time_embed_dim: int,
|
162 |
+
act_fn: str = "silu",
|
163 |
+
out_dim: int = None,
|
164 |
+
post_act_fn: Optional[str] = None,
|
165 |
+
cond_proj_dim=None,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
170 |
+
|
171 |
+
if cond_proj_dim is not None:
|
172 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
173 |
+
else:
|
174 |
+
self.cond_proj = None
|
175 |
+
|
176 |
+
self.act = get_activation(act_fn)
|
177 |
+
|
178 |
+
if out_dim is not None:
|
179 |
+
time_embed_dim_out = out_dim
|
180 |
+
else:
|
181 |
+
time_embed_dim_out = time_embed_dim
|
182 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
183 |
+
|
184 |
+
if post_act_fn is None:
|
185 |
+
self.post_act = None
|
186 |
+
else:
|
187 |
+
self.post_act = get_activation(post_act_fn)
|
188 |
+
|
189 |
+
def forward(self, sample, condition=None):
|
190 |
+
if condition is not None:
|
191 |
+
sample = sample + self.cond_proj(condition)
|
192 |
+
sample = self.linear_1(sample)
|
193 |
+
|
194 |
+
if self.act is not None:
|
195 |
+
sample = self.act(sample)
|
196 |
+
|
197 |
+
sample = self.linear_2(sample)
|
198 |
+
|
199 |
+
if self.post_act is not None:
|
200 |
+
sample = self.post_act(sample)
|
201 |
+
return sample
|
202 |
+
|
203 |
+
|
204 |
+
class Timesteps(nn.Module):
|
205 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
206 |
+
super().__init__()
|
207 |
+
self.num_channels = num_channels
|
208 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
209 |
+
self.downscale_freq_shift = downscale_freq_shift
|
210 |
+
|
211 |
+
def forward(self, timesteps):
|
212 |
+
t_emb = get_timestep_embedding(
|
213 |
+
timesteps,
|
214 |
+
self.num_channels,
|
215 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
216 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
217 |
+
)
|
218 |
+
return t_emb
|
219 |
+
|
220 |
+
|
221 |
+
class GaussianFourierProjection(nn.Module):
|
222 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
223 |
+
|
224 |
+
def __init__(
|
225 |
+
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
226 |
+
):
|
227 |
+
super().__init__()
|
228 |
+
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
229 |
+
self.log = log
|
230 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
231 |
+
|
232 |
+
if set_W_to_weight:
|
233 |
+
# to delete later
|
234 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
235 |
+
|
236 |
+
self.weight = self.W
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
if self.log:
|
240 |
+
x = torch.log(x)
|
241 |
+
|
242 |
+
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
243 |
+
|
244 |
+
if self.flip_sin_to_cos:
|
245 |
+
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
246 |
+
else:
|
247 |
+
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
248 |
+
return out
|
249 |
+
|
250 |
+
|
251 |
+
class ImagePositionalEmbeddings(nn.Module):
|
252 |
+
"""
|
253 |
+
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
254 |
+
height and width of the latent space.
|
255 |
+
|
256 |
+
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
257 |
+
|
258 |
+
For VQ-diffusion:
|
259 |
+
|
260 |
+
Output vector embeddings are used as input for the transformer.
|
261 |
+
|
262 |
+
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
num_embed (`int`):
|
266 |
+
Number of embeddings for the latent pixels embeddings.
|
267 |
+
height (`int`):
|
268 |
+
Height of the latent image i.e. the number of height embeddings.
|
269 |
+
width (`int`):
|
270 |
+
Width of the latent image i.e. the number of width embeddings.
|
271 |
+
embed_dim (`int`):
|
272 |
+
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
273 |
+
"""
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
num_embed: int,
|
278 |
+
height: int,
|
279 |
+
width: int,
|
280 |
+
embed_dim: int,
|
281 |
+
):
|
282 |
+
super().__init__()
|
283 |
+
|
284 |
+
self.height = height
|
285 |
+
self.width = width
|
286 |
+
self.num_embed = num_embed
|
287 |
+
self.embed_dim = embed_dim
|
288 |
+
|
289 |
+
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
290 |
+
self.height_emb = nn.Embedding(self.height, embed_dim)
|
291 |
+
self.width_emb = nn.Embedding(self.width, embed_dim)
|
292 |
+
|
293 |
+
def forward(self, index):
|
294 |
+
emb = self.emb(index)
|
295 |
+
|
296 |
+
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
297 |
+
|
298 |
+
# 1 x H x D -> 1 x H x 1 x D
|
299 |
+
height_emb = height_emb.unsqueeze(2)
|
300 |
+
|
301 |
+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
302 |
+
|
303 |
+
# 1 x W x D -> 1 x 1 x W x D
|
304 |
+
width_emb = width_emb.unsqueeze(1)
|
305 |
+
|
306 |
+
pos_emb = height_emb + width_emb
|
307 |
+
|
308 |
+
# 1 x H x W x D -> 1 x L xD
|
309 |
+
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
310 |
+
|
311 |
+
emb = emb + pos_emb[:, : emb.shape[1], :]
|
312 |
+
|
313 |
+
return emb
|
314 |
+
|
315 |
+
|
316 |
+
class LabelEmbedding(nn.Module):
|
317 |
+
"""
|
318 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
num_classes (`int`): The number of classes.
|
322 |
+
hidden_size (`int`): The size of the vector embeddings.
|
323 |
+
dropout_prob (`float`): The probability of dropping a label.
|
324 |
+
"""
|
325 |
+
|
326 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
327 |
+
super().__init__()
|
328 |
+
use_cfg_embedding = dropout_prob > 0
|
329 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
330 |
+
self.num_classes = num_classes
|
331 |
+
self.dropout_prob = dropout_prob
|
332 |
+
|
333 |
+
def token_drop(self, labels, force_drop_ids=None):
|
334 |
+
"""
|
335 |
+
Drops labels to enable classifier-free guidance.
|
336 |
+
"""
|
337 |
+
if force_drop_ids is None:
|
338 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
339 |
+
else:
|
340 |
+
drop_ids = torch.tensor(force_drop_ids == 1)
|
341 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
342 |
+
return labels
|
343 |
+
|
344 |
+
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
345 |
+
use_dropout = self.dropout_prob > 0
|
346 |
+
if (self.training and use_dropout) or (force_drop_ids is not None):
|
347 |
+
labels = self.token_drop(labels, force_drop_ids)
|
348 |
+
embeddings = self.embedding_table(labels)
|
349 |
+
return embeddings
|
350 |
+
|
351 |
+
|
352 |
+
class TextImageProjection(nn.Module):
|
353 |
+
def __init__(
|
354 |
+
self,
|
355 |
+
text_embed_dim: int = 1024,
|
356 |
+
image_embed_dim: int = 768,
|
357 |
+
cross_attention_dim: int = 768,
|
358 |
+
num_image_text_embeds: int = 10,
|
359 |
+
):
|
360 |
+
super().__init__()
|
361 |
+
|
362 |
+
self.num_image_text_embeds = num_image_text_embeds
|
363 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
364 |
+
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
365 |
+
|
366 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
367 |
+
batch_size = text_embeds.shape[0]
|
368 |
+
|
369 |
+
# image
|
370 |
+
image_text_embeds = self.image_embeds(image_embeds)
|
371 |
+
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
372 |
+
|
373 |
+
# text
|
374 |
+
text_embeds = self.text_proj(text_embeds)
|
375 |
+
|
376 |
+
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
377 |
+
|
378 |
+
|
379 |
+
class ImageProjection(nn.Module):
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
image_embed_dim: int = 768,
|
383 |
+
cross_attention_dim: int = 768,
|
384 |
+
num_image_text_embeds: int = 32,
|
385 |
+
):
|
386 |
+
super().__init__()
|
387 |
+
|
388 |
+
self.num_image_text_embeds = num_image_text_embeds
|
389 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
390 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
391 |
+
|
392 |
+
def forward(self, image_embeds: torch.FloatTensor):
|
393 |
+
batch_size = image_embeds.shape[0]
|
394 |
+
|
395 |
+
# image
|
396 |
+
image_embeds = self.image_embeds(image_embeds)
|
397 |
+
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
398 |
+
image_embeds = self.norm(image_embeds)
|
399 |
+
return image_embeds
|
400 |
+
|
401 |
+
|
402 |
+
class CombinedTimestepLabelEmbeddings(nn.Module):
|
403 |
+
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
404 |
+
super().__init__()
|
405 |
+
|
406 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
407 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
408 |
+
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
409 |
+
|
410 |
+
def forward(self, timestep, class_labels, hidden_dtype=None):
|
411 |
+
timesteps_proj = self.time_proj(timestep)
|
412 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
413 |
+
|
414 |
+
class_labels = self.class_embedder(class_labels) # (N, D)
|
415 |
+
|
416 |
+
conditioning = timesteps_emb + class_labels # (N, D)
|
417 |
+
|
418 |
+
return conditioning
|
419 |
+
|
420 |
+
|
421 |
+
class TextTimeEmbedding(nn.Module):
|
422 |
+
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
423 |
+
super().__init__()
|
424 |
+
self.norm1 = nn.LayerNorm(encoder_dim)
|
425 |
+
self.pool = AttentionPooling(num_heads, encoder_dim)
|
426 |
+
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
427 |
+
self.norm2 = nn.LayerNorm(time_embed_dim)
|
428 |
+
|
429 |
+
def forward(self, hidden_states):
|
430 |
+
hidden_states = self.norm1(hidden_states)
|
431 |
+
hidden_states = self.pool(hidden_states)
|
432 |
+
hidden_states = self.proj(hidden_states)
|
433 |
+
hidden_states = self.norm2(hidden_states)
|
434 |
+
return hidden_states
|
435 |
+
|
436 |
+
|
437 |
+
class TextImageTimeEmbedding(nn.Module):
|
438 |
+
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
439 |
+
super().__init__()
|
440 |
+
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
|
441 |
+
self.text_norm = nn.LayerNorm(time_embed_dim)
|
442 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
443 |
+
|
444 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
445 |
+
# text
|
446 |
+
time_text_embeds = self.text_proj(text_embeds)
|
447 |
+
time_text_embeds = self.text_norm(time_text_embeds)
|
448 |
+
|
449 |
+
# image
|
450 |
+
time_image_embeds = self.image_proj(image_embeds)
|
451 |
+
|
452 |
+
return time_image_embeds + time_text_embeds
|
453 |
+
|
454 |
+
|
455 |
+
class ImageTimeEmbedding(nn.Module):
|
456 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
457 |
+
super().__init__()
|
458 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
459 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
460 |
+
|
461 |
+
def forward(self, image_embeds: torch.FloatTensor):
|
462 |
+
# image
|
463 |
+
time_image_embeds = self.image_proj(image_embeds)
|
464 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
465 |
+
return time_image_embeds
|
466 |
+
|
467 |
+
|
468 |
+
class ImageHintTimeEmbedding(nn.Module):
|
469 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
470 |
+
super().__init__()
|
471 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
472 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
473 |
+
self.input_hint_block = nn.Sequential(
|
474 |
+
nn.Conv2d(3, 16, 3, padding=1),
|
475 |
+
nn.SiLU(),
|
476 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
477 |
+
nn.SiLU(),
|
478 |
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
479 |
+
nn.SiLU(),
|
480 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
481 |
+
nn.SiLU(),
|
482 |
+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
483 |
+
nn.SiLU(),
|
484 |
+
nn.Conv2d(96, 96, 3, padding=1),
|
485 |
+
nn.SiLU(),
|
486 |
+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
487 |
+
nn.SiLU(),
|
488 |
+
nn.Conv2d(256, 4, 3, padding=1),
|
489 |
+
)
|
490 |
+
|
491 |
+
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
|
492 |
+
# image
|
493 |
+
time_image_embeds = self.image_proj(image_embeds)
|
494 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
495 |
+
hint = self.input_hint_block(hint)
|
496 |
+
return time_image_embeds, hint
|
497 |
+
|
498 |
+
|
499 |
+
class AttentionPooling(nn.Module):
|
500 |
+
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
501 |
+
|
502 |
+
def __init__(self, num_heads, embed_dim, dtype=None):
|
503 |
+
super().__init__()
|
504 |
+
self.dtype = dtype
|
505 |
+
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
506 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
507 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
508 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
509 |
+
self.num_heads = num_heads
|
510 |
+
self.dim_per_head = embed_dim // self.num_heads
|
511 |
+
|
512 |
+
def forward(self, x):
|
513 |
+
bs, length, width = x.size()
|
514 |
+
|
515 |
+
def shape(x):
|
516 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
517 |
+
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
518 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
519 |
+
x = x.transpose(1, 2)
|
520 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
521 |
+
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
522 |
+
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
523 |
+
x = x.transpose(1, 2)
|
524 |
+
return x
|
525 |
+
|
526 |
+
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
527 |
+
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
528 |
+
|
529 |
+
# (bs*n_heads, class_token_length, dim_per_head)
|
530 |
+
q = shape(self.q_proj(class_token))
|
531 |
+
# (bs*n_heads, length+class_token_length, dim_per_head)
|
532 |
+
k = shape(self.k_proj(x))
|
533 |
+
v = shape(self.v_proj(x))
|
534 |
+
|
535 |
+
# (bs*n_heads, class_token_length, length+class_token_length):
|
536 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
537 |
+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
538 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
539 |
+
|
540 |
+
# (bs*n_heads, dim_per_head, class_token_length)
|
541 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
542 |
+
|
543 |
+
# (bs, length+1, width)
|
544 |
+
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
545 |
+
|
546 |
+
return a[:, 0, :] # cls_token
|
547 |
+
|
548 |
+
|
549 |
+
class FourierEmbedder(nn.Module):
|
550 |
+
def __init__(self, num_freqs=64, temperature=100):
|
551 |
+
super().__init__()
|
552 |
+
|
553 |
+
self.num_freqs = num_freqs
|
554 |
+
self.temperature = temperature
|
555 |
+
|
556 |
+
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
557 |
+
freq_bands = freq_bands[None, None, None]
|
558 |
+
self.register_buffer("freq_bands", freq_bands, persistent=False)
|
559 |
+
|
560 |
+
def __call__(self, x):
|
561 |
+
x = self.freq_bands * x.unsqueeze(-1)
|
562 |
+
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
563 |
+
|
564 |
+
|
565 |
+
class PositionNet(nn.Module):
|
566 |
+
def __init__(self, positive_len, out_dim, fourier_freqs=8):
|
567 |
+
super().__init__()
|
568 |
+
self.positive_len = positive_len
|
569 |
+
self.out_dim = out_dim
|
570 |
+
|
571 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
572 |
+
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
573 |
+
|
574 |
+
if isinstance(out_dim, tuple):
|
575 |
+
out_dim = out_dim[0]
|
576 |
+
self.linears = nn.Sequential(
|
577 |
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
578 |
+
nn.SiLU(),
|
579 |
+
nn.Linear(512, 512),
|
580 |
+
nn.SiLU(),
|
581 |
+
nn.Linear(512, out_dim),
|
582 |
+
)
|
583 |
+
|
584 |
+
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
585 |
+
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
|
586 |
+
|
587 |
+
def forward(self, boxes, masks, positive_embeddings):
|
588 |
+
masks = masks.unsqueeze(-1)
|
589 |
+
|
590 |
+
# embedding position (it may includes padding as placeholder)
|
591 |
+
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
|
592 |
+
|
593 |
+
# learnable null embedding
|
594 |
+
positive_null = self.null_positive_feature.view(1, 1, -1)
|
595 |
+
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
596 |
+
|
597 |
+
# replace padding with learnable null embedding
|
598 |
+
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
|
599 |
+
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
600 |
+
|
601 |
+
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
602 |
+
return objs
|
Tiger Model/diffusiers-Tiger/models/lora.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
class LoRALinearLayer(nn.Module):
|
22 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
26 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
27 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
28 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
29 |
+
self.network_alpha = network_alpha
|
30 |
+
self.rank = rank
|
31 |
+
|
32 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
33 |
+
nn.init.zeros_(self.up.weight)
|
34 |
+
|
35 |
+
def forward(self, hidden_states):
|
36 |
+
orig_dtype = hidden_states.dtype
|
37 |
+
dtype = self.down.weight.dtype
|
38 |
+
|
39 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
40 |
+
up_hidden_states = self.up(down_hidden_states)
|
41 |
+
|
42 |
+
if self.network_alpha is not None:
|
43 |
+
up_hidden_states *= self.network_alpha / self.rank
|
44 |
+
|
45 |
+
return up_hidden_states.to(orig_dtype)
|
46 |
+
|
47 |
+
|
48 |
+
class LoRAConv2dLayer(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
55 |
+
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
56 |
+
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
57 |
+
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
58 |
+
|
59 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
60 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
61 |
+
self.network_alpha = network_alpha
|
62 |
+
self.rank = rank
|
63 |
+
|
64 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
65 |
+
nn.init.zeros_(self.up.weight)
|
66 |
+
|
67 |
+
def forward(self, hidden_states):
|
68 |
+
orig_dtype = hidden_states.dtype
|
69 |
+
dtype = self.down.weight.dtype
|
70 |
+
|
71 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
72 |
+
up_hidden_states = self.up(down_hidden_states)
|
73 |
+
|
74 |
+
if self.network_alpha is not None:
|
75 |
+
up_hidden_states *= self.network_alpha / self.rank
|
76 |
+
|
77 |
+
return up_hidden_states.to(orig_dtype)
|
78 |
+
|
79 |
+
|
80 |
+
class LoRACompatibleConv(nn.Conv2d):
|
81 |
+
"""
|
82 |
+
A convolutional layer that can be used with LoRA.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
|
86 |
+
super().__init__(*args, **kwargs)
|
87 |
+
self.lora_layer = lora_layer
|
88 |
+
|
89 |
+
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
90 |
+
self.lora_layer = lora_layer
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
if self.lora_layer is None:
|
94 |
+
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
95 |
+
# see: https://github.com/huggingface/diffusers/pull/4315
|
96 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
97 |
+
else:
|
98 |
+
return super().forward(x) + self.lora_layer(x)
|
99 |
+
|
100 |
+
|
101 |
+
class LoRACompatibleLinear(nn.Linear):
|
102 |
+
"""
|
103 |
+
A Linear layer that can be used with LoRA.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
107 |
+
super().__init__(*args, **kwargs)
|
108 |
+
self.lora_layer = lora_layer
|
109 |
+
|
110 |
+
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
111 |
+
self.lora_layer = lora_layer
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
if self.lora_layer is None:
|
115 |
+
return super().forward(x)
|
116 |
+
else:
|
117 |
+
return super().forward(x) + self.lora_layer(x)
|
Tiger Model/diffusiers-Tiger/models/modeling_utils.py
ADDED
@@ -0,0 +1,997 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
import itertools
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from functools import partial
|
22 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import safetensors
|
25 |
+
import torch
|
26 |
+
from huggingface_hub import create_repo
|
27 |
+
from torch import Tensor, device, nn
|
28 |
+
|
29 |
+
from .. import __version__
|
30 |
+
from ..utils import (
|
31 |
+
CONFIG_NAME,
|
32 |
+
DIFFUSERS_CACHE,
|
33 |
+
FLAX_WEIGHTS_NAME,
|
34 |
+
HF_HUB_OFFLINE,
|
35 |
+
SAFETENSORS_WEIGHTS_NAME,
|
36 |
+
WEIGHTS_NAME,
|
37 |
+
_add_variant,
|
38 |
+
_get_model_file,
|
39 |
+
deprecate,
|
40 |
+
is_accelerate_available,
|
41 |
+
is_torch_version,
|
42 |
+
logging,
|
43 |
+
)
|
44 |
+
from ..utils.hub_utils import PushToHubMixin
|
45 |
+
|
46 |
+
|
47 |
+
logger = logging.get_logger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
if is_torch_version(">=", "1.9.0"):
|
51 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
52 |
+
else:
|
53 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
54 |
+
|
55 |
+
|
56 |
+
if is_accelerate_available():
|
57 |
+
import accelerate
|
58 |
+
from accelerate.utils import set_module_tensor_to_device
|
59 |
+
from accelerate.utils.versions import is_torch_version
|
60 |
+
|
61 |
+
|
62 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
63 |
+
try:
|
64 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
65 |
+
return next(parameters_and_buffers).device
|
66 |
+
except StopIteration:
|
67 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
68 |
+
|
69 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
70 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
71 |
+
return tuples
|
72 |
+
|
73 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
74 |
+
first_tuple = next(gen)
|
75 |
+
return first_tuple[1].device
|
76 |
+
|
77 |
+
|
78 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
79 |
+
try:
|
80 |
+
params = tuple(parameter.parameters())
|
81 |
+
if len(params) > 0:
|
82 |
+
return params[0].dtype
|
83 |
+
|
84 |
+
buffers = tuple(parameter.buffers())
|
85 |
+
if len(buffers) > 0:
|
86 |
+
return buffers[0].dtype
|
87 |
+
|
88 |
+
except StopIteration:
|
89 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
90 |
+
|
91 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
92 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
93 |
+
return tuples
|
94 |
+
|
95 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
96 |
+
first_tuple = next(gen)
|
97 |
+
return first_tuple[1].dtype
|
98 |
+
|
99 |
+
|
100 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
101 |
+
"""
|
102 |
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
103 |
+
"""
|
104 |
+
try:
|
105 |
+
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
106 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
107 |
+
else:
|
108 |
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
109 |
+
except Exception as e:
|
110 |
+
try:
|
111 |
+
with open(checkpoint_file) as f:
|
112 |
+
if f.read().startswith("version"):
|
113 |
+
raise OSError(
|
114 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
115 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
116 |
+
"you cloned."
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
raise ValueError(
|
120 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
121 |
+
"model. Make sure you have saved the model properly."
|
122 |
+
) from e
|
123 |
+
except (UnicodeDecodeError, ValueError):
|
124 |
+
raise OSError(
|
125 |
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
126 |
+
f"at '{checkpoint_file}'. "
|
127 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
132 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
133 |
+
# copy state_dict so _load_from_state_dict can modify it
|
134 |
+
state_dict = state_dict.copy()
|
135 |
+
error_msgs = []
|
136 |
+
|
137 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
138 |
+
# so we need to apply the function recursively.
|
139 |
+
def load(module: torch.nn.Module, prefix=""):
|
140 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
141 |
+
module._load_from_state_dict(*args)
|
142 |
+
|
143 |
+
for name, child in module._modules.items():
|
144 |
+
if child is not None:
|
145 |
+
load(child, prefix + name + ".")
|
146 |
+
|
147 |
+
load(model_to_load)
|
148 |
+
|
149 |
+
return error_msgs
|
150 |
+
|
151 |
+
|
152 |
+
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
153 |
+
r"""
|
154 |
+
Base class for all models.
|
155 |
+
|
156 |
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
157 |
+
saving models.
|
158 |
+
|
159 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
160 |
+
"""
|
161 |
+
config_name = CONFIG_NAME
|
162 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
163 |
+
_supports_gradient_checkpointing = False
|
164 |
+
_keys_to_ignore_on_load_unexpected = None
|
165 |
+
|
166 |
+
def __init__(self):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
def __getattr__(self, name: str) -> Any:
|
170 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
171 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
172 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
173 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
174 |
+
"""
|
175 |
+
|
176 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
177 |
+
is_attribute = name in self.__dict__
|
178 |
+
|
179 |
+
if is_in_config and not is_attribute:
|
180 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
181 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
182 |
+
return self._internal_dict[name]
|
183 |
+
|
184 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
185 |
+
return super().__getattr__(name)
|
186 |
+
|
187 |
+
@property
|
188 |
+
def is_gradient_checkpointing(self) -> bool:
|
189 |
+
"""
|
190 |
+
Whether gradient checkpointing is activated for this model or not.
|
191 |
+
"""
|
192 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
193 |
+
|
194 |
+
def enable_gradient_checkpointing(self):
|
195 |
+
"""
|
196 |
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
197 |
+
*checkpoint activations* in other frameworks).
|
198 |
+
"""
|
199 |
+
if not self._supports_gradient_checkpointing:
|
200 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
201 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
202 |
+
|
203 |
+
def disable_gradient_checkpointing(self):
|
204 |
+
"""
|
205 |
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
206 |
+
*checkpoint activations* in other frameworks).
|
207 |
+
"""
|
208 |
+
if self._supports_gradient_checkpointing:
|
209 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
210 |
+
|
211 |
+
def set_use_memory_efficient_attention_xformers(
|
212 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
213 |
+
) -> None:
|
214 |
+
# Recursively walk through all the children.
|
215 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
216 |
+
# gets the message
|
217 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
218 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
219 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
220 |
+
|
221 |
+
for child in module.children():
|
222 |
+
fn_recursive_set_mem_eff(child)
|
223 |
+
|
224 |
+
for module in self.children():
|
225 |
+
if isinstance(module, torch.nn.Module):
|
226 |
+
fn_recursive_set_mem_eff(module)
|
227 |
+
|
228 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
229 |
+
r"""
|
230 |
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
231 |
+
|
232 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
233 |
+
inference. Speed up during training is not guaranteed.
|
234 |
+
|
235 |
+
<Tip warning={true}>
|
236 |
+
|
237 |
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
238 |
+
precedent.
|
239 |
+
|
240 |
+
</Tip>
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
attention_op (`Callable`, *optional*):
|
244 |
+
Override the default `None` operator for use as `op` argument to the
|
245 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
246 |
+
function of xFormers.
|
247 |
+
|
248 |
+
Examples:
|
249 |
+
|
250 |
+
```py
|
251 |
+
>>> import torch
|
252 |
+
>>> from diffusers import UNet2DConditionModel
|
253 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
254 |
+
|
255 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
256 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
257 |
+
... )
|
258 |
+
>>> model = model.to("cuda")
|
259 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
260 |
+
```
|
261 |
+
"""
|
262 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
263 |
+
|
264 |
+
def disable_xformers_memory_efficient_attention(self):
|
265 |
+
r"""
|
266 |
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
267 |
+
"""
|
268 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
269 |
+
|
270 |
+
def save_pretrained(
|
271 |
+
self,
|
272 |
+
save_directory: Union[str, os.PathLike],
|
273 |
+
is_main_process: bool = True,
|
274 |
+
save_function: Callable = None,
|
275 |
+
safe_serialization: bool = True,
|
276 |
+
variant: Optional[str] = None,
|
277 |
+
push_to_hub: bool = False,
|
278 |
+
**kwargs,
|
279 |
+
):
|
280 |
+
"""
|
281 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
282 |
+
[`~models.ModelMixin.from_pretrained`] class method.
|
283 |
+
|
284 |
+
Arguments:
|
285 |
+
save_directory (`str` or `os.PathLike`):
|
286 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
287 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
288 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
289 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
290 |
+
process to avoid race conditions.
|
291 |
+
save_function (`Callable`):
|
292 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
293 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
294 |
+
`DIFFUSERS_SAVE_MODE`.
|
295 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
296 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
297 |
+
variant (`str`, *optional*):
|
298 |
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
299 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
300 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
301 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
302 |
+
namespace).
|
303 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
304 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
305 |
+
"""
|
306 |
+
if os.path.isfile(save_directory):
|
307 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
308 |
+
return
|
309 |
+
|
310 |
+
os.makedirs(save_directory, exist_ok=True)
|
311 |
+
|
312 |
+
if push_to_hub:
|
313 |
+
commit_message = kwargs.pop("commit_message", None)
|
314 |
+
private = kwargs.pop("private", False)
|
315 |
+
create_pr = kwargs.pop("create_pr", False)
|
316 |
+
token = kwargs.pop("token", None)
|
317 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
318 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
319 |
+
|
320 |
+
# Only save the model itself if we are using distributed training
|
321 |
+
model_to_save = self
|
322 |
+
|
323 |
+
# Attach architecture to the config
|
324 |
+
# Save the config
|
325 |
+
if is_main_process:
|
326 |
+
model_to_save.save_config(save_directory)
|
327 |
+
|
328 |
+
# Save the model
|
329 |
+
state_dict = model_to_save.state_dict()
|
330 |
+
|
331 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
332 |
+
weights_name = _add_variant(weights_name, variant)
|
333 |
+
|
334 |
+
# Save the model
|
335 |
+
if safe_serialization:
|
336 |
+
safetensors.torch.save_file(
|
337 |
+
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
341 |
+
|
342 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
343 |
+
|
344 |
+
if push_to_hub:
|
345 |
+
self._upload_folder(
|
346 |
+
save_directory,
|
347 |
+
repo_id,
|
348 |
+
token=token,
|
349 |
+
commit_message=commit_message,
|
350 |
+
create_pr=create_pr,
|
351 |
+
)
|
352 |
+
|
353 |
+
@classmethod
|
354 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
355 |
+
r"""
|
356 |
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
357 |
+
|
358 |
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
359 |
+
train the model, set it back in training mode with `model.train()`.
|
360 |
+
|
361 |
+
Parameters:
|
362 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
363 |
+
Can be either:
|
364 |
+
|
365 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
366 |
+
the Hub.
|
367 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
368 |
+
with [`~ModelMixin.save_pretrained`].
|
369 |
+
|
370 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
371 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
372 |
+
is not used.
|
373 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
374 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
375 |
+
dtype is automatically derived from the model's weights.
|
376 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
377 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
378 |
+
cached versions if they exist.
|
379 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
380 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
381 |
+
incompletely downloaded files are deleted.
|
382 |
+
proxies (`Dict[str, str]`, *optional*):
|
383 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
384 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
385 |
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
386 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
387 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
388 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
389 |
+
won't be downloaded from the Hub.
|
390 |
+
use_auth_token (`str` or *bool*, *optional*):
|
391 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
392 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
393 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
394 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
395 |
+
allowed by Git.
|
396 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
397 |
+
Load the model weights from a Flax checkpoint save file.
|
398 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
399 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
400 |
+
mirror (`str`, *optional*):
|
401 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
402 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
403 |
+
information.
|
404 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
405 |
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
406 |
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
407 |
+
same device.
|
408 |
+
|
409 |
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
410 |
+
more information about each option see [designing a device
|
411 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
412 |
+
max_memory (`Dict`, *optional*):
|
413 |
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
414 |
+
each GPU and the available CPU RAM if unset.
|
415 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
416 |
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
417 |
+
offload_state_dict (`bool`, *optional*):
|
418 |
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
419 |
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
420 |
+
when there is some disk offload.
|
421 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
422 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
423 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
424 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
425 |
+
argument to `True` will raise an error.
|
426 |
+
variant (`str`, *optional*):
|
427 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
428 |
+
loading `from_flax`.
|
429 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
430 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
431 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
432 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
433 |
+
|
434 |
+
<Tip>
|
435 |
+
|
436 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
437 |
+
`huggingface-cli login`. You can also activate the special
|
438 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
439 |
+
firewalled environment.
|
440 |
+
|
441 |
+
</Tip>
|
442 |
+
|
443 |
+
Example:
|
444 |
+
|
445 |
+
```py
|
446 |
+
from diffusers import UNet2DConditionModel
|
447 |
+
|
448 |
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
449 |
+
```
|
450 |
+
|
451 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
452 |
+
|
453 |
+
```bash
|
454 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
455 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
456 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
457 |
+
```
|
458 |
+
"""
|
459 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
460 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
461 |
+
force_download = kwargs.pop("force_download", False)
|
462 |
+
from_flax = kwargs.pop("from_flax", False)
|
463 |
+
resume_download = kwargs.pop("resume_download", False)
|
464 |
+
proxies = kwargs.pop("proxies", None)
|
465 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
466 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
467 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
468 |
+
revision = kwargs.pop("revision", None)
|
469 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
470 |
+
subfolder = kwargs.pop("subfolder", None)
|
471 |
+
device_map = kwargs.pop("device_map", None)
|
472 |
+
max_memory = kwargs.pop("max_memory", None)
|
473 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
474 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
475 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
476 |
+
variant = kwargs.pop("variant", None)
|
477 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
478 |
+
|
479 |
+
allow_pickle = False
|
480 |
+
if use_safetensors is None:
|
481 |
+
use_safetensors = True
|
482 |
+
allow_pickle = True
|
483 |
+
|
484 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
485 |
+
low_cpu_mem_usage = False
|
486 |
+
logger.warning(
|
487 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
488 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
489 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
490 |
+
" install accelerate\n```\n."
|
491 |
+
)
|
492 |
+
|
493 |
+
if device_map is not None and not is_accelerate_available():
|
494 |
+
raise NotImplementedError(
|
495 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
496 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
497 |
+
)
|
498 |
+
|
499 |
+
# Check if we can handle device_map and dispatching the weights
|
500 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
501 |
+
raise NotImplementedError(
|
502 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
503 |
+
" `device_map=None`."
|
504 |
+
)
|
505 |
+
|
506 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
507 |
+
raise NotImplementedError(
|
508 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
509 |
+
" `low_cpu_mem_usage=False`."
|
510 |
+
)
|
511 |
+
|
512 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
513 |
+
raise ValueError(
|
514 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
515 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
516 |
+
)
|
517 |
+
|
518 |
+
# Load config if we don't provide a configuration
|
519 |
+
config_path = pretrained_model_name_or_path
|
520 |
+
|
521 |
+
user_agent = {
|
522 |
+
"diffusers": __version__,
|
523 |
+
"file_type": "model",
|
524 |
+
"framework": "pytorch",
|
525 |
+
}
|
526 |
+
|
527 |
+
# load config
|
528 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
529 |
+
config_path,
|
530 |
+
cache_dir=cache_dir,
|
531 |
+
return_unused_kwargs=True,
|
532 |
+
return_commit_hash=True,
|
533 |
+
force_download=force_download,
|
534 |
+
resume_download=resume_download,
|
535 |
+
proxies=proxies,
|
536 |
+
local_files_only=local_files_only,
|
537 |
+
use_auth_token=use_auth_token,
|
538 |
+
revision=revision,
|
539 |
+
subfolder=subfolder,
|
540 |
+
device_map=device_map,
|
541 |
+
max_memory=max_memory,
|
542 |
+
offload_folder=offload_folder,
|
543 |
+
offload_state_dict=offload_state_dict,
|
544 |
+
user_agent=user_agent,
|
545 |
+
**kwargs,
|
546 |
+
)
|
547 |
+
|
548 |
+
# load model
|
549 |
+
model_file = None
|
550 |
+
if from_flax:
|
551 |
+
model_file = _get_model_file(
|
552 |
+
pretrained_model_name_or_path,
|
553 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
554 |
+
cache_dir=cache_dir,
|
555 |
+
force_download=force_download,
|
556 |
+
resume_download=resume_download,
|
557 |
+
proxies=proxies,
|
558 |
+
local_files_only=local_files_only,
|
559 |
+
use_auth_token=use_auth_token,
|
560 |
+
revision=revision,
|
561 |
+
subfolder=subfolder,
|
562 |
+
user_agent=user_agent,
|
563 |
+
commit_hash=commit_hash,
|
564 |
+
)
|
565 |
+
model = cls.from_config(config, **unused_kwargs)
|
566 |
+
|
567 |
+
# Convert the weights
|
568 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
569 |
+
|
570 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
571 |
+
else:
|
572 |
+
if use_safetensors:
|
573 |
+
try:
|
574 |
+
model_file = _get_model_file(
|
575 |
+
pretrained_model_name_or_path,
|
576 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
577 |
+
cache_dir=cache_dir,
|
578 |
+
force_download=force_download,
|
579 |
+
resume_download=resume_download,
|
580 |
+
proxies=proxies,
|
581 |
+
local_files_only=local_files_only,
|
582 |
+
use_auth_token=use_auth_token,
|
583 |
+
revision=revision,
|
584 |
+
subfolder=subfolder,
|
585 |
+
user_agent=user_agent,
|
586 |
+
commit_hash=commit_hash,
|
587 |
+
)
|
588 |
+
except IOError as e:
|
589 |
+
if not allow_pickle:
|
590 |
+
raise e
|
591 |
+
pass
|
592 |
+
if model_file is None:
|
593 |
+
model_file = _get_model_file(
|
594 |
+
pretrained_model_name_or_path,
|
595 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
596 |
+
cache_dir=cache_dir,
|
597 |
+
force_download=force_download,
|
598 |
+
resume_download=resume_download,
|
599 |
+
proxies=proxies,
|
600 |
+
local_files_only=local_files_only,
|
601 |
+
use_auth_token=use_auth_token,
|
602 |
+
revision=revision,
|
603 |
+
subfolder=subfolder,
|
604 |
+
user_agent=user_agent,
|
605 |
+
commit_hash=commit_hash,
|
606 |
+
)
|
607 |
+
|
608 |
+
if low_cpu_mem_usage:
|
609 |
+
# Instantiate model with empty weights
|
610 |
+
with accelerate.init_empty_weights():
|
611 |
+
model = cls.from_config(config, **unused_kwargs)
|
612 |
+
|
613 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
614 |
+
if device_map is None:
|
615 |
+
param_device = "cpu"
|
616 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
617 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
618 |
+
# move the params from meta device to cpu
|
619 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
620 |
+
if len(missing_keys) > 0:
|
621 |
+
raise ValueError(
|
622 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
623 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
624 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
625 |
+
" those weights or else make sure your checkpoint file is correct."
|
626 |
+
)
|
627 |
+
unexpected_keys = []
|
628 |
+
|
629 |
+
empty_state_dict = model.state_dict()
|
630 |
+
for param_name, param in state_dict.items():
|
631 |
+
accepts_dtype = "dtype" in set(
|
632 |
+
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
633 |
+
)
|
634 |
+
|
635 |
+
if param_name not in empty_state_dict:
|
636 |
+
unexpected_keys.append(param_name)
|
637 |
+
continue
|
638 |
+
|
639 |
+
if empty_state_dict[param_name].shape != param.shape:
|
640 |
+
raise ValueError(
|
641 |
+
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
642 |
+
)
|
643 |
+
|
644 |
+
if accepts_dtype:
|
645 |
+
set_module_tensor_to_device(
|
646 |
+
model, param_name, param_device, value=param, dtype=torch_dtype
|
647 |
+
)
|
648 |
+
else:
|
649 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
650 |
+
|
651 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
652 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
653 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
654 |
+
|
655 |
+
if len(unexpected_keys) > 0:
|
656 |
+
logger.warn(
|
657 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
658 |
+
)
|
659 |
+
|
660 |
+
else: # else let accelerate handle loading and dispatching.
|
661 |
+
# Load weights and dispatch according to the device_map
|
662 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
663 |
+
try:
|
664 |
+
accelerate.load_checkpoint_and_dispatch(
|
665 |
+
model,
|
666 |
+
model_file,
|
667 |
+
device_map,
|
668 |
+
max_memory=max_memory,
|
669 |
+
offload_folder=offload_folder,
|
670 |
+
offload_state_dict=offload_state_dict,
|
671 |
+
dtype=torch_dtype,
|
672 |
+
)
|
673 |
+
except AttributeError as e:
|
674 |
+
# When using accelerate loading, we do not have the ability to load the state
|
675 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
676 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
677 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
678 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
679 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
680 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
681 |
+
# the weights so we don't have to do this again.
|
682 |
+
|
683 |
+
if "'Attention' object has no attribute" in str(e):
|
684 |
+
logger.warn(
|
685 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
686 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
687 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
688 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
689 |
+
" please also re-upload it or open a PR on the original repository."
|
690 |
+
)
|
691 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
692 |
+
accelerate.load_checkpoint_and_dispatch(
|
693 |
+
model,
|
694 |
+
model_file,
|
695 |
+
device_map,
|
696 |
+
max_memory=max_memory,
|
697 |
+
offload_folder=offload_folder,
|
698 |
+
offload_state_dict=offload_state_dict,
|
699 |
+
dtype=torch_dtype,
|
700 |
+
)
|
701 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
702 |
+
else:
|
703 |
+
raise e
|
704 |
+
|
705 |
+
loading_info = {
|
706 |
+
"missing_keys": [],
|
707 |
+
"unexpected_keys": [],
|
708 |
+
"mismatched_keys": [],
|
709 |
+
"error_msgs": [],
|
710 |
+
}
|
711 |
+
else:
|
712 |
+
model = cls.from_config(config, **unused_kwargs)
|
713 |
+
|
714 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
715 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
716 |
+
|
717 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
718 |
+
model,
|
719 |
+
state_dict,
|
720 |
+
model_file,
|
721 |
+
pretrained_model_name_or_path,
|
722 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
723 |
+
)
|
724 |
+
|
725 |
+
loading_info = {
|
726 |
+
"missing_keys": missing_keys,
|
727 |
+
"unexpected_keys": unexpected_keys,
|
728 |
+
"mismatched_keys": mismatched_keys,
|
729 |
+
"error_msgs": error_msgs,
|
730 |
+
}
|
731 |
+
|
732 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
733 |
+
raise ValueError(
|
734 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
735 |
+
)
|
736 |
+
elif torch_dtype is not None:
|
737 |
+
model = model.to(torch_dtype)
|
738 |
+
|
739 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
740 |
+
|
741 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
742 |
+
model.eval()
|
743 |
+
if output_loading_info:
|
744 |
+
return model, loading_info
|
745 |
+
|
746 |
+
return model
|
747 |
+
|
748 |
+
@classmethod
|
749 |
+
def _load_pretrained_model(
|
750 |
+
cls,
|
751 |
+
model,
|
752 |
+
state_dict,
|
753 |
+
resolved_archive_file,
|
754 |
+
pretrained_model_name_or_path,
|
755 |
+
ignore_mismatched_sizes=False,
|
756 |
+
):
|
757 |
+
# Retrieve missing & unexpected_keys
|
758 |
+
model_state_dict = model.state_dict()
|
759 |
+
loaded_keys = list(state_dict.keys())
|
760 |
+
|
761 |
+
expected_keys = list(model_state_dict.keys())
|
762 |
+
|
763 |
+
original_loaded_keys = loaded_keys
|
764 |
+
|
765 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
766 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
767 |
+
|
768 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
769 |
+
model_to_load = model
|
770 |
+
|
771 |
+
def _find_mismatched_keys(
|
772 |
+
state_dict,
|
773 |
+
model_state_dict,
|
774 |
+
loaded_keys,
|
775 |
+
ignore_mismatched_sizes,
|
776 |
+
):
|
777 |
+
mismatched_keys = []
|
778 |
+
if ignore_mismatched_sizes:
|
779 |
+
for checkpoint_key in loaded_keys:
|
780 |
+
model_key = checkpoint_key
|
781 |
+
|
782 |
+
if (
|
783 |
+
model_key in model_state_dict
|
784 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
785 |
+
):
|
786 |
+
mismatched_keys.append(
|
787 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
788 |
+
)
|
789 |
+
del state_dict[checkpoint_key]
|
790 |
+
return mismatched_keys
|
791 |
+
|
792 |
+
if state_dict is not None:
|
793 |
+
# Whole checkpoint
|
794 |
+
mismatched_keys = _find_mismatched_keys(
|
795 |
+
state_dict,
|
796 |
+
model_state_dict,
|
797 |
+
original_loaded_keys,
|
798 |
+
ignore_mismatched_sizes,
|
799 |
+
)
|
800 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
801 |
+
|
802 |
+
if len(error_msgs) > 0:
|
803 |
+
error_msg = "\n\t".join(error_msgs)
|
804 |
+
if "size mismatch" in error_msg:
|
805 |
+
error_msg += (
|
806 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
807 |
+
)
|
808 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
809 |
+
|
810 |
+
if len(unexpected_keys) > 0:
|
811 |
+
logger.warning(
|
812 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
813 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
814 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
815 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
816 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
817 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
818 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
819 |
+
" BertForSequenceClassification model)."
|
820 |
+
)
|
821 |
+
else:
|
822 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
823 |
+
if len(missing_keys) > 0:
|
824 |
+
logger.warning(
|
825 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
826 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
827 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
828 |
+
)
|
829 |
+
elif len(mismatched_keys) == 0:
|
830 |
+
logger.info(
|
831 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
832 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
833 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
834 |
+
" without further training."
|
835 |
+
)
|
836 |
+
if len(mismatched_keys) > 0:
|
837 |
+
mismatched_warning = "\n".join(
|
838 |
+
[
|
839 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
840 |
+
for key, shape1, shape2 in mismatched_keys
|
841 |
+
]
|
842 |
+
)
|
843 |
+
logger.warning(
|
844 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
845 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
846 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
847 |
+
" able to use it for predictions and inference."
|
848 |
+
)
|
849 |
+
|
850 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
851 |
+
|
852 |
+
@property
|
853 |
+
def device(self) -> device:
|
854 |
+
"""
|
855 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
856 |
+
device).
|
857 |
+
"""
|
858 |
+
return get_parameter_device(self)
|
859 |
+
|
860 |
+
@property
|
861 |
+
def dtype(self) -> torch.dtype:
|
862 |
+
"""
|
863 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
864 |
+
"""
|
865 |
+
return get_parameter_dtype(self)
|
866 |
+
|
867 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
868 |
+
"""
|
869 |
+
Get number of (trainable or non-embedding) parameters in the module.
|
870 |
+
|
871 |
+
Args:
|
872 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
873 |
+
Whether or not to return only the number of trainable parameters.
|
874 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
875 |
+
Whether or not to return only the number of non-embedding parameters.
|
876 |
+
|
877 |
+
Returns:
|
878 |
+
`int`: The number of parameters.
|
879 |
+
|
880 |
+
Example:
|
881 |
+
|
882 |
+
```py
|
883 |
+
from diffusers import UNet2DConditionModel
|
884 |
+
|
885 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
886 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
887 |
+
unet.num_parameters(only_trainable=True)
|
888 |
+
859520964
|
889 |
+
```
|
890 |
+
"""
|
891 |
+
|
892 |
+
if exclude_embeddings:
|
893 |
+
embedding_param_names = [
|
894 |
+
f"{name}.weight"
|
895 |
+
for name, module_type in self.named_modules()
|
896 |
+
if isinstance(module_type, torch.nn.Embedding)
|
897 |
+
]
|
898 |
+
non_embedding_parameters = [
|
899 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
900 |
+
]
|
901 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
902 |
+
else:
|
903 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
904 |
+
|
905 |
+
def _convert_deprecated_attention_blocks(self, state_dict):
|
906 |
+
deprecated_attention_block_paths = []
|
907 |
+
|
908 |
+
def recursive_find_attn_block(name, module):
|
909 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
910 |
+
deprecated_attention_block_paths.append(name)
|
911 |
+
|
912 |
+
for sub_name, sub_module in module.named_children():
|
913 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
914 |
+
recursive_find_attn_block(sub_name, sub_module)
|
915 |
+
|
916 |
+
recursive_find_attn_block("", self)
|
917 |
+
|
918 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
919 |
+
# because it is possible we are loading from a state dict that was already
|
920 |
+
# converted
|
921 |
+
|
922 |
+
for path in deprecated_attention_block_paths:
|
923 |
+
# group_norm path stays the same
|
924 |
+
|
925 |
+
# query -> to_q
|
926 |
+
if f"{path}.query.weight" in state_dict:
|
927 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
928 |
+
if f"{path}.query.bias" in state_dict:
|
929 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
930 |
+
|
931 |
+
# key -> to_k
|
932 |
+
if f"{path}.key.weight" in state_dict:
|
933 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
934 |
+
if f"{path}.key.bias" in state_dict:
|
935 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
936 |
+
|
937 |
+
# value -> to_v
|
938 |
+
if f"{path}.value.weight" in state_dict:
|
939 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
940 |
+
if f"{path}.value.bias" in state_dict:
|
941 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
942 |
+
|
943 |
+
# proj_attn -> to_out.0
|
944 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
945 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
946 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
947 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
948 |
+
|
949 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
950 |
+
deprecated_attention_block_modules = []
|
951 |
+
|
952 |
+
def recursive_find_attn_block(module):
|
953 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
954 |
+
deprecated_attention_block_modules.append(module)
|
955 |
+
|
956 |
+
for sub_module in module.children():
|
957 |
+
recursive_find_attn_block(sub_module)
|
958 |
+
|
959 |
+
recursive_find_attn_block(self)
|
960 |
+
|
961 |
+
for module in deprecated_attention_block_modules:
|
962 |
+
module.query = module.to_q
|
963 |
+
module.key = module.to_k
|
964 |
+
module.value = module.to_v
|
965 |
+
module.proj_attn = module.to_out[0]
|
966 |
+
|
967 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
968 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
969 |
+
# making an incorrect assumption that this model should be converted when
|
970 |
+
# it really shouldn't be.
|
971 |
+
del module.to_q
|
972 |
+
del module.to_k
|
973 |
+
del module.to_v
|
974 |
+
del module.to_out
|
975 |
+
|
976 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
977 |
+
deprecated_attention_block_modules = []
|
978 |
+
|
979 |
+
def recursive_find_attn_block(module):
|
980 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
981 |
+
deprecated_attention_block_modules.append(module)
|
982 |
+
|
983 |
+
for sub_module in module.children():
|
984 |
+
recursive_find_attn_block(sub_module)
|
985 |
+
|
986 |
+
recursive_find_attn_block(self)
|
987 |
+
|
988 |
+
for module in deprecated_attention_block_modules:
|
989 |
+
module.to_q = module.query
|
990 |
+
module.to_k = module.key
|
991 |
+
module.to_v = module.value
|
992 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
993 |
+
|
994 |
+
del module.query
|
995 |
+
del module.key
|
996 |
+
del module.value
|
997 |
+
del module.proj_attn
|
Tiger Model/diffusiers-Tiger/models/prior_transformer.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from ..utils import BaseOutput
|
10 |
+
from .attention import BasicTransformerBlock
|
11 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
12 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
13 |
+
from .modeling_utils import ModelMixin
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class PriorTransformerOutput(BaseOutput):
|
18 |
+
"""
|
19 |
+
The output of [`PriorTransformer`].
|
20 |
+
|
21 |
+
Args:
|
22 |
+
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
23 |
+
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
24 |
+
"""
|
25 |
+
|
26 |
+
predicted_image_embedding: torch.FloatTensor
|
27 |
+
|
28 |
+
|
29 |
+
class PriorTransformer(ModelMixin, ConfigMixin):
|
30 |
+
"""
|
31 |
+
A Prior Transformer model.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
35 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
36 |
+
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
37 |
+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
38 |
+
num_embeddings (`int`, *optional*, defaults to 77):
|
39 |
+
The number of embeddings of the model input `hidden_states`
|
40 |
+
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
41 |
+
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
42 |
+
additional_embeddings`.
|
43 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
44 |
+
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
45 |
+
The activation function to use to create timestep embeddings.
|
46 |
+
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
47 |
+
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
48 |
+
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
49 |
+
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
50 |
+
needed.
|
51 |
+
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
52 |
+
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
53 |
+
`encoder_hidden_states` is `None`.
|
54 |
+
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
55 |
+
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
56 |
+
product between the text embedding and image embedding as proposed in the unclip paper
|
57 |
+
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
58 |
+
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
59 |
+
If None, will be set to `num_attention_heads * attention_head_dim`
|
60 |
+
embedding_proj_dim (`int`, *optional*, default to None):
|
61 |
+
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
62 |
+
clip_embed_dim (`int`, *optional*, default to None):
|
63 |
+
The dimension of the output. If None, will be set to `embedding_dim`.
|
64 |
+
"""
|
65 |
+
|
66 |
+
@register_to_config
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
num_attention_heads: int = 32,
|
70 |
+
attention_head_dim: int = 64,
|
71 |
+
num_layers: int = 20,
|
72 |
+
embedding_dim: int = 768,
|
73 |
+
num_embeddings=77,
|
74 |
+
additional_embeddings=4,
|
75 |
+
dropout: float = 0.0,
|
76 |
+
time_embed_act_fn: str = "silu",
|
77 |
+
norm_in_type: Optional[str] = None, # layer
|
78 |
+
embedding_proj_norm_type: Optional[str] = None, # layer
|
79 |
+
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
80 |
+
added_emb_type: Optional[str] = "prd", # prd
|
81 |
+
time_embed_dim: Optional[int] = None,
|
82 |
+
embedding_proj_dim: Optional[int] = None,
|
83 |
+
clip_embed_dim: Optional[int] = None,
|
84 |
+
):
|
85 |
+
super().__init__()
|
86 |
+
self.num_attention_heads = num_attention_heads
|
87 |
+
self.attention_head_dim = attention_head_dim
|
88 |
+
inner_dim = num_attention_heads * attention_head_dim
|
89 |
+
self.additional_embeddings = additional_embeddings
|
90 |
+
|
91 |
+
time_embed_dim = time_embed_dim or inner_dim
|
92 |
+
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
93 |
+
clip_embed_dim = clip_embed_dim or embedding_dim
|
94 |
+
|
95 |
+
self.time_proj = Timesteps(inner_dim, True, 0)
|
96 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
97 |
+
|
98 |
+
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
99 |
+
|
100 |
+
if embedding_proj_norm_type is None:
|
101 |
+
self.embedding_proj_norm = None
|
102 |
+
elif embedding_proj_norm_type == "layer":
|
103 |
+
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
106 |
+
|
107 |
+
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
108 |
+
|
109 |
+
if encoder_hid_proj_type is None:
|
110 |
+
self.encoder_hidden_states_proj = None
|
111 |
+
elif encoder_hid_proj_type == "linear":
|
112 |
+
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
113 |
+
else:
|
114 |
+
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
115 |
+
|
116 |
+
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
117 |
+
|
118 |
+
if added_emb_type == "prd":
|
119 |
+
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
120 |
+
elif added_emb_type is None:
|
121 |
+
self.prd_embedding = None
|
122 |
+
else:
|
123 |
+
raise ValueError(
|
124 |
+
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
125 |
+
)
|
126 |
+
|
127 |
+
self.transformer_blocks = nn.ModuleList(
|
128 |
+
[
|
129 |
+
BasicTransformerBlock(
|
130 |
+
inner_dim,
|
131 |
+
num_attention_heads,
|
132 |
+
attention_head_dim,
|
133 |
+
dropout=dropout,
|
134 |
+
activation_fn="gelu",
|
135 |
+
attention_bias=True,
|
136 |
+
)
|
137 |
+
for d in range(num_layers)
|
138 |
+
]
|
139 |
+
)
|
140 |
+
|
141 |
+
if norm_in_type == "layer":
|
142 |
+
self.norm_in = nn.LayerNorm(inner_dim)
|
143 |
+
elif norm_in_type is None:
|
144 |
+
self.norm_in = None
|
145 |
+
else:
|
146 |
+
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
147 |
+
|
148 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
149 |
+
|
150 |
+
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
151 |
+
|
152 |
+
causal_attention_mask = torch.full(
|
153 |
+
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
154 |
+
)
|
155 |
+
causal_attention_mask.triu_(1)
|
156 |
+
causal_attention_mask = causal_attention_mask[None, ...]
|
157 |
+
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
158 |
+
|
159 |
+
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
160 |
+
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
161 |
+
|
162 |
+
@property
|
163 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
164 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
165 |
+
r"""
|
166 |
+
Returns:
|
167 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
168 |
+
indexed by its weight name.
|
169 |
+
"""
|
170 |
+
# set recursively
|
171 |
+
processors = {}
|
172 |
+
|
173 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
174 |
+
if hasattr(module, "set_processor"):
|
175 |
+
processors[f"{name}.processor"] = module.processor
|
176 |
+
|
177 |
+
for sub_name, child in module.named_children():
|
178 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
179 |
+
|
180 |
+
return processors
|
181 |
+
|
182 |
+
for name, module in self.named_children():
|
183 |
+
fn_recursive_add_processors(name, module, processors)
|
184 |
+
|
185 |
+
return processors
|
186 |
+
|
187 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
188 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
189 |
+
r"""
|
190 |
+
Sets the attention processor to use to compute attention.
|
191 |
+
|
192 |
+
Parameters:
|
193 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
194 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
195 |
+
for **all** `Attention` layers.
|
196 |
+
|
197 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
198 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
199 |
+
|
200 |
+
"""
|
201 |
+
count = len(self.attn_processors.keys())
|
202 |
+
|
203 |
+
if isinstance(processor, dict) and len(processor) != count:
|
204 |
+
raise ValueError(
|
205 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
206 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
207 |
+
)
|
208 |
+
|
209 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
210 |
+
if hasattr(module, "set_processor"):
|
211 |
+
if not isinstance(processor, dict):
|
212 |
+
module.set_processor(processor)
|
213 |
+
else:
|
214 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
215 |
+
|
216 |
+
for sub_name, child in module.named_children():
|
217 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
218 |
+
|
219 |
+
for name, module in self.named_children():
|
220 |
+
fn_recursive_attn_processor(name, module, processor)
|
221 |
+
|
222 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
223 |
+
def set_default_attn_processor(self):
|
224 |
+
"""
|
225 |
+
Disables custom attention processors and sets the default attention implementation.
|
226 |
+
"""
|
227 |
+
self.set_attn_processor(AttnProcessor())
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self,
|
231 |
+
hidden_states,
|
232 |
+
timestep: Union[torch.Tensor, float, int],
|
233 |
+
proj_embedding: torch.FloatTensor,
|
234 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
235 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
236 |
+
return_dict: bool = True,
|
237 |
+
):
|
238 |
+
"""
|
239 |
+
The [`PriorTransformer`] forward method.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
243 |
+
The currently predicted image embeddings.
|
244 |
+
timestep (`torch.LongTensor`):
|
245 |
+
Current denoising step.
|
246 |
+
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
247 |
+
Projected embedding vector the denoising process is conditioned on.
|
248 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
249 |
+
Hidden states of the text embeddings the denoising process is conditioned on.
|
250 |
+
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
251 |
+
Text mask for the text embeddings.
|
252 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
253 |
+
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
254 |
+
tuple.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
258 |
+
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
259 |
+
tuple is returned where the first element is the sample tensor.
|
260 |
+
"""
|
261 |
+
batch_size = hidden_states.shape[0]
|
262 |
+
|
263 |
+
timesteps = timestep
|
264 |
+
if not torch.is_tensor(timesteps):
|
265 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
266 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
267 |
+
timesteps = timesteps[None].to(hidden_states.device)
|
268 |
+
|
269 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
270 |
+
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
271 |
+
|
272 |
+
timesteps_projected = self.time_proj(timesteps)
|
273 |
+
|
274 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
275 |
+
# but time_embedding might be fp16, so we need to cast here.
|
276 |
+
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
277 |
+
time_embeddings = self.time_embedding(timesteps_projected)
|
278 |
+
|
279 |
+
if self.embedding_proj_norm is not None:
|
280 |
+
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
281 |
+
|
282 |
+
proj_embeddings = self.embedding_proj(proj_embedding)
|
283 |
+
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
284 |
+
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
285 |
+
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
286 |
+
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
287 |
+
|
288 |
+
hidden_states = self.proj_in(hidden_states)
|
289 |
+
|
290 |
+
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
291 |
+
|
292 |
+
additional_embeds = []
|
293 |
+
additional_embeddings_len = 0
|
294 |
+
|
295 |
+
if encoder_hidden_states is not None:
|
296 |
+
additional_embeds.append(encoder_hidden_states)
|
297 |
+
additional_embeddings_len += encoder_hidden_states.shape[1]
|
298 |
+
|
299 |
+
if len(proj_embeddings.shape) == 2:
|
300 |
+
proj_embeddings = proj_embeddings[:, None, :]
|
301 |
+
|
302 |
+
if len(hidden_states.shape) == 2:
|
303 |
+
hidden_states = hidden_states[:, None, :]
|
304 |
+
|
305 |
+
additional_embeds = additional_embeds + [
|
306 |
+
proj_embeddings,
|
307 |
+
time_embeddings[:, None, :],
|
308 |
+
hidden_states,
|
309 |
+
]
|
310 |
+
|
311 |
+
if self.prd_embedding is not None:
|
312 |
+
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
313 |
+
additional_embeds.append(prd_embedding)
|
314 |
+
|
315 |
+
hidden_states = torch.cat(
|
316 |
+
additional_embeds,
|
317 |
+
dim=1,
|
318 |
+
)
|
319 |
+
|
320 |
+
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
321 |
+
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
322 |
+
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
323 |
+
positional_embeddings = F.pad(
|
324 |
+
positional_embeddings,
|
325 |
+
(
|
326 |
+
0,
|
327 |
+
0,
|
328 |
+
additional_embeddings_len,
|
329 |
+
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
330 |
+
),
|
331 |
+
value=0.0,
|
332 |
+
)
|
333 |
+
|
334 |
+
hidden_states = hidden_states + positional_embeddings
|
335 |
+
|
336 |
+
if attention_mask is not None:
|
337 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
338 |
+
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
339 |
+
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
340 |
+
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
341 |
+
|
342 |
+
if self.norm_in is not None:
|
343 |
+
hidden_states = self.norm_in(hidden_states)
|
344 |
+
|
345 |
+
for block in self.transformer_blocks:
|
346 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
347 |
+
|
348 |
+
hidden_states = self.norm_out(hidden_states)
|
349 |
+
|
350 |
+
if self.prd_embedding is not None:
|
351 |
+
hidden_states = hidden_states[:, -1]
|
352 |
+
else:
|
353 |
+
hidden_states = hidden_states[:, additional_embeddings_len:]
|
354 |
+
|
355 |
+
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
356 |
+
|
357 |
+
if not return_dict:
|
358 |
+
return (predicted_image_embedding,)
|
359 |
+
|
360 |
+
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
361 |
+
|
362 |
+
def post_process_latents(self, prior_latents):
|
363 |
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
364 |
+
return prior_latents
|
Tiger Model/diffusiers-Tiger/models/resnet.py
ADDED
@@ -0,0 +1,878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from .activations import get_activation
|
24 |
+
from .attention import AdaGroupNorm
|
25 |
+
from .attention_processor import SpatialNorm
|
26 |
+
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
27 |
+
|
28 |
+
|
29 |
+
class Upsample1D(nn.Module):
|
30 |
+
"""A 1D upsampling layer with an optional convolution.
|
31 |
+
|
32 |
+
Parameters:
|
33 |
+
channels (`int`):
|
34 |
+
number of channels in the inputs and outputs.
|
35 |
+
use_conv (`bool`, default `False`):
|
36 |
+
option to use a convolution.
|
37 |
+
use_conv_transpose (`bool`, default `False`):
|
38 |
+
option to use a convolution transpose.
|
39 |
+
out_channels (`int`, optional):
|
40 |
+
number of output channels. Defaults to `channels`.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
44 |
+
super().__init__()
|
45 |
+
self.channels = channels
|
46 |
+
self.out_channels = out_channels or channels
|
47 |
+
self.use_conv = use_conv
|
48 |
+
self.use_conv_transpose = use_conv_transpose
|
49 |
+
self.name = name
|
50 |
+
|
51 |
+
self.conv = None
|
52 |
+
if use_conv_transpose:
|
53 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
54 |
+
elif use_conv:
|
55 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
56 |
+
|
57 |
+
def forward(self, inputs):
|
58 |
+
assert inputs.shape[1] == self.channels
|
59 |
+
if self.use_conv_transpose:
|
60 |
+
return self.conv(inputs)
|
61 |
+
|
62 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
63 |
+
|
64 |
+
if self.use_conv:
|
65 |
+
outputs = self.conv(outputs)
|
66 |
+
|
67 |
+
return outputs
|
68 |
+
|
69 |
+
|
70 |
+
class Downsample1D(nn.Module):
|
71 |
+
"""A 1D downsampling layer with an optional convolution.
|
72 |
+
|
73 |
+
Parameters:
|
74 |
+
channels (`int`):
|
75 |
+
number of channels in the inputs and outputs.
|
76 |
+
use_conv (`bool`, default `False`):
|
77 |
+
option to use a convolution.
|
78 |
+
out_channels (`int`, optional):
|
79 |
+
number of output channels. Defaults to `channels`.
|
80 |
+
padding (`int`, default `1`):
|
81 |
+
padding for the convolution.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
85 |
+
super().__init__()
|
86 |
+
self.channels = channels
|
87 |
+
self.out_channels = out_channels or channels
|
88 |
+
self.use_conv = use_conv
|
89 |
+
self.padding = padding
|
90 |
+
stride = 2
|
91 |
+
self.name = name
|
92 |
+
|
93 |
+
if use_conv:
|
94 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
95 |
+
else:
|
96 |
+
assert self.channels == self.out_channels
|
97 |
+
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
98 |
+
|
99 |
+
def forward(self, inputs):
|
100 |
+
assert inputs.shape[1] == self.channels
|
101 |
+
return self.conv(inputs)
|
102 |
+
|
103 |
+
|
104 |
+
class Upsample2D(nn.Module):
|
105 |
+
"""A 2D upsampling layer with an optional convolution.
|
106 |
+
|
107 |
+
Parameters:
|
108 |
+
channels (`int`):
|
109 |
+
number of channels in the inputs and outputs.
|
110 |
+
use_conv (`bool`, default `False`):
|
111 |
+
option to use a convolution.
|
112 |
+
use_conv_transpose (`bool`, default `False`):
|
113 |
+
option to use a convolution transpose.
|
114 |
+
out_channels (`int`, optional):
|
115 |
+
number of output channels. Defaults to `channels`.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
119 |
+
super().__init__()
|
120 |
+
self.channels = channels
|
121 |
+
self.out_channels = out_channels or channels
|
122 |
+
self.use_conv = use_conv
|
123 |
+
self.use_conv_transpose = use_conv_transpose
|
124 |
+
self.name = name
|
125 |
+
|
126 |
+
conv = None
|
127 |
+
if use_conv_transpose:
|
128 |
+
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
129 |
+
elif use_conv:
|
130 |
+
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
|
131 |
+
|
132 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
133 |
+
if name == "conv":
|
134 |
+
self.conv = conv
|
135 |
+
else:
|
136 |
+
self.Conv2d_0 = conv
|
137 |
+
|
138 |
+
def forward(self, hidden_states, output_size=None):
|
139 |
+
assert hidden_states.shape[1] == self.channels
|
140 |
+
|
141 |
+
if self.use_conv_transpose:
|
142 |
+
return self.conv(hidden_states)
|
143 |
+
|
144 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
145 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
146 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
147 |
+
dtype = hidden_states.dtype
|
148 |
+
if dtype == torch.bfloat16:
|
149 |
+
hidden_states = hidden_states.to(torch.float32)
|
150 |
+
|
151 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
152 |
+
if hidden_states.shape[0] >= 64:
|
153 |
+
hidden_states = hidden_states.contiguous()
|
154 |
+
|
155 |
+
# if `output_size` is passed we force the interpolation output
|
156 |
+
# size and do not make use of `scale_factor=2`
|
157 |
+
if output_size is None:
|
158 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
159 |
+
else:
|
160 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
161 |
+
|
162 |
+
# If the input is bfloat16, we cast back to bfloat16
|
163 |
+
if dtype == torch.bfloat16:
|
164 |
+
hidden_states = hidden_states.to(dtype)
|
165 |
+
|
166 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
167 |
+
if self.use_conv:
|
168 |
+
if self.name == "conv":
|
169 |
+
hidden_states = self.conv(hidden_states)
|
170 |
+
else:
|
171 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
172 |
+
|
173 |
+
return hidden_states
|
174 |
+
|
175 |
+
|
176 |
+
class Downsample2D(nn.Module):
|
177 |
+
"""A 2D downsampling layer with an optional convolution.
|
178 |
+
|
179 |
+
Parameters:
|
180 |
+
channels (`int`):
|
181 |
+
number of channels in the inputs and outputs.
|
182 |
+
use_conv (`bool`, default `False`):
|
183 |
+
option to use a convolution.
|
184 |
+
out_channels (`int`, optional):
|
185 |
+
number of output channels. Defaults to `channels`.
|
186 |
+
padding (`int`, default `1`):
|
187 |
+
padding for the convolution.
|
188 |
+
"""
|
189 |
+
|
190 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
191 |
+
super().__init__()
|
192 |
+
self.channels = channels
|
193 |
+
self.out_channels = out_channels or channels
|
194 |
+
self.use_conv = use_conv
|
195 |
+
self.padding = padding
|
196 |
+
stride = 2
|
197 |
+
self.name = name
|
198 |
+
|
199 |
+
if use_conv:
|
200 |
+
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
201 |
+
else:
|
202 |
+
assert self.channels == self.out_channels
|
203 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
204 |
+
|
205 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
206 |
+
if name == "conv":
|
207 |
+
self.Conv2d_0 = conv
|
208 |
+
self.conv = conv
|
209 |
+
elif name == "Conv2d_0":
|
210 |
+
self.conv = conv
|
211 |
+
else:
|
212 |
+
self.conv = conv
|
213 |
+
|
214 |
+
def forward(self, hidden_states):
|
215 |
+
assert hidden_states.shape[1] == self.channels
|
216 |
+
if self.use_conv and self.padding == 0:
|
217 |
+
pad = (0, 1, 0, 1)
|
218 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
219 |
+
|
220 |
+
assert hidden_states.shape[1] == self.channels
|
221 |
+
hidden_states = self.conv(hidden_states)
|
222 |
+
|
223 |
+
return hidden_states
|
224 |
+
|
225 |
+
|
226 |
+
class FirUpsample2D(nn.Module):
|
227 |
+
"""A 2D FIR upsampling layer with an optional convolution.
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
channels (`int`):
|
231 |
+
number of channels in the inputs and outputs.
|
232 |
+
use_conv (`bool`, default `False`):
|
233 |
+
option to use a convolution.
|
234 |
+
out_channels (`int`, optional):
|
235 |
+
number of output channels. Defaults to `channels`.
|
236 |
+
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
237 |
+
kernel for the FIR filter.
|
238 |
+
"""
|
239 |
+
|
240 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
241 |
+
super().__init__()
|
242 |
+
out_channels = out_channels if out_channels else channels
|
243 |
+
if use_conv:
|
244 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
245 |
+
self.use_conv = use_conv
|
246 |
+
self.fir_kernel = fir_kernel
|
247 |
+
self.out_channels = out_channels
|
248 |
+
|
249 |
+
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
250 |
+
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
251 |
+
|
252 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
253 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
254 |
+
arbitrary order.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
258 |
+
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
259 |
+
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
260 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
261 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
262 |
+
factor: Integer upsampling factor (default: 2).
|
263 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
267 |
+
datatype as `hidden_states`.
|
268 |
+
"""
|
269 |
+
|
270 |
+
assert isinstance(factor, int) and factor >= 1
|
271 |
+
|
272 |
+
# Setup filter kernel.
|
273 |
+
if kernel is None:
|
274 |
+
kernel = [1] * factor
|
275 |
+
|
276 |
+
# setup kernel
|
277 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
278 |
+
if kernel.ndim == 1:
|
279 |
+
kernel = torch.outer(kernel, kernel)
|
280 |
+
kernel /= torch.sum(kernel)
|
281 |
+
|
282 |
+
kernel = kernel * (gain * (factor**2))
|
283 |
+
|
284 |
+
if self.use_conv:
|
285 |
+
convH = weight.shape[2]
|
286 |
+
convW = weight.shape[3]
|
287 |
+
inC = weight.shape[1]
|
288 |
+
|
289 |
+
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
290 |
+
|
291 |
+
stride = (factor, factor)
|
292 |
+
# Determine data dimensions.
|
293 |
+
output_shape = (
|
294 |
+
(hidden_states.shape[2] - 1) * factor + convH,
|
295 |
+
(hidden_states.shape[3] - 1) * factor + convW,
|
296 |
+
)
|
297 |
+
output_padding = (
|
298 |
+
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
299 |
+
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
300 |
+
)
|
301 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
302 |
+
num_groups = hidden_states.shape[1] // inC
|
303 |
+
|
304 |
+
# Transpose weights.
|
305 |
+
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
306 |
+
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
307 |
+
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
308 |
+
|
309 |
+
inverse_conv = F.conv_transpose2d(
|
310 |
+
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
|
311 |
+
)
|
312 |
+
|
313 |
+
output = upfirdn2d_native(
|
314 |
+
inverse_conv,
|
315 |
+
torch.tensor(kernel, device=inverse_conv.device),
|
316 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
pad_value = kernel.shape[0] - factor
|
320 |
+
output = upfirdn2d_native(
|
321 |
+
hidden_states,
|
322 |
+
torch.tensor(kernel, device=hidden_states.device),
|
323 |
+
up=factor,
|
324 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
325 |
+
)
|
326 |
+
|
327 |
+
return output
|
328 |
+
|
329 |
+
def forward(self, hidden_states):
|
330 |
+
if self.use_conv:
|
331 |
+
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
332 |
+
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
333 |
+
else:
|
334 |
+
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
335 |
+
|
336 |
+
return height
|
337 |
+
|
338 |
+
|
339 |
+
class FirDownsample2D(nn.Module):
|
340 |
+
"""A 2D FIR downsampling layer with an optional convolution.
|
341 |
+
|
342 |
+
Parameters:
|
343 |
+
channels (`int`):
|
344 |
+
number of channels in the inputs and outputs.
|
345 |
+
use_conv (`bool`, default `False`):
|
346 |
+
option to use a convolution.
|
347 |
+
out_channels (`int`, optional):
|
348 |
+
number of output channels. Defaults to `channels`.
|
349 |
+
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
350 |
+
kernel for the FIR filter.
|
351 |
+
"""
|
352 |
+
|
353 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
354 |
+
super().__init__()
|
355 |
+
out_channels = out_channels if out_channels else channels
|
356 |
+
if use_conv:
|
357 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
358 |
+
self.fir_kernel = fir_kernel
|
359 |
+
self.use_conv = use_conv
|
360 |
+
self.out_channels = out_channels
|
361 |
+
|
362 |
+
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
363 |
+
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
364 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
365 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
366 |
+
arbitrary order.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
370 |
+
weight:
|
371 |
+
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
372 |
+
performed by `inChannels = x.shape[0] // numGroups`.
|
373 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
374 |
+
factor`, which corresponds to average pooling.
|
375 |
+
factor: Integer downsampling factor (default: 2).
|
376 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
|
380 |
+
same datatype as `x`.
|
381 |
+
"""
|
382 |
+
|
383 |
+
assert isinstance(factor, int) and factor >= 1
|
384 |
+
if kernel is None:
|
385 |
+
kernel = [1] * factor
|
386 |
+
|
387 |
+
# setup kernel
|
388 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
389 |
+
if kernel.ndim == 1:
|
390 |
+
kernel = torch.outer(kernel, kernel)
|
391 |
+
kernel /= torch.sum(kernel)
|
392 |
+
|
393 |
+
kernel = kernel * gain
|
394 |
+
|
395 |
+
if self.use_conv:
|
396 |
+
_, _, convH, convW = weight.shape
|
397 |
+
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
398 |
+
stride_value = [factor, factor]
|
399 |
+
upfirdn_input = upfirdn2d_native(
|
400 |
+
hidden_states,
|
401 |
+
torch.tensor(kernel, device=hidden_states.device),
|
402 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
403 |
+
)
|
404 |
+
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
405 |
+
else:
|
406 |
+
pad_value = kernel.shape[0] - factor
|
407 |
+
output = upfirdn2d_native(
|
408 |
+
hidden_states,
|
409 |
+
torch.tensor(kernel, device=hidden_states.device),
|
410 |
+
down=factor,
|
411 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
412 |
+
)
|
413 |
+
|
414 |
+
return output
|
415 |
+
|
416 |
+
def forward(self, hidden_states):
|
417 |
+
if self.use_conv:
|
418 |
+
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
419 |
+
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
420 |
+
else:
|
421 |
+
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
422 |
+
|
423 |
+
return hidden_states
|
424 |
+
|
425 |
+
|
426 |
+
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
427 |
+
class KDownsample2D(nn.Module):
|
428 |
+
def __init__(self, pad_mode="reflect"):
|
429 |
+
super().__init__()
|
430 |
+
self.pad_mode = pad_mode
|
431 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
432 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
433 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
434 |
+
|
435 |
+
def forward(self, inputs):
|
436 |
+
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
437 |
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
438 |
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
439 |
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
440 |
+
weight[indices, indices] = kernel
|
441 |
+
return F.conv2d(inputs, weight, stride=2)
|
442 |
+
|
443 |
+
|
444 |
+
class KUpsample2D(nn.Module):
|
445 |
+
def __init__(self, pad_mode="reflect"):
|
446 |
+
super().__init__()
|
447 |
+
self.pad_mode = pad_mode
|
448 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
449 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
450 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
451 |
+
|
452 |
+
def forward(self, inputs):
|
453 |
+
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
454 |
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
455 |
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
456 |
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
457 |
+
weight[indices, indices] = kernel
|
458 |
+
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
459 |
+
|
460 |
+
|
461 |
+
class ResnetBlock2D(nn.Module):
|
462 |
+
r"""
|
463 |
+
A Resnet block.
|
464 |
+
|
465 |
+
Parameters:
|
466 |
+
in_channels (`int`): The number of channels in the input.
|
467 |
+
out_channels (`int`, *optional*, default to be `None`):
|
468 |
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
469 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
470 |
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
471 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
472 |
+
groups_out (`int`, *optional*, default to None):
|
473 |
+
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
474 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
475 |
+
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
476 |
+
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
477 |
+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
|
478 |
+
"ada_group" for a stronger conditioning with scale and shift.
|
479 |
+
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
|
480 |
+
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
481 |
+
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
482 |
+
use_in_shortcut (`bool`, *optional*, default to `True`):
|
483 |
+
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
484 |
+
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
485 |
+
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
486 |
+
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
487 |
+
`conv_shortcut` output.
|
488 |
+
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
489 |
+
If None, same as `out_channels`.
|
490 |
+
"""
|
491 |
+
|
492 |
+
def __init__(
|
493 |
+
self,
|
494 |
+
*,
|
495 |
+
in_channels,
|
496 |
+
out_channels=None,
|
497 |
+
conv_shortcut=False,
|
498 |
+
dropout=0.0,
|
499 |
+
temb_channels=512,
|
500 |
+
groups=32,
|
501 |
+
groups_out=None,
|
502 |
+
pre_norm=True,
|
503 |
+
eps=1e-6,
|
504 |
+
non_linearity="swish",
|
505 |
+
skip_time_act=False,
|
506 |
+
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
|
507 |
+
kernel=None,
|
508 |
+
output_scale_factor=1.0,
|
509 |
+
use_in_shortcut=None,
|
510 |
+
up=False,
|
511 |
+
down=False,
|
512 |
+
conv_shortcut_bias: bool = True,
|
513 |
+
conv_2d_out_channels: Optional[int] = None,
|
514 |
+
):
|
515 |
+
super().__init__()
|
516 |
+
self.pre_norm = pre_norm
|
517 |
+
self.pre_norm = True
|
518 |
+
self.in_channels = in_channels
|
519 |
+
out_channels = in_channels if out_channels is None else out_channels
|
520 |
+
self.out_channels = out_channels
|
521 |
+
self.use_conv_shortcut = conv_shortcut
|
522 |
+
self.up = up
|
523 |
+
self.down = down
|
524 |
+
self.output_scale_factor = output_scale_factor
|
525 |
+
self.time_embedding_norm = time_embedding_norm
|
526 |
+
self.skip_time_act = skip_time_act
|
527 |
+
|
528 |
+
if groups_out is None:
|
529 |
+
groups_out = groups
|
530 |
+
|
531 |
+
if self.time_embedding_norm == "ada_group":
|
532 |
+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
|
533 |
+
elif self.time_embedding_norm == "spatial":
|
534 |
+
self.norm1 = SpatialNorm(in_channels, temb_channels)
|
535 |
+
else:
|
536 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
537 |
+
|
538 |
+
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
539 |
+
|
540 |
+
if temb_channels is not None:
|
541 |
+
if self.time_embedding_norm == "default":
|
542 |
+
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
|
543 |
+
elif self.time_embedding_norm == "scale_shift":
|
544 |
+
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
|
545 |
+
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
546 |
+
self.time_emb_proj = None
|
547 |
+
else:
|
548 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
549 |
+
else:
|
550 |
+
self.time_emb_proj = None
|
551 |
+
|
552 |
+
if self.time_embedding_norm == "ada_group":
|
553 |
+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
554 |
+
elif self.time_embedding_norm == "spatial":
|
555 |
+
self.norm2 = SpatialNorm(out_channels, temb_channels)
|
556 |
+
else:
|
557 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
558 |
+
|
559 |
+
self.dropout = torch.nn.Dropout(dropout)
|
560 |
+
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
561 |
+
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
562 |
+
|
563 |
+
self.nonlinearity = get_activation(non_linearity)
|
564 |
+
|
565 |
+
self.upsample = self.downsample = None
|
566 |
+
if self.up:
|
567 |
+
if kernel == "fir":
|
568 |
+
fir_kernel = (1, 3, 3, 1)
|
569 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
570 |
+
elif kernel == "sde_vp":
|
571 |
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
572 |
+
else:
|
573 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
574 |
+
elif self.down:
|
575 |
+
if kernel == "fir":
|
576 |
+
fir_kernel = (1, 3, 3, 1)
|
577 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
578 |
+
elif kernel == "sde_vp":
|
579 |
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
580 |
+
else:
|
581 |
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
582 |
+
|
583 |
+
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
|
584 |
+
|
585 |
+
self.conv_shortcut = None
|
586 |
+
if self.use_in_shortcut:
|
587 |
+
self.conv_shortcut = LoRACompatibleConv(
|
588 |
+
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
589 |
+
)
|
590 |
+
|
591 |
+
def forward(self, input_tensor, temb):
|
592 |
+
hidden_states = input_tensor
|
593 |
+
|
594 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
595 |
+
hidden_states = self.norm1(hidden_states, temb)
|
596 |
+
else:
|
597 |
+
hidden_states = self.norm1(hidden_states)
|
598 |
+
|
599 |
+
hidden_states = self.nonlinearity(hidden_states)
|
600 |
+
|
601 |
+
if self.upsample is not None:
|
602 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
603 |
+
if hidden_states.shape[0] >= 64:
|
604 |
+
input_tensor = input_tensor.contiguous()
|
605 |
+
hidden_states = hidden_states.contiguous()
|
606 |
+
input_tensor = self.upsample(input_tensor)
|
607 |
+
hidden_states = self.upsample(hidden_states)
|
608 |
+
elif self.downsample is not None:
|
609 |
+
input_tensor = self.downsample(input_tensor)
|
610 |
+
hidden_states = self.downsample(hidden_states)
|
611 |
+
|
612 |
+
hidden_states = self.conv1(hidden_states)
|
613 |
+
|
614 |
+
if self.time_emb_proj is not None:
|
615 |
+
if not self.skip_time_act:
|
616 |
+
temb = self.nonlinearity(temb)
|
617 |
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
618 |
+
|
619 |
+
if temb is not None and self.time_embedding_norm == "default":
|
620 |
+
hidden_states = hidden_states + temb
|
621 |
+
|
622 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
623 |
+
hidden_states = self.norm2(hidden_states, temb)
|
624 |
+
else:
|
625 |
+
hidden_states = self.norm2(hidden_states)
|
626 |
+
|
627 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
628 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
629 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
630 |
+
|
631 |
+
hidden_states = self.nonlinearity(hidden_states)
|
632 |
+
|
633 |
+
hidden_states = self.dropout(hidden_states)
|
634 |
+
hidden_states = self.conv2(hidden_states)
|
635 |
+
|
636 |
+
if self.conv_shortcut is not None:
|
637 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
638 |
+
|
639 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
640 |
+
|
641 |
+
return output_tensor
|
642 |
+
|
643 |
+
|
644 |
+
# unet_rl.py
|
645 |
+
def rearrange_dims(tensor):
|
646 |
+
if len(tensor.shape) == 2:
|
647 |
+
return tensor[:, :, None]
|
648 |
+
if len(tensor.shape) == 3:
|
649 |
+
return tensor[:, :, None, :]
|
650 |
+
elif len(tensor.shape) == 4:
|
651 |
+
return tensor[:, :, 0, :]
|
652 |
+
else:
|
653 |
+
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
654 |
+
|
655 |
+
|
656 |
+
class Conv1dBlock(nn.Module):
|
657 |
+
"""
|
658 |
+
Conv1d --> GroupNorm --> Mish
|
659 |
+
"""
|
660 |
+
|
661 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
662 |
+
super().__init__()
|
663 |
+
|
664 |
+
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
665 |
+
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
666 |
+
self.mish = nn.Mish()
|
667 |
+
|
668 |
+
def forward(self, inputs):
|
669 |
+
intermediate_repr = self.conv1d(inputs)
|
670 |
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
671 |
+
intermediate_repr = self.group_norm(intermediate_repr)
|
672 |
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
673 |
+
output = self.mish(intermediate_repr)
|
674 |
+
return output
|
675 |
+
|
676 |
+
|
677 |
+
# unet_rl.py
|
678 |
+
class ResidualTemporalBlock1D(nn.Module):
|
679 |
+
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
680 |
+
super().__init__()
|
681 |
+
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
682 |
+
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
683 |
+
|
684 |
+
self.time_emb_act = nn.Mish()
|
685 |
+
self.time_emb = nn.Linear(embed_dim, out_channels)
|
686 |
+
|
687 |
+
self.residual_conv = (
|
688 |
+
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
689 |
+
)
|
690 |
+
|
691 |
+
def forward(self, inputs, t):
|
692 |
+
"""
|
693 |
+
Args:
|
694 |
+
inputs : [ batch_size x inp_channels x horizon ]
|
695 |
+
t : [ batch_size x embed_dim ]
|
696 |
+
|
697 |
+
returns:
|
698 |
+
out : [ batch_size x out_channels x horizon ]
|
699 |
+
"""
|
700 |
+
t = self.time_emb_act(t)
|
701 |
+
t = self.time_emb(t)
|
702 |
+
out = self.conv_in(inputs) + rearrange_dims(t)
|
703 |
+
out = self.conv_out(out)
|
704 |
+
return out + self.residual_conv(inputs)
|
705 |
+
|
706 |
+
|
707 |
+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
708 |
+
r"""Upsample2D a batch of 2D images with the given filter.
|
709 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
710 |
+
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
711 |
+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
712 |
+
a: multiple of the upsampling factor.
|
713 |
+
|
714 |
+
Args:
|
715 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
716 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
717 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
718 |
+
factor: Integer upsampling factor (default: 2).
|
719 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
720 |
+
|
721 |
+
Returns:
|
722 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
723 |
+
"""
|
724 |
+
assert isinstance(factor, int) and factor >= 1
|
725 |
+
if kernel is None:
|
726 |
+
kernel = [1] * factor
|
727 |
+
|
728 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
729 |
+
if kernel.ndim == 1:
|
730 |
+
kernel = torch.outer(kernel, kernel)
|
731 |
+
kernel /= torch.sum(kernel)
|
732 |
+
|
733 |
+
kernel = kernel * (gain * (factor**2))
|
734 |
+
pad_value = kernel.shape[0] - factor
|
735 |
+
output = upfirdn2d_native(
|
736 |
+
hidden_states,
|
737 |
+
kernel.to(device=hidden_states.device),
|
738 |
+
up=factor,
|
739 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
740 |
+
)
|
741 |
+
return output
|
742 |
+
|
743 |
+
|
744 |
+
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
745 |
+
r"""Downsample2D a batch of 2D images with the given filter.
|
746 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
747 |
+
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
748 |
+
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
749 |
+
shape is a multiple of the downsampling factor.
|
750 |
+
|
751 |
+
Args:
|
752 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
753 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
754 |
+
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
755 |
+
factor: Integer downsampling factor (default: 2).
|
756 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
757 |
+
|
758 |
+
Returns:
|
759 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
760 |
+
"""
|
761 |
+
|
762 |
+
assert isinstance(factor, int) and factor >= 1
|
763 |
+
if kernel is None:
|
764 |
+
kernel = [1] * factor
|
765 |
+
|
766 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
767 |
+
if kernel.ndim == 1:
|
768 |
+
kernel = torch.outer(kernel, kernel)
|
769 |
+
kernel /= torch.sum(kernel)
|
770 |
+
|
771 |
+
kernel = kernel * gain
|
772 |
+
pad_value = kernel.shape[0] - factor
|
773 |
+
output = upfirdn2d_native(
|
774 |
+
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
|
775 |
+
)
|
776 |
+
return output
|
777 |
+
|
778 |
+
|
779 |
+
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
780 |
+
up_x = up_y = up
|
781 |
+
down_x = down_y = down
|
782 |
+
pad_x0 = pad_y0 = pad[0]
|
783 |
+
pad_x1 = pad_y1 = pad[1]
|
784 |
+
|
785 |
+
_, channel, in_h, in_w = tensor.shape
|
786 |
+
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
787 |
+
|
788 |
+
_, in_h, in_w, minor = tensor.shape
|
789 |
+
kernel_h, kernel_w = kernel.shape
|
790 |
+
|
791 |
+
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
792 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
793 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
794 |
+
|
795 |
+
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
796 |
+
out = out.to(tensor.device) # Move back to mps if necessary
|
797 |
+
out = out[
|
798 |
+
:,
|
799 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
800 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
801 |
+
:,
|
802 |
+
]
|
803 |
+
|
804 |
+
out = out.permute(0, 3, 1, 2)
|
805 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
806 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
807 |
+
out = F.conv2d(out, w)
|
808 |
+
out = out.reshape(
|
809 |
+
-1,
|
810 |
+
minor,
|
811 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
812 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
813 |
+
)
|
814 |
+
out = out.permute(0, 2, 3, 1)
|
815 |
+
out = out[:, ::down_y, ::down_x, :]
|
816 |
+
|
817 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
818 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
819 |
+
|
820 |
+
return out.view(-1, channel, out_h, out_w)
|
821 |
+
|
822 |
+
|
823 |
+
class TemporalConvLayer(nn.Module):
|
824 |
+
"""
|
825 |
+
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
826 |
+
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
827 |
+
"""
|
828 |
+
|
829 |
+
def __init__(self, in_dim, out_dim=None, dropout=0.0):
|
830 |
+
super().__init__()
|
831 |
+
out_dim = out_dim or in_dim
|
832 |
+
self.in_dim = in_dim
|
833 |
+
self.out_dim = out_dim
|
834 |
+
|
835 |
+
# conv layers
|
836 |
+
self.conv1 = nn.Sequential(
|
837 |
+
nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
|
838 |
+
)
|
839 |
+
self.conv2 = nn.Sequential(
|
840 |
+
nn.GroupNorm(32, out_dim),
|
841 |
+
nn.SiLU(),
|
842 |
+
nn.Dropout(dropout),
|
843 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
844 |
+
)
|
845 |
+
self.conv3 = nn.Sequential(
|
846 |
+
nn.GroupNorm(32, out_dim),
|
847 |
+
nn.SiLU(),
|
848 |
+
nn.Dropout(dropout),
|
849 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
850 |
+
)
|
851 |
+
self.conv4 = nn.Sequential(
|
852 |
+
nn.GroupNorm(32, out_dim),
|
853 |
+
nn.SiLU(),
|
854 |
+
nn.Dropout(dropout),
|
855 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
856 |
+
)
|
857 |
+
|
858 |
+
# zero out the last layer params,so the conv block is identity
|
859 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
860 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
861 |
+
|
862 |
+
def forward(self, hidden_states, num_frames=1):
|
863 |
+
hidden_states = (
|
864 |
+
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
|
865 |
+
)
|
866 |
+
|
867 |
+
identity = hidden_states
|
868 |
+
hidden_states = self.conv1(hidden_states)
|
869 |
+
hidden_states = self.conv2(hidden_states)
|
870 |
+
hidden_states = self.conv3(hidden_states)
|
871 |
+
hidden_states = self.conv4(hidden_states)
|
872 |
+
|
873 |
+
hidden_states = identity + hidden_states
|
874 |
+
|
875 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
|
876 |
+
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
877 |
+
)
|
878 |
+
return hidden_states
|
Tiger Model/diffusiers-Tiger/models/t5_film_transformer.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from .attention_processor import Attention
|
21 |
+
from .embeddings import get_timestep_embedding
|
22 |
+
from .modeling_utils import ModelMixin
|
23 |
+
|
24 |
+
|
25 |
+
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
26 |
+
@register_to_config
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
input_dims: int = 128,
|
30 |
+
targets_length: int = 256,
|
31 |
+
max_decoder_noise_time: float = 2000.0,
|
32 |
+
d_model: int = 768,
|
33 |
+
num_layers: int = 12,
|
34 |
+
num_heads: int = 12,
|
35 |
+
d_kv: int = 64,
|
36 |
+
d_ff: int = 2048,
|
37 |
+
dropout_rate: float = 0.1,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.conditioning_emb = nn.Sequential(
|
42 |
+
nn.Linear(d_model, d_model * 4, bias=False),
|
43 |
+
nn.SiLU(),
|
44 |
+
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
45 |
+
nn.SiLU(),
|
46 |
+
)
|
47 |
+
|
48 |
+
self.position_encoding = nn.Embedding(targets_length, d_model)
|
49 |
+
self.position_encoding.weight.requires_grad = False
|
50 |
+
|
51 |
+
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
52 |
+
|
53 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
54 |
+
|
55 |
+
self.decoders = nn.ModuleList()
|
56 |
+
for lyr_num in range(num_layers):
|
57 |
+
# FiLM conditional T5 decoder
|
58 |
+
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
59 |
+
self.decoders.append(lyr)
|
60 |
+
|
61 |
+
self.decoder_norm = T5LayerNorm(d_model)
|
62 |
+
|
63 |
+
self.post_dropout = nn.Dropout(p=dropout_rate)
|
64 |
+
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
65 |
+
|
66 |
+
def encoder_decoder_mask(self, query_input, key_input):
|
67 |
+
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
68 |
+
return mask.unsqueeze(-3)
|
69 |
+
|
70 |
+
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
71 |
+
batch, _, _ = decoder_input_tokens.shape
|
72 |
+
assert decoder_noise_time.shape == (batch,)
|
73 |
+
|
74 |
+
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
75 |
+
time_steps = get_timestep_embedding(
|
76 |
+
decoder_noise_time * self.config.max_decoder_noise_time,
|
77 |
+
embedding_dim=self.config.d_model,
|
78 |
+
max_period=self.config.max_decoder_noise_time,
|
79 |
+
).to(dtype=self.dtype)
|
80 |
+
|
81 |
+
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
82 |
+
|
83 |
+
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
84 |
+
|
85 |
+
seq_length = decoder_input_tokens.shape[1]
|
86 |
+
|
87 |
+
# If we want to use relative positions for audio context, we can just offset
|
88 |
+
# this sequence by the length of encodings_and_masks.
|
89 |
+
decoder_positions = torch.broadcast_to(
|
90 |
+
torch.arange(seq_length, device=decoder_input_tokens.device),
|
91 |
+
(batch, seq_length),
|
92 |
+
)
|
93 |
+
|
94 |
+
position_encodings = self.position_encoding(decoder_positions)
|
95 |
+
|
96 |
+
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
97 |
+
inputs += position_encodings
|
98 |
+
y = self.dropout(inputs)
|
99 |
+
|
100 |
+
# decoder: No padding present.
|
101 |
+
decoder_mask = torch.ones(
|
102 |
+
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
103 |
+
)
|
104 |
+
|
105 |
+
# Translate encoding masks to encoder-decoder masks.
|
106 |
+
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
107 |
+
|
108 |
+
# cross attend style: concat encodings
|
109 |
+
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
110 |
+
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
111 |
+
|
112 |
+
for lyr in self.decoders:
|
113 |
+
y = lyr(
|
114 |
+
y,
|
115 |
+
conditioning_emb=conditioning_emb,
|
116 |
+
encoder_hidden_states=encoded,
|
117 |
+
encoder_attention_mask=encoder_decoder_mask,
|
118 |
+
)[0]
|
119 |
+
|
120 |
+
y = self.decoder_norm(y)
|
121 |
+
y = self.post_dropout(y)
|
122 |
+
|
123 |
+
spec_out = self.spec_out(y)
|
124 |
+
return spec_out
|
125 |
+
|
126 |
+
|
127 |
+
class DecoderLayer(nn.Module):
|
128 |
+
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
|
129 |
+
super().__init__()
|
130 |
+
self.layer = nn.ModuleList()
|
131 |
+
|
132 |
+
# cond self attention: layer 0
|
133 |
+
self.layer.append(
|
134 |
+
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
135 |
+
)
|
136 |
+
|
137 |
+
# cross attention: layer 1
|
138 |
+
self.layer.append(
|
139 |
+
T5LayerCrossAttention(
|
140 |
+
d_model=d_model,
|
141 |
+
d_kv=d_kv,
|
142 |
+
num_heads=num_heads,
|
143 |
+
dropout_rate=dropout_rate,
|
144 |
+
layer_norm_epsilon=layer_norm_epsilon,
|
145 |
+
)
|
146 |
+
)
|
147 |
+
|
148 |
+
# Film Cond MLP + dropout: last layer
|
149 |
+
self.layer.append(
|
150 |
+
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
hidden_states,
|
156 |
+
conditioning_emb=None,
|
157 |
+
attention_mask=None,
|
158 |
+
encoder_hidden_states=None,
|
159 |
+
encoder_attention_mask=None,
|
160 |
+
encoder_decoder_position_bias=None,
|
161 |
+
):
|
162 |
+
hidden_states = self.layer[0](
|
163 |
+
hidden_states,
|
164 |
+
conditioning_emb=conditioning_emb,
|
165 |
+
attention_mask=attention_mask,
|
166 |
+
)
|
167 |
+
|
168 |
+
if encoder_hidden_states is not None:
|
169 |
+
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
170 |
+
encoder_hidden_states.dtype
|
171 |
+
)
|
172 |
+
|
173 |
+
hidden_states = self.layer[1](
|
174 |
+
hidden_states,
|
175 |
+
key_value_states=encoder_hidden_states,
|
176 |
+
attention_mask=encoder_extended_attention_mask,
|
177 |
+
)
|
178 |
+
|
179 |
+
# Apply Film Conditional Feed Forward layer
|
180 |
+
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
181 |
+
|
182 |
+
return (hidden_states,)
|
183 |
+
|
184 |
+
|
185 |
+
class T5LayerSelfAttentionCond(nn.Module):
|
186 |
+
def __init__(self, d_model, d_kv, num_heads, dropout_rate):
|
187 |
+
super().__init__()
|
188 |
+
self.layer_norm = T5LayerNorm(d_model)
|
189 |
+
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
190 |
+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
191 |
+
self.dropout = nn.Dropout(dropout_rate)
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
hidden_states,
|
196 |
+
conditioning_emb=None,
|
197 |
+
attention_mask=None,
|
198 |
+
):
|
199 |
+
# pre_self_attention_layer_norm
|
200 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
201 |
+
|
202 |
+
if conditioning_emb is not None:
|
203 |
+
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
204 |
+
|
205 |
+
# Self-attention block
|
206 |
+
attention_output = self.attention(normed_hidden_states)
|
207 |
+
|
208 |
+
hidden_states = hidden_states + self.dropout(attention_output)
|
209 |
+
|
210 |
+
return hidden_states
|
211 |
+
|
212 |
+
|
213 |
+
class T5LayerCrossAttention(nn.Module):
|
214 |
+
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
|
215 |
+
super().__init__()
|
216 |
+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
217 |
+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
218 |
+
self.dropout = nn.Dropout(dropout_rate)
|
219 |
+
|
220 |
+
def forward(
|
221 |
+
self,
|
222 |
+
hidden_states,
|
223 |
+
key_value_states=None,
|
224 |
+
attention_mask=None,
|
225 |
+
):
|
226 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
227 |
+
attention_output = self.attention(
|
228 |
+
normed_hidden_states,
|
229 |
+
encoder_hidden_states=key_value_states,
|
230 |
+
attention_mask=attention_mask.squeeze(1),
|
231 |
+
)
|
232 |
+
layer_output = hidden_states + self.dropout(attention_output)
|
233 |
+
return layer_output
|
234 |
+
|
235 |
+
|
236 |
+
class T5LayerFFCond(nn.Module):
|
237 |
+
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
|
238 |
+
super().__init__()
|
239 |
+
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
240 |
+
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
241 |
+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
242 |
+
self.dropout = nn.Dropout(dropout_rate)
|
243 |
+
|
244 |
+
def forward(self, hidden_states, conditioning_emb=None):
|
245 |
+
forwarded_states = self.layer_norm(hidden_states)
|
246 |
+
if conditioning_emb is not None:
|
247 |
+
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
248 |
+
|
249 |
+
forwarded_states = self.DenseReluDense(forwarded_states)
|
250 |
+
hidden_states = hidden_states + self.dropout(forwarded_states)
|
251 |
+
return hidden_states
|
252 |
+
|
253 |
+
|
254 |
+
class T5DenseGatedActDense(nn.Module):
|
255 |
+
def __init__(self, d_model, d_ff, dropout_rate):
|
256 |
+
super().__init__()
|
257 |
+
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
258 |
+
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
259 |
+
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
260 |
+
self.dropout = nn.Dropout(dropout_rate)
|
261 |
+
self.act = NewGELUActivation()
|
262 |
+
|
263 |
+
def forward(self, hidden_states):
|
264 |
+
hidden_gelu = self.act(self.wi_0(hidden_states))
|
265 |
+
hidden_linear = self.wi_1(hidden_states)
|
266 |
+
hidden_states = hidden_gelu * hidden_linear
|
267 |
+
hidden_states = self.dropout(hidden_states)
|
268 |
+
|
269 |
+
hidden_states = self.wo(hidden_states)
|
270 |
+
return hidden_states
|
271 |
+
|
272 |
+
|
273 |
+
class T5LayerNorm(nn.Module):
|
274 |
+
def __init__(self, hidden_size, eps=1e-6):
|
275 |
+
"""
|
276 |
+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
277 |
+
"""
|
278 |
+
super().__init__()
|
279 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
280 |
+
self.variance_epsilon = eps
|
281 |
+
|
282 |
+
def forward(self, hidden_states):
|
283 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
284 |
+
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
285 |
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
286 |
+
# half-precision inputs is done in fp32
|
287 |
+
|
288 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
289 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
290 |
+
|
291 |
+
# convert into half-precision if necessary
|
292 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
293 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
294 |
+
|
295 |
+
return self.weight * hidden_states
|
296 |
+
|
297 |
+
|
298 |
+
class NewGELUActivation(nn.Module):
|
299 |
+
"""
|
300 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
301 |
+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
302 |
+
"""
|
303 |
+
|
304 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
305 |
+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
306 |
+
|
307 |
+
|
308 |
+
class T5FiLMLayer(nn.Module):
|
309 |
+
"""
|
310 |
+
FiLM Layer
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(self, in_features, out_features):
|
314 |
+
super().__init__()
|
315 |
+
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
316 |
+
|
317 |
+
def forward(self, x, conditioning_emb):
|
318 |
+
emb = self.scale_bias(conditioning_emb)
|
319 |
+
scale, shift = torch.chunk(emb, 2, -1)
|
320 |
+
x = x * (1 + scale) + shift
|
321 |
+
return x
|
Tiger Model/diffusiers-Tiger/models/transformer_2d.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..models.embeddings import ImagePositionalEmbeddings
|
23 |
+
from ..utils import BaseOutput, deprecate
|
24 |
+
from .attention import BasicTransformerBlock
|
25 |
+
from .embeddings import PatchEmbed
|
26 |
+
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
27 |
+
from .modeling_utils import ModelMixin
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class Transformer2DModelOutput(BaseOutput):
|
32 |
+
"""
|
33 |
+
The output of [`Transformer2DModel`].
|
34 |
+
|
35 |
+
Args:
|
36 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
37 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
38 |
+
distributions for the unnoised latent pixels.
|
39 |
+
"""
|
40 |
+
|
41 |
+
sample: torch.FloatTensor
|
42 |
+
|
43 |
+
|
44 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
45 |
+
"""
|
46 |
+
A 2D Transformer model for image-like data.
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
50 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
51 |
+
in_channels (`int`, *optional*):
|
52 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
53 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
54 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
55 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
56 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
57 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
58 |
+
num_vector_embeds (`int`, *optional*):
|
59 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
60 |
+
Includes the class for the masked latent pixel.
|
61 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
62 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
63 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
64 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
65 |
+
added to the hidden states.
|
66 |
+
|
67 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
68 |
+
attention_bias (`bool`, *optional*):
|
69 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
70 |
+
"""
|
71 |
+
|
72 |
+
@register_to_config
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
num_attention_heads: int = 16,
|
76 |
+
attention_head_dim: int = 88,
|
77 |
+
in_channels: Optional[int] = None,
|
78 |
+
out_channels: Optional[int] = None,
|
79 |
+
num_layers: int = 1,
|
80 |
+
dropout: float = 0.0,
|
81 |
+
norm_num_groups: int = 32,
|
82 |
+
cross_attention_dim: Optional[int] = None,
|
83 |
+
attention_bias: bool = False,
|
84 |
+
sample_size: Optional[int] = None,
|
85 |
+
num_vector_embeds: Optional[int] = None,
|
86 |
+
patch_size: Optional[int] = None,
|
87 |
+
activation_fn: str = "geglu",
|
88 |
+
num_embeds_ada_norm: Optional[int] = None,
|
89 |
+
use_linear_projection: bool = False,
|
90 |
+
only_cross_attention: bool = False,
|
91 |
+
upcast_attention: bool = False,
|
92 |
+
norm_type: str = "layer_norm",
|
93 |
+
norm_elementwise_affine: bool = True,
|
94 |
+
attention_type: str = "default",
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.use_linear_projection = use_linear_projection
|
98 |
+
self.num_attention_heads = num_attention_heads
|
99 |
+
self.attention_head_dim = attention_head_dim
|
100 |
+
inner_dim = num_attention_heads * attention_head_dim
|
101 |
+
|
102 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
103 |
+
# Define whether input is continuous or discrete depending on configuration
|
104 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
105 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
106 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
107 |
+
|
108 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
109 |
+
deprecation_message = (
|
110 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
111 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
112 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
113 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
114 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
115 |
+
)
|
116 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
117 |
+
norm_type = "ada_norm"
|
118 |
+
|
119 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
120 |
+
raise ValueError(
|
121 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
122 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
123 |
+
)
|
124 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
125 |
+
raise ValueError(
|
126 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
127 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
128 |
+
)
|
129 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
130 |
+
raise ValueError(
|
131 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
132 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
133 |
+
)
|
134 |
+
|
135 |
+
# 2. Define input layers
|
136 |
+
if self.is_input_continuous:
|
137 |
+
self.in_channels = in_channels
|
138 |
+
|
139 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
140 |
+
if use_linear_projection:
|
141 |
+
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
|
142 |
+
else:
|
143 |
+
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
144 |
+
elif self.is_input_vectorized:
|
145 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
146 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
147 |
+
|
148 |
+
self.height = sample_size
|
149 |
+
self.width = sample_size
|
150 |
+
self.num_vector_embeds = num_vector_embeds
|
151 |
+
self.num_latent_pixels = self.height * self.width
|
152 |
+
|
153 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
154 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
155 |
+
)
|
156 |
+
elif self.is_input_patches:
|
157 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
158 |
+
|
159 |
+
self.height = sample_size
|
160 |
+
self.width = sample_size
|
161 |
+
|
162 |
+
self.patch_size = patch_size
|
163 |
+
self.pos_embed = PatchEmbed(
|
164 |
+
height=sample_size,
|
165 |
+
width=sample_size,
|
166 |
+
patch_size=patch_size,
|
167 |
+
in_channels=in_channels,
|
168 |
+
embed_dim=inner_dim,
|
169 |
+
)
|
170 |
+
|
171 |
+
# 3. Define transformers blocks
|
172 |
+
self.transformer_blocks = nn.ModuleList(
|
173 |
+
[
|
174 |
+
BasicTransformerBlock(
|
175 |
+
inner_dim,
|
176 |
+
num_attention_heads,
|
177 |
+
attention_head_dim,
|
178 |
+
dropout=dropout,
|
179 |
+
cross_attention_dim=cross_attention_dim,
|
180 |
+
activation_fn=activation_fn,
|
181 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
182 |
+
attention_bias=attention_bias,
|
183 |
+
only_cross_attention=only_cross_attention,
|
184 |
+
upcast_attention=upcast_attention,
|
185 |
+
norm_type=norm_type,
|
186 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
187 |
+
attention_type=attention_type,
|
188 |
+
)
|
189 |
+
for d in range(num_layers)
|
190 |
+
]
|
191 |
+
)
|
192 |
+
|
193 |
+
# 4. Define output layers
|
194 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
195 |
+
if self.is_input_continuous:
|
196 |
+
# TODO: should use out_channels for continuous projections
|
197 |
+
if use_linear_projection:
|
198 |
+
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
|
199 |
+
else:
|
200 |
+
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
201 |
+
elif self.is_input_vectorized:
|
202 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
203 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
204 |
+
elif self.is_input_patches:
|
205 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
206 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
207 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
208 |
+
|
209 |
+
self.gradient_checkpointing = False
|
210 |
+
|
211 |
+
def forward(
|
212 |
+
self,
|
213 |
+
hidden_states: torch.Tensor,
|
214 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
215 |
+
timestep: Optional[torch.LongTensor] = None,
|
216 |
+
class_labels: Optional[torch.LongTensor] = None,
|
217 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
218 |
+
attention_mask: Optional[torch.Tensor] = None,
|
219 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
220 |
+
return_dict: bool = True,
|
221 |
+
):
|
222 |
+
"""
|
223 |
+
The [`Transformer2DModel`] forward method.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
227 |
+
Input `hidden_states`.
|
228 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
229 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
230 |
+
self-attention.
|
231 |
+
timestep ( `torch.LongTensor`, *optional*):
|
232 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
233 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
234 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
235 |
+
`AdaLayerZeroNorm`.
|
236 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
237 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
238 |
+
|
239 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
240 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
241 |
+
|
242 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
243 |
+
above. This bias will be added to the cross-attention scores.
|
244 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
245 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
246 |
+
tuple.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
250 |
+
`tuple` where the first element is the sample tensor.
|
251 |
+
"""
|
252 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
253 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
254 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
255 |
+
# expects mask of shape:
|
256 |
+
# [batch, key_tokens]
|
257 |
+
# adds singleton query_tokens dimension:
|
258 |
+
# [batch, 1, key_tokens]
|
259 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
260 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
261 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
262 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
263 |
+
# assume that mask is expressed as:
|
264 |
+
# (1 = keep, 0 = discard)
|
265 |
+
# convert mask into a bias that can be added to attention scores:
|
266 |
+
# (keep = +0, discard = -10000.0)
|
267 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
268 |
+
attention_mask = attention_mask.unsqueeze(1)
|
269 |
+
|
270 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
271 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
272 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
273 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
274 |
+
|
275 |
+
# 1. Input
|
276 |
+
if self.is_input_continuous:
|
277 |
+
batch, _, height, width = hidden_states.shape
|
278 |
+
residual = hidden_states
|
279 |
+
|
280 |
+
hidden_states = self.norm(hidden_states)
|
281 |
+
if not self.use_linear_projection:
|
282 |
+
hidden_states = self.proj_in(hidden_states)
|
283 |
+
inner_dim = hidden_states.shape[1]
|
284 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
285 |
+
else:
|
286 |
+
inner_dim = hidden_states.shape[1]
|
287 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
288 |
+
hidden_states = self.proj_in(hidden_states)
|
289 |
+
elif self.is_input_vectorized:
|
290 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
291 |
+
elif self.is_input_patches:
|
292 |
+
hidden_states = self.pos_embed(hidden_states)
|
293 |
+
|
294 |
+
# 2. Blocks
|
295 |
+
for block in self.transformer_blocks:
|
296 |
+
if self.training and self.gradient_checkpointing:
|
297 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
298 |
+
block,
|
299 |
+
hidden_states,
|
300 |
+
attention_mask,
|
301 |
+
encoder_hidden_states,
|
302 |
+
encoder_attention_mask,
|
303 |
+
timestep,
|
304 |
+
cross_attention_kwargs,
|
305 |
+
class_labels,
|
306 |
+
use_reentrant=False,
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
hidden_states = block(
|
310 |
+
hidden_states,
|
311 |
+
attention_mask=attention_mask,
|
312 |
+
encoder_hidden_states=encoder_hidden_states,
|
313 |
+
encoder_attention_mask=encoder_attention_mask,
|
314 |
+
timestep=timestep,
|
315 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
316 |
+
class_labels=class_labels,
|
317 |
+
)
|
318 |
+
|
319 |
+
# 3. Output
|
320 |
+
if self.is_input_continuous:
|
321 |
+
if not self.use_linear_projection:
|
322 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
323 |
+
hidden_states = self.proj_out(hidden_states)
|
324 |
+
else:
|
325 |
+
hidden_states = self.proj_out(hidden_states)
|
326 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
327 |
+
|
328 |
+
output = hidden_states + residual
|
329 |
+
elif self.is_input_vectorized:
|
330 |
+
hidden_states = self.norm_out(hidden_states)
|
331 |
+
logits = self.out(hidden_states)
|
332 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
333 |
+
logits = logits.permute(0, 2, 1)
|
334 |
+
|
335 |
+
# log(p(x_0))
|
336 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
337 |
+
elif self.is_input_patches:
|
338 |
+
# TODO: cleanup!
|
339 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
340 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
341 |
+
)
|
342 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
343 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
344 |
+
hidden_states = self.proj_out_2(hidden_states)
|
345 |
+
|
346 |
+
# unpatchify
|
347 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
348 |
+
hidden_states = hidden_states.reshape(
|
349 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
350 |
+
)
|
351 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
352 |
+
output = hidden_states.reshape(
|
353 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
354 |
+
)
|
355 |
+
|
356 |
+
if not return_dict:
|
357 |
+
return (output,)
|
358 |
+
|
359 |
+
return Transformer2DModelOutput(sample=output)
|
Tiger Model/diffusiers-Tiger/models/transformer_temporal.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils import BaseOutput
|
22 |
+
from .attention import BasicTransformerBlock
|
23 |
+
from .modeling_utils import ModelMixin
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class TransformerTemporalModelOutput(BaseOutput):
|
28 |
+
"""
|
29 |
+
The output of [`TransformerTemporalModel`].
|
30 |
+
|
31 |
+
Args:
|
32 |
+
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
33 |
+
The hidden states output conditioned on `encoder_hidden_states` input.
|
34 |
+
"""
|
35 |
+
|
36 |
+
sample: torch.FloatTensor
|
37 |
+
|
38 |
+
|
39 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
40 |
+
"""
|
41 |
+
A Transformer model for video-like data.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
45 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
46 |
+
in_channels (`int`, *optional*):
|
47 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
48 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
49 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
50 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
51 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
52 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
53 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
54 |
+
attention_bias (`bool`, *optional*):
|
55 |
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
56 |
+
double_self_attention (`bool`, *optional*):
|
57 |
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
58 |
+
"""
|
59 |
+
|
60 |
+
@register_to_config
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
num_attention_heads: int = 16,
|
64 |
+
attention_head_dim: int = 88,
|
65 |
+
in_channels: Optional[int] = None,
|
66 |
+
out_channels: Optional[int] = None,
|
67 |
+
num_layers: int = 1,
|
68 |
+
dropout: float = 0.0,
|
69 |
+
norm_num_groups: int = 32,
|
70 |
+
cross_attention_dim: Optional[int] = None,
|
71 |
+
attention_bias: bool = False,
|
72 |
+
sample_size: Optional[int] = None,
|
73 |
+
activation_fn: str = "geglu",
|
74 |
+
norm_elementwise_affine: bool = True,
|
75 |
+
double_self_attention: bool = True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.num_attention_heads = num_attention_heads
|
79 |
+
self.attention_head_dim = attention_head_dim
|
80 |
+
inner_dim = num_attention_heads * attention_head_dim
|
81 |
+
|
82 |
+
self.in_channels = in_channels
|
83 |
+
|
84 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
85 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
86 |
+
|
87 |
+
# 3. Define transformers blocks
|
88 |
+
self.transformer_blocks = nn.ModuleList(
|
89 |
+
[
|
90 |
+
BasicTransformerBlock(
|
91 |
+
inner_dim,
|
92 |
+
num_attention_heads,
|
93 |
+
attention_head_dim,
|
94 |
+
dropout=dropout,
|
95 |
+
cross_attention_dim=cross_attention_dim,
|
96 |
+
activation_fn=activation_fn,
|
97 |
+
attention_bias=attention_bias,
|
98 |
+
double_self_attention=double_self_attention,
|
99 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
100 |
+
)
|
101 |
+
for d in range(num_layers)
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
106 |
+
|
107 |
+
def forward(
|
108 |
+
self,
|
109 |
+
hidden_states,
|
110 |
+
encoder_hidden_states=None,
|
111 |
+
timestep=None,
|
112 |
+
class_labels=None,
|
113 |
+
num_frames=1,
|
114 |
+
cross_attention_kwargs=None,
|
115 |
+
return_dict: bool = True,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
The [`TransformerTemporal`] forward method.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
122 |
+
Input hidden_states.
|
123 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
124 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
125 |
+
self-attention.
|
126 |
+
timestep ( `torch.long`, *optional*):
|
127 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
128 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
129 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
130 |
+
`AdaLayerZeroNorm`.
|
131 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
132 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
133 |
+
tuple.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
137 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
138 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
139 |
+
"""
|
140 |
+
# 1. Input
|
141 |
+
batch_frames, channel, height, width = hidden_states.shape
|
142 |
+
batch_size = batch_frames // num_frames
|
143 |
+
|
144 |
+
residual = hidden_states
|
145 |
+
|
146 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
147 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
148 |
+
|
149 |
+
hidden_states = self.norm(hidden_states)
|
150 |
+
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
151 |
+
|
152 |
+
hidden_states = self.proj_in(hidden_states)
|
153 |
+
|
154 |
+
# 2. Blocks
|
155 |
+
for block in self.transformer_blocks:
|
156 |
+
hidden_states = block(
|
157 |
+
hidden_states,
|
158 |
+
encoder_hidden_states=encoder_hidden_states,
|
159 |
+
timestep=timestep,
|
160 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
161 |
+
class_labels=class_labels,
|
162 |
+
)
|
163 |
+
|
164 |
+
# 3. Output
|
165 |
+
hidden_states = self.proj_out(hidden_states)
|
166 |
+
hidden_states = (
|
167 |
+
hidden_states[None, None, :]
|
168 |
+
.reshape(batch_size, height, width, channel, num_frames)
|
169 |
+
.permute(0, 3, 4, 1, 2)
|
170 |
+
.contiguous()
|
171 |
+
)
|
172 |
+
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
173 |
+
|
174 |
+
output = hidden_states + residual
|
175 |
+
|
176 |
+
if not return_dict:
|
177 |
+
return (output,)
|
178 |
+
|
179 |
+
return TransformerTemporalModelOutput(sample=output)
|
Tiger Model/diffusiers-Tiger/models/unet_1d.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..utils import BaseOutput
|
23 |
+
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
24 |
+
from .modeling_utils import ModelMixin
|
25 |
+
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class UNet1DOutput(BaseOutput):
|
30 |
+
"""
|
31 |
+
The output of [`UNet1DModel`].
|
32 |
+
|
33 |
+
Args:
|
34 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
35 |
+
The hidden states output from the last layer of the model.
|
36 |
+
"""
|
37 |
+
|
38 |
+
sample: torch.FloatTensor
|
39 |
+
|
40 |
+
|
41 |
+
class UNet1DModel(ModelMixin, ConfigMixin):
|
42 |
+
r"""
|
43 |
+
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
44 |
+
|
45 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
46 |
+
for all models (such as downloading or saving).
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
50 |
+
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
51 |
+
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
52 |
+
extra_in_channels (`int`, *optional*, defaults to 0):
|
53 |
+
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
54 |
+
input data has more channels than what the model was initially designed for.
|
55 |
+
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
56 |
+
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
57 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
58 |
+
Whether to flip sin to cos for Fourier time embedding.
|
59 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
|
60 |
+
Tuple of downsample block types.
|
61 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
|
62 |
+
Tuple of upsample block types.
|
63 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
64 |
+
Tuple of block output channels.
|
65 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
66 |
+
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
67 |
+
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
68 |
+
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
69 |
+
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
70 |
+
downsample_each_block (`int`, *optional*, defaults to `False`):
|
71 |
+
Experimental feature for using a UNet without upsampling.
|
72 |
+
"""
|
73 |
+
|
74 |
+
@register_to_config
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
sample_size: int = 65536,
|
78 |
+
sample_rate: Optional[int] = None,
|
79 |
+
in_channels: int = 2,
|
80 |
+
out_channels: int = 2,
|
81 |
+
extra_in_channels: int = 0,
|
82 |
+
time_embedding_type: str = "fourier",
|
83 |
+
flip_sin_to_cos: bool = True,
|
84 |
+
use_timestep_embedding: bool = False,
|
85 |
+
freq_shift: float = 0.0,
|
86 |
+
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
87 |
+
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
88 |
+
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
89 |
+
out_block_type: str = None,
|
90 |
+
block_out_channels: Tuple[int] = (32, 32, 64),
|
91 |
+
act_fn: str = None,
|
92 |
+
norm_num_groups: int = 8,
|
93 |
+
layers_per_block: int = 1,
|
94 |
+
downsample_each_block: bool = False,
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.sample_size = sample_size
|
98 |
+
|
99 |
+
# time
|
100 |
+
if time_embedding_type == "fourier":
|
101 |
+
self.time_proj = GaussianFourierProjection(
|
102 |
+
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
103 |
+
)
|
104 |
+
timestep_input_dim = 2 * block_out_channels[0]
|
105 |
+
elif time_embedding_type == "positional":
|
106 |
+
self.time_proj = Timesteps(
|
107 |
+
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
108 |
+
)
|
109 |
+
timestep_input_dim = block_out_channels[0]
|
110 |
+
|
111 |
+
if use_timestep_embedding:
|
112 |
+
time_embed_dim = block_out_channels[0] * 4
|
113 |
+
self.time_mlp = TimestepEmbedding(
|
114 |
+
in_channels=timestep_input_dim,
|
115 |
+
time_embed_dim=time_embed_dim,
|
116 |
+
act_fn=act_fn,
|
117 |
+
out_dim=block_out_channels[0],
|
118 |
+
)
|
119 |
+
|
120 |
+
self.down_blocks = nn.ModuleList([])
|
121 |
+
self.mid_block = None
|
122 |
+
self.up_blocks = nn.ModuleList([])
|
123 |
+
self.out_block = None
|
124 |
+
|
125 |
+
# down
|
126 |
+
output_channel = in_channels
|
127 |
+
for i, down_block_type in enumerate(down_block_types):
|
128 |
+
input_channel = output_channel
|
129 |
+
output_channel = block_out_channels[i]
|
130 |
+
|
131 |
+
if i == 0:
|
132 |
+
input_channel += extra_in_channels
|
133 |
+
|
134 |
+
is_final_block = i == len(block_out_channels) - 1
|
135 |
+
|
136 |
+
down_block = get_down_block(
|
137 |
+
down_block_type,
|
138 |
+
num_layers=layers_per_block,
|
139 |
+
in_channels=input_channel,
|
140 |
+
out_channels=output_channel,
|
141 |
+
temb_channels=block_out_channels[0],
|
142 |
+
add_downsample=not is_final_block or downsample_each_block,
|
143 |
+
)
|
144 |
+
self.down_blocks.append(down_block)
|
145 |
+
|
146 |
+
# mid
|
147 |
+
self.mid_block = get_mid_block(
|
148 |
+
mid_block_type,
|
149 |
+
in_channels=block_out_channels[-1],
|
150 |
+
mid_channels=block_out_channels[-1],
|
151 |
+
out_channels=block_out_channels[-1],
|
152 |
+
embed_dim=block_out_channels[0],
|
153 |
+
num_layers=layers_per_block,
|
154 |
+
add_downsample=downsample_each_block,
|
155 |
+
)
|
156 |
+
|
157 |
+
# up
|
158 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
159 |
+
output_channel = reversed_block_out_channels[0]
|
160 |
+
if out_block_type is None:
|
161 |
+
final_upsample_channels = out_channels
|
162 |
+
else:
|
163 |
+
final_upsample_channels = block_out_channels[0]
|
164 |
+
|
165 |
+
for i, up_block_type in enumerate(up_block_types):
|
166 |
+
prev_output_channel = output_channel
|
167 |
+
output_channel = (
|
168 |
+
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
169 |
+
)
|
170 |
+
|
171 |
+
is_final_block = i == len(block_out_channels) - 1
|
172 |
+
|
173 |
+
up_block = get_up_block(
|
174 |
+
up_block_type,
|
175 |
+
num_layers=layers_per_block,
|
176 |
+
in_channels=prev_output_channel,
|
177 |
+
out_channels=output_channel,
|
178 |
+
temb_channels=block_out_channels[0],
|
179 |
+
add_upsample=not is_final_block,
|
180 |
+
)
|
181 |
+
self.up_blocks.append(up_block)
|
182 |
+
prev_output_channel = output_channel
|
183 |
+
|
184 |
+
# out
|
185 |
+
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
186 |
+
self.out_block = get_out_block(
|
187 |
+
out_block_type=out_block_type,
|
188 |
+
num_groups_out=num_groups_out,
|
189 |
+
embed_dim=block_out_channels[0],
|
190 |
+
out_channels=out_channels,
|
191 |
+
act_fn=act_fn,
|
192 |
+
fc_dim=block_out_channels[-1] // 4,
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
sample: torch.FloatTensor,
|
198 |
+
timestep: Union[torch.Tensor, float, int],
|
199 |
+
return_dict: bool = True,
|
200 |
+
) -> Union[UNet1DOutput, Tuple]:
|
201 |
+
r"""
|
202 |
+
The [`UNet1DModel`] forward method.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
sample (`torch.FloatTensor`):
|
206 |
+
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
207 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
208 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
209 |
+
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
213 |
+
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
214 |
+
returned where the first element is the sample tensor.
|
215 |
+
"""
|
216 |
+
|
217 |
+
# 1. time
|
218 |
+
timesteps = timestep
|
219 |
+
if not torch.is_tensor(timesteps):
|
220 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
221 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
222 |
+
timesteps = timesteps[None].to(sample.device)
|
223 |
+
|
224 |
+
timestep_embed = self.time_proj(timesteps)
|
225 |
+
if self.config.use_timestep_embedding:
|
226 |
+
timestep_embed = self.time_mlp(timestep_embed)
|
227 |
+
else:
|
228 |
+
timestep_embed = timestep_embed[..., None]
|
229 |
+
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
230 |
+
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
231 |
+
|
232 |
+
# 2. down
|
233 |
+
down_block_res_samples = ()
|
234 |
+
for downsample_block in self.down_blocks:
|
235 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
236 |
+
down_block_res_samples += res_samples
|
237 |
+
|
238 |
+
# 3. mid
|
239 |
+
if self.mid_block:
|
240 |
+
sample = self.mid_block(sample, timestep_embed)
|
241 |
+
|
242 |
+
# 4. up
|
243 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
244 |
+
res_samples = down_block_res_samples[-1:]
|
245 |
+
down_block_res_samples = down_block_res_samples[:-1]
|
246 |
+
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
247 |
+
|
248 |
+
# 5. post-process
|
249 |
+
if self.out_block:
|
250 |
+
sample = self.out_block(sample, timestep_embed)
|
251 |
+
|
252 |
+
if not return_dict:
|
253 |
+
return (sample,)
|
254 |
+
|
255 |
+
return UNet1DOutput(sample=sample)
|
Tiger Model/diffusiers-Tiger/models/unet_1d_blocks.py
ADDED
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from .activations import get_activation
|
21 |
+
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
22 |
+
|
23 |
+
|
24 |
+
class DownResnetBlock1D(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_channels,
|
28 |
+
out_channels=None,
|
29 |
+
num_layers=1,
|
30 |
+
conv_shortcut=False,
|
31 |
+
temb_channels=32,
|
32 |
+
groups=32,
|
33 |
+
groups_out=None,
|
34 |
+
non_linearity=None,
|
35 |
+
time_embedding_norm="default",
|
36 |
+
output_scale_factor=1.0,
|
37 |
+
add_downsample=True,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.in_channels = in_channels
|
41 |
+
out_channels = in_channels if out_channels is None else out_channels
|
42 |
+
self.out_channels = out_channels
|
43 |
+
self.use_conv_shortcut = conv_shortcut
|
44 |
+
self.time_embedding_norm = time_embedding_norm
|
45 |
+
self.add_downsample = add_downsample
|
46 |
+
self.output_scale_factor = output_scale_factor
|
47 |
+
|
48 |
+
if groups_out is None:
|
49 |
+
groups_out = groups
|
50 |
+
|
51 |
+
# there will always be at least one resnet
|
52 |
+
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
53 |
+
|
54 |
+
for _ in range(num_layers):
|
55 |
+
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
56 |
+
|
57 |
+
self.resnets = nn.ModuleList(resnets)
|
58 |
+
|
59 |
+
if non_linearity is None:
|
60 |
+
self.nonlinearity = None
|
61 |
+
else:
|
62 |
+
self.nonlinearity = get_activation(non_linearity)
|
63 |
+
|
64 |
+
self.downsample = None
|
65 |
+
if add_downsample:
|
66 |
+
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
67 |
+
|
68 |
+
def forward(self, hidden_states, temb=None):
|
69 |
+
output_states = ()
|
70 |
+
|
71 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
72 |
+
for resnet in self.resnets[1:]:
|
73 |
+
hidden_states = resnet(hidden_states, temb)
|
74 |
+
|
75 |
+
output_states += (hidden_states,)
|
76 |
+
|
77 |
+
if self.nonlinearity is not None:
|
78 |
+
hidden_states = self.nonlinearity(hidden_states)
|
79 |
+
|
80 |
+
if self.downsample is not None:
|
81 |
+
hidden_states = self.downsample(hidden_states)
|
82 |
+
|
83 |
+
return hidden_states, output_states
|
84 |
+
|
85 |
+
|
86 |
+
class UpResnetBlock1D(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
in_channels,
|
90 |
+
out_channels=None,
|
91 |
+
num_layers=1,
|
92 |
+
temb_channels=32,
|
93 |
+
groups=32,
|
94 |
+
groups_out=None,
|
95 |
+
non_linearity=None,
|
96 |
+
time_embedding_norm="default",
|
97 |
+
output_scale_factor=1.0,
|
98 |
+
add_upsample=True,
|
99 |
+
):
|
100 |
+
super().__init__()
|
101 |
+
self.in_channels = in_channels
|
102 |
+
out_channels = in_channels if out_channels is None else out_channels
|
103 |
+
self.out_channels = out_channels
|
104 |
+
self.time_embedding_norm = time_embedding_norm
|
105 |
+
self.add_upsample = add_upsample
|
106 |
+
self.output_scale_factor = output_scale_factor
|
107 |
+
|
108 |
+
if groups_out is None:
|
109 |
+
groups_out = groups
|
110 |
+
|
111 |
+
# there will always be at least one resnet
|
112 |
+
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
113 |
+
|
114 |
+
for _ in range(num_layers):
|
115 |
+
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
116 |
+
|
117 |
+
self.resnets = nn.ModuleList(resnets)
|
118 |
+
|
119 |
+
if non_linearity is None:
|
120 |
+
self.nonlinearity = None
|
121 |
+
else:
|
122 |
+
self.nonlinearity = get_activation(non_linearity)
|
123 |
+
|
124 |
+
self.upsample = None
|
125 |
+
if add_upsample:
|
126 |
+
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
127 |
+
|
128 |
+
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
|
129 |
+
if res_hidden_states_tuple is not None:
|
130 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
131 |
+
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
132 |
+
|
133 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
134 |
+
for resnet in self.resnets[1:]:
|
135 |
+
hidden_states = resnet(hidden_states, temb)
|
136 |
+
|
137 |
+
if self.nonlinearity is not None:
|
138 |
+
hidden_states = self.nonlinearity(hidden_states)
|
139 |
+
|
140 |
+
if self.upsample is not None:
|
141 |
+
hidden_states = self.upsample(hidden_states)
|
142 |
+
|
143 |
+
return hidden_states
|
144 |
+
|
145 |
+
|
146 |
+
class ValueFunctionMidBlock1D(nn.Module):
|
147 |
+
def __init__(self, in_channels, out_channels, embed_dim):
|
148 |
+
super().__init__()
|
149 |
+
self.in_channels = in_channels
|
150 |
+
self.out_channels = out_channels
|
151 |
+
self.embed_dim = embed_dim
|
152 |
+
|
153 |
+
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
154 |
+
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
155 |
+
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
156 |
+
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
157 |
+
|
158 |
+
def forward(self, x, temb=None):
|
159 |
+
x = self.res1(x, temb)
|
160 |
+
x = self.down1(x)
|
161 |
+
x = self.res2(x, temb)
|
162 |
+
x = self.down2(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class MidResTemporalBlock1D(nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
in_channels,
|
170 |
+
out_channels,
|
171 |
+
embed_dim,
|
172 |
+
num_layers: int = 1,
|
173 |
+
add_downsample: bool = False,
|
174 |
+
add_upsample: bool = False,
|
175 |
+
non_linearity=None,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
self.in_channels = in_channels
|
179 |
+
self.out_channels = out_channels
|
180 |
+
self.add_downsample = add_downsample
|
181 |
+
|
182 |
+
# there will always be at least one resnet
|
183 |
+
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
184 |
+
|
185 |
+
for _ in range(num_layers):
|
186 |
+
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
187 |
+
|
188 |
+
self.resnets = nn.ModuleList(resnets)
|
189 |
+
|
190 |
+
if non_linearity is None:
|
191 |
+
self.nonlinearity = None
|
192 |
+
else:
|
193 |
+
self.nonlinearity = get_activation(non_linearity)
|
194 |
+
|
195 |
+
self.upsample = None
|
196 |
+
if add_upsample:
|
197 |
+
self.upsample = Downsample1D(out_channels, use_conv=True)
|
198 |
+
|
199 |
+
self.downsample = None
|
200 |
+
if add_downsample:
|
201 |
+
self.downsample = Downsample1D(out_channels, use_conv=True)
|
202 |
+
|
203 |
+
if self.upsample and self.downsample:
|
204 |
+
raise ValueError("Block cannot downsample and upsample")
|
205 |
+
|
206 |
+
def forward(self, hidden_states, temb):
|
207 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
208 |
+
for resnet in self.resnets[1:]:
|
209 |
+
hidden_states = resnet(hidden_states, temb)
|
210 |
+
|
211 |
+
if self.upsample:
|
212 |
+
hidden_states = self.upsample(hidden_states)
|
213 |
+
if self.downsample:
|
214 |
+
self.downsample = self.downsample(hidden_states)
|
215 |
+
|
216 |
+
return hidden_states
|
217 |
+
|
218 |
+
|
219 |
+
class OutConv1DBlock(nn.Module):
|
220 |
+
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
|
221 |
+
super().__init__()
|
222 |
+
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
223 |
+
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
224 |
+
self.final_conv1d_act = get_activation(act_fn)
|
225 |
+
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
226 |
+
|
227 |
+
def forward(self, hidden_states, temb=None):
|
228 |
+
hidden_states = self.final_conv1d_1(hidden_states)
|
229 |
+
hidden_states = rearrange_dims(hidden_states)
|
230 |
+
hidden_states = self.final_conv1d_gn(hidden_states)
|
231 |
+
hidden_states = rearrange_dims(hidden_states)
|
232 |
+
hidden_states = self.final_conv1d_act(hidden_states)
|
233 |
+
hidden_states = self.final_conv1d_2(hidden_states)
|
234 |
+
return hidden_states
|
235 |
+
|
236 |
+
|
237 |
+
class OutValueFunctionBlock(nn.Module):
|
238 |
+
def __init__(self, fc_dim, embed_dim, act_fn="mish"):
|
239 |
+
super().__init__()
|
240 |
+
self.final_block = nn.ModuleList(
|
241 |
+
[
|
242 |
+
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
243 |
+
get_activation(act_fn),
|
244 |
+
nn.Linear(fc_dim // 2, 1),
|
245 |
+
]
|
246 |
+
)
|
247 |
+
|
248 |
+
def forward(self, hidden_states, temb):
|
249 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
250 |
+
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
251 |
+
for layer in self.final_block:
|
252 |
+
hidden_states = layer(hidden_states)
|
253 |
+
|
254 |
+
return hidden_states
|
255 |
+
|
256 |
+
|
257 |
+
_kernels = {
|
258 |
+
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
259 |
+
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
260 |
+
"lanczos3": [
|
261 |
+
0.003689131001010537,
|
262 |
+
0.015056144446134567,
|
263 |
+
-0.03399861603975296,
|
264 |
+
-0.066637322306633,
|
265 |
+
0.13550527393817902,
|
266 |
+
0.44638532400131226,
|
267 |
+
0.44638532400131226,
|
268 |
+
0.13550527393817902,
|
269 |
+
-0.066637322306633,
|
270 |
+
-0.03399861603975296,
|
271 |
+
0.015056144446134567,
|
272 |
+
0.003689131001010537,
|
273 |
+
],
|
274 |
+
}
|
275 |
+
|
276 |
+
|
277 |
+
class Downsample1d(nn.Module):
|
278 |
+
def __init__(self, kernel="linear", pad_mode="reflect"):
|
279 |
+
super().__init__()
|
280 |
+
self.pad_mode = pad_mode
|
281 |
+
kernel_1d = torch.tensor(_kernels[kernel])
|
282 |
+
self.pad = kernel_1d.shape[0] // 2 - 1
|
283 |
+
self.register_buffer("kernel", kernel_1d)
|
284 |
+
|
285 |
+
def forward(self, hidden_states):
|
286 |
+
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
287 |
+
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
288 |
+
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
289 |
+
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
290 |
+
weight[indices, indices] = kernel
|
291 |
+
return F.conv1d(hidden_states, weight, stride=2)
|
292 |
+
|
293 |
+
|
294 |
+
class Upsample1d(nn.Module):
|
295 |
+
def __init__(self, kernel="linear", pad_mode="reflect"):
|
296 |
+
super().__init__()
|
297 |
+
self.pad_mode = pad_mode
|
298 |
+
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
299 |
+
self.pad = kernel_1d.shape[0] // 2 - 1
|
300 |
+
self.register_buffer("kernel", kernel_1d)
|
301 |
+
|
302 |
+
def forward(self, hidden_states, temb=None):
|
303 |
+
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
304 |
+
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
305 |
+
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
306 |
+
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
307 |
+
weight[indices, indices] = kernel
|
308 |
+
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
309 |
+
|
310 |
+
|
311 |
+
class SelfAttention1d(nn.Module):
|
312 |
+
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
|
313 |
+
super().__init__()
|
314 |
+
self.channels = in_channels
|
315 |
+
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
316 |
+
self.num_heads = n_head
|
317 |
+
|
318 |
+
self.query = nn.Linear(self.channels, self.channels)
|
319 |
+
self.key = nn.Linear(self.channels, self.channels)
|
320 |
+
self.value = nn.Linear(self.channels, self.channels)
|
321 |
+
|
322 |
+
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
323 |
+
|
324 |
+
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
325 |
+
|
326 |
+
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
327 |
+
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
328 |
+
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
329 |
+
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
330 |
+
return new_projection
|
331 |
+
|
332 |
+
def forward(self, hidden_states):
|
333 |
+
residual = hidden_states
|
334 |
+
batch, channel_dim, seq = hidden_states.shape
|
335 |
+
|
336 |
+
hidden_states = self.group_norm(hidden_states)
|
337 |
+
hidden_states = hidden_states.transpose(1, 2)
|
338 |
+
|
339 |
+
query_proj = self.query(hidden_states)
|
340 |
+
key_proj = self.key(hidden_states)
|
341 |
+
value_proj = self.value(hidden_states)
|
342 |
+
|
343 |
+
query_states = self.transpose_for_scores(query_proj)
|
344 |
+
key_states = self.transpose_for_scores(key_proj)
|
345 |
+
value_states = self.transpose_for_scores(value_proj)
|
346 |
+
|
347 |
+
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
348 |
+
|
349 |
+
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
350 |
+
attention_probs = torch.softmax(attention_scores, dim=-1)
|
351 |
+
|
352 |
+
# compute attention output
|
353 |
+
hidden_states = torch.matmul(attention_probs, value_states)
|
354 |
+
|
355 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
356 |
+
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
357 |
+
hidden_states = hidden_states.view(new_hidden_states_shape)
|
358 |
+
|
359 |
+
# compute next hidden_states
|
360 |
+
hidden_states = self.proj_attn(hidden_states)
|
361 |
+
hidden_states = hidden_states.transpose(1, 2)
|
362 |
+
hidden_states = self.dropout(hidden_states)
|
363 |
+
|
364 |
+
output = hidden_states + residual
|
365 |
+
|
366 |
+
return output
|
367 |
+
|
368 |
+
|
369 |
+
class ResConvBlock(nn.Module):
|
370 |
+
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
|
371 |
+
super().__init__()
|
372 |
+
self.is_last = is_last
|
373 |
+
self.has_conv_skip = in_channels != out_channels
|
374 |
+
|
375 |
+
if self.has_conv_skip:
|
376 |
+
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
377 |
+
|
378 |
+
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
379 |
+
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
380 |
+
self.gelu_1 = nn.GELU()
|
381 |
+
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
382 |
+
|
383 |
+
if not self.is_last:
|
384 |
+
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
385 |
+
self.gelu_2 = nn.GELU()
|
386 |
+
|
387 |
+
def forward(self, hidden_states):
|
388 |
+
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
389 |
+
|
390 |
+
hidden_states = self.conv_1(hidden_states)
|
391 |
+
hidden_states = self.group_norm_1(hidden_states)
|
392 |
+
hidden_states = self.gelu_1(hidden_states)
|
393 |
+
hidden_states = self.conv_2(hidden_states)
|
394 |
+
|
395 |
+
if not self.is_last:
|
396 |
+
hidden_states = self.group_norm_2(hidden_states)
|
397 |
+
hidden_states = self.gelu_2(hidden_states)
|
398 |
+
|
399 |
+
output = hidden_states + residual
|
400 |
+
return output
|
401 |
+
|
402 |
+
|
403 |
+
class UNetMidBlock1D(nn.Module):
|
404 |
+
def __init__(self, mid_channels, in_channels, out_channels=None):
|
405 |
+
super().__init__()
|
406 |
+
|
407 |
+
out_channels = in_channels if out_channels is None else out_channels
|
408 |
+
|
409 |
+
# there is always at least one resnet
|
410 |
+
self.down = Downsample1d("cubic")
|
411 |
+
resnets = [
|
412 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
413 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
414 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
415 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
416 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
417 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
418 |
+
]
|
419 |
+
attentions = [
|
420 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
421 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
422 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
423 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
424 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
425 |
+
SelfAttention1d(out_channels, out_channels // 32),
|
426 |
+
]
|
427 |
+
self.up = Upsample1d(kernel="cubic")
|
428 |
+
|
429 |
+
self.attentions = nn.ModuleList(attentions)
|
430 |
+
self.resnets = nn.ModuleList(resnets)
|
431 |
+
|
432 |
+
def forward(self, hidden_states, temb=None):
|
433 |
+
hidden_states = self.down(hidden_states)
|
434 |
+
for attn, resnet in zip(self.attentions, self.resnets):
|
435 |
+
hidden_states = resnet(hidden_states)
|
436 |
+
hidden_states = attn(hidden_states)
|
437 |
+
|
438 |
+
hidden_states = self.up(hidden_states)
|
439 |
+
|
440 |
+
return hidden_states
|
441 |
+
|
442 |
+
|
443 |
+
class AttnDownBlock1D(nn.Module):
|
444 |
+
def __init__(self, out_channels, in_channels, mid_channels=None):
|
445 |
+
super().__init__()
|
446 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
447 |
+
|
448 |
+
self.down = Downsample1d("cubic")
|
449 |
+
resnets = [
|
450 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
451 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
452 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
453 |
+
]
|
454 |
+
attentions = [
|
455 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
456 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
457 |
+
SelfAttention1d(out_channels, out_channels // 32),
|
458 |
+
]
|
459 |
+
|
460 |
+
self.attentions = nn.ModuleList(attentions)
|
461 |
+
self.resnets = nn.ModuleList(resnets)
|
462 |
+
|
463 |
+
def forward(self, hidden_states, temb=None):
|
464 |
+
hidden_states = self.down(hidden_states)
|
465 |
+
|
466 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
467 |
+
hidden_states = resnet(hidden_states)
|
468 |
+
hidden_states = attn(hidden_states)
|
469 |
+
|
470 |
+
return hidden_states, (hidden_states,)
|
471 |
+
|
472 |
+
|
473 |
+
class DownBlock1D(nn.Module):
|
474 |
+
def __init__(self, out_channels, in_channels, mid_channels=None):
|
475 |
+
super().__init__()
|
476 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
477 |
+
|
478 |
+
self.down = Downsample1d("cubic")
|
479 |
+
resnets = [
|
480 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
481 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
482 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
483 |
+
]
|
484 |
+
|
485 |
+
self.resnets = nn.ModuleList(resnets)
|
486 |
+
|
487 |
+
def forward(self, hidden_states, temb=None):
|
488 |
+
hidden_states = self.down(hidden_states)
|
489 |
+
|
490 |
+
for resnet in self.resnets:
|
491 |
+
hidden_states = resnet(hidden_states)
|
492 |
+
|
493 |
+
return hidden_states, (hidden_states,)
|
494 |
+
|
495 |
+
|
496 |
+
class DownBlock1DNoSkip(nn.Module):
|
497 |
+
def __init__(self, out_channels, in_channels, mid_channels=None):
|
498 |
+
super().__init__()
|
499 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
500 |
+
|
501 |
+
resnets = [
|
502 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
503 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
504 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
505 |
+
]
|
506 |
+
|
507 |
+
self.resnets = nn.ModuleList(resnets)
|
508 |
+
|
509 |
+
def forward(self, hidden_states, temb=None):
|
510 |
+
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
511 |
+
for resnet in self.resnets:
|
512 |
+
hidden_states = resnet(hidden_states)
|
513 |
+
|
514 |
+
return hidden_states, (hidden_states,)
|
515 |
+
|
516 |
+
|
517 |
+
class AttnUpBlock1D(nn.Module):
|
518 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
519 |
+
super().__init__()
|
520 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
521 |
+
|
522 |
+
resnets = [
|
523 |
+
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
524 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
525 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
526 |
+
]
|
527 |
+
attentions = [
|
528 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
529 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
530 |
+
SelfAttention1d(out_channels, out_channels // 32),
|
531 |
+
]
|
532 |
+
|
533 |
+
self.attentions = nn.ModuleList(attentions)
|
534 |
+
self.resnets = nn.ModuleList(resnets)
|
535 |
+
self.up = Upsample1d(kernel="cubic")
|
536 |
+
|
537 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
538 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
539 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
540 |
+
|
541 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
542 |
+
hidden_states = resnet(hidden_states)
|
543 |
+
hidden_states = attn(hidden_states)
|
544 |
+
|
545 |
+
hidden_states = self.up(hidden_states)
|
546 |
+
|
547 |
+
return hidden_states
|
548 |
+
|
549 |
+
|
550 |
+
class UpBlock1D(nn.Module):
|
551 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
552 |
+
super().__init__()
|
553 |
+
mid_channels = in_channels if mid_channels is None else mid_channels
|
554 |
+
|
555 |
+
resnets = [
|
556 |
+
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
557 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
558 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
559 |
+
]
|
560 |
+
|
561 |
+
self.resnets = nn.ModuleList(resnets)
|
562 |
+
self.up = Upsample1d(kernel="cubic")
|
563 |
+
|
564 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
565 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
566 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
567 |
+
|
568 |
+
for resnet in self.resnets:
|
569 |
+
hidden_states = resnet(hidden_states)
|
570 |
+
|
571 |
+
hidden_states = self.up(hidden_states)
|
572 |
+
|
573 |
+
return hidden_states
|
574 |
+
|
575 |
+
|
576 |
+
class UpBlock1DNoSkip(nn.Module):
|
577 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
578 |
+
super().__init__()
|
579 |
+
mid_channels = in_channels if mid_channels is None else mid_channels
|
580 |
+
|
581 |
+
resnets = [
|
582 |
+
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
583 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
584 |
+
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
585 |
+
]
|
586 |
+
|
587 |
+
self.resnets = nn.ModuleList(resnets)
|
588 |
+
|
589 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
590 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
591 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
592 |
+
|
593 |
+
for resnet in self.resnets:
|
594 |
+
hidden_states = resnet(hidden_states)
|
595 |
+
|
596 |
+
return hidden_states
|
597 |
+
|
598 |
+
|
599 |
+
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
|
600 |
+
if down_block_type == "DownResnetBlock1D":
|
601 |
+
return DownResnetBlock1D(
|
602 |
+
in_channels=in_channels,
|
603 |
+
num_layers=num_layers,
|
604 |
+
out_channels=out_channels,
|
605 |
+
temb_channels=temb_channels,
|
606 |
+
add_downsample=add_downsample,
|
607 |
+
)
|
608 |
+
elif down_block_type == "DownBlock1D":
|
609 |
+
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
610 |
+
elif down_block_type == "AttnDownBlock1D":
|
611 |
+
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
612 |
+
elif down_block_type == "DownBlock1DNoSkip":
|
613 |
+
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
614 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
615 |
+
|
616 |
+
|
617 |
+
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
|
618 |
+
if up_block_type == "UpResnetBlock1D":
|
619 |
+
return UpResnetBlock1D(
|
620 |
+
in_channels=in_channels,
|
621 |
+
num_layers=num_layers,
|
622 |
+
out_channels=out_channels,
|
623 |
+
temb_channels=temb_channels,
|
624 |
+
add_upsample=add_upsample,
|
625 |
+
)
|
626 |
+
elif up_block_type == "UpBlock1D":
|
627 |
+
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
628 |
+
elif up_block_type == "AttnUpBlock1D":
|
629 |
+
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
630 |
+
elif up_block_type == "UpBlock1DNoSkip":
|
631 |
+
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
632 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
633 |
+
|
634 |
+
|
635 |
+
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
|
636 |
+
if mid_block_type == "MidResTemporalBlock1D":
|
637 |
+
return MidResTemporalBlock1D(
|
638 |
+
num_layers=num_layers,
|
639 |
+
in_channels=in_channels,
|
640 |
+
out_channels=out_channels,
|
641 |
+
embed_dim=embed_dim,
|
642 |
+
add_downsample=add_downsample,
|
643 |
+
)
|
644 |
+
elif mid_block_type == "ValueFunctionMidBlock1D":
|
645 |
+
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
646 |
+
elif mid_block_type == "UNetMidBlock1D":
|
647 |
+
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
648 |
+
raise ValueError(f"{mid_block_type} does not exist.")
|
649 |
+
|
650 |
+
|
651 |
+
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
|
652 |
+
if out_block_type == "OutConv1DBlock":
|
653 |
+
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
654 |
+
elif out_block_type == "ValueFunction":
|
655 |
+
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
656 |
+
return None
|
Tiger Model/diffusiers-Tiger/models/unet_2d.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils import BaseOutput
|
22 |
+
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
23 |
+
from .modeling_utils import ModelMixin
|
24 |
+
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class UNet2DOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
The output of [`UNet2DModel`].
|
31 |
+
|
32 |
+
Args:
|
33 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
34 |
+
The hidden states output from the last layer of the model.
|
35 |
+
"""
|
36 |
+
|
37 |
+
sample: torch.FloatTensor
|
38 |
+
|
39 |
+
|
40 |
+
class UNet2DModel(ModelMixin, ConfigMixin):
|
41 |
+
r"""
|
42 |
+
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
43 |
+
|
44 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
45 |
+
for all models (such as downloading or saving).
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
49 |
+
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
50 |
+
1)`.
|
51 |
+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
52 |
+
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
53 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
54 |
+
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
55 |
+
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
56 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
57 |
+
Whether to flip sin to cos for Fourier time embedding.
|
58 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
59 |
+
Tuple of downsample block types.
|
60 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
61 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
62 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
63 |
+
Tuple of upsample block types.
|
64 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
65 |
+
Tuple of block output channels.
|
66 |
+
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
67 |
+
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
68 |
+
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
69 |
+
downsample_type (`str`, *optional*, defaults to `conv`):
|
70 |
+
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
71 |
+
upsample_type (`str`, *optional*, defaults to `conv`):
|
72 |
+
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
73 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
74 |
+
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
75 |
+
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
76 |
+
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
77 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
78 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
79 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
80 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
81 |
+
`"timestep"`, or `"identity"`.
|
82 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
83 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
84 |
+
conditioning with `class_embed_type` equal to `None`.
|
85 |
+
"""
|
86 |
+
|
87 |
+
@register_to_config
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
91 |
+
in_channels: int = 3,
|
92 |
+
out_channels: int = 3,
|
93 |
+
center_input_sample: bool = False,
|
94 |
+
time_embedding_type: str = "positional",
|
95 |
+
freq_shift: int = 0,
|
96 |
+
flip_sin_to_cos: bool = True,
|
97 |
+
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
98 |
+
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
99 |
+
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
100 |
+
layers_per_block: int = 2,
|
101 |
+
mid_block_scale_factor: float = 1,
|
102 |
+
downsample_padding: int = 1,
|
103 |
+
downsample_type: str = "conv",
|
104 |
+
upsample_type: str = "conv",
|
105 |
+
act_fn: str = "silu",
|
106 |
+
attention_head_dim: Optional[int] = 8,
|
107 |
+
norm_num_groups: int = 32,
|
108 |
+
norm_eps: float = 1e-5,
|
109 |
+
resnet_time_scale_shift: str = "default",
|
110 |
+
add_attention: bool = True,
|
111 |
+
class_embed_type: Optional[str] = None,
|
112 |
+
num_class_embeds: Optional[int] = None,
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
|
116 |
+
self.sample_size = sample_size
|
117 |
+
time_embed_dim = block_out_channels[0] * 4
|
118 |
+
|
119 |
+
# Check inputs
|
120 |
+
if len(down_block_types) != len(up_block_types):
|
121 |
+
raise ValueError(
|
122 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
123 |
+
)
|
124 |
+
|
125 |
+
if len(block_out_channels) != len(down_block_types):
|
126 |
+
raise ValueError(
|
127 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
128 |
+
)
|
129 |
+
|
130 |
+
# input
|
131 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
132 |
+
|
133 |
+
# time
|
134 |
+
if time_embedding_type == "fourier":
|
135 |
+
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
136 |
+
timestep_input_dim = 2 * block_out_channels[0]
|
137 |
+
elif time_embedding_type == "positional":
|
138 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
139 |
+
timestep_input_dim = block_out_channels[0]
|
140 |
+
|
141 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
142 |
+
|
143 |
+
# class embedding
|
144 |
+
if class_embed_type is None and num_class_embeds is not None:
|
145 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
146 |
+
elif class_embed_type == "timestep":
|
147 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
148 |
+
elif class_embed_type == "identity":
|
149 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
150 |
+
else:
|
151 |
+
self.class_embedding = None
|
152 |
+
|
153 |
+
self.down_blocks = nn.ModuleList([])
|
154 |
+
self.mid_block = None
|
155 |
+
self.up_blocks = nn.ModuleList([])
|
156 |
+
|
157 |
+
# down
|
158 |
+
output_channel = block_out_channels[0]
|
159 |
+
for i, down_block_type in enumerate(down_block_types):
|
160 |
+
input_channel = output_channel
|
161 |
+
output_channel = block_out_channels[i]
|
162 |
+
is_final_block = i == len(block_out_channels) - 1
|
163 |
+
|
164 |
+
down_block = get_down_block(
|
165 |
+
down_block_type,
|
166 |
+
num_layers=layers_per_block,
|
167 |
+
in_channels=input_channel,
|
168 |
+
out_channels=output_channel,
|
169 |
+
temb_channels=time_embed_dim,
|
170 |
+
add_downsample=not is_final_block,
|
171 |
+
resnet_eps=norm_eps,
|
172 |
+
resnet_act_fn=act_fn,
|
173 |
+
resnet_groups=norm_num_groups,
|
174 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
175 |
+
downsample_padding=downsample_padding,
|
176 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
177 |
+
downsample_type=downsample_type,
|
178 |
+
)
|
179 |
+
self.down_blocks.append(down_block)
|
180 |
+
|
181 |
+
# mid
|
182 |
+
self.mid_block = UNetMidBlock2D(
|
183 |
+
in_channels=block_out_channels[-1],
|
184 |
+
temb_channels=time_embed_dim,
|
185 |
+
resnet_eps=norm_eps,
|
186 |
+
resnet_act_fn=act_fn,
|
187 |
+
output_scale_factor=mid_block_scale_factor,
|
188 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
189 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
190 |
+
resnet_groups=norm_num_groups,
|
191 |
+
add_attention=add_attention,
|
192 |
+
)
|
193 |
+
|
194 |
+
# up
|
195 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
196 |
+
output_channel = reversed_block_out_channels[0]
|
197 |
+
for i, up_block_type in enumerate(up_block_types):
|
198 |
+
prev_output_channel = output_channel
|
199 |
+
output_channel = reversed_block_out_channels[i]
|
200 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
201 |
+
|
202 |
+
is_final_block = i == len(block_out_channels) - 1
|
203 |
+
|
204 |
+
up_block = get_up_block(
|
205 |
+
up_block_type,
|
206 |
+
num_layers=layers_per_block + 1,
|
207 |
+
in_channels=input_channel,
|
208 |
+
out_channels=output_channel,
|
209 |
+
prev_output_channel=prev_output_channel,
|
210 |
+
temb_channels=time_embed_dim,
|
211 |
+
add_upsample=not is_final_block,
|
212 |
+
resnet_eps=norm_eps,
|
213 |
+
resnet_act_fn=act_fn,
|
214 |
+
resnet_groups=norm_num_groups,
|
215 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
216 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
217 |
+
upsample_type=upsample_type,
|
218 |
+
)
|
219 |
+
self.up_blocks.append(up_block)
|
220 |
+
prev_output_channel = output_channel
|
221 |
+
|
222 |
+
# out
|
223 |
+
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
224 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
225 |
+
self.conv_act = nn.SiLU()
|
226 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
227 |
+
|
228 |
+
def forward(
|
229 |
+
self,
|
230 |
+
sample: torch.FloatTensor,
|
231 |
+
timestep: Union[torch.Tensor, float, int],
|
232 |
+
class_labels: Optional[torch.Tensor] = None,
|
233 |
+
return_dict: bool = True,
|
234 |
+
) -> Union[UNet2DOutput, Tuple]:
|
235 |
+
r"""
|
236 |
+
The [`UNet2DModel`] forward method.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
sample (`torch.FloatTensor`):
|
240 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
241 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
242 |
+
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
243 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
244 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
245 |
+
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
249 |
+
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
250 |
+
returned where the first element is the sample tensor.
|
251 |
+
"""
|
252 |
+
# 0. center input if necessary
|
253 |
+
if self.config.center_input_sample:
|
254 |
+
sample = 2 * sample - 1.0
|
255 |
+
|
256 |
+
# 1. time
|
257 |
+
timesteps = timestep
|
258 |
+
if not torch.is_tensor(timesteps):
|
259 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
260 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
261 |
+
timesteps = timesteps[None].to(sample.device)
|
262 |
+
|
263 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
264 |
+
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
265 |
+
|
266 |
+
t_emb = self.time_proj(timesteps)
|
267 |
+
|
268 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
269 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
270 |
+
# there might be better ways to encapsulate this.
|
271 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
272 |
+
emb = self.time_embedding(t_emb)
|
273 |
+
|
274 |
+
if self.class_embedding is not None:
|
275 |
+
if class_labels is None:
|
276 |
+
raise ValueError("class_labels should be provided when doing class conditioning")
|
277 |
+
|
278 |
+
if self.config.class_embed_type == "timestep":
|
279 |
+
class_labels = self.time_proj(class_labels)
|
280 |
+
|
281 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
282 |
+
emb = emb + class_emb
|
283 |
+
|
284 |
+
# 2. pre-process
|
285 |
+
skip_sample = sample
|
286 |
+
sample = self.conv_in(sample)
|
287 |
+
|
288 |
+
# 3. down
|
289 |
+
down_block_res_samples = (sample,)
|
290 |
+
for downsample_block in self.down_blocks:
|
291 |
+
if hasattr(downsample_block, "skip_conv"):
|
292 |
+
sample, res_samples, skip_sample = downsample_block(
|
293 |
+
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
297 |
+
|
298 |
+
down_block_res_samples += res_samples
|
299 |
+
|
300 |
+
# 4. mid
|
301 |
+
sample = self.mid_block(sample, emb)
|
302 |
+
|
303 |
+
# 5. up
|
304 |
+
skip_sample = None
|
305 |
+
for upsample_block in self.up_blocks:
|
306 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
307 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
308 |
+
|
309 |
+
if hasattr(upsample_block, "skip_conv"):
|
310 |
+
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
311 |
+
else:
|
312 |
+
sample = upsample_block(sample, res_samples, emb)
|
313 |
+
|
314 |
+
# 6. post-process
|
315 |
+
sample = self.conv_norm_out(sample)
|
316 |
+
sample = self.conv_act(sample)
|
317 |
+
sample = self.conv_out(sample)
|
318 |
+
|
319 |
+
if skip_sample is not None:
|
320 |
+
sample += skip_sample
|
321 |
+
|
322 |
+
if self.config.time_embedding_type == "fourier":
|
323 |
+
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
324 |
+
sample = sample / timesteps
|
325 |
+
|
326 |
+
if not return_dict:
|
327 |
+
return (sample,)
|
328 |
+
|
329 |
+
return UNet2DOutput(sample=sample)
|
Tiger Model/diffusiers-Tiger/models/unet_2d_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Tiger Model/diffusiers-Tiger/models/unet_2d_condition.py
ADDED
@@ -0,0 +1,1009 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..loaders import UNet2DConditionLoadersMixin
|
23 |
+
from ..utils import BaseOutput, logging
|
24 |
+
from .activations import get_activation
|
25 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
26 |
+
from .embeddings import (
|
27 |
+
GaussianFourierProjection,
|
28 |
+
ImageHintTimeEmbedding,
|
29 |
+
ImageProjection,
|
30 |
+
ImageTimeEmbedding,
|
31 |
+
PositionNet,
|
32 |
+
TextImageProjection,
|
33 |
+
TextImageTimeEmbedding,
|
34 |
+
TextTimeEmbedding,
|
35 |
+
TimestepEmbedding,
|
36 |
+
Timesteps,
|
37 |
+
)
|
38 |
+
from .modeling_utils import ModelMixin
|
39 |
+
from .unet_2d_blocks import (
|
40 |
+
UNetMidBlock2DCrossAttn,
|
41 |
+
UNetMidBlock2DSimpleCrossAttn,
|
42 |
+
get_down_block,
|
43 |
+
get_up_block,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class UNet2DConditionOutput(BaseOutput):
|
52 |
+
"""
|
53 |
+
The output of [`UNet2DConditionModel`].
|
54 |
+
|
55 |
+
Args:
|
56 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
57 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
58 |
+
"""
|
59 |
+
|
60 |
+
sample: torch.FloatTensor = None
|
61 |
+
|
62 |
+
|
63 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
64 |
+
r"""
|
65 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
66 |
+
shaped output.
|
67 |
+
|
68 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
69 |
+
for all models (such as downloading or saving).
|
70 |
+
|
71 |
+
Parameters:
|
72 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
73 |
+
Height and width of input/output sample.
|
74 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
75 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
76 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
77 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
78 |
+
Whether to flip the sin to cos in the time embedding.
|
79 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
80 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
81 |
+
The tuple of downsample blocks to use.
|
82 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
83 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
|
84 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
85 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
86 |
+
The tuple of upsample blocks to use.
|
87 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
88 |
+
Whether to include self-attention in the basic transformer blocks, see
|
89 |
+
[`~models.attention.BasicTransformerBlock`].
|
90 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
91 |
+
The tuple of output channels for each block.
|
92 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
93 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
94 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
95 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
96 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
97 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
98 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
99 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
100 |
+
The dimension of the cross attention features.
|
101 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
102 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
103 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
104 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
105 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
106 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
107 |
+
dimension to `cross_attention_dim`.
|
108 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
109 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
110 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
111 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
112 |
+
num_attention_heads (`int`, *optional*):
|
113 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
114 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
115 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
116 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
117 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
118 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
119 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
120 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
121 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
122 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
123 |
+
Dimension for the timestep embeddings.
|
124 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
125 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
126 |
+
class conditioning with `class_embed_type` equal to `None`.
|
127 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
128 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
129 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
130 |
+
An optional override for the dimension of the projected time embedding.
|
131 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
132 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
133 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
134 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
135 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
136 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
137 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
138 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
139 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
140 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
141 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
142 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
143 |
+
embeddings with the class embeddings.
|
144 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
145 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
146 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
147 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
148 |
+
otherwise.
|
149 |
+
"""
|
150 |
+
|
151 |
+
_supports_gradient_checkpointing = True
|
152 |
+
|
153 |
+
@register_to_config
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
sample_size: Optional[int] = None,
|
157 |
+
in_channels: int = 4,
|
158 |
+
out_channels: int = 4,
|
159 |
+
center_input_sample: bool = False,
|
160 |
+
flip_sin_to_cos: bool = True,
|
161 |
+
freq_shift: int = 0,
|
162 |
+
down_block_types: Tuple[str] = (
|
163 |
+
"CrossAttnDownBlock2D",
|
164 |
+
"CrossAttnDownBlock2D",
|
165 |
+
"CrossAttnDownBlock2D",
|
166 |
+
"DownBlock2D",
|
167 |
+
),
|
168 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
169 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
170 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
171 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
172 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
173 |
+
downsample_padding: int = 1,
|
174 |
+
mid_block_scale_factor: float = 1,
|
175 |
+
act_fn: str = "silu",
|
176 |
+
norm_num_groups: Optional[int] = 32,
|
177 |
+
norm_eps: float = 1e-5,
|
178 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
179 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
180 |
+
encoder_hid_dim: Optional[int] = None,
|
181 |
+
encoder_hid_dim_type: Optional[str] = None,
|
182 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
183 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
184 |
+
dual_cross_attention: bool = False,
|
185 |
+
use_linear_projection: bool = False,
|
186 |
+
class_embed_type: Optional[str] = None,
|
187 |
+
addition_embed_type: Optional[str] = None,
|
188 |
+
addition_time_embed_dim: Optional[int] = None,
|
189 |
+
num_class_embeds: Optional[int] = None,
|
190 |
+
upcast_attention: bool = False,
|
191 |
+
resnet_time_scale_shift: str = "default",
|
192 |
+
resnet_skip_time_act: bool = False,
|
193 |
+
resnet_out_scale_factor: int = 1.0,
|
194 |
+
time_embedding_type: str = "positional",
|
195 |
+
time_embedding_dim: Optional[int] = None,
|
196 |
+
time_embedding_act_fn: Optional[str] = None,
|
197 |
+
timestep_post_act: Optional[str] = None,
|
198 |
+
time_cond_proj_dim: Optional[int] = None,
|
199 |
+
conv_in_kernel: int = 3,
|
200 |
+
conv_out_kernel: int = 3,
|
201 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
202 |
+
attention_type: str = "default",
|
203 |
+
class_embeddings_concat: bool = False,
|
204 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
205 |
+
cross_attention_norm: Optional[str] = None,
|
206 |
+
addition_embed_type_num_heads=64,
|
207 |
+
):
|
208 |
+
super().__init__()
|
209 |
+
|
210 |
+
self.sample_size = sample_size
|
211 |
+
|
212 |
+
if num_attention_heads is not None:
|
213 |
+
raise ValueError(
|
214 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
215 |
+
)
|
216 |
+
|
217 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
218 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
219 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
220 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
221 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
222 |
+
# which is why we correct for the naming here.
|
223 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
224 |
+
|
225 |
+
# Check inputs
|
226 |
+
if len(down_block_types) != len(up_block_types):
|
227 |
+
raise ValueError(
|
228 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
229 |
+
)
|
230 |
+
|
231 |
+
if len(block_out_channels) != len(down_block_types):
|
232 |
+
raise ValueError(
|
233 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
234 |
+
)
|
235 |
+
|
236 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
237 |
+
raise ValueError(
|
238 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
239 |
+
)
|
240 |
+
|
241 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
242 |
+
raise ValueError(
|
243 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
244 |
+
)
|
245 |
+
|
246 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
247 |
+
raise ValueError(
|
248 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
249 |
+
)
|
250 |
+
|
251 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
252 |
+
raise ValueError(
|
253 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
254 |
+
)
|
255 |
+
|
256 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
257 |
+
raise ValueError(
|
258 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
259 |
+
)
|
260 |
+
|
261 |
+
# input
|
262 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
263 |
+
self.conv_in = nn.Conv2d(
|
264 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
265 |
+
)
|
266 |
+
|
267 |
+
# time
|
268 |
+
if time_embedding_type == "fourier":
|
269 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
270 |
+
if time_embed_dim % 2 != 0:
|
271 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
272 |
+
self.time_proj = GaussianFourierProjection(
|
273 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
274 |
+
)
|
275 |
+
timestep_input_dim = time_embed_dim
|
276 |
+
elif time_embedding_type == "positional":
|
277 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
278 |
+
|
279 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
280 |
+
timestep_input_dim = block_out_channels[0]
|
281 |
+
else:
|
282 |
+
raise ValueError(
|
283 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
284 |
+
)
|
285 |
+
|
286 |
+
self.time_embedding = TimestepEmbedding(
|
287 |
+
timestep_input_dim,
|
288 |
+
time_embed_dim,
|
289 |
+
act_fn=act_fn,
|
290 |
+
post_act_fn=timestep_post_act,
|
291 |
+
cond_proj_dim=time_cond_proj_dim,
|
292 |
+
)
|
293 |
+
|
294 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
295 |
+
encoder_hid_dim_type = "text_proj"
|
296 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
297 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
298 |
+
|
299 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
300 |
+
raise ValueError(
|
301 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
302 |
+
)
|
303 |
+
|
304 |
+
if encoder_hid_dim_type == "text_proj":
|
305 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
306 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
307 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
308 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
309 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
310 |
+
self.encoder_hid_proj = TextImageProjection(
|
311 |
+
text_embed_dim=encoder_hid_dim,
|
312 |
+
image_embed_dim=cross_attention_dim,
|
313 |
+
cross_attention_dim=cross_attention_dim,
|
314 |
+
)
|
315 |
+
elif encoder_hid_dim_type == "image_proj":
|
316 |
+
# Kandinsky 2.2
|
317 |
+
self.encoder_hid_proj = ImageProjection(
|
318 |
+
image_embed_dim=encoder_hid_dim,
|
319 |
+
cross_attention_dim=cross_attention_dim,
|
320 |
+
)
|
321 |
+
elif encoder_hid_dim_type is not None:
|
322 |
+
raise ValueError(
|
323 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
324 |
+
)
|
325 |
+
else:
|
326 |
+
self.encoder_hid_proj = None
|
327 |
+
|
328 |
+
# class embedding
|
329 |
+
if class_embed_type is None and num_class_embeds is not None:
|
330 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
331 |
+
elif class_embed_type == "timestep":
|
332 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
333 |
+
elif class_embed_type == "identity":
|
334 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
335 |
+
elif class_embed_type == "projection":
|
336 |
+
if projection_class_embeddings_input_dim is None:
|
337 |
+
raise ValueError(
|
338 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
339 |
+
)
|
340 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
341 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
342 |
+
# 2. it projects from an arbitrary input dimension.
|
343 |
+
#
|
344 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
345 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
346 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
347 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
348 |
+
elif class_embed_type == "simple_projection":
|
349 |
+
if projection_class_embeddings_input_dim is None:
|
350 |
+
raise ValueError(
|
351 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
352 |
+
)
|
353 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
354 |
+
else:
|
355 |
+
self.class_embedding = None
|
356 |
+
|
357 |
+
if addition_embed_type == "text":
|
358 |
+
if encoder_hid_dim is not None:
|
359 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
360 |
+
else:
|
361 |
+
text_time_embedding_from_dim = cross_attention_dim
|
362 |
+
|
363 |
+
self.add_embedding = TextTimeEmbedding(
|
364 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
365 |
+
)
|
366 |
+
elif addition_embed_type == "text_image":
|
367 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
368 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
369 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
370 |
+
self.add_embedding = TextImageTimeEmbedding(
|
371 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
372 |
+
)
|
373 |
+
elif addition_embed_type == "text_time":
|
374 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
375 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
376 |
+
elif addition_embed_type == "image":
|
377 |
+
# Kandinsky 2.2
|
378 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
379 |
+
elif addition_embed_type == "image_hint":
|
380 |
+
# Kandinsky 2.2 ControlNet
|
381 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
382 |
+
elif addition_embed_type is not None:
|
383 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
384 |
+
|
385 |
+
if time_embedding_act_fn is None:
|
386 |
+
self.time_embed_act = None
|
387 |
+
else:
|
388 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
389 |
+
|
390 |
+
self.down_blocks = nn.ModuleList([])
|
391 |
+
self.up_blocks = nn.ModuleList([])
|
392 |
+
|
393 |
+
if isinstance(only_cross_attention, bool):
|
394 |
+
if mid_block_only_cross_attention is None:
|
395 |
+
mid_block_only_cross_attention = only_cross_attention
|
396 |
+
|
397 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
398 |
+
|
399 |
+
if mid_block_only_cross_attention is None:
|
400 |
+
mid_block_only_cross_attention = False
|
401 |
+
|
402 |
+
if isinstance(num_attention_heads, int):
|
403 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
404 |
+
|
405 |
+
if isinstance(attention_head_dim, int):
|
406 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
407 |
+
|
408 |
+
if isinstance(cross_attention_dim, int):
|
409 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
410 |
+
|
411 |
+
if isinstance(layers_per_block, int):
|
412 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
413 |
+
|
414 |
+
if isinstance(transformer_layers_per_block, int):
|
415 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
416 |
+
|
417 |
+
if class_embeddings_concat:
|
418 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
419 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
420 |
+
# regular time embeddings
|
421 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
422 |
+
else:
|
423 |
+
blocks_time_embed_dim = time_embed_dim
|
424 |
+
|
425 |
+
# down
|
426 |
+
output_channel = block_out_channels[0]
|
427 |
+
for i, down_block_type in enumerate(down_block_types):
|
428 |
+
input_channel = output_channel
|
429 |
+
output_channel = block_out_channels[i]
|
430 |
+
is_final_block = i == len(block_out_channels) - 1
|
431 |
+
|
432 |
+
down_block = get_down_block(
|
433 |
+
down_block_type,
|
434 |
+
num_layers=layers_per_block[i],
|
435 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
436 |
+
in_channels=input_channel,
|
437 |
+
out_channels=output_channel,
|
438 |
+
temb_channels=blocks_time_embed_dim,
|
439 |
+
add_downsample=not is_final_block,
|
440 |
+
resnet_eps=norm_eps,
|
441 |
+
resnet_act_fn=act_fn,
|
442 |
+
resnet_groups=norm_num_groups,
|
443 |
+
cross_attention_dim=cross_attention_dim[i],
|
444 |
+
num_attention_heads=num_attention_heads[i],
|
445 |
+
downsample_padding=downsample_padding,
|
446 |
+
dual_cross_attention=dual_cross_attention,
|
447 |
+
use_linear_projection=use_linear_projection,
|
448 |
+
only_cross_attention=only_cross_attention[i],
|
449 |
+
upcast_attention=upcast_attention,
|
450 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
451 |
+
attention_type=attention_type,
|
452 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
453 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
454 |
+
cross_attention_norm=cross_attention_norm,
|
455 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
456 |
+
)
|
457 |
+
self.down_blocks.append(down_block)
|
458 |
+
|
459 |
+
# mid
|
460 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
461 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
462 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
463 |
+
in_channels=block_out_channels[-1],
|
464 |
+
temb_channels=blocks_time_embed_dim,
|
465 |
+
resnet_eps=norm_eps,
|
466 |
+
resnet_act_fn=act_fn,
|
467 |
+
output_scale_factor=mid_block_scale_factor,
|
468 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
469 |
+
cross_attention_dim=cross_attention_dim[-1],
|
470 |
+
num_attention_heads=num_attention_heads[-1],
|
471 |
+
resnet_groups=norm_num_groups,
|
472 |
+
dual_cross_attention=dual_cross_attention,
|
473 |
+
use_linear_projection=use_linear_projection,
|
474 |
+
upcast_attention=upcast_attention,
|
475 |
+
attention_type=attention_type,
|
476 |
+
)
|
477 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
478 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
479 |
+
in_channels=block_out_channels[-1],
|
480 |
+
temb_channels=blocks_time_embed_dim,
|
481 |
+
resnet_eps=norm_eps,
|
482 |
+
resnet_act_fn=act_fn,
|
483 |
+
output_scale_factor=mid_block_scale_factor,
|
484 |
+
cross_attention_dim=cross_attention_dim[-1],
|
485 |
+
attention_head_dim=attention_head_dim[-1],
|
486 |
+
resnet_groups=norm_num_groups,
|
487 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
488 |
+
skip_time_act=resnet_skip_time_act,
|
489 |
+
only_cross_attention=mid_block_only_cross_attention,
|
490 |
+
cross_attention_norm=cross_attention_norm,
|
491 |
+
)
|
492 |
+
elif mid_block_type is None:
|
493 |
+
self.mid_block = None
|
494 |
+
else:
|
495 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
496 |
+
|
497 |
+
# count how many layers upsample the images
|
498 |
+
self.num_upsamplers = 0
|
499 |
+
|
500 |
+
# up
|
501 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
502 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
503 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
504 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
505 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
506 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
507 |
+
|
508 |
+
output_channel = reversed_block_out_channels[0]
|
509 |
+
for i, up_block_type in enumerate(up_block_types):
|
510 |
+
is_final_block = i == len(block_out_channels) - 1
|
511 |
+
|
512 |
+
prev_output_channel = output_channel
|
513 |
+
output_channel = reversed_block_out_channels[i]
|
514 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
515 |
+
|
516 |
+
# add upsample block for all BUT final layer
|
517 |
+
if not is_final_block:
|
518 |
+
add_upsample = True
|
519 |
+
self.num_upsamplers += 1
|
520 |
+
else:
|
521 |
+
add_upsample = False
|
522 |
+
|
523 |
+
up_block = get_up_block(
|
524 |
+
up_block_type,
|
525 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
526 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
527 |
+
in_channels=input_channel,
|
528 |
+
out_channels=output_channel,
|
529 |
+
prev_output_channel=prev_output_channel,
|
530 |
+
temb_channels=blocks_time_embed_dim,
|
531 |
+
add_upsample=add_upsample,
|
532 |
+
resnet_eps=norm_eps,
|
533 |
+
resnet_act_fn=act_fn,
|
534 |
+
resnet_groups=norm_num_groups,
|
535 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
536 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
537 |
+
dual_cross_attention=dual_cross_attention,
|
538 |
+
use_linear_projection=use_linear_projection,
|
539 |
+
only_cross_attention=only_cross_attention[i],
|
540 |
+
upcast_attention=upcast_attention,
|
541 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
542 |
+
attention_type=attention_type,
|
543 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
544 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
545 |
+
cross_attention_norm=cross_attention_norm,
|
546 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
547 |
+
)
|
548 |
+
self.up_blocks.append(up_block)
|
549 |
+
prev_output_channel = output_channel
|
550 |
+
|
551 |
+
# out
|
552 |
+
if norm_num_groups is not None:
|
553 |
+
self.conv_norm_out = nn.GroupNorm(
|
554 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
555 |
+
)
|
556 |
+
|
557 |
+
self.conv_act = get_activation(act_fn)
|
558 |
+
|
559 |
+
else:
|
560 |
+
self.conv_norm_out = None
|
561 |
+
self.conv_act = None
|
562 |
+
|
563 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
564 |
+
self.conv_out = nn.Conv2d(
|
565 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
566 |
+
)
|
567 |
+
|
568 |
+
if attention_type == "gated":
|
569 |
+
positive_len = 768
|
570 |
+
if isinstance(cross_attention_dim, int):
|
571 |
+
positive_len = cross_attention_dim
|
572 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
573 |
+
positive_len = cross_attention_dim[0]
|
574 |
+
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
|
575 |
+
|
576 |
+
@property
|
577 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
578 |
+
r"""
|
579 |
+
Returns:
|
580 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
581 |
+
indexed by its weight name.
|
582 |
+
"""
|
583 |
+
# set recursively
|
584 |
+
processors = {}
|
585 |
+
|
586 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
587 |
+
if hasattr(module, "set_processor"):
|
588 |
+
processors[f"{name}.processor"] = module.processor
|
589 |
+
|
590 |
+
for sub_name, child in module.named_children():
|
591 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
592 |
+
|
593 |
+
return processors
|
594 |
+
|
595 |
+
for name, module in self.named_children():
|
596 |
+
fn_recursive_add_processors(name, module, processors)
|
597 |
+
|
598 |
+
return processors
|
599 |
+
|
600 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
601 |
+
r"""
|
602 |
+
Sets the attention processor to use to compute attention.
|
603 |
+
|
604 |
+
Parameters:
|
605 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
606 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
607 |
+
for **all** `Attention` layers.
|
608 |
+
|
609 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
610 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
611 |
+
|
612 |
+
"""
|
613 |
+
count = len(self.attn_processors.keys())
|
614 |
+
|
615 |
+
if isinstance(processor, dict) and len(processor) != count:
|
616 |
+
raise ValueError(
|
617 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
618 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
619 |
+
)
|
620 |
+
|
621 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
622 |
+
if hasattr(module, "set_processor"):
|
623 |
+
if not isinstance(processor, dict):
|
624 |
+
module.set_processor(processor)
|
625 |
+
else:
|
626 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
627 |
+
|
628 |
+
for sub_name, child in module.named_children():
|
629 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
630 |
+
|
631 |
+
for name, module in self.named_children():
|
632 |
+
fn_recursive_attn_processor(name, module, processor)
|
633 |
+
|
634 |
+
def set_default_attn_processor(self):
|
635 |
+
"""
|
636 |
+
Disables custom attention processors and sets the default attention implementation.
|
637 |
+
"""
|
638 |
+
self.set_attn_processor(AttnProcessor())
|
639 |
+
|
640 |
+
def set_attention_slice(self, slice_size):
|
641 |
+
r"""
|
642 |
+
Enable sliced attention computation.
|
643 |
+
|
644 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
645 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
646 |
+
|
647 |
+
Args:
|
648 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
649 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
650 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
651 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
652 |
+
must be a multiple of `slice_size`.
|
653 |
+
"""
|
654 |
+
sliceable_head_dims = []
|
655 |
+
|
656 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
657 |
+
if hasattr(module, "set_attention_slice"):
|
658 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
659 |
+
|
660 |
+
for child in module.children():
|
661 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
662 |
+
|
663 |
+
# retrieve number of attention layers
|
664 |
+
for module in self.children():
|
665 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
666 |
+
|
667 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
668 |
+
|
669 |
+
if slice_size == "auto":
|
670 |
+
# half the attention head size is usually a good trade-off between
|
671 |
+
# speed and memory
|
672 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
673 |
+
elif slice_size == "max":
|
674 |
+
# make smallest slice possible
|
675 |
+
slice_size = num_sliceable_layers * [1]
|
676 |
+
|
677 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
678 |
+
|
679 |
+
if len(slice_size) != len(sliceable_head_dims):
|
680 |
+
raise ValueError(
|
681 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
682 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
683 |
+
)
|
684 |
+
|
685 |
+
for i in range(len(slice_size)):
|
686 |
+
size = slice_size[i]
|
687 |
+
dim = sliceable_head_dims[i]
|
688 |
+
if size is not None and size > dim:
|
689 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
690 |
+
|
691 |
+
# Recursively walk through all the children.
|
692 |
+
# Any children which exposes the set_attention_slice method
|
693 |
+
# gets the message
|
694 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
695 |
+
if hasattr(module, "set_attention_slice"):
|
696 |
+
module.set_attention_slice(slice_size.pop())
|
697 |
+
|
698 |
+
for child in module.children():
|
699 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
700 |
+
|
701 |
+
reversed_slice_size = list(reversed(slice_size))
|
702 |
+
for module in self.children():
|
703 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
704 |
+
|
705 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
706 |
+
if hasattr(module, "gradient_checkpointing"):
|
707 |
+
module.gradient_checkpointing = value
|
708 |
+
|
709 |
+
def forward(
|
710 |
+
self,
|
711 |
+
sample: torch.FloatTensor,
|
712 |
+
timestep: Union[torch.Tensor, float, int],
|
713 |
+
encoder_hidden_states: torch.Tensor,
|
714 |
+
class_labels: Optional[torch.Tensor] = None,
|
715 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
716 |
+
attention_mask: Optional[torch.Tensor] = None,
|
717 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
718 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
719 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
720 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
721 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
722 |
+
return_dict: bool = True,
|
723 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
724 |
+
r"""
|
725 |
+
The [`UNet2DConditionModel`] forward method.
|
726 |
+
|
727 |
+
Args:
|
728 |
+
sample (`torch.FloatTensor`):
|
729 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
730 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
731 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
732 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
733 |
+
encoder_attention_mask (`torch.Tensor`):
|
734 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
735 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
736 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
737 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
738 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
739 |
+
tuple.
|
740 |
+
cross_attention_kwargs (`dict`, *optional*):
|
741 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
742 |
+
added_cond_kwargs: (`dict`, *optional*):
|
743 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
744 |
+
are passed along to the UNet blocks.
|
745 |
+
|
746 |
+
Returns:
|
747 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
748 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
749 |
+
a `tuple` is returned where the first element is the sample tensor.
|
750 |
+
"""
|
751 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
752 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
753 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
754 |
+
# on the fly if necessary.
|
755 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
756 |
+
|
757 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
758 |
+
forward_upsample_size = False
|
759 |
+
upsample_size = None
|
760 |
+
|
761 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
762 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
763 |
+
forward_upsample_size = True
|
764 |
+
|
765 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
766 |
+
# expects mask of shape:
|
767 |
+
# [batch, key_tokens]
|
768 |
+
# adds singleton query_tokens dimension:
|
769 |
+
# [batch, 1, key_tokens]
|
770 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
771 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
772 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
773 |
+
if attention_mask is not None:
|
774 |
+
# assume that mask is expressed as:
|
775 |
+
# (1 = keep, 0 = discard)
|
776 |
+
# convert mask into a bias that can be added to attention scores:
|
777 |
+
# (keep = +0, discard = -10000.0)
|
778 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
779 |
+
attention_mask = attention_mask.unsqueeze(1)
|
780 |
+
|
781 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
782 |
+
if encoder_attention_mask is not None:
|
783 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
784 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
785 |
+
|
786 |
+
# 0. center input if necessary
|
787 |
+
if self.config.center_input_sample:
|
788 |
+
sample = 2 * sample - 1.0
|
789 |
+
|
790 |
+
# 1. time
|
791 |
+
timesteps = timestep
|
792 |
+
if not torch.is_tensor(timesteps):
|
793 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
794 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
795 |
+
is_mps = sample.device.type == "mps"
|
796 |
+
if isinstance(timestep, float):
|
797 |
+
dtype = torch.float32 if is_mps else torch.float64
|
798 |
+
else:
|
799 |
+
dtype = torch.int32 if is_mps else torch.int64
|
800 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
801 |
+
elif len(timesteps.shape) == 0:
|
802 |
+
timesteps = timesteps[None].to(sample.device)
|
803 |
+
|
804 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
805 |
+
timesteps = timesteps.expand(sample.shape[0])
|
806 |
+
|
807 |
+
t_emb = self.time_proj(timesteps)
|
808 |
+
|
809 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
810 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
811 |
+
# there might be better ways to encapsulate this.
|
812 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
813 |
+
|
814 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
815 |
+
aug_emb = None
|
816 |
+
|
817 |
+
if self.class_embedding is not None:
|
818 |
+
if class_labels is None:
|
819 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
820 |
+
|
821 |
+
if self.config.class_embed_type == "timestep":
|
822 |
+
class_labels = self.time_proj(class_labels)
|
823 |
+
|
824 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
825 |
+
# there might be better ways to encapsulate this.
|
826 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
827 |
+
|
828 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
829 |
+
|
830 |
+
if self.config.class_embeddings_concat:
|
831 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
832 |
+
else:
|
833 |
+
emb = emb + class_emb
|
834 |
+
|
835 |
+
if self.config.addition_embed_type == "text":
|
836 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
837 |
+
elif self.config.addition_embed_type == "text_image":
|
838 |
+
# Kandinsky 2.1 - style
|
839 |
+
if "image_embeds" not in added_cond_kwargs:
|
840 |
+
raise ValueError(
|
841 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
842 |
+
)
|
843 |
+
|
844 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
845 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
846 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
847 |
+
elif self.config.addition_embed_type == "text_time":
|
848 |
+
# SDXL - style
|
849 |
+
if "text_embeds" not in added_cond_kwargs:
|
850 |
+
raise ValueError(
|
851 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
852 |
+
)
|
853 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
854 |
+
if "time_ids" not in added_cond_kwargs:
|
855 |
+
raise ValueError(
|
856 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
857 |
+
)
|
858 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
859 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
860 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
861 |
+
|
862 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
863 |
+
add_embeds = add_embeds.to(emb.dtype)
|
864 |
+
aug_emb = self.add_embedding(add_embeds)
|
865 |
+
elif self.config.addition_embed_type == "image":
|
866 |
+
# Kandinsky 2.2 - style
|
867 |
+
if "image_embeds" not in added_cond_kwargs:
|
868 |
+
raise ValueError(
|
869 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
870 |
+
)
|
871 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
872 |
+
aug_emb = self.add_embedding(image_embs)
|
873 |
+
elif self.config.addition_embed_type == "image_hint":
|
874 |
+
# Kandinsky 2.2 - style
|
875 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
876 |
+
raise ValueError(
|
877 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
878 |
+
)
|
879 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
880 |
+
hint = added_cond_kwargs.get("hint")
|
881 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
882 |
+
sample = torch.cat([sample, hint], dim=1)
|
883 |
+
|
884 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
885 |
+
|
886 |
+
if self.time_embed_act is not None:
|
887 |
+
emb = self.time_embed_act(emb)
|
888 |
+
|
889 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
890 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
891 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
892 |
+
# Kadinsky 2.1 - style
|
893 |
+
if "image_embeds" not in added_cond_kwargs:
|
894 |
+
raise ValueError(
|
895 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
896 |
+
)
|
897 |
+
|
898 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
899 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
900 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
901 |
+
# Kandinsky 2.2 - style
|
902 |
+
if "image_embeds" not in added_cond_kwargs:
|
903 |
+
raise ValueError(
|
904 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
905 |
+
)
|
906 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
907 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
908 |
+
# 2. pre-process
|
909 |
+
sample = self.conv_in(sample)
|
910 |
+
|
911 |
+
# 2.5 GLIGEN position net
|
912 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
913 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
914 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
915 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
916 |
+
|
917 |
+
# 3. down
|
918 |
+
|
919 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
920 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
921 |
+
|
922 |
+
down_block_res_samples = (sample,)
|
923 |
+
for downsample_block in self.down_blocks:
|
924 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
925 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
926 |
+
additional_residuals = {}
|
927 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
928 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
929 |
+
|
930 |
+
sample, res_samples = downsample_block(
|
931 |
+
hidden_states=sample,
|
932 |
+
temb=emb,
|
933 |
+
encoder_hidden_states=encoder_hidden_states,
|
934 |
+
attention_mask=attention_mask,
|
935 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
936 |
+
encoder_attention_mask=encoder_attention_mask,
|
937 |
+
**additional_residuals,
|
938 |
+
)
|
939 |
+
else:
|
940 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
941 |
+
|
942 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
943 |
+
sample += down_block_additional_residuals.pop(0)
|
944 |
+
|
945 |
+
down_block_res_samples += res_samples
|
946 |
+
|
947 |
+
if is_controlnet:
|
948 |
+
new_down_block_res_samples = ()
|
949 |
+
|
950 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
951 |
+
down_block_res_samples, down_block_additional_residuals
|
952 |
+
):
|
953 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
954 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
955 |
+
|
956 |
+
down_block_res_samples = new_down_block_res_samples
|
957 |
+
|
958 |
+
# 4. mid
|
959 |
+
if self.mid_block is not None:
|
960 |
+
sample = self.mid_block(
|
961 |
+
sample,
|
962 |
+
emb,
|
963 |
+
encoder_hidden_states=encoder_hidden_states,
|
964 |
+
attention_mask=attention_mask,
|
965 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
966 |
+
encoder_attention_mask=encoder_attention_mask,
|
967 |
+
)
|
968 |
+
|
969 |
+
if is_controlnet:
|
970 |
+
sample = sample + mid_block_additional_residual
|
971 |
+
|
972 |
+
# 5. up
|
973 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
974 |
+
is_final_block = i == len(self.up_blocks) - 1
|
975 |
+
|
976 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
977 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
978 |
+
|
979 |
+
# if we have not reached the final block and need to forward the
|
980 |
+
# upsample size, we do it here
|
981 |
+
if not is_final_block and forward_upsample_size:
|
982 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
983 |
+
|
984 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
985 |
+
sample = upsample_block(
|
986 |
+
hidden_states=sample,
|
987 |
+
temb=emb,
|
988 |
+
res_hidden_states_tuple=res_samples,
|
989 |
+
encoder_hidden_states=encoder_hidden_states,
|
990 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
991 |
+
upsample_size=upsample_size,
|
992 |
+
attention_mask=attention_mask,
|
993 |
+
encoder_attention_mask=encoder_attention_mask,
|
994 |
+
)
|
995 |
+
else:
|
996 |
+
sample = upsample_block(
|
997 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
998 |
+
)
|
999 |
+
|
1000 |
+
# 6. post-process
|
1001 |
+
if self.conv_norm_out:
|
1002 |
+
sample = self.conv_norm_out(sample)
|
1003 |
+
sample = self.conv_act(sample)
|
1004 |
+
sample = self.conv_out(sample)
|
1005 |
+
|
1006 |
+
if not return_dict:
|
1007 |
+
return (sample,)
|
1008 |
+
|
1009 |
+
return UNet2DConditionOutput(sample=sample)
|
Tiger Model/diffusiers-Tiger/models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
|
19 |
+
from .transformer_2d import Transformer2DModel
|
20 |
+
from .transformer_temporal import TransformerTemporalModel
|
21 |
+
|
22 |
+
|
23 |
+
def get_down_block(
|
24 |
+
down_block_type,
|
25 |
+
num_layers,
|
26 |
+
in_channels,
|
27 |
+
out_channels,
|
28 |
+
temb_channels,
|
29 |
+
add_downsample,
|
30 |
+
resnet_eps,
|
31 |
+
resnet_act_fn,
|
32 |
+
num_attention_heads,
|
33 |
+
resnet_groups=None,
|
34 |
+
cross_attention_dim=None,
|
35 |
+
downsample_padding=None,
|
36 |
+
dual_cross_attention=False,
|
37 |
+
use_linear_projection=True,
|
38 |
+
only_cross_attention=False,
|
39 |
+
upcast_attention=False,
|
40 |
+
resnet_time_scale_shift="default",
|
41 |
+
):
|
42 |
+
if down_block_type == "DownBlock3D":
|
43 |
+
return DownBlock3D(
|
44 |
+
num_layers=num_layers,
|
45 |
+
in_channels=in_channels,
|
46 |
+
out_channels=out_channels,
|
47 |
+
temb_channels=temb_channels,
|
48 |
+
add_downsample=add_downsample,
|
49 |
+
resnet_eps=resnet_eps,
|
50 |
+
resnet_act_fn=resnet_act_fn,
|
51 |
+
resnet_groups=resnet_groups,
|
52 |
+
downsample_padding=downsample_padding,
|
53 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
54 |
+
)
|
55 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
56 |
+
if cross_attention_dim is None:
|
57 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
58 |
+
return CrossAttnDownBlock3D(
|
59 |
+
num_layers=num_layers,
|
60 |
+
in_channels=in_channels,
|
61 |
+
out_channels=out_channels,
|
62 |
+
temb_channels=temb_channels,
|
63 |
+
add_downsample=add_downsample,
|
64 |
+
resnet_eps=resnet_eps,
|
65 |
+
resnet_act_fn=resnet_act_fn,
|
66 |
+
resnet_groups=resnet_groups,
|
67 |
+
downsample_padding=downsample_padding,
|
68 |
+
cross_attention_dim=cross_attention_dim,
|
69 |
+
num_attention_heads=num_attention_heads,
|
70 |
+
dual_cross_attention=dual_cross_attention,
|
71 |
+
use_linear_projection=use_linear_projection,
|
72 |
+
only_cross_attention=only_cross_attention,
|
73 |
+
upcast_attention=upcast_attention,
|
74 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
75 |
+
)
|
76 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
77 |
+
|
78 |
+
|
79 |
+
def get_up_block(
|
80 |
+
up_block_type,
|
81 |
+
num_layers,
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
prev_output_channel,
|
85 |
+
temb_channels,
|
86 |
+
add_upsample,
|
87 |
+
resnet_eps,
|
88 |
+
resnet_act_fn,
|
89 |
+
num_attention_heads,
|
90 |
+
resnet_groups=None,
|
91 |
+
cross_attention_dim=None,
|
92 |
+
dual_cross_attention=False,
|
93 |
+
use_linear_projection=True,
|
94 |
+
only_cross_attention=False,
|
95 |
+
upcast_attention=False,
|
96 |
+
resnet_time_scale_shift="default",
|
97 |
+
):
|
98 |
+
if up_block_type == "UpBlock3D":
|
99 |
+
return UpBlock3D(
|
100 |
+
num_layers=num_layers,
|
101 |
+
in_channels=in_channels,
|
102 |
+
out_channels=out_channels,
|
103 |
+
prev_output_channel=prev_output_channel,
|
104 |
+
temb_channels=temb_channels,
|
105 |
+
add_upsample=add_upsample,
|
106 |
+
resnet_eps=resnet_eps,
|
107 |
+
resnet_act_fn=resnet_act_fn,
|
108 |
+
resnet_groups=resnet_groups,
|
109 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
110 |
+
)
|
111 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
112 |
+
if cross_attention_dim is None:
|
113 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
114 |
+
return CrossAttnUpBlock3D(
|
115 |
+
num_layers=num_layers,
|
116 |
+
in_channels=in_channels,
|
117 |
+
out_channels=out_channels,
|
118 |
+
prev_output_channel=prev_output_channel,
|
119 |
+
temb_channels=temb_channels,
|
120 |
+
add_upsample=add_upsample,
|
121 |
+
resnet_eps=resnet_eps,
|
122 |
+
resnet_act_fn=resnet_act_fn,
|
123 |
+
resnet_groups=resnet_groups,
|
124 |
+
cross_attention_dim=cross_attention_dim,
|
125 |
+
num_attention_heads=num_attention_heads,
|
126 |
+
dual_cross_attention=dual_cross_attention,
|
127 |
+
use_linear_projection=use_linear_projection,
|
128 |
+
only_cross_attention=only_cross_attention,
|
129 |
+
upcast_attention=upcast_attention,
|
130 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
131 |
+
)
|
132 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
133 |
+
|
134 |
+
|
135 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
in_channels: int,
|
139 |
+
temb_channels: int,
|
140 |
+
dropout: float = 0.0,
|
141 |
+
num_layers: int = 1,
|
142 |
+
resnet_eps: float = 1e-6,
|
143 |
+
resnet_time_scale_shift: str = "default",
|
144 |
+
resnet_act_fn: str = "swish",
|
145 |
+
resnet_groups: int = 32,
|
146 |
+
resnet_pre_norm: bool = True,
|
147 |
+
num_attention_heads=1,
|
148 |
+
output_scale_factor=1.0,
|
149 |
+
cross_attention_dim=1280,
|
150 |
+
dual_cross_attention=False,
|
151 |
+
use_linear_projection=True,
|
152 |
+
upcast_attention=False,
|
153 |
+
):
|
154 |
+
super().__init__()
|
155 |
+
|
156 |
+
self.has_cross_attention = True
|
157 |
+
self.num_attention_heads = num_attention_heads
|
158 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
159 |
+
|
160 |
+
# there is always at least one resnet
|
161 |
+
resnets = [
|
162 |
+
ResnetBlock2D(
|
163 |
+
in_channels=in_channels,
|
164 |
+
out_channels=in_channels,
|
165 |
+
temb_channels=temb_channels,
|
166 |
+
eps=resnet_eps,
|
167 |
+
groups=resnet_groups,
|
168 |
+
dropout=dropout,
|
169 |
+
time_embedding_norm=resnet_time_scale_shift,
|
170 |
+
non_linearity=resnet_act_fn,
|
171 |
+
output_scale_factor=output_scale_factor,
|
172 |
+
pre_norm=resnet_pre_norm,
|
173 |
+
)
|
174 |
+
]
|
175 |
+
temp_convs = [
|
176 |
+
TemporalConvLayer(
|
177 |
+
in_channels,
|
178 |
+
in_channels,
|
179 |
+
dropout=0.1,
|
180 |
+
)
|
181 |
+
]
|
182 |
+
attentions = []
|
183 |
+
temp_attentions = []
|
184 |
+
|
185 |
+
for _ in range(num_layers):
|
186 |
+
attentions.append(
|
187 |
+
Transformer2DModel(
|
188 |
+
in_channels // num_attention_heads,
|
189 |
+
num_attention_heads,
|
190 |
+
in_channels=in_channels,
|
191 |
+
num_layers=1,
|
192 |
+
cross_attention_dim=cross_attention_dim,
|
193 |
+
norm_num_groups=resnet_groups,
|
194 |
+
use_linear_projection=use_linear_projection,
|
195 |
+
upcast_attention=upcast_attention,
|
196 |
+
)
|
197 |
+
)
|
198 |
+
temp_attentions.append(
|
199 |
+
TransformerTemporalModel(
|
200 |
+
in_channels // num_attention_heads,
|
201 |
+
num_attention_heads,
|
202 |
+
in_channels=in_channels,
|
203 |
+
num_layers=1,
|
204 |
+
cross_attention_dim=cross_attention_dim,
|
205 |
+
norm_num_groups=resnet_groups,
|
206 |
+
)
|
207 |
+
)
|
208 |
+
resnets.append(
|
209 |
+
ResnetBlock2D(
|
210 |
+
in_channels=in_channels,
|
211 |
+
out_channels=in_channels,
|
212 |
+
temb_channels=temb_channels,
|
213 |
+
eps=resnet_eps,
|
214 |
+
groups=resnet_groups,
|
215 |
+
dropout=dropout,
|
216 |
+
time_embedding_norm=resnet_time_scale_shift,
|
217 |
+
non_linearity=resnet_act_fn,
|
218 |
+
output_scale_factor=output_scale_factor,
|
219 |
+
pre_norm=resnet_pre_norm,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
temp_convs.append(
|
223 |
+
TemporalConvLayer(
|
224 |
+
in_channels,
|
225 |
+
in_channels,
|
226 |
+
dropout=0.1,
|
227 |
+
)
|
228 |
+
)
|
229 |
+
|
230 |
+
self.resnets = nn.ModuleList(resnets)
|
231 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
232 |
+
self.attentions = nn.ModuleList(attentions)
|
233 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
234 |
+
|
235 |
+
def forward(
|
236 |
+
self,
|
237 |
+
hidden_states,
|
238 |
+
temb=None,
|
239 |
+
encoder_hidden_states=None,
|
240 |
+
attention_mask=None,
|
241 |
+
num_frames=1,
|
242 |
+
cross_attention_kwargs=None,
|
243 |
+
):
|
244 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
245 |
+
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
246 |
+
for attn, temp_attn, resnet, temp_conv in zip(
|
247 |
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
248 |
+
):
|
249 |
+
hidden_states = attn(
|
250 |
+
hidden_states,
|
251 |
+
encoder_hidden_states=encoder_hidden_states,
|
252 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
253 |
+
return_dict=False,
|
254 |
+
)[0]
|
255 |
+
hidden_states = temp_attn(
|
256 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
257 |
+
)[0]
|
258 |
+
hidden_states = resnet(hidden_states, temb)
|
259 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
260 |
+
|
261 |
+
return hidden_states
|
262 |
+
|
263 |
+
|
264 |
+
class CrossAttnDownBlock3D(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
in_channels: int,
|
268 |
+
out_channels: int,
|
269 |
+
temb_channels: int,
|
270 |
+
dropout: float = 0.0,
|
271 |
+
num_layers: int = 1,
|
272 |
+
resnet_eps: float = 1e-6,
|
273 |
+
resnet_time_scale_shift: str = "default",
|
274 |
+
resnet_act_fn: str = "swish",
|
275 |
+
resnet_groups: int = 32,
|
276 |
+
resnet_pre_norm: bool = True,
|
277 |
+
num_attention_heads=1,
|
278 |
+
cross_attention_dim=1280,
|
279 |
+
output_scale_factor=1.0,
|
280 |
+
downsample_padding=1,
|
281 |
+
add_downsample=True,
|
282 |
+
dual_cross_attention=False,
|
283 |
+
use_linear_projection=False,
|
284 |
+
only_cross_attention=False,
|
285 |
+
upcast_attention=False,
|
286 |
+
):
|
287 |
+
super().__init__()
|
288 |
+
resnets = []
|
289 |
+
attentions = []
|
290 |
+
temp_attentions = []
|
291 |
+
temp_convs = []
|
292 |
+
|
293 |
+
self.has_cross_attention = True
|
294 |
+
self.num_attention_heads = num_attention_heads
|
295 |
+
|
296 |
+
for i in range(num_layers):
|
297 |
+
in_channels = in_channels if i == 0 else out_channels
|
298 |
+
resnets.append(
|
299 |
+
ResnetBlock2D(
|
300 |
+
in_channels=in_channels,
|
301 |
+
out_channels=out_channels,
|
302 |
+
temb_channels=temb_channels,
|
303 |
+
eps=resnet_eps,
|
304 |
+
groups=resnet_groups,
|
305 |
+
dropout=dropout,
|
306 |
+
time_embedding_norm=resnet_time_scale_shift,
|
307 |
+
non_linearity=resnet_act_fn,
|
308 |
+
output_scale_factor=output_scale_factor,
|
309 |
+
pre_norm=resnet_pre_norm,
|
310 |
+
)
|
311 |
+
)
|
312 |
+
temp_convs.append(
|
313 |
+
TemporalConvLayer(
|
314 |
+
out_channels,
|
315 |
+
out_channels,
|
316 |
+
dropout=0.1,
|
317 |
+
)
|
318 |
+
)
|
319 |
+
attentions.append(
|
320 |
+
Transformer2DModel(
|
321 |
+
out_channels // num_attention_heads,
|
322 |
+
num_attention_heads,
|
323 |
+
in_channels=out_channels,
|
324 |
+
num_layers=1,
|
325 |
+
cross_attention_dim=cross_attention_dim,
|
326 |
+
norm_num_groups=resnet_groups,
|
327 |
+
use_linear_projection=use_linear_projection,
|
328 |
+
only_cross_attention=only_cross_attention,
|
329 |
+
upcast_attention=upcast_attention,
|
330 |
+
)
|
331 |
+
)
|
332 |
+
temp_attentions.append(
|
333 |
+
TransformerTemporalModel(
|
334 |
+
out_channels // num_attention_heads,
|
335 |
+
num_attention_heads,
|
336 |
+
in_channels=out_channels,
|
337 |
+
num_layers=1,
|
338 |
+
cross_attention_dim=cross_attention_dim,
|
339 |
+
norm_num_groups=resnet_groups,
|
340 |
+
)
|
341 |
+
)
|
342 |
+
self.resnets = nn.ModuleList(resnets)
|
343 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
344 |
+
self.attentions = nn.ModuleList(attentions)
|
345 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
346 |
+
|
347 |
+
if add_downsample:
|
348 |
+
self.downsamplers = nn.ModuleList(
|
349 |
+
[
|
350 |
+
Downsample2D(
|
351 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
352 |
+
)
|
353 |
+
]
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
self.downsamplers = None
|
357 |
+
|
358 |
+
self.gradient_checkpointing = False
|
359 |
+
|
360 |
+
def forward(
|
361 |
+
self,
|
362 |
+
hidden_states,
|
363 |
+
temb=None,
|
364 |
+
encoder_hidden_states=None,
|
365 |
+
attention_mask=None,
|
366 |
+
num_frames=1,
|
367 |
+
cross_attention_kwargs=None,
|
368 |
+
):
|
369 |
+
# TODO(Patrick, William) - attention mask is not used
|
370 |
+
output_states = ()
|
371 |
+
|
372 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
373 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
374 |
+
):
|
375 |
+
hidden_states = resnet(hidden_states, temb)
|
376 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
377 |
+
hidden_states = attn(
|
378 |
+
hidden_states,
|
379 |
+
encoder_hidden_states=encoder_hidden_states,
|
380 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
381 |
+
return_dict=False,
|
382 |
+
)[0]
|
383 |
+
hidden_states = temp_attn(
|
384 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
385 |
+
)[0]
|
386 |
+
|
387 |
+
output_states += (hidden_states,)
|
388 |
+
|
389 |
+
if self.downsamplers is not None:
|
390 |
+
for downsampler in self.downsamplers:
|
391 |
+
hidden_states = downsampler(hidden_states)
|
392 |
+
|
393 |
+
output_states += (hidden_states,)
|
394 |
+
|
395 |
+
return hidden_states, output_states
|
396 |
+
|
397 |
+
|
398 |
+
class DownBlock3D(nn.Module):
|
399 |
+
def __init__(
|
400 |
+
self,
|
401 |
+
in_channels: int,
|
402 |
+
out_channels: int,
|
403 |
+
temb_channels: int,
|
404 |
+
dropout: float = 0.0,
|
405 |
+
num_layers: int = 1,
|
406 |
+
resnet_eps: float = 1e-6,
|
407 |
+
resnet_time_scale_shift: str = "default",
|
408 |
+
resnet_act_fn: str = "swish",
|
409 |
+
resnet_groups: int = 32,
|
410 |
+
resnet_pre_norm: bool = True,
|
411 |
+
output_scale_factor=1.0,
|
412 |
+
add_downsample=True,
|
413 |
+
downsample_padding=1,
|
414 |
+
):
|
415 |
+
super().__init__()
|
416 |
+
resnets = []
|
417 |
+
temp_convs = []
|
418 |
+
|
419 |
+
for i in range(num_layers):
|
420 |
+
in_channels = in_channels if i == 0 else out_channels
|
421 |
+
resnets.append(
|
422 |
+
ResnetBlock2D(
|
423 |
+
in_channels=in_channels,
|
424 |
+
out_channels=out_channels,
|
425 |
+
temb_channels=temb_channels,
|
426 |
+
eps=resnet_eps,
|
427 |
+
groups=resnet_groups,
|
428 |
+
dropout=dropout,
|
429 |
+
time_embedding_norm=resnet_time_scale_shift,
|
430 |
+
non_linearity=resnet_act_fn,
|
431 |
+
output_scale_factor=output_scale_factor,
|
432 |
+
pre_norm=resnet_pre_norm,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
temp_convs.append(
|
436 |
+
TemporalConvLayer(
|
437 |
+
out_channels,
|
438 |
+
out_channels,
|
439 |
+
dropout=0.1,
|
440 |
+
)
|
441 |
+
)
|
442 |
+
|
443 |
+
self.resnets = nn.ModuleList(resnets)
|
444 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
445 |
+
|
446 |
+
if add_downsample:
|
447 |
+
self.downsamplers = nn.ModuleList(
|
448 |
+
[
|
449 |
+
Downsample2D(
|
450 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
451 |
+
)
|
452 |
+
]
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
self.downsamplers = None
|
456 |
+
|
457 |
+
self.gradient_checkpointing = False
|
458 |
+
|
459 |
+
def forward(self, hidden_states, temb=None, num_frames=1):
|
460 |
+
output_states = ()
|
461 |
+
|
462 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
463 |
+
hidden_states = resnet(hidden_states, temb)
|
464 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
465 |
+
|
466 |
+
output_states += (hidden_states,)
|
467 |
+
|
468 |
+
if self.downsamplers is not None:
|
469 |
+
for downsampler in self.downsamplers:
|
470 |
+
hidden_states = downsampler(hidden_states)
|
471 |
+
|
472 |
+
output_states += (hidden_states,)
|
473 |
+
|
474 |
+
return hidden_states, output_states
|
475 |
+
|
476 |
+
|
477 |
+
class CrossAttnUpBlock3D(nn.Module):
|
478 |
+
def __init__(
|
479 |
+
self,
|
480 |
+
in_channels: int,
|
481 |
+
out_channels: int,
|
482 |
+
prev_output_channel: int,
|
483 |
+
temb_channels: int,
|
484 |
+
dropout: float = 0.0,
|
485 |
+
num_layers: int = 1,
|
486 |
+
resnet_eps: float = 1e-6,
|
487 |
+
resnet_time_scale_shift: str = "default",
|
488 |
+
resnet_act_fn: str = "swish",
|
489 |
+
resnet_groups: int = 32,
|
490 |
+
resnet_pre_norm: bool = True,
|
491 |
+
num_attention_heads=1,
|
492 |
+
cross_attention_dim=1280,
|
493 |
+
output_scale_factor=1.0,
|
494 |
+
add_upsample=True,
|
495 |
+
dual_cross_attention=False,
|
496 |
+
use_linear_projection=False,
|
497 |
+
only_cross_attention=False,
|
498 |
+
upcast_attention=False,
|
499 |
+
):
|
500 |
+
super().__init__()
|
501 |
+
resnets = []
|
502 |
+
temp_convs = []
|
503 |
+
attentions = []
|
504 |
+
temp_attentions = []
|
505 |
+
|
506 |
+
self.has_cross_attention = True
|
507 |
+
self.num_attention_heads = num_attention_heads
|
508 |
+
|
509 |
+
for i in range(num_layers):
|
510 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
511 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
512 |
+
|
513 |
+
resnets.append(
|
514 |
+
ResnetBlock2D(
|
515 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
516 |
+
out_channels=out_channels,
|
517 |
+
temb_channels=temb_channels,
|
518 |
+
eps=resnet_eps,
|
519 |
+
groups=resnet_groups,
|
520 |
+
dropout=dropout,
|
521 |
+
time_embedding_norm=resnet_time_scale_shift,
|
522 |
+
non_linearity=resnet_act_fn,
|
523 |
+
output_scale_factor=output_scale_factor,
|
524 |
+
pre_norm=resnet_pre_norm,
|
525 |
+
)
|
526 |
+
)
|
527 |
+
temp_convs.append(
|
528 |
+
TemporalConvLayer(
|
529 |
+
out_channels,
|
530 |
+
out_channels,
|
531 |
+
dropout=0.1,
|
532 |
+
)
|
533 |
+
)
|
534 |
+
attentions.append(
|
535 |
+
Transformer2DModel(
|
536 |
+
out_channels // num_attention_heads,
|
537 |
+
num_attention_heads,
|
538 |
+
in_channels=out_channels,
|
539 |
+
num_layers=1,
|
540 |
+
cross_attention_dim=cross_attention_dim,
|
541 |
+
norm_num_groups=resnet_groups,
|
542 |
+
use_linear_projection=use_linear_projection,
|
543 |
+
only_cross_attention=only_cross_attention,
|
544 |
+
upcast_attention=upcast_attention,
|
545 |
+
)
|
546 |
+
)
|
547 |
+
temp_attentions.append(
|
548 |
+
TransformerTemporalModel(
|
549 |
+
out_channels // num_attention_heads,
|
550 |
+
num_attention_heads,
|
551 |
+
in_channels=out_channels,
|
552 |
+
num_layers=1,
|
553 |
+
cross_attention_dim=cross_attention_dim,
|
554 |
+
norm_num_groups=resnet_groups,
|
555 |
+
)
|
556 |
+
)
|
557 |
+
self.resnets = nn.ModuleList(resnets)
|
558 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
559 |
+
self.attentions = nn.ModuleList(attentions)
|
560 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
561 |
+
|
562 |
+
if add_upsample:
|
563 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
564 |
+
else:
|
565 |
+
self.upsamplers = None
|
566 |
+
|
567 |
+
self.gradient_checkpointing = False
|
568 |
+
|
569 |
+
def forward(
|
570 |
+
self,
|
571 |
+
hidden_states,
|
572 |
+
res_hidden_states_tuple,
|
573 |
+
temb=None,
|
574 |
+
encoder_hidden_states=None,
|
575 |
+
upsample_size=None,
|
576 |
+
attention_mask=None,
|
577 |
+
num_frames=1,
|
578 |
+
cross_attention_kwargs=None,
|
579 |
+
):
|
580 |
+
# TODO(Patrick, William) - attention mask is not used
|
581 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
582 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
583 |
+
):
|
584 |
+
# pop res hidden states
|
585 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
586 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
587 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
588 |
+
|
589 |
+
hidden_states = resnet(hidden_states, temb)
|
590 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
591 |
+
hidden_states = attn(
|
592 |
+
hidden_states,
|
593 |
+
encoder_hidden_states=encoder_hidden_states,
|
594 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
595 |
+
return_dict=False,
|
596 |
+
)[0]
|
597 |
+
hidden_states = temp_attn(
|
598 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
599 |
+
)[0]
|
600 |
+
|
601 |
+
if self.upsamplers is not None:
|
602 |
+
for upsampler in self.upsamplers:
|
603 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
604 |
+
|
605 |
+
return hidden_states
|
606 |
+
|
607 |
+
|
608 |
+
class UpBlock3D(nn.Module):
|
609 |
+
def __init__(
|
610 |
+
self,
|
611 |
+
in_channels: int,
|
612 |
+
prev_output_channel: int,
|
613 |
+
out_channels: int,
|
614 |
+
temb_channels: int,
|
615 |
+
dropout: float = 0.0,
|
616 |
+
num_layers: int = 1,
|
617 |
+
resnet_eps: float = 1e-6,
|
618 |
+
resnet_time_scale_shift: str = "default",
|
619 |
+
resnet_act_fn: str = "swish",
|
620 |
+
resnet_groups: int = 32,
|
621 |
+
resnet_pre_norm: bool = True,
|
622 |
+
output_scale_factor=1.0,
|
623 |
+
add_upsample=True,
|
624 |
+
):
|
625 |
+
super().__init__()
|
626 |
+
resnets = []
|
627 |
+
temp_convs = []
|
628 |
+
|
629 |
+
for i in range(num_layers):
|
630 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
631 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
632 |
+
|
633 |
+
resnets.append(
|
634 |
+
ResnetBlock2D(
|
635 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
636 |
+
out_channels=out_channels,
|
637 |
+
temb_channels=temb_channels,
|
638 |
+
eps=resnet_eps,
|
639 |
+
groups=resnet_groups,
|
640 |
+
dropout=dropout,
|
641 |
+
time_embedding_norm=resnet_time_scale_shift,
|
642 |
+
non_linearity=resnet_act_fn,
|
643 |
+
output_scale_factor=output_scale_factor,
|
644 |
+
pre_norm=resnet_pre_norm,
|
645 |
+
)
|
646 |
+
)
|
647 |
+
temp_convs.append(
|
648 |
+
TemporalConvLayer(
|
649 |
+
out_channels,
|
650 |
+
out_channels,
|
651 |
+
dropout=0.1,
|
652 |
+
)
|
653 |
+
)
|
654 |
+
|
655 |
+
self.resnets = nn.ModuleList(resnets)
|
656 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
657 |
+
|
658 |
+
if add_upsample:
|
659 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
660 |
+
else:
|
661 |
+
self.upsamplers = None
|
662 |
+
|
663 |
+
self.gradient_checkpointing = False
|
664 |
+
|
665 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
666 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
667 |
+
# pop res hidden states
|
668 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
669 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
670 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
671 |
+
|
672 |
+
hidden_states = resnet(hidden_states, temb)
|
673 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
674 |
+
|
675 |
+
if self.upsamplers is not None:
|
676 |
+
for upsampler in self.upsamplers:
|
677 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
678 |
+
|
679 |
+
return hidden_states
|
Tiger Model/diffusiers-Tiger/models/unet_3d_condition.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2 |
+
# Copyright 2023 The ModelScope Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
|
22 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from ..loaders import UNet2DConditionLoadersMixin
|
24 |
+
from ..utils import BaseOutput, logging
|
25 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
26 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
27 |
+
from .modeling_utils import ModelMixin
|
28 |
+
from .transformer_temporal import TransformerTemporalModel
|
29 |
+
from .unet_3d_blocks import (
|
30 |
+
CrossAttnDownBlock3D,
|
31 |
+
CrossAttnUpBlock3D,
|
32 |
+
DownBlock3D,
|
33 |
+
UNetMidBlock3DCrossAttn,
|
34 |
+
UpBlock3D,
|
35 |
+
get_down_block,
|
36 |
+
get_up_block,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class UNet3DConditionOutput(BaseOutput):
|
45 |
+
"""
|
46 |
+
The output of [`UNet3DConditionModel`].
|
47 |
+
|
48 |
+
Args:
|
49 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
50 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
51 |
+
"""
|
52 |
+
|
53 |
+
sample: torch.FloatTensor
|
54 |
+
|
55 |
+
|
56 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
57 |
+
r"""
|
58 |
+
A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
59 |
+
shaped output.
|
60 |
+
|
61 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
62 |
+
for all models (such as downloading or saving).
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
66 |
+
Height and width of input/output sample.
|
67 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
68 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
69 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
70 |
+
The tuple of downsample blocks to use.
|
71 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
72 |
+
The tuple of upsample blocks to use.
|
73 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
74 |
+
The tuple of output channels for each block.
|
75 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
76 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
77 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
78 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
79 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
80 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
81 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
82 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
83 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
84 |
+
num_attention_heads (`int`, *optional*): The number of attention heads.
|
85 |
+
"""
|
86 |
+
|
87 |
+
_supports_gradient_checkpointing = False
|
88 |
+
|
89 |
+
@register_to_config
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
sample_size: Optional[int] = None,
|
93 |
+
in_channels: int = 4,
|
94 |
+
out_channels: int = 4,
|
95 |
+
down_block_types: Tuple[str] = (
|
96 |
+
"CrossAttnDownBlock3D",
|
97 |
+
"CrossAttnDownBlock3D",
|
98 |
+
"CrossAttnDownBlock3D",
|
99 |
+
"DownBlock3D",
|
100 |
+
),
|
101 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
102 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
103 |
+
layers_per_block: int = 2,
|
104 |
+
downsample_padding: int = 1,
|
105 |
+
mid_block_scale_factor: float = 1,
|
106 |
+
act_fn: str = "silu",
|
107 |
+
norm_num_groups: Optional[int] = 32,
|
108 |
+
norm_eps: float = 1e-5,
|
109 |
+
cross_attention_dim: int = 1024,
|
110 |
+
attention_head_dim: Union[int, Tuple[int]] = 64,
|
111 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
self.sample_size = sample_size
|
116 |
+
|
117 |
+
if num_attention_heads is not None:
|
118 |
+
raise NotImplementedError(
|
119 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
120 |
+
)
|
121 |
+
|
122 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
123 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
124 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
125 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
126 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
127 |
+
# which is why we correct for the naming here.
|
128 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
129 |
+
|
130 |
+
# Check inputs
|
131 |
+
if len(down_block_types) != len(up_block_types):
|
132 |
+
raise ValueError(
|
133 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
134 |
+
)
|
135 |
+
|
136 |
+
if len(block_out_channels) != len(down_block_types):
|
137 |
+
raise ValueError(
|
138 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
139 |
+
)
|
140 |
+
|
141 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
142 |
+
raise ValueError(
|
143 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
144 |
+
)
|
145 |
+
|
146 |
+
# input
|
147 |
+
conv_in_kernel = 3
|
148 |
+
conv_out_kernel = 3
|
149 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
150 |
+
self.conv_in = nn.Conv2d(
|
151 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
152 |
+
)
|
153 |
+
|
154 |
+
# time
|
155 |
+
time_embed_dim = block_out_channels[0] * 4
|
156 |
+
self.time_proj = Timesteps(block_out_channels[0], True, 0)
|
157 |
+
timestep_input_dim = block_out_channels[0]
|
158 |
+
|
159 |
+
self.time_embedding = TimestepEmbedding(
|
160 |
+
timestep_input_dim,
|
161 |
+
time_embed_dim,
|
162 |
+
act_fn=act_fn,
|
163 |
+
)
|
164 |
+
|
165 |
+
self.transformer_in = TransformerTemporalModel(
|
166 |
+
num_attention_heads=8,
|
167 |
+
attention_head_dim=attention_head_dim,
|
168 |
+
in_channels=block_out_channels[0],
|
169 |
+
num_layers=1,
|
170 |
+
)
|
171 |
+
|
172 |
+
# class embedding
|
173 |
+
self.down_blocks = nn.ModuleList([])
|
174 |
+
self.up_blocks = nn.ModuleList([])
|
175 |
+
|
176 |
+
if isinstance(num_attention_heads, int):
|
177 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
178 |
+
|
179 |
+
# down
|
180 |
+
output_channel = block_out_channels[0]
|
181 |
+
for i, down_block_type in enumerate(down_block_types):
|
182 |
+
input_channel = output_channel
|
183 |
+
output_channel = block_out_channels[i]
|
184 |
+
is_final_block = i == len(block_out_channels) - 1
|
185 |
+
|
186 |
+
down_block = get_down_block(
|
187 |
+
down_block_type,
|
188 |
+
num_layers=layers_per_block,
|
189 |
+
in_channels=input_channel,
|
190 |
+
out_channels=output_channel,
|
191 |
+
temb_channels=time_embed_dim,
|
192 |
+
add_downsample=not is_final_block,
|
193 |
+
resnet_eps=norm_eps,
|
194 |
+
resnet_act_fn=act_fn,
|
195 |
+
resnet_groups=norm_num_groups,
|
196 |
+
cross_attention_dim=cross_attention_dim,
|
197 |
+
num_attention_heads=num_attention_heads[i],
|
198 |
+
downsample_padding=downsample_padding,
|
199 |
+
dual_cross_attention=False,
|
200 |
+
)
|
201 |
+
self.down_blocks.append(down_block)
|
202 |
+
|
203 |
+
# mid
|
204 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
205 |
+
in_channels=block_out_channels[-1],
|
206 |
+
temb_channels=time_embed_dim,
|
207 |
+
resnet_eps=norm_eps,
|
208 |
+
resnet_act_fn=act_fn,
|
209 |
+
output_scale_factor=mid_block_scale_factor,
|
210 |
+
cross_attention_dim=cross_attention_dim,
|
211 |
+
num_attention_heads=num_attention_heads[-1],
|
212 |
+
resnet_groups=norm_num_groups,
|
213 |
+
dual_cross_attention=False,
|
214 |
+
)
|
215 |
+
|
216 |
+
# count how many layers upsample the images
|
217 |
+
self.num_upsamplers = 0
|
218 |
+
|
219 |
+
# up
|
220 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
221 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
222 |
+
|
223 |
+
output_channel = reversed_block_out_channels[0]
|
224 |
+
for i, up_block_type in enumerate(up_block_types):
|
225 |
+
is_final_block = i == len(block_out_channels) - 1
|
226 |
+
|
227 |
+
prev_output_channel = output_channel
|
228 |
+
output_channel = reversed_block_out_channels[i]
|
229 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
230 |
+
|
231 |
+
# add upsample block for all BUT final layer
|
232 |
+
if not is_final_block:
|
233 |
+
add_upsample = True
|
234 |
+
self.num_upsamplers += 1
|
235 |
+
else:
|
236 |
+
add_upsample = False
|
237 |
+
|
238 |
+
up_block = get_up_block(
|
239 |
+
up_block_type,
|
240 |
+
num_layers=layers_per_block + 1,
|
241 |
+
in_channels=input_channel,
|
242 |
+
out_channels=output_channel,
|
243 |
+
prev_output_channel=prev_output_channel,
|
244 |
+
temb_channels=time_embed_dim,
|
245 |
+
add_upsample=add_upsample,
|
246 |
+
resnet_eps=norm_eps,
|
247 |
+
resnet_act_fn=act_fn,
|
248 |
+
resnet_groups=norm_num_groups,
|
249 |
+
cross_attention_dim=cross_attention_dim,
|
250 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
251 |
+
dual_cross_attention=False,
|
252 |
+
)
|
253 |
+
self.up_blocks.append(up_block)
|
254 |
+
prev_output_channel = output_channel
|
255 |
+
|
256 |
+
# out
|
257 |
+
if norm_num_groups is not None:
|
258 |
+
self.conv_norm_out = nn.GroupNorm(
|
259 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
260 |
+
)
|
261 |
+
self.conv_act = nn.SiLU()
|
262 |
+
else:
|
263 |
+
self.conv_norm_out = None
|
264 |
+
self.conv_act = None
|
265 |
+
|
266 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
267 |
+
self.conv_out = nn.Conv2d(
|
268 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
269 |
+
)
|
270 |
+
|
271 |
+
@property
|
272 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
273 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
274 |
+
r"""
|
275 |
+
Returns:
|
276 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
277 |
+
indexed by its weight name.
|
278 |
+
"""
|
279 |
+
# set recursively
|
280 |
+
processors = {}
|
281 |
+
|
282 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
283 |
+
if hasattr(module, "set_processor"):
|
284 |
+
processors[f"{name}.processor"] = module.processor
|
285 |
+
|
286 |
+
for sub_name, child in module.named_children():
|
287 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
288 |
+
|
289 |
+
return processors
|
290 |
+
|
291 |
+
for name, module in self.named_children():
|
292 |
+
fn_recursive_add_processors(name, module, processors)
|
293 |
+
|
294 |
+
return processors
|
295 |
+
|
296 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
297 |
+
def set_attention_slice(self, slice_size):
|
298 |
+
r"""
|
299 |
+
Enable sliced attention computation.
|
300 |
+
|
301 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
302 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
306 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
307 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
308 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
309 |
+
must be a multiple of `slice_size`.
|
310 |
+
"""
|
311 |
+
sliceable_head_dims = []
|
312 |
+
|
313 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
314 |
+
if hasattr(module, "set_attention_slice"):
|
315 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
316 |
+
|
317 |
+
for child in module.children():
|
318 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
319 |
+
|
320 |
+
# retrieve number of attention layers
|
321 |
+
for module in self.children():
|
322 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
323 |
+
|
324 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
325 |
+
|
326 |
+
if slice_size == "auto":
|
327 |
+
# half the attention head size is usually a good trade-off between
|
328 |
+
# speed and memory
|
329 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
330 |
+
elif slice_size == "max":
|
331 |
+
# make smallest slice possible
|
332 |
+
slice_size = num_sliceable_layers * [1]
|
333 |
+
|
334 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
335 |
+
|
336 |
+
if len(slice_size) != len(sliceable_head_dims):
|
337 |
+
raise ValueError(
|
338 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
339 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
340 |
+
)
|
341 |
+
|
342 |
+
for i in range(len(slice_size)):
|
343 |
+
size = slice_size[i]
|
344 |
+
dim = sliceable_head_dims[i]
|
345 |
+
if size is not None and size > dim:
|
346 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
347 |
+
|
348 |
+
# Recursively walk through all the children.
|
349 |
+
# Any children which exposes the set_attention_slice method
|
350 |
+
# gets the message
|
351 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
352 |
+
if hasattr(module, "set_attention_slice"):
|
353 |
+
module.set_attention_slice(slice_size.pop())
|
354 |
+
|
355 |
+
for child in module.children():
|
356 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
357 |
+
|
358 |
+
reversed_slice_size = list(reversed(slice_size))
|
359 |
+
for module in self.children():
|
360 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
361 |
+
|
362 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
363 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
364 |
+
r"""
|
365 |
+
Sets the attention processor to use to compute attention.
|
366 |
+
|
367 |
+
Parameters:
|
368 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
369 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
370 |
+
for **all** `Attention` layers.
|
371 |
+
|
372 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
373 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
374 |
+
|
375 |
+
"""
|
376 |
+
count = len(self.attn_processors.keys())
|
377 |
+
|
378 |
+
if isinstance(processor, dict) and len(processor) != count:
|
379 |
+
raise ValueError(
|
380 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
381 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
382 |
+
)
|
383 |
+
|
384 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
385 |
+
if hasattr(module, "set_processor"):
|
386 |
+
if not isinstance(processor, dict):
|
387 |
+
module.set_processor(processor)
|
388 |
+
else:
|
389 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
390 |
+
|
391 |
+
for sub_name, child in module.named_children():
|
392 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
393 |
+
|
394 |
+
for name, module in self.named_children():
|
395 |
+
fn_recursive_attn_processor(name, module, processor)
|
396 |
+
|
397 |
+
def enable_forward_chunking(self, chunk_size=None, dim=0):
|
398 |
+
"""
|
399 |
+
Sets the attention processor to use [feed forward
|
400 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
401 |
+
|
402 |
+
Parameters:
|
403 |
+
chunk_size (`int`, *optional*):
|
404 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
405 |
+
over each tensor of dim=`dim`.
|
406 |
+
dim (`int`, *optional*, defaults to `0`):
|
407 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
408 |
+
or dim=1 (sequence length).
|
409 |
+
"""
|
410 |
+
if dim not in [0, 1]:
|
411 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
412 |
+
|
413 |
+
# By default chunk size is 1
|
414 |
+
chunk_size = chunk_size or 1
|
415 |
+
|
416 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
417 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
418 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
419 |
+
|
420 |
+
for child in module.children():
|
421 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
422 |
+
|
423 |
+
for module in self.children():
|
424 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
425 |
+
|
426 |
+
def disable_forward_chunking(self):
|
427 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
428 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
429 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
430 |
+
|
431 |
+
for child in module.children():
|
432 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
433 |
+
|
434 |
+
for module in self.children():
|
435 |
+
fn_recursive_feed_forward(module, None, 0)
|
436 |
+
|
437 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
438 |
+
def set_default_attn_processor(self):
|
439 |
+
"""
|
440 |
+
Disables custom attention processors and sets the default attention implementation.
|
441 |
+
"""
|
442 |
+
self.set_attn_processor(AttnProcessor())
|
443 |
+
|
444 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
445 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
446 |
+
module.gradient_checkpointing = value
|
447 |
+
|
448 |
+
def forward(
|
449 |
+
self,
|
450 |
+
sample: torch.FloatTensor,
|
451 |
+
timestep: Union[torch.Tensor, float, int],
|
452 |
+
encoder_hidden_states: torch.Tensor,
|
453 |
+
class_labels: Optional[torch.Tensor] = None,
|
454 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
455 |
+
attention_mask: Optional[torch.Tensor] = None,
|
456 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
457 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
458 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
459 |
+
return_dict: bool = True,
|
460 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
461 |
+
r"""
|
462 |
+
The [`UNet3DConditionModel`] forward method.
|
463 |
+
|
464 |
+
Args:
|
465 |
+
sample (`torch.FloatTensor`):
|
466 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
|
467 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
468 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
469 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
470 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
471 |
+
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
|
472 |
+
tuple.
|
473 |
+
cross_attention_kwargs (`dict`, *optional*):
|
474 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
[`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
|
478 |
+
If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
|
479 |
+
a `tuple` is returned where the first element is the sample tensor.
|
480 |
+
"""
|
481 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
482 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
483 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
484 |
+
# on the fly if necessary.
|
485 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
486 |
+
|
487 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
488 |
+
forward_upsample_size = False
|
489 |
+
upsample_size = None
|
490 |
+
|
491 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
492 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
493 |
+
forward_upsample_size = True
|
494 |
+
|
495 |
+
# prepare attention_mask
|
496 |
+
if attention_mask is not None:
|
497 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
498 |
+
attention_mask = attention_mask.unsqueeze(1)
|
499 |
+
|
500 |
+
# 1. time
|
501 |
+
timesteps = timestep
|
502 |
+
if not torch.is_tensor(timesteps):
|
503 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
504 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
505 |
+
is_mps = sample.device.type == "mps"
|
506 |
+
if isinstance(timestep, float):
|
507 |
+
dtype = torch.float32 if is_mps else torch.float64
|
508 |
+
else:
|
509 |
+
dtype = torch.int32 if is_mps else torch.int64
|
510 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
511 |
+
elif len(timesteps.shape) == 0:
|
512 |
+
timesteps = timesteps[None].to(sample.device)
|
513 |
+
|
514 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
515 |
+
num_frames = sample.shape[2]
|
516 |
+
timesteps = timesteps.expand(sample.shape[0])
|
517 |
+
|
518 |
+
t_emb = self.time_proj(timesteps)
|
519 |
+
|
520 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
521 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
522 |
+
# there might be better ways to encapsulate this.
|
523 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
524 |
+
|
525 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
526 |
+
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
527 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
528 |
+
|
529 |
+
# 2. pre-process
|
530 |
+
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
531 |
+
sample = self.conv_in(sample)
|
532 |
+
|
533 |
+
sample = self.transformer_in(
|
534 |
+
sample,
|
535 |
+
num_frames=num_frames,
|
536 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
537 |
+
return_dict=False,
|
538 |
+
)[0]
|
539 |
+
|
540 |
+
# 3. down
|
541 |
+
down_block_res_samples = (sample,)
|
542 |
+
for downsample_block in self.down_blocks:
|
543 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
544 |
+
sample, res_samples = downsample_block(
|
545 |
+
hidden_states=sample,
|
546 |
+
temb=emb,
|
547 |
+
encoder_hidden_states=encoder_hidden_states,
|
548 |
+
attention_mask=attention_mask,
|
549 |
+
num_frames=num_frames,
|
550 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
551 |
+
)
|
552 |
+
else:
|
553 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
554 |
+
|
555 |
+
down_block_res_samples += res_samples
|
556 |
+
|
557 |
+
if down_block_additional_residuals is not None:
|
558 |
+
new_down_block_res_samples = ()
|
559 |
+
|
560 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
561 |
+
down_block_res_samples, down_block_additional_residuals
|
562 |
+
):
|
563 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
564 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
565 |
+
|
566 |
+
down_block_res_samples = new_down_block_res_samples
|
567 |
+
|
568 |
+
# 4. mid
|
569 |
+
if self.mid_block is not None:
|
570 |
+
sample = self.mid_block(
|
571 |
+
sample,
|
572 |
+
emb,
|
573 |
+
encoder_hidden_states=encoder_hidden_states,
|
574 |
+
attention_mask=attention_mask,
|
575 |
+
num_frames=num_frames,
|
576 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
577 |
+
)
|
578 |
+
|
579 |
+
if mid_block_additional_residual is not None:
|
580 |
+
sample = sample + mid_block_additional_residual
|
581 |
+
|
582 |
+
# 5. up
|
583 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
584 |
+
is_final_block = i == len(self.up_blocks) - 1
|
585 |
+
|
586 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
587 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
588 |
+
|
589 |
+
# if we have not reached the final block and need to forward the
|
590 |
+
# upsample size, we do it here
|
591 |
+
if not is_final_block and forward_upsample_size:
|
592 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
593 |
+
|
594 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
595 |
+
sample = upsample_block(
|
596 |
+
hidden_states=sample,
|
597 |
+
temb=emb,
|
598 |
+
res_hidden_states_tuple=res_samples,
|
599 |
+
encoder_hidden_states=encoder_hidden_states,
|
600 |
+
upsample_size=upsample_size,
|
601 |
+
attention_mask=attention_mask,
|
602 |
+
num_frames=num_frames,
|
603 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
604 |
+
)
|
605 |
+
else:
|
606 |
+
sample = upsample_block(
|
607 |
+
hidden_states=sample,
|
608 |
+
temb=emb,
|
609 |
+
res_hidden_states_tuple=res_samples,
|
610 |
+
upsample_size=upsample_size,
|
611 |
+
num_frames=num_frames,
|
612 |
+
)
|
613 |
+
|
614 |
+
# 6. post-process
|
615 |
+
if self.conv_norm_out:
|
616 |
+
sample = self.conv_norm_out(sample)
|
617 |
+
sample = self.conv_act(sample)
|
618 |
+
|
619 |
+
sample = self.conv_out(sample)
|
620 |
+
|
621 |
+
# reshape to (batch, channel, framerate, width, height)
|
622 |
+
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
|
623 |
+
|
624 |
+
if not return_dict:
|
625 |
+
return (sample,)
|
626 |
+
|
627 |
+
return UNet3DConditionOutput(sample=sample)
|