FANG DAI commited on
Commit
2ad255e
·
verified ·
1 Parent(s): 0f1163d

Upload 126 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Tiger Model/Coarse-Training.py +947 -0
  2. Tiger Model/Fine-Training.py +1246 -0
  3. Tiger Model/GP.py +266 -0
  4. Tiger Model/IS.py +109 -0
  5. Tiger Model/diffusiers-Tiger/CLIPTextModel.py +1326 -0
  6. Tiger Model/diffusiers-Tiger/__init__.py +293 -0
  7. Tiger Model/diffusiers-Tiger/__pycache__/__init__.cpython-38.pyc +0 -0
  8. Tiger Model/diffusiers-Tiger/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  9. Tiger Model/diffusiers-Tiger/__pycache__/fuse.cpython-38.pyc +0 -0
  10. Tiger Model/diffusiers-Tiger/__pycache__/image_processor.cpython-38.pyc +0 -0
  11. Tiger Model/diffusiers-Tiger/__pycache__/loaders.cpython-38.pyc +0 -0
  12. Tiger Model/diffusiers-Tiger/__pycache__/optimization.cpython-38.pyc +0 -0
  13. Tiger Model/diffusiers-Tiger/__pycache__/training_utils.cpython-38.pyc +0 -0
  14. Tiger Model/diffusiers-Tiger/commands/__init__.py +27 -0
  15. Tiger Model/diffusiers-Tiger/commands/diffusers_cli.py +43 -0
  16. Tiger Model/diffusiers-Tiger/commands/env.py +84 -0
  17. Tiger Model/diffusiers-Tiger/commands/fp16_safetensors.py +133 -0
  18. Tiger Model/diffusiers-Tiger/configuration_utils.py +686 -0
  19. Tiger Model/diffusiers-Tiger/dependency_versions_check.py +47 -0
  20. Tiger Model/diffusiers-Tiger/dependency_versions_table.py +44 -0
  21. Tiger Model/diffusiers-Tiger/fuse.py +175 -0
  22. Tiger Model/diffusiers-Tiger/getWeight.py +88 -0
  23. Tiger Model/diffusiers-Tiger/image_processor.py +366 -0
  24. Tiger Model/diffusiers-Tiger/loaders.py +0 -0
  25. Tiger Model/diffusiers-Tiger/models/README.md +3 -0
  26. Tiger Model/diffusiers-Tiger/models/__init__.py +39 -0
  27. Tiger Model/diffusiers-Tiger/models/activations.py +14 -0
  28. Tiger Model/diffusiers-Tiger/models/adapter.py +291 -0
  29. Tiger Model/diffusiers-Tiger/models/attention.py +437 -0
  30. Tiger Model/diffusiers-Tiger/models/attention_processor.py +1716 -0
  31. Tiger Model/diffusiers-Tiger/models/autoencoder_asym_kl.py +180 -0
  32. Tiger Model/diffusiers-Tiger/models/autoencoder_kl.py +417 -0
  33. Tiger Model/diffusiers-Tiger/models/autoencoder_tiny.py +342 -0
  34. Tiger Model/diffusiers-Tiger/models/controlnet.py +762 -0
  35. Tiger Model/diffusiers-Tiger/models/dual_transformer_2d.py +151 -0
  36. Tiger Model/diffusiers-Tiger/models/embeddings.py +602 -0
  37. Tiger Model/diffusiers-Tiger/models/lora.py +117 -0
  38. Tiger Model/diffusiers-Tiger/models/modeling_utils.py +997 -0
  39. Tiger Model/diffusiers-Tiger/models/prior_transformer.py +364 -0
  40. Tiger Model/diffusiers-Tiger/models/resnet.py +878 -0
  41. Tiger Model/diffusiers-Tiger/models/t5_film_transformer.py +321 -0
  42. Tiger Model/diffusiers-Tiger/models/transformer_2d.py +359 -0
  43. Tiger Model/diffusiers-Tiger/models/transformer_temporal.py +179 -0
  44. Tiger Model/diffusiers-Tiger/models/unet_1d.py +255 -0
  45. Tiger Model/diffusiers-Tiger/models/unet_1d_blocks.py +656 -0
  46. Tiger Model/diffusiers-Tiger/models/unet_2d.py +329 -0
  47. Tiger Model/diffusiers-Tiger/models/unet_2d_blocks.py +0 -0
  48. Tiger Model/diffusiers-Tiger/models/unet_2d_condition.py +1009 -0
  49. Tiger Model/diffusiers-Tiger/models/unet_3d_blocks.py +679 -0
  50. Tiger Model/diffusiers-Tiger/models/unet_3d_condition.py +627 -0
Tiger Model/Coarse-Training.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ import datasets
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import ProjectConfiguration, set_seed
34
+ from datasets import load_dataset
35
+ from packaging import version
36
+ from torchvision import transforms
37
+ from tqdm.auto import tqdm
38
+ from transformers import CLIPTextModel, CLIPTokenizer
39
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
40
+ import diffusers
41
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler
42
+ from diffusers.loaders import AttnProcsLayers
43
+ from diffusers.models.attention_processor import LoRAAttnProcessor
44
+ from diffusers.optimization import get_scheduler
45
+ from diffusers.utils import check_min_version, is_wandb_available
46
+ from diffusers.utils.import_utils import is_xformers_available
47
+ import warnings
48
+ warnings.filterwarnings('ignore')
49
+
50
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
51
+
52
+
53
+ logger = get_logger(__name__, log_level="INFO")
54
+
55
+
56
+ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
57
+ img_str = ""
58
+ for i, image in enumerate(images):
59
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
60
+ img_str += f"![img_{i}](./image_{i}.png)\n"
61
+
62
+ yaml = f"""
63
+ ---
64
+ license: creativeml-openrail-m
65
+ base_model: {base_model}
66
+ tags:
67
+ - stable-diffusion
68
+ - stable-diffusion-diffusers
69
+ - text-to-image
70
+ - diffusers
71
+ - lora
72
+ inference: true
73
+ ---
74
+ """
75
+ model_card = f"""
76
+ # LoRA text2image fine-tuning - {repo_id}
77
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
78
+ {img_str}
79
+ """
80
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
81
+ f.write(yaml + model_card)
82
+
83
+
84
+ def parse_args():
85
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
86
+ parser.add_argument(
87
+ "--pretrained_model_name_or_path",
88
+ type=str,
89
+ default=None,
90
+ required=True,
91
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
92
+ )
93
+ parser.add_argument(
94
+ "--revision",
95
+ type=str,
96
+ default=None,
97
+ required=False,
98
+ help="Revision of pretrained model identifier from huggingface.co/models.",
99
+ )
100
+ parser.add_argument(
101
+ "--dataset_name",
102
+ type=str,
103
+ default=None,
104
+ help=(
105
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
106
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
107
+ " or to a folder containing files that 🤗 Datasets can understand."
108
+ ),
109
+ )
110
+ parser.add_argument(
111
+ "--dataset_config_name",
112
+ type=str,
113
+ default=None,
114
+ help="The config of the Dataset, leave as None if there's only one config.",
115
+ )
116
+ parser.add_argument(
117
+ "--train_data_dir",
118
+ type=str,
119
+ default=None,
120
+ help=(
121
+ "A folder containing the training data. Folder contents must follow the structure described in"
122
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
123
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
124
+ ),
125
+ )
126
+ parser.add_argument(
127
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
128
+ )
129
+ parser.add_argument(
130
+ "--caption_column",
131
+ type=str,
132
+ default="text",
133
+ help="The column of the dataset containing a caption or a list of captions.",
134
+ )
135
+ parser.add_argument(
136
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
137
+ )
138
+ parser.add_argument(
139
+ "--num_validation_images",
140
+ type=int,
141
+ default=4,
142
+ help="Number of images that should be generated during validation with `validation_prompt`.",
143
+ )
144
+ parser.add_argument(
145
+ "--validation_epochs",
146
+ type=int,
147
+ default=1,
148
+ help=(
149
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
150
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
151
+ ),
152
+ )
153
+ parser.add_argument(
154
+ "--max_train_samples",
155
+ type=int,
156
+ default=None,
157
+ help=(
158
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
159
+ "value if set."
160
+ ),
161
+ )
162
+ parser.add_argument(
163
+ "--output_dir",
164
+ type=str,
165
+ default="sd-model-finetuned-lora",
166
+ help="The output directory where the model predictions and checkpoints will be written.",
167
+ )
168
+ parser.add_argument(
169
+ "--cache_dir",
170
+ type=str,
171
+ default=None,
172
+ help="The directory where the downloaded models and datasets will be stored.",
173
+ )
174
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
175
+ parser.add_argument(
176
+ "--resolution",
177
+ type=int,
178
+ default=512,
179
+ help=(
180
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
181
+ " resolution"
182
+ ),
183
+ )
184
+ parser.add_argument(
185
+ "--center_crop",
186
+ default=False,
187
+ action="store_true",
188
+ help=(
189
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
190
+ " cropped. The images will be resized to the resolution first before cropping."
191
+ ),
192
+ )
193
+ parser.add_argument(
194
+ "--random_flip",
195
+ action="store_true",
196
+ help="whether to randomly flip images horizontally",
197
+ )
198
+ parser.add_argument(
199
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
200
+ )
201
+ parser.add_argument("--num_train_epochs", type=int, default=100)
202
+ parser.add_argument(
203
+ "--max_train_steps",
204
+ type=int,
205
+ default=None,
206
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
207
+ )
208
+ parser.add_argument(
209
+ "--gradient_accumulation_steps",
210
+ type=int,
211
+ default=1,
212
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
213
+ )
214
+ parser.add_argument(
215
+ "--gradient_checkpointing",
216
+ action="store_true",
217
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
218
+ )
219
+ parser.add_argument(
220
+ "--learning_rate",
221
+ type=float,
222
+ default=1e-4,
223
+ help="Initial learning rate (after the potential warmup period) to use.",
224
+ )
225
+ parser.add_argument(
226
+ "--scale_lr",
227
+ action="store_true",
228
+ default=False,
229
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
230
+ )
231
+ parser.add_argument(
232
+ "--lr_scheduler",
233
+ type=str,
234
+ default="constant",
235
+ help=(
236
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
237
+ ' "constant", "constant_with_warmup"]'
238
+ ),
239
+ )
240
+ parser.add_argument(
241
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
242
+ )
243
+ parser.add_argument(
244
+ "--snr_gamma",
245
+ type=float,
246
+ default=None,
247
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
248
+ "More details here: https://arxiv.org/abs/2303.09556.",
249
+ )
250
+ parser.add_argument(
251
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
252
+ )
253
+ parser.add_argument(
254
+ "--allow_tf32",
255
+ action="store_true",
256
+ help=(
257
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
258
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
259
+ ),
260
+ )
261
+ parser.add_argument(
262
+ "--dataloader_num_workers",
263
+ type=int,
264
+ default=0,
265
+ help=(
266
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
267
+ ),
268
+ )
269
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
270
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
271
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
272
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
273
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
274
+
275
+ parser.add_argument(
276
+ "--prediction_type",
277
+ type=str,
278
+ default=None,
279
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
280
+ )
281
+ parser.add_argument(
282
+ "--hub_model_id",
283
+ type=str,
284
+ default=None,
285
+ help="The name of the repository to keep in sync with the local `output_dir`.",
286
+ )
287
+ parser.add_argument(
288
+ "--logging_dir",
289
+ type=str,
290
+ default="logs",
291
+ help=(
292
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
293
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
294
+ ),
295
+ )
296
+ parser.add_argument(
297
+ "--mixed_precision",
298
+ type=str,
299
+ default="no",
300
+ choices=["no", "fp16", "bf16"],
301
+ help=(
302
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
303
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
304
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--report_to",
309
+ type=str,
310
+ default="tensorboard",
311
+ help=(
312
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
313
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
314
+ ),
315
+ )
316
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
317
+ parser.add_argument(
318
+ "--checkpointing_steps",
319
+ type=int,
320
+ default=500,
321
+ help=(
322
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
323
+ " training using `--resume_from_checkpoint`."
324
+ ),
325
+ )
326
+ parser.add_argument(
327
+ "--checkpoints_total_limit",
328
+ type=int,
329
+ default=None,
330
+ help=("Max number of checkpoints to store."),
331
+ )
332
+ parser.add_argument(
333
+ "--resume_from_checkpoint",
334
+ type=str,
335
+ default=None,
336
+ help=(
337
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
338
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
339
+ ),
340
+ )
341
+ parser.add_argument(
342
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
343
+ )
344
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
345
+ parser.add_argument(
346
+ "--rank",
347
+ type=int,
348
+ default=4,
349
+ help=("The dimension of the LoRA update matrices."),
350
+ )
351
+
352
+ args = parser.parse_args()
353
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
354
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
355
+ args.local_rank = env_local_rank
356
+
357
+ # Sanity checks
358
+ if args.dataset_name is None and args.train_data_dir is None:
359
+ raise ValueError("Need either a dataset name or a training folder.")
360
+
361
+ return args
362
+
363
+
364
+ DATASET_NAME_MAPPING = {
365
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
366
+ }
367
+
368
+
369
+ def main():
370
+ args = parse_args()
371
+ logging_dir = Path(args.output_dir, args.logging_dir)
372
+
373
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
374
+
375
+ accelerator = Accelerator(
376
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
377
+ mixed_precision=args.mixed_precision,
378
+ log_with=args.report_to,
379
+ project_config=accelerator_project_config,
380
+ )
381
+ if args.report_to == "wandb":
382
+ if not is_wandb_available():
383
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
384
+ import wandb
385
+
386
+ # Make one log on every process with the configuration for debugging.
387
+ logging.basicConfig(
388
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
389
+ datefmt="%m/%d/%Y %H:%M:%S",
390
+ level=logging.INFO,
391
+ )
392
+ logger.info(accelerator.state, main_process_only=False)
393
+ if accelerator.is_local_main_process:
394
+ datasets.utils.logging.set_verbosity_warning()
395
+ transformers.utils.logging.set_verbosity_warning()
396
+ diffusers.utils.logging.set_verbosity_info()
397
+ else:
398
+ datasets.utils.logging.set_verbosity_error()
399
+ transformers.utils.logging.set_verbosity_error()
400
+ diffusers.utils.logging.set_verbosity_error()
401
+
402
+ # If passed along, set the training seed now.
403
+ if args.seed is not None:
404
+ set_seed(args.seed)
405
+
406
+ # Handle the repository creation
407
+ if accelerator.is_main_process:
408
+ if args.output_dir is not None:
409
+ os.makedirs(args.output_dir, exist_ok=True)
410
+
411
+ # Load scheduler, tokenizer and models.
412
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
413
+ tokenizer = CLIPTokenizer.from_pretrained(
414
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
415
+ )
416
+ text_encoder = CLIPTextModel.from_pretrained(
417
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
418
+ )
419
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
420
+ unet = UNet2DConditionModel.from_pretrained(
421
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
422
+ )
423
+ # freeze parameters of models to save more memory
424
+ unet.requires_grad_(False)
425
+ vae.requires_grad_(False)
426
+
427
+ text_encoder.requires_grad_(False)
428
+
429
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
430
+ # as these weights are only used for inference, keeping weights in full precision is not required.
431
+ weight_dtype = torch.float32
432
+ if accelerator.mixed_precision == "fp16":
433
+ weight_dtype = torch.float16
434
+ elif accelerator.mixed_precision == "bf16":
435
+ weight_dtype = torch.bfloat16
436
+
437
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
438
+ unet.to(accelerator.device, dtype=weight_dtype)
439
+ vae.to(accelerator.device, dtype=weight_dtype)
440
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
441
+
442
+ # now we will add new LoRA weights to the attention layers
443
+ # It's important to realize here how many attention weights will be added and of which sizes
444
+ # The sizes of the attention layers consist only of two different variables:
445
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
446
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
447
+
448
+ # Let's first see how many attention processors we will have to set.
449
+ # For Stable Diffusion, it should be equal to:
450
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
451
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
452
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
453
+ # => 32 layers
454
+
455
+ # Set correct lora layers
456
+ lora_attn_procs = {}
457
+ for name in unet.attn_processors.keys():
458
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
459
+ if name.startswith("mid_block"):
460
+ hidden_size = unet.config.block_out_channels[-1]
461
+ elif name.startswith("up_blocks"):
462
+ block_id = int(name[len("up_blocks.")])
463
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
464
+ elif name.startswith("down_blocks"):
465
+ block_id = int(name[len("down_blocks.")])
466
+ hidden_size = unet.config.block_out_channels[block_id]
467
+
468
+ lora_attn_procs[name] = LoRAAttnProcessor(
469
+ hidden_size=hidden_size,
470
+ cross_attention_dim=cross_attention_dim,
471
+ rank=args.rank,
472
+ )
473
+
474
+ unet.set_attn_processor(lora_attn_procs)
475
+
476
+
477
+ def compute_snr(timesteps):
478
+ """
479
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
480
+ """
481
+ alphas_cumprod = noise_scheduler.alphas_cumprod
482
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
483
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
484
+
485
+ # Expand the tensors.
486
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
487
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
488
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
489
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
490
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
491
+
492
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
493
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
494
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
495
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
496
+
497
+ # Compute SNR.
498
+ snr = (alpha / sigma) ** 2
499
+ return snr
500
+
501
+ lora_layers = AttnProcsLayers(unet.attn_processors)
502
+
503
+ # Enable TF32 for faster training on Ampere GPUs,
504
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
505
+ if args.allow_tf32:
506
+ torch.backends.cuda.matmul.allow_tf32 = True
507
+
508
+ if args.scale_lr:
509
+ args.learning_rate = (
510
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
511
+ )
512
+
513
+ # Initialize the optimizer
514
+ if args.use_8bit_adam:
515
+ try:
516
+ import bitsandbytes as bnb
517
+ except ImportError:
518
+ raise ImportError(
519
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
520
+ )
521
+
522
+ optimizer_cls = bnb.optim.AdamW8bit
523
+ else:
524
+ optimizer_cls = torch.optim.AdamW
525
+
526
+ optimizer = optimizer_cls(
527
+ lora_layers.parameters(),
528
+ lr=args.learning_rate,
529
+ betas=(args.adam_beta1, args.adam_beta2),
530
+ weight_decay=args.adam_weight_decay,
531
+ eps=args.adam_epsilon,
532
+ )
533
+
534
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
535
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
536
+
537
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
538
+ # download the dataset.
539
+ if args.dataset_name is not None:
540
+ # Downloading and loading a dataset from the hub.
541
+ dataset = load_dataset(
542
+ args.dataset_name,
543
+ args.dataset_config_name,
544
+ cache_dir=args.cache_dir,
545
+ data_dir=args.train_data_dir,
546
+ )
547
+ else:
548
+ data_files = {}
549
+ if args.train_data_dir is not None:
550
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
551
+ dataset = load_dataset(
552
+ "imagefolder",
553
+ data_files=data_files,
554
+ cache_dir=args.cache_dir,
555
+ )
556
+ # See more about loading custom images at
557
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
558
+
559
+ # Preprocessing the datasets.
560
+ # We need to tokenize inputs and targets.
561
+ column_names = dataset["train"].column_names
562
+
563
+ # 6. Get the column names for input/target.
564
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
565
+ if args.image_column is None:
566
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
567
+ else:
568
+ image_column = args.image_column
569
+ if image_column not in column_names:
570
+ raise ValueError(
571
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
572
+ )
573
+ if args.caption_column is None:
574
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
575
+ else:
576
+ caption_column = args.caption_column
577
+ if caption_column not in column_names:
578
+ raise ValueError(
579
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
580
+ )
581
+
582
+ # Preprocessing the datasets.
583
+ # We need to tokenize input captions and transform the images.
584
+ def tokenize_captions(examples, is_train=True):
585
+ captions = []
586
+ for caption in examples[caption_column]:
587
+ if isinstance(caption, str):
588
+ captions.append(caption)
589
+ elif isinstance(caption, (list, np.ndarray)):
590
+ # take a random caption if there are multiple
591
+ captions.append(random.choice(caption) if is_train else caption[0])
592
+ else:
593
+ raise ValueError(
594
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
595
+ )
596
+ inputs = tokenizer(
597
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
598
+ )
599
+ return inputs.input_ids
600
+
601
+ # Preprocessing the datasets.
602
+ train_transforms = transforms.Compose(
603
+ [
604
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
605
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
606
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
607
+ transforms.ToTensor(),
608
+ transforms.Normalize([0.5], [0.5]),
609
+ ]
610
+ )
611
+
612
+ def preprocess_train(examples):
613
+ images = [image.convert("RGB") for image in examples[image_column]]
614
+ examples["pixel_values"] = [train_transforms(image) for image in images]
615
+ examples["input_ids"] = tokenize_captions(examples)
616
+ return examples
617
+
618
+ with accelerator.main_process_first():
619
+ if args.max_train_samples is not None:
620
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
621
+ # Set the training transforms
622
+ train_dataset = dataset["train"].with_transform(preprocess_train)
623
+
624
+ def collate_fn(examples):
625
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
626
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
627
+ input_ids = torch.stack([example["input_ids"] for example in examples])
628
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
629
+
630
+ # DataLoaders creation:
631
+ train_dataloader = torch.utils.data.DataLoader(
632
+ train_dataset,
633
+ shuffle=True,
634
+ collate_fn=collate_fn,
635
+ batch_size=args.train_batch_size,
636
+ num_workers=args.dataloader_num_workers,
637
+ )
638
+
639
+ # Scheduler and math around the number of training steps.
640
+ overrode_max_train_steps = False
641
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
642
+ if args.max_train_steps is None:
643
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
644
+ overrode_max_train_steps = True
645
+
646
+ lr_scheduler = get_scheduler(
647
+ args.lr_scheduler,
648
+ optimizer=optimizer,
649
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
650
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
651
+ )
652
+
653
+ # Prepare everything with our `accelerator`.
654
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
655
+ lora_layers, optimizer, train_dataloader, lr_scheduler
656
+ )
657
+
658
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
659
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
660
+ if overrode_max_train_steps:
661
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
662
+ # Afterwards we recalculate our number of training epochs
663
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
664
+
665
+ # We need to initialize the trackers we use, and also store our configuration.
666
+ # The trackers initializes automatically on the main process.
667
+ if accelerator.is_main_process:
668
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
669
+
670
+ # Train!
671
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
672
+
673
+ logger.info("***** Running training *****")
674
+ logger.info(f" Num examples = {len(train_dataset)}")
675
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
676
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
677
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
678
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
679
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
680
+ global_step = 0
681
+ first_epoch = 0
682
+
683
+ # Potentially load in the weights and states from a previous save
684
+ if args.resume_from_checkpoint:
685
+ if args.resume_from_checkpoint != "latest":
686
+ path = os.path.basename(args.resume_from_checkpoint)
687
+ else:
688
+ # Get the most recent checkpoint
689
+ dirs = os.listdir(args.output_dir)
690
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
691
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
692
+ path = dirs[-1] if len(dirs) > 0 else None
693
+
694
+ if path is None:
695
+ accelerator.print(
696
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
697
+ )
698
+ args.resume_from_checkpoint = None
699
+ else:
700
+ accelerator.print(f"Resuming from checkpoint {path}")
701
+ accelerator.load_state(os.path.join(args.output_dir, path))
702
+ global_step = int(path.split("-")[1])
703
+ resume_global_step = global_step * args.gradient_accumulation_steps
704
+ first_epoch = global_step // num_update_steps_per_epoch
705
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
706
+
707
+ # Only show the progress bar once on each machine.
708
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
709
+ progress_bar.set_description("Steps")
710
+
711
+ for epoch in range(first_epoch, args.num_train_epochs):
712
+ unet.train()
713
+ train_loss = 0.0
714
+ for step, batch in enumerate(train_dataloader):
715
+ # Skip steps until we reach the resumed step
716
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
717
+ if step % args.gradient_accumulation_steps == 0:
718
+ progress_bar.update(1)
719
+ continue
720
+
721
+ with accelerator.accumulate(unet):
722
+ # Convert images to latent space
723
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
724
+ latents = latents * vae.config.scaling_factor
725
+
726
+ # Sample noise that we'll add to the latents
727
+ noise = torch.randn_like(latents)
728
+ if args.noise_offset:
729
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
730
+ noise += args.noise_offset * torch.randn(
731
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
732
+ )
733
+
734
+ bsz = latents.shape[0]
735
+ # Sample a random timestep for each image
736
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
737
+ timesteps = timesteps.long()
738
+
739
+ # Add noise to the latents according to the noise magnitude at each timestep
740
+ # (this is the forward diffusion process)
741
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
742
+
743
+ # Get the text embedding for conditioning
744
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
745
+
746
+ # Get the target for loss depending on the prediction type
747
+ if args.prediction_type is not None:
748
+ # set prediction_type of scheduler if defined
749
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
750
+
751
+ if noise_scheduler.config.prediction_type == "epsilon":
752
+ target = noise
753
+ elif noise_scheduler.config.prediction_type == "v_prediction":
754
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
755
+ else:
756
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
757
+
758
+ # Predict the noise residual and compute loss
759
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
760
+
761
+ if args.snr_gamma is None:
762
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
763
+ else:
764
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
765
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
766
+ # This is discussed in Section 4.2 of the same paper.
767
+ snr = compute_snr(timesteps)
768
+ mse_loss_weights = (
769
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
770
+ )
771
+ # We first calculate the original loss. Then we mean over the non-batch dimensions and
772
+ # rebalance the sample-wise losses with their respective loss weights.
773
+ # Finally, we take the mean of the rebalanced loss.
774
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
775
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
776
+ loss = loss.mean()
777
+
778
+ # Gather the losses across all processes for logging (if we use distributed training).
779
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
780
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
781
+
782
+ # Backpropagate
783
+ accelerator.backward(loss)
784
+ if accelerator.sync_gradients:
785
+ params_to_clip = lora_layers.parameters()
786
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
787
+ optimizer.step()
788
+ lr_scheduler.step()
789
+ optimizer.zero_grad()
790
+
791
+ # Checks if the accelerator has performed an optimization step behind the scenes
792
+ if accelerator.sync_gradients:
793
+ progress_bar.update(1)
794
+ global_step += 1
795
+ accelerator.log({"train_loss": train_loss}, step=global_step)
796
+ train_loss = 0.0
797
+
798
+ if global_step % args.checkpointing_steps == 0:
799
+ if accelerator.is_main_process:
800
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
801
+ if args.checkpoints_total_limit is not None:
802
+ checkpoints = os.listdir(args.output_dir)
803
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
804
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
805
+
806
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
807
+ if len(checkpoints) >= args.checkpoints_total_limit:
808
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
809
+ removing_checkpoints = checkpoints[0:num_to_remove]
810
+
811
+ logger.info(
812
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
813
+ )
814
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
815
+
816
+ for removing_checkpoint in removing_checkpoints:
817
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
818
+ shutil.rmtree(removing_checkpoint)
819
+
820
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
821
+ logger.info(f"Saved state to {save_path}")
822
+
823
+ unet = unet.to(torch.float32)
824
+ unet.save_attn_procs(save_path)
825
+
826
+ # create pipeline
827
+ # pipeline = DiffusionPipeline.from_pretrained(
828
+ # args.pretrained_model_name_or_path,
829
+ # unet=accelerator.unwrap_model(unet),
830
+ # revision=args.revision,
831
+ # torch_dtype=weight_dtype,
832
+ # )
833
+ # pipeline = pipeline.to(accelerator.device)
834
+ # pipeline.set_progress_bar_config(disable=True)
835
+
836
+ # # run inference
837
+ # generator = torch.Generator(device=accelerator.device)
838
+
839
+ # images = []
840
+ # for i in range(args.num_validation_images):
841
+ # if args.seed is not None:
842
+ # generator = generator.manual_seed(args.seed + i + args.checkpointing_steps)
843
+ # images.append(
844
+ # pipeline(args.validation_prompt, num_inference_steps=30, generator=generator, guidance_scale=7).images[0]
845
+ # )
846
+
847
+ if args.validation_prompt is not None:
848
+ logger.info(
849
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
850
+ f" {args.validation_prompt}."
851
+ )
852
+ print()
853
+ # create pipeline
854
+ pipeline = DiffusionPipeline.from_pretrained(
855
+ args.pretrained_model_name_or_path,
856
+ unet=accelerator.unwrap_model(unet),
857
+ revision=args.revision,
858
+ torch_dtype=weight_dtype,
859
+ )
860
+ pipeline = pipeline.to(accelerator.device)
861
+ pipeline.set_progress_bar_config(disable=True)
862
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
863
+ # run inference
864
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
865
+ images = []
866
+ for _ in range(args.num_validation_images):
867
+ images.append(
868
+ pipeline(
869
+ args.validation_prompt,
870
+ height=args.resolution,
871
+ width=args.resolution,
872
+ num_inference_steps=45,
873
+ generator=generator
874
+ ).images[0]
875
+ )
876
+ image2 = pipeline(
877
+ 'High quality photo of an astronaut riding a horse in space',
878
+ guidance_scale=7,
879
+ height=args.resolution,
880
+ width=args.resolution,
881
+ num_inference_steps=45
882
+ ).images[0]
883
+ images.append(image2)
884
+
885
+ for tracker in accelerator.trackers:
886
+ if tracker.name == "tensorboard":
887
+ np_images = np.stack([np.asarray(img) for img in images])
888
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
889
+ if tracker.name == "wandb":
890
+ tracker.log(
891
+ {
892
+ "validation": [
893
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
894
+ for i, image in enumerate(images)
895
+ ]
896
+ }
897
+ )
898
+
899
+ del pipeline
900
+ torch.cuda.empty_cache()
901
+
902
+ # Save the lora layers
903
+ accelerator.wait_for_everyone()
904
+ if accelerator.is_main_process:
905
+ unet = unet.to(torch.float32)
906
+ unet.save_attn_procs(args.output_dir)
907
+
908
+
909
+ # Final inference
910
+ # Load previous pipeline
911
+ pipeline = DiffusionPipeline.from_pretrained(
912
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
913
+ )
914
+ pipeline = pipeline.to(accelerator.device)
915
+
916
+ # load attention processors
917
+ pipeline.unet.load_attn_procs(args.output_dir)
918
+
919
+ # run inference
920
+ generator = torch.Generator(device=accelerator.device)
921
+ if args.seed is not None:
922
+ generator = generator.manual_seed(args.seed)
923
+ images = []
924
+ for _ in range(args.num_validation_images):
925
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
926
+
927
+ if accelerator.is_main_process:
928
+ for tracker in accelerator.trackers:
929
+ if len(images) != 0:
930
+ if tracker.name == "tensorboard":
931
+ np_images = np.stack([np.asarray(img) for img in images])
932
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
933
+ if tracker.name == "wandb":
934
+ tracker.log(
935
+ {
936
+ "test": [
937
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
938
+ for i, image in enumerate(images)
939
+ ]
940
+ }
941
+ )
942
+
943
+ accelerator.end_training()
944
+
945
+
946
+ if __name__ == "__main__":
947
+ main()
Tiger Model/Fine-Training.py ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import shutil
23
+ from pathlib import Path
24
+ from pynvml import *
25
+ import accelerate
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ import transformers123
31
+ from accelerate import Accelerator
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import ProjectConfiguration, set_seed
34
+ from datasets import load_dataset
35
+ from huggingface_hub import create_repo, upload_folder
36
+ from packaging import version
37
+ from PIL import Image
38
+ from torchvision import transforms
39
+ from tqdm.auto import tqdm
40
+ import transformers
41
+ from transformers import AutoTokenizer, PretrainedConfig
42
+ import tensorflow as tf
43
+ tf.get_logger().setLevel('ERROR')
44
+ from collections import Counter
45
+ import diffusers_Tiger
46
+ from diffusers_Tiger import (
47
+ AutoencoderKL,
48
+ ControlNetModel,
49
+ DDPMScheduler,
50
+ StableDiffusionControlNetPipeline,
51
+ StableDiffusionControlNetInpaintPipeline,
52
+ UNet2DConditionModel,
53
+ UniPCMultistepScheduler,
54
+ DDIMScheduler
55
+ )
56
+ from diffusers_Tiger.optimization import get_scheduler
57
+ from diffusers_Tiger.utils import check_min_version, is_wandb_available
58
+ from diffusers_Tiger.utils.import_utils import is_xformers_available
59
+ from diffusers_Tiger import fuse
60
+
61
+ if is_wandb_available():
62
+ import wandb
63
+ import warnings
64
+ warnings.filterwarnings('ignore')
65
+
66
+ # Will error if the minimal version of diffusers123 is not installed. Remove at your own risks.
67
+ check_min_version("0.19.0.dev0")
68
+
69
+ logger = get_logger(__name__)
70
+
71
+
72
+ def image_grid(imgs, rows, cols):
73
+ assert len(imgs) == rows * cols
74
+
75
+ w, h = imgs[0].sizeelerator
76
+ grid = Image.new("RGB", size=(cols * w, rows * h))
77
+
78
+ for i, img in enumerate(imgs):
79
+ grid.paste(img, box=(i % cols * w, i // cols * h))
80
+ return grid
81
+
82
+ def make_inpaint_condition(image, image_mask):
83
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
84
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
85
+
86
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
87
+ image[image_mask > 0.5] = -1.0 # set as masked pixel
88
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
89
+ image = torch.from_numpy(image)
90
+ return image
91
+
92
+ def log_validation(vae, text_encoder, tokenizer, unet, controlnet_nd, controlnet_bg, args, accelerator, weight_dtype, step):
93
+ logger.info("Running validation... ")
94
+
95
+ controlnet_nd = accelerator.unwrap_model(controlnet)
96
+
97
+ pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
98
+ args.pretrained_model_name_or_path,
99
+ vae=vae,
100
+ text_encoder=text_encoder,
101
+ tokenizer=tokenizer,
102
+ unet=unet,
103
+ controlnet=controlnet,
104
+ safety_checker=None,
105
+ revision=args.revision,
106
+ torch_dtype=weight_dtype,
107
+ )
108
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
109
+ pipeline = pipeline.to(accelerator.device)
110
+ pipeline.set_progress_bar_config(disable=True)
111
+
112
+ if args.enable_xformers_memory_efficient_attention:
113
+ pipeline.enable_xformers_memory_efficient_attention()
114
+
115
+ if args.seed is None:
116
+ generator = None
117
+ else:
118
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
119
+
120
+ if len(args.validation_image) == len(args.validation_prompt):
121
+ validation_images = args.validation_image
122
+ validation_prompts = args.validation_prompt
123
+ elif len(args.validation_image) == 1:
124
+ validation_images = args.validation_image * len(args.validation_prompt)
125
+ validation_prompts = args.validation_prompt
126
+ elif len(args.validation_prompt) == 1:
127
+ validation_images = args.validation_image
128
+ validation_prompts = args.validation_prompt * len(args.validation_image)
129
+ else:
130
+ raise ValueError(
131
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
132
+ )
133
+
134
+ image_logs = []
135
+ images = []
136
+ for validation_prompt, validation_image1 in zip(validation_prompts, validation_images):
137
+ validation_image = Image.open(validation_image1).convert("RGB").resize((512, 512))
138
+ mask_image = Image.open(validation_image1).convert("RGB").resize((512, 512))
139
+
140
+ control_image = make_inpaint_condition(validation_image, mask_image)
141
+
142
+ for _ in range(args.num_validation_images):
143
+ with torch.autocast("cuda"):
144
+ seed = random.randint(1,1000000)
145
+ generator = torch.Generator(device='cuda').manual_seed(seed)
146
+ image = pipeline(
147
+ validation_prompt,
148
+ num_inference_steps=50,
149
+ generator=generator,
150
+ eta=1.0,
151
+ image=validation_image,
152
+ mask_image=mask_image,
153
+ control_image=control_image,
154
+ guidance_scale = 7
155
+ ).images[0]
156
+
157
+ images.append(image)
158
+
159
+
160
+ image_logs.append(
161
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
162
+ )
163
+
164
+ for tracker in accelerator.trackers:
165
+ if tracker.name == "tensorboard":
166
+ for log in image_logs:
167
+ images = log["images"]
168
+ validation_prompt = log["validation_prompt"]
169
+ validation_image = log["validation_image"]
170
+
171
+ formatted_images = []
172
+
173
+ formatted_images.append(np.asarray(validation_image))
174
+
175
+ for image in images:
176
+ formatted_images.append(np.asarray(image))
177
+
178
+ formatted_images = np.stack(formatted_images)
179
+
180
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
181
+ elif tracker.name == "wandb":
182
+ formatted_images = []
183
+
184
+ for log in image_logs:
185
+ images = log["images"]
186
+ validation_prompt = log["validation_prompt"]
187
+ validation_image = log["validation_image"]
188
+
189
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
190
+
191
+ for image in images:
192
+ image = wandb.Image(image, caption=validation_prompt)
193
+ formatted_images.append(image)
194
+
195
+ tracker.log({"validation": formatted_images})
196
+ else:
197
+ logger.warn(f"image logging not implemented for {tracker.name}")
198
+
199
+ return image_logs
200
+
201
+
202
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
203
+ text_encoder_config = PretrainedConfig.from_pretrained(
204
+ pretrained_model_name_or_path,
205
+ subfolder="text_encoder",
206
+ revision=revision,
207
+ )
208
+ model_class = text_encoder_config.architectures[0]
209
+
210
+ if model_class == "CLIPTextModel":
211
+ from transformers123 import CLIPTextModel
212
+
213
+ return CLIPTextModel
214
+ elif model_class == "RobertaSeriesModelWithTransformation":
215
+ from diffusers123.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
216
+
217
+ return RobertaSeriesModelWithTransformation
218
+ else:
219
+ raise ValueError(f"{model_class} is not supported.")
220
+
221
+
222
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
223
+ img_str = ""
224
+ if image_logs is not None:
225
+ img_str = "You can find some example images below.\n"
226
+ for i, log in enumerate(image_logs):
227
+ images = log["images"]
228
+ validation_prompt = log["validation_prompt"]
229
+ validation_image = log["validation_image"]
230
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
231
+ img_str += f"prompt: {validation_prompt}\n"
232
+ images = [validation_image] + images
233
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
234
+ img_str += f"![images_{i})](./images_{i}.png)\n"
235
+
236
+ yaml = f"""
237
+ ---
238
+ license: creativeml-openrail-m
239
+ base_model: {base_model}
240
+ tags:
241
+ - stable-diffusion
242
+ - stable-diffusion-diffusers
243
+ - text-to-image
244
+ - diffusers
245
+ - controlnet
246
+ inference: true
247
+ ---
248
+ """
249
+ model_card = f"""
250
+ # controlnet-{repo_id}
251
+
252
+ These are controlnet weights trained on {base_model} with new type of conditioning.
253
+ {img_str}
254
+ """
255
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
256
+ f.write(yaml + model_card)
257
+
258
+
259
+ def parse_args(input_args=None):
260
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
261
+ parser.add_argument(
262
+ "--pretrained_model_name_or_path",
263
+ type=str,
264
+ default=None,
265
+ required=True,
266
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
267
+ )
268
+ parser.add_argument(
269
+ "--controlnet_model_name_or_path",
270
+ type=str,
271
+ default=None,
272
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
273
+ " If not specified controlnet weights are initialized from unet.",
274
+ )
275
+ parser.add_argument(
276
+ "--revision",
277
+ type=str,
278
+ default=None,
279
+ required=False,
280
+ help=(
281
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
282
+ " float32 precision."
283
+ ),
284
+ )
285
+ parser.add_argument(
286
+ "--tokenizer_name",
287
+ type=str,
288
+ default=None,
289
+ help="Pretrained tokenizer name or path if not the same as model_name",
290
+ )
291
+ parser.add_argument(
292
+ "--output_dir",
293
+ type=str,
294
+ default="controlnet-model",
295
+ help="The output directory where the model predictions and checkpoints will be written.",
296
+ )
297
+ parser.add_argument(
298
+ "--cache_dir",
299
+ type=str,
300
+ default="/export/home/daifang/Diffusion/own_code/dataset",
301
+ help="The directory where the downloaded models and datasets will be stored.",
302
+ )
303
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
304
+ parser.add_argument(
305
+ "--resolution",
306
+ type=int,
307
+ default=512,
308
+ help=(
309
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
310
+ " resolution"
311
+ ),
312
+ )
313
+ parser.add_argument(
314
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
315
+ )
316
+ parser.add_argument("--num_train_epochs", type=int, default=1)
317
+ parser.add_argument(
318
+ "--max_train_steps",
319
+ type=int,
320
+ default=None,
321
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
322
+ )
323
+ parser.add_argument(
324
+ "--checkpointing_steps",
325
+ type=int,
326
+ default=500,
327
+ help=(
328
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
329
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
330
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
331
+ "See https://huggingface.co/docs/diffusers123/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
332
+ "instructions."
333
+ ),
334
+ )
335
+ parser.add_argument(
336
+ "--checkpoints_total_limit",
337
+ type=int,
338
+ default=None,
339
+ help=("Max number of checkpoints to store."),
340
+ )
341
+ parser.add_argument(
342
+ "--resume_from_checkpoint",
343
+ type=str,
344
+ default=None,
345
+ help=(
346
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
347
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--gradient_accumulation_steps",
352
+ type=int,
353
+ default=1,
354
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
355
+ )
356
+ parser.add_argument(
357
+ "--gradient_checkpointing",
358
+ action="store_true",
359
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
360
+ )
361
+ parser.add_argument(
362
+ "--learning_rate",
363
+ type=float,
364
+ default=5e-6,
365
+ help="Initial learning rate (after the potential warmup period) to use.",
366
+ )
367
+ parser.add_argument(
368
+ "--scale_lr",
369
+ action="store_true",
370
+ default=False,
371
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
372
+ )
373
+ parser.add_argument(
374
+ "--lr_scheduler",
375
+ type=str,
376
+ default="constant",
377
+ help=(
378
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
379
+ ' "constant", "constant_with_warmup"]'
380
+ ),
381
+ )
382
+ parser.add_argument(
383
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
384
+ )
385
+ parser.add_argument(
386
+ "--lr_num_cycles",
387
+ type=int,
388
+ default=1,
389
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
390
+ )
391
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
392
+ parser.add_argument(
393
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
394
+ )
395
+ parser.add_argument(
396
+ "--dataloader_num_workers",
397
+ type=int,
398
+ default=0,
399
+ help=(
400
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
401
+ ),
402
+ )
403
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
404
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
405
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
406
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
407
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
408
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
409
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
410
+ parser.add_argument(
411
+ "--hub_model_id",
412
+ type=str,
413
+ default=None,
414
+ help="The name of the repository to keep in sync with the local `output_dir`.",
415
+ )
416
+ parser.add_argument(
417
+ "--logging_dir",
418
+ type=str,
419
+ default="logs",
420
+ help=(
421
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
422
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
423
+ ),
424
+ )
425
+ parser.add_argument(
426
+ "--allow_tf32",
427
+ action="store_true",
428
+ help=(
429
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
430
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
431
+ ),
432
+ )
433
+ parser.add_argument(
434
+ "--report_to",
435
+ type=str,
436
+ default="tensorboard",
437
+ help=(
438
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
439
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
440
+ ),
441
+ )
442
+ parser.add_argument(
443
+ "--mixed_precision",
444
+ type=str,
445
+ default="no",
446
+ choices=["no", "fp16", "bf16"],
447
+ help=(
448
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
449
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
450
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
451
+ ),
452
+ )
453
+ parser.add_argument(
454
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
455
+ )
456
+ parser.add_argument(
457
+ "--set_grads_to_none",
458
+ action="store_true",
459
+ help=(
460
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
461
+ " behaviors, so disable this argument if it causes any problems. More info:"
462
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
463
+ ),
464
+ )
465
+ parser.add_argument(
466
+ "--dataset_name",
467
+ type=str,
468
+ default=None,
469
+ help=(
470
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
471
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
472
+ " or to a folder containing files that 🤗 Datasets can understand."
473
+ ),
474
+ )
475
+ parser.add_argument(
476
+ "--dataset_config_name",
477
+ type=str,
478
+ default=None,
479
+ help="The config of the Dataset, leave as None if there's only one config.",
480
+ )
481
+ parser.add_argument(
482
+ "--train_data_dir",
483
+ type=str,
484
+ default=None,
485
+ help=(
486
+ "A folder containing the training data. Folder contents must follow the structure described in"
487
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
488
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
489
+ ),
490
+ )
491
+ ##############################################################################################################
492
+ parser.add_argument(
493
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
494
+ )
495
+ parser.add_argument(
496
+ "--conditioning_nd_column",
497
+ type=str,
498
+ default="condition_nd",
499
+ help="The column of the dataset containing the controlnet conditioning image.",
500
+ )
501
+ parser.add_argument(
502
+ "--conditioning_bg_column",
503
+ type=str,
504
+ default="condition_bg",
505
+ help="The column of the dataset containing the controlnet conditioning image.",
506
+ )
507
+ parser.add_argument(
508
+ "--caption_column_nd",
509
+ type=str,
510
+ default="text_nd",
511
+ help="The column of the dataset containing a caption or a list of captions.",
512
+ )
513
+ parser.add_argument(
514
+ "--caption_column_bg",
515
+ type=str,
516
+ default="text_nd",
517
+ help="The column of the dataset containing a caption or a list of captions.",
518
+ )
519
+ ##############################################################################################################
520
+ parser.add_argument(
521
+ "--max_train_samples",
522
+ type=int,
523
+ default=None,
524
+ help=(
525
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
526
+ "value if set."
527
+ ),
528
+ )
529
+ parser.add_argument(
530
+ "--proportion_empty_prompts",
531
+ type=float,
532
+ default=0,
533
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
534
+ )
535
+ parser.add_argument(
536
+ "--validation_prompt",
537
+ type=str,
538
+ default=None,
539
+ nargs="+",
540
+ help=(
541
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
542
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
543
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
544
+ ),
545
+ )
546
+ parser.add_argument(
547
+ "--validation_image",
548
+ type=str,
549
+ default=None,
550
+ nargs="+",
551
+ help=(
552
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
553
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
554
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
555
+ " `--validation_image` that will be used with all `--validation_prompt`s."
556
+ ),
557
+ )
558
+ parser.add_argument(
559
+ "--num_validation_images",
560
+ type=int,
561
+ default=4,
562
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
563
+ )
564
+ parser.add_argument(
565
+ "--validation_steps",
566
+ type=int,
567
+ default=100,
568
+ help=(
569
+ "Run validation every X steps. Validation consists of running the prompt"
570
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
571
+ " and logging the images."
572
+ ),
573
+ )
574
+ parser.add_argument(
575
+ "--tracker_project_name",
576
+ type=str,
577
+ default="train_controlnet",
578
+ help=(
579
+ "The `project_name` argument passed to Accelerator.init_trackers for"
580
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
581
+ ),
582
+ )
583
+
584
+ if input_args is not None:
585
+ args = parser.parse_args(input_args)
586
+ else:
587
+ args = parser.parse_args()
588
+
589
+ if args.dataset_name is None and args.train_data_dir is None:
590
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
591
+
592
+ if args.dataset_name is not None and args.train_data_dir is not None:
593
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
594
+
595
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
596
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
597
+
598
+ if args.validation_prompt is not None and args.validation_image is None:
599
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
600
+
601
+ if args.validation_prompt is None and args.validation_image is not None:
602
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
603
+
604
+ if (
605
+ args.validation_image is not None
606
+ and args.validation_prompt is not None
607
+ and len(args.validation_image) != 1
608
+ and len(args.validation_prompt) != 1
609
+ and len(args.validation_image) != len(args.validation_prompt)
610
+ ):
611
+ raise ValueError(
612
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
613
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
614
+ )
615
+
616
+ if args.resolution % 8 != 0:
617
+ raise ValueError(
618
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
619
+ )
620
+
621
+ return args
622
+
623
+
624
+ def make_train_dataset(args, tokenizer, accelerator):
625
+ if args.dataset_name is not None:
626
+ dataset = load_dataset(
627
+ args.dataset_name,
628
+ args.dataset_config_name,
629
+ cache_dir=args.cache_dir,
630
+ )
631
+ else:
632
+ if args.train_data_dir is not None:
633
+ dataset = load_dataset(
634
+ args.train_data_dir,
635
+ cache_dir=args.cache_dir,
636
+ )
637
+ column_names = dataset["train"].column_names
638
+ ##########################################################################################################################################################################
639
+ # Get the column names for input/target.
640
+ # target image
641
+ if args.image_column is None:
642
+ image_column = column_names[0]
643
+ logger.info(f"image column defaulting to {image_column}")
644
+ else:
645
+ image_column = args.image_column
646
+ if image_column not in column_names:
647
+ raise ValueError(
648
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
649
+ )
650
+ # condition nodule image
651
+ if args.conditioning_nd_column is None:
652
+
653
+ conditioning_nd_column = column_names[1]
654
+ logger.info(f"conditioning image column defaulting to {conditioning_nd_column}")
655
+ else:
656
+ conditioning_nd_column = args.conditioning_nd_column
657
+ if conditioning_nd_column not in column_names:
658
+ raise ValueError(
659
+ f"`--conditioning_nd_column` value '{args.conditioning_nd_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
660
+ )
661
+ # condition background image
662
+ if args.conditioning_bg_column is None:
663
+ conditioning_bg_column = column_names[2]
664
+ logger.info(f"conditioning bg column defaulting to {conditioning_bg_column}")
665
+ else:
666
+ conditioning_bg_column = args.conditioning_bg_column
667
+ if conditioning_bg_column not in column_names:
668
+ raise ValueError(
669
+ f"`--conditioning_bg_column` value '{args.conditioning_bg_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
670
+ )
671
+ # condition nodule text
672
+
673
+ if args.caption_column_nd is None:
674
+ caption_column_nd = column_names[3]
675
+ logger.info(f"caption column defaulting to {caption_column_nd}")
676
+ else:
677
+ caption_column_nd = args.caption_column_nd
678
+ if caption_column_nd not in column_names:
679
+ raise ValueError(
680
+ f"`--caption_column` value '{args.caption_column_nd}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
681
+ )
682
+ # condition backgrorund text
683
+ if args.caption_column_bg is None:
684
+ caption_column_bg = column_names[4]
685
+ logger.info(f"caption column defaulting to {caption_column_bg}")
686
+ else:
687
+ caption_column_bg = args.caption_column_bg
688
+ if caption_column_bg not in column_names:
689
+ raise ValueError(
690
+ f"`--caption_column` value '{args.caption_column_bg}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
691
+ )
692
+ ##########################################################################################################################################################################
693
+
694
+ def tokenize_captions(examples, caption_column, names, is_train=True):
695
+ captions = []
696
+ for caption in examples[caption_column]:
697
+ if random.random() < args.proportion_empty_prompts:
698
+ captions.append("")
699
+ elif isinstance(caption, str):
700
+ captions.append(caption)
701
+ elif isinstance(caption, (list, np.ndarray)):
702
+
703
+ # take a random caption if there are multiple
704
+ captions.append(random.choice(caption) if is_train else caption[0])
705
+ else:
706
+ raise ValueError(
707
+ f"Caption column `{caption_column_nd}` should contain either strings or lists of strings."
708
+ )
709
+
710
+ inputs = tokenizer(
711
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
712
+ )
713
+
714
+ def calculate_word_frequencies(phrases):
715
+ total_counts = Counter()
716
+ total_words = 0
717
+ for phrase in phrases:
718
+ words = phrase.replace(',', '').split()
719
+ total_counts.update(words)
720
+ total_words += len(words)
721
+ frequencies = {word: count / total_words for word, count in total_counts.items()}
722
+ return frequencies, total_words
723
+
724
+ def calculate_average_frequencies(phrases, word_frequencies):
725
+ average_frequencies = []
726
+ for phrase in phrases:
727
+ words = phrase.replace(',', '').split()
728
+ total_freq = sum(word_frequencies[word] for word in words)
729
+ avg_freq = total_freq / len(words) if words else 0
730
+ average_frequencies.append((phrase, avg_freq))
731
+ return average_frequencies
732
+ if names == 'nd':
733
+ word_frequencies, total_word_count = calculate_word_frequencies(captions)
734
+ weight_matrix = calculate_average_frequencies(captions, word_frequencies)
735
+ # Extract the values to replace
736
+ values = [desc[1] for desc in weight_matrix]
737
+ # Replace the first zero in each row with the corresponding value
738
+ for i in range(inputs.input_ids.shape[0]):
739
+ weight = int(values[i]*10**5)
740
+ inputs.input_ids[i][0] = weight
741
+ assert not torch.isnan(inputs.input_ids).any(), "inputs.input_ids contains NaN values"
742
+
743
+ return inputs.input_ids
744
+
745
+ image_transforms = transforms.Compose(
746
+ [
747
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
748
+ transforms.CenterCrop(args.resolution),
749
+ transforms.ToTensor(),
750
+ transforms.Normalize([0.5], [0.5]),
751
+ ]
752
+ )
753
+
754
+ conditioning_image_transforms = transforms.Compose(
755
+ [
756
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
757
+ transforms.CenterCrop(args.resolution),
758
+ transforms.ToTensor(),
759
+ ]
760
+ )
761
+
762
+ def preprocess_train(examples):
763
+ images = [image.convert("RGB") for image in examples[image_column]]
764
+ images = [image_transforms(image) for image in images]
765
+ conditioning_nd = [Image.open(image).convert("RGB") for image in examples[conditioning_nd_column]]
766
+ conditioning_nd = [conditioning_image_transforms(image) for image in conditioning_nd]
767
+
768
+ conditioning_bg = [Image.open(image).convert("RGB") for image in examples[conditioning_bg_column]]
769
+ conditioning_bg = [conditioning_image_transforms(image) for image in conditioning_bg]
770
+
771
+ examples["pixel_values"] = images
772
+ examples["conditioning_pixel_values_nd"] = conditioning_nd
773
+ examples["conditioning_pixel_values_bg"] = conditioning_bg
774
+ examples["input_ids_nd"] = tokenize_captions(examples, caption_column = caption_column_nd, names = 'nd')
775
+ examples["input_ids_bg"] = tokenize_captions(examples, caption_column = caption_column_bg, names = 'bg')
776
+
777
+ return examples
778
+
779
+ with accelerator.main_process_first():
780
+ if args.max_train_samples is not None:
781
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
782
+ # Set the training transforms
783
+ train_dataset = dataset["train"].with_transform(preprocess_train)
784
+
785
+ return train_dataset
786
+
787
+
788
+ def collate_fn(examples):
789
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
790
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
791
+
792
+ conditioning_pixel_values_nd = torch.stack([example["conditioning_pixel_values_nd"] for example in examples])
793
+ conditioning_pixel_values_nd = conditioning_pixel_values_nd.to(memory_format=torch.contiguous_format).float()
794
+
795
+ conditioning_pixel_values_bg = torch.stack([example["conditioning_pixel_values_bg"] for example in examples])
796
+ conditioning_pixel_values_bg = conditioning_pixel_values_bg.to(memory_format=torch.contiguous_format).float()
797
+
798
+ input_ids_nd = torch.stack([example["input_ids_nd"] for example in examples])
799
+ input_ids_bg = torch.stack([example["input_ids_bg"] for example in examples])
800
+
801
+ return {
802
+ "pixel_values": pixel_values,
803
+ "conditioning_pixel_values_nd": conditioning_pixel_values_nd,
804
+ "conditioning_pixel_values_bg": conditioning_pixel_values_bg,
805
+ "input_ids_nd": input_ids_nd,
806
+ "input_ids_bg": input_ids_bg,
807
+ }
808
+
809
+
810
+ def main(args):
811
+ logging_dir = Path(args.output_dir, args.logging_dir)
812
+
813
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
814
+
815
+ accelerator = Accelerator(
816
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
817
+ mixed_precision=args.mixed_precision,
818
+ log_with=args.report_to,
819
+ project_config=accelerator_project_config,
820
+ )
821
+
822
+ # Make one log on every process with the configuration for debugging.
823
+ logging.basicConfig(
824
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
825
+ datefmt="%m/%d/%Y %H:%M:%S",
826
+ level=logging.INFO,
827
+ )
828
+ logger.info(accelerator.state, main_process_only=False)
829
+ if accelerator.is_local_main_process:
830
+ transformers.utils.logging.set_verbosity_warning()
831
+ diffusers_Tiger.utils.logging.set_verbosity_info()
832
+ else:
833
+ transformers.utils.logging.set_verbosity_error()
834
+ diffusers_Tiger.utils.logging.set_verbosity_error()
835
+
836
+ # If passed along, set the training seed now.
837
+ if args.seed is not None:
838
+ set_seed(args.seed)
839
+
840
+ # Handle the repository creation
841
+ if accelerator.is_main_process:
842
+ if args.output_dir is not None:
843
+ os.makedirs(args.output_dir, exist_ok=True)
844
+
845
+ # Load the tokenizer
846
+ if args.tokenizer_name:
847
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
848
+ elif args.pretrained_model_name_or_path:
849
+ tokenizer = AutoTokenizer.from_pretrained(
850
+ args.pretrained_model_name_or_path,
851
+ subfolder="tokenizer",
852
+ revision=args.revision,
853
+ use_fast=False,
854
+ )
855
+
856
+ # import correct text encoder class
857
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
858
+
859
+ # Load scheduler and models
860
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
861
+ text_encoder = text_encoder_cls.from_pretrained(
862
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
863
+ )
864
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
865
+ unet = UNet2DConditionModel.from_pretrained(
866
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
867
+ )
868
+
869
+ if args.controlnet_model_name_or_path:
870
+ logger.info("Loading existing controlnet weights")
871
+ controlnet_nd = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
872
+ controlnet_bg = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
873
+ else:
874
+ logger.info("Initializing controlnet weights from unet")
875
+ controlnet_nd = ControlNetModel.from_unet(unet)
876
+ controlnet_bg = ControlNetModel.from_unet(unet)
877
+
878
+
879
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
880
+ def save_model_hook(models, weights, output_dir):
881
+ weights.pop()
882
+ model1 = models[0]
883
+ sub_dir = "controlnet_nd"
884
+ model1.save_pretrained(os.path.join(output_dir, sub_dir))
885
+
886
+
887
+ def load_model_hook(models, input_dir):
888
+ while len(models) > 0:
889
+ # pop models so that they are not loaded again
890
+ model = models.pop()
891
+
892
+ # load diffusers123 style into model
893
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
894
+ model.register_to_config(**load_model.config)
895
+
896
+ model.load_state_dict(load_model.state_dict())
897
+ del load_model
898
+
899
+ accelerator.register_save_state_pre_hook(save_model_hook)
900
+ accelerator.register_load_state_pre_hook(load_model_hook)
901
+
902
+ vae.requires_grad_(False)
903
+ unet.requires_grad_(False)
904
+ text_encoder.requires_grad_(False)
905
+ controlnet_nd.requires_grad_(True).train()
906
+ controlnet_bg.requires_grad_(True).train()
907
+
908
+ if args.gradient_checkpointing:
909
+ controlnet_nd.enable_gradient_checkpointing()
910
+ controlnet_bg.enable_gradient_checkpointing()
911
+
912
+ # Check that all trainable models are in full precision
913
+ low_precision_error_string = (
914
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
915
+ " doing mixed precision training, copy of the weights should still be float32."
916
+ )
917
+
918
+ if accelerator.unwrap_model(controlnet_nd).dtype != torch.float32:
919
+ raise ValueError(
920
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_nd).dtype}. {low_precision_error_string}"
921
+ )
922
+ if accelerator.unwrap_model(controlnet_bg).dtype != torch.float32:
923
+ raise ValueError(
924
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_bg).dtype}. {low_precision_error_string}"
925
+ )
926
+ # Enable TF32 for faster training on Ampere GPUs,
927
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
928
+ if args.allow_tf32:
929
+ torch.backends.cuda.matmul.allow_tf32 = True
930
+
931
+ if args.scale_lr:
932
+ args.learning_rate = (
933
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
934
+ )
935
+
936
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
937
+ if args.use_8bit_adam:
938
+ try:
939
+ import bitsandbytes as bnb
940
+ except ImportError:
941
+ raise ImportError(
942
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
943
+ )
944
+
945
+ optimizer_class = bnb.optim.AdamW8bit
946
+ else:
947
+ optimizer_class = torch.optim.AdamW
948
+
949
+ # Optimizer creation
950
+ params_to_optimize_nd = controlnet_nd.parameters()
951
+ params_to_optimize_bg = controlnet_bg.parameters()
952
+
953
+ optimizer_nd = optimizer_class(
954
+ params_to_optimize_nd,
955
+ lr=args.learning_rate,
956
+ betas=(args.adam_beta1, args.adam_beta2),
957
+ weight_decay=args.adam_weight_decay,
958
+ eps=args.adam_epsilon,
959
+ )
960
+ optimizer_bg = optimizer_class(
961
+ params_to_optimize_bg,
962
+ lr=args.learning_rate,
963
+ betas=(args.adam_beta1, args.adam_beta2),
964
+ weight_decay=args.adam_weight_decay,
965
+ eps=args.adam_epsilon,
966
+ )
967
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
968
+
969
+ train_dataloader = torch.utils.data.DataLoader(
970
+ train_dataset,
971
+ shuffle=True,
972
+ collate_fn=collate_fn,
973
+ batch_size=args.train_batch_size,
974
+ num_workers=args.dataloader_num_workers,
975
+ )
976
+
977
+ # Scheduler and math around the number of training steps.
978
+ overrode_max_train_steps = False
979
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
980
+ if args.max_train_steps is None:
981
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
982
+ overrode_max_train_steps = True
983
+
984
+ lr_scheduler = get_scheduler(
985
+ args.lr_scheduler,
986
+ optimizer=optimizer_nd,
987
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
988
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
989
+ num_cycles=args.lr_num_cycles,
990
+ power=args.lr_power)
991
+
992
+ # Prepare everything with our `accelerator`.
993
+ controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler = accelerator.prepare(
994
+ controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler
995
+ )
996
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
997
+ # as these models are only used for inference, keeping weights in full precision is not required.
998
+ weight_dtype = torch.float32
999
+ if accelerator.mixed_precision == "fp16":
1000
+ weight_dtype = torch.float16
1001
+ elif accelerator.mixed_precision == "bf16":
1002
+ weight_dtype = torch.bfloat16
1003
+
1004
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
1005
+ vae.to(accelerator.device, dtype=weight_dtype)
1006
+ unet.to(accelerator.device, dtype=weight_dtype)
1007
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1008
+
1009
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1010
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1011
+ if overrode_max_train_steps:
1012
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1013
+ # Afterwards we recalculate our number of training epochs
1014
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1015
+
1016
+ # We need to initialize the trackers we use, and also store our configuration.
1017
+ # The trackers initializes automatically on the main process.
1018
+ if accelerator.is_main_process:
1019
+ tracker_config = dict(vars(args))
1020
+
1021
+ # tensorboard cannot handle list types for config
1022
+ tracker_config.pop("validation_prompt")
1023
+ tracker_config.pop("validation_image")
1024
+
1025
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1026
+
1027
+ # Train!
1028
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1029
+
1030
+ logger.info("***** Running training *****")
1031
+ logger.info(f" Num examples = {len(train_dataset)}")
1032
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1033
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1034
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1035
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1036
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1037
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1038
+ global_step = 0
1039
+ first_epoch = 0
1040
+
1041
+ # Potentially load in the weights and states from a previous save
1042
+ if args.resume_from_checkpoint:
1043
+ if args.resume_from_checkpoint != "latest":
1044
+ path = os.path.basename(args.resume_from_checkpoint)
1045
+ else:
1046
+ # Get the most recent checkpoint
1047
+ dirs = os.listdir(args.output_dir)
1048
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1049
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1050
+ path = dirs[-1] if len(dirs) > 0 else None
1051
+
1052
+ if path is None:
1053
+ accelerator.print(
1054
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1055
+ )
1056
+ args.resume_from_checkpoint = None
1057
+ initial_global_step = 0
1058
+ else:
1059
+ accelerator.print(f"Resuming from checkpoint {path}")
1060
+ accelerator.load_state(os.path.join(args.output_dir, path))
1061
+ global_step = int(path.split("-")[1])
1062
+
1063
+ initial_global_step = global_step
1064
+ first_epoch = global_step // num_update_steps_per_epoch
1065
+ else:
1066
+ initial_global_step = 0
1067
+
1068
+ progress_bar = tqdm(
1069
+ range(0, args.max_train_steps),
1070
+ initial=initial_global_step,
1071
+ desc="Steps",
1072
+ # Only show the progress bar once on each machine.
1073
+ disable=not accelerator.is_local_main_process,
1074
+ )
1075
+
1076
+ image_logs = None
1077
+ for epoch in range(first_epoch, args.num_train_epochs):
1078
+ for step, batch in enumerate(train_dataloader):
1079
+ # with accelerator.accumulate(controlnet_nd):
1080
+ # Convert images to latent space
1081
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1082
+ latents = latents * vae.config.scaling_factor
1083
+ # Sample noise that we'll add to the latents
1084
+ noise = torch.randn_like(latents)
1085
+ bsz = latents.shape[0]
1086
+ # Sample a random timestep for each image
1087
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1088
+ timesteps = timesteps.long()
1089
+ # Add noise to the latents according to the noise magnitude at each timestep
1090
+ # (this is the forward diffusion process)
1091
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1092
+ # Get the text embedding for conditioning
1093
+
1094
+ weight_nd = batch["input_ids_nd"][:, 0]
1095
+ weight_nd = weight_nd / 10**5
1096
+ batch["input_ids_nd"][:, 0] = 49406
1097
+ encoder_hidden_states_nd = text_encoder(batch["input_ids_nd"])[0]
1098
+ encoder_hidden_states_bg = text_encoder(batch["input_ids_bg"])[0]
1099
+ controlnet_image_nd = batch["conditioning_pixel_values_nd"].to(dtype=weight_dtype)
1100
+ controlnet_image_bg = batch["conditioning_pixel_values_bg"].to(dtype=weight_dtype)
1101
+ # print(weight_nd)
1102
+ down_block_res_samples_nd, mid_block_res_sample_nd = controlnet_nd(
1103
+ noisy_latents,
1104
+ timesteps,
1105
+ encoder_hidden_states=encoder_hidden_states_nd, # text
1106
+ controlnet_cond=controlnet_image_nd,
1107
+ return_dict=False,
1108
+ weight=weight_nd)
1109
+
1110
+
1111
+
1112
+ down_block_res_samples_bg, mid_block_res_sample_bg = controlnet_bg(
1113
+ noisy_latents,
1114
+ timesteps,
1115
+ encoder_hidden_states=encoder_hidden_states_bg, # text
1116
+ controlnet_cond=controlnet_image_bg,
1117
+ return_dict=False)
1118
+ # Predict the noise residual
1119
+ samples_nd_list, samples_bg_list = [], []
1120
+ for number in range(len(down_block_res_samples_nd)):
1121
+ if number > 1 :
1122
+ sample = down_block_res_samples_nd[number]
1123
+ samples_nd = torch.stack((down_block_res_samples_nd[number][0].to('cpu'), \
1124
+ down_block_res_samples_nd[number][0].to('cpu')))
1125
+ samples_bg = torch.stack((down_block_res_samples_bg[number][0].to('cpu'), \
1126
+ down_block_res_samples_bg[number][0].to('cpu')))
1127
+ channels = sample.shape[1]
1128
+ model_fuse_down = fuse.AFF(channels=channels).to(device='cpu')
1129
+ output = model_fuse_down(samples_nd, samples_bg)[0].unsqueeze(0)
1130
+
1131
+ samples_nd_list.append(output)
1132
+ samples_bg_list.append(output)
1133
+ else:
1134
+ samples_nd_list.append(down_block_res_samples_nd[number])
1135
+ samples_bg_list.append(down_block_res_samples_bg[number])
1136
+ mid_block_res_sample = mid_block_res_sample_bg + mid_block_res_sample_nd
1137
+ model_pred_nd = unet(
1138
+ noisy_latents,
1139
+ timesteps,
1140
+ encoder_hidden_states=encoder_hidden_states_nd.to('cuda'),
1141
+ down_block_additional_residuals=[
1142
+ sample.to(dtype=weight_dtype).to('cuda') for sample in samples_nd_list],
1143
+ mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype),
1144
+ ).sample
1145
+ model_pred_bg = unet(
1146
+ noisy_latents,
1147
+ timesteps,
1148
+ encoder_hidden_states=encoder_hidden_states_bg.to('cuda'),
1149
+ down_block_additional_residuals=[
1150
+ sample.to(dtype=weight_dtype).to('cuda') for sample in samples_bg_list],
1151
+ mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype),
1152
+ ).sample
1153
+ # Get the target for loss depending on the prediction type
1154
+ if noise_scheduler.config.prediction_type == "epsilon":
1155
+ target = noise
1156
+ elif noise_scheduler.config.prediction_type == "v_prediction": # use
1157
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1158
+ else:
1159
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1160
+ loss_nd = F.mse_loss(model_pred_nd.to('cuda').float(), target.float(), reduction="mean")
1161
+ loss_bg = F.mse_loss(model_pred_bg.to('cuda').float(), target.float(), reduction="mean")
1162
+ optimizer_nd.zero_grad(set_to_none=args.set_grads_to_none)
1163
+ optimizer_bg.zero_grad(set_to_none=args.set_grads_to_none)
1164
+ # h0, h1 = nvmlDeviceGetHandleByIndex(0), nvmlDeviceGetHandleByIndex(1)
1165
+ # info0, info1 = nvmlDeviceGetMemoryInfo(h0), nvmlDeviceGetMemoryInfo(h1)
1166
+ # print(f'0free : {info0.free} 1free : {info1.free}')
1167
+ loss = loss_nd + loss_bg
1168
+ accelerator.backward(loss)
1169
+ # loss_nd.backward()
1170
+ # loss_bg.backward()
1171
+ # if accelerator.sync_gradients:
1172
+ # params_to_clip_nd = controlnet_nd.parameters()
1173
+ # accelerator.clip_grad_norm_(params_to_clip_nd, args.max_grad_norm)
1174
+ # params_to_clip_bg = controlnet_bg.parameters()
1175
+ # accelerator.clip_grad_norm_(params_to_clip_bg, args.max_grad_norm)
1176
+ optimizer_nd.step()
1177
+
1178
+ optimizer_bg.step()
1179
+
1180
+ lr_scheduler.step()
1181
+ # Checks if the accelerator has performed an optimization step behind the scenes
1182
+ if accelerator.sync_gradients:
1183
+ progress_bar.update(1)
1184
+ global_step += 1
1185
+
1186
+ if accelerator.is_main_process:
1187
+ if global_step % args.checkpointing_steps == 0:
1188
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1189
+ if args.checkpoints_total_limit is not None:
1190
+ checkpoints = os.listdir(args.output_dir)
1191
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1192
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1193
+
1194
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1195
+ if len(checkpoints) >= args.checkpoints_total_limit:
1196
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1197
+ removing_checkpoints = checkpoints[0:num_to_remove]
1198
+
1199
+ logger.info(
1200
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1201
+ )
1202
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1203
+
1204
+ for removing_checkpoint in removing_checkpoints:
1205
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1206
+ shutil.rmtree(removing_checkpoint)
1207
+
1208
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1209
+ accelerator.save_state(save_path)
1210
+ logger.info(f"Saved state to {save_path}")
1211
+
1212
+ # if args.validation_prompt is not None :
1213
+ # image_logs = log_validation(
1214
+ # vae,
1215
+ # text_encoder,
1216
+ # tokenizer,
1217
+ # unet,
1218
+ # controlnet_nd,
1219
+ # controlnet_bg,
1220
+ # args,
1221
+ # accelerator,
1222
+ # weight_dtype,
1223
+ # global_step,
1224
+ # )
1225
+
1226
+ logs = {"loss": loss.detach().item()}
1227
+ progress_bar.set_postfix(**logs)
1228
+ accelerator.log(logs, step=global_step)
1229
+
1230
+ if global_step >= args.max_train_steps:
1231
+ break
1232
+
1233
+ # Create the pipeline using using the trained modules and save it.
1234
+ # accelerator.wait_for_everyone()
1235
+ if accelerator.is_main_process:
1236
+ controlnet_nd = accelerator.unwrap_model(controlnet_nd)
1237
+ controlnet_nd.save_pretrained(args.output_dir)
1238
+ controlnet_bg = accelerator.unwrap_model(controlnet_bg)
1239
+ controlnet_bg.save_pretrained(args.output_dir)
1240
+
1241
+ accelerator.end_training()
1242
+
1243
+
1244
+ if __name__ == "__main__":
1245
+ args = parse_args()
1246
+ main(args)
Tiger Model/GP.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ # import os
17
+ # import torch
18
+ # import numpy as np
19
+ # import torchvision.transforms as transforms
20
+ # from torch.utils.data import DataLoader, Dataset
21
+ # from PIL import Image
22
+ # from gtda.images import Binarizer, HeightFiltration
23
+ # from gtda.homology import CubicalPersistence
24
+ # from gtda.diagrams import Amplitude
25
+ # from sklearn.metrics import pairwise_distances
26
+
27
+
28
+ # transform = transforms.Compose([
29
+ # transforms.Resize((256, 256)),
30
+ # transforms.ToTensor(),
31
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+ # ])
33
+
34
+
35
+ # class ImageFolderDataset(Dataset):
36
+ # def __init__(self, folder_path, transform=None):
37
+ # self.file_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.png')]
38
+ # self.transform = transform
39
+
40
+ # def __len__(self):
41
+ # return len(self.file_paths)
42
+
43
+ # def __getitem__(self, idx):
44
+ # img_path = self.file_paths[idx]
45
+ # image = Image.open(img_path).convert('RGB')
46
+ # if self.transform:
47
+ # image = self.transform(image)
48
+ # return image
49
+
50
+ # def load_data(folder_path):
51
+ # dataset = ImageFolderDataset(folder_path, transform=transform)
52
+ # loader = DataLoader(dataset, batch_size=10, shuffle=False)
53
+ # return loader
54
+
55
+ # # 计算Diversity Score
56
+ # def calculate_diversity_score(features):
57
+ # distances = pairwise_distances(features, metric='euclidean')
58
+ # diversity_score = np.mean(distances)
59
+ # return diversity_score
60
+
61
+ # # 计算Geometry Score
62
+ # def calculate_geometry_score(images):
63
+ # binarizer = Binarizer(threshold=0.5)
64
+ # height_filtration = HeightFiltration(direction=np.array([1, 1, 1]))
65
+ # cubical_persistence = CubicalPersistence(homology_dimensions=[0, 1], coeff=2)
66
+ # amplitude = Amplitude(metric='wasserstein', metric_params={'p': 2})
67
+
68
+ # # Preprocess images
69
+ # images = np.array([img.numpy() if isinstance(img, torch.Tensor) else img for img in images])
70
+ # images_binarized = binarizer.fit_transform(images)
71
+ # images_filtered = height_filtration.fit_transform(images_binarized)
72
+ # diagrams = cubical_persistence.fit_transform(images_filtered)
73
+ # gs_score = amplitude.fit_transform(diagrams)
74
+ # return gs_score.mean()
75
+
76
+
77
+ # generated_images_loader = load_data('../figure/1')
78
+ # real_images_loader = load_data('../figure/2')
79
+
80
+
81
+ # generated_features = []
82
+ # real_features = []
83
+
84
+ # for img_batch in generated_images_loader:
85
+ # generated_features.extend(img_batch.numpy())
86
+
87
+ # for img_batch in real_images_loader:
88
+ # real_features.extend(img_batch.numpy())
89
+
90
+ # generated_features = np.array(generated_features)
91
+ # real_features = np.array(real_features)
92
+
93
+ # # 计算Diversity Score
94
+ # generated_div_score = calculate_diversity_score(generated_features.reshape(len(generated_features), -1))
95
+ # real_div_score = calculate_diversity_score(real_features.reshape(len(real_features), -1))
96
+
97
+ # # 计算Geometry Score
98
+ # generated_gs_score = calculate_geometry_score(generated_features)
99
+ # real_gs_score = calculate_geometry_score(real_features)
100
+
101
+ # print(f"Generated Images Diversity Score: {generated_div_score}")
102
+ # print(f"Real Images Diversity Score: {real_div_score}")
103
+ # print(f"Generated Images Geometry Score: {generated_gs_score}")
104
+ # print(f"Real Images Geometry Score: {real_gs_score}")
105
+
106
+
107
+ # import torch
108
+ # import torch.nn.functional as F
109
+ # from torchvision import transforms
110
+ # from PIL import Image
111
+ # import numpy as np
112
+ # import os
113
+
114
+ # # Function to load and preprocess images
115
+ # def load_and_preprocess_image(img_path):
116
+ # img = Image.open(img_path).convert('RGB')
117
+ # preprocess = transforms.Compose([
118
+ # transforms.ToTensor(),
119
+ # ])
120
+ # img = preprocess(img).unsqueeze(0) # Add batch dimension
121
+ # return img
122
+
123
+ # # Function to compute image gradients
124
+ # def compute_gradients(img):
125
+ # grad_x = img[:, :, 1:, :] - img[:, :, :-1, :]
126
+ # grad_y = img[:, :, :, 1:] - img[:, :, :, :-1]
127
+ # return grad_x, grad_y
128
+
129
+ # # Function to calculate Gradient Similarity (GS)
130
+ # def gradient_similarity(real_img, gen_img):
131
+ # real_grad_x, real_grad_y = compute_gradients(real_img)
132
+ # gen_grad_x, gen_grad_y = compute_gradients(gen_img)
133
+
134
+ # grad_sim_x = F.cosine_similarity(real_grad_x, gen_grad_x, dim=1).mean()
135
+ # grad_sim_y = F.cosine_similarity(real_grad_y, gen_grad_y, dim=1).mean()
136
+
137
+ # gs = (grad_sim_x + grad_sim_y) / 2.0
138
+ # return gs.item()
139
+
140
+ # # Example usage
141
+ # real_img_dir = '../GS/real' # Replace with your real image directory
142
+ # gen_img_dir = '../GS/fake' # Replace with your generated image directory
143
+
144
+ # real_img_paths = [os.path.join(real_img_dir, fname) for fname in os.listdir(real_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
145
+ # gen_img_paths = [os.path.join(gen_img_dir, fname) for fname in os.listdir(gen_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
146
+
147
+ # # Ensure both directories have the same number of images
148
+ # assert len(real_img_paths) == len(gen_img_paths), "The number of images in both directories must be the same"
149
+
150
+ # gs_scores = []
151
+
152
+ # for real_img_path, gen_img_path in zip(real_img_paths, gen_img_paths):
153
+ # real_img = load_and_preprocess_image(real_img_path)
154
+ # gen_img = load_and_preprocess_image(gen_img_path)
155
+
156
+ # gs = gradient_similarity(real_img, gen_img)
157
+ # gs_scores.append(gs)
158
+
159
+ # print(f'Processed {real_img_path} and {gen_img_path}: GS = {gs:.3e}')
160
+
161
+ # mean_gs = np.mean(gs_scores)
162
+ # print(f'Mean Gradient Similarity (GS) score: {mean_gs:.3e}')
163
+
164
+
165
+
166
+ import torch
167
+ import torch.nn as nn
168
+ from torchvision import models, transforms
169
+ from PIL import Image
170
+ import numpy as np
171
+ import os
172
+ from prdc import compute_prdc
173
+
174
+ # Function to load and preprocess images
175
+ def load_and_preprocess_image(img_path):
176
+ img = Image.open(img_path).convert('RGB')
177
+ preprocess = transforms.Compose([
178
+ transforms.Resize((299, 299)),
179
+ transforms.ToTensor(),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
181
+ ])
182
+ img = preprocess(img).unsqueeze(0) # Add batch dimension
183
+ return img
184
+
185
+ # Function to extract features using InceptionV3
186
+ def extract_features(img_paths, model):
187
+ features = []
188
+ with torch.no_grad():
189
+ for img_path in img_paths:
190
+ img = load_and_preprocess_image(img_path)
191
+ feature = model(img).numpy().squeeze()
192
+ features.append(feature)
193
+ features = np.array(features)
194
+ return features
195
+
196
+ # Load the InceptionV3 model
197
+ model = models.resnet18(pretrained=False)
198
+ model.load_state_dict(torch.load('../modelsaved/Pretrained_InceptionV3.pth', map_location=lambda storage, loc: storage),strict=False)
199
+ model.fc = nn.Identity() # Remove the final classification layer
200
+ model.eval()
201
+
202
+ # Example usage
203
+ real_img_dir = '../dataset/1' # Replace with your real image directory
204
+ gen_img_dir = '../dataset/2' # Replace with your generated image directory
205
+
206
+ real_img_paths = [os.path.join(real_img_dir, fname) for fname in os.listdir(real_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
207
+ gen_img_paths = [os.path.join(gen_img_dir, fname) for fname in os.listdir(gen_img_dir) if fname.endswith(('jpg', 'jpeg', 'png'))]
208
+
209
+ # Extract features for real and generated images
210
+ real_features = extract_features(real_img_paths, model)
211
+ gen_features = extract_features(gen_img_paths, model)
212
+
213
+ # Calculate PRDC metrics
214
+ metrics = compute_prdc(real_features=real_features,
215
+ fake_features=gen_features,
216
+ nearest_k=2)
217
+
218
+ print(metrics)
219
+
220
+
221
+
222
+ # import torch
223
+ # from torch import nn
224
+ # from clip import clip
225
+ # import numpy as np
226
+
227
+
228
+ # clip_model, preprocess = clip.load("ViT-L/14@336px", device="cuda")
229
+
230
+
231
+ # def get_clip_embedding(images):
232
+ # with torch.no_grad():
233
+ # images = preprocess(images).unsqueeze(0).to("cuda")
234
+ # image_features = clip_model.encode_image(images)
235
+ # return image_features
236
+
237
+
238
+ # def compute_mmd(x, y, kernel):
239
+
240
+ # xx = kernel(x, x)
241
+ # yy = kernel(y, y)
242
+ # xy = kernel(x, y)
243
+
244
+ # mmd = torch.mean(xx) + torch.mean(yy) - 2 * torch.mean(xy)
245
+ # return mmd
246
+
247
+ # def gaussian_rbf_kernel(x, y, sigma=1.0):
248
+
249
+ # dist = torch.cdist(x, y, p=2.0)
250
+
251
+ # return torch.exp(-dist**2 / (2 * sigma**2))
252
+
253
+
254
+ # real_images = ...
255
+ # generated_images = ...
256
+
257
+
258
+ # real_features = get_clip_embedding(real_images)
259
+ # generated_features = get_clip_embedding(generated_images)
260
+
261
+
262
+ # sigma = 1.0
263
+ # mmd = compute_mmd(real_features, generated_features, lambda x, y: gaussian_rbf_kernel(x, y, sigma))
264
+ # cmmd = mmd * 1000
265
+
266
+ # print(f"CMMD: {cmmd.item()}")
Tiger Model/IS.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from datasets import *
16
+
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import torch
20
+ import numpy as np
21
+ import torchvision.transforms as transforms
22
+ from torch.utils.data import DataLoader
23
+ from torch.autograd import Variable
24
+ from torch.nn import functional as F
25
+ import torch.utils.data
26
+ from scipy.stats import entropy
27
+ from torchvision.models.inception import inception_v3
28
+
29
+ import os
30
+ import glob
31
+ import random
32
+ import os
33
+ import numpy as np
34
+
35
+ from torch.utils.data import Dataset
36
+ from PIL import Image
37
+ import torchvision.transforms as transforms
38
+
39
+ class ISImageDataset(Dataset):
40
+ def __init__(self, root, transforms_=None):
41
+ self.transform = transforms.Compose(transforms_)
42
+
43
+ self.files = sorted(glob.glob(os.path.join(root) + "/*.png"))
44
+
45
+ def __getitem__(self, index):
46
+ img = Image.open(self.files[index % len(self.files)]).convert('RGB')
47
+ item_image = self.transform(img)
48
+ return item_image
49
+
50
+ def __len__(self):
51
+ return len(self.files)
52
+
53
+ path = '.../Figure/'
54
+ count = 0
55
+ for root,dirs,files in os.walk(path):
56
+ for each in files:
57
+ count += 1
58
+ print(count)
59
+ batch_size = 64
60
+ transforms_ = [
61
+ transforms.Resize((256, 256)),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
64
+ ]
65
+
66
+ val_dataloader = DataLoader(
67
+ ISImageDataset(path, transforms_=transforms_),
68
+ batch_size = batch_size,
69
+ )
70
+
71
+ cuda = False if torch.cuda.is_available() else False
72
+ print('cuda: ',cuda)
73
+ tensor = torch.cuda.FloatTensor
74
+
75
+ inception_model = inception_v3(pretrained=True, transform_input=False).cuda()
76
+ inception_model.eval()
77
+ up = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False).cuda()
78
+
79
+ def get_pred(x):
80
+ if True:
81
+ x = up(x)
82
+ x = inception_model(x)
83
+ return F.softmax(x, dim=1).data.cpu().numpy()
84
+
85
+ print('Computing predictions using inception v3 model')
86
+ preds = np.zeros((count, 1000))
87
+
88
+ for i, data in enumerate(val_dataloader):
89
+ data = data.type(tensor)
90
+ batch_size_i = data.size()[0]
91
+ preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(data)
92
+
93
+ print('Computing KL Divergence')
94
+ split_scores = []
95
+ splits=10
96
+ N = count
97
+ for k in range(splits):
98
+ part = preds[k * (N // splits): (k + 1) * (N // splits), :]
99
+ py = np.mean(part, axis=0)
100
+ scores = []
101
+ for i in range(part.shape[0]):
102
+ pyx = part[i, :]
103
+ scores.append(entropy(pyx, py))
104
+ split_scores.append(np.exp(np.mean(scores)))
105
+
106
+
107
+ mean, std = np.mean(split_scores), np.std(split_scores)
108
+ print('IS is %.4f' % mean)
109
+ print('The std is %.4f' % std)
Tiger Model/diffusiers-Tiger/CLIPTextModel.py ADDED
@@ -0,0 +1,1326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ """ PyTorch CLIP model."""
18
+
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Any, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+
27
+ from ...activations import ACT2FN
28
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
29
+ from ...modeling_utils import PreTrainedModel
30
+ from ...utils import (
31
+ ModelOutput,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ replace_return_docstrings,
36
+ )
37
+ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
43
+
44
+ CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
+ "openai/clip-vit-base-patch32",
46
+ # See all CLIP models at https://huggingface.co/models?filter=clip
47
+ ]
48
+
49
+
50
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
51
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
52
+ """
53
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
54
+ """
55
+ bsz, src_len = mask.size()
56
+ tgt_len = tgt_len if tgt_len is not None else src_len
57
+
58
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
59
+
60
+ inverted_mask = 1.0 - expanded_mask
61
+
62
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
63
+
64
+
65
+ # contrastive loss function, adapted from
66
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
67
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
68
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
69
+
70
+
71
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
72
+ caption_loss = contrastive_loss(similarity)
73
+ image_loss = contrastive_loss(similarity.t())
74
+ return (caption_loss + image_loss) / 2.0
75
+
76
+
77
+ @dataclass
78
+ class CLIPVisionModelOutput(ModelOutput):
79
+ """
80
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
81
+
82
+ Args:
83
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
84
+ The image embeddings obtained by applying the projection layer to the pooler_output.
85
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
86
+ Sequence of hidden-states at the output of the last layer of the model.
87
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90
+
91
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94
+ sequence_length)`.
95
+
96
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97
+ heads.
98
+ """
99
+
100
+ image_embeds: Optional[torch.FloatTensor] = None
101
+ last_hidden_state: torch.FloatTensor = None
102
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
103
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
104
+
105
+
106
+ @dataclass
107
+ class CLIPTextModelOutput(ModelOutput):
108
+ """
109
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
110
+
111
+ Args:
112
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
113
+ The text embeddings obtained by applying the projection layer to the pooler_output.
114
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
115
+ Sequence of hidden-states at the output of the last layer of the model.
116
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
117
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
118
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
119
+
120
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
121
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
122
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
123
+ sequence_length)`.
124
+
125
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
126
+ heads.
127
+ """
128
+
129
+ text_embeds: Optional[torch.FloatTensor] = None
130
+ last_hidden_state: torch.FloatTensor = None
131
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
132
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
133
+
134
+
135
+ @dataclass
136
+ class CLIPOutput(ModelOutput):
137
+ """
138
+ Args:
139
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
140
+ Contrastive loss for image-text similarity.
141
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
142
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
143
+ similarity scores.
144
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
145
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
146
+ similarity scores.
147
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
148
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
149
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
150
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
151
+ text_model_output(`BaseModelOutputWithPooling`):
152
+ The output of the [`CLIPTextModel`].
153
+ vision_model_output(`BaseModelOutputWithPooling`):
154
+ The output of the [`CLIPVisionModel`].
155
+ """
156
+
157
+ loss: Optional[torch.FloatTensor] = None
158
+ logits_per_image: torch.FloatTensor = None
159
+ logits_per_text: torch.FloatTensor = None
160
+ text_embeds: torch.FloatTensor = None
161
+ image_embeds: torch.FloatTensor = None
162
+ text_model_output: BaseModelOutputWithPooling = None
163
+ vision_model_output: BaseModelOutputWithPooling = None
164
+
165
+ def to_tuple(self) -> Tuple[Any]:
166
+ return tuple(
167
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
168
+ for k in self.keys()
169
+ )
170
+
171
+
172
+ class CLIPVisionEmbeddings(nn.Module):
173
+ def __init__(self, config: CLIPVisionConfig):
174
+ super().__init__()
175
+ self.config = config
176
+ self.embed_dim = config.hidden_size
177
+ self.image_size = config.image_size
178
+ self.patch_size = config.patch_size
179
+
180
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
181
+
182
+ self.patch_embedding = nn.Conv2d(
183
+ in_channels=config.num_channels,
184
+ out_channels=self.embed_dim,
185
+ kernel_size=self.patch_size,
186
+ stride=self.patch_size,
187
+ bias=False,
188
+ )
189
+
190
+ self.num_patches = (self.image_size // self.patch_size) ** 2
191
+ self.num_positions = self.num_patches + 1
192
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
193
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
194
+
195
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
196
+ batch_size = pixel_values.shape[0]
197
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
198
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
199
+
200
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
201
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
202
+ embeddings = embeddings + self.position_embedding(self.position_ids)
203
+ return embeddings
204
+
205
+
206
+ class CLIPTextEmbeddings(nn.Module):
207
+ def __init__(self, config: CLIPTextConfig):
208
+ super().__init__()
209
+ embed_dim = config.hidden_size
210
+
211
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
212
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
213
+
214
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
215
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: Optional[torch.LongTensor] = None,
220
+ position_ids: Optional[torch.LongTensor] = None,
221
+ inputs_embeds: Optional[torch.FloatTensor] = None,
222
+ ) -> torch.Tensor:
223
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
224
+
225
+ if position_ids is None:
226
+ position_ids = self.position_ids[:, :seq_length]
227
+
228
+ if inputs_embeds is None:
229
+ inputs_embeds = self.token_embedding(input_ids)
230
+
231
+ position_embeddings = self.position_embedding(position_ids)
232
+ embeddings = inputs_embeds + position_embeddings
233
+
234
+ return embeddings
235
+
236
+
237
+ class CLIPAttention(nn.Module):
238
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
239
+
240
+ def __init__(self, config):
241
+ super().__init__()
242
+ self.config = config
243
+ self.embed_dim = config.hidden_size
244
+ self.num_heads = config.num_attention_heads
245
+ self.head_dim = self.embed_dim // self.num_heads
246
+ if self.head_dim * self.num_heads != self.embed_dim:
247
+ raise ValueError(
248
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
249
+ f" {self.num_heads})."
250
+ )
251
+ self.scale = self.head_dim**-0.5
252
+ self.dropout = config.attention_dropout
253
+
254
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
255
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
256
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
257
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
258
+
259
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
260
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.Tensor,
265
+ attention_mask: Optional[torch.Tensor] = None,
266
+ causal_attention_mask: Optional[torch.Tensor] = None,
267
+ output_attentions: Optional[bool] = False,
268
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
269
+ """Input shape: Batch x Time x Channel"""
270
+
271
+ bsz, tgt_len, embed_dim = hidden_states.size()
272
+
273
+ # get query proj
274
+ query_states = self.q_proj(hidden_states) * self.scale
275
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
276
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
277
+
278
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
279
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
280
+ key_states = key_states.view(*proj_shape)
281
+ value_states = value_states.view(*proj_shape)
282
+
283
+ src_len = key_states.size(1)
284
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
285
+
286
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
287
+ raise ValueError(
288
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
289
+ f" {attn_weights.size()}"
290
+ )
291
+
292
+ # apply the causal_attention_mask first
293
+ if causal_attention_mask is not None:
294
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
295
+ raise ValueError(
296
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
297
+ f" {causal_attention_mask.size()}"
298
+ )
299
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
300
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
301
+
302
+ if attention_mask is not None:
303
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
304
+ raise ValueError(
305
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
306
+ )
307
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
308
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
309
+
310
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
311
+
312
+ if output_attentions:
313
+ # this operation is a bit akward, but it's required to
314
+ # make sure that attn_weights keeps its gradient.
315
+ # In order to do so, attn_weights have to reshaped
316
+ # twice and have to be reused in the following
317
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
318
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
319
+ else:
320
+ attn_weights_reshaped = None
321
+
322
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
323
+
324
+ attn_output = torch.bmm(attn_probs, value_states)
325
+
326
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
327
+ raise ValueError(
328
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
329
+ f" {attn_output.size()}"
330
+ )
331
+
332
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
333
+ attn_output = attn_output.transpose(1, 2)
334
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
335
+
336
+ attn_output = self.out_proj(attn_output)
337
+
338
+ return attn_output, attn_weights_reshaped
339
+
340
+
341
+ class CLIPMLP(nn.Module):
342
+ def __init__(self, config):
343
+ super().__init__()
344
+ self.config = config
345
+ self.activation_fn = ACT2FN[config.hidden_act]
346
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
347
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
348
+
349
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
350
+ hidden_states = self.fc1(hidden_states)
351
+ hidden_states = self.activation_fn(hidden_states)
352
+ hidden_states = self.fc2(hidden_states)
353
+ return hidden_states
354
+
355
+
356
+ class CLIPEncoderLayer(nn.Module):
357
+ def __init__(self, config: CLIPConfig):
358
+ super().__init__()
359
+ self.embed_dim = config.hidden_size
360
+ self.self_attn = CLIPAttention(config)
361
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
362
+ self.mlp = CLIPMLP(config)
363
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ attention_mask: torch.Tensor,
369
+ causal_attention_mask: torch.Tensor,
370
+ output_attentions: Optional[bool] = False,
371
+ ) -> Tuple[torch.FloatTensor]:
372
+ """
373
+ Args:
374
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
375
+ attention_mask (`torch.FloatTensor`): attention mask of size
376
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
377
+ `(config.encoder_attention_heads,)`.
378
+ output_attentions (`bool`, *optional*):
379
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
380
+ returned tensors for more detail.
381
+ """
382
+ residual = hidden_states
383
+
384
+ hidden_states = self.layer_norm1(hidden_states)
385
+ hidden_states, attn_weights = self.self_attn(
386
+ hidden_states=hidden_states,
387
+ attention_mask=attention_mask,
388
+ causal_attention_mask=causal_attention_mask,
389
+ output_attentions=output_attentions,
390
+ )
391
+ hidden_states = residual + hidden_states
392
+
393
+ residual = hidden_states
394
+ hidden_states = self.layer_norm2(hidden_states)
395
+ hidden_states = self.mlp(hidden_states)
396
+ hidden_states = residual + hidden_states
397
+
398
+ outputs = (hidden_states,)
399
+
400
+ if output_attentions:
401
+ outputs += (attn_weights,)
402
+
403
+ return outputs
404
+
405
+
406
+ class CLIPPreTrainedModel(PreTrainedModel):
407
+ """
408
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
409
+ models.
410
+ """
411
+
412
+ config_class = CLIPConfig
413
+ base_model_prefix = "clip"
414
+ supports_gradient_checkpointing = True
415
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
416
+
417
+ def _init_weights(self, module):
418
+ """Initialize the weights"""
419
+ factor = self.config.initializer_factor
420
+ if isinstance(module, CLIPTextEmbeddings):
421
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
422
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
423
+ elif isinstance(module, CLIPVisionEmbeddings):
424
+ factor = self.config.initializer_factor
425
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
426
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
427
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
428
+ elif isinstance(module, CLIPAttention):
429
+ factor = self.config.initializer_factor
430
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
431
+ out_proj_std = (module.embed_dim**-0.5) * factor
432
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
433
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
434
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
435
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
436
+ elif isinstance(module, CLIPMLP):
437
+ factor = self.config.initializer_factor
438
+ in_proj_std = (
439
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
440
+ )
441
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
442
+ nn.init.normal_(module.fc1.weight, std=fc_std)
443
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
444
+ elif isinstance(module, CLIPModel):
445
+ nn.init.normal_(
446
+ module.text_projection.weight,
447
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
448
+ )
449
+ nn.init.normal_(
450
+ module.visual_projection.weight,
451
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
452
+ )
453
+ elif isinstance(module, CLIPVisionModelWithProjection):
454
+ nn.init.normal_(
455
+ module.visual_projection.weight,
456
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
457
+ )
458
+ elif isinstance(module, CLIPTextModelWithProjection):
459
+ nn.init.normal_(
460
+ module.text_projection.weight,
461
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
462
+ )
463
+
464
+ if isinstance(module, nn.LayerNorm):
465
+ module.bias.data.zero_()
466
+ module.weight.data.fill_(1.0)
467
+ if isinstance(module, nn.Linear) and module.bias is not None:
468
+ module.bias.data.zero_()
469
+
470
+ def _set_gradient_checkpointing(self, module, value=False):
471
+ if isinstance(module, CLIPEncoder):
472
+ module.gradient_checkpointing = value
473
+
474
+
475
+ CLIP_START_DOCSTRING = r"""
476
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
477
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
478
+ etc.)
479
+
480
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
481
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
482
+ and behavior.
483
+
484
+ Parameters:
485
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
486
+ Initializing with a config file does not load the weights associated with the model, only the
487
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
488
+ """
489
+
490
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
491
+ Args:
492
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
493
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
494
+ it.
495
+
496
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
497
+ [`PreTrainedTokenizer.__call__`] for details.
498
+
499
+ [What are input IDs?](../glossary#input-ids)
500
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
501
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
502
+
503
+ - 1 for tokens that are **not masked**,
504
+ - 0 for tokens that are **masked**.
505
+
506
+ [What are attention masks?](../glossary#attention-mask)
507
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
508
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
509
+ config.max_position_embeddings - 1]`.
510
+
511
+ [What are position IDs?](../glossary#position-ids)
512
+ output_attentions (`bool`, *optional*):
513
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
514
+ tensors for more detail.
515
+ output_hidden_states (`bool`, *optional*):
516
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
517
+ more detail.
518
+ return_dict (`bool`, *optional*):
519
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
520
+ """
521
+
522
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
523
+ Args:
524
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
525
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
526
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
527
+ output_attentions (`bool`, *optional*):
528
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
529
+ tensors for more detail.
530
+ output_hidden_states (`bool`, *optional*):
531
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
532
+ more detail.
533
+ return_dict (`bool`, *optional*):
534
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
535
+ """
536
+
537
+ CLIP_INPUTS_DOCSTRING = r"""
538
+ Args:
539
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
540
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
541
+ it.
542
+
543
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
544
+ [`PreTrainedTokenizer.__call__`] for details.
545
+
546
+ [What are input IDs?](../glossary#input-ids)
547
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
548
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
549
+
550
+ - 1 for tokens that are **not masked**,
551
+ - 0 for tokens that are **masked**.
552
+
553
+ [What are attention masks?](../glossary#attention-mask)
554
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
555
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
556
+ config.max_position_embeddings - 1]`.
557
+
558
+ [What are position IDs?](../glossary#position-ids)
559
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
560
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
561
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
562
+ return_loss (`bool`, *optional*):
563
+ Whether or not to return the contrastive loss.
564
+ output_attentions (`bool`, *optional*):
565
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
566
+ tensors for more detail.
567
+ output_hidden_states (`bool`, *optional*):
568
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
569
+ more detail.
570
+ return_dict (`bool`, *optional*):
571
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
572
+ """
573
+
574
+
575
+ class CLIPEncoder(nn.Module):
576
+ """
577
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
578
+ [`CLIPEncoderLayer`].
579
+
580
+ Args:
581
+ config: CLIPConfig
582
+ """
583
+
584
+ def __init__(self, config: CLIPConfig):
585
+ super().__init__()
586
+ self.config = config
587
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
588
+ self.gradient_checkpointing = False
589
+
590
+ def forward(
591
+ self,
592
+ inputs_embeds,
593
+ attention_mask: Optional[torch.Tensor] = None,
594
+ causal_attention_mask: Optional[torch.Tensor] = None,
595
+ output_attentions: Optional[bool] = None,
596
+ output_hidden_states: Optional[bool] = None,
597
+ return_dict: Optional[bool] = None,
598
+ ) -> Union[Tuple, BaseModelOutput]:
599
+ r"""
600
+ Args:
601
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
602
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
603
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
604
+ than the model's internal embedding lookup matrix.
605
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
606
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
607
+
608
+ - 1 for tokens that are **not masked**,
609
+ - 0 for tokens that are **masked**.
610
+
611
+ [What are attention masks?](../glossary#attention-mask)
612
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
613
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
614
+
615
+ - 1 for tokens that are **not masked**,
616
+ - 0 for tokens that are **masked**.
617
+
618
+ [What are attention masks?](../glossary#attention-mask)
619
+ output_attentions (`bool`, *optional*):
620
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
621
+ returned tensors for more detail.
622
+ output_hidden_states (`bool`, *optional*):
623
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
624
+ for more detail.
625
+ return_dict (`bool`, *optional*):
626
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
627
+ """
628
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
629
+ output_hidden_states = (
630
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
631
+ )
632
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
633
+
634
+ encoder_states = () if output_hidden_states else None
635
+ all_attentions = () if output_attentions else None
636
+
637
+ hidden_states = inputs_embeds
638
+ for idx, encoder_layer in enumerate(self.layers):
639
+ if output_hidden_states:
640
+ encoder_states = encoder_states + (hidden_states,)
641
+ if self.gradient_checkpointing and self.training:
642
+
643
+ def create_custom_forward(module):
644
+ def custom_forward(*inputs):
645
+ return module(*inputs, output_attentions)
646
+
647
+ return custom_forward
648
+
649
+ layer_outputs = torch.utils.checkpoint.checkpoint(
650
+ create_custom_forward(encoder_layer),
651
+ hidden_states,
652
+ attention_mask,
653
+ causal_attention_mask,
654
+ )
655
+ else:
656
+ layer_outputs = encoder_layer(
657
+ hidden_states,
658
+ attention_mask,
659
+ causal_attention_mask,
660
+ output_attentions=output_attentions,
661
+ )
662
+
663
+ hidden_states = layer_outputs[0]
664
+
665
+ if output_attentions:
666
+ all_attentions = all_attentions + (layer_outputs[1],)
667
+
668
+ if output_hidden_states:
669
+ encoder_states = encoder_states + (hidden_states,)
670
+
671
+ if not return_dict:
672
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
673
+ return BaseModelOutput(
674
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
675
+ )
676
+
677
+
678
+ class CLIPTextTransformer(nn.Module):
679
+ def __init__(self, config: CLIPTextConfig):
680
+ super().__init__()
681
+ self.config = config
682
+ embed_dim = config.hidden_size
683
+ self.embeddings = CLIPTextEmbeddings(config)
684
+ self.encoder = CLIPEncoder(config)
685
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
686
+
687
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
688
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
689
+ def forward(
690
+ self,
691
+ input_ids: Optional[torch.Tensor] = None,
692
+ attention_mask: Optional[torch.Tensor] = None,
693
+ position_ids: Optional[torch.Tensor] = None,
694
+ output_attentions: Optional[bool] = None,
695
+ output_hidden_states: Optional[bool] = None,
696
+ return_dict: Optional[bool] = None,
697
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
698
+ r"""
699
+ Returns:
700
+
701
+ """
702
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
703
+ output_hidden_states = (
704
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
705
+ )
706
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
707
+
708
+ if input_ids is None:
709
+ raise ValueError("You have to specify input_ids")
710
+
711
+ input_shape = input_ids.size()
712
+ input_ids = input_ids.view(-1, input_shape[-1])
713
+
714
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
715
+
716
+ bsz, seq_len = input_shape
717
+ # CLIP's text model uses causal mask, prepare it here.
718
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
719
+ causal_attention_mask = self._build_causal_attention_mask(
720
+ bsz, seq_len, hidden_states.dtype, device=hidden_states.device
721
+ )
722
+ # expand attention_mask
723
+ if attention_mask is not None:
724
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
725
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
726
+
727
+ encoder_outputs = self.encoder(
728
+ inputs_embeds=hidden_states,
729
+ attention_mask=attention_mask,
730
+ causal_attention_mask=causal_attention_mask,
731
+ output_attentions=output_attentions,
732
+ output_hidden_states=output_hidden_states,
733
+ return_dict=return_dict,
734
+ )
735
+
736
+ last_hidden_state = encoder_outputs[0]
737
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
738
+
739
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
740
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
741
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
742
+ pooled_output = last_hidden_state[
743
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
744
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
745
+ ]
746
+
747
+ if not return_dict:
748
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
749
+
750
+ return BaseModelOutputWithPooling(
751
+ last_hidden_state=last_hidden_state,
752
+ pooler_output=pooled_output,
753
+ hidden_states=encoder_outputs.hidden_states,
754
+ attentions=encoder_outputs.attentions,
755
+ )
756
+
757
+ def _build_causal_attention_mask(self, bsz, seq_len, dtype, device=None):
758
+ # lazily create causal attention mask, with full attention between the vision tokens
759
+ # pytorch uses additive attention mask; fill with -inf
760
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=device)
761
+ mask.fill_(torch.finfo(dtype).min)
762
+ mask.triu_(1) # zero out the lower diagonal
763
+ mask = mask.unsqueeze(1) # expand mask
764
+ return mask
765
+
766
+
767
+ @add_start_docstrings(
768
+ """The text model from CLIP without any head or projection on top.""",
769
+ CLIP_START_DOCSTRING,
770
+ )
771
+ class CLIPTextModel(CLIPPreTrainedModel):
772
+ config_class = CLIPTextConfig
773
+
774
+ _no_split_modules = ["CLIPEncoderLayer"]
775
+
776
+ def __init__(self, config: CLIPTextConfig):
777
+ super().__init__(config)
778
+ self.text_model = CLIPTextTransformer(config)
779
+ # Initialize weights and apply final processing
780
+ self.post_init()
781
+
782
+ def get_input_embeddings(self) -> nn.Module:
783
+ return self.text_model.embeddings.token_embedding
784
+
785
+ def set_input_embeddings(self, value):
786
+ self.text_model.embeddings.token_embedding = value
787
+
788
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
789
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
790
+ def forward(
791
+ self,
792
+ input_ids: Optional[torch.Tensor] = None,
793
+ attention_mask: Optional[torch.Tensor] = None,
794
+ position_ids: Optional[torch.Tensor] = None,
795
+ output_attentions: Optional[bool] = None,
796
+ output_hidden_states: Optional[bool] = None,
797
+ return_dict: Optional[bool] = None,
798
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
799
+ r"""
800
+ Returns:
801
+
802
+ Examples:
803
+
804
+ ```python
805
+ >>> from transformers import AutoTokenizer, CLIPTextModel
806
+
807
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
808
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
809
+
810
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
811
+
812
+ >>> outputs = model(**inputs)
813
+ >>> last_hidden_state = outputs.last_hidden_state
814
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
815
+ ```"""
816
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
817
+
818
+ return self.text_model(
819
+ input_ids=input_ids,
820
+ attention_mask=attention_mask,
821
+ position_ids=position_ids,
822
+ output_attentions=output_attentions,
823
+ output_hidden_states=output_hidden_states,
824
+ return_dict=return_dict,
825
+ )
826
+
827
+
828
+ class CLIPVisionTransformer(nn.Module):
829
+ def __init__(self, config: CLIPVisionConfig):
830
+ super().__init__()
831
+ self.config = config
832
+ embed_dim = config.hidden_size
833
+
834
+ self.embeddings = CLIPVisionEmbeddings(config)
835
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
836
+ self.encoder = CLIPEncoder(config)
837
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
838
+
839
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
840
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
841
+ def forward(
842
+ self,
843
+ pixel_values: Optional[torch.FloatTensor] = None,
844
+ output_attentions: Optional[bool] = None,
845
+ output_hidden_states: Optional[bool] = None,
846
+ return_dict: Optional[bool] = None,
847
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
848
+ r"""
849
+ Returns:
850
+
851
+ """
852
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
853
+ output_hidden_states = (
854
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
855
+ )
856
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
857
+
858
+ if pixel_values is None:
859
+ raise ValueError("You have to specify pixel_values")
860
+
861
+ hidden_states = self.embeddings(pixel_values)
862
+ hidden_states = self.pre_layrnorm(hidden_states)
863
+
864
+ encoder_outputs = self.encoder(
865
+ inputs_embeds=hidden_states,
866
+ output_attentions=output_attentions,
867
+ output_hidden_states=output_hidden_states,
868
+ return_dict=return_dict,
869
+ )
870
+
871
+ last_hidden_state = encoder_outputs[0]
872
+ pooled_output = last_hidden_state[:, 0, :]
873
+ pooled_output = self.post_layernorm(pooled_output)
874
+
875
+ if not return_dict:
876
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
877
+
878
+ return BaseModelOutputWithPooling(
879
+ last_hidden_state=last_hidden_state,
880
+ pooler_output=pooled_output,
881
+ hidden_states=encoder_outputs.hidden_states,
882
+ attentions=encoder_outputs.attentions,
883
+ )
884
+
885
+
886
+ @add_start_docstrings(
887
+ """The vision model from CLIP without any head or projection on top.""",
888
+ CLIP_START_DOCSTRING,
889
+ )
890
+ class CLIPVisionModel(CLIPPreTrainedModel):
891
+ config_class = CLIPVisionConfig
892
+ main_input_name = "pixel_values"
893
+
894
+ def __init__(self, config: CLIPVisionConfig):
895
+ super().__init__(config)
896
+ self.vision_model = CLIPVisionTransformer(config)
897
+ # Initialize weights and apply final processing
898
+ self.post_init()
899
+
900
+ def get_input_embeddings(self) -> nn.Module:
901
+ return self.vision_model.embeddings.patch_embedding
902
+
903
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
904
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
905
+ def forward(
906
+ self,
907
+ pixel_values: Optional[torch.FloatTensor] = None,
908
+ output_attentions: Optional[bool] = None,
909
+ output_hidden_states: Optional[bool] = None,
910
+ return_dict: Optional[bool] = None,
911
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
912
+ r"""
913
+ Returns:
914
+
915
+ Examples:
916
+
917
+ ```python
918
+ >>> from PIL import Image
919
+ >>> import requests
920
+ >>> from transformers import AutoProcessor, CLIPVisionModel
921
+
922
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
923
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
924
+
925
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
926
+ >>> image = Image.open(requests.get(url, stream=True).raw)
927
+
928
+ >>> inputs = processor(images=image, return_tensors="pt")
929
+
930
+ >>> outputs = model(**inputs)
931
+ >>> last_hidden_state = outputs.last_hidden_state
932
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
933
+ ```"""
934
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
935
+
936
+ return self.vision_model(
937
+ pixel_values=pixel_values,
938
+ output_attentions=output_attentions,
939
+ output_hidden_states=output_hidden_states,
940
+ return_dict=return_dict,
941
+ )
942
+
943
+
944
+ @add_start_docstrings(CLIP_START_DOCSTRING)
945
+ class CLIPModel(CLIPPreTrainedModel):
946
+ config_class = CLIPConfig
947
+
948
+ def __init__(self, config: CLIPConfig):
949
+ super().__init__(config)
950
+
951
+ if not isinstance(config.text_config, CLIPTextConfig):
952
+ raise ValueError(
953
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
954
+ f" {type(config.text_config)}."
955
+ )
956
+
957
+ if not isinstance(config.vision_config, CLIPVisionConfig):
958
+ raise ValueError(
959
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
960
+ f" {type(config.vision_config)}."
961
+ )
962
+
963
+ text_config = config.text_config
964
+ vision_config = config.vision_config
965
+
966
+ self.projection_dim = config.projection_dim
967
+ self.text_embed_dim = text_config.hidden_size
968
+ self.vision_embed_dim = vision_config.hidden_size
969
+
970
+ self.text_model = CLIPTextTransformer(text_config)
971
+ self.vision_model = CLIPVisionTransformer(vision_config)
972
+
973
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
974
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
975
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
976
+
977
+ # Initialize weights and apply final processing
978
+ self.post_init()
979
+
980
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
981
+ def get_text_features(
982
+ self,
983
+ input_ids: Optional[torch.Tensor] = None,
984
+ attention_mask: Optional[torch.Tensor] = None,
985
+ position_ids: Optional[torch.Tensor] = None,
986
+ output_attentions: Optional[bool] = None,
987
+ output_hidden_states: Optional[bool] = None,
988
+ return_dict: Optional[bool] = None,
989
+ ) -> torch.FloatTensor:
990
+ r"""
991
+ Returns:
992
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
993
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
994
+
995
+ Examples:
996
+
997
+ ```python
998
+ >>> from transformers import AutoTokenizer, CLIPModel
999
+
1000
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1001
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1002
+
1003
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1004
+ >>> text_features = model.get_text_features(**inputs)
1005
+ ```"""
1006
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1007
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1008
+ output_hidden_states = (
1009
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1010
+ )
1011
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1012
+
1013
+ text_outputs = self.text_model(
1014
+ input_ids=input_ids,
1015
+ attention_mask=attention_mask,
1016
+ position_ids=position_ids,
1017
+ output_attentions=output_attentions,
1018
+ output_hidden_states=output_hidden_states,
1019
+ return_dict=return_dict,
1020
+ )
1021
+
1022
+ pooled_output = text_outputs[1]
1023
+ text_features = self.text_projection(pooled_output)
1024
+
1025
+ return text_features
1026
+
1027
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1028
+ def get_image_features(
1029
+ self,
1030
+ pixel_values: Optional[torch.FloatTensor] = None,
1031
+ output_attentions: Optional[bool] = None,
1032
+ output_hidden_states: Optional[bool] = None,
1033
+ return_dict: Optional[bool] = None,
1034
+ ) -> torch.FloatTensor:
1035
+ r"""
1036
+ Returns:
1037
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1038
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
1039
+
1040
+ Examples:
1041
+
1042
+ ```python
1043
+ >>> from PIL import Image
1044
+ >>> import requests
1045
+ >>> from transformers import AutoProcessor, CLIPModel
1046
+
1047
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1048
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1049
+
1050
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1051
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1052
+
1053
+ >>> inputs = processor(images=image, return_tensors="pt")
1054
+
1055
+ >>> image_features = model.get_image_features(**inputs)
1056
+ ```"""
1057
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1058
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1059
+ output_hidden_states = (
1060
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1061
+ )
1062
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1063
+
1064
+ vision_outputs = self.vision_model(
1065
+ pixel_values=pixel_values,
1066
+ output_attentions=output_attentions,
1067
+ output_hidden_states=output_hidden_states,
1068
+ return_dict=return_dict,
1069
+ )
1070
+
1071
+ pooled_output = vision_outputs[1] # pooled_output
1072
+ image_features = self.visual_projection(pooled_output)
1073
+
1074
+ return image_features
1075
+
1076
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1077
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
1078
+ def forward(
1079
+ self,
1080
+ input_ids: Optional[torch.LongTensor] = None,
1081
+ pixel_values: Optional[torch.FloatTensor] = None,
1082
+ attention_mask: Optional[torch.Tensor] = None,
1083
+ position_ids: Optional[torch.LongTensor] = None,
1084
+ return_loss: Optional[bool] = None,
1085
+ output_attentions: Optional[bool] = None,
1086
+ output_hidden_states: Optional[bool] = None,
1087
+ return_dict: Optional[bool] = None,
1088
+ ) -> Union[Tuple, CLIPOutput]:
1089
+ r"""
1090
+ Returns:
1091
+
1092
+ Examples:
1093
+
1094
+ ```python
1095
+ >>> from PIL import Image
1096
+ >>> import requests
1097
+ >>> from transformers import AutoProcessor, CLIPModel
1098
+
1099
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1100
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1101
+
1102
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1103
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1104
+
1105
+ >>> inputs = processor(
1106
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1107
+ ... )
1108
+
1109
+ >>> outputs = model(**inputs)
1110
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1111
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1112
+ ```"""
1113
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1114
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1115
+ output_hidden_states = (
1116
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1117
+ )
1118
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1119
+
1120
+ vision_outputs = self.vision_model(
1121
+ pixel_values=pixel_values,
1122
+ output_attentions=output_attentions,
1123
+ output_hidden_states=output_hidden_states,
1124
+ return_dict=return_dict,
1125
+ )
1126
+
1127
+ text_outputs = self.text_model(
1128
+ input_ids=input_ids,
1129
+ attention_mask=attention_mask,
1130
+ position_ids=position_ids,
1131
+ output_attentions=output_attentions,
1132
+ output_hidden_states=output_hidden_states,
1133
+ return_dict=return_dict,
1134
+ )
1135
+
1136
+ image_embeds = vision_outputs[1]
1137
+ image_embeds = self.visual_projection(image_embeds)
1138
+
1139
+ text_embeds = text_outputs[1]
1140
+ text_embeds = self.text_projection(text_embeds)
1141
+
1142
+ # normalized features
1143
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1144
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1145
+
1146
+ # cosine similarity as logits
1147
+ logit_scale = self.logit_scale.exp()
1148
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1149
+ logits_per_image = logits_per_text.t()
1150
+
1151
+ loss = None
1152
+ if return_loss:
1153
+ loss = clip_loss(logits_per_text)
1154
+
1155
+ if not return_dict:
1156
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1157
+ return ((loss,) + output) if loss is not None else output
1158
+
1159
+ return CLIPOutput(
1160
+ loss=loss,
1161
+ logits_per_image=logits_per_image,
1162
+ logits_per_text=logits_per_text,
1163
+ text_embeds=text_embeds,
1164
+ image_embeds=image_embeds,
1165
+ text_model_output=text_outputs,
1166
+ vision_model_output=vision_outputs,
1167
+ )
1168
+
1169
+
1170
+ @add_start_docstrings(
1171
+ """
1172
+ CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
1173
+ """,
1174
+ CLIP_START_DOCSTRING,
1175
+ )
1176
+ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
1177
+ config_class = CLIPTextConfig
1178
+
1179
+ _no_split_modules = ["CLIPEncoderLayer"]
1180
+
1181
+ def __init__(self, config: CLIPTextConfig):
1182
+ super().__init__(config)
1183
+
1184
+ self.text_model = CLIPTextTransformer(config)
1185
+
1186
+ self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1187
+
1188
+ # Initialize weights and apply final processing
1189
+ self.post_init()
1190
+
1191
+ def get_input_embeddings(self) -> nn.Module:
1192
+ return self.text_model.embeddings.token_embedding
1193
+
1194
+ def set_input_embeddings(self, value):
1195
+ self.text_model.embeddings.token_embedding = value
1196
+
1197
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1198
+ @replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
1199
+ def forward(
1200
+ self,
1201
+ input_ids: Optional[torch.Tensor] = None,
1202
+ attention_mask: Optional[torch.Tensor] = None,
1203
+ position_ids: Optional[torch.Tensor] = None,
1204
+ output_attentions: Optional[bool] = None,
1205
+ output_hidden_states: Optional[bool] = None,
1206
+ return_dict: Optional[bool] = None,
1207
+ ) -> Union[Tuple, CLIPTextModelOutput]:
1208
+ r"""
1209
+ Returns:
1210
+
1211
+ Examples:
1212
+
1213
+ ```python
1214
+ >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
1215
+
1216
+ >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1217
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1218
+
1219
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1220
+
1221
+ >>> outputs = model(**inputs)
1222
+ >>> text_embeds = outputs.text_embeds
1223
+ ```"""
1224
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1225
+
1226
+ text_outputs = self.text_model(
1227
+ input_ids=input_ids,
1228
+ attention_mask=attention_mask,
1229
+ position_ids=position_ids,
1230
+ output_attentions=output_attentions,
1231
+ output_hidden_states=output_hidden_states,
1232
+ return_dict=return_dict,
1233
+ )
1234
+
1235
+ pooled_output = text_outputs[1]
1236
+
1237
+ text_embeds = self.text_projection(pooled_output)
1238
+
1239
+ if not return_dict:
1240
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
1241
+ return tuple(output for output in outputs if output is not None)
1242
+
1243
+ return CLIPTextModelOutput(
1244
+ text_embeds=text_embeds,
1245
+ last_hidden_state=text_outputs.last_hidden_state,
1246
+ hidden_states=text_outputs.hidden_states,
1247
+ attentions=text_outputs.attentions,
1248
+ )
1249
+
1250
+
1251
+ @add_start_docstrings(
1252
+ """
1253
+ CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
1254
+ """,
1255
+ CLIP_START_DOCSTRING,
1256
+ )
1257
+ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
1258
+ config_class = CLIPVisionConfig
1259
+ main_input_name = "pixel_values"
1260
+
1261
+ def __init__(self, config: CLIPVisionConfig):
1262
+ super().__init__(config)
1263
+
1264
+ self.vision_model = CLIPVisionTransformer(config)
1265
+
1266
+ self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1267
+
1268
+ # Initialize weights and apply final processing
1269
+ self.post_init()
1270
+
1271
+ def get_input_embeddings(self) -> nn.Module:
1272
+ return self.vision_model.embeddings.patch_embedding
1273
+
1274
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1275
+ @replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
1276
+ def forward(
1277
+ self,
1278
+ pixel_values: Optional[torch.FloatTensor] = None,
1279
+ output_attentions: Optional[bool] = None,
1280
+ output_hidden_states: Optional[bool] = None,
1281
+ return_dict: Optional[bool] = None,
1282
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
1283
+ r"""
1284
+ Returns:
1285
+
1286
+ Examples:
1287
+
1288
+ ```python
1289
+ >>> from PIL import Image
1290
+ >>> import requests
1291
+ >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
1292
+
1293
+ >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1294
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1295
+
1296
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1297
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1298
+
1299
+ >>> inputs = processor(images=image, return_tensors="pt")
1300
+
1301
+ >>> outputs = model(**inputs)
1302
+ >>> image_embeds = outputs.image_embeds
1303
+ ```"""
1304
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1305
+
1306
+ vision_outputs = self.vision_model(
1307
+ pixel_values=pixel_values,
1308
+ output_attentions=output_attentions,
1309
+ output_hidden_states=output_hidden_states,
1310
+ return_dict=return_dict,
1311
+ )
1312
+
1313
+ pooled_output = vision_outputs[1] # pooled_output
1314
+
1315
+ image_embeds = self.visual_projection(pooled_output)
1316
+
1317
+ if not return_dict:
1318
+ outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
1319
+ return tuple(output for output in outputs if output is not None)
1320
+
1321
+ return CLIPVisionModelOutput(
1322
+ image_embeds=image_embeds,
1323
+ last_hidden_state=vision_outputs.last_hidden_state,
1324
+ hidden_states=vision_outputs.hidden_states,
1325
+ attentions=vision_outputs.attentions,
1326
+ )
Tiger Model/diffusiers-Tiger/__init__.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.21.0.dev0"
2
+
3
+ from .configuration_utils import ConfigMixin
4
+ from .utils import (
5
+ OptionalDependencyNotAvailable,
6
+ is_flax_available,
7
+ is_inflect_available,
8
+ is_invisible_watermark_available,
9
+ is_k_diffusion_available,
10
+ is_k_diffusion_version,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_torch_available,
16
+ is_torchsde_available,
17
+ is_transformers_available,
18
+ is_transformers_version,
19
+ is_unidecode_available,
20
+ logging,
21
+ )
22
+
23
+
24
+ try:
25
+ if not is_onnx_available():
26
+ raise OptionalDependencyNotAvailable()
27
+ except OptionalDependencyNotAvailable:
28
+ from .utils.dummy_onnx_objects import * # noqa F403
29
+ else:
30
+ from .pipelines import OnnxRuntimeModel
31
+
32
+ try:
33
+ if not is_torch_available():
34
+ raise OptionalDependencyNotAvailable()
35
+ except OptionalDependencyNotAvailable:
36
+ from .utils.dummy_pt_objects import * # noqa F403
37
+ else:
38
+ from .models import (
39
+ AsymmetricAutoencoderKL,
40
+ AutoencoderKL,
41
+ AutoencoderTiny,
42
+ ControlNetModel,
43
+ ModelMixin,
44
+ MultiAdapter,
45
+ PriorTransformer,
46
+ T2IAdapter,
47
+ T5FilmDecoder,
48
+ Transformer2DModel,
49
+ UNet1DModel,
50
+ UNet2DConditionModel,
51
+ UNet2DModel,
52
+ UNet3DConditionModel,
53
+ VQModel,
54
+ )
55
+ from .optimization import (
56
+ get_constant_schedule,
57
+ get_constant_schedule_with_warmup,
58
+ get_cosine_schedule_with_warmup,
59
+ get_cosine_with_hard_restarts_schedule_with_warmup,
60
+ get_linear_schedule_with_warmup,
61
+ get_polynomial_decay_schedule_with_warmup,
62
+ get_scheduler,
63
+ )
64
+ from .pipelines import (
65
+ AudioPipelineOutput,
66
+ AutoPipelineForImage2Image,
67
+ AutoPipelineForInpainting,
68
+ AutoPipelineForText2Image,
69
+ ConsistencyModelPipeline,
70
+ DanceDiffusionPipeline,
71
+ DDIMPipeline,
72
+ DDPMPipeline,
73
+ DiffusionPipeline,
74
+ DiTPipeline,
75
+ ImagePipelineOutput,
76
+ KarrasVePipeline,
77
+ LDMPipeline,
78
+ LDMSuperResolutionPipeline,
79
+ PNDMPipeline,
80
+ RePaintPipeline,
81
+ ScoreSdeVePipeline,
82
+ )
83
+ from .schedulers import (
84
+ CMStochasticIterativeScheduler,
85
+ DDIMInverseScheduler,
86
+ DDIMParallelScheduler,
87
+ DDIMScheduler,
88
+ DDPMParallelScheduler,
89
+ DDPMScheduler,
90
+ DEISMultistepScheduler,
91
+ DPMSolverMultistepInverseScheduler,
92
+ DPMSolverMultistepScheduler,
93
+ DPMSolverSinglestepScheduler,
94
+ EulerAncestralDiscreteScheduler,
95
+ EulerDiscreteScheduler,
96
+ HeunDiscreteScheduler,
97
+ IPNDMScheduler,
98
+ KarrasVeScheduler,
99
+ KDPM2AncestralDiscreteScheduler,
100
+ KDPM2DiscreteScheduler,
101
+ PNDMScheduler,
102
+ RePaintScheduler,
103
+ SchedulerMixin,
104
+ ScoreSdeVeScheduler,
105
+ UnCLIPScheduler,
106
+ UniPCMultistepScheduler,
107
+ VQDiffusionScheduler,
108
+ )
109
+ from .training_utils import EMAModel
110
+
111
+ try:
112
+ if not (is_torch_available() and is_scipy_available()):
113
+ raise OptionalDependencyNotAvailable()
114
+ except OptionalDependencyNotAvailable:
115
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
116
+ else:
117
+ from .schedulers import LMSDiscreteScheduler
118
+
119
+ try:
120
+ if not (is_torch_available() and is_torchsde_available()):
121
+ raise OptionalDependencyNotAvailable()
122
+ except OptionalDependencyNotAvailable:
123
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
124
+ else:
125
+ from .schedulers import DPMSolverSDEScheduler
126
+
127
+ try:
128
+ if not (is_torch_available() and is_transformers_available()):
129
+ raise OptionalDependencyNotAvailable()
130
+ except OptionalDependencyNotAvailable:
131
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
132
+ else:
133
+ from .pipelines import (
134
+ AltDiffusionImg2ImgPipeline,
135
+ AltDiffusionPipeline,
136
+ AudioLDMPipeline,
137
+ CycleDiffusionPipeline,
138
+ IFImg2ImgPipeline,
139
+ IFImg2ImgSuperResolutionPipeline,
140
+ IFInpaintingPipeline,
141
+ IFInpaintingSuperResolutionPipeline,
142
+ IFPipeline,
143
+ IFSuperResolutionPipeline,
144
+ ImageTextPipelineOutput,
145
+ KandinskyCombinedPipeline,
146
+ KandinskyImg2ImgCombinedPipeline,
147
+ KandinskyImg2ImgPipeline,
148
+ KandinskyInpaintCombinedPipeline,
149
+ KandinskyInpaintPipeline,
150
+ KandinskyPipeline,
151
+ KandinskyPriorPipeline,
152
+ KandinskyV22CombinedPipeline,
153
+ KandinskyV22ControlnetImg2ImgPipeline,
154
+ KandinskyV22ControlnetPipeline,
155
+ KandinskyV22Img2ImgCombinedPipeline,
156
+ KandinskyV22Img2ImgPipeline,
157
+ KandinskyV22InpaintCombinedPipeline,
158
+ KandinskyV22InpaintPipeline,
159
+ KandinskyV22Pipeline,
160
+ KandinskyV22PriorEmb2EmbPipeline,
161
+ KandinskyV22PriorPipeline,
162
+ LDMTextToImagePipeline,
163
+ PaintByExamplePipeline,
164
+ SemanticStableDiffusionPipeline,
165
+ ShapEImg2ImgPipeline,
166
+ ShapEPipeline,
167
+ StableDiffusionAdapterPipeline,
168
+ StableDiffusionAttendAndExcitePipeline,
169
+ StableDiffusionControlNetImg2ImgPipeline,
170
+ StableDiffusionControlNetInpaintPipeline,
171
+ StableDiffusionControlNetPipeline,
172
+ StableDiffusionDepth2ImgPipeline,
173
+ StableDiffusionDiffEditPipeline,
174
+ StableDiffusionGLIGENPipeline,
175
+ StableDiffusionImageVariationPipeline,
176
+ StableDiffusionImg2ImgPipeline,
177
+ StableDiffusionInpaintPipeline,
178
+ StableDiffusionInpaintPipelineLegacy,
179
+ StableDiffusionInstructPix2PixPipeline,
180
+ StableDiffusionLatentUpscalePipeline,
181
+ StableDiffusionLDM3DPipeline,
182
+ StableDiffusionModelEditingPipeline,
183
+ StableDiffusionPanoramaPipeline,
184
+ StableDiffusionParadigmsPipeline,
185
+ StableDiffusionPipeline,
186
+ StableDiffusionPipelineSafe,
187
+ StableDiffusionPix2PixZeroPipeline,
188
+ StableDiffusionSAGPipeline,
189
+ StableDiffusionUpscalePipeline,
190
+ StableDiffusionXLControlNetPipeline,
191
+ StableDiffusionXLImg2ImgPipeline,
192
+ StableDiffusionXLInpaintPipeline,
193
+ StableDiffusionXLInstructPix2PixPipeline,
194
+ StableDiffusionXLPipeline,
195
+ StableUnCLIPImg2ImgPipeline,
196
+ StableUnCLIPPipeline,
197
+ TextToVideoSDPipeline,
198
+ TextToVideoZeroPipeline,
199
+ UnCLIPImageVariationPipeline,
200
+ UnCLIPPipeline,
201
+ UniDiffuserModel,
202
+ UniDiffuserPipeline,
203
+ UniDiffuserTextDecoder,
204
+ VersatileDiffusionDualGuidedPipeline,
205
+ VersatileDiffusionImageVariationPipeline,
206
+ VersatileDiffusionPipeline,
207
+ VersatileDiffusionTextToImagePipeline,
208
+ VideoToVideoSDPipeline,
209
+ VQDiffusionPipeline,
210
+ )
211
+
212
+ try:
213
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
214
+ raise OptionalDependencyNotAvailable()
215
+ except OptionalDependencyNotAvailable:
216
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
217
+ else:
218
+ from .pipelines import StableDiffusionKDiffusionPipeline
219
+
220
+ try:
221
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
222
+ raise OptionalDependencyNotAvailable()
223
+ except OptionalDependencyNotAvailable:
224
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
225
+ else:
226
+ from .pipelines import (
227
+ OnnxStableDiffusionImg2ImgPipeline,
228
+ OnnxStableDiffusionInpaintPipeline,
229
+ OnnxStableDiffusionInpaintPipelineLegacy,
230
+ OnnxStableDiffusionPipeline,
231
+ OnnxStableDiffusionUpscalePipeline,
232
+ StableDiffusionOnnxPipeline,
233
+ )
234
+
235
+ try:
236
+ if not (is_torch_available() and is_librosa_available()):
237
+ raise OptionalDependencyNotAvailable()
238
+ except OptionalDependencyNotAvailable:
239
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
240
+ else:
241
+ from .pipelines import AudioDiffusionPipeline, Mel
242
+
243
+ try:
244
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
245
+ raise OptionalDependencyNotAvailable()
246
+ except OptionalDependencyNotAvailable:
247
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
248
+ else:
249
+ from .pipelines import SpectrogramDiffusionPipeline
250
+
251
+ try:
252
+ if not is_flax_available():
253
+ raise OptionalDependencyNotAvailable()
254
+ except OptionalDependencyNotAvailable:
255
+ from .utils.dummy_flax_objects import * # noqa F403
256
+ else:
257
+ from .models.controlnet_flax import FlaxControlNetModel
258
+ from .models.modeling_flax_utils import FlaxModelMixin
259
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
260
+ from .models.vae_flax import FlaxAutoencoderKL
261
+ from .pipelines import FlaxDiffusionPipeline
262
+ from .schedulers import (
263
+ FlaxDDIMScheduler,
264
+ FlaxDDPMScheduler,
265
+ FlaxDPMSolverMultistepScheduler,
266
+ FlaxKarrasVeScheduler,
267
+ FlaxLMSDiscreteScheduler,
268
+ FlaxPNDMScheduler,
269
+ FlaxSchedulerMixin,
270
+ FlaxScoreSdeVeScheduler,
271
+ )
272
+
273
+
274
+ try:
275
+ if not (is_flax_available() and is_transformers_available()):
276
+ raise OptionalDependencyNotAvailable()
277
+ except OptionalDependencyNotAvailable:
278
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
279
+ else:
280
+ from .pipelines import (
281
+ FlaxStableDiffusionControlNetPipeline,
282
+ FlaxStableDiffusionImg2ImgPipeline,
283
+ FlaxStableDiffusionInpaintPipeline,
284
+ FlaxStableDiffusionPipeline,
285
+ )
286
+
287
+ try:
288
+ if not (is_note_seq_available()):
289
+ raise OptionalDependencyNotAvailable()
290
+ except OptionalDependencyNotAvailable:
291
+ from .utils.dummy_note_seq_objects import * # noqa F403
292
+ else:
293
+ from .pipelines import MidiProcessor
Tiger Model/diffusiers-Tiger/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (8.47 kB). View file
 
Tiger Model/diffusiers-Tiger/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (24 kB). View file
 
Tiger Model/diffusiers-Tiger/__pycache__/fuse.cpython-38.pyc ADDED
Binary file (3.83 kB). View file
 
Tiger Model/diffusiers-Tiger/__pycache__/image_processor.cpython-38.pyc ADDED
Binary file (12.7 kB). View file
 
Tiger Model/diffusiers-Tiger/__pycache__/loaders.cpython-38.pyc ADDED
Binary file (78.3 kB). View file
 
Tiger Model/diffusiers-Tiger/__pycache__/optimization.cpython-38.pyc ADDED
Binary file (12.8 kB). View file
 
Tiger Model/diffusiers-Tiger/__pycache__/training_utils.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
Tiger Model/diffusiers-Tiger/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
Tiger Model/diffusiers-Tiger/commands/diffusers_cli.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+ from .fp16_safetensors import FP16SafetensorsCommand
20
+
21
+
22
+ def main():
23
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
24
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
25
+
26
+ # Register commands
27
+ EnvironmentCommand.register_subcommand(commands_parser)
28
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
29
+
30
+ # Let's go
31
+ args = parser.parse_args()
32
+
33
+ if not hasattr(args, "func"):
34
+ parser.print_help()
35
+ exit(1)
36
+
37
+ # Run
38
+ service = args.func(args)
39
+ service.run()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
Tiger Model/diffusiers-Tiger/commands/env.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available():
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ accelerate_version = "not installed"
53
+ if is_accelerate_available():
54
+ import accelerate
55
+
56
+ accelerate_version = accelerate.__version__
57
+
58
+ xformers_version = "not installed"
59
+ if is_xformers_available():
60
+ import xformers
61
+
62
+ xformers_version = xformers.__version__
63
+
64
+ info = {
65
+ "`diffusers` version": version,
66
+ "Platform": platform.platform(),
67
+ "Python version": platform.python_version(),
68
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
+ "Huggingface_hub version": hub_version,
70
+ "Transformers version": transformers_version,
71
+ "Accelerate version": accelerate_version,
72
+ "xFormers version": xformers_version,
73
+ "Using GPU in script?": "<fill in>",
74
+ "Using distributed or parallel set-up in script?": "<fill in>",
75
+ }
76
+
77
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
+ print(self.format_dict(info))
79
+
80
+ return info
81
+
82
+ @staticmethod
83
+ def format_dict(d):
84
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
Tiger Model/diffusiers-Tiger/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ from argparse import ArgumentParser, Namespace
23
+ from importlib import import_module
24
+
25
+ import huggingface_hub
26
+ import torch
27
+ from huggingface_hub import hf_hub_download
28
+ from packaging import version
29
+
30
+ from ..utils import logging
31
+ from . import BaseDiffusersCLICommand
32
+
33
+
34
+ def conversion_command_factory(args: Namespace):
35
+ return FP16SafetensorsCommand(
36
+ args.ckpt_id,
37
+ args.fp16,
38
+ args.use_safetensors,
39
+ args.use_auth_token,
40
+ )
41
+
42
+
43
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
44
+ @staticmethod
45
+ def register_subcommand(parser: ArgumentParser):
46
+ conversion_parser = parser.add_parser("fp16_safetensors")
47
+ conversion_parser.add_argument(
48
+ "--ckpt_id",
49
+ type=str,
50
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
51
+ )
52
+ conversion_parser.add_argument(
53
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
54
+ )
55
+ conversion_parser.add_argument(
56
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
57
+ )
58
+ conversion_parser.add_argument(
59
+ "--use_auth_token",
60
+ action="store_true",
61
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
62
+ )
63
+ conversion_parser.set_defaults(func=conversion_command_factory)
64
+
65
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
66
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
67
+ self.ckpt_id = ckpt_id
68
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
69
+ self.fp16 = fp16
70
+
71
+ self.use_safetensors = use_safetensors
72
+
73
+ if not self.use_safetensors and not self.fp16:
74
+ raise NotImplementedError(
75
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
76
+ )
77
+
78
+ self.use_auth_token = use_auth_token
79
+
80
+ def run(self):
81
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
82
+ raise ImportError(
83
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
84
+ " installation."
85
+ )
86
+ else:
87
+ from huggingface_hub import create_commit
88
+ from huggingface_hub._commit_api import CommitOperationAdd
89
+
90
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
91
+ with open(model_index, "r") as f:
92
+ pipeline_class_name = json.load(f)["_class_name"]
93
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
94
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
95
+
96
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
97
+ # here, but just to avoid any rough edge cases.
98
+ pipeline = pipeline_class.from_pretrained(
99
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
100
+ )
101
+ pipeline.save_pretrained(
102
+ self.local_ckpt_dir,
103
+ safe_serialization=True if self.use_safetensors else False,
104
+ variant="fp16" if self.fp16 else None,
105
+ )
106
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
107
+
108
+ # Fetch all the paths.
109
+ if self.fp16:
110
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
111
+ elif self.use_safetensors:
112
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
113
+
114
+ # Prepare for the PR.
115
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
116
+ operations = []
117
+ for path in modified_paths:
118
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
119
+
120
+ # Open the PR.
121
+ commit_description = (
122
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
123
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
124
+ )
125
+ hub_pr_url = create_commit(
126
+ repo_id=self.ckpt_id,
127
+ operations=operations,
128
+ commit_message=commit_message,
129
+ commit_description=commit_description,
130
+ repo_type="model",
131
+ create_pr=True,
132
+ ).pr_url
133
+ self.logger.info(f"PR created here: {hub_pr_url}.")
Tiger Model/diffusiers-Tiger/configuration_utils.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from pathlib import PosixPath
26
+ from typing import Any, Dict, Tuple, Union
27
+
28
+ import numpy as np
29
+ from huggingface_hub import create_repo, hf_hub_download
30
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
+ from requests import HTTPError
32
+
33
+ from . import __version__
34
+ from .utils import (
35
+ DIFFUSERS_CACHE,
36
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37
+ DummyObject,
38
+ deprecate,
39
+ extract_commit_hash,
40
+ http_user_agent,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
48
+
49
+
50
+ class FrozenDict(OrderedDict):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+
54
+ for key, value in self.items():
55
+ setattr(self, key, value)
56
+
57
+ self.__frozen = True
58
+
59
+ def __delitem__(self, *args, **kwargs):
60
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
61
+
62
+ def setdefault(self, *args, **kwargs):
63
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
64
+
65
+ def pop(self, *args, **kwargs):
66
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
67
+
68
+ def update(self, *args, **kwargs):
69
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
70
+
71
+ def __setattr__(self, name, value):
72
+ if hasattr(self, "__frozen") and self.__frozen:
73
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
74
+ super().__setattr__(name, value)
75
+
76
+ def __setitem__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setitem__(name, value)
80
+
81
+
82
+ class ConfigMixin:
83
+ r"""
84
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
+ saving classes that inherit from [`ConfigMixin`].
87
+
88
+ Class attributes:
89
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
90
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
91
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
92
+ overridden by subclass).
93
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
94
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
95
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
96
+ subclass).
97
+ """
98
+ config_name = None
99
+ ignore_for_config = []
100
+ has_compatibles = False
101
+
102
+ _deprecated_kwargs = []
103
+
104
+ def register_to_config(self, **kwargs):
105
+ if self.config_name is None:
106
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
107
+ # Special case for `kwargs` used in deprecation warning added to schedulers
108
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
109
+ # or solve in a more general way.
110
+ kwargs.pop("kwargs", None)
111
+
112
+ if not hasattr(self, "_internal_dict"):
113
+ internal_dict = kwargs
114
+ else:
115
+ previous_dict = dict(self._internal_dict)
116
+ internal_dict = {**self._internal_dict, **kwargs}
117
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
118
+
119
+ self._internal_dict = FrozenDict(internal_dict)
120
+
121
+ def __getattr__(self, name: str) -> Any:
122
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124
+
125
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127
+ """
128
+
129
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130
+ is_attribute = name in self.__dict__
131
+
132
+ if is_in_config and not is_attribute:
133
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135
+ return self._internal_dict[name]
136
+
137
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138
+
139
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
140
+ """
141
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
142
+ [`~ConfigMixin.from_config`] class method.
143
+
144
+ Args:
145
+ save_directory (`str` or `os.PathLike`):
146
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
147
+ push_to_hub (`bool`, *optional*, defaults to `False`):
148
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
149
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
150
+ namespace).
151
+ kwargs (`Dict[str, Any]`, *optional*):
152
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
153
+ """
154
+ if os.path.isfile(save_directory):
155
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
156
+
157
+ os.makedirs(save_directory, exist_ok=True)
158
+
159
+ # If we save using the predefined names, we can load using `from_config`
160
+ output_config_file = os.path.join(save_directory, self.config_name)
161
+
162
+ self.to_json_file(output_config_file)
163
+ logger.info(f"Configuration saved in {output_config_file}")
164
+
165
+ if push_to_hub:
166
+ commit_message = kwargs.pop("commit_message", None)
167
+ private = kwargs.pop("private", False)
168
+ create_pr = kwargs.pop("create_pr", False)
169
+ token = kwargs.pop("token", None)
170
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
171
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
172
+
173
+ self._upload_folder(
174
+ save_directory,
175
+ repo_id,
176
+ token=token,
177
+ commit_message=commit_message,
178
+ create_pr=create_pr,
179
+ )
180
+
181
+ @classmethod
182
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
183
+ r"""
184
+ Instantiate a Python class from a config dictionary.
185
+
186
+ Parameters:
187
+ config (`Dict[str, Any]`):
188
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
189
+ files of compatible classes.
190
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
191
+ Whether kwargs that are not consumed by the Python class should be returned or not.
192
+ kwargs (remaining dictionary of keyword arguments, *optional*):
193
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
194
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
195
+ overwrite the same named arguments in `config`.
196
+
197
+ Returns:
198
+ [`ModelMixin`] or [`SchedulerMixin`]:
199
+ A model or scheduler object instantiated from a config dictionary.
200
+
201
+ Examples:
202
+
203
+ ```python
204
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
205
+
206
+ >>> # Download scheduler from huggingface.co and cache.
207
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
208
+
209
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
210
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
211
+
212
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
213
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
214
+ ```
215
+ """
216
+ # <===== TO BE REMOVED WITH DEPRECATION
217
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
218
+ if "pretrained_model_name_or_path" in kwargs:
219
+ config = kwargs.pop("pretrained_model_name_or_path")
220
+
221
+ if config is None:
222
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
223
+ # ======>
224
+
225
+ if not isinstance(config, dict):
226
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
227
+ if "Scheduler" in cls.__name__:
228
+ deprecation_message += (
229
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
230
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
231
+ " be removed in v1.0.0."
232
+ )
233
+ elif "Model" in cls.__name__:
234
+ deprecation_message += (
235
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
236
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
237
+ " instead. This functionality will be removed in v1.0.0."
238
+ )
239
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
240
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
241
+
242
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
243
+
244
+ # Allow dtype to be specified on initialization
245
+ if "dtype" in unused_kwargs:
246
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
247
+
248
+ # add possible deprecated kwargs
249
+ for deprecated_kwarg in cls._deprecated_kwargs:
250
+ if deprecated_kwarg in unused_kwargs:
251
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
252
+
253
+ # Return model and optionally state and/or unused_kwargs
254
+ model = cls(**init_dict)
255
+
256
+ # make sure to also save config parameters that might be used for compatible classes
257
+ model.register_to_config(**hidden_dict)
258
+
259
+ # add hidden kwargs of compatible classes to unused_kwargs
260
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
261
+
262
+ if return_unused_kwargs:
263
+ return (model, unused_kwargs)
264
+ else:
265
+ return model
266
+
267
+ @classmethod
268
+ def get_config_dict(cls, *args, **kwargs):
269
+ deprecation_message = (
270
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
271
+ " removed in version v1.0.0"
272
+ )
273
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
274
+ return cls.load_config(*args, **kwargs)
275
+
276
+ @classmethod
277
+ def load_config(
278
+ cls,
279
+ pretrained_model_name_or_path: Union[str, os.PathLike],
280
+ return_unused_kwargs=False,
281
+ return_commit_hash=False,
282
+ **kwargs,
283
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
284
+ r"""
285
+ Load a model or scheduler configuration.
286
+
287
+ Parameters:
288
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
289
+ Can be either:
290
+
291
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
292
+ the Hub.
293
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
294
+ [`~ConfigMixin.save_config`].
295
+
296
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
297
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
298
+ is not used.
299
+ force_download (`bool`, *optional*, defaults to `False`):
300
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
301
+ cached versions if they exist.
302
+ resume_download (`bool`, *optional*, defaults to `False`):
303
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
304
+ incompletely downloaded files are deleted.
305
+ proxies (`Dict[str, str]`, *optional*):
306
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
307
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
308
+ output_loading_info(`bool`, *optional*, defaults to `False`):
309
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
310
+ local_files_only (`bool`, *optional*, defaults to `False`):
311
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
312
+ won't be downloaded from the Hub.
313
+ use_auth_token (`str` or *bool*, *optional*):
314
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
315
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
316
+ revision (`str`, *optional*, defaults to `"main"`):
317
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
318
+ allowed by Git.
319
+ subfolder (`str`, *optional*, defaults to `""`):
320
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
321
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
322
+ Whether unused keyword arguments of the config are returned.
323
+ return_commit_hash (`bool`, *optional*, defaults to `False):
324
+ Whether the `commit_hash` of the loaded configuration are returned.
325
+
326
+ Returns:
327
+ `dict`:
328
+ A dictionary of all the parameters stored in a JSON configuration file.
329
+
330
+ """
331
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
332
+ force_download = kwargs.pop("force_download", False)
333
+ resume_download = kwargs.pop("resume_download", False)
334
+ proxies = kwargs.pop("proxies", None)
335
+ use_auth_token = kwargs.pop("use_auth_token", None)
336
+ local_files_only = kwargs.pop("local_files_only", False)
337
+ revision = kwargs.pop("revision", None)
338
+ _ = kwargs.pop("mirror", None)
339
+ subfolder = kwargs.pop("subfolder", None)
340
+ user_agent = kwargs.pop("user_agent", {})
341
+
342
+ user_agent = {**user_agent, "file_type": "config"}
343
+ user_agent = http_user_agent(user_agent)
344
+
345
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
346
+
347
+ if cls.config_name is None:
348
+ raise ValueError(
349
+ "`self.config_name` is not defined. Note that one should not load a config from "
350
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
351
+ )
352
+
353
+ if os.path.isfile(pretrained_model_name_or_path):
354
+ config_file = pretrained_model_name_or_path
355
+ elif os.path.isdir(pretrained_model_name_or_path):
356
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
357
+ # Load from a PyTorch checkpoint
358
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
359
+ elif subfolder is not None and os.path.isfile(
360
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
361
+ ):
362
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
363
+ else:
364
+ raise EnvironmentError(
365
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
366
+ )
367
+ else:
368
+ try:
369
+ # Load from URL or cache if already cached
370
+ config_file = hf_hub_download(
371
+ pretrained_model_name_or_path,
372
+ filename=cls.config_name,
373
+ cache_dir=cache_dir,
374
+ force_download=force_download,
375
+ proxies=proxies,
376
+ resume_download=resume_download,
377
+ local_files_only=local_files_only,
378
+ use_auth_token=use_auth_token,
379
+ user_agent=user_agent,
380
+ subfolder=subfolder,
381
+ revision=revision,
382
+ )
383
+ except RepositoryNotFoundError:
384
+ raise EnvironmentError(
385
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
386
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
387
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
388
+ " login`."
389
+ )
390
+ except RevisionNotFoundError:
391
+ raise EnvironmentError(
392
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
393
+ " this model name. Check the model page at"
394
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
395
+ )
396
+ except EntryNotFoundError:
397
+ raise EnvironmentError(
398
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
399
+ )
400
+ except HTTPError as err:
401
+ raise EnvironmentError(
402
+ "There was a specific connection error when trying to load"
403
+ f" {pretrained_model_name_or_path}:\n{err}"
404
+ )
405
+ except ValueError:
406
+ raise EnvironmentError(
407
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
408
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
409
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
410
+ " run the library in offline mode at"
411
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
412
+ )
413
+ except EnvironmentError:
414
+ raise EnvironmentError(
415
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
416
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
417
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
418
+ f"containing a {cls.config_name} file"
419
+ )
420
+
421
+ try:
422
+ # Load config dict
423
+ config_dict = cls._dict_from_json_file(config_file)
424
+
425
+ commit_hash = extract_commit_hash(config_file)
426
+ except (json.JSONDecodeError, UnicodeDecodeError):
427
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
428
+
429
+ if not (return_unused_kwargs or return_commit_hash):
430
+ return config_dict
431
+
432
+ outputs = (config_dict,)
433
+
434
+ if return_unused_kwargs:
435
+ outputs += (kwargs,)
436
+
437
+ if return_commit_hash:
438
+ outputs += (commit_hash,)
439
+
440
+ return outputs
441
+
442
+ @staticmethod
443
+ def _get_init_keys(cls):
444
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
445
+
446
+ @classmethod
447
+ def extract_init_dict(cls, config_dict, **kwargs):
448
+ # Skip keys that were not present in the original config, so default __init__ values were used
449
+ used_defaults = config_dict.get("_use_default_values", [])
450
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
451
+
452
+ # 0. Copy origin config dict
453
+ original_dict = dict(config_dict.items())
454
+
455
+ # 1. Retrieve expected config attributes from __init__ signature
456
+ expected_keys = cls._get_init_keys(cls)
457
+ expected_keys.remove("self")
458
+ # remove general kwargs if present in dict
459
+ if "kwargs" in expected_keys:
460
+ expected_keys.remove("kwargs")
461
+ # remove flax internal keys
462
+ if hasattr(cls, "_flax_internal_args"):
463
+ for arg in cls._flax_internal_args:
464
+ expected_keys.remove(arg)
465
+
466
+ # 2. Remove attributes that cannot be expected from expected config attributes
467
+ # remove keys to be ignored
468
+ if len(cls.ignore_for_config) > 0:
469
+ expected_keys = expected_keys - set(cls.ignore_for_config)
470
+
471
+ # load diffusers library to import compatible and original scheduler
472
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
473
+
474
+ if cls.has_compatibles:
475
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
476
+ else:
477
+ compatible_classes = []
478
+
479
+ expected_keys_comp_cls = set()
480
+ for c in compatible_classes:
481
+ expected_keys_c = cls._get_init_keys(c)
482
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
483
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
484
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
485
+
486
+ # remove attributes from orig class that cannot be expected
487
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
488
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
489
+ orig_cls = getattr(diffusers_library, orig_cls_name)
490
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
491
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
492
+
493
+ # remove private attributes
494
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
495
+
496
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
497
+ init_dict = {}
498
+ for key in expected_keys:
499
+ # if config param is passed to kwarg and is present in config dict
500
+ # it should overwrite existing config dict key
501
+ if key in kwargs and key in config_dict:
502
+ config_dict[key] = kwargs.pop(key)
503
+
504
+ if key in kwargs:
505
+ # overwrite key
506
+ init_dict[key] = kwargs.pop(key)
507
+ elif key in config_dict:
508
+ # use value from config dict
509
+ init_dict[key] = config_dict.pop(key)
510
+
511
+ # 4. Give nice warning if unexpected values have been passed
512
+ if len(config_dict) > 0:
513
+ logger.warning(
514
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
515
+ "but are not expected and will be ignored. Please verify your "
516
+ f"{cls.config_name} configuration file."
517
+ )
518
+
519
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
520
+ passed_keys = set(init_dict.keys())
521
+ if len(expected_keys - passed_keys) > 0:
522
+ logger.info(
523
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
524
+ )
525
+
526
+ # 6. Define unused keyword arguments
527
+ unused_kwargs = {**config_dict, **kwargs}
528
+
529
+ # 7. Define "hidden" config parameters that were saved for compatible classes
530
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
531
+
532
+ return init_dict, unused_kwargs, hidden_config_dict
533
+
534
+ @classmethod
535
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
536
+ with open(json_file, "r", encoding="utf-8") as reader:
537
+ text = reader.read()
538
+ return json.loads(text)
539
+
540
+ def __repr__(self):
541
+ return f"{self.__class__.__name__} {self.to_json_string()}"
542
+
543
+ @property
544
+ def config(self) -> Dict[str, Any]:
545
+ """
546
+ Returns the config of the class as a frozen dictionary
547
+
548
+ Returns:
549
+ `Dict[str, Any]`: Config of the class.
550
+ """
551
+ return self._internal_dict
552
+
553
+ def to_json_string(self) -> str:
554
+ """
555
+ Serializes the configuration instance to a JSON string.
556
+
557
+ Returns:
558
+ `str`:
559
+ String containing all the attributes that make up the configuration instance in JSON format.
560
+ """
561
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
562
+ config_dict["_class_name"] = self.__class__.__name__
563
+ config_dict["_diffusers_version"] = __version__
564
+
565
+ def to_json_saveable(value):
566
+ if isinstance(value, np.ndarray):
567
+ value = value.tolist()
568
+ elif isinstance(value, PosixPath):
569
+ value = str(value)
570
+ return value
571
+
572
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
573
+ # Don't save "_ignore_files" or "_use_default_values"
574
+ config_dict.pop("_ignore_files", None)
575
+ config_dict.pop("_use_default_values", None)
576
+
577
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
578
+
579
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
580
+ """
581
+ Save the configuration instance's parameters to a JSON file.
582
+
583
+ Args:
584
+ json_file_path (`str` or `os.PathLike`):
585
+ Path to the JSON file to save a configuration instance's parameters.
586
+ """
587
+ with open(json_file_path, "w", encoding="utf-8") as writer:
588
+ writer.write(self.to_json_string())
589
+
590
+
591
+ def register_to_config(init):
592
+ r"""
593
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
594
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
595
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
596
+
597
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
598
+ """
599
+
600
+ @functools.wraps(init)
601
+ def inner_init(self, *args, **kwargs):
602
+ # Ignore private kwargs in the init.
603
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
604
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
605
+ if not isinstance(self, ConfigMixin):
606
+ raise RuntimeError(
607
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
608
+ "not inherit from `ConfigMixin`."
609
+ )
610
+
611
+ ignore = getattr(self, "ignore_for_config", [])
612
+ # Get positional arguments aligned with kwargs
613
+ new_kwargs = {}
614
+ signature = inspect.signature(init)
615
+ parameters = {
616
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
617
+ }
618
+ for arg, name in zip(args, parameters.keys()):
619
+ new_kwargs[name] = arg
620
+
621
+ # Then add all kwargs
622
+ new_kwargs.update(
623
+ {
624
+ k: init_kwargs.get(k, default)
625
+ for k, default in parameters.items()
626
+ if k not in ignore and k not in new_kwargs
627
+ }
628
+ )
629
+
630
+ # Take note of the parameters that were not present in the loaded config
631
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
632
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
633
+
634
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
635
+ getattr(self, "register_to_config")(**new_kwargs)
636
+ init(self, *args, **init_kwargs)
637
+
638
+ return inner_init
639
+
640
+
641
+ def flax_register_to_config(cls):
642
+ original_init = cls.__init__
643
+
644
+ @functools.wraps(original_init)
645
+ def init(self, *args, **kwargs):
646
+ if not isinstance(self, ConfigMixin):
647
+ raise RuntimeError(
648
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
649
+ "not inherit from `ConfigMixin`."
650
+ )
651
+
652
+ # Ignore private kwargs in the init. Retrieve all passed attributes
653
+ init_kwargs = dict(kwargs.items())
654
+
655
+ # Retrieve default values
656
+ fields = dataclasses.fields(self)
657
+ default_kwargs = {}
658
+ for field in fields:
659
+ # ignore flax specific attributes
660
+ if field.name in self._flax_internal_args:
661
+ continue
662
+ if type(field.default) == dataclasses._MISSING_TYPE:
663
+ default_kwargs[field.name] = None
664
+ else:
665
+ default_kwargs[field.name] = getattr(self, field.name)
666
+
667
+ # Make sure init_kwargs override default kwargs
668
+ new_kwargs = {**default_kwargs, **init_kwargs}
669
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
670
+ if "dtype" in new_kwargs:
671
+ new_kwargs.pop("dtype")
672
+
673
+ # Get positional arguments aligned with kwargs
674
+ for i, arg in enumerate(args):
675
+ name = fields[i].name
676
+ new_kwargs[name] = arg
677
+
678
+ # Take note of the parameters that were not present in the loaded config
679
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
680
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
681
+
682
+ getattr(self, "register_to_config")(**new_kwargs)
683
+ original_init(self, *args, **kwargs)
684
+
685
+ cls.__init__ = init
686
+ return cls
Tiger Model/diffusiers-Tiger/dependency_versions_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
+ if sys.version_info < (3, 7):
28
+ pkgs_to_check_at_runtime.append("dataclasses")
29
+ if sys.version_info < (3, 8):
30
+ pkgs_to_check_at_runtime.append("importlib_metadata")
31
+
32
+ for pkg in pkgs_to_check_at_runtime:
33
+ if pkg in deps:
34
+ if pkg == "tokenizers":
35
+ # must be loaded here, or else tqdm check may fail
36
+ from .utils import is_tokenizers_available
37
+
38
+ if not is_tokenizers_available():
39
+ continue # not required, check version only if installed
40
+
41
+ require_version_core(deps[pkg])
42
+ else:
43
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
+
45
+
46
+ def dep_version_check(pkg, hint=None):
47
+ require_version(deps[pkg], hint)
Tiger Model/diffusiers-Tiger/dependency_versions_table.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "compel": "compel==0.1.8",
8
+ "black": "black~=23.1",
9
+ "datasets": "datasets",
10
+ "filelock": "filelock",
11
+ "flax": "flax>=0.4.1",
12
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
+ "huggingface-hub": "huggingface-hub>=0.13.2",
14
+ "requests-mock": "requests-mock==1.10.0",
15
+ "importlib_metadata": "importlib_metadata",
16
+ "invisible-watermark": "invisible-watermark>=0.2.0",
17
+ "isort": "isort>=5.5.4",
18
+ "jax": "jax>=0.2.8,!=0.3.2",
19
+ "jaxlib": "jaxlib>=0.1.65",
20
+ "Jinja2": "Jinja2",
21
+ "k-diffusion": "k-diffusion>=0.0.12",
22
+ "torchsde": "torchsde",
23
+ "note_seq": "note_seq",
24
+ "librosa": "librosa",
25
+ "numpy": "numpy",
26
+ "omegaconf": "omegaconf",
27
+ "parameterized": "parameterized",
28
+ "protobuf": "protobuf>=3.20.3,<4",
29
+ "pytest": "pytest",
30
+ "pytest-timeout": "pytest-timeout",
31
+ "pytest-xdist": "pytest-xdist",
32
+ "ruff": "ruff==0.0.280",
33
+ "safetensors": "safetensors>=0.3.1",
34
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
+ "scipy": "scipy",
36
+ "onnx": "onnx",
37
+ "regex": "regex!=2019.12.17",
38
+ "requests": "requests",
39
+ "tensorboard": "tensorboard",
40
+ "torch": "torch>=1.4",
41
+ "torchvision": "torchvision",
42
+ "transformers": "transformers>=4.25.1",
43
+ "urllib3": "urllib3<=2.0.0",
44
+ }
Tiger Model/diffusiers-Tiger/fuse.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DAF(nn.Module):
6
+ '''
7
+ 直接相加 DirectAddFuse
8
+ '''
9
+
10
+ def __init__(self):
11
+ super(DAF, self).__init__()
12
+
13
+ def forward(self, x, residual):
14
+ return x + residual
15
+
16
+
17
+ class iAFF(nn.Module):
18
+ '''
19
+ 多特征融合 iAFF
20
+ '''
21
+
22
+ def __init__(self, channels=64, r=4):
23
+ super(iAFF, self).__init__()
24
+ inter_channels = int(channels // r)
25
+
26
+ # 本地注意力
27
+ self.local_att = nn.Sequential(
28
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
29
+ nn.BatchNorm2d(inter_channels),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
32
+ nn.BatchNorm2d(channels),
33
+ )
34
+
35
+ # 全局注意力
36
+ self.global_att = nn.Sequential(
37
+ nn.AdaptiveAvgPool2d(1),
38
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
39
+ nn.BatchNorm2d(inter_channels),
40
+ nn.ReLU(inplace=True),
41
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
42
+ nn.BatchNorm2d(channels),
43
+ )
44
+
45
+ # 第二次本地注意力
46
+ self.local_att2 = nn.Sequential(
47
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
48
+ nn.BatchNorm2d(inter_channels),
49
+ nn.ReLU(inplace=True),
50
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
51
+ nn.BatchNorm2d(channels),
52
+ )
53
+ # 第二次全局注意力
54
+ self.global_att2 = nn.Sequential(
55
+ nn.AdaptiveAvgPool2d(1),
56
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
57
+ nn.BatchNorm2d(inter_channels),
58
+ nn.ReLU(inplace=True),
59
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
60
+ nn.BatchNorm2d(channels),
61
+ )
62
+
63
+ self.sigmoid = nn.Sigmoid()
64
+
65
+ def forward(self, x, residual):
66
+ xa = x + residual
67
+ xl = self.local_att(xa)
68
+ xg = self.global_att(xa)
69
+ xlg = xl + xg
70
+ wei = self.sigmoid(xlg)
71
+ xi = x * wei + residual * (1 - wei)
72
+
73
+ xl2 = self.local_att2(xi)
74
+ xg2 = self.global_att(xi)
75
+ xlg2 = xl2 + xg2
76
+ wei2 = self.sigmoid(xlg2)
77
+ xo = x * wei2 + residual * (1 - wei2)
78
+ return xo
79
+
80
+
81
+ class AFF(nn.Module):
82
+ '''
83
+ 多特征融合 AFF
84
+ '''
85
+
86
+ def __init__(self, channels=64, r=4):
87
+ super(AFF, self).__init__()
88
+ inter_channels = int(channels // r)
89
+
90
+ self.local_att = nn.Sequential(
91
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
+ nn.BatchNorm2d(inter_channels),
93
+ nn.ReLU(inplace=True),
94
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
+ nn.BatchNorm2d(channels),
96
+ )
97
+
98
+ self.global_att = nn.Sequential(
99
+ nn.AdaptiveAvgPool2d(1),
100
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
+ nn.BatchNorm2d(inter_channels),
102
+ nn.ReLU(inplace=True),
103
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
+ nn.BatchNorm2d(channels),
105
+ )
106
+
107
+ self.sigmoid = nn.Sigmoid()
108
+
109
+ def forward(self, x, residual):
110
+ xa = x + residual
111
+ xl = self.local_att(xa)
112
+ xg = self.global_att(xa)
113
+ xlg = xl + xg
114
+ wei = self.sigmoid(xlg)
115
+
116
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
117
+ return xo
118
+
119
+
120
+ class MS_CAM(nn.Module):
121
+ '''
122
+ 单特征 进行通道加权,作用类似SE模块
123
+ '''
124
+
125
+ def __init__(self, channels=64, r=4):
126
+ super(MS_CAM, self).__init__()
127
+ inter_channels = int(channels // r)
128
+
129
+ self.local_att = nn.Sequential(
130
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
131
+ nn.BatchNorm2d(inter_channels),
132
+ nn.ReLU(inplace=True),
133
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
134
+ nn.BatchNorm2d(channels),
135
+ )
136
+
137
+ self.global_att = nn.Sequential(
138
+ nn.AdaptiveAvgPool2d(1),
139
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
140
+ nn.BatchNorm2d(inter_channels),
141
+ nn.ReLU(inplace=True),
142
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
143
+ nn.BatchNorm2d(channels),
144
+ )
145
+
146
+ self.sigmoid = nn.Sigmoid()
147
+
148
+ def forward(self, x):
149
+ xl = self.local_att(x)
150
+ xg = self.global_att(x)
151
+ xlg = xl + xg
152
+ wei = self.sigmoid(xlg)
153
+ return x * wei
154
+
155
+
156
+
157
+ if __name__ == '__main__':
158
+ import os
159
+ device = torch.device("cpu")
160
+ x = torch.ones(1, 2, 2, 2).to(device)
161
+ print(x)
162
+ a = x[0]
163
+ print(a)
164
+ b = torch.ones(2, 2, 2)
165
+ c = torch.stack((a, b))
166
+ print(x.shape)
167
+ # x, residual= torch.ones(1, 2, 2, 2).to(device), torch.ones(1,64, 32, 32).to(device)
168
+ # x = torch.cat(x, dim=1)
169
+ # print(x.shape)
170
+ # channels=x.shape[1]
171
+ # print(channels)
172
+ # model=AFF(channels=channels)
173
+ # model=model.to(device).train()
174
+ # output = model(x, residual)
175
+ # print(output.shape)
Tiger Model/diffusiers-Tiger/getWeight.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import shutil
5
+ from pathlib import Path
6
+ from pynvml import *
7
+ import accelerate
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from transformers import AutoTokenizer, PretrainedConfig
13
+
14
+ tensor1 = torch.tensor([[49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
15
+ 32955, 267, 38692, 267, 13989, 43204, 267, 1042, 13989, 49407,
16
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
17
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
18
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
19
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
20
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
21
+ 0, 0, 0, 0, 0, 0, 0],
22
+ [49406, 1884, 33667, 267, 41122, 3633, 267, 21263, 268, 1126,
23
+ 268, 7771, 267, 6148, 267, 32955, 267, 13989, 43204, 267,
24
+ 1042, 13989, 267, 1579, 3396, 267, 2442, 1579, 3396, 49407,
25
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
26
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
27
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
28
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
29
+ 0, 0, 0, 0, 0, 0, 0],
30
+ [49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
31
+ 3143, 267, 6307, 267, 1070, 1042, 13989, 49407, 0, 0,
32
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
33
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
34
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
35
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
36
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
37
+ 0, 0, 0, 0, 0, 0, 0],
38
+ [49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
39
+ 46131, 267, 3143, 267, 6307, 49407, 0, 0, 0, 0,
40
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
41
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
42
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
43
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
44
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
45
+ 0, 0, 0, 0, 0, 0, 0],
46
+ [49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
47
+ 6148, 267, 32955, 267, 38692, 267, 13989, 43204, 267, 1042,
48
+ 13989, 267, 1579, 3396, 267, 5094, 268, 789, 1579, 3396,
49
+ 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0,
50
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
51
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
52
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
53
+ 0, 0, 0, 0, 0, 0, 0],
54
+ [49406, 1884, 33667, 267, 21263, 268, 1126, 268, 7771, 267,
55
+ 32955, 267, 38692, 6448, 49407, 0, 0, 0, 0, 0,
56
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
57
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
58
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
59
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
60
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
61
+ 0, 0, 0, 0, 0, 0, 0]])
62
+
63
+
64
+ l = tensor1.tolist()
65
+ list_2 = sum(l, [])
66
+ def remove_item(n):
67
+ return n != 0 and n !=49407 and n!=49406 and n!=267
68
+ list_3 = list(filter(remove_item, list_2))
69
+
70
+ dict = {}
71
+ for key in list_3:
72
+ dict[key] = dict.get(key, 0) + 1
73
+ print(dict)
74
+
75
+ revision = None
76
+ tokenizer = AutoTokenizer.from_pretrained(
77
+ "/export/home/daifang/Diffusion/diffusers/model/sd-8_28",
78
+ subfolder="tokenizer",
79
+ revision=revision,
80
+ use_fast=False,
81
+ )
82
+ # captions = ['papillary blood flow', 'malignant follicular, solid, unclear, irregular, hales, circular, enormous, white point', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo, white points, sand-like white points', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, extremely low echo, white points, sand-like white points', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo, white points, sand-like white points']
83
+
84
+
85
+ ['papillary, taller-than-wide, solid, unclear, irregular, echo uneven, low echo', 'papillary, wider-than-tall, solid, unclear, irregular, echo uneven, low echo, white points, sand-like white points', 'papillary, wider-than-tall, unclear, irregular, echo uneven, low echo, white points, sand-like white points', 'papillary, taller-than-wide, solid, unclear, irregular, echo uneven, low echo', 'papillary, taller-than-wide, solid, unclear, irregular, echo uneven, low echo, white points, large white points', 'No focus']
86
+ inputs = tokenizer(
87
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")
88
+ print(inputs)
Tiger Model/diffusiers-Tiger/image_processor.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from .configuration_utils import ConfigMixin, register_to_config
24
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
+
26
+
27
+ class VaeImageProcessor(ConfigMixin):
28
+ """
29
+ Image processor for VAE.
30
+
31
+ Args:
32
+ do_resize (`bool`, *optional*, defaults to `True`):
33
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
35
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
36
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
37
+ resample (`str`, *optional*, defaults to `lanczos`):
38
+ Resampling filter to use when resizing the image.
39
+ do_normalize (`bool`, *optional*, defaults to `True`):
40
+ Whether to normalize the image to [-1,1].
41
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
42
+ Whether to convert the images to RGB format.
43
+ """
44
+
45
+ config_name = CONFIG_NAME
46
+
47
+ @register_to_config
48
+ def __init__(
49
+ self,
50
+ do_resize: bool = True,
51
+ vae_scale_factor: int = 8,
52
+ resample: str = "lanczos",
53
+ do_normalize: bool = True,
54
+ do_convert_rgb: bool = False,
55
+ ):
56
+ super().__init__()
57
+
58
+ @staticmethod
59
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
60
+ """
61
+ Convert a numpy image or a batch of images to a PIL image.
62
+ """
63
+ if images.ndim == 3:
64
+ images = images[None, ...]
65
+ images = (images * 255).round().astype("uint8")
66
+ if images.shape[-1] == 1:
67
+ # special case for grayscale (single channel) images
68
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
69
+ else:
70
+ pil_images = [Image.fromarray(image) for image in images]
71
+
72
+ return pil_images
73
+
74
+ @staticmethod
75
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
76
+ """
77
+ Convert a PIL image or a list of PIL images to NumPy arrays.
78
+ """
79
+ if not isinstance(images, list):
80
+ images = [images]
81
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
82
+ images = np.stack(images, axis=0)
83
+
84
+ return images
85
+
86
+ @staticmethod
87
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
88
+ """
89
+ Convert a NumPy image to a PyTorch tensor.
90
+ """
91
+ if images.ndim == 3:
92
+ images = images[..., None]
93
+
94
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
95
+ return images
96
+
97
+ @staticmethod
98
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
99
+ """
100
+ Convert a PyTorch tensor to a NumPy image.
101
+ """
102
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
103
+ return images
104
+
105
+ @staticmethod
106
+ def normalize(images):
107
+ """
108
+ Normalize an image array to [-1,1].
109
+ """
110
+ return 2.0 * images - 1.0
111
+
112
+ @staticmethod
113
+ def denormalize(images):
114
+ """
115
+ Denormalize an image array to [0,1].
116
+ """
117
+ return (images / 2 + 0.5).clamp(0, 1)
118
+
119
+ @staticmethod
120
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
121
+ """
122
+ Converts an image to RGB format.
123
+ """
124
+ image = image.convert("RGB")
125
+ return image
126
+
127
+ def resize(
128
+ self,
129
+ image: PIL.Image.Image,
130
+ height: Optional[int] = None,
131
+ width: Optional[int] = None,
132
+ ) -> PIL.Image.Image:
133
+ """
134
+ Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
135
+ """
136
+ if height is None:
137
+ height = image.height
138
+ if width is None:
139
+ width = image.width
140
+
141
+ width, height = (
142
+ x - x % self.config.vae_scale_factor for x in (width, height)
143
+ ) # resize to integer multiple of vae_scale_factor
144
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
145
+ return image
146
+
147
+ def preprocess(
148
+ self,
149
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
150
+ height: Optional[int] = None,
151
+ width: Optional[int] = None,
152
+ ) -> torch.Tensor:
153
+ """
154
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
155
+ """
156
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
157
+ if isinstance(image, supported_formats):
158
+ image = [image]
159
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
160
+ raise ValueError(
161
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
162
+ )
163
+
164
+ if isinstance(image[0], PIL.Image.Image):
165
+ if self.config.do_convert_rgb:
166
+ image = [self.convert_to_rgb(i) for i in image]
167
+ if self.config.do_resize:
168
+ image = [self.resize(i, height, width) for i in image]
169
+ image = self.pil_to_numpy(image) # to np
170
+ image = self.numpy_to_pt(image) # to pt
171
+
172
+ elif isinstance(image[0], np.ndarray):
173
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
174
+ image = self.numpy_to_pt(image)
175
+ _, _, height, width = image.shape
176
+ if self.config.do_resize and (
177
+ height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
178
+ ):
179
+ raise ValueError(
180
+ f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
181
+ f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
182
+ )
183
+
184
+ elif isinstance(image[0], torch.Tensor):
185
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
186
+ _, channel, height, width = image.shape
187
+
188
+ # don't need any preprocess if the image is latents
189
+ if channel == 4:
190
+ return image
191
+
192
+ if self.config.do_resize and (
193
+ height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
194
+ ):
195
+ raise ValueError(
196
+ f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
197
+ f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
198
+ )
199
+
200
+ # expected range [0,1], normalize to [-1,1]
201
+ do_normalize = self.config.do_normalize
202
+ if image.min() < 0:
203
+ warnings.warn(
204
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
205
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
206
+ FutureWarning,
207
+ )
208
+ do_normalize = False
209
+
210
+ if do_normalize:
211
+ image = self.normalize(image)
212
+
213
+ return image
214
+
215
+ def postprocess(
216
+ self,
217
+ image: torch.FloatTensor,
218
+ output_type: str = "pil",
219
+ do_denormalize: Optional[List[bool]] = None,
220
+ ):
221
+ if not isinstance(image, torch.Tensor):
222
+ raise ValueError(
223
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
224
+ )
225
+ if output_type not in ["latent", "pt", "np", "pil"]:
226
+ deprecation_message = (
227
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
228
+ "`pil`, `np`, `pt`, `latent`"
229
+ )
230
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
231
+ output_type = "np"
232
+
233
+ if output_type == "latent":
234
+ return image
235
+
236
+ if do_denormalize is None:
237
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
238
+
239
+ image = torch.stack(
240
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
241
+ )
242
+
243
+ if output_type == "pt":
244
+ return image
245
+
246
+ image = self.pt_to_numpy(image)
247
+
248
+ if output_type == "np":
249
+ return image
250
+
251
+ if output_type == "pil":
252
+ return self.numpy_to_pil(image)
253
+
254
+
255
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
256
+ """
257
+ Image processor for VAE LDM3D.
258
+
259
+ Args:
260
+ do_resize (`bool`, *optional*, defaults to `True`):
261
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
262
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
263
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
264
+ resample (`str`, *optional*, defaults to `lanczos`):
265
+ Resampling filter to use when resizing the image.
266
+ do_normalize (`bool`, *optional*, defaults to `True`):
267
+ Whether to normalize the image to [-1,1].
268
+ """
269
+
270
+ config_name = CONFIG_NAME
271
+
272
+ @register_to_config
273
+ def __init__(
274
+ self,
275
+ do_resize: bool = True,
276
+ vae_scale_factor: int = 8,
277
+ resample: str = "lanczos",
278
+ do_normalize: bool = True,
279
+ ):
280
+ super().__init__()
281
+
282
+ @staticmethod
283
+ def numpy_to_pil(images):
284
+ """
285
+ Convert a NumPy image or a batch of images to a PIL image.
286
+ """
287
+ if images.ndim == 3:
288
+ images = images[None, ...]
289
+ images = (images * 255).round().astype("uint8")
290
+ if images.shape[-1] == 1:
291
+ # special case for grayscale (single channel) images
292
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
293
+ else:
294
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
295
+
296
+ return pil_images
297
+
298
+ @staticmethod
299
+ def rgblike_to_depthmap(image):
300
+ """
301
+ Args:
302
+ image: RGB-like depth image
303
+
304
+ Returns: depth map
305
+
306
+ """
307
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
308
+
309
+ def numpy_to_depth(self, images):
310
+ """
311
+ Convert a NumPy depth image or a batch of images to a PIL image.
312
+ """
313
+ if images.ndim == 3:
314
+ images = images[None, ...]
315
+ images_depth = images[:, :, :, 3:]
316
+ if images.shape[-1] == 6:
317
+ images_depth = (images_depth * 255).round().astype("uint8")
318
+ pil_images = [
319
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
320
+ ]
321
+ elif images.shape[-1] == 4:
322
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
323
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
324
+ else:
325
+ raise Exception("Not supported")
326
+
327
+ return pil_images
328
+
329
+ def postprocess(
330
+ self,
331
+ image: torch.FloatTensor,
332
+ output_type: str = "pil",
333
+ do_denormalize: Optional[List[bool]] = None,
334
+ ):
335
+ if not isinstance(image, torch.Tensor):
336
+ raise ValueError(
337
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
338
+ )
339
+ if output_type not in ["latent", "pt", "np", "pil"]:
340
+ deprecation_message = (
341
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
342
+ "`pil`, `np`, `pt`, `latent`"
343
+ )
344
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
345
+ output_type = "np"
346
+
347
+ if do_denormalize is None:
348
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
349
+
350
+ image = torch.stack(
351
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
352
+ )
353
+
354
+ image = self.pt_to_numpy(image)
355
+
356
+ if output_type == "np":
357
+ if image.shape[-1] == 6:
358
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
359
+ else:
360
+ image_depth = image[:, :, :, 3:]
361
+ return image[:, :, :, :3], image_depth
362
+
363
+ if output_type == "pil":
364
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
365
+ else:
366
+ raise Exception(f"This type {output_type} is not supported")
Tiger Model/diffusiers-Tiger/loaders.py ADDED
The diff for this file is too large to render. See raw diff
 
Tiger Model/diffusiers-Tiger/models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Models
2
+
3
+ For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
Tiger Model/diffusiers-Tiger/models/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils import is_flax_available, is_torch_available
16
+
17
+
18
+ if is_torch_available():
19
+ from .adapter import MultiAdapter, T2IAdapter
20
+ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
21
+ from .autoencoder_kl import AutoencoderKL
22
+ from .autoencoder_tiny import AutoencoderTiny
23
+ from .controlnet import ControlNetModel
24
+ from .dual_transformer_2d import DualTransformer2DModel
25
+ from .modeling_utils import ModelMixin
26
+ from .prior_transformer import PriorTransformer
27
+ from .t5_film_transformer import T5FilmDecoder
28
+ from .transformer_2d import Transformer2DModel
29
+ from .unet_1d import UNet1DModel
30
+ from .unet_2d import UNet2DModel
31
+ from .unet_2d_condition import UNet2DConditionModel
32
+ from .modeling_utils import ModelMixin
33
+ from .unet_3d_condition import UNet3DConditionModel
34
+ from .vq_model import VQModel
35
+
36
+ if is_flax_available():
37
+ from .controlnet_flax import FlaxControlNetModel
38
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
39
+ from .vae_flax import FlaxAutoencoderKL
Tiger Model/diffusiers-Tiger/models/activations.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def get_activation(act_fn):
5
+ if act_fn in ["swish", "silu"]:
6
+ return nn.SiLU()
7
+ elif act_fn == "mish":
8
+ return nn.Mish()
9
+ elif act_fn == "gelu":
10
+ return nn.GELU()
11
+ elif act_fn == "relu":
12
+ return nn.ReLU()
13
+ else:
14
+ raise ValueError(f"Unsupported activation function: {act_fn}")
Tiger Model/diffusiers-Tiger/models/adapter.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from .modeling_utils import ModelMixin
22
+ from .resnet import Downsample2D
23
+
24
+
25
+ class MultiAdapter(ModelMixin):
26
+ r"""
27
+ MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
28
+ user-assigned weighting.
29
+
30
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
31
+ implements for all the model (such as downloading or saving, etc.)
32
+
33
+ Parameters:
34
+ adapters (`List[T2IAdapter]`, *optional*, defaults to None):
35
+ A list of `T2IAdapter` model instances.
36
+ """
37
+
38
+ def __init__(self, adapters: List["T2IAdapter"]):
39
+ super(MultiAdapter, self).__init__()
40
+
41
+ self.num_adapter = len(adapters)
42
+ self.adapters = nn.ModuleList(adapters)
43
+
44
+ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
45
+ r"""
46
+ Args:
47
+ xs (`torch.Tensor`):
48
+ (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
49
+ `channel` should equal to `num_adapter` * "number of channel of image".
50
+ adapter_weights (`List[float]`, *optional*, defaults to None):
51
+ List of floats representing the weight which will be multiply to each adapter's output before adding
52
+ them together.
53
+ """
54
+ if adapter_weights is None:
55
+ adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
56
+ else:
57
+ adapter_weights = torch.tensor(adapter_weights)
58
+
59
+ if xs.shape[1] % self.num_adapter != 0:
60
+ raise ValueError(
61
+ f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
62
+ f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
63
+ )
64
+ x_list = torch.chunk(xs, self.num_adapter, dim=1)
65
+ accume_state = None
66
+ for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
67
+ features = adapter(x)
68
+ if accume_state is None:
69
+ accume_state = features
70
+ else:
71
+ for i in range(len(features)):
72
+ accume_state[i] += w * features[i]
73
+ return accume_state
74
+
75
+
76
+ class T2IAdapter(ModelMixin, ConfigMixin):
77
+ r"""
78
+ A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
79
+ generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
80
+ architecture follows the original implementation of
81
+ [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
82
+ and
83
+ [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
84
+
85
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
86
+ implements for all the model (such as downloading or saving, etc.)
87
+
88
+ Parameters:
89
+ in_channels (`int`, *optional*, defaults to 3):
90
+ Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
91
+ image as *control image*.
92
+ channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
93
+ The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
94
+ also determine the number of downsample blocks in the Adapter.
95
+ num_res_blocks (`int`, *optional*, defaults to 2):
96
+ Number of ResNet blocks in each downsample block
97
+ """
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ in_channels: int = 3,
103
+ channels: List[int] = [320, 640, 1280, 1280],
104
+ num_res_blocks: int = 2,
105
+ downscale_factor: int = 8,
106
+ adapter_type: str = "full_adapter",
107
+ ):
108
+ super().__init__()
109
+
110
+ if adapter_type == "full_adapter":
111
+ self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
112
+ elif adapter_type == "light_adapter":
113
+ self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
114
+ else:
115
+ raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
116
+
117
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
118
+ return self.adapter(x)
119
+
120
+ @property
121
+ def total_downscale_factor(self):
122
+ return self.adapter.total_downscale_factor
123
+
124
+
125
+ # full adapter
126
+
127
+
128
+ class FullAdapter(nn.Module):
129
+ def __init__(
130
+ self,
131
+ in_channels: int = 3,
132
+ channels: List[int] = [320, 640, 1280, 1280],
133
+ num_res_blocks: int = 2,
134
+ downscale_factor: int = 8,
135
+ ):
136
+ super().__init__()
137
+
138
+ in_channels = in_channels * downscale_factor**2
139
+
140
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
141
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
142
+
143
+ self.body = nn.ModuleList(
144
+ [
145
+ AdapterBlock(channels[0], channels[0], num_res_blocks),
146
+ *[
147
+ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
148
+ for i in range(1, len(channels))
149
+ ],
150
+ ]
151
+ )
152
+
153
+ self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
154
+
155
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
156
+ x = self.unshuffle(x)
157
+ x = self.conv_in(x)
158
+
159
+ features = []
160
+
161
+ for block in self.body:
162
+ x = block(x)
163
+ features.append(x)
164
+
165
+ return features
166
+
167
+
168
+ class AdapterBlock(nn.Module):
169
+ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
170
+ super().__init__()
171
+
172
+ self.downsample = None
173
+ if down:
174
+ self.downsample = Downsample2D(in_channels)
175
+
176
+ self.in_conv = None
177
+ if in_channels != out_channels:
178
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
179
+
180
+ self.resnets = nn.Sequential(
181
+ *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
182
+ )
183
+
184
+ def forward(self, x):
185
+ if self.downsample is not None:
186
+ x = self.downsample(x)
187
+
188
+ if self.in_conv is not None:
189
+ x = self.in_conv(x)
190
+
191
+ x = self.resnets(x)
192
+
193
+ return x
194
+
195
+
196
+ class AdapterResnetBlock(nn.Module):
197
+ def __init__(self, channels):
198
+ super().__init__()
199
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
200
+ self.act = nn.ReLU()
201
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
202
+
203
+ def forward(self, x):
204
+ h = x
205
+ h = self.block1(h)
206
+ h = self.act(h)
207
+ h = self.block2(h)
208
+
209
+ return h + x
210
+
211
+
212
+ # light adapter
213
+
214
+
215
+ class LightAdapter(nn.Module):
216
+ def __init__(
217
+ self,
218
+ in_channels: int = 3,
219
+ channels: List[int] = [320, 640, 1280],
220
+ num_res_blocks: int = 4,
221
+ downscale_factor: int = 8,
222
+ ):
223
+ super().__init__()
224
+
225
+ in_channels = in_channels * downscale_factor**2
226
+
227
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
228
+
229
+ self.body = nn.ModuleList(
230
+ [
231
+ LightAdapterBlock(in_channels, channels[0], num_res_blocks),
232
+ *[
233
+ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
234
+ for i in range(len(channels) - 1)
235
+ ],
236
+ LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
237
+ ]
238
+ )
239
+
240
+ self.total_downscale_factor = downscale_factor * (2 ** len(channels))
241
+
242
+ def forward(self, x):
243
+ x = self.unshuffle(x)
244
+
245
+ features = []
246
+
247
+ for block in self.body:
248
+ x = block(x)
249
+ features.append(x)
250
+
251
+ return features
252
+
253
+
254
+ class LightAdapterBlock(nn.Module):
255
+ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
256
+ super().__init__()
257
+ mid_channels = out_channels // 4
258
+
259
+ self.downsample = None
260
+ if down:
261
+ self.downsample = Downsample2D(in_channels)
262
+
263
+ self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
264
+ self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
265
+ self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
266
+
267
+ def forward(self, x):
268
+ if self.downsample is not None:
269
+ x = self.downsample(x)
270
+
271
+ x = self.in_conv(x)
272
+ x = self.resnets(x)
273
+ x = self.out_conv(x)
274
+
275
+ return x
276
+
277
+
278
+ class LightAdapterResnetBlock(nn.Module):
279
+ def __init__(self, channels):
280
+ super().__init__()
281
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
282
+ self.act = nn.ReLU()
283
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
284
+
285
+ def forward(self, x):
286
+ h = x
287
+ h = self.block1(h)
288
+ h = self.act(h)
289
+ h = self.block2(h)
290
+
291
+ return h + x
Tiger Model/diffusiers-Tiger/models/attention.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import maybe_allow_in_graph
21
+ from .activations import get_activation
22
+ from .attention_processor import Attention
23
+ from .embeddings import CombinedTimestepLabelEmbeddings
24
+ from .lora import LoRACompatibleLinear
25
+
26
+
27
+ @maybe_allow_in_graph
28
+ class GatedSelfAttentionDense(nn.Module):
29
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
30
+ super().__init__()
31
+
32
+ # we need a linear projection since we need cat visual feature and obj feature
33
+ self.linear = nn.Linear(context_dim, query_dim)
34
+
35
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
36
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
37
+
38
+ self.norm1 = nn.LayerNorm(query_dim)
39
+ self.norm2 = nn.LayerNorm(query_dim)
40
+
41
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
42
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
43
+
44
+ self.enabled = True
45
+
46
+ def forward(self, x, objs):
47
+ if not self.enabled:
48
+ return x
49
+
50
+ n_visual = x.shape[1]
51
+ objs = self.linear(objs)
52
+
53
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
54
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
55
+
56
+ return x
57
+
58
+
59
+ @maybe_allow_in_graph
60
+ class BasicTransformerBlock(nn.Module):
61
+ r"""
62
+ A basic Transformer block.
63
+
64
+ Parameters:
65
+ dim (`int`): The number of channels in the input and output.
66
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
67
+ attention_head_dim (`int`): The number of channels in each head.
68
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
69
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
70
+ only_cross_attention (`bool`, *optional*):
71
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
72
+ double_self_attention (`bool`, *optional*):
73
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
74
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
75
+ num_embeds_ada_norm (:
76
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
77
+ attention_bias (:
78
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ dim: int,
84
+ num_attention_heads: int,
85
+ attention_head_dim: int,
86
+ dropout=0.0,
87
+ cross_attention_dim: Optional[int] = None,
88
+ activation_fn: str = "geglu",
89
+ num_embeds_ada_norm: Optional[int] = None,
90
+ attention_bias: bool = False,
91
+ only_cross_attention: bool = False,
92
+ double_self_attention: bool = False,
93
+ upcast_attention: bool = False,
94
+ norm_elementwise_affine: bool = True,
95
+ norm_type: str = "layer_norm",
96
+ final_dropout: bool = False,
97
+ attention_type: str = "default",
98
+ weight: Optional[torch.LongTensor] = None,
99
+ ):
100
+ super().__init__()
101
+ self.only_cross_attention = only_cross_attention
102
+
103
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
104
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
105
+
106
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
107
+ raise ValueError(
108
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
109
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
110
+ )
111
+
112
+ # Define 3 blocks. Each block has its own normalization layer.
113
+ # 1. Self-Attn
114
+ if self.use_ada_layer_norm:
115
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
116
+ elif self.use_ada_layer_norm_zero:
117
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
118
+ else:
119
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
+ self.attn1 = Attention(
121
+ query_dim=dim,
122
+ heads=num_attention_heads,
123
+ dim_head=attention_head_dim,
124
+ dropout=dropout,
125
+ bias=attention_bias,
126
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
127
+ upcast_attention=upcast_attention,
128
+ )
129
+
130
+ # 2. Cross-Attn
131
+ if cross_attention_dim is not None or double_self_attention:
132
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
133
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
134
+ # the second cross attention block.
135
+ self.norm2 = (
136
+ AdaLayerNorm(dim, num_embeds_ada_norm)
137
+ if self.use_ada_layer_norm
138
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
139
+ )
140
+ self.attn2 = Attention(
141
+ query_dim=dim,
142
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
143
+ heads=num_attention_heads,
144
+ dim_head=attention_head_dim,
145
+ dropout=dropout,
146
+ bias=attention_bias,
147
+ upcast_attention=upcast_attention,
148
+ ) # is self-attn if encoder_hidden_states is none
149
+ else:
150
+ self.norm2 = None
151
+ self.attn2 = None
152
+
153
+ # 3. Feed-forward
154
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
155
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
156
+
157
+ # 4. Fuser
158
+ if attention_type == "gated":
159
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
160
+
161
+ # let chunk size default to None
162
+ self._chunk_size = None
163
+ self._chunk_dim = 0
164
+
165
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
166
+ # Sets chunk feed-forward
167
+ self._chunk_size = chunk_size
168
+ self._chunk_dim = dim
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states: torch.FloatTensor,
173
+ attention_mask: Optional[torch.FloatTensor] = None,
174
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
175
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
176
+ timestep: Optional[torch.LongTensor] = None,
177
+ cross_attention_kwargs: Dict[str, Any] = None,
178
+ class_labels: Optional[torch.LongTensor] = None,
179
+ weight : Optional[torch.LongTensor] = None,
180
+ ):
181
+ # Notice that normalization is always applied before the real computation in the following blocks.
182
+ # 1. Self-Attention
183
+ if self.use_ada_layer_norm:
184
+ norm_hidden_states = self.norm1(hidden_states, timestep)
185
+ elif self.use_ada_layer_norm_zero:
186
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
187
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
188
+ )
189
+ else:
190
+ norm_hidden_states = self.norm1(hidden_states)
191
+
192
+ # 0. Prepare GLIGEN inputs
193
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
194
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
195
+
196
+ attn_output = self.attn1(
197
+ norm_hidden_states,
198
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
199
+ attention_mask=attention_mask,
200
+ weight = weight,
201
+ **cross_attention_kwargs,
202
+ )
203
+ if self.use_ada_layer_norm_zero:
204
+ attn_output = gate_msa.unsqueeze(1) * attn_output
205
+ hidden_states = attn_output + hidden_states
206
+
207
+ # 1.5 GLIGEN Control
208
+ if gligen_kwargs is not None:
209
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
210
+ # 1.5 ends
211
+
212
+ # 2. Cross-Attention
213
+ if self.attn2 is not None:
214
+ norm_hidden_states = (
215
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
216
+ )
217
+
218
+ attn_output = self.attn2(
219
+ norm_hidden_states,
220
+ encoder_hidden_states=encoder_hidden_states,
221
+ attention_mask=encoder_attention_mask,
222
+ **cross_attention_kwargs,
223
+ )
224
+ hidden_states = attn_output + hidden_states
225
+
226
+ # 3. Feed-forward
227
+ norm_hidden_states = self.norm3(hidden_states)
228
+
229
+ if self.use_ada_layer_norm_zero:
230
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
231
+
232
+ if self._chunk_size is not None:
233
+ # "feed_forward_chunk_size" can be used to save memory
234
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
235
+ raise ValueError(
236
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
237
+ )
238
+
239
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
240
+ ff_output = torch.cat(
241
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
242
+ dim=self._chunk_dim,
243
+ )
244
+ else:
245
+ ff_output = self.ff(norm_hidden_states)
246
+
247
+ if self.use_ada_layer_norm_zero:
248
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
249
+
250
+ hidden_states = ff_output + hidden_states
251
+
252
+ return hidden_states
253
+
254
+
255
+ class FeedForward(nn.Module):
256
+ r"""
257
+ A feed-forward layer.
258
+
259
+ Parameters:
260
+ dim (`int`): The number of channels in the input.
261
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
262
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
263
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
264
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
265
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ dim: int,
271
+ dim_out: Optional[int] = None,
272
+ mult: int = 4,
273
+ dropout: float = 0.0,
274
+ activation_fn: str = "geglu",
275
+ final_dropout: bool = False,
276
+ ):
277
+ super().__init__()
278
+ inner_dim = int(dim * mult)
279
+ dim_out = dim_out if dim_out is not None else dim
280
+
281
+ if activation_fn == "gelu":
282
+ act_fn = GELU(dim, inner_dim)
283
+ if activation_fn == "gelu-approximate":
284
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
285
+ elif activation_fn == "geglu":
286
+ act_fn = GEGLU(dim, inner_dim)
287
+ elif activation_fn == "geglu-approximate":
288
+ act_fn = ApproximateGELU(dim, inner_dim)
289
+
290
+ self.net = nn.ModuleList([])
291
+ # project in
292
+ self.net.append(act_fn)
293
+ # project dropout
294
+ self.net.append(nn.Dropout(dropout))
295
+ # project out
296
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
297
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
298
+ if final_dropout:
299
+ self.net.append(nn.Dropout(dropout))
300
+
301
+ def forward(self, hidden_states):
302
+ for module in self.net:
303
+ hidden_states = module(hidden_states)
304
+ return hidden_states
305
+
306
+
307
+ class GELU(nn.Module):
308
+ r"""
309
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
310
+ """
311
+
312
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
313
+ super().__init__()
314
+ self.proj = nn.Linear(dim_in, dim_out)
315
+ self.approximate = approximate
316
+
317
+ def gelu(self, gate):
318
+ if gate.device.type != "mps":
319
+ return F.gelu(gate, approximate=self.approximate)
320
+ # mps: gelu is not implemented for float16
321
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
322
+
323
+ def forward(self, hidden_states):
324
+ hidden_states = self.proj(hidden_states)
325
+ hidden_states = self.gelu(hidden_states)
326
+ return hidden_states
327
+
328
+
329
+ class GEGLU(nn.Module):
330
+ r"""
331
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
332
+
333
+ Parameters:
334
+ dim_in (`int`): The number of channels in the input.
335
+ dim_out (`int`): The number of channels in the output.
336
+ """
337
+
338
+ def __init__(self, dim_in: int, dim_out: int):
339
+ super().__init__()
340
+ self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
341
+
342
+ def gelu(self, gate):
343
+ if gate.device.type != "mps":
344
+ return F.gelu(gate)
345
+ # mps: gelu is not implemented for float16
346
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
347
+
348
+ def forward(self, hidden_states):
349
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
350
+ return hidden_states * self.gelu(gate)
351
+
352
+
353
+ class ApproximateGELU(nn.Module):
354
+ """
355
+ The approximate form of Gaussian Error Linear Unit (GELU)
356
+
357
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
358
+ """
359
+
360
+ def __init__(self, dim_in: int, dim_out: int):
361
+ super().__init__()
362
+ self.proj = nn.Linear(dim_in, dim_out)
363
+
364
+ def forward(self, x):
365
+ x = self.proj(x)
366
+ return x * torch.sigmoid(1.702 * x)
367
+
368
+
369
+ class AdaLayerNorm(nn.Module):
370
+ """
371
+ Norm layer modified to incorporate timestep embeddings.
372
+ """
373
+
374
+ def __init__(self, embedding_dim, num_embeddings):
375
+ super().__init__()
376
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
377
+ self.silu = nn.SiLU()
378
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
379
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
380
+
381
+ def forward(self, x, timestep):
382
+ emb = self.linear(self.silu(self.emb(timestep)))
383
+ scale, shift = torch.chunk(emb, 2)
384
+ x = self.norm(x) * (1 + scale) + shift
385
+ return x
386
+
387
+
388
+ class AdaLayerNormZero(nn.Module):
389
+ """
390
+ Norm layer adaptive layer norm zero (adaLN-Zero).
391
+ """
392
+
393
+ def __init__(self, embedding_dim, num_embeddings):
394
+ super().__init__()
395
+
396
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
397
+
398
+ self.silu = nn.SiLU()
399
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
400
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
401
+
402
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
403
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
404
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
405
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
406
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
407
+
408
+
409
+ class AdaGroupNorm(nn.Module):
410
+ """
411
+ GroupNorm layer modified to incorporate timestep embeddings.
412
+ """
413
+
414
+ def __init__(
415
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
416
+ ):
417
+ super().__init__()
418
+ self.num_groups = num_groups
419
+ self.eps = eps
420
+
421
+ if act_fn is None:
422
+ self.act = None
423
+ else:
424
+ self.act = get_activation(act_fn)
425
+
426
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
427
+
428
+ def forward(self, x, emb):
429
+ if self.act:
430
+ emb = self.act(emb)
431
+ emb = self.linear(emb)
432
+ emb = emb[:, :, None, None]
433
+ scale, shift = emb.chunk(2, dim=1)
434
+
435
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
436
+ x = x * (1 + scale) + shift
437
+ return x
Tiger Model/diffusiers-Tiger/models/attention_processor.py ADDED
@@ -0,0 +1,1716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, Optional, Union
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import deprecate, logging, maybe_allow_in_graph
21
+ from ..utils.import_utils import is_xformers_available
22
+ from .lora import LoRALinearLayer
23
+
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+
28
+ if is_xformers_available():
29
+ import xformers
30
+ import xformers.ops
31
+ else:
32
+ xformers = None
33
+
34
+
35
+ @maybe_allow_in_graph
36
+ class Attention(nn.Module):
37
+ r"""
38
+ A cross attention layer.
39
+
40
+ Parameters:
41
+ query_dim (`int`): The number of channels in the query.
42
+ cross_attention_dim (`int`, *optional*):
43
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
44
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
45
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
46
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
47
+ bias (`bool`, *optional*, defaults to False):
48
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ query_dim: int,
54
+ cross_attention_dim: Optional[int] = None,
55
+ heads: int = 8,
56
+ dim_head: int = 64,
57
+ dropout: float = 0.0,
58
+ bias=False,
59
+ upcast_attention: bool = False,
60
+ upcast_softmax: bool = False,
61
+ cross_attention_norm: Optional[str] = None,
62
+ cross_attention_norm_num_groups: int = 32,
63
+ added_kv_proj_dim: Optional[int] = None,
64
+ norm_num_groups: Optional[int] = None,
65
+ spatial_norm_dim: Optional[int] = None,
66
+ out_bias: bool = True,
67
+ scale_qk: bool = True,
68
+ only_cross_attention: bool = False,
69
+ eps: float = 1e-5,
70
+ rescale_output_factor: float = 1.0,
71
+ residual_connection: bool = False,
72
+ _from_deprecated_attn_block=False,
73
+ processor: Optional["AttnProcessor"] = None,
74
+ weight : Optional[torch.FloatTensor] = None,
75
+ ):
76
+ super().__init__()
77
+ inner_dim = dim_head * heads
78
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
79
+ self.upcast_attention = upcast_attention
80
+ self.upcast_softmax = upcast_softmax
81
+ self.rescale_output_factor = rescale_output_factor
82
+ self.residual_connection = residual_connection
83
+ self.dropout = dropout
84
+ self.weight = weight
85
+ # we make use of this private variable to know whether this class is loaded
86
+ # with an deprecated state dict so that we can convert it on the fly
87
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
88
+
89
+ self.scale_qk = scale_qk
90
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
91
+
92
+ self.heads = heads
93
+ # for slice_size > 0 the attention score computation
94
+ # is split across the batch axis to save memory
95
+ # You can set slice_size with `set_attention_slice`
96
+ self.sliceable_head_dim = heads
97
+
98
+ self.added_kv_proj_dim = added_kv_proj_dim
99
+ self.only_cross_attention = only_cross_attention
100
+
101
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
102
+ raise ValueError(
103
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
104
+ )
105
+
106
+ if norm_num_groups is not None:
107
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
108
+ else:
109
+ self.group_norm = None
110
+
111
+ if spatial_norm_dim is not None:
112
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
113
+ else:
114
+ self.spatial_norm = None
115
+
116
+ if cross_attention_norm is None:
117
+ self.norm_cross = None
118
+ elif cross_attention_norm == "layer_norm":
119
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
120
+ elif cross_attention_norm == "group_norm":
121
+ if self.added_kv_proj_dim is not None:
122
+ # The given `encoder_hidden_states` are initially of shape
123
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
124
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
125
+ # before the projection, so we need to use `added_kv_proj_dim` as
126
+ # the number of channels for the group norm.
127
+ norm_cross_num_channels = added_kv_proj_dim
128
+ else:
129
+ norm_cross_num_channels = cross_attention_dim
130
+
131
+ self.norm_cross = nn.GroupNorm(
132
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
133
+ )
134
+ else:
135
+ raise ValueError(
136
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
137
+ )
138
+
139
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
140
+
141
+ if not self.only_cross_attention:
142
+ # only relevant for the `AddedKVProcessor` classes
143
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
144
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
145
+ else:
146
+ self.to_k = None
147
+ self.to_v = None
148
+
149
+ if self.added_kv_proj_dim is not None:
150
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
151
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
152
+
153
+ self.to_out = nn.ModuleList([])
154
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
155
+ self.to_out.append(nn.Dropout(dropout))
156
+
157
+ # set attention processor
158
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
159
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
160
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
161
+ if processor is None:
162
+ processor = (
163
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
164
+ )
165
+ self.set_processor(processor)
166
+
167
+ def set_use_memory_efficient_attention_xformers(
168
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
169
+ ):
170
+ is_lora = hasattr(self, "processor") and isinstance(
171
+ self.processor,
172
+ LORA_ATTENTION_PROCESSORS,
173
+ )
174
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
175
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
176
+ )
177
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
178
+ self.processor,
179
+ (
180
+ AttnAddedKVProcessor,
181
+ AttnAddedKVProcessor2_0,
182
+ SlicedAttnAddedKVProcessor,
183
+ XFormersAttnAddedKVProcessor,
184
+ LoRAAttnAddedKVProcessor,
185
+ ),
186
+ )
187
+
188
+ if use_memory_efficient_attention_xformers:
189
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
190
+ raise NotImplementedError(
191
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
192
+ )
193
+ if not is_xformers_available():
194
+ raise ModuleNotFoundError(
195
+ (
196
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
197
+ " xformers"
198
+ ),
199
+ name="xformers",
200
+ )
201
+ elif not torch.cuda.is_available():
202
+ raise ValueError(
203
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
204
+ " only available for GPU "
205
+ )
206
+ else:
207
+ try:
208
+ # Make sure we can run the memory efficient attention
209
+ _ = xformers.ops.memory_efficient_attention(
210
+ torch.randn((1, 2, 40), device="cuda"),
211
+ torch.randn((1, 2, 40), device="cuda"),
212
+ torch.randn((1, 2, 40), device="cuda"),
213
+ )
214
+ except Exception as e:
215
+ raise e
216
+
217
+ if is_lora:
218
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
219
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
220
+ processor = LoRAXFormersAttnProcessor(
221
+ hidden_size=self.processor.hidden_size,
222
+ cross_attention_dim=self.processor.cross_attention_dim,
223
+ rank=self.processor.rank,
224
+ attention_op=attention_op,
225
+ )
226
+ processor.load_state_dict(self.processor.state_dict())
227
+ processor.to(self.processor.to_q_lora.up.weight.device)
228
+ elif is_custom_diffusion:
229
+ processor = CustomDiffusionXFormersAttnProcessor(
230
+ train_kv=self.processor.train_kv,
231
+ train_q_out=self.processor.train_q_out,
232
+ hidden_size=self.processor.hidden_size,
233
+ cross_attention_dim=self.processor.cross_attention_dim,
234
+ attention_op=attention_op,
235
+ )
236
+ processor.load_state_dict(self.processor.state_dict())
237
+ if hasattr(self.processor, "to_k_custom_diffusion"):
238
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
239
+ elif is_added_kv_processor:
240
+ # throw warning
241
+ logger.info(
242
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
243
+ )
244
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
245
+ else:
246
+ processor = XFormersAttnProcessor(attention_op=attention_op)
247
+ else:
248
+ if is_lora:
249
+ attn_processor_class = (
250
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
251
+ )
252
+ processor = attn_processor_class(
253
+ hidden_size=self.processor.hidden_size,
254
+ cross_attention_dim=self.processor.cross_attention_dim,
255
+ rank=self.processor.rank,
256
+ )
257
+ processor.load_state_dict(self.processor.state_dict())
258
+ processor.to(self.processor.to_q_lora.up.weight.device)
259
+ elif is_custom_diffusion:
260
+ processor = CustomDiffusionAttnProcessor(
261
+ train_kv=self.processor.train_kv,
262
+ train_q_out=self.processor.train_q_out,
263
+ hidden_size=self.processor.hidden_size,
264
+ cross_attention_dim=self.processor.cross_attention_dim,
265
+ )
266
+ processor.load_state_dict(self.processor.state_dict())
267
+ if hasattr(self.processor, "to_k_custom_diffusion"):
268
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
269
+ else:
270
+ processor = (
271
+ AttnProcessor2_0()
272
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
273
+ else AttnProcessor()
274
+ )
275
+
276
+ self.set_processor(processor)
277
+
278
+ def set_attention_slice(self, slice_size):
279
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
280
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
281
+
282
+ if slice_size is not None and self.added_kv_proj_dim is not None:
283
+ processor = SlicedAttnAddedKVProcessor(slice_size)
284
+ elif slice_size is not None:
285
+ processor = SlicedAttnProcessor(slice_size)
286
+ elif self.added_kv_proj_dim is not None:
287
+ processor = AttnAddedKVProcessor()
288
+ else:
289
+ processor = (
290
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
291
+ )
292
+
293
+ self.set_processor(processor)
294
+
295
+ def set_processor(self, processor: "AttnProcessor"):
296
+ if (
297
+ hasattr(self, "processor")
298
+ and isinstance(self.processor, torch.nn.Module)
299
+ and not isinstance(processor, torch.nn.Module)
300
+ ):
301
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
302
+ self._modules.pop("processor")
303
+
304
+ self.processor = processor
305
+
306
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, weight=None, **cross_attention_kwargs):
307
+ return self.processor(
308
+ self,
309
+ hidden_states,
310
+ encoder_hidden_states=encoder_hidden_states,
311
+ attention_mask=attention_mask,
312
+ weight = weight,
313
+ **cross_attention_kwargs,
314
+ )
315
+
316
+ def batch_to_head_dim(self, tensor):
317
+ head_size = self.heads
318
+ batch_size, seq_len, dim = tensor.shape
319
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
320
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
321
+ return tensor
322
+
323
+ def head_to_batch_dim(self, tensor, out_dim=3):
324
+ head_size = self.heads
325
+ batch_size, seq_len, dim = tensor.shape
326
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
327
+ tensor = tensor.permute(0, 2, 1, 3)
328
+
329
+ if out_dim == 3:
330
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
331
+
332
+ return tensor
333
+
334
+ def get_attention_scores(self, query, key, weight, attention_mask=None,):
335
+ dtype = query.dtype
336
+ if self.upcast_attention:
337
+ query = query.float()
338
+ key = key.float()
339
+ if attention_mask is None:
340
+ baddbmm_input = torch.empty(
341
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
342
+ )
343
+ beta = 0
344
+ else:
345
+ baddbmm_input = attention_mask
346
+ beta = 1
347
+
348
+ attention_scores = torch.baddbmm(
349
+ baddbmm_input,
350
+ query,
351
+ key.transpose(-1, -2),
352
+ beta=beta,
353
+ alpha=self.scale,
354
+ )
355
+
356
+ del baddbmm_input
357
+
358
+ if self.upcast_softmax:
359
+ attention_scores = attention_scores.float()
360
+ attention_probs = attention_scores.softmax(dim=-1)
361
+ del attention_scores
362
+
363
+ attention_probs = attention_probs.to(dtype)
364
+
365
+ return attention_probs
366
+
367
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
368
+ if batch_size is None:
369
+ deprecate(
370
+ "batch_size=None",
371
+ "0.0.15",
372
+ (
373
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
374
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
375
+ " `prepare_attention_mask` when preparing the attention_mask."
376
+ ),
377
+ )
378
+ batch_size = 1
379
+ head_size = self.heads
380
+ if attention_mask is None:
381
+ return attention_mask
382
+
383
+ current_length: int = attention_mask.shape[-1]
384
+ if current_length != target_length:
385
+ if attention_mask.device.type == "mps":
386
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
387
+ # Instead, we can manually construct the padding tensor.
388
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
389
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
390
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
391
+ else:
392
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
393
+
394
+ if out_dim == 3:
395
+ if attention_mask.shape[0] < batch_size * head_size:
396
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
397
+ elif out_dim == 4:
398
+ attention_mask = attention_mask.unsqueeze(1)
399
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
400
+ return attention_mask
401
+
402
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
403
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
404
+
405
+ if isinstance(self.norm_cross, nn.LayerNorm):
406
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
407
+ elif isinstance(self.norm_cross, nn.GroupNorm):
408
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
409
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
410
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
411
+ else:
412
+ assert False
413
+
414
+ return encoder_hidden_states
415
+
416
+ class AttnProcessor:
417
+ r"""
418
+ Default processor for performing attention-related computations.
419
+ """
420
+
421
+ def __call__(
422
+ self,
423
+ attn: Attention,
424
+ hidden_states,
425
+ encoder_hidden_states=None,
426
+ attention_mask=None,
427
+ temb=None,
428
+ weight = None,):
429
+ residual = hidden_states
430
+
431
+ if attn.spatial_norm is not None:
432
+ hidden_states = attn.spatial_norm(hidden_states, temb)
433
+
434
+ input_ndim = hidden_states.ndim
435
+
436
+ if input_ndim == 4:
437
+ batch_size, channel, height, width = hidden_states.shape
438
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
439
+
440
+ batch_size, sequence_length, _ = (
441
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
442
+ )
443
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
444
+ if attn.group_norm is not None:
445
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
446
+
447
+ if weight is not None:
448
+ multiplier = weight.unsqueeze(1).unsqueeze(2)
449
+ hidden_states = hidden_states * multiplier
450
+ query = attn.to_q(hidden_states)
451
+
452
+
453
+ if encoder_hidden_states is None:
454
+ encoder_hidden_states = hidden_states
455
+ elif attn.norm_cross:
456
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
457
+ key = attn.to_k(encoder_hidden_states)
458
+ value = attn.to_v(encoder_hidden_states)
459
+
460
+ query = attn.head_to_batch_dim(query)
461
+ key = attn.head_to_batch_dim(key)
462
+ value = attn.head_to_batch_dim(value)
463
+
464
+ attention_probs = attn.get_attention_scores(query, key, weight, attention_mask)
465
+ hidden_states = torch.bmm(attention_probs, value)
466
+ hidden_states = attn.batch_to_head_dim(hidden_states)
467
+
468
+ # linear proj
469
+ hidden_states = attn.to_out[0](hidden_states)
470
+ # dropout
471
+ hidden_states = attn.to_out[1](hidden_states)
472
+
473
+ if input_ndim == 4:
474
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
475
+
476
+ if attn.residual_connection:
477
+ hidden_states = hidden_states + residual
478
+
479
+ hidden_states = hidden_states / attn.rescale_output_factor
480
+
481
+ return hidden_states
482
+
483
+ class Guid_AttnProcessor:
484
+ r"""
485
+ Default processor for performing attention-related computations.
486
+ """
487
+
488
+ def __call__(
489
+ self,
490
+ attn: Attention,
491
+ hidden_states,
492
+ encoder_hidden_states=None,
493
+ attention_mask=None,
494
+ temb=None,
495
+ ):
496
+ residual = hidden_states
497
+
498
+ if attn.spatial_norm is not None:
499
+ hidden_states = attn.spatial_norm(hidden_states, temb)
500
+
501
+ input_ndim = hidden_states.ndim
502
+
503
+ if input_ndim == 4:
504
+ batch_size, channel, height, width = hidden_states.shape
505
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
506
+
507
+ batch_size, sequence_length, _ = (
508
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
509
+ )
510
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
511
+
512
+ if attn.group_norm is not None:
513
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
514
+
515
+ query = attn.to_q(hidden_states)
516
+
517
+ if encoder_hidden_states is None:
518
+ encoder_hidden_states = hidden_states
519
+ elif attn.norm_cross:
520
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
521
+
522
+ key = attn.to_k(encoder_hidden_states)
523
+ value = attn.to_v(encoder_hidden_states)
524
+
525
+ query = attn.head_to_batch_dim(query)
526
+ key = attn.head_to_batch_dim(key)
527
+ value = attn.head_to_batch_dim(value)
528
+
529
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
530
+ hidden_states = torch.bmm(attention_probs, value)
531
+ hidden_states = attn.batch_to_head_dim(hidden_states)
532
+
533
+ # linear proj
534
+ hidden_states = attn.to_out[0](hidden_states)
535
+ # dropout
536
+ hidden_states = attn.to_out[1](hidden_states)
537
+
538
+ if input_ndim == 4:
539
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
540
+
541
+ if attn.residual_connection:
542
+ hidden_states = hidden_states + residual
543
+
544
+ hidden_states = hidden_states / attn.rescale_output_factor
545
+
546
+ return hidden_states
547
+
548
+ class LoRAAttnProcessor(nn.Module):
549
+ r"""
550
+ Processor for implementing the LoRA attention mechanism.
551
+
552
+ Args:
553
+ hidden_size (`int`, *optional*):
554
+ The hidden size of the attention layer.
555
+ cross_attention_dim (`int`, *optional*):
556
+ The number of channels in the `encoder_hidden_states`.
557
+ rank (`int`, defaults to 4):
558
+ The dimension of the LoRA update matrices.
559
+ network_alpha (`int`, *optional*):
560
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
561
+ """
562
+
563
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
564
+ super().__init__()
565
+
566
+ self.hidden_size = hidden_size
567
+ self.cross_attention_dim = cross_attention_dim
568
+ self.rank = rank
569
+
570
+ q_rank = kwargs.pop("q_rank", None)
571
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
572
+ q_rank = q_rank if q_rank is not None else rank
573
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
574
+
575
+ v_rank = kwargs.pop("v_rank", None)
576
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
577
+ v_rank = v_rank if v_rank is not None else rank
578
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
579
+
580
+ out_rank = kwargs.pop("out_rank", None)
581
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
582
+ out_rank = out_rank if out_rank is not None else rank
583
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
584
+
585
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
586
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
587
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
588
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
589
+
590
+ def __call__(
591
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
592
+ ):
593
+ residual = hidden_states
594
+
595
+ if attn.spatial_norm is not None:
596
+ hidden_states = attn.spatial_norm(hidden_states, temb)
597
+
598
+ input_ndim = hidden_states.ndim
599
+
600
+ if input_ndim == 4:
601
+ batch_size, channel, height, width = hidden_states.shape
602
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
603
+
604
+ batch_size, sequence_length, _ = (
605
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
606
+ )
607
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
608
+
609
+ if attn.group_norm is not None:
610
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
611
+
612
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
613
+ query = attn.head_to_batch_dim(query)
614
+
615
+ if encoder_hidden_states is None:
616
+ encoder_hidden_states = hidden_states
617
+ elif attn.norm_cross:
618
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
619
+
620
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
621
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
622
+
623
+ key = attn.head_to_batch_dim(key)
624
+ value = attn.head_to_batch_dim(value)
625
+
626
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
627
+ hidden_states = torch.bmm(attention_probs, value)
628
+ hidden_states = attn.batch_to_head_dim(hidden_states)
629
+
630
+ # linear proj
631
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
632
+ # dropout
633
+ hidden_states = attn.to_out[1](hidden_states)
634
+
635
+ if input_ndim == 4:
636
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
637
+
638
+ if attn.residual_connection:
639
+ hidden_states = hidden_states + residual
640
+
641
+ hidden_states = hidden_states / attn.rescale_output_factor
642
+
643
+ return hidden_states
644
+
645
+ class CustomDiffusionAttnProcessor(nn.Module):
646
+ r"""
647
+ Processor for implementing attention for the Custom Diffusion method.
648
+
649
+ Args:
650
+ train_kv (`bool`, defaults to `True`):
651
+ Whether to newly train the key and value matrices corresponding to the text features.
652
+ train_q_out (`bool`, defaults to `True`):
653
+ Whether to newly train query matrices corresponding to the latent image features.
654
+ hidden_size (`int`, *optional*, defaults to `None`):
655
+ The hidden size of the attention layer.
656
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
657
+ The number of channels in the `encoder_hidden_states`.
658
+ out_bias (`bool`, defaults to `True`):
659
+ Whether to include the bias parameter in `train_q_out`.
660
+ dropout (`float`, *optional*, defaults to 0.0):
661
+ The dropout probability to use.
662
+ """
663
+
664
+ def __init__(
665
+ self,
666
+ train_kv=True,
667
+ train_q_out=True,
668
+ hidden_size=None,
669
+ cross_attention_dim=None,
670
+ out_bias=True,
671
+ dropout=0.0,
672
+ ):
673
+ super().__init__()
674
+ self.train_kv = train_kv
675
+ self.train_q_out = train_q_out
676
+
677
+ self.hidden_size = hidden_size
678
+ self.cross_attention_dim = cross_attention_dim
679
+
680
+ # `_custom_diffusion` id for easy serialization and loading.
681
+ if self.train_kv:
682
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
683
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
684
+ if self.train_q_out:
685
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
686
+ self.to_out_custom_diffusion = nn.ModuleList([])
687
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
688
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
689
+
690
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
691
+ batch_size, sequence_length, _ = hidden_states.shape
692
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
693
+ if self.train_q_out:
694
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
695
+ else:
696
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
697
+
698
+ if encoder_hidden_states is None:
699
+ crossattn = False
700
+ encoder_hidden_states = hidden_states
701
+ else:
702
+ crossattn = True
703
+ if attn.norm_cross:
704
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
705
+
706
+ if self.train_kv:
707
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
708
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
709
+ key = key.to(attn.to_q.weight.dtype)
710
+ value = value.to(attn.to_q.weight.dtype)
711
+ else:
712
+ key = attn.to_k(encoder_hidden_states)
713
+ value = attn.to_v(encoder_hidden_states)
714
+
715
+ if crossattn:
716
+ detach = torch.ones_like(key)
717
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
718
+ key = detach * key + (1 - detach) * key.detach()
719
+ value = detach * value + (1 - detach) * value.detach()
720
+
721
+ query = attn.head_to_batch_dim(query)
722
+ key = attn.head_to_batch_dim(key)
723
+ value = attn.head_to_batch_dim(value)
724
+
725
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
726
+ hidden_states = torch.bmm(attention_probs, value)
727
+ hidden_states = attn.batch_to_head_dim(hidden_states)
728
+
729
+ if self.train_q_out:
730
+ # linear proj
731
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
732
+ # dropout
733
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
734
+ else:
735
+ # linear proj
736
+ hidden_states = attn.to_out[0](hidden_states)
737
+ # dropout
738
+ hidden_states = attn.to_out[1](hidden_states)
739
+
740
+ return hidden_states
741
+
742
+ class AttnAddedKVProcessor:
743
+ r"""
744
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
745
+ encoder.
746
+ """
747
+
748
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
749
+ residual = hidden_states
750
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
751
+ batch_size, sequence_length, _ = hidden_states.shape
752
+
753
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
754
+
755
+ if encoder_hidden_states is None:
756
+ encoder_hidden_states = hidden_states
757
+ elif attn.norm_cross:
758
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
759
+
760
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
761
+
762
+ query = attn.to_q(hidden_states)
763
+ query = attn.head_to_batch_dim(query)
764
+
765
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
766
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
767
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
768
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
769
+
770
+ if not attn.only_cross_attention:
771
+ key = attn.to_k(hidden_states)
772
+ value = attn.to_v(hidden_states)
773
+ key = attn.head_to_batch_dim(key)
774
+ value = attn.head_to_batch_dim(value)
775
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
776
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
777
+ else:
778
+ key = encoder_hidden_states_key_proj
779
+ value = encoder_hidden_states_value_proj
780
+
781
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
782
+ hidden_states = torch.bmm(attention_probs, value)
783
+ hidden_states = attn.batch_to_head_dim(hidden_states)
784
+
785
+ # linear proj
786
+ hidden_states = attn.to_out[0](hidden_states)
787
+ # dropout
788
+ hidden_states = attn.to_out[1](hidden_states)
789
+
790
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
791
+ hidden_states = hidden_states + residual
792
+
793
+ return hidden_states
794
+
795
+ class AttnAddedKVProcessor2_0:
796
+ r"""
797
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
798
+ learnable key and value matrices for the text encoder.
799
+ """
800
+
801
+ def __init__(self):
802
+ if not hasattr(F, "scaled_dot_product_attention"):
803
+ raise ImportError(
804
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
805
+ )
806
+
807
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
808
+ residual = hidden_states
809
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
810
+ batch_size, sequence_length, _ = hidden_states.shape
811
+
812
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
813
+
814
+ if encoder_hidden_states is None:
815
+ encoder_hidden_states = hidden_states
816
+ elif attn.norm_cross:
817
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
818
+
819
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
820
+
821
+ query = attn.to_q(hidden_states)
822
+ query = attn.head_to_batch_dim(query, out_dim=4)
823
+
824
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
825
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
826
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
827
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
828
+
829
+ if not attn.only_cross_attention:
830
+ key = attn.to_k(hidden_states)
831
+ value = attn.to_v(hidden_states)
832
+ key = attn.head_to_batch_dim(key, out_dim=4)
833
+ value = attn.head_to_batch_dim(value, out_dim=4)
834
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
835
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
836
+ else:
837
+ key = encoder_hidden_states_key_proj
838
+ value = encoder_hidden_states_value_proj
839
+
840
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
841
+ # TODO: add support for attn.scale when we move to Torch 2.1
842
+ hidden_states = F.scaled_dot_product_attention(
843
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
844
+ )
845
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
846
+
847
+ # linear proj
848
+ hidden_states = attn.to_out[0](hidden_states)
849
+ # dropout
850
+ hidden_states = attn.to_out[1](hidden_states)
851
+
852
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
853
+ hidden_states = hidden_states + residual
854
+
855
+ return hidden_states
856
+
857
+ class LoRAAttnAddedKVProcessor(nn.Module):
858
+ r"""
859
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
860
+ encoder.
861
+
862
+ Args:
863
+ hidden_size (`int`, *optional*):
864
+ The hidden size of the attention layer.
865
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
866
+ The number of channels in the `encoder_hidden_states`.
867
+ rank (`int`, defaults to 4):
868
+ The dimension of the LoRA update matrices.
869
+
870
+ """
871
+
872
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
873
+ super().__init__()
874
+
875
+ self.hidden_size = hidden_size
876
+ self.cross_attention_dim = cross_attention_dim
877
+ self.rank = rank
878
+
879
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
880
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
881
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
882
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
883
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
884
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
885
+
886
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
887
+ residual = hidden_states
888
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
889
+ batch_size, sequence_length, _ = hidden_states.shape
890
+
891
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
892
+
893
+ if encoder_hidden_states is None:
894
+ encoder_hidden_states = hidden_states
895
+ elif attn.norm_cross:
896
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
897
+
898
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
899
+
900
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
901
+ query = attn.head_to_batch_dim(query)
902
+
903
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
904
+ encoder_hidden_states
905
+ )
906
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
907
+ encoder_hidden_states
908
+ )
909
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
910
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
911
+
912
+ if not attn.only_cross_attention:
913
+ key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
914
+ value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
915
+ key = attn.head_to_batch_dim(key)
916
+ value = attn.head_to_batch_dim(value)
917
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
918
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
919
+ else:
920
+ key = encoder_hidden_states_key_proj
921
+ value = encoder_hidden_states_value_proj
922
+
923
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
924
+ hidden_states = torch.bmm(attention_probs, value)
925
+ hidden_states = attn.batch_to_head_dim(hidden_states)
926
+
927
+ # linear proj
928
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
929
+ # dropout
930
+ hidden_states = attn.to_out[1](hidden_states)
931
+
932
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
933
+ hidden_states = hidden_states + residual
934
+
935
+ return hidden_states
936
+
937
+
938
+ class XFormersAttnAddedKVProcessor:
939
+ r"""
940
+ Processor for implementing memory efficient attention using xFormers.
941
+
942
+ Args:
943
+ attention_op (`Callable`, *optional*, defaults to `None`):
944
+ The base
945
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
946
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
947
+ operator.
948
+ """
949
+
950
+ def __init__(self, attention_op: Optional[Callable] = None):
951
+ self.attention_op = attention_op
952
+
953
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
954
+ residual = hidden_states
955
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
956
+ batch_size, sequence_length, _ = hidden_states.shape
957
+
958
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
959
+
960
+ if encoder_hidden_states is None:
961
+ encoder_hidden_states = hidden_states
962
+ elif attn.norm_cross:
963
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
964
+
965
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
966
+
967
+ query = attn.to_q(hidden_states)
968
+ query = attn.head_to_batch_dim(query)
969
+
970
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
971
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
972
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
973
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
974
+
975
+ if not attn.only_cross_attention:
976
+ key = attn.to_k(hidden_states)
977
+ value = attn.to_v(hidden_states)
978
+ key = attn.head_to_batch_dim(key)
979
+ value = attn.head_to_batch_dim(value)
980
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
981
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
982
+ else:
983
+ key = encoder_hidden_states_key_proj
984
+ value = encoder_hidden_states_value_proj
985
+
986
+ hidden_states = xformers.ops.memory_efficient_attention(
987
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
988
+ )
989
+ hidden_states = hidden_states.to(query.dtype)
990
+ hidden_states = attn.batch_to_head_dim(hidden_states)
991
+
992
+ # linear proj
993
+ hidden_states = attn.to_out[0](hidden_states)
994
+ # dropout
995
+ hidden_states = attn.to_out[1](hidden_states)
996
+
997
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
998
+ hidden_states = hidden_states + residual
999
+
1000
+ return hidden_states
1001
+
1002
+
1003
+ class XFormersAttnProcessor:
1004
+ r"""
1005
+ Processor for implementing memory efficient attention using xFormers.
1006
+
1007
+ Args:
1008
+ attention_op (`Callable`, *optional*, defaults to `None`):
1009
+ The base
1010
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1011
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1012
+ operator.
1013
+ """
1014
+
1015
+ def __init__(self, attention_op: Optional[Callable] = None):
1016
+ self.attention_op = attention_op
1017
+
1018
+ def __call__(
1019
+ self,
1020
+ attn: Attention,
1021
+ hidden_states: torch.FloatTensor,
1022
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1023
+ attention_mask: Optional[torch.FloatTensor] = None,
1024
+ temb: Optional[torch.FloatTensor] = None,
1025
+ ):
1026
+ residual = hidden_states
1027
+
1028
+ if attn.spatial_norm is not None:
1029
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1030
+
1031
+ input_ndim = hidden_states.ndim
1032
+
1033
+ if input_ndim == 4:
1034
+ batch_size, channel, height, width = hidden_states.shape
1035
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1036
+
1037
+ batch_size, key_tokens, _ = (
1038
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1039
+ )
1040
+
1041
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1042
+ if attention_mask is not None:
1043
+ # expand our mask's singleton query_tokens dimension:
1044
+ # [batch*heads, 1, key_tokens] ->
1045
+ # [batch*heads, query_tokens, key_tokens]
1046
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1047
+ # [batch*heads, query_tokens, key_tokens]
1048
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1049
+ _, query_tokens, _ = hidden_states.shape
1050
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1051
+
1052
+ if attn.group_norm is not None:
1053
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1054
+
1055
+ query = attn.to_q(hidden_states)
1056
+
1057
+ if encoder_hidden_states is None:
1058
+ encoder_hidden_states = hidden_states
1059
+ elif attn.norm_cross:
1060
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1061
+
1062
+ key = attn.to_k(encoder_hidden_states)
1063
+ value = attn.to_v(encoder_hidden_states)
1064
+
1065
+ query = attn.head_to_batch_dim(query).contiguous()
1066
+ key = attn.head_to_batch_dim(key).contiguous()
1067
+ value = attn.head_to_batch_dim(value).contiguous()
1068
+
1069
+ hidden_states = xformers.ops.memory_efficient_attention(
1070
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1071
+ )
1072
+ hidden_states = hidden_states.to(query.dtype)
1073
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1074
+
1075
+ # linear proj
1076
+ hidden_states = attn.to_out[0](hidden_states)
1077
+ # dropout
1078
+ hidden_states = attn.to_out[1](hidden_states)
1079
+
1080
+ if input_ndim == 4:
1081
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1082
+
1083
+ if attn.residual_connection:
1084
+ hidden_states = hidden_states + residual
1085
+
1086
+ hidden_states = hidden_states / attn.rescale_output_factor
1087
+
1088
+ return hidden_states
1089
+
1090
+
1091
+ class AttnProcessor2_0:
1092
+ r"""
1093
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1094
+ """
1095
+
1096
+ def __init__(self):
1097
+ if not hasattr(F, "scaled_dot_product_attention"):
1098
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1099
+
1100
+ def __call__(
1101
+ self,
1102
+ attn: Attention,
1103
+ hidden_states,
1104
+ encoder_hidden_states=None,
1105
+ attention_mask=None,
1106
+ temb=None,
1107
+ ):
1108
+ residual = hidden_states
1109
+
1110
+ if attn.spatial_norm is not None:
1111
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1112
+
1113
+ input_ndim = hidden_states.ndim
1114
+
1115
+ if input_ndim == 4:
1116
+ batch_size, channel, height, width = hidden_states.shape
1117
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1118
+
1119
+ batch_size, sequence_length, _ = (
1120
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1121
+ )
1122
+
1123
+ if attention_mask is not None:
1124
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1125
+ # scaled_dot_product_attention expects attention_mask shape to be
1126
+ # (batch, heads, source_length, target_length)
1127
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1128
+
1129
+ if attn.group_norm is not None:
1130
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1131
+
1132
+ query = attn.to_q(hidden_states)
1133
+
1134
+ if encoder_hidden_states is None:
1135
+ encoder_hidden_states = hidden_states
1136
+ elif attn.norm_cross:
1137
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1138
+
1139
+ key = attn.to_k(encoder_hidden_states)
1140
+ value = attn.to_v(encoder_hidden_states)
1141
+
1142
+ inner_dim = key.shape[-1]
1143
+ head_dim = inner_dim // attn.heads
1144
+
1145
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1146
+
1147
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1148
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1149
+
1150
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1151
+ # TODO: add support for attn.scale when we move to Torch 2.1
1152
+ hidden_states = F.scaled_dot_product_attention(
1153
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1154
+ )
1155
+
1156
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1157
+ hidden_states = hidden_states.to(query.dtype)
1158
+
1159
+ # linear proj
1160
+ hidden_states = attn.to_out[0](hidden_states)
1161
+ # dropout
1162
+ hidden_states = attn.to_out[1](hidden_states)
1163
+
1164
+ if input_ndim == 4:
1165
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1166
+
1167
+ if attn.residual_connection:
1168
+ hidden_states = hidden_states + residual
1169
+
1170
+ hidden_states = hidden_states / attn.rescale_output_factor
1171
+
1172
+ return hidden_states
1173
+
1174
+
1175
+ class LoRAXFormersAttnProcessor(nn.Module):
1176
+ r"""
1177
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1178
+
1179
+ Args:
1180
+ hidden_size (`int`, *optional*):
1181
+ The hidden size of the attention layer.
1182
+ cross_attention_dim (`int`, *optional*):
1183
+ The number of channels in the `encoder_hidden_states`.
1184
+ rank (`int`, defaults to 4):
1185
+ The dimension of the LoRA update matrices.
1186
+ attention_op (`Callable`, *optional*, defaults to `None`):
1187
+ The base
1188
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1189
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1190
+ operator.
1191
+ network_alpha (`int`, *optional*):
1192
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1193
+
1194
+ """
1195
+
1196
+ def __init__(
1197
+ self,
1198
+ hidden_size,
1199
+ cross_attention_dim,
1200
+ rank=4,
1201
+ attention_op: Optional[Callable] = None,
1202
+ network_alpha=None,
1203
+ **kwargs,
1204
+ ):
1205
+ super().__init__()
1206
+
1207
+ self.hidden_size = hidden_size
1208
+ self.cross_attention_dim = cross_attention_dim
1209
+ self.rank = rank
1210
+ self.attention_op = attention_op
1211
+
1212
+ q_rank = kwargs.pop("q_rank", None)
1213
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1214
+ q_rank = q_rank if q_rank is not None else rank
1215
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1216
+
1217
+ v_rank = kwargs.pop("v_rank", None)
1218
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1219
+ v_rank = v_rank if v_rank is not None else rank
1220
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1221
+
1222
+ out_rank = kwargs.pop("out_rank", None)
1223
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1224
+ out_rank = out_rank if out_rank is not None else rank
1225
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1226
+
1227
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1228
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1229
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1230
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1231
+
1232
+ def __call__(
1233
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1234
+ ):
1235
+ residual = hidden_states
1236
+
1237
+ if attn.spatial_norm is not None:
1238
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1239
+
1240
+ input_ndim = hidden_states.ndim
1241
+
1242
+ if input_ndim == 4:
1243
+ batch_size, channel, height, width = hidden_states.shape
1244
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1245
+
1246
+ batch_size, sequence_length, _ = (
1247
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1248
+ )
1249
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1250
+
1251
+ if attn.group_norm is not None:
1252
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1253
+
1254
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1255
+ query = attn.head_to_batch_dim(query).contiguous()
1256
+
1257
+ if encoder_hidden_states is None:
1258
+ encoder_hidden_states = hidden_states
1259
+ elif attn.norm_cross:
1260
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1261
+
1262
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1263
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1264
+
1265
+ key = attn.head_to_batch_dim(key).contiguous()
1266
+ value = attn.head_to_batch_dim(value).contiguous()
1267
+
1268
+ hidden_states = xformers.ops.memory_efficient_attention(
1269
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1270
+ )
1271
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1272
+
1273
+ # linear proj
1274
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1275
+ # dropout
1276
+ hidden_states = attn.to_out[1](hidden_states)
1277
+
1278
+ if input_ndim == 4:
1279
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1280
+
1281
+ if attn.residual_connection:
1282
+ hidden_states = hidden_states + residual
1283
+
1284
+ hidden_states = hidden_states / attn.rescale_output_factor
1285
+
1286
+ return hidden_states
1287
+
1288
+
1289
+ class LoRAAttnProcessor2_0(nn.Module):
1290
+ r"""
1291
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1292
+ attention.
1293
+
1294
+ Args:
1295
+ hidden_size (`int`):
1296
+ The hidden size of the attention layer.
1297
+ cross_attention_dim (`int`, *optional*):
1298
+ The number of channels in the `encoder_hidden_states`.
1299
+ rank (`int`, defaults to 4):
1300
+ The dimension of the LoRA update matrices.
1301
+ network_alpha (`int`, *optional*):
1302
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1303
+ """
1304
+
1305
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
1306
+ super().__init__()
1307
+ if not hasattr(F, "scaled_dot_product_attention"):
1308
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1309
+
1310
+ self.hidden_size = hidden_size
1311
+ self.cross_attention_dim = cross_attention_dim
1312
+ self.rank = rank
1313
+
1314
+ q_rank = kwargs.pop("q_rank", None)
1315
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1316
+ q_rank = q_rank if q_rank is not None else rank
1317
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1318
+
1319
+ v_rank = kwargs.pop("v_rank", None)
1320
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1321
+ v_rank = v_rank if v_rank is not None else rank
1322
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1323
+
1324
+ out_rank = kwargs.pop("out_rank", None)
1325
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1326
+ out_rank = out_rank if out_rank is not None else rank
1327
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1328
+
1329
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1330
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1331
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1332
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1333
+
1334
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1335
+ residual = hidden_states
1336
+
1337
+ input_ndim = hidden_states.ndim
1338
+
1339
+ if input_ndim == 4:
1340
+ batch_size, channel, height, width = hidden_states.shape
1341
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1342
+
1343
+ batch_size, sequence_length, _ = (
1344
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1345
+ )
1346
+ inner_dim = hidden_states.shape[-1]
1347
+
1348
+ if attention_mask is not None:
1349
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1350
+ # scaled_dot_product_attention expects attention_mask shape to be
1351
+ # (batch, heads, source_length, target_length)
1352
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1353
+
1354
+ if attn.group_norm is not None:
1355
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1356
+
1357
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1358
+
1359
+ if encoder_hidden_states is None:
1360
+ encoder_hidden_states = hidden_states
1361
+ elif attn.norm_cross:
1362
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1363
+
1364
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1365
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1366
+
1367
+ head_dim = inner_dim // attn.heads
1368
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1369
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1370
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1371
+
1372
+ # TODO: add support for attn.scale when we move to Torch 2.1
1373
+ hidden_states = F.scaled_dot_product_attention(
1374
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1375
+ )
1376
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1377
+ hidden_states = hidden_states.to(query.dtype)
1378
+
1379
+ # linear proj
1380
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1381
+ # dropout
1382
+ hidden_states = attn.to_out[1](hidden_states)
1383
+
1384
+ if input_ndim == 4:
1385
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1386
+
1387
+ if attn.residual_connection:
1388
+ hidden_states = hidden_states + residual
1389
+
1390
+ hidden_states = hidden_states / attn.rescale_output_factor
1391
+
1392
+ return hidden_states
1393
+
1394
+
1395
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1396
+ r"""
1397
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1398
+
1399
+ Args:
1400
+ train_kv (`bool`, defaults to `True`):
1401
+ Whether to newly train the key and value matrices corresponding to the text features.
1402
+ train_q_out (`bool`, defaults to `True`):
1403
+ Whether to newly train query matrices corresponding to the latent image features.
1404
+ hidden_size (`int`, *optional*, defaults to `None`):
1405
+ The hidden size of the attention layer.
1406
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1407
+ The number of channels in the `encoder_hidden_states`.
1408
+ out_bias (`bool`, defaults to `True`):
1409
+ Whether to include the bias parameter in `train_q_out`.
1410
+ dropout (`float`, *optional*, defaults to 0.0):
1411
+ The dropout probability to use.
1412
+ attention_op (`Callable`, *optional*, defaults to `None`):
1413
+ The base
1414
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1415
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1416
+ """
1417
+
1418
+ def __init__(
1419
+ self,
1420
+ train_kv=True,
1421
+ train_q_out=False,
1422
+ hidden_size=None,
1423
+ cross_attention_dim=None,
1424
+ out_bias=True,
1425
+ dropout=0.0,
1426
+ attention_op: Optional[Callable] = None,
1427
+ ):
1428
+ super().__init__()
1429
+ self.train_kv = train_kv
1430
+ self.train_q_out = train_q_out
1431
+
1432
+ self.hidden_size = hidden_size
1433
+ self.cross_attention_dim = cross_attention_dim
1434
+ self.attention_op = attention_op
1435
+
1436
+ # `_custom_diffusion` id for easy serialization and loading.
1437
+ if self.train_kv:
1438
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1439
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1440
+ if self.train_q_out:
1441
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1442
+ self.to_out_custom_diffusion = nn.ModuleList([])
1443
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1444
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1445
+
1446
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1447
+ batch_size, sequence_length, _ = (
1448
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1449
+ )
1450
+
1451
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1452
+
1453
+ if self.train_q_out:
1454
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
1455
+ else:
1456
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
1457
+
1458
+ if encoder_hidden_states is None:
1459
+ crossattn = False
1460
+ encoder_hidden_states = hidden_states
1461
+ else:
1462
+ crossattn = True
1463
+ if attn.norm_cross:
1464
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1465
+
1466
+ if self.train_kv:
1467
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1468
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1469
+ key = key.to(attn.to_q.weight.dtype)
1470
+ value = value.to(attn.to_q.weight.dtype)
1471
+ else:
1472
+ key = attn.to_k(encoder_hidden_states)
1473
+ value = attn.to_v(encoder_hidden_states)
1474
+
1475
+ if crossattn:
1476
+ detach = torch.ones_like(key)
1477
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1478
+ key = detach * key + (1 - detach) * key.detach()
1479
+ value = detach * value + (1 - detach) * value.detach()
1480
+
1481
+ query = attn.head_to_batch_dim(query).contiguous()
1482
+ key = attn.head_to_batch_dim(key).contiguous()
1483
+ value = attn.head_to_batch_dim(value).contiguous()
1484
+
1485
+ hidden_states = xformers.ops.memory_efficient_attention(
1486
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1487
+ )
1488
+ hidden_states = hidden_states.to(query.dtype)
1489
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1490
+
1491
+ if self.train_q_out:
1492
+ # linear proj
1493
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1494
+ # dropout
1495
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1496
+ else:
1497
+ # linear proj
1498
+ hidden_states = attn.to_out[0](hidden_states)
1499
+ # dropout
1500
+ hidden_states = attn.to_out[1](hidden_states)
1501
+ return hidden_states
1502
+
1503
+
1504
+ class SlicedAttnProcessor:
1505
+ r"""
1506
+ Processor for implementing sliced attention.
1507
+
1508
+ Args:
1509
+ slice_size (`int`, *optional*):
1510
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1511
+ `attention_head_dim` must be a multiple of the `slice_size`.
1512
+ """
1513
+
1514
+ def __init__(self, slice_size):
1515
+ self.slice_size = slice_size
1516
+
1517
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1518
+ residual = hidden_states
1519
+
1520
+ input_ndim = hidden_states.ndim
1521
+
1522
+ if input_ndim == 4:
1523
+ batch_size, channel, height, width = hidden_states.shape
1524
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1525
+
1526
+ batch_size, sequence_length, _ = (
1527
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1528
+ )
1529
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1530
+
1531
+ if attn.group_norm is not None:
1532
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1533
+
1534
+ query = attn.to_q(hidden_states)
1535
+ dim = query.shape[-1]
1536
+ query = attn.head_to_batch_dim(query)
1537
+
1538
+ if encoder_hidden_states is None:
1539
+ encoder_hidden_states = hidden_states
1540
+ elif attn.norm_cross:
1541
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1542
+
1543
+ key = attn.to_k(encoder_hidden_states)
1544
+ value = attn.to_v(encoder_hidden_states)
1545
+ key = attn.head_to_batch_dim(key)
1546
+ value = attn.head_to_batch_dim(value)
1547
+
1548
+ batch_size_attention, query_tokens, _ = query.shape
1549
+ hidden_states = torch.zeros(
1550
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1551
+ )
1552
+
1553
+ for i in range(batch_size_attention // self.slice_size):
1554
+ start_idx = i * self.slice_size
1555
+ end_idx = (i + 1) * self.slice_size
1556
+
1557
+ query_slice = query[start_idx:end_idx]
1558
+ key_slice = key[start_idx:end_idx]
1559
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1560
+ ###########################################################################################################
1561
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, weight, attn_mask_slice)
1562
+
1563
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1564
+
1565
+ hidden_states[start_idx:end_idx] = attn_slice
1566
+
1567
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1568
+
1569
+ # linear proj
1570
+ hidden_states = attn.to_out[0](hidden_states)
1571
+ # dropout
1572
+ hidden_states = attn.to_out[1](hidden_states)
1573
+
1574
+ if input_ndim == 4:
1575
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1576
+
1577
+ if attn.residual_connection:
1578
+ hidden_states = hidden_states + residual
1579
+
1580
+ hidden_states = hidden_states / attn.rescale_output_factor
1581
+
1582
+ return hidden_states
1583
+
1584
+
1585
+ class SlicedAttnAddedKVProcessor:
1586
+ r"""
1587
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1588
+
1589
+ Args:
1590
+ slice_size (`int`, *optional*):
1591
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1592
+ `attention_head_dim` must be a multiple of the `slice_size`.
1593
+ """
1594
+
1595
+ def __init__(self, slice_size):
1596
+ self.slice_size = slice_size
1597
+
1598
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1599
+ residual = hidden_states
1600
+
1601
+ if attn.spatial_norm is not None:
1602
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1603
+
1604
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1605
+
1606
+ batch_size, sequence_length, _ = hidden_states.shape
1607
+
1608
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1609
+
1610
+ if encoder_hidden_states is None:
1611
+ encoder_hidden_states = hidden_states
1612
+ elif attn.norm_cross:
1613
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1614
+
1615
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1616
+
1617
+ query = attn.to_q(hidden_states)
1618
+ dim = query.shape[-1]
1619
+ query = attn.head_to_batch_dim(query)
1620
+
1621
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1622
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1623
+
1624
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1625
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1626
+
1627
+ if not attn.only_cross_attention:
1628
+ key = attn.to_k(hidden_states)
1629
+ value = attn.to_v(hidden_states)
1630
+ key = attn.head_to_batch_dim(key)
1631
+ value = attn.head_to_batch_dim(value)
1632
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1633
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1634
+ else:
1635
+ key = encoder_hidden_states_key_proj
1636
+ value = encoder_hidden_states_value_proj
1637
+
1638
+ batch_size_attention, query_tokens, _ = query.shape
1639
+ hidden_states = torch.zeros(
1640
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1641
+ )
1642
+
1643
+ for i in range(batch_size_attention // self.slice_size):
1644
+ start_idx = i * self.slice_size
1645
+ end_idx = (i + 1) * self.slice_size
1646
+
1647
+ query_slice = query[start_idx:end_idx]
1648
+ key_slice = key[start_idx:end_idx]
1649
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1650
+ ###########################################################################################################
1651
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, weight, attn_mask_slice)
1652
+
1653
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1654
+
1655
+ hidden_states[start_idx:end_idx] = attn_slice
1656
+
1657
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1658
+
1659
+ # linear proj
1660
+ hidden_states = attn.to_out[0](hidden_states)
1661
+ # dropout
1662
+ hidden_states = attn.to_out[1](hidden_states)
1663
+
1664
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1665
+ hidden_states = hidden_states + residual
1666
+
1667
+ return hidden_states
1668
+
1669
+
1670
+ AttentionProcessor = Union[
1671
+ AttnProcessor,
1672
+ Guid_AttnProcessor,
1673
+ AttnProcessor2_0,
1674
+ XFormersAttnProcessor,
1675
+ SlicedAttnProcessor,
1676
+ AttnAddedKVProcessor,
1677
+ SlicedAttnAddedKVProcessor,
1678
+ AttnAddedKVProcessor2_0,
1679
+ XFormersAttnAddedKVProcessor,
1680
+ LoRAAttnProcessor,
1681
+ LoRAXFormersAttnProcessor,
1682
+ LoRAAttnProcessor2_0,
1683
+ LoRAAttnAddedKVProcessor,
1684
+ CustomDiffusionAttnProcessor,
1685
+ CustomDiffusionXFormersAttnProcessor,
1686
+ ]
1687
+
1688
+ LORA_ATTENTION_PROCESSORS = (
1689
+ LoRAAttnProcessor,
1690
+ LoRAAttnProcessor2_0,
1691
+ LoRAXFormersAttnProcessor,
1692
+ LoRAAttnAddedKVProcessor,
1693
+ )
1694
+
1695
+
1696
+ class SpatialNorm(nn.Module):
1697
+ """
1698
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1699
+ """
1700
+
1701
+ def __init__(
1702
+ self,
1703
+ f_channels,
1704
+ zq_channels,
1705
+ ):
1706
+ super().__init__()
1707
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1708
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1709
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1710
+
1711
+ def forward(self, f, zq):
1712
+ f_size = f.shape[-2:]
1713
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1714
+ norm_f = self.norm_layer(f)
1715
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1716
+ return new_f
Tiger Model/diffusiers-Tiger/models/autoencoder_asym_kl.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ..configuration_utils import ConfigMixin, register_to_config
20
+ from ..utils import apply_forward_hook
21
+ from .autoencoder_kl import AutoencoderKLOutput
22
+ from .modeling_utils import ModelMixin
23
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
24
+
25
+
26
+ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
27
+ r"""
28
+ Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
29
+ for encoding images into latents and decoding latent representations into images.
30
+
31
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
32
+ for all models (such as downloading or saving).
33
+
34
+ Parameters:
35
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
36
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
37
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
38
+ Tuple of downsample block types.
39
+ down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
40
+ Tuple of down block output channels.
41
+ layers_per_down_block (`int`, *optional*, defaults to `1`):
42
+ Number layers for down block.
43
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
44
+ Tuple of upsample block types.
45
+ up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
46
+ Tuple of up block output channels.
47
+ layers_per_up_block (`int`, *optional*, defaults to `1`):
48
+ Number layers for up block.
49
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
50
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
51
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
52
+ norm_num_groups (`int`, *optional*, defaults to `32`):
53
+ Number of groups to use for the first normalization layer in ResNet blocks.
54
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
55
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
56
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
57
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
58
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
59
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
60
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
61
+ """
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
69
+ down_block_out_channels: Tuple[int] = (64,),
70
+ layers_per_down_block: int = 1,
71
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
72
+ up_block_out_channels: Tuple[int] = (64,),
73
+ layers_per_up_block: int = 1,
74
+ act_fn: str = "silu",
75
+ latent_channels: int = 4,
76
+ norm_num_groups: int = 32,
77
+ sample_size: int = 32,
78
+ scaling_factor: float = 0.18215,
79
+ ) -> None:
80
+ super().__init__()
81
+
82
+ # pass init params to Encoder
83
+ self.encoder = Encoder(
84
+ in_channels=in_channels,
85
+ out_channels=latent_channels,
86
+ down_block_types=down_block_types,
87
+ block_out_channels=down_block_out_channels,
88
+ layers_per_block=layers_per_down_block,
89
+ act_fn=act_fn,
90
+ norm_num_groups=norm_num_groups,
91
+ double_z=True,
92
+ )
93
+
94
+ # pass init params to Decoder
95
+ self.decoder = MaskConditionDecoder(
96
+ in_channels=latent_channels,
97
+ out_channels=out_channels,
98
+ up_block_types=up_block_types,
99
+ block_out_channels=up_block_out_channels,
100
+ layers_per_block=layers_per_up_block,
101
+ act_fn=act_fn,
102
+ norm_num_groups=norm_num_groups,
103
+ )
104
+
105
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
106
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
107
+
108
+ self.use_slicing = False
109
+ self.use_tiling = False
110
+
111
+ @apply_forward_hook
112
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
113
+ h = self.encoder(x)
114
+ moments = self.quant_conv(h)
115
+ posterior = DiagonalGaussianDistribution(moments)
116
+
117
+ if not return_dict:
118
+ return (posterior,)
119
+
120
+ return AutoencoderKLOutput(latent_dist=posterior)
121
+
122
+ def _decode(
123
+ self,
124
+ z: torch.FloatTensor,
125
+ image: Optional[torch.FloatTensor] = None,
126
+ mask: Optional[torch.FloatTensor] = None,
127
+ return_dict: bool = True,
128
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
129
+ z = self.post_quant_conv(z)
130
+ dec = self.decoder(z, image, mask)
131
+
132
+ if not return_dict:
133
+ return (dec,)
134
+
135
+ return DecoderOutput(sample=dec)
136
+
137
+ @apply_forward_hook
138
+ def decode(
139
+ self,
140
+ z: torch.FloatTensor,
141
+ image: Optional[torch.FloatTensor] = None,
142
+ mask: Optional[torch.FloatTensor] = None,
143
+ return_dict: bool = True,
144
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
145
+ decoded = self._decode(z, image, mask).sample
146
+
147
+ if not return_dict:
148
+ return (decoded,)
149
+
150
+ return DecoderOutput(sample=decoded)
151
+
152
+ def forward(
153
+ self,
154
+ sample: torch.FloatTensor,
155
+ mask: Optional[torch.FloatTensor] = None,
156
+ sample_posterior: bool = False,
157
+ return_dict: bool = True,
158
+ generator: Optional[torch.Generator] = None,
159
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
160
+ r"""
161
+ Args:
162
+ sample (`torch.FloatTensor`): Input sample.
163
+ mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
164
+ sample_posterior (`bool`, *optional*, defaults to `False`):
165
+ Whether to sample from the posterior.
166
+ return_dict (`bool`, *optional*, defaults to `True`):
167
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
168
+ """
169
+ x = sample
170
+ posterior = self.encode(x).latent_dist
171
+ if sample_posterior:
172
+ z = posterior.sample(generator=generator)
173
+ else:
174
+ z = posterior.mode()
175
+ dec = self.decode(z, sample, mask).sample
176
+
177
+ if not return_dict:
178
+ return (dec,)
179
+
180
+ return DecoderOutput(sample=dec)
Tiger Model/diffusiers-Tiger/models/autoencoder_kl.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..loaders import FromOriginalVAEMixin
22
+ from ..utils import BaseOutput, apply_forward_hook
23
+ from .attention_processor import AttentionProcessor, AttnProcessor
24
+ from .modeling_utils import ModelMixin
25
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
26
+
27
+
28
+ @dataclass
29
+ class AutoencoderKLOutput(BaseOutput):
30
+ """
31
+ Output of AutoencoderKL encoding method.
32
+
33
+ Args:
34
+ latent_dist (`DiagonalGaussianDistribution`):
35
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
36
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
37
+ """
38
+
39
+ latent_dist: "DiagonalGaussianDistribution"
40
+
41
+
42
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
43
+ r"""
44
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
45
+
46
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
47
+ for all models (such as downloading or saving).
48
+
49
+ Parameters:
50
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
51
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
52
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
53
+ Tuple of downsample block types.
54
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
55
+ Tuple of upsample block types.
56
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
57
+ Tuple of block output channels.
58
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
59
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
60
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
61
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
62
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
63
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
64
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
65
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
66
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
67
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
68
+ force_upcast (`bool`, *optional*, default to `True`):
69
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
70
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
71
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
72
+ """
73
+
74
+ _supports_gradient_checkpointing = True
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ in_channels: int = 3,
80
+ out_channels: int = 3,
81
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
82
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
83
+ block_out_channels: Tuple[int] = (64,),
84
+ layers_per_block: int = 1,
85
+ act_fn: str = "silu",
86
+ latent_channels: int = 4,
87
+ norm_num_groups: int = 32,
88
+ sample_size: int = 32,
89
+ scaling_factor: float = 0.18215,
90
+ force_upcast: float = True,
91
+ ):
92
+ super().__init__()
93
+
94
+ # pass init params to Encoder
95
+ self.encoder = Encoder(
96
+ in_channels=in_channels,
97
+ out_channels=latent_channels,
98
+ down_block_types=down_block_types,
99
+ block_out_channels=block_out_channels,
100
+ layers_per_block=layers_per_block,
101
+ act_fn=act_fn,
102
+ norm_num_groups=norm_num_groups,
103
+ double_z=True,
104
+ )
105
+
106
+ # pass init params to Decoder
107
+ self.decoder = Decoder(
108
+ in_channels=latent_channels,
109
+ out_channels=out_channels,
110
+ up_block_types=up_block_types,
111
+ block_out_channels=block_out_channels,
112
+ layers_per_block=layers_per_block,
113
+ norm_num_groups=norm_num_groups,
114
+ act_fn=act_fn,
115
+ )
116
+
117
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
118
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
119
+
120
+ self.use_slicing = False
121
+ self.use_tiling = False
122
+
123
+ # only relevant if vae tiling is enabled
124
+ self.tile_sample_min_size = self.config.sample_size
125
+ sample_size = (
126
+ self.config.sample_size[0]
127
+ if isinstance(self.config.sample_size, (list, tuple))
128
+ else self.config.sample_size
129
+ )
130
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
131
+ self.tile_overlap_factor = 0.25
132
+
133
+ def _set_gradient_checkpointing(self, module, value=False):
134
+ if isinstance(module, (Encoder, Decoder)):
135
+ module.gradient_checkpointing = value
136
+
137
+ def enable_tiling(self, use_tiling: bool = True):
138
+ r"""
139
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
140
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
141
+ processing larger images.
142
+ """
143
+ self.use_tiling = use_tiling
144
+
145
+ def disable_tiling(self):
146
+ r"""
147
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
148
+ decoding in one step.
149
+ """
150
+ self.enable_tiling(False)
151
+
152
+ def enable_slicing(self):
153
+ r"""
154
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
155
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
156
+ """
157
+ self.use_slicing = True
158
+
159
+ def disable_slicing(self):
160
+ r"""
161
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
162
+ decoding in one step.
163
+ """
164
+ self.use_slicing = False
165
+
166
+ @property
167
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
168
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
169
+ r"""
170
+ Returns:
171
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
172
+ indexed by its weight name.
173
+ """
174
+ # set recursively
175
+ processors = {}
176
+
177
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
178
+ if hasattr(module, "set_processor"):
179
+ processors[f"{name}.processor"] = module.processor
180
+
181
+ for sub_name, child in module.named_children():
182
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
183
+
184
+ return processors
185
+
186
+ for name, module in self.named_children():
187
+ fn_recursive_add_processors(name, module, processors)
188
+
189
+ return processors
190
+
191
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
192
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
193
+ r"""
194
+ Sets the attention processor to use to compute attention.
195
+
196
+ Parameters:
197
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
198
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
199
+ for **all** `Attention` layers.
200
+
201
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
202
+ processor. This is strongly recommended when setting trainable attention processors.
203
+
204
+ """
205
+ count = len(self.attn_processors.keys())
206
+
207
+ if isinstance(processor, dict) and len(processor) != count:
208
+ raise ValueError(
209
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
210
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
211
+ )
212
+
213
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
214
+ if hasattr(module, "set_processor"):
215
+ if not isinstance(processor, dict):
216
+ module.set_processor(processor)
217
+ else:
218
+ module.set_processor(processor.pop(f"{name}.processor"))
219
+
220
+ for sub_name, child in module.named_children():
221
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
222
+
223
+ for name, module in self.named_children():
224
+ fn_recursive_attn_processor(name, module, processor)
225
+
226
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
227
+ def set_default_attn_processor(self):
228
+ """
229
+ Disables custom attention processors and sets the default attention implementation.
230
+ """
231
+ self.set_attn_processor(AttnProcessor())
232
+
233
+ @apply_forward_hook
234
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
235
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
236
+ return self.tiled_encode(x, return_dict=return_dict)
237
+
238
+ if self.use_slicing and x.shape[0] > 1:
239
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
240
+ h = torch.cat(encoded_slices)
241
+ else:
242
+ h = self.encoder(x)
243
+
244
+ moments = self.quant_conv(h)
245
+ posterior = DiagonalGaussianDistribution(moments)
246
+
247
+ if not return_dict:
248
+ return (posterior,)
249
+
250
+ return AutoencoderKLOutput(latent_dist=posterior)
251
+
252
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
253
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
254
+ return self.tiled_decode(z, return_dict=return_dict)
255
+
256
+ z = self.post_quant_conv(z)
257
+ dec = self.decoder(z)
258
+
259
+ if not return_dict:
260
+ return (dec,)
261
+
262
+ return DecoderOutput(sample=dec)
263
+
264
+ @apply_forward_hook
265
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
266
+ if self.use_slicing and z.shape[0] > 1:
267
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
268
+ decoded = torch.cat(decoded_slices)
269
+ else:
270
+ decoded = self._decode(z).sample
271
+
272
+ if not return_dict:
273
+ return (decoded,)
274
+
275
+ return DecoderOutput(sample=decoded)
276
+
277
+ def blend_v(self, a, b, blend_extent):
278
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
279
+ for y in range(blend_extent):
280
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
281
+ return b
282
+
283
+ def blend_h(self, a, b, blend_extent):
284
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
285
+ for x in range(blend_extent):
286
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
287
+ return b
288
+
289
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
290
+ r"""Encode a batch of images using a tiled encoder.
291
+
292
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
293
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
294
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
295
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
296
+ output, but they should be much less noticeable.
297
+
298
+ Args:
299
+ x (`torch.FloatTensor`): Input batch of images.
300
+ return_dict (`bool`, *optional*, defaults to `True`):
301
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
302
+
303
+ Returns:
304
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
305
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
306
+ `tuple` is returned.
307
+ """
308
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
309
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
310
+ row_limit = self.tile_latent_min_size - blend_extent
311
+
312
+ # Split the image into 512x512 tiles and encode them separately.
313
+ rows = []
314
+ for i in range(0, x.shape[2], overlap_size):
315
+ row = []
316
+ for j in range(0, x.shape[3], overlap_size):
317
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
318
+ tile = self.encoder(tile)
319
+ tile = self.quant_conv(tile)
320
+ row.append(tile)
321
+ rows.append(row)
322
+ result_rows = []
323
+ for i, row in enumerate(rows):
324
+ result_row = []
325
+ for j, tile in enumerate(row):
326
+ # blend the above tile and the left tile
327
+ # to the current tile and add the current tile to the result row
328
+ if i > 0:
329
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
330
+ if j > 0:
331
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
332
+ result_row.append(tile[:, :, :row_limit, :row_limit])
333
+ result_rows.append(torch.cat(result_row, dim=3))
334
+
335
+ moments = torch.cat(result_rows, dim=2)
336
+ posterior = DiagonalGaussianDistribution(moments)
337
+
338
+ if not return_dict:
339
+ return (posterior,)
340
+
341
+ return AutoencoderKLOutput(latent_dist=posterior)
342
+
343
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
344
+ r"""
345
+ Decode a batch of images using a tiled decoder.
346
+
347
+ Args:
348
+ z (`torch.FloatTensor`): Input batch of latent vectors.
349
+ return_dict (`bool`, *optional*, defaults to `True`):
350
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
351
+
352
+ Returns:
353
+ [`~models.vae.DecoderOutput`] or `tuple`:
354
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
355
+ returned.
356
+ """
357
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
358
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
359
+ row_limit = self.tile_sample_min_size - blend_extent
360
+
361
+ # Split z into overlapping 64x64 tiles and decode them separately.
362
+ # The tiles have an overlap to avoid seams between tiles.
363
+ rows = []
364
+ for i in range(0, z.shape[2], overlap_size):
365
+ row = []
366
+ for j in range(0, z.shape[3], overlap_size):
367
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
368
+ tile = self.post_quant_conv(tile)
369
+ decoded = self.decoder(tile)
370
+ row.append(decoded)
371
+ rows.append(row)
372
+ result_rows = []
373
+ for i, row in enumerate(rows):
374
+ result_row = []
375
+ for j, tile in enumerate(row):
376
+ # blend the above tile and the left tile
377
+ # to the current tile and add the current tile to the result row
378
+ if i > 0:
379
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
380
+ if j > 0:
381
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
382
+ result_row.append(tile[:, :, :row_limit, :row_limit])
383
+ result_rows.append(torch.cat(result_row, dim=3))
384
+
385
+ dec = torch.cat(result_rows, dim=2)
386
+ if not return_dict:
387
+ return (dec,)
388
+
389
+ return DecoderOutput(sample=dec)
390
+
391
+ def forward(
392
+ self,
393
+ sample: torch.FloatTensor,
394
+ sample_posterior: bool = False,
395
+ return_dict: bool = True,
396
+ generator: Optional[torch.Generator] = None,
397
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
398
+ r"""
399
+ Args:
400
+ sample (`torch.FloatTensor`): Input sample.
401
+ sample_posterior (`bool`, *optional*, defaults to `False`):
402
+ Whether to sample from the posterior.
403
+ return_dict (`bool`, *optional*, defaults to `True`):
404
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
405
+ """
406
+ x = sample
407
+ posterior = self.encode(x).latent_dist
408
+ if sample_posterior:
409
+ z = posterior.sample(generator=generator)
410
+ else:
411
+ z = posterior.mode()
412
+ dec = self.decode(z).sample
413
+
414
+ if not return_dict:
415
+ return (dec,)
416
+
417
+ return DecoderOutput(sample=dec)
Tiger Model/diffusiers-Tiger/models/autoencoder_tiny.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Tuple, Union
18
+
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, apply_forward_hook
23
+ from .modeling_utils import ModelMixin
24
+ from .vae import DecoderOutput, DecoderTiny, EncoderTiny
25
+
26
+
27
+ @dataclass
28
+ class AutoencoderTinyOutput(BaseOutput):
29
+ """
30
+ Output of AutoencoderTiny encoding method.
31
+
32
+ Args:
33
+ latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
34
+
35
+ """
36
+
37
+ latents: torch.Tensor
38
+
39
+
40
+ class AutoencoderTiny(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
43
+
44
+ [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
45
+
46
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
47
+ all models (such as downloading or saving).
48
+
49
+ Parameters:
50
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
51
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
52
+ encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
53
+ Tuple of integers representing the number of output channels for each encoder block. The length of the
54
+ tuple should be equal to the number of encoder blocks.
55
+ decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
56
+ Tuple of integers representing the number of output channels for each decoder block. The length of the
57
+ tuple should be equal to the number of decoder blocks.
58
+ act_fn (`str`, *optional*, defaults to `"relu"`):
59
+ Activation function to be used throughout the model.
60
+ latent_channels (`int`, *optional*, defaults to 4):
61
+ Number of channels in the latent representation. The latent space acts as a compressed representation of
62
+ the input image.
63
+ upsampling_scaling_factor (`int`, *optional*, defaults to 2):
64
+ Scaling factor for upsampling in the decoder. It determines the size of the output image during the
65
+ upsampling process.
66
+ num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
67
+ Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
68
+ length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
69
+ number of encoder blocks.
70
+ num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
71
+ Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
72
+ length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
73
+ number of decoder blocks.
74
+ latent_magnitude (`float`, *optional*, defaults to 3.0):
75
+ Magnitude of the latent representation. This parameter scales the latent representation values to control
76
+ the extent of information preservation.
77
+ latent_shift (float, *optional*, defaults to 0.5):
78
+ Shift applied to the latent representation. This parameter controls the center of the latent space.
79
+ scaling_factor (`float`, *optional*, defaults to 1.0):
80
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
81
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
82
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
83
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
84
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
85
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
86
+ however, no such scaling factor was used, hence the value of 1.0 as the default.
87
+ force_upcast (`bool`, *optional*, default to `False`):
88
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
89
+ can be fine-tuned / trained to a lower range without losing too much precision, in which case
90
+ `force_upcast` can be set to `False` (see this fp16-friendly
91
+ [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
92
+ """
93
+ _supports_gradient_checkpointing = True
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ in_channels=3,
99
+ out_channels=3,
100
+ encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
101
+ decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
102
+ act_fn: str = "relu",
103
+ latent_channels: int = 4,
104
+ upsampling_scaling_factor: int = 2,
105
+ num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
106
+ num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
107
+ latent_magnitude: int = 3,
108
+ latent_shift: float = 0.5,
109
+ force_upcast: float = False,
110
+ scaling_factor: float = 1.0,
111
+ ):
112
+ super().__init__()
113
+
114
+ if len(encoder_block_out_channels) != len(num_encoder_blocks):
115
+ raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
116
+ if len(decoder_block_out_channels) != len(num_decoder_blocks):
117
+ raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
118
+
119
+ self.encoder = EncoderTiny(
120
+ in_channels=in_channels,
121
+ out_channels=latent_channels,
122
+ num_blocks=num_encoder_blocks,
123
+ block_out_channels=encoder_block_out_channels,
124
+ act_fn=act_fn,
125
+ )
126
+
127
+ self.decoder = DecoderTiny(
128
+ in_channels=latent_channels,
129
+ out_channels=out_channels,
130
+ num_blocks=num_decoder_blocks,
131
+ block_out_channels=decoder_block_out_channels,
132
+ upsampling_scaling_factor=upsampling_scaling_factor,
133
+ act_fn=act_fn,
134
+ )
135
+
136
+ self.latent_magnitude = latent_magnitude
137
+ self.latent_shift = latent_shift
138
+ self.scaling_factor = scaling_factor
139
+
140
+ self.use_slicing = False
141
+ self.use_tiling = False
142
+
143
+ # only relevant if vae tiling is enabled
144
+ self.spatial_scale_factor = 2**out_channels
145
+ self.tile_overlap_factor = 0.125
146
+ self.tile_sample_min_size = 512
147
+ self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
148
+
149
+ def _set_gradient_checkpointing(self, module, value=False):
150
+ if isinstance(module, (EncoderTiny, DecoderTiny)):
151
+ module.gradient_checkpointing = value
152
+
153
+ def scale_latents(self, x):
154
+ """raw latents -> [0, 1]"""
155
+ return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
156
+
157
+ def unscale_latents(self, x):
158
+ """[0, 1] -> raw latents"""
159
+ return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
160
+
161
+ def enable_slicing(self):
162
+ r"""
163
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
164
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
165
+ """
166
+ self.use_slicing = True
167
+
168
+ def disable_slicing(self):
169
+ r"""
170
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
171
+ decoding in one step.
172
+ """
173
+ self.use_slicing = False
174
+
175
+ def enable_tiling(self, use_tiling: bool = True):
176
+ r"""
177
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
178
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
179
+ processing larger images.
180
+ """
181
+ self.use_tiling = use_tiling
182
+
183
+ def disable_tiling(self):
184
+ r"""
185
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
186
+ decoding in one step.
187
+ """
188
+ self.enable_tiling(False)
189
+
190
+ def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
191
+ r"""Encode a batch of images using a tiled encoder.
192
+
193
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
194
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
195
+ tiles overlap and are blended together to form a smooth output.
196
+
197
+ Args:
198
+ x (`torch.FloatTensor`): Input batch of images.
199
+ return_dict (`bool`, *optional*, defaults to `True`):
200
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
201
+
202
+ Returns:
203
+ [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
204
+ If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
205
+ plain `tuple` is returned.
206
+ """
207
+ # scale of encoder output relative to input
208
+ sf = self.spatial_scale_factor
209
+ tile_size = self.tile_sample_min_size
210
+
211
+ # number of pixels to blend and to traverse between tile
212
+ blend_size = int(tile_size * self.tile_overlap_factor)
213
+ traverse_size = tile_size - blend_size
214
+
215
+ # tiles index (up/left)
216
+ ti = range(0, x.shape[-2], traverse_size)
217
+ tj = range(0, x.shape[-1], traverse_size)
218
+
219
+ # mask for blending
220
+ blend_masks = torch.stack(
221
+ torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
222
+ )
223
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
224
+
225
+ # output array
226
+ out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
227
+ for i in ti:
228
+ for j in tj:
229
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
230
+ # tile result
231
+ tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
232
+ tile = self.encoder(tile_in)
233
+ h, w = tile.shape[-2], tile.shape[-1]
234
+ # blend tile result into output
235
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
236
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
237
+ blend_mask = blend_mask_i * blend_mask_j
238
+ tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
239
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
240
+ return out
241
+
242
+ def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
243
+ r"""Encode a batch of images using a tiled encoder.
244
+
245
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
246
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
247
+ tiles overlap and are blended together to form a smooth output.
248
+
249
+ Args:
250
+ x (`torch.FloatTensor`): Input batch of images.
251
+ return_dict (`bool`, *optional*, defaults to `True`):
252
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
253
+
254
+ Returns:
255
+ [`~models.vae.DecoderOutput`] or `tuple`:
256
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
257
+ returned.
258
+ """
259
+ # scale of decoder output relative to input
260
+ sf = self.spatial_scale_factor
261
+ tile_size = self.tile_latent_min_size
262
+
263
+ # number of pixels to blend and to traverse between tiles
264
+ blend_size = int(tile_size * self.tile_overlap_factor)
265
+ traverse_size = tile_size - blend_size
266
+
267
+ # tiles index (up/left)
268
+ ti = range(0, x.shape[-2], traverse_size)
269
+ tj = range(0, x.shape[-1], traverse_size)
270
+
271
+ # mask for blending
272
+ blend_masks = torch.stack(
273
+ torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
274
+ )
275
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
276
+
277
+ # output array
278
+ out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
279
+ for i in ti:
280
+ for j in tj:
281
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
282
+ # tile result
283
+ tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
284
+ tile = self.decoder(tile_in)
285
+ h, w = tile.shape[-2], tile.shape[-1]
286
+ # blend tile result into output
287
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
288
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
289
+ blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
290
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
291
+ return out
292
+
293
+ @apply_forward_hook
294
+ def encode(
295
+ self, x: torch.FloatTensor, return_dict: bool = True
296
+ ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
297
+ if self.use_slicing and x.shape[0] > 1:
298
+ output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
299
+ output = torch.cat(output)
300
+ else:
301
+ output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
302
+
303
+ if not return_dict:
304
+ return (output,)
305
+
306
+ return AutoencoderTinyOutput(latents=output)
307
+
308
+ @apply_forward_hook
309
+ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
310
+ if self.use_slicing and x.shape[0] > 1:
311
+ output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
312
+ output = torch.cat(output)
313
+ else:
314
+ output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
315
+ # Refer to the following discussion to know why this is needed.
316
+ # https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
317
+ output = output.mul_(2).sub_(1)
318
+
319
+ if not return_dict:
320
+ return (output,)
321
+
322
+ return DecoderOutput(sample=output)
323
+
324
+ def forward(
325
+ self,
326
+ sample: torch.FloatTensor,
327
+ return_dict: bool = True,
328
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
329
+ r"""
330
+ Args:
331
+ sample (`torch.FloatTensor`): Input sample.
332
+ return_dict (`bool`, *optional*, defaults to `True`):
333
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
334
+ """
335
+ enc = self.encode(sample).latents
336
+ scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
337
+ unscaled_enc = self.unscale_latents(scaled_enc)
338
+ dec = self.decode(unscaled_enc)
339
+
340
+ if not return_dict:
341
+ return (dec,)
342
+ return DecoderOutput(sample=dec)
Tiger Model/diffusiers-Tiger/models/controlnet.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders import FromOriginalControlnetMixin
23
+ from ..utils import BaseOutput, logging
24
+ from .attention_processor import AttentionProcessor, Guid_AttnProcessor
25
+ from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
26
+ from .modeling_utils import ModelMixin
27
+ from .unet_2d_blocks import (
28
+ CrossAttnDownBlock2D,
29
+ DownBlock2D,
30
+ UNetMidBlock2DCrossAttn,
31
+ get_down_block,
32
+ )
33
+ from .unet_2d_condition import UNet2DConditionModel
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ @dataclass
40
+ class ControlNetOutput(BaseOutput):
41
+ """
42
+ The output of [`ControlNetModel`].
43
+
44
+ Args:
45
+ down_block_res_samples (`tuple[torch.Tensor]`):
46
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
47
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
48
+ used to condition the original UNet's downsampling activations.
49
+ mid_down_block_re_sample (`torch.Tensor`):
50
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
51
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
52
+ Output can be used to condition the original UNet's middle block activation.
53
+ """
54
+
55
+ down_block_res_samples: Tuple[torch.Tensor]
56
+ mid_block_res_sample: torch.Tensor
57
+
58
+
59
+ class ControlNetConditioningEmbedding(nn.Module):
60
+ """
61
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
62
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
63
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
64
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
65
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
66
+ model) to encode image-space conditions ... into feature maps ..."
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ conditioning_embedding_channels: int,
72
+ conditioning_channels: int = 3,
73
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
74
+ ):
75
+ super().__init__()
76
+
77
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
78
+
79
+ self.blocks = nn.ModuleList([])
80
+
81
+ for i in range(len(block_out_channels) - 1):
82
+ channel_in = block_out_channels[i]
83
+ channel_out = block_out_channels[i + 1]
84
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
85
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
86
+
87
+ self.conv_out = zero_module(
88
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
89
+ )
90
+
91
+ def forward(self, conditioning):
92
+ embedding = self.conv_in(conditioning)
93
+ embedding = F.silu(embedding)
94
+
95
+ for block in self.blocks:
96
+ embedding = block(embedding)
97
+ embedding = F.silu(embedding)
98
+
99
+ embedding = self.conv_out(embedding)
100
+
101
+ return embedding
102
+
103
+
104
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
105
+ _supports_gradient_checkpointing = True
106
+
107
+ @register_to_config
108
+ def __init__(
109
+ self,
110
+ in_channels: int = 4,
111
+ conditioning_channels: int = 3,
112
+ flip_sin_to_cos: bool = True,
113
+ freq_shift: int = 0,
114
+ down_block_types: Tuple[str] = (
115
+ "CrossAttnDownBlock2D",
116
+ "CrossAttnDownBlock2D",
117
+ "CrossAttnDownBlock2D",
118
+ "DownBlock2D",
119
+ ),
120
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
121
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
122
+ layers_per_block: int = 2,
123
+ downsample_padding: int = 1,
124
+ mid_block_scale_factor: float = 1,
125
+ act_fn: str = "silu",
126
+ norm_num_groups: Optional[int] = 32,
127
+ norm_eps: float = 1e-5,
128
+ cross_attention_dim: int = 1280,
129
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
130
+ encoder_hid_dim: Optional[int] = None,
131
+ encoder_hid_dim_type: Optional[str] = None,
132
+ attention_head_dim: Union[int, Tuple[int]] = 8,
133
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
134
+ use_linear_projection: bool = False,
135
+ class_embed_type: Optional[str] = None,
136
+ addition_embed_type: Optional[str] = None,
137
+ addition_time_embed_dim: Optional[int] = None,
138
+ num_class_embeds: Optional[int] = None,
139
+ upcast_attention: bool = False,
140
+ resnet_time_scale_shift: str = "default",
141
+ projection_class_embeddings_input_dim: Optional[int] = None,
142
+ controlnet_conditioning_channel_order: str = "rgb",
143
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
144
+ global_pool_conditions: bool = False,
145
+ addition_embed_type_num_heads=64,
146
+ weight : Optional[torch.Tensor] = None,
147
+ ):
148
+ super().__init__()
149
+
150
+ # If `num_attention_heads` is not defined (which is the case for most models)
151
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
152
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
153
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
154
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
155
+ # which is why we correct for the naming here.
156
+ num_attention_heads = num_attention_heads or attention_head_dim
157
+
158
+ # Check inputs
159
+ if len(block_out_channels) != len(down_block_types):
160
+ raise ValueError(
161
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
162
+ )
163
+
164
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
165
+ raise ValueError(
166
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
167
+ )
168
+
169
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
170
+ raise ValueError(
171
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
172
+ )
173
+
174
+ if isinstance(transformer_layers_per_block, int):
175
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
176
+
177
+ # input
178
+ conv_in_kernel = 3
179
+ conv_in_padding = (conv_in_kernel - 1) // 2
180
+ self.conv_in = nn.Conv2d(
181
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
182
+ )
183
+
184
+ # time
185
+ time_embed_dim = block_out_channels[0] * 4
186
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
187
+ timestep_input_dim = block_out_channels[0]
188
+ self.time_embedding = TimestepEmbedding(
189
+ timestep_input_dim,
190
+ time_embed_dim,
191
+ act_fn=act_fn,
192
+ )
193
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
194
+ encoder_hid_dim_type = "text_proj"
195
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
196
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
197
+
198
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
199
+ raise ValueError(
200
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
201
+ )
202
+
203
+ if encoder_hid_dim_type == "text_proj":
204
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
205
+ elif encoder_hid_dim_type == "text_image_proj":
206
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
207
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
208
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
209
+ self.encoder_hid_proj = TextImageProjection(
210
+ text_embed_dim=encoder_hid_dim,
211
+ image_embed_dim=cross_attention_dim,
212
+ cross_attention_dim=cross_attention_dim,
213
+ )
214
+
215
+ elif encoder_hid_dim_type is not None:
216
+ raise ValueError(
217
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
218
+ )
219
+ else:
220
+ self.encoder_hid_proj = None
221
+ # class embedding
222
+ if class_embed_type is None and num_class_embeds is not None:
223
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
224
+ elif class_embed_type == "timestep":
225
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
226
+ elif class_embed_type == "identity":
227
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
228
+ elif class_embed_type == "projection":
229
+ if projection_class_embeddings_input_dim is None:
230
+ raise ValueError(
231
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
232
+ )
233
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
234
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
235
+ # 2. it projects from an arbitrary input dimension.
236
+ #
237
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
238
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
239
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
240
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
241
+ else:
242
+ self.class_embedding = None
243
+
244
+ if addition_embed_type == "text_nd":
245
+ if encoder_hid_dim is not None:
246
+ text_time_embedding_from_dim = encoder_hid_dim
247
+ else:
248
+ text_time_embedding_from_dim = cross_attention_dim
249
+
250
+ self.add_embedding = TextTimeEmbedding(
251
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
252
+ )
253
+ elif addition_embed_type == "text_image":
254
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
255
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
256
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
257
+ self.add_embedding = TextImageTimeEmbedding(
258
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
259
+ )
260
+ elif addition_embed_type == "text_time":
261
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
262
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
263
+
264
+ elif addition_embed_type is not None:
265
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
266
+
267
+ # control net conditioning embedding
268
+ ############################################################### ControlNetConditioningEmbedding #############################################################################
269
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
270
+ conditioning_embedding_channels=block_out_channels[0],
271
+ block_out_channels=conditioning_embedding_out_channels,
272
+ conditioning_channels=conditioning_channels,
273
+ )
274
+ self.down_blocks = nn.ModuleList([])
275
+ self.controlnet_down_blocks = nn.ModuleList([])
276
+
277
+ if isinstance(only_cross_attention, bool):
278
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
279
+
280
+ if isinstance(attention_head_dim, int):
281
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
282
+
283
+ if isinstance(num_attention_heads, int):
284
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
285
+
286
+ # down
287
+ output_channel = block_out_channels[0]
288
+
289
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
290
+ controlnet_block = zero_module(controlnet_block)
291
+ self.controlnet_down_blocks.append(controlnet_block)
292
+
293
+ for i, down_block_type in enumerate(down_block_types):
294
+ input_channel = output_channel
295
+ output_channel = block_out_channels[i]
296
+ is_final_block = i == len(block_out_channels) - 1
297
+
298
+ down_block = get_down_block(
299
+ down_block_type,
300
+ num_layers=layers_per_block,
301
+ transformer_layers_per_block=transformer_layers_per_block[i],
302
+ in_channels=input_channel,
303
+ out_channels=output_channel,
304
+ temb_channels=time_embed_dim,
305
+ add_downsample=not is_final_block,
306
+ resnet_eps=norm_eps,
307
+ resnet_act_fn=act_fn,
308
+ resnet_groups=norm_num_groups,
309
+ cross_attention_dim=cross_attention_dim,
310
+ num_attention_heads=num_attention_heads[i],
311
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
312
+ downsample_padding=downsample_padding,
313
+ use_linear_projection=use_linear_projection,
314
+ only_cross_attention=only_cross_attention[i],
315
+ upcast_attention=upcast_attention,
316
+ resnet_time_scale_shift=resnet_time_scale_shift,
317
+ weight = weight,
318
+ )
319
+
320
+ self.down_blocks.append(down_block)
321
+
322
+ for _ in range(layers_per_block):
323
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
324
+ controlnet_block = zero_module(controlnet_block)
325
+ self.controlnet_down_blocks.append(controlnet_block)
326
+
327
+ if not is_final_block:
328
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
329
+ controlnet_block = zero_module(controlnet_block)
330
+ self.controlnet_down_blocks.append(controlnet_block)
331
+
332
+ # mid
333
+ mid_block_channel = block_out_channels[-1]
334
+
335
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
336
+ controlnet_block = zero_module(controlnet_block)
337
+ self.controlnet_mid_block = controlnet_block
338
+
339
+ self.mid_block = UNetMidBlock2DCrossAttn(
340
+ transformer_layers_per_block=transformer_layers_per_block[-1],
341
+ in_channels=mid_block_channel,
342
+ temb_channels=time_embed_dim,
343
+ resnet_eps=norm_eps,
344
+ resnet_act_fn=act_fn,
345
+ output_scale_factor=mid_block_scale_factor,
346
+ resnet_time_scale_shift=resnet_time_scale_shift,
347
+ cross_attention_dim=cross_attention_dim,
348
+ num_attention_heads=num_attention_heads[-1],
349
+ resnet_groups=norm_num_groups,
350
+ use_linear_projection=use_linear_projection,
351
+ upcast_attention=upcast_attention,
352
+ )
353
+
354
+ @classmethod
355
+ def from_unet(
356
+ cls,
357
+ unet: UNet2DConditionModel,
358
+ controlnet_conditioning_channel_order: str = "rgb",
359
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
360
+ load_weights_from_unet: bool = True,
361
+ weight : Optional[torch.Tensor] = None,
362
+ ):
363
+ r"""
364
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
365
+
366
+ Parameters:
367
+ unet (`UNet2DConditionModel`):
368
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
369
+ where applicable.
370
+ """
371
+ transformer_layers_per_block = (
372
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
373
+ )
374
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
375
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
376
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
377
+ addition_time_embed_dim = (
378
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
379
+ )
380
+ controlnet = cls(
381
+ encoder_hid_dim=encoder_hid_dim,
382
+ encoder_hid_dim_type=encoder_hid_dim_type,
383
+ addition_embed_type=addition_embed_type,
384
+ addition_time_embed_dim=addition_time_embed_dim,
385
+ transformer_layers_per_block=transformer_layers_per_block,
386
+ in_channels=unet.config.in_channels,
387
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
388
+ freq_shift=unet.config.freq_shift,
389
+ down_block_types=unet.config.down_block_types,
390
+ only_cross_attention=unet.config.only_cross_attention,
391
+ block_out_channels=unet.config.block_out_channels,
392
+ layers_per_block=unet.config.layers_per_block,
393
+ downsample_padding=unet.config.downsample_padding,
394
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
395
+ act_fn=unet.config.act_fn,
396
+ norm_num_groups=unet.config.norm_num_groups,
397
+ norm_eps=unet.config.norm_eps,
398
+ cross_attention_dim=unet.config.cross_attention_dim,
399
+ attention_head_dim=unet.config.attention_head_dim,
400
+ num_attention_heads=unet.config.num_attention_heads,
401
+ use_linear_projection=unet.config.use_linear_projection,
402
+ class_embed_type=unet.config.class_embed_type,
403
+ num_class_embeds=unet.config.num_class_embeds,
404
+ upcast_attention=unet.config.upcast_attention,
405
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
406
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
407
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
408
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
409
+ )
410
+
411
+ if load_weights_from_unet:
412
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
413
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
414
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
415
+
416
+ if controlnet.class_embedding:
417
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
418
+
419
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
420
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
421
+
422
+ return controlnet
423
+
424
+ @property
425
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
426
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
427
+ r"""
428
+ Returns:
429
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
430
+ indexed by its weight name.
431
+ """
432
+ # set recursively
433
+ processors = {}
434
+
435
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
436
+ if hasattr(module, "set_processor"):
437
+ processors[f"{name}.processor"] = module.processor
438
+
439
+ for sub_name, child in module.named_children():
440
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
441
+
442
+ return processors
443
+
444
+ for name, module in self.named_children():
445
+ fn_recursive_add_processors(name, module, processors)
446
+
447
+ return processors
448
+
449
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
450
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
451
+ r"""
452
+ Sets the attention processor to use to compute attention.
453
+
454
+ Parameters:
455
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
456
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
457
+ for **all** `Attention` layers.
458
+
459
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
460
+ processor. This is strongly recommended when setting trainable attention processors.
461
+
462
+ """
463
+ count = len(self.attn_processors.keys())
464
+
465
+ if isinstance(processor, dict) and len(processor) != count:
466
+ raise ValueError(
467
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
468
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
469
+ )
470
+
471
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
472
+ if hasattr(module, "set_processor"):
473
+ if not isinstance(processor, dict):
474
+ module.set_processor(processor)
475
+ else:
476
+ module.set_processor(processor.pop(f"{name}.processor"))
477
+
478
+ for sub_name, child in module.named_children():
479
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
480
+ for name, module in self.named_children():
481
+ fn_recursive_attn_processor(name, module, processor)
482
+
483
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
484
+ def set_default_attn_processor(self):
485
+ """
486
+ Disables custom attention processors and sets the default attention implementation.
487
+ """
488
+ self.set_attn_processor(AttnProcessor())
489
+
490
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
491
+ def set_attention_slice(self, slice_size):
492
+ r"""
493
+ Enable sliced attention computation.
494
+
495
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
496
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
497
+
498
+ Args:
499
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
500
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
501
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
502
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
503
+ must be a multiple of `slice_size`.
504
+ """
505
+ sliceable_head_dims = []
506
+
507
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
508
+ if hasattr(module, "set_attention_slice"):
509
+ sliceable_head_dims.append(module.sliceable_head_dim)
510
+
511
+ for child in module.children():
512
+ fn_recursive_retrieve_sliceable_dims(child)
513
+
514
+ # retrieve number of attention layers
515
+ for module in self.children():
516
+ fn_recursive_retrieve_sliceable_dims(module)
517
+
518
+ num_sliceable_layers = len(sliceable_head_dims)
519
+
520
+ if slice_size == "auto":
521
+ # half the attention head size is usually a good trade-off between
522
+ # speed and memory
523
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
524
+ elif slice_size == "max":
525
+ # make smallest slice possible
526
+ slice_size = num_sliceable_layers * [1]
527
+
528
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
529
+
530
+ if len(slice_size) != len(sliceable_head_dims):
531
+ raise ValueError(
532
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
533
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
534
+ )
535
+
536
+ for i in range(len(slice_size)):
537
+ size = slice_size[i]
538
+ dim = sliceable_head_dims[i]
539
+ if size is not None and size > dim:
540
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
541
+
542
+ # Recursively walk through all the children.
543
+ # Any children which exposes the set_attention_slice method
544
+ # gets the message
545
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
546
+ if hasattr(module, "set_attention_slice"):
547
+ module.set_attention_slice(slice_size.pop())
548
+
549
+ for child in module.children():
550
+ fn_recursive_set_attention_slice(child, slice_size)
551
+
552
+ reversed_slice_size = list(reversed(slice_size))
553
+ for module in self.children():
554
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
555
+
556
+ def _set_gradient_checkpointing(self, module, value=False):
557
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
558
+ module.gradient_checkpointing = value
559
+
560
+ def forward(
561
+ self,
562
+ sample: torch.FloatTensor,
563
+ timestep: Union[torch.Tensor, float, int],
564
+ encoder_hidden_states: torch.Tensor,
565
+ controlnet_cond: torch.FloatTensor,
566
+ conditioning_scale: float = 1.0,
567
+ class_labels: Optional[torch.Tensor] = None,
568
+ timestep_cond: Optional[torch.Tensor] = None,
569
+ attention_mask: Optional[torch.Tensor] = None,
570
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
571
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
572
+ guess_mode: bool = False,
573
+ return_dict: bool = True,
574
+ weight: Optional[torch.Tensor] = None,
575
+ ) -> Union[ControlNetOutput, Tuple]:
576
+ """
577
+ The [`ControlNetModel`] forward method.
578
+
579
+ Args:
580
+ sample (`torch.FloatTensor`):
581
+ The noisy input tensor.
582
+ timestep (`Union[torch.Tensor, float, int]`):
583
+ The number of timesteps to denoise an input.
584
+ encoder_hidden_states (`torch.Tensor`):
585
+ The encoder hidden states.
586
+ controlnet_cond (`torch.FloatTensor`):
587
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
588
+ conditioning_scale (`float`, defaults to `1.0`):
589
+ The scale factor for ControlNet outputs.
590
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
591
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
592
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
593
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
594
+ added_cond_kwargs (`dict`):
595
+ Additional conditions for the Stable Diffusion XL UNet.
596
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
597
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
598
+ guess_mode (`bool`, defaults to `False`):
599
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
600
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
601
+ return_dict (`bool`, defaults to `True`):
602
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
603
+
604
+ Returns:
605
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
606
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
607
+ returned where the first element is the sample tensor.
608
+ """
609
+ # check channel order
610
+ channel_order = self.config.controlnet_conditioning_channel_order
611
+
612
+ if channel_order == "rgb":
613
+ # in rgb order by default
614
+ ...
615
+ elif channel_order == "bgr":
616
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
617
+ else:
618
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
619
+
620
+ # prepare attention_mask
621
+
622
+ if attention_mask is not None:
623
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
624
+ attention_mask = attention_mask.unsqueeze(1)
625
+
626
+ # 1. time
627
+ timesteps = timestep
628
+ if not torch.is_tensor(timesteps):
629
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
630
+ # This would be a good case for the `match` statement (Python 3.10+)
631
+ is_mps = sample.device.type == "mps"
632
+ if isinstance(timestep, float):
633
+ dtype = torch.float32 if is_mps else torch.float64
634
+ else:
635
+ dtype = torch.int32 if is_mps else torch.int64
636
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
637
+ elif len(timesteps.shape) == 0:
638
+ timesteps = timesteps[None].to(sample.device)
639
+
640
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
641
+ timesteps = timesteps.expand(sample.shape[0])
642
+
643
+ t_emb = self.time_proj(timesteps)
644
+
645
+ # timesteps does not contain any weights and will always return f32 tensors
646
+ # but time_embedding might actually be running in fp16. so we need to cast here.
647
+ # there might be better ways to encapsulate this.
648
+ t_emb = t_emb.to(dtype=sample.dtype)
649
+
650
+ emb = self.time_embedding(t_emb, timestep_cond)
651
+ aug_emb = None
652
+
653
+ if self.class_embedding is not None:
654
+ if class_labels is None:
655
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
656
+
657
+ if self.config.class_embed_type == "timestep":
658
+ class_labels = self.time_proj(class_labels)
659
+
660
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
661
+ emb = emb + class_emb
662
+
663
+ if "addition_embed_type" in self.config:
664
+ if self.config.addition_embed_type == "text":
665
+ aug_emb = self.add_embedding(encoder_hidden_states)
666
+
667
+ elif self.config.addition_embed_type == "text_time":
668
+ if "text_embeds" not in added_cond_kwargs:
669
+ raise ValueError(
670
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
671
+ )
672
+ text_embeds = added_cond_kwargs.get("text_embeds")
673
+ if "time_ids" not in added_cond_kwargs:
674
+ raise ValueError(
675
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
676
+ )
677
+ time_ids = added_cond_kwargs.get("time_ids")
678
+ time_embeds = self.add_time_proj(time_ids.flatten())
679
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
680
+
681
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
682
+ add_embeds = add_embeds.to(emb.dtype)
683
+ aug_emb = self.add_embedding(add_embeds)
684
+
685
+ emb = emb + aug_emb if aug_emb is not None else emb
686
+
687
+ # 2. pre-process
688
+ sample = self.conv_in(sample)
689
+
690
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
691
+ sample = sample + controlnet_cond
692
+
693
+ # 3. down
694
+ down_block_res_samples = (sample,)
695
+ for downsample_block in self.down_blocks:
696
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
697
+ print('controlnet encoder_hidden_states_nd',encoder_hidden_states.shape)
698
+ sample, res_samples = downsample_block(
699
+ hidden_states=sample,
700
+ temb=emb,
701
+ encoder_hidden_states=encoder_hidden_states,
702
+ attention_mask=attention_mask,
703
+ cross_attention_kwargs=cross_attention_kwargs,
704
+ weight = weight
705
+ )
706
+ else:
707
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
708
+
709
+ down_block_res_samples += res_samples
710
+
711
+
712
+ # 4. mid
713
+ if self.mid_block is not None:
714
+ sample = self.mid_block(
715
+ sample,
716
+ emb,
717
+ encoder_hidden_states=encoder_hidden_states,
718
+ attention_mask=attention_mask,
719
+ cross_attention_kwargs=cross_attention_kwargs,
720
+ )
721
+
722
+ # 5. Control net blocks
723
+
724
+ controlnet_down_block_res_samples = ()
725
+
726
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
727
+ down_block_res_sample = controlnet_block(down_block_res_sample)
728
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
729
+
730
+ down_block_res_samples = controlnet_down_block_res_samples
731
+
732
+ mid_block_res_sample = self.controlnet_mid_block(sample)
733
+
734
+ # 6. scaling
735
+ if guess_mode and not self.config.global_pool_conditions:
736
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
737
+
738
+ scales = scales * conditioning_scale
739
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
740
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
741
+ else:
742
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
743
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
744
+
745
+ if self.config.global_pool_conditions:
746
+ down_block_res_samples = [
747
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
748
+ ]
749
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
750
+
751
+ if not return_dict:
752
+ return (down_block_res_samples, mid_block_res_sample)
753
+
754
+ return ControlNetOutput(
755
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
756
+ )
757
+
758
+
759
+ def zero_module(module):
760
+ for p in module.parameters():
761
+ nn.init.zeros_(p)
762
+ return module
Tiger Model/diffusiers-Tiger/models/dual_transformer_2d.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Module):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.ModuleList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ attention_mask=None,
103
+ cross_attention_kwargs=None,
104
+ return_dict: bool = True,
105
+ ):
106
+ """
107
+ Args:
108
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
+ hidden_states
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.FloatTensor`, *optional*):
117
+ Optional attention mask to be applied in Attention
118
+ return_dict (`bool`, *optional*, defaults to `True`):
119
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120
+
121
+ Returns:
122
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124
+ returning a tuple, the first element is the sample tensor.
125
+ """
126
+ input_states = hidden_states
127
+
128
+ encoded_states = []
129
+ tokens_start = 0
130
+ # attention_mask is not used yet
131
+ for i in range(2):
132
+ # for each of the two transformers, pass the corresponding condition tokens
133
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134
+ transformer_index = self.transformer_index_for_condition[i]
135
+ encoded_state = self.transformers[transformer_index](
136
+ input_states,
137
+ encoder_hidden_states=condition_state,
138
+ timestep=timestep,
139
+ cross_attention_kwargs=cross_attention_kwargs,
140
+ return_dict=False,
141
+ )[0]
142
+ encoded_states.append(encoded_state - input_states)
143
+ tokens_start += self.condition_lengths[i]
144
+
145
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146
+ output_states = output_states + input_states
147
+
148
+ if not return_dict:
149
+ return (output_states,)
150
+
151
+ return Transformer2DModelOutput(sample=output_states)
Tiger Model/diffusiers-Tiger/models/embeddings.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+
21
+ from .activations import get_activation
22
+
23
+
24
+ def get_timestep_embedding(
25
+ timesteps: torch.Tensor,
26
+ embedding_dim: int,
27
+ flip_sin_to_cos: bool = False,
28
+ downscale_freq_shift: float = 1,
29
+ scale: float = 1,
30
+ max_period: int = 10000,
31
+ ):
32
+ """
33
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
34
+
35
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
36
+ These may be fractional.
37
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
38
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
39
+ """
40
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
41
+
42
+ half_dim = embedding_dim // 2
43
+ exponent = -math.log(max_period) * torch.arange(
44
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
45
+ )
46
+ exponent = exponent / (half_dim - downscale_freq_shift)
47
+
48
+ emb = torch.exp(exponent)
49
+ emb = timesteps[:, None].float() * emb[None, :]
50
+
51
+ # scale embeddings
52
+ emb = scale * emb
53
+
54
+ # concat sine and cosine embeddings
55
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
56
+
57
+ # flip sine and cosine embeddings
58
+ if flip_sin_to_cos:
59
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
60
+
61
+ # zero pad
62
+ if embedding_dim % 2 == 1:
63
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
64
+ return emb
65
+
66
+
67
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
68
+ """
69
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
70
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71
+ """
72
+ grid_h = np.arange(grid_size, dtype=np.float32)
73
+ grid_w = np.arange(grid_size, dtype=np.float32)
74
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
75
+ grid = np.stack(grid, axis=0)
76
+
77
+ grid = grid.reshape([2, 1, grid_size, grid_size])
78
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79
+ if cls_token and extra_tokens > 0:
80
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
81
+ return pos_embed
82
+
83
+
84
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
85
+ if embed_dim % 2 != 0:
86
+ raise ValueError("embed_dim must be divisible by 2")
87
+
88
+ # use half of dimensions to encode grid_h
89
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
90
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
91
+
92
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
93
+ return emb
94
+
95
+
96
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
97
+ """
98
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
99
+ """
100
+ if embed_dim % 2 != 0:
101
+ raise ValueError("embed_dim must be divisible by 2")
102
+
103
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
104
+ omega /= embed_dim / 2.0
105
+ omega = 1.0 / 10000**omega # (D/2,)
106
+
107
+ pos = pos.reshape(-1) # (M,)
108
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
109
+
110
+ emb_sin = np.sin(out) # (M, D/2)
111
+ emb_cos = np.cos(out) # (M, D/2)
112
+
113
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
114
+ return emb
115
+
116
+
117
+ class PatchEmbed(nn.Module):
118
+ """2D Image to Patch Embedding"""
119
+
120
+ def __init__(
121
+ self,
122
+ height=224,
123
+ width=224,
124
+ patch_size=16,
125
+ in_channels=3,
126
+ embed_dim=768,
127
+ layer_norm=False,
128
+ flatten=True,
129
+ bias=True,
130
+ ):
131
+ super().__init__()
132
+
133
+ num_patches = (height // patch_size) * (width // patch_size)
134
+ self.flatten = flatten
135
+ self.layer_norm = layer_norm
136
+
137
+ self.proj = nn.Conv2d(
138
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
139
+ )
140
+ if layer_norm:
141
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
142
+ else:
143
+ self.norm = None
144
+
145
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
146
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
147
+
148
+ def forward(self, latent):
149
+ latent = self.proj(latent)
150
+ if self.flatten:
151
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
152
+ if self.layer_norm:
153
+ latent = self.norm(latent)
154
+ return latent + self.pos_embed
155
+
156
+
157
+ class TimestepEmbedding(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels: int,
161
+ time_embed_dim: int,
162
+ act_fn: str = "silu",
163
+ out_dim: int = None,
164
+ post_act_fn: Optional[str] = None,
165
+ cond_proj_dim=None,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
170
+
171
+ if cond_proj_dim is not None:
172
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
173
+ else:
174
+ self.cond_proj = None
175
+
176
+ self.act = get_activation(act_fn)
177
+
178
+ if out_dim is not None:
179
+ time_embed_dim_out = out_dim
180
+ else:
181
+ time_embed_dim_out = time_embed_dim
182
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
183
+
184
+ if post_act_fn is None:
185
+ self.post_act = None
186
+ else:
187
+ self.post_act = get_activation(post_act_fn)
188
+
189
+ def forward(self, sample, condition=None):
190
+ if condition is not None:
191
+ sample = sample + self.cond_proj(condition)
192
+ sample = self.linear_1(sample)
193
+
194
+ if self.act is not None:
195
+ sample = self.act(sample)
196
+
197
+ sample = self.linear_2(sample)
198
+
199
+ if self.post_act is not None:
200
+ sample = self.post_act(sample)
201
+ return sample
202
+
203
+
204
+ class Timesteps(nn.Module):
205
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
206
+ super().__init__()
207
+ self.num_channels = num_channels
208
+ self.flip_sin_to_cos = flip_sin_to_cos
209
+ self.downscale_freq_shift = downscale_freq_shift
210
+
211
+ def forward(self, timesteps):
212
+ t_emb = get_timestep_embedding(
213
+ timesteps,
214
+ self.num_channels,
215
+ flip_sin_to_cos=self.flip_sin_to_cos,
216
+ downscale_freq_shift=self.downscale_freq_shift,
217
+ )
218
+ return t_emb
219
+
220
+
221
+ class GaussianFourierProjection(nn.Module):
222
+ """Gaussian Fourier embeddings for noise levels."""
223
+
224
+ def __init__(
225
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
226
+ ):
227
+ super().__init__()
228
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
229
+ self.log = log
230
+ self.flip_sin_to_cos = flip_sin_to_cos
231
+
232
+ if set_W_to_weight:
233
+ # to delete later
234
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
235
+
236
+ self.weight = self.W
237
+
238
+ def forward(self, x):
239
+ if self.log:
240
+ x = torch.log(x)
241
+
242
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
243
+
244
+ if self.flip_sin_to_cos:
245
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
246
+ else:
247
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
248
+ return out
249
+
250
+
251
+ class ImagePositionalEmbeddings(nn.Module):
252
+ """
253
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
254
+ height and width of the latent space.
255
+
256
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
257
+
258
+ For VQ-diffusion:
259
+
260
+ Output vector embeddings are used as input for the transformer.
261
+
262
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
263
+
264
+ Args:
265
+ num_embed (`int`):
266
+ Number of embeddings for the latent pixels embeddings.
267
+ height (`int`):
268
+ Height of the latent image i.e. the number of height embeddings.
269
+ width (`int`):
270
+ Width of the latent image i.e. the number of width embeddings.
271
+ embed_dim (`int`):
272
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ num_embed: int,
278
+ height: int,
279
+ width: int,
280
+ embed_dim: int,
281
+ ):
282
+ super().__init__()
283
+
284
+ self.height = height
285
+ self.width = width
286
+ self.num_embed = num_embed
287
+ self.embed_dim = embed_dim
288
+
289
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
290
+ self.height_emb = nn.Embedding(self.height, embed_dim)
291
+ self.width_emb = nn.Embedding(self.width, embed_dim)
292
+
293
+ def forward(self, index):
294
+ emb = self.emb(index)
295
+
296
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
297
+
298
+ # 1 x H x D -> 1 x H x 1 x D
299
+ height_emb = height_emb.unsqueeze(2)
300
+
301
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
302
+
303
+ # 1 x W x D -> 1 x 1 x W x D
304
+ width_emb = width_emb.unsqueeze(1)
305
+
306
+ pos_emb = height_emb + width_emb
307
+
308
+ # 1 x H x W x D -> 1 x L xD
309
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
310
+
311
+ emb = emb + pos_emb[:, : emb.shape[1], :]
312
+
313
+ return emb
314
+
315
+
316
+ class LabelEmbedding(nn.Module):
317
+ """
318
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
319
+
320
+ Args:
321
+ num_classes (`int`): The number of classes.
322
+ hidden_size (`int`): The size of the vector embeddings.
323
+ dropout_prob (`float`): The probability of dropping a label.
324
+ """
325
+
326
+ def __init__(self, num_classes, hidden_size, dropout_prob):
327
+ super().__init__()
328
+ use_cfg_embedding = dropout_prob > 0
329
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
330
+ self.num_classes = num_classes
331
+ self.dropout_prob = dropout_prob
332
+
333
+ def token_drop(self, labels, force_drop_ids=None):
334
+ """
335
+ Drops labels to enable classifier-free guidance.
336
+ """
337
+ if force_drop_ids is None:
338
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
339
+ else:
340
+ drop_ids = torch.tensor(force_drop_ids == 1)
341
+ labels = torch.where(drop_ids, self.num_classes, labels)
342
+ return labels
343
+
344
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
345
+ use_dropout = self.dropout_prob > 0
346
+ if (self.training and use_dropout) or (force_drop_ids is not None):
347
+ labels = self.token_drop(labels, force_drop_ids)
348
+ embeddings = self.embedding_table(labels)
349
+ return embeddings
350
+
351
+
352
+ class TextImageProjection(nn.Module):
353
+ def __init__(
354
+ self,
355
+ text_embed_dim: int = 1024,
356
+ image_embed_dim: int = 768,
357
+ cross_attention_dim: int = 768,
358
+ num_image_text_embeds: int = 10,
359
+ ):
360
+ super().__init__()
361
+
362
+ self.num_image_text_embeds = num_image_text_embeds
363
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
364
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
365
+
366
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
367
+ batch_size = text_embeds.shape[0]
368
+
369
+ # image
370
+ image_text_embeds = self.image_embeds(image_embeds)
371
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
372
+
373
+ # text
374
+ text_embeds = self.text_proj(text_embeds)
375
+
376
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
377
+
378
+
379
+ class ImageProjection(nn.Module):
380
+ def __init__(
381
+ self,
382
+ image_embed_dim: int = 768,
383
+ cross_attention_dim: int = 768,
384
+ num_image_text_embeds: int = 32,
385
+ ):
386
+ super().__init__()
387
+
388
+ self.num_image_text_embeds = num_image_text_embeds
389
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
390
+ self.norm = nn.LayerNorm(cross_attention_dim)
391
+
392
+ def forward(self, image_embeds: torch.FloatTensor):
393
+ batch_size = image_embeds.shape[0]
394
+
395
+ # image
396
+ image_embeds = self.image_embeds(image_embeds)
397
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
398
+ image_embeds = self.norm(image_embeds)
399
+ return image_embeds
400
+
401
+
402
+ class CombinedTimestepLabelEmbeddings(nn.Module):
403
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
404
+ super().__init__()
405
+
406
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
407
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
408
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
409
+
410
+ def forward(self, timestep, class_labels, hidden_dtype=None):
411
+ timesteps_proj = self.time_proj(timestep)
412
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
413
+
414
+ class_labels = self.class_embedder(class_labels) # (N, D)
415
+
416
+ conditioning = timesteps_emb + class_labels # (N, D)
417
+
418
+ return conditioning
419
+
420
+
421
+ class TextTimeEmbedding(nn.Module):
422
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
423
+ super().__init__()
424
+ self.norm1 = nn.LayerNorm(encoder_dim)
425
+ self.pool = AttentionPooling(num_heads, encoder_dim)
426
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
427
+ self.norm2 = nn.LayerNorm(time_embed_dim)
428
+
429
+ def forward(self, hidden_states):
430
+ hidden_states = self.norm1(hidden_states)
431
+ hidden_states = self.pool(hidden_states)
432
+ hidden_states = self.proj(hidden_states)
433
+ hidden_states = self.norm2(hidden_states)
434
+ return hidden_states
435
+
436
+
437
+ class TextImageTimeEmbedding(nn.Module):
438
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
439
+ super().__init__()
440
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
441
+ self.text_norm = nn.LayerNorm(time_embed_dim)
442
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
443
+
444
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
445
+ # text
446
+ time_text_embeds = self.text_proj(text_embeds)
447
+ time_text_embeds = self.text_norm(time_text_embeds)
448
+
449
+ # image
450
+ time_image_embeds = self.image_proj(image_embeds)
451
+
452
+ return time_image_embeds + time_text_embeds
453
+
454
+
455
+ class ImageTimeEmbedding(nn.Module):
456
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
457
+ super().__init__()
458
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
459
+ self.image_norm = nn.LayerNorm(time_embed_dim)
460
+
461
+ def forward(self, image_embeds: torch.FloatTensor):
462
+ # image
463
+ time_image_embeds = self.image_proj(image_embeds)
464
+ time_image_embeds = self.image_norm(time_image_embeds)
465
+ return time_image_embeds
466
+
467
+
468
+ class ImageHintTimeEmbedding(nn.Module):
469
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
470
+ super().__init__()
471
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
472
+ self.image_norm = nn.LayerNorm(time_embed_dim)
473
+ self.input_hint_block = nn.Sequential(
474
+ nn.Conv2d(3, 16, 3, padding=1),
475
+ nn.SiLU(),
476
+ nn.Conv2d(16, 16, 3, padding=1),
477
+ nn.SiLU(),
478
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
479
+ nn.SiLU(),
480
+ nn.Conv2d(32, 32, 3, padding=1),
481
+ nn.SiLU(),
482
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
483
+ nn.SiLU(),
484
+ nn.Conv2d(96, 96, 3, padding=1),
485
+ nn.SiLU(),
486
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
487
+ nn.SiLU(),
488
+ nn.Conv2d(256, 4, 3, padding=1),
489
+ )
490
+
491
+ def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
492
+ # image
493
+ time_image_embeds = self.image_proj(image_embeds)
494
+ time_image_embeds = self.image_norm(time_image_embeds)
495
+ hint = self.input_hint_block(hint)
496
+ return time_image_embeds, hint
497
+
498
+
499
+ class AttentionPooling(nn.Module):
500
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
501
+
502
+ def __init__(self, num_heads, embed_dim, dtype=None):
503
+ super().__init__()
504
+ self.dtype = dtype
505
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
506
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
507
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
508
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
509
+ self.num_heads = num_heads
510
+ self.dim_per_head = embed_dim // self.num_heads
511
+
512
+ def forward(self, x):
513
+ bs, length, width = x.size()
514
+
515
+ def shape(x):
516
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
517
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
518
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
519
+ x = x.transpose(1, 2)
520
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
521
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
522
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
523
+ x = x.transpose(1, 2)
524
+ return x
525
+
526
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
527
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
528
+
529
+ # (bs*n_heads, class_token_length, dim_per_head)
530
+ q = shape(self.q_proj(class_token))
531
+ # (bs*n_heads, length+class_token_length, dim_per_head)
532
+ k = shape(self.k_proj(x))
533
+ v = shape(self.v_proj(x))
534
+
535
+ # (bs*n_heads, class_token_length, length+class_token_length):
536
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
537
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
538
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
539
+
540
+ # (bs*n_heads, dim_per_head, class_token_length)
541
+ a = torch.einsum("bts,bcs->bct", weight, v)
542
+
543
+ # (bs, length+1, width)
544
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
545
+
546
+ return a[:, 0, :] # cls_token
547
+
548
+
549
+ class FourierEmbedder(nn.Module):
550
+ def __init__(self, num_freqs=64, temperature=100):
551
+ super().__init__()
552
+
553
+ self.num_freqs = num_freqs
554
+ self.temperature = temperature
555
+
556
+ freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
557
+ freq_bands = freq_bands[None, None, None]
558
+ self.register_buffer("freq_bands", freq_bands, persistent=False)
559
+
560
+ def __call__(self, x):
561
+ x = self.freq_bands * x.unsqueeze(-1)
562
+ return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
563
+
564
+
565
+ class PositionNet(nn.Module):
566
+ def __init__(self, positive_len, out_dim, fourier_freqs=8):
567
+ super().__init__()
568
+ self.positive_len = positive_len
569
+ self.out_dim = out_dim
570
+
571
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
572
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
573
+
574
+ if isinstance(out_dim, tuple):
575
+ out_dim = out_dim[0]
576
+ self.linears = nn.Sequential(
577
+ nn.Linear(self.positive_len + self.position_dim, 512),
578
+ nn.SiLU(),
579
+ nn.Linear(512, 512),
580
+ nn.SiLU(),
581
+ nn.Linear(512, out_dim),
582
+ )
583
+
584
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
585
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
586
+
587
+ def forward(self, boxes, masks, positive_embeddings):
588
+ masks = masks.unsqueeze(-1)
589
+
590
+ # embedding position (it may includes padding as placeholder)
591
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
592
+
593
+ # learnable null embedding
594
+ positive_null = self.null_positive_feature.view(1, 1, -1)
595
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
596
+
597
+ # replace padding with learnable null embedding
598
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
599
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
600
+
601
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
602
+ return objs
Tiger Model/diffusiers-Tiger/models/lora.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+
21
+ class LoRALinearLayer(nn.Module):
22
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
23
+ super().__init__()
24
+
25
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
26
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
27
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
28
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
29
+ self.network_alpha = network_alpha
30
+ self.rank = rank
31
+
32
+ nn.init.normal_(self.down.weight, std=1 / rank)
33
+ nn.init.zeros_(self.up.weight)
34
+
35
+ def forward(self, hidden_states):
36
+ orig_dtype = hidden_states.dtype
37
+ dtype = self.down.weight.dtype
38
+
39
+ down_hidden_states = self.down(hidden_states.to(dtype))
40
+ up_hidden_states = self.up(down_hidden_states)
41
+
42
+ if self.network_alpha is not None:
43
+ up_hidden_states *= self.network_alpha / self.rank
44
+
45
+ return up_hidden_states.to(orig_dtype)
46
+
47
+
48
+ class LoRAConv2dLayer(nn.Module):
49
+ def __init__(
50
+ self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
51
+ ):
52
+ super().__init__()
53
+
54
+ self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
55
+ # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
56
+ # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
57
+ self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
58
+
59
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
60
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
61
+ self.network_alpha = network_alpha
62
+ self.rank = rank
63
+
64
+ nn.init.normal_(self.down.weight, std=1 / rank)
65
+ nn.init.zeros_(self.up.weight)
66
+
67
+ def forward(self, hidden_states):
68
+ orig_dtype = hidden_states.dtype
69
+ dtype = self.down.weight.dtype
70
+
71
+ down_hidden_states = self.down(hidden_states.to(dtype))
72
+ up_hidden_states = self.up(down_hidden_states)
73
+
74
+ if self.network_alpha is not None:
75
+ up_hidden_states *= self.network_alpha / self.rank
76
+
77
+ return up_hidden_states.to(orig_dtype)
78
+
79
+
80
+ class LoRACompatibleConv(nn.Conv2d):
81
+ """
82
+ A convolutional layer that can be used with LoRA.
83
+ """
84
+
85
+ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
86
+ super().__init__(*args, **kwargs)
87
+ self.lora_layer = lora_layer
88
+
89
+ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
90
+ self.lora_layer = lora_layer
91
+
92
+ def forward(self, x):
93
+ if self.lora_layer is None:
94
+ # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
95
+ # see: https://github.com/huggingface/diffusers/pull/4315
96
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
97
+ else:
98
+ return super().forward(x) + self.lora_layer(x)
99
+
100
+
101
+ class LoRACompatibleLinear(nn.Linear):
102
+ """
103
+ A Linear layer that can be used with LoRA.
104
+ """
105
+
106
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
107
+ super().__init__(*args, **kwargs)
108
+ self.lora_layer = lora_layer
109
+
110
+ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
111
+ self.lora_layer = lora_layer
112
+
113
+ def forward(self, x):
114
+ if self.lora_layer is None:
115
+ return super().forward(x)
116
+ else:
117
+ return super().forward(x) + self.lora_layer(x)
Tiger Model/diffusiers-Tiger/models/modeling_utils.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import os
20
+ import re
21
+ from functools import partial
22
+ from typing import Any, Callable, List, Optional, Tuple, Union
23
+
24
+ import safetensors
25
+ import torch
26
+ from huggingface_hub import create_repo
27
+ from torch import Tensor, device, nn
28
+
29
+ from .. import __version__
30
+ from ..utils import (
31
+ CONFIG_NAME,
32
+ DIFFUSERS_CACHE,
33
+ FLAX_WEIGHTS_NAME,
34
+ HF_HUB_OFFLINE,
35
+ SAFETENSORS_WEIGHTS_NAME,
36
+ WEIGHTS_NAME,
37
+ _add_variant,
38
+ _get_model_file,
39
+ deprecate,
40
+ is_accelerate_available,
41
+ is_torch_version,
42
+ logging,
43
+ )
44
+ from ..utils.hub_utils import PushToHubMixin
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ if is_torch_version(">=", "1.9.0"):
51
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
52
+ else:
53
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
54
+
55
+
56
+ if is_accelerate_available():
57
+ import accelerate
58
+ from accelerate.utils import set_module_tensor_to_device
59
+ from accelerate.utils.versions import is_torch_version
60
+
61
+
62
+ def get_parameter_device(parameter: torch.nn.Module):
63
+ try:
64
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
65
+ return next(parameters_and_buffers).device
66
+ except StopIteration:
67
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
68
+
69
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
70
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
71
+ return tuples
72
+
73
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
74
+ first_tuple = next(gen)
75
+ return first_tuple[1].device
76
+
77
+
78
+ def get_parameter_dtype(parameter: torch.nn.Module):
79
+ try:
80
+ params = tuple(parameter.parameters())
81
+ if len(params) > 0:
82
+ return params[0].dtype
83
+
84
+ buffers = tuple(parameter.buffers())
85
+ if len(buffers) > 0:
86
+ return buffers[0].dtype
87
+
88
+ except StopIteration:
89
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
90
+
91
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
92
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
93
+ return tuples
94
+
95
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
96
+ first_tuple = next(gen)
97
+ return first_tuple[1].dtype
98
+
99
+
100
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
101
+ """
102
+ Reads a checkpoint file, returning properly formatted errors if they arise.
103
+ """
104
+ try:
105
+ if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
106
+ return torch.load(checkpoint_file, map_location="cpu")
107
+ else:
108
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
109
+ except Exception as e:
110
+ try:
111
+ with open(checkpoint_file) as f:
112
+ if f.read().startswith("version"):
113
+ raise OSError(
114
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
115
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
116
+ "you cloned."
117
+ )
118
+ else:
119
+ raise ValueError(
120
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
121
+ "model. Make sure you have saved the model properly."
122
+ ) from e
123
+ except (UnicodeDecodeError, ValueError):
124
+ raise OSError(
125
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
126
+ f"at '{checkpoint_file}'. "
127
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
128
+ )
129
+
130
+
131
+ def _load_state_dict_into_model(model_to_load, state_dict):
132
+ # Convert old format to new format if needed from a PyTorch state_dict
133
+ # copy state_dict so _load_from_state_dict can modify it
134
+ state_dict = state_dict.copy()
135
+ error_msgs = []
136
+
137
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
138
+ # so we need to apply the function recursively.
139
+ def load(module: torch.nn.Module, prefix=""):
140
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
141
+ module._load_from_state_dict(*args)
142
+
143
+ for name, child in module._modules.items():
144
+ if child is not None:
145
+ load(child, prefix + name + ".")
146
+
147
+ load(model_to_load)
148
+
149
+ return error_msgs
150
+
151
+
152
+ class ModelMixin(torch.nn.Module, PushToHubMixin):
153
+ r"""
154
+ Base class for all models.
155
+
156
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
157
+ saving models.
158
+
159
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
160
+ """
161
+ config_name = CONFIG_NAME
162
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
163
+ _supports_gradient_checkpointing = False
164
+ _keys_to_ignore_on_load_unexpected = None
165
+
166
+ def __init__(self):
167
+ super().__init__()
168
+
169
+ def __getattr__(self, name: str) -> Any:
170
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
171
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
172
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
173
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
174
+ """
175
+
176
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
177
+ is_attribute = name in self.__dict__
178
+
179
+ if is_in_config and not is_attribute:
180
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
181
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
182
+ return self._internal_dict[name]
183
+
184
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
185
+ return super().__getattr__(name)
186
+
187
+ @property
188
+ def is_gradient_checkpointing(self) -> bool:
189
+ """
190
+ Whether gradient checkpointing is activated for this model or not.
191
+ """
192
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
193
+
194
+ def enable_gradient_checkpointing(self):
195
+ """
196
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
197
+ *checkpoint activations* in other frameworks).
198
+ """
199
+ if not self._supports_gradient_checkpointing:
200
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
201
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
202
+
203
+ def disable_gradient_checkpointing(self):
204
+ """
205
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
206
+ *checkpoint activations* in other frameworks).
207
+ """
208
+ if self._supports_gradient_checkpointing:
209
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
210
+
211
+ def set_use_memory_efficient_attention_xformers(
212
+ self, valid: bool, attention_op: Optional[Callable] = None
213
+ ) -> None:
214
+ # Recursively walk through all the children.
215
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
216
+ # gets the message
217
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
218
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
219
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
220
+
221
+ for child in module.children():
222
+ fn_recursive_set_mem_eff(child)
223
+
224
+ for module in self.children():
225
+ if isinstance(module, torch.nn.Module):
226
+ fn_recursive_set_mem_eff(module)
227
+
228
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
229
+ r"""
230
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
231
+
232
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
233
+ inference. Speed up during training is not guaranteed.
234
+
235
+ <Tip warning={true}>
236
+
237
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
238
+ precedent.
239
+
240
+ </Tip>
241
+
242
+ Parameters:
243
+ attention_op (`Callable`, *optional*):
244
+ Override the default `None` operator for use as `op` argument to the
245
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
246
+ function of xFormers.
247
+
248
+ Examples:
249
+
250
+ ```py
251
+ >>> import torch
252
+ >>> from diffusers import UNet2DConditionModel
253
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
254
+
255
+ >>> model = UNet2DConditionModel.from_pretrained(
256
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
257
+ ... )
258
+ >>> model = model.to("cuda")
259
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
260
+ ```
261
+ """
262
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
263
+
264
+ def disable_xformers_memory_efficient_attention(self):
265
+ r"""
266
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
267
+ """
268
+ self.set_use_memory_efficient_attention_xformers(False)
269
+
270
+ def save_pretrained(
271
+ self,
272
+ save_directory: Union[str, os.PathLike],
273
+ is_main_process: bool = True,
274
+ save_function: Callable = None,
275
+ safe_serialization: bool = True,
276
+ variant: Optional[str] = None,
277
+ push_to_hub: bool = False,
278
+ **kwargs,
279
+ ):
280
+ """
281
+ Save a model and its configuration file to a directory so that it can be reloaded using the
282
+ [`~models.ModelMixin.from_pretrained`] class method.
283
+
284
+ Arguments:
285
+ save_directory (`str` or `os.PathLike`):
286
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
287
+ is_main_process (`bool`, *optional*, defaults to `True`):
288
+ Whether the process calling this is the main process or not. Useful during distributed training and you
289
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
290
+ process to avoid race conditions.
291
+ save_function (`Callable`):
292
+ The function to use to save the state dictionary. Useful during distributed training when you need to
293
+ replace `torch.save` with another method. Can be configured with the environment variable
294
+ `DIFFUSERS_SAVE_MODE`.
295
+ safe_serialization (`bool`, *optional*, defaults to `True`):
296
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
297
+ variant (`str`, *optional*):
298
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
299
+ push_to_hub (`bool`, *optional*, defaults to `False`):
300
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
301
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
302
+ namespace).
303
+ kwargs (`Dict[str, Any]`, *optional*):
304
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
305
+ """
306
+ if os.path.isfile(save_directory):
307
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
308
+ return
309
+
310
+ os.makedirs(save_directory, exist_ok=True)
311
+
312
+ if push_to_hub:
313
+ commit_message = kwargs.pop("commit_message", None)
314
+ private = kwargs.pop("private", False)
315
+ create_pr = kwargs.pop("create_pr", False)
316
+ token = kwargs.pop("token", None)
317
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
318
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
319
+
320
+ # Only save the model itself if we are using distributed training
321
+ model_to_save = self
322
+
323
+ # Attach architecture to the config
324
+ # Save the config
325
+ if is_main_process:
326
+ model_to_save.save_config(save_directory)
327
+
328
+ # Save the model
329
+ state_dict = model_to_save.state_dict()
330
+
331
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
332
+ weights_name = _add_variant(weights_name, variant)
333
+
334
+ # Save the model
335
+ if safe_serialization:
336
+ safetensors.torch.save_file(
337
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
338
+ )
339
+ else:
340
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
341
+
342
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
343
+
344
+ if push_to_hub:
345
+ self._upload_folder(
346
+ save_directory,
347
+ repo_id,
348
+ token=token,
349
+ commit_message=commit_message,
350
+ create_pr=create_pr,
351
+ )
352
+
353
+ @classmethod
354
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
355
+ r"""
356
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
357
+
358
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
359
+ train the model, set it back in training mode with `model.train()`.
360
+
361
+ Parameters:
362
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
363
+ Can be either:
364
+
365
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
366
+ the Hub.
367
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
368
+ with [`~ModelMixin.save_pretrained`].
369
+
370
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
371
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
372
+ is not used.
373
+ torch_dtype (`str` or `torch.dtype`, *optional*):
374
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
375
+ dtype is automatically derived from the model's weights.
376
+ force_download (`bool`, *optional*, defaults to `False`):
377
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
378
+ cached versions if they exist.
379
+ resume_download (`bool`, *optional*, defaults to `False`):
380
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
381
+ incompletely downloaded files are deleted.
382
+ proxies (`Dict[str, str]`, *optional*):
383
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
384
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
385
+ output_loading_info (`bool`, *optional*, defaults to `False`):
386
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
387
+ local_files_only(`bool`, *optional*, defaults to `False`):
388
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
389
+ won't be downloaded from the Hub.
390
+ use_auth_token (`str` or *bool*, *optional*):
391
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
392
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
393
+ revision (`str`, *optional*, defaults to `"main"`):
394
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
395
+ allowed by Git.
396
+ from_flax (`bool`, *optional*, defaults to `False`):
397
+ Load the model weights from a Flax checkpoint save file.
398
+ subfolder (`str`, *optional*, defaults to `""`):
399
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
400
+ mirror (`str`, *optional*):
401
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
402
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
403
+ information.
404
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
405
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
406
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
407
+ same device.
408
+
409
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
410
+ more information about each option see [designing a device
411
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
412
+ max_memory (`Dict`, *optional*):
413
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
414
+ each GPU and the available CPU RAM if unset.
415
+ offload_folder (`str` or `os.PathLike`, *optional*):
416
+ The path to offload weights if `device_map` contains the value `"disk"`.
417
+ offload_state_dict (`bool`, *optional*):
418
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
419
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
420
+ when there is some disk offload.
421
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
422
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
423
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
424
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
425
+ argument to `True` will raise an error.
426
+ variant (`str`, *optional*):
427
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
428
+ loading `from_flax`.
429
+ use_safetensors (`bool`, *optional*, defaults to `None`):
430
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
431
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
432
+ weights. If set to `False`, `safetensors` weights are not loaded.
433
+
434
+ <Tip>
435
+
436
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
437
+ `huggingface-cli login`. You can also activate the special
438
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
439
+ firewalled environment.
440
+
441
+ </Tip>
442
+
443
+ Example:
444
+
445
+ ```py
446
+ from diffusers import UNet2DConditionModel
447
+
448
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
449
+ ```
450
+
451
+ If you get the error message below, you need to finetune the weights for your downstream task:
452
+
453
+ ```bash
454
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
455
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
456
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
457
+ ```
458
+ """
459
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
460
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
461
+ force_download = kwargs.pop("force_download", False)
462
+ from_flax = kwargs.pop("from_flax", False)
463
+ resume_download = kwargs.pop("resume_download", False)
464
+ proxies = kwargs.pop("proxies", None)
465
+ output_loading_info = kwargs.pop("output_loading_info", False)
466
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
467
+ use_auth_token = kwargs.pop("use_auth_token", None)
468
+ revision = kwargs.pop("revision", None)
469
+ torch_dtype = kwargs.pop("torch_dtype", None)
470
+ subfolder = kwargs.pop("subfolder", None)
471
+ device_map = kwargs.pop("device_map", None)
472
+ max_memory = kwargs.pop("max_memory", None)
473
+ offload_folder = kwargs.pop("offload_folder", None)
474
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
475
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
476
+ variant = kwargs.pop("variant", None)
477
+ use_safetensors = kwargs.pop("use_safetensors", None)
478
+
479
+ allow_pickle = False
480
+ if use_safetensors is None:
481
+ use_safetensors = True
482
+ allow_pickle = True
483
+
484
+ if low_cpu_mem_usage and not is_accelerate_available():
485
+ low_cpu_mem_usage = False
486
+ logger.warning(
487
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
488
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
489
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
490
+ " install accelerate\n```\n."
491
+ )
492
+
493
+ if device_map is not None and not is_accelerate_available():
494
+ raise NotImplementedError(
495
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
496
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
497
+ )
498
+
499
+ # Check if we can handle device_map and dispatching the weights
500
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
501
+ raise NotImplementedError(
502
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
503
+ " `device_map=None`."
504
+ )
505
+
506
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
507
+ raise NotImplementedError(
508
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
509
+ " `low_cpu_mem_usage=False`."
510
+ )
511
+
512
+ if low_cpu_mem_usage is False and device_map is not None:
513
+ raise ValueError(
514
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
515
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
516
+ )
517
+
518
+ # Load config if we don't provide a configuration
519
+ config_path = pretrained_model_name_or_path
520
+
521
+ user_agent = {
522
+ "diffusers": __version__,
523
+ "file_type": "model",
524
+ "framework": "pytorch",
525
+ }
526
+
527
+ # load config
528
+ config, unused_kwargs, commit_hash = cls.load_config(
529
+ config_path,
530
+ cache_dir=cache_dir,
531
+ return_unused_kwargs=True,
532
+ return_commit_hash=True,
533
+ force_download=force_download,
534
+ resume_download=resume_download,
535
+ proxies=proxies,
536
+ local_files_only=local_files_only,
537
+ use_auth_token=use_auth_token,
538
+ revision=revision,
539
+ subfolder=subfolder,
540
+ device_map=device_map,
541
+ max_memory=max_memory,
542
+ offload_folder=offload_folder,
543
+ offload_state_dict=offload_state_dict,
544
+ user_agent=user_agent,
545
+ **kwargs,
546
+ )
547
+
548
+ # load model
549
+ model_file = None
550
+ if from_flax:
551
+ model_file = _get_model_file(
552
+ pretrained_model_name_or_path,
553
+ weights_name=FLAX_WEIGHTS_NAME,
554
+ cache_dir=cache_dir,
555
+ force_download=force_download,
556
+ resume_download=resume_download,
557
+ proxies=proxies,
558
+ local_files_only=local_files_only,
559
+ use_auth_token=use_auth_token,
560
+ revision=revision,
561
+ subfolder=subfolder,
562
+ user_agent=user_agent,
563
+ commit_hash=commit_hash,
564
+ )
565
+ model = cls.from_config(config, **unused_kwargs)
566
+
567
+ # Convert the weights
568
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
569
+
570
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
571
+ else:
572
+ if use_safetensors:
573
+ try:
574
+ model_file = _get_model_file(
575
+ pretrained_model_name_or_path,
576
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
577
+ cache_dir=cache_dir,
578
+ force_download=force_download,
579
+ resume_download=resume_download,
580
+ proxies=proxies,
581
+ local_files_only=local_files_only,
582
+ use_auth_token=use_auth_token,
583
+ revision=revision,
584
+ subfolder=subfolder,
585
+ user_agent=user_agent,
586
+ commit_hash=commit_hash,
587
+ )
588
+ except IOError as e:
589
+ if not allow_pickle:
590
+ raise e
591
+ pass
592
+ if model_file is None:
593
+ model_file = _get_model_file(
594
+ pretrained_model_name_or_path,
595
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
596
+ cache_dir=cache_dir,
597
+ force_download=force_download,
598
+ resume_download=resume_download,
599
+ proxies=proxies,
600
+ local_files_only=local_files_only,
601
+ use_auth_token=use_auth_token,
602
+ revision=revision,
603
+ subfolder=subfolder,
604
+ user_agent=user_agent,
605
+ commit_hash=commit_hash,
606
+ )
607
+
608
+ if low_cpu_mem_usage:
609
+ # Instantiate model with empty weights
610
+ with accelerate.init_empty_weights():
611
+ model = cls.from_config(config, **unused_kwargs)
612
+
613
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
614
+ if device_map is None:
615
+ param_device = "cpu"
616
+ state_dict = load_state_dict(model_file, variant=variant)
617
+ model._convert_deprecated_attention_blocks(state_dict)
618
+ # move the params from meta device to cpu
619
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
620
+ if len(missing_keys) > 0:
621
+ raise ValueError(
622
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
623
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
624
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
625
+ " those weights or else make sure your checkpoint file is correct."
626
+ )
627
+ unexpected_keys = []
628
+
629
+ empty_state_dict = model.state_dict()
630
+ for param_name, param in state_dict.items():
631
+ accepts_dtype = "dtype" in set(
632
+ inspect.signature(set_module_tensor_to_device).parameters.keys()
633
+ )
634
+
635
+ if param_name not in empty_state_dict:
636
+ unexpected_keys.append(param_name)
637
+ continue
638
+
639
+ if empty_state_dict[param_name].shape != param.shape:
640
+ raise ValueError(
641
+ f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
642
+ )
643
+
644
+ if accepts_dtype:
645
+ set_module_tensor_to_device(
646
+ model, param_name, param_device, value=param, dtype=torch_dtype
647
+ )
648
+ else:
649
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
650
+
651
+ if cls._keys_to_ignore_on_load_unexpected is not None:
652
+ for pat in cls._keys_to_ignore_on_load_unexpected:
653
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
654
+
655
+ if len(unexpected_keys) > 0:
656
+ logger.warn(
657
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
658
+ )
659
+
660
+ else: # else let accelerate handle loading and dispatching.
661
+ # Load weights and dispatch according to the device_map
662
+ # by default the device_map is None and the weights are loaded on the CPU
663
+ try:
664
+ accelerate.load_checkpoint_and_dispatch(
665
+ model,
666
+ model_file,
667
+ device_map,
668
+ max_memory=max_memory,
669
+ offload_folder=offload_folder,
670
+ offload_state_dict=offload_state_dict,
671
+ dtype=torch_dtype,
672
+ )
673
+ except AttributeError as e:
674
+ # When using accelerate loading, we do not have the ability to load the state
675
+ # dict and rename the weight names manually. Additionally, accelerate skips
676
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
677
+ # (which look like they should be private variables?), so we can't use the standard hooks
678
+ # to rename parameters on load. We need to mimic the original weight names so the correct
679
+ # attributes are available. After we have loaded the weights, we convert the deprecated
680
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
681
+ # the weights so we don't have to do this again.
682
+
683
+ if "'Attention' object has no attribute" in str(e):
684
+ logger.warn(
685
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
686
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
687
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
688
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
689
+ " please also re-upload it or open a PR on the original repository."
690
+ )
691
+ model._temp_convert_self_to_deprecated_attention_blocks()
692
+ accelerate.load_checkpoint_and_dispatch(
693
+ model,
694
+ model_file,
695
+ device_map,
696
+ max_memory=max_memory,
697
+ offload_folder=offload_folder,
698
+ offload_state_dict=offload_state_dict,
699
+ dtype=torch_dtype,
700
+ )
701
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
702
+ else:
703
+ raise e
704
+
705
+ loading_info = {
706
+ "missing_keys": [],
707
+ "unexpected_keys": [],
708
+ "mismatched_keys": [],
709
+ "error_msgs": [],
710
+ }
711
+ else:
712
+ model = cls.from_config(config, **unused_kwargs)
713
+
714
+ state_dict = load_state_dict(model_file, variant=variant)
715
+ model._convert_deprecated_attention_blocks(state_dict)
716
+
717
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
718
+ model,
719
+ state_dict,
720
+ model_file,
721
+ pretrained_model_name_or_path,
722
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
723
+ )
724
+
725
+ loading_info = {
726
+ "missing_keys": missing_keys,
727
+ "unexpected_keys": unexpected_keys,
728
+ "mismatched_keys": mismatched_keys,
729
+ "error_msgs": error_msgs,
730
+ }
731
+
732
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
733
+ raise ValueError(
734
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
735
+ )
736
+ elif torch_dtype is not None:
737
+ model = model.to(torch_dtype)
738
+
739
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
740
+
741
+ # Set model in evaluation mode to deactivate DropOut modules by default
742
+ model.eval()
743
+ if output_loading_info:
744
+ return model, loading_info
745
+
746
+ return model
747
+
748
+ @classmethod
749
+ def _load_pretrained_model(
750
+ cls,
751
+ model,
752
+ state_dict,
753
+ resolved_archive_file,
754
+ pretrained_model_name_or_path,
755
+ ignore_mismatched_sizes=False,
756
+ ):
757
+ # Retrieve missing & unexpected_keys
758
+ model_state_dict = model.state_dict()
759
+ loaded_keys = list(state_dict.keys())
760
+
761
+ expected_keys = list(model_state_dict.keys())
762
+
763
+ original_loaded_keys = loaded_keys
764
+
765
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
766
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
767
+
768
+ # Make sure we are able to load base models as well as derived models (with heads)
769
+ model_to_load = model
770
+
771
+ def _find_mismatched_keys(
772
+ state_dict,
773
+ model_state_dict,
774
+ loaded_keys,
775
+ ignore_mismatched_sizes,
776
+ ):
777
+ mismatched_keys = []
778
+ if ignore_mismatched_sizes:
779
+ for checkpoint_key in loaded_keys:
780
+ model_key = checkpoint_key
781
+
782
+ if (
783
+ model_key in model_state_dict
784
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
785
+ ):
786
+ mismatched_keys.append(
787
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
788
+ )
789
+ del state_dict[checkpoint_key]
790
+ return mismatched_keys
791
+
792
+ if state_dict is not None:
793
+ # Whole checkpoint
794
+ mismatched_keys = _find_mismatched_keys(
795
+ state_dict,
796
+ model_state_dict,
797
+ original_loaded_keys,
798
+ ignore_mismatched_sizes,
799
+ )
800
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
801
+
802
+ if len(error_msgs) > 0:
803
+ error_msg = "\n\t".join(error_msgs)
804
+ if "size mismatch" in error_msg:
805
+ error_msg += (
806
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
807
+ )
808
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
809
+
810
+ if len(unexpected_keys) > 0:
811
+ logger.warning(
812
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
813
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
814
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
815
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
816
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
817
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
818
+ " identical (initializing a BertForSequenceClassification model from a"
819
+ " BertForSequenceClassification model)."
820
+ )
821
+ else:
822
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
823
+ if len(missing_keys) > 0:
824
+ logger.warning(
825
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
826
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
827
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
828
+ )
829
+ elif len(mismatched_keys) == 0:
830
+ logger.info(
831
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
832
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
833
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
834
+ " without further training."
835
+ )
836
+ if len(mismatched_keys) > 0:
837
+ mismatched_warning = "\n".join(
838
+ [
839
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
840
+ for key, shape1, shape2 in mismatched_keys
841
+ ]
842
+ )
843
+ logger.warning(
844
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
845
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
846
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
847
+ " able to use it for predictions and inference."
848
+ )
849
+
850
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
851
+
852
+ @property
853
+ def device(self) -> device:
854
+ """
855
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
856
+ device).
857
+ """
858
+ return get_parameter_device(self)
859
+
860
+ @property
861
+ def dtype(self) -> torch.dtype:
862
+ """
863
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
864
+ """
865
+ return get_parameter_dtype(self)
866
+
867
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
868
+ """
869
+ Get number of (trainable or non-embedding) parameters in the module.
870
+
871
+ Args:
872
+ only_trainable (`bool`, *optional*, defaults to `False`):
873
+ Whether or not to return only the number of trainable parameters.
874
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
875
+ Whether or not to return only the number of non-embedding parameters.
876
+
877
+ Returns:
878
+ `int`: The number of parameters.
879
+
880
+ Example:
881
+
882
+ ```py
883
+ from diffusers import UNet2DConditionModel
884
+
885
+ model_id = "runwayml/stable-diffusion-v1-5"
886
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
887
+ unet.num_parameters(only_trainable=True)
888
+ 859520964
889
+ ```
890
+ """
891
+
892
+ if exclude_embeddings:
893
+ embedding_param_names = [
894
+ f"{name}.weight"
895
+ for name, module_type in self.named_modules()
896
+ if isinstance(module_type, torch.nn.Embedding)
897
+ ]
898
+ non_embedding_parameters = [
899
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
900
+ ]
901
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
902
+ else:
903
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
904
+
905
+ def _convert_deprecated_attention_blocks(self, state_dict):
906
+ deprecated_attention_block_paths = []
907
+
908
+ def recursive_find_attn_block(name, module):
909
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
910
+ deprecated_attention_block_paths.append(name)
911
+
912
+ for sub_name, sub_module in module.named_children():
913
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
914
+ recursive_find_attn_block(sub_name, sub_module)
915
+
916
+ recursive_find_attn_block("", self)
917
+
918
+ # NOTE: we have to check if the deprecated parameters are in the state dict
919
+ # because it is possible we are loading from a state dict that was already
920
+ # converted
921
+
922
+ for path in deprecated_attention_block_paths:
923
+ # group_norm path stays the same
924
+
925
+ # query -> to_q
926
+ if f"{path}.query.weight" in state_dict:
927
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
928
+ if f"{path}.query.bias" in state_dict:
929
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
930
+
931
+ # key -> to_k
932
+ if f"{path}.key.weight" in state_dict:
933
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
934
+ if f"{path}.key.bias" in state_dict:
935
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
936
+
937
+ # value -> to_v
938
+ if f"{path}.value.weight" in state_dict:
939
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
940
+ if f"{path}.value.bias" in state_dict:
941
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
942
+
943
+ # proj_attn -> to_out.0
944
+ if f"{path}.proj_attn.weight" in state_dict:
945
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
946
+ if f"{path}.proj_attn.bias" in state_dict:
947
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
948
+
949
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
950
+ deprecated_attention_block_modules = []
951
+
952
+ def recursive_find_attn_block(module):
953
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
954
+ deprecated_attention_block_modules.append(module)
955
+
956
+ for sub_module in module.children():
957
+ recursive_find_attn_block(sub_module)
958
+
959
+ recursive_find_attn_block(self)
960
+
961
+ for module in deprecated_attention_block_modules:
962
+ module.query = module.to_q
963
+ module.key = module.to_k
964
+ module.value = module.to_v
965
+ module.proj_attn = module.to_out[0]
966
+
967
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
968
+ # that _all_ the weights are loaded into the new attributes and we're not
969
+ # making an incorrect assumption that this model should be converted when
970
+ # it really shouldn't be.
971
+ del module.to_q
972
+ del module.to_k
973
+ del module.to_v
974
+ del module.to_out
975
+
976
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
977
+ deprecated_attention_block_modules = []
978
+
979
+ def recursive_find_attn_block(module):
980
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
981
+ deprecated_attention_block_modules.append(module)
982
+
983
+ for sub_module in module.children():
984
+ recursive_find_attn_block(sub_module)
985
+
986
+ recursive_find_attn_block(self)
987
+
988
+ for module in deprecated_attention_block_modules:
989
+ module.to_q = module.query
990
+ module.to_k = module.key
991
+ module.to_v = module.value
992
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
993
+
994
+ del module.query
995
+ del module.key
996
+ del module.value
997
+ del module.proj_attn
Tiger Model/diffusiers-Tiger/models/prior_transformer.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from ..configuration_utils import ConfigMixin, register_to_config
9
+ from ..utils import BaseOutput
10
+ from .attention import BasicTransformerBlock
11
+ from .attention_processor import AttentionProcessor, AttnProcessor
12
+ from .embeddings import TimestepEmbedding, Timesteps
13
+ from .modeling_utils import ModelMixin
14
+
15
+
16
+ @dataclass
17
+ class PriorTransformerOutput(BaseOutput):
18
+ """
19
+ The output of [`PriorTransformer`].
20
+
21
+ Args:
22
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
23
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
24
+ """
25
+
26
+ predicted_image_embedding: torch.FloatTensor
27
+
28
+
29
+ class PriorTransformer(ModelMixin, ConfigMixin):
30
+ """
31
+ A Prior Transformer model.
32
+
33
+ Parameters:
34
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
35
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
36
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
37
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
38
+ num_embeddings (`int`, *optional*, defaults to 77):
39
+ The number of embeddings of the model input `hidden_states`
40
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
41
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
42
+ additional_embeddings`.
43
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
44
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
45
+ The activation function to use to create timestep embeddings.
46
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
47
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
48
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
49
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
50
+ needed.
51
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
52
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
53
+ `encoder_hidden_states` is `None`.
54
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
55
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
56
+ product between the text embedding and image embedding as proposed in the unclip paper
57
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
58
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
59
+ If None, will be set to `num_attention_heads * attention_head_dim`
60
+ embedding_proj_dim (`int`, *optional*, default to None):
61
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
62
+ clip_embed_dim (`int`, *optional*, default to None):
63
+ The dimension of the output. If None, will be set to `embedding_dim`.
64
+ """
65
+
66
+ @register_to_config
67
+ def __init__(
68
+ self,
69
+ num_attention_heads: int = 32,
70
+ attention_head_dim: int = 64,
71
+ num_layers: int = 20,
72
+ embedding_dim: int = 768,
73
+ num_embeddings=77,
74
+ additional_embeddings=4,
75
+ dropout: float = 0.0,
76
+ time_embed_act_fn: str = "silu",
77
+ norm_in_type: Optional[str] = None, # layer
78
+ embedding_proj_norm_type: Optional[str] = None, # layer
79
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
80
+ added_emb_type: Optional[str] = "prd", # prd
81
+ time_embed_dim: Optional[int] = None,
82
+ embedding_proj_dim: Optional[int] = None,
83
+ clip_embed_dim: Optional[int] = None,
84
+ ):
85
+ super().__init__()
86
+ self.num_attention_heads = num_attention_heads
87
+ self.attention_head_dim = attention_head_dim
88
+ inner_dim = num_attention_heads * attention_head_dim
89
+ self.additional_embeddings = additional_embeddings
90
+
91
+ time_embed_dim = time_embed_dim or inner_dim
92
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
93
+ clip_embed_dim = clip_embed_dim or embedding_dim
94
+
95
+ self.time_proj = Timesteps(inner_dim, True, 0)
96
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
97
+
98
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
99
+
100
+ if embedding_proj_norm_type is None:
101
+ self.embedding_proj_norm = None
102
+ elif embedding_proj_norm_type == "layer":
103
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
104
+ else:
105
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
106
+
107
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
108
+
109
+ if encoder_hid_proj_type is None:
110
+ self.encoder_hidden_states_proj = None
111
+ elif encoder_hid_proj_type == "linear":
112
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
113
+ else:
114
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
115
+
116
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
117
+
118
+ if added_emb_type == "prd":
119
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
120
+ elif added_emb_type is None:
121
+ self.prd_embedding = None
122
+ else:
123
+ raise ValueError(
124
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
125
+ )
126
+
127
+ self.transformer_blocks = nn.ModuleList(
128
+ [
129
+ BasicTransformerBlock(
130
+ inner_dim,
131
+ num_attention_heads,
132
+ attention_head_dim,
133
+ dropout=dropout,
134
+ activation_fn="gelu",
135
+ attention_bias=True,
136
+ )
137
+ for d in range(num_layers)
138
+ ]
139
+ )
140
+
141
+ if norm_in_type == "layer":
142
+ self.norm_in = nn.LayerNorm(inner_dim)
143
+ elif norm_in_type is None:
144
+ self.norm_in = None
145
+ else:
146
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
147
+
148
+ self.norm_out = nn.LayerNorm(inner_dim)
149
+
150
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
151
+
152
+ causal_attention_mask = torch.full(
153
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
154
+ )
155
+ causal_attention_mask.triu_(1)
156
+ causal_attention_mask = causal_attention_mask[None, ...]
157
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
158
+
159
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
160
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
161
+
162
+ @property
163
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
164
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
165
+ r"""
166
+ Returns:
167
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
168
+ indexed by its weight name.
169
+ """
170
+ # set recursively
171
+ processors = {}
172
+
173
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
174
+ if hasattr(module, "set_processor"):
175
+ processors[f"{name}.processor"] = module.processor
176
+
177
+ for sub_name, child in module.named_children():
178
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
179
+
180
+ return processors
181
+
182
+ for name, module in self.named_children():
183
+ fn_recursive_add_processors(name, module, processors)
184
+
185
+ return processors
186
+
187
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
188
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
189
+ r"""
190
+ Sets the attention processor to use to compute attention.
191
+
192
+ Parameters:
193
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
194
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
195
+ for **all** `Attention` layers.
196
+
197
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
198
+ processor. This is strongly recommended when setting trainable attention processors.
199
+
200
+ """
201
+ count = len(self.attn_processors.keys())
202
+
203
+ if isinstance(processor, dict) and len(processor) != count:
204
+ raise ValueError(
205
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
206
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
207
+ )
208
+
209
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
210
+ if hasattr(module, "set_processor"):
211
+ if not isinstance(processor, dict):
212
+ module.set_processor(processor)
213
+ else:
214
+ module.set_processor(processor.pop(f"{name}.processor"))
215
+
216
+ for sub_name, child in module.named_children():
217
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
218
+
219
+ for name, module in self.named_children():
220
+ fn_recursive_attn_processor(name, module, processor)
221
+
222
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
223
+ def set_default_attn_processor(self):
224
+ """
225
+ Disables custom attention processors and sets the default attention implementation.
226
+ """
227
+ self.set_attn_processor(AttnProcessor())
228
+
229
+ def forward(
230
+ self,
231
+ hidden_states,
232
+ timestep: Union[torch.Tensor, float, int],
233
+ proj_embedding: torch.FloatTensor,
234
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
235
+ attention_mask: Optional[torch.BoolTensor] = None,
236
+ return_dict: bool = True,
237
+ ):
238
+ """
239
+ The [`PriorTransformer`] forward method.
240
+
241
+ Args:
242
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
243
+ The currently predicted image embeddings.
244
+ timestep (`torch.LongTensor`):
245
+ Current denoising step.
246
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
247
+ Projected embedding vector the denoising process is conditioned on.
248
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
249
+ Hidden states of the text embeddings the denoising process is conditioned on.
250
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
251
+ Text mask for the text embeddings.
252
+ return_dict (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
254
+ tuple.
255
+
256
+ Returns:
257
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
258
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
259
+ tuple is returned where the first element is the sample tensor.
260
+ """
261
+ batch_size = hidden_states.shape[0]
262
+
263
+ timesteps = timestep
264
+ if not torch.is_tensor(timesteps):
265
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
266
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
267
+ timesteps = timesteps[None].to(hidden_states.device)
268
+
269
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
270
+ timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
271
+
272
+ timesteps_projected = self.time_proj(timesteps)
273
+
274
+ # timesteps does not contain any weights and will always return f32 tensors
275
+ # but time_embedding might be fp16, so we need to cast here.
276
+ timesteps_projected = timesteps_projected.to(dtype=self.dtype)
277
+ time_embeddings = self.time_embedding(timesteps_projected)
278
+
279
+ if self.embedding_proj_norm is not None:
280
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
281
+
282
+ proj_embeddings = self.embedding_proj(proj_embedding)
283
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
284
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
285
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
286
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
287
+
288
+ hidden_states = self.proj_in(hidden_states)
289
+
290
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
291
+
292
+ additional_embeds = []
293
+ additional_embeddings_len = 0
294
+
295
+ if encoder_hidden_states is not None:
296
+ additional_embeds.append(encoder_hidden_states)
297
+ additional_embeddings_len += encoder_hidden_states.shape[1]
298
+
299
+ if len(proj_embeddings.shape) == 2:
300
+ proj_embeddings = proj_embeddings[:, None, :]
301
+
302
+ if len(hidden_states.shape) == 2:
303
+ hidden_states = hidden_states[:, None, :]
304
+
305
+ additional_embeds = additional_embeds + [
306
+ proj_embeddings,
307
+ time_embeddings[:, None, :],
308
+ hidden_states,
309
+ ]
310
+
311
+ if self.prd_embedding is not None:
312
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
313
+ additional_embeds.append(prd_embedding)
314
+
315
+ hidden_states = torch.cat(
316
+ additional_embeds,
317
+ dim=1,
318
+ )
319
+
320
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
321
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
322
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
323
+ positional_embeddings = F.pad(
324
+ positional_embeddings,
325
+ (
326
+ 0,
327
+ 0,
328
+ additional_embeddings_len,
329
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
330
+ ),
331
+ value=0.0,
332
+ )
333
+
334
+ hidden_states = hidden_states + positional_embeddings
335
+
336
+ if attention_mask is not None:
337
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
338
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
339
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
340
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
341
+
342
+ if self.norm_in is not None:
343
+ hidden_states = self.norm_in(hidden_states)
344
+
345
+ for block in self.transformer_blocks:
346
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
347
+
348
+ hidden_states = self.norm_out(hidden_states)
349
+
350
+ if self.prd_embedding is not None:
351
+ hidden_states = hidden_states[:, -1]
352
+ else:
353
+ hidden_states = hidden_states[:, additional_embeddings_len:]
354
+
355
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
356
+
357
+ if not return_dict:
358
+ return (predicted_image_embedding,)
359
+
360
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
361
+
362
+ def post_process_latents(self, prior_latents):
363
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
364
+ return prior_latents
Tiger Model/diffusiers-Tiger/models/resnet.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from .activations import get_activation
24
+ from .attention import AdaGroupNorm
25
+ from .attention_processor import SpatialNorm
26
+ from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
+
28
+
29
+ class Upsample1D(nn.Module):
30
+ """A 1D upsampling layer with an optional convolution.
31
+
32
+ Parameters:
33
+ channels (`int`):
34
+ number of channels in the inputs and outputs.
35
+ use_conv (`bool`, default `False`):
36
+ option to use a convolution.
37
+ use_conv_transpose (`bool`, default `False`):
38
+ option to use a convolution transpose.
39
+ out_channels (`int`, optional):
40
+ number of output channels. Defaults to `channels`.
41
+ """
42
+
43
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
44
+ super().__init__()
45
+ self.channels = channels
46
+ self.out_channels = out_channels or channels
47
+ self.use_conv = use_conv
48
+ self.use_conv_transpose = use_conv_transpose
49
+ self.name = name
50
+
51
+ self.conv = None
52
+ if use_conv_transpose:
53
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
54
+ elif use_conv:
55
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
56
+
57
+ def forward(self, inputs):
58
+ assert inputs.shape[1] == self.channels
59
+ if self.use_conv_transpose:
60
+ return self.conv(inputs)
61
+
62
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
63
+
64
+ if self.use_conv:
65
+ outputs = self.conv(outputs)
66
+
67
+ return outputs
68
+
69
+
70
+ class Downsample1D(nn.Module):
71
+ """A 1D downsampling layer with an optional convolution.
72
+
73
+ Parameters:
74
+ channels (`int`):
75
+ number of channels in the inputs and outputs.
76
+ use_conv (`bool`, default `False`):
77
+ option to use a convolution.
78
+ out_channels (`int`, optional):
79
+ number of output channels. Defaults to `channels`.
80
+ padding (`int`, default `1`):
81
+ padding for the convolution.
82
+ """
83
+
84
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
85
+ super().__init__()
86
+ self.channels = channels
87
+ self.out_channels = out_channels or channels
88
+ self.use_conv = use_conv
89
+ self.padding = padding
90
+ stride = 2
91
+ self.name = name
92
+
93
+ if use_conv:
94
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
95
+ else:
96
+ assert self.channels == self.out_channels
97
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
98
+
99
+ def forward(self, inputs):
100
+ assert inputs.shape[1] == self.channels
101
+ return self.conv(inputs)
102
+
103
+
104
+ class Upsample2D(nn.Module):
105
+ """A 2D upsampling layer with an optional convolution.
106
+
107
+ Parameters:
108
+ channels (`int`):
109
+ number of channels in the inputs and outputs.
110
+ use_conv (`bool`, default `False`):
111
+ option to use a convolution.
112
+ use_conv_transpose (`bool`, default `False`):
113
+ option to use a convolution transpose.
114
+ out_channels (`int`, optional):
115
+ number of output channels. Defaults to `channels`.
116
+ """
117
+
118
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
119
+ super().__init__()
120
+ self.channels = channels
121
+ self.out_channels = out_channels or channels
122
+ self.use_conv = use_conv
123
+ self.use_conv_transpose = use_conv_transpose
124
+ self.name = name
125
+
126
+ conv = None
127
+ if use_conv_transpose:
128
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
129
+ elif use_conv:
130
+ conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
131
+
132
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
133
+ if name == "conv":
134
+ self.conv = conv
135
+ else:
136
+ self.Conv2d_0 = conv
137
+
138
+ def forward(self, hidden_states, output_size=None):
139
+ assert hidden_states.shape[1] == self.channels
140
+
141
+ if self.use_conv_transpose:
142
+ return self.conv(hidden_states)
143
+
144
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
145
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
146
+ # https://github.com/pytorch/pytorch/issues/86679
147
+ dtype = hidden_states.dtype
148
+ if dtype == torch.bfloat16:
149
+ hidden_states = hidden_states.to(torch.float32)
150
+
151
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
152
+ if hidden_states.shape[0] >= 64:
153
+ hidden_states = hidden_states.contiguous()
154
+
155
+ # if `output_size` is passed we force the interpolation output
156
+ # size and do not make use of `scale_factor=2`
157
+ if output_size is None:
158
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
159
+ else:
160
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
161
+
162
+ # If the input is bfloat16, we cast back to bfloat16
163
+ if dtype == torch.bfloat16:
164
+ hidden_states = hidden_states.to(dtype)
165
+
166
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
167
+ if self.use_conv:
168
+ if self.name == "conv":
169
+ hidden_states = self.conv(hidden_states)
170
+ else:
171
+ hidden_states = self.Conv2d_0(hidden_states)
172
+
173
+ return hidden_states
174
+
175
+
176
+ class Downsample2D(nn.Module):
177
+ """A 2D downsampling layer with an optional convolution.
178
+
179
+ Parameters:
180
+ channels (`int`):
181
+ number of channels in the inputs and outputs.
182
+ use_conv (`bool`, default `False`):
183
+ option to use a convolution.
184
+ out_channels (`int`, optional):
185
+ number of output channels. Defaults to `channels`.
186
+ padding (`int`, default `1`):
187
+ padding for the convolution.
188
+ """
189
+
190
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
191
+ super().__init__()
192
+ self.channels = channels
193
+ self.out_channels = out_channels or channels
194
+ self.use_conv = use_conv
195
+ self.padding = padding
196
+ stride = 2
197
+ self.name = name
198
+
199
+ if use_conv:
200
+ conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
201
+ else:
202
+ assert self.channels == self.out_channels
203
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
204
+
205
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
206
+ if name == "conv":
207
+ self.Conv2d_0 = conv
208
+ self.conv = conv
209
+ elif name == "Conv2d_0":
210
+ self.conv = conv
211
+ else:
212
+ self.conv = conv
213
+
214
+ def forward(self, hidden_states):
215
+ assert hidden_states.shape[1] == self.channels
216
+ if self.use_conv and self.padding == 0:
217
+ pad = (0, 1, 0, 1)
218
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
219
+
220
+ assert hidden_states.shape[1] == self.channels
221
+ hidden_states = self.conv(hidden_states)
222
+
223
+ return hidden_states
224
+
225
+
226
+ class FirUpsample2D(nn.Module):
227
+ """A 2D FIR upsampling layer with an optional convolution.
228
+
229
+ Parameters:
230
+ channels (`int`):
231
+ number of channels in the inputs and outputs.
232
+ use_conv (`bool`, default `False`):
233
+ option to use a convolution.
234
+ out_channels (`int`, optional):
235
+ number of output channels. Defaults to `channels`.
236
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
237
+ kernel for the FIR filter.
238
+ """
239
+
240
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
241
+ super().__init__()
242
+ out_channels = out_channels if out_channels else channels
243
+ if use_conv:
244
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
245
+ self.use_conv = use_conv
246
+ self.fir_kernel = fir_kernel
247
+ self.out_channels = out_channels
248
+
249
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
250
+ """Fused `upsample_2d()` followed by `Conv2d()`.
251
+
252
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
253
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
254
+ arbitrary order.
255
+
256
+ Args:
257
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
258
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
259
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
260
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
261
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
262
+ factor: Integer upsampling factor (default: 2).
263
+ gain: Scaling factor for signal magnitude (default: 1.0).
264
+
265
+ Returns:
266
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
267
+ datatype as `hidden_states`.
268
+ """
269
+
270
+ assert isinstance(factor, int) and factor >= 1
271
+
272
+ # Setup filter kernel.
273
+ if kernel is None:
274
+ kernel = [1] * factor
275
+
276
+ # setup kernel
277
+ kernel = torch.tensor(kernel, dtype=torch.float32)
278
+ if kernel.ndim == 1:
279
+ kernel = torch.outer(kernel, kernel)
280
+ kernel /= torch.sum(kernel)
281
+
282
+ kernel = kernel * (gain * (factor**2))
283
+
284
+ if self.use_conv:
285
+ convH = weight.shape[2]
286
+ convW = weight.shape[3]
287
+ inC = weight.shape[1]
288
+
289
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
290
+
291
+ stride = (factor, factor)
292
+ # Determine data dimensions.
293
+ output_shape = (
294
+ (hidden_states.shape[2] - 1) * factor + convH,
295
+ (hidden_states.shape[3] - 1) * factor + convW,
296
+ )
297
+ output_padding = (
298
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
299
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
300
+ )
301
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
302
+ num_groups = hidden_states.shape[1] // inC
303
+
304
+ # Transpose weights.
305
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
306
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
307
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
308
+
309
+ inverse_conv = F.conv_transpose2d(
310
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
311
+ )
312
+
313
+ output = upfirdn2d_native(
314
+ inverse_conv,
315
+ torch.tensor(kernel, device=inverse_conv.device),
316
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
317
+ )
318
+ else:
319
+ pad_value = kernel.shape[0] - factor
320
+ output = upfirdn2d_native(
321
+ hidden_states,
322
+ torch.tensor(kernel, device=hidden_states.device),
323
+ up=factor,
324
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
325
+ )
326
+
327
+ return output
328
+
329
+ def forward(self, hidden_states):
330
+ if self.use_conv:
331
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
332
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
333
+ else:
334
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
335
+
336
+ return height
337
+
338
+
339
+ class FirDownsample2D(nn.Module):
340
+ """A 2D FIR downsampling layer with an optional convolution.
341
+
342
+ Parameters:
343
+ channels (`int`):
344
+ number of channels in the inputs and outputs.
345
+ use_conv (`bool`, default `False`):
346
+ option to use a convolution.
347
+ out_channels (`int`, optional):
348
+ number of output channels. Defaults to `channels`.
349
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
350
+ kernel for the FIR filter.
351
+ """
352
+
353
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
354
+ super().__init__()
355
+ out_channels = out_channels if out_channels else channels
356
+ if use_conv:
357
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
358
+ self.fir_kernel = fir_kernel
359
+ self.use_conv = use_conv
360
+ self.out_channels = out_channels
361
+
362
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
363
+ """Fused `Conv2d()` followed by `downsample_2d()`.
364
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
365
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
366
+ arbitrary order.
367
+
368
+ Args:
369
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
370
+ weight:
371
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
372
+ performed by `inChannels = x.shape[0] // numGroups`.
373
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
374
+ factor`, which corresponds to average pooling.
375
+ factor: Integer downsampling factor (default: 2).
376
+ gain: Scaling factor for signal magnitude (default: 1.0).
377
+
378
+ Returns:
379
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
380
+ same datatype as `x`.
381
+ """
382
+
383
+ assert isinstance(factor, int) and factor >= 1
384
+ if kernel is None:
385
+ kernel = [1] * factor
386
+
387
+ # setup kernel
388
+ kernel = torch.tensor(kernel, dtype=torch.float32)
389
+ if kernel.ndim == 1:
390
+ kernel = torch.outer(kernel, kernel)
391
+ kernel /= torch.sum(kernel)
392
+
393
+ kernel = kernel * gain
394
+
395
+ if self.use_conv:
396
+ _, _, convH, convW = weight.shape
397
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
398
+ stride_value = [factor, factor]
399
+ upfirdn_input = upfirdn2d_native(
400
+ hidden_states,
401
+ torch.tensor(kernel, device=hidden_states.device),
402
+ pad=((pad_value + 1) // 2, pad_value // 2),
403
+ )
404
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
405
+ else:
406
+ pad_value = kernel.shape[0] - factor
407
+ output = upfirdn2d_native(
408
+ hidden_states,
409
+ torch.tensor(kernel, device=hidden_states.device),
410
+ down=factor,
411
+ pad=((pad_value + 1) // 2, pad_value // 2),
412
+ )
413
+
414
+ return output
415
+
416
+ def forward(self, hidden_states):
417
+ if self.use_conv:
418
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
419
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
420
+ else:
421
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
422
+
423
+ return hidden_states
424
+
425
+
426
+ # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
427
+ class KDownsample2D(nn.Module):
428
+ def __init__(self, pad_mode="reflect"):
429
+ super().__init__()
430
+ self.pad_mode = pad_mode
431
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
432
+ self.pad = kernel_1d.shape[1] // 2 - 1
433
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
434
+
435
+ def forward(self, inputs):
436
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
437
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
438
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
439
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
440
+ weight[indices, indices] = kernel
441
+ return F.conv2d(inputs, weight, stride=2)
442
+
443
+
444
+ class KUpsample2D(nn.Module):
445
+ def __init__(self, pad_mode="reflect"):
446
+ super().__init__()
447
+ self.pad_mode = pad_mode
448
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
449
+ self.pad = kernel_1d.shape[1] // 2 - 1
450
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
451
+
452
+ def forward(self, inputs):
453
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
454
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
455
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
456
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
457
+ weight[indices, indices] = kernel
458
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
459
+
460
+
461
+ class ResnetBlock2D(nn.Module):
462
+ r"""
463
+ A Resnet block.
464
+
465
+ Parameters:
466
+ in_channels (`int`): The number of channels in the input.
467
+ out_channels (`int`, *optional*, default to be `None`):
468
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
469
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
470
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
471
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
472
+ groups_out (`int`, *optional*, default to None):
473
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
474
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
475
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
476
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
477
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
478
+ "ada_group" for a stronger conditioning with scale and shift.
479
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
480
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
481
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
482
+ use_in_shortcut (`bool`, *optional*, default to `True`):
483
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
484
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
485
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
486
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
487
+ `conv_shortcut` output.
488
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
489
+ If None, same as `out_channels`.
490
+ """
491
+
492
+ def __init__(
493
+ self,
494
+ *,
495
+ in_channels,
496
+ out_channels=None,
497
+ conv_shortcut=False,
498
+ dropout=0.0,
499
+ temb_channels=512,
500
+ groups=32,
501
+ groups_out=None,
502
+ pre_norm=True,
503
+ eps=1e-6,
504
+ non_linearity="swish",
505
+ skip_time_act=False,
506
+ time_embedding_norm="default", # default, scale_shift, ada_group, spatial
507
+ kernel=None,
508
+ output_scale_factor=1.0,
509
+ use_in_shortcut=None,
510
+ up=False,
511
+ down=False,
512
+ conv_shortcut_bias: bool = True,
513
+ conv_2d_out_channels: Optional[int] = None,
514
+ ):
515
+ super().__init__()
516
+ self.pre_norm = pre_norm
517
+ self.pre_norm = True
518
+ self.in_channels = in_channels
519
+ out_channels = in_channels if out_channels is None else out_channels
520
+ self.out_channels = out_channels
521
+ self.use_conv_shortcut = conv_shortcut
522
+ self.up = up
523
+ self.down = down
524
+ self.output_scale_factor = output_scale_factor
525
+ self.time_embedding_norm = time_embedding_norm
526
+ self.skip_time_act = skip_time_act
527
+
528
+ if groups_out is None:
529
+ groups_out = groups
530
+
531
+ if self.time_embedding_norm == "ada_group":
532
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
533
+ elif self.time_embedding_norm == "spatial":
534
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
535
+ else:
536
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
537
+
538
+ self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
539
+
540
+ if temb_channels is not None:
541
+ if self.time_embedding_norm == "default":
542
+ self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
543
+ elif self.time_embedding_norm == "scale_shift":
544
+ self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
545
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
546
+ self.time_emb_proj = None
547
+ else:
548
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
549
+ else:
550
+ self.time_emb_proj = None
551
+
552
+ if self.time_embedding_norm == "ada_group":
553
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
554
+ elif self.time_embedding_norm == "spatial":
555
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
556
+ else:
557
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
558
+
559
+ self.dropout = torch.nn.Dropout(dropout)
560
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
561
+ self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
562
+
563
+ self.nonlinearity = get_activation(non_linearity)
564
+
565
+ self.upsample = self.downsample = None
566
+ if self.up:
567
+ if kernel == "fir":
568
+ fir_kernel = (1, 3, 3, 1)
569
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
570
+ elif kernel == "sde_vp":
571
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
572
+ else:
573
+ self.upsample = Upsample2D(in_channels, use_conv=False)
574
+ elif self.down:
575
+ if kernel == "fir":
576
+ fir_kernel = (1, 3, 3, 1)
577
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
578
+ elif kernel == "sde_vp":
579
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
580
+ else:
581
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
582
+
583
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
584
+
585
+ self.conv_shortcut = None
586
+ if self.use_in_shortcut:
587
+ self.conv_shortcut = LoRACompatibleConv(
588
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
589
+ )
590
+
591
+ def forward(self, input_tensor, temb):
592
+ hidden_states = input_tensor
593
+
594
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
595
+ hidden_states = self.norm1(hidden_states, temb)
596
+ else:
597
+ hidden_states = self.norm1(hidden_states)
598
+
599
+ hidden_states = self.nonlinearity(hidden_states)
600
+
601
+ if self.upsample is not None:
602
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
603
+ if hidden_states.shape[0] >= 64:
604
+ input_tensor = input_tensor.contiguous()
605
+ hidden_states = hidden_states.contiguous()
606
+ input_tensor = self.upsample(input_tensor)
607
+ hidden_states = self.upsample(hidden_states)
608
+ elif self.downsample is not None:
609
+ input_tensor = self.downsample(input_tensor)
610
+ hidden_states = self.downsample(hidden_states)
611
+
612
+ hidden_states = self.conv1(hidden_states)
613
+
614
+ if self.time_emb_proj is not None:
615
+ if not self.skip_time_act:
616
+ temb = self.nonlinearity(temb)
617
+ temb = self.time_emb_proj(temb)[:, :, None, None]
618
+
619
+ if temb is not None and self.time_embedding_norm == "default":
620
+ hidden_states = hidden_states + temb
621
+
622
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
623
+ hidden_states = self.norm2(hidden_states, temb)
624
+ else:
625
+ hidden_states = self.norm2(hidden_states)
626
+
627
+ if temb is not None and self.time_embedding_norm == "scale_shift":
628
+ scale, shift = torch.chunk(temb, 2, dim=1)
629
+ hidden_states = hidden_states * (1 + scale) + shift
630
+
631
+ hidden_states = self.nonlinearity(hidden_states)
632
+
633
+ hidden_states = self.dropout(hidden_states)
634
+ hidden_states = self.conv2(hidden_states)
635
+
636
+ if self.conv_shortcut is not None:
637
+ input_tensor = self.conv_shortcut(input_tensor)
638
+
639
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
640
+
641
+ return output_tensor
642
+
643
+
644
+ # unet_rl.py
645
+ def rearrange_dims(tensor):
646
+ if len(tensor.shape) == 2:
647
+ return tensor[:, :, None]
648
+ if len(tensor.shape) == 3:
649
+ return tensor[:, :, None, :]
650
+ elif len(tensor.shape) == 4:
651
+ return tensor[:, :, 0, :]
652
+ else:
653
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
654
+
655
+
656
+ class Conv1dBlock(nn.Module):
657
+ """
658
+ Conv1d --> GroupNorm --> Mish
659
+ """
660
+
661
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
662
+ super().__init__()
663
+
664
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
665
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
666
+ self.mish = nn.Mish()
667
+
668
+ def forward(self, inputs):
669
+ intermediate_repr = self.conv1d(inputs)
670
+ intermediate_repr = rearrange_dims(intermediate_repr)
671
+ intermediate_repr = self.group_norm(intermediate_repr)
672
+ intermediate_repr = rearrange_dims(intermediate_repr)
673
+ output = self.mish(intermediate_repr)
674
+ return output
675
+
676
+
677
+ # unet_rl.py
678
+ class ResidualTemporalBlock1D(nn.Module):
679
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
680
+ super().__init__()
681
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
682
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
683
+
684
+ self.time_emb_act = nn.Mish()
685
+ self.time_emb = nn.Linear(embed_dim, out_channels)
686
+
687
+ self.residual_conv = (
688
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
689
+ )
690
+
691
+ def forward(self, inputs, t):
692
+ """
693
+ Args:
694
+ inputs : [ batch_size x inp_channels x horizon ]
695
+ t : [ batch_size x embed_dim ]
696
+
697
+ returns:
698
+ out : [ batch_size x out_channels x horizon ]
699
+ """
700
+ t = self.time_emb_act(t)
701
+ t = self.time_emb(t)
702
+ out = self.conv_in(inputs) + rearrange_dims(t)
703
+ out = self.conv_out(out)
704
+ return out + self.residual_conv(inputs)
705
+
706
+
707
+ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
708
+ r"""Upsample2D a batch of 2D images with the given filter.
709
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
710
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
711
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
712
+ a: multiple of the upsampling factor.
713
+
714
+ Args:
715
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
716
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
717
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
718
+ factor: Integer upsampling factor (default: 2).
719
+ gain: Scaling factor for signal magnitude (default: 1.0).
720
+
721
+ Returns:
722
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
723
+ """
724
+ assert isinstance(factor, int) and factor >= 1
725
+ if kernel is None:
726
+ kernel = [1] * factor
727
+
728
+ kernel = torch.tensor(kernel, dtype=torch.float32)
729
+ if kernel.ndim == 1:
730
+ kernel = torch.outer(kernel, kernel)
731
+ kernel /= torch.sum(kernel)
732
+
733
+ kernel = kernel * (gain * (factor**2))
734
+ pad_value = kernel.shape[0] - factor
735
+ output = upfirdn2d_native(
736
+ hidden_states,
737
+ kernel.to(device=hidden_states.device),
738
+ up=factor,
739
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
740
+ )
741
+ return output
742
+
743
+
744
+ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
745
+ r"""Downsample2D a batch of 2D images with the given filter.
746
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
747
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
748
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
749
+ shape is a multiple of the downsampling factor.
750
+
751
+ Args:
752
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
753
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
754
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
755
+ factor: Integer downsampling factor (default: 2).
756
+ gain: Scaling factor for signal magnitude (default: 1.0).
757
+
758
+ Returns:
759
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
760
+ """
761
+
762
+ assert isinstance(factor, int) and factor >= 1
763
+ if kernel is None:
764
+ kernel = [1] * factor
765
+
766
+ kernel = torch.tensor(kernel, dtype=torch.float32)
767
+ if kernel.ndim == 1:
768
+ kernel = torch.outer(kernel, kernel)
769
+ kernel /= torch.sum(kernel)
770
+
771
+ kernel = kernel * gain
772
+ pad_value = kernel.shape[0] - factor
773
+ output = upfirdn2d_native(
774
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
775
+ )
776
+ return output
777
+
778
+
779
+ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
780
+ up_x = up_y = up
781
+ down_x = down_y = down
782
+ pad_x0 = pad_y0 = pad[0]
783
+ pad_x1 = pad_y1 = pad[1]
784
+
785
+ _, channel, in_h, in_w = tensor.shape
786
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
787
+
788
+ _, in_h, in_w, minor = tensor.shape
789
+ kernel_h, kernel_w = kernel.shape
790
+
791
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
792
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
793
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
794
+
795
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
796
+ out = out.to(tensor.device) # Move back to mps if necessary
797
+ out = out[
798
+ :,
799
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
800
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
801
+ :,
802
+ ]
803
+
804
+ out = out.permute(0, 3, 1, 2)
805
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
806
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
807
+ out = F.conv2d(out, w)
808
+ out = out.reshape(
809
+ -1,
810
+ minor,
811
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
812
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
813
+ )
814
+ out = out.permute(0, 2, 3, 1)
815
+ out = out[:, ::down_y, ::down_x, :]
816
+
817
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
818
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
819
+
820
+ return out.view(-1, channel, out_h, out_w)
821
+
822
+
823
+ class TemporalConvLayer(nn.Module):
824
+ """
825
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
826
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
827
+ """
828
+
829
+ def __init__(self, in_dim, out_dim=None, dropout=0.0):
830
+ super().__init__()
831
+ out_dim = out_dim or in_dim
832
+ self.in_dim = in_dim
833
+ self.out_dim = out_dim
834
+
835
+ # conv layers
836
+ self.conv1 = nn.Sequential(
837
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
838
+ )
839
+ self.conv2 = nn.Sequential(
840
+ nn.GroupNorm(32, out_dim),
841
+ nn.SiLU(),
842
+ nn.Dropout(dropout),
843
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
844
+ )
845
+ self.conv3 = nn.Sequential(
846
+ nn.GroupNorm(32, out_dim),
847
+ nn.SiLU(),
848
+ nn.Dropout(dropout),
849
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
850
+ )
851
+ self.conv4 = nn.Sequential(
852
+ nn.GroupNorm(32, out_dim),
853
+ nn.SiLU(),
854
+ nn.Dropout(dropout),
855
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
856
+ )
857
+
858
+ # zero out the last layer params,so the conv block is identity
859
+ nn.init.zeros_(self.conv4[-1].weight)
860
+ nn.init.zeros_(self.conv4[-1].bias)
861
+
862
+ def forward(self, hidden_states, num_frames=1):
863
+ hidden_states = (
864
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
865
+ )
866
+
867
+ identity = hidden_states
868
+ hidden_states = self.conv1(hidden_states)
869
+ hidden_states = self.conv2(hidden_states)
870
+ hidden_states = self.conv3(hidden_states)
871
+ hidden_states = self.conv4(hidden_states)
872
+
873
+ hidden_states = identity + hidden_states
874
+
875
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
876
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
877
+ )
878
+ return hidden_states
Tiger Model/diffusiers-Tiger/models/t5_film_transformer.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ..configuration_utils import ConfigMixin, register_to_config
20
+ from .attention_processor import Attention
21
+ from .embeddings import get_timestep_embedding
22
+ from .modeling_utils import ModelMixin
23
+
24
+
25
+ class T5FilmDecoder(ModelMixin, ConfigMixin):
26
+ @register_to_config
27
+ def __init__(
28
+ self,
29
+ input_dims: int = 128,
30
+ targets_length: int = 256,
31
+ max_decoder_noise_time: float = 2000.0,
32
+ d_model: int = 768,
33
+ num_layers: int = 12,
34
+ num_heads: int = 12,
35
+ d_kv: int = 64,
36
+ d_ff: int = 2048,
37
+ dropout_rate: float = 0.1,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.conditioning_emb = nn.Sequential(
42
+ nn.Linear(d_model, d_model * 4, bias=False),
43
+ nn.SiLU(),
44
+ nn.Linear(d_model * 4, d_model * 4, bias=False),
45
+ nn.SiLU(),
46
+ )
47
+
48
+ self.position_encoding = nn.Embedding(targets_length, d_model)
49
+ self.position_encoding.weight.requires_grad = False
50
+
51
+ self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
52
+
53
+ self.dropout = nn.Dropout(p=dropout_rate)
54
+
55
+ self.decoders = nn.ModuleList()
56
+ for lyr_num in range(num_layers):
57
+ # FiLM conditional T5 decoder
58
+ lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
59
+ self.decoders.append(lyr)
60
+
61
+ self.decoder_norm = T5LayerNorm(d_model)
62
+
63
+ self.post_dropout = nn.Dropout(p=dropout_rate)
64
+ self.spec_out = nn.Linear(d_model, input_dims, bias=False)
65
+
66
+ def encoder_decoder_mask(self, query_input, key_input):
67
+ mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
68
+ return mask.unsqueeze(-3)
69
+
70
+ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
71
+ batch, _, _ = decoder_input_tokens.shape
72
+ assert decoder_noise_time.shape == (batch,)
73
+
74
+ # decoder_noise_time is in [0, 1), so rescale to expected timing range.
75
+ time_steps = get_timestep_embedding(
76
+ decoder_noise_time * self.config.max_decoder_noise_time,
77
+ embedding_dim=self.config.d_model,
78
+ max_period=self.config.max_decoder_noise_time,
79
+ ).to(dtype=self.dtype)
80
+
81
+ conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
82
+
83
+ assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
84
+
85
+ seq_length = decoder_input_tokens.shape[1]
86
+
87
+ # If we want to use relative positions for audio context, we can just offset
88
+ # this sequence by the length of encodings_and_masks.
89
+ decoder_positions = torch.broadcast_to(
90
+ torch.arange(seq_length, device=decoder_input_tokens.device),
91
+ (batch, seq_length),
92
+ )
93
+
94
+ position_encodings = self.position_encoding(decoder_positions)
95
+
96
+ inputs = self.continuous_inputs_projection(decoder_input_tokens)
97
+ inputs += position_encodings
98
+ y = self.dropout(inputs)
99
+
100
+ # decoder: No padding present.
101
+ decoder_mask = torch.ones(
102
+ decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
103
+ )
104
+
105
+ # Translate encoding masks to encoder-decoder masks.
106
+ encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
107
+
108
+ # cross attend style: concat encodings
109
+ encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
110
+ encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
111
+
112
+ for lyr in self.decoders:
113
+ y = lyr(
114
+ y,
115
+ conditioning_emb=conditioning_emb,
116
+ encoder_hidden_states=encoded,
117
+ encoder_attention_mask=encoder_decoder_mask,
118
+ )[0]
119
+
120
+ y = self.decoder_norm(y)
121
+ y = self.post_dropout(y)
122
+
123
+ spec_out = self.spec_out(y)
124
+ return spec_out
125
+
126
+
127
+ class DecoderLayer(nn.Module):
128
+ def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
129
+ super().__init__()
130
+ self.layer = nn.ModuleList()
131
+
132
+ # cond self attention: layer 0
133
+ self.layer.append(
134
+ T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
135
+ )
136
+
137
+ # cross attention: layer 1
138
+ self.layer.append(
139
+ T5LayerCrossAttention(
140
+ d_model=d_model,
141
+ d_kv=d_kv,
142
+ num_heads=num_heads,
143
+ dropout_rate=dropout_rate,
144
+ layer_norm_epsilon=layer_norm_epsilon,
145
+ )
146
+ )
147
+
148
+ # Film Cond MLP + dropout: last layer
149
+ self.layer.append(
150
+ T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
151
+ )
152
+
153
+ def forward(
154
+ self,
155
+ hidden_states,
156
+ conditioning_emb=None,
157
+ attention_mask=None,
158
+ encoder_hidden_states=None,
159
+ encoder_attention_mask=None,
160
+ encoder_decoder_position_bias=None,
161
+ ):
162
+ hidden_states = self.layer[0](
163
+ hidden_states,
164
+ conditioning_emb=conditioning_emb,
165
+ attention_mask=attention_mask,
166
+ )
167
+
168
+ if encoder_hidden_states is not None:
169
+ encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
170
+ encoder_hidden_states.dtype
171
+ )
172
+
173
+ hidden_states = self.layer[1](
174
+ hidden_states,
175
+ key_value_states=encoder_hidden_states,
176
+ attention_mask=encoder_extended_attention_mask,
177
+ )
178
+
179
+ # Apply Film Conditional Feed Forward layer
180
+ hidden_states = self.layer[-1](hidden_states, conditioning_emb)
181
+
182
+ return (hidden_states,)
183
+
184
+
185
+ class T5LayerSelfAttentionCond(nn.Module):
186
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate):
187
+ super().__init__()
188
+ self.layer_norm = T5LayerNorm(d_model)
189
+ self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
190
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
191
+ self.dropout = nn.Dropout(dropout_rate)
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states,
196
+ conditioning_emb=None,
197
+ attention_mask=None,
198
+ ):
199
+ # pre_self_attention_layer_norm
200
+ normed_hidden_states = self.layer_norm(hidden_states)
201
+
202
+ if conditioning_emb is not None:
203
+ normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
204
+
205
+ # Self-attention block
206
+ attention_output = self.attention(normed_hidden_states)
207
+
208
+ hidden_states = hidden_states + self.dropout(attention_output)
209
+
210
+ return hidden_states
211
+
212
+
213
+ class T5LayerCrossAttention(nn.Module):
214
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
215
+ super().__init__()
216
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
217
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
218
+ self.dropout = nn.Dropout(dropout_rate)
219
+
220
+ def forward(
221
+ self,
222
+ hidden_states,
223
+ key_value_states=None,
224
+ attention_mask=None,
225
+ ):
226
+ normed_hidden_states = self.layer_norm(hidden_states)
227
+ attention_output = self.attention(
228
+ normed_hidden_states,
229
+ encoder_hidden_states=key_value_states,
230
+ attention_mask=attention_mask.squeeze(1),
231
+ )
232
+ layer_output = hidden_states + self.dropout(attention_output)
233
+ return layer_output
234
+
235
+
236
+ class T5LayerFFCond(nn.Module):
237
+ def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
238
+ super().__init__()
239
+ self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
240
+ self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
241
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
242
+ self.dropout = nn.Dropout(dropout_rate)
243
+
244
+ def forward(self, hidden_states, conditioning_emb=None):
245
+ forwarded_states = self.layer_norm(hidden_states)
246
+ if conditioning_emb is not None:
247
+ forwarded_states = self.film(forwarded_states, conditioning_emb)
248
+
249
+ forwarded_states = self.DenseReluDense(forwarded_states)
250
+ hidden_states = hidden_states + self.dropout(forwarded_states)
251
+ return hidden_states
252
+
253
+
254
+ class T5DenseGatedActDense(nn.Module):
255
+ def __init__(self, d_model, d_ff, dropout_rate):
256
+ super().__init__()
257
+ self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
258
+ self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
259
+ self.wo = nn.Linear(d_ff, d_model, bias=False)
260
+ self.dropout = nn.Dropout(dropout_rate)
261
+ self.act = NewGELUActivation()
262
+
263
+ def forward(self, hidden_states):
264
+ hidden_gelu = self.act(self.wi_0(hidden_states))
265
+ hidden_linear = self.wi_1(hidden_states)
266
+ hidden_states = hidden_gelu * hidden_linear
267
+ hidden_states = self.dropout(hidden_states)
268
+
269
+ hidden_states = self.wo(hidden_states)
270
+ return hidden_states
271
+
272
+
273
+ class T5LayerNorm(nn.Module):
274
+ def __init__(self, hidden_size, eps=1e-6):
275
+ """
276
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
277
+ """
278
+ super().__init__()
279
+ self.weight = nn.Parameter(torch.ones(hidden_size))
280
+ self.variance_epsilon = eps
281
+
282
+ def forward(self, hidden_states):
283
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
284
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
285
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
286
+ # half-precision inputs is done in fp32
287
+
288
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
289
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
290
+
291
+ # convert into half-precision if necessary
292
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
293
+ hidden_states = hidden_states.to(self.weight.dtype)
294
+
295
+ return self.weight * hidden_states
296
+
297
+
298
+ class NewGELUActivation(nn.Module):
299
+ """
300
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
301
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
302
+ """
303
+
304
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
305
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
306
+
307
+
308
+ class T5FiLMLayer(nn.Module):
309
+ """
310
+ FiLM Layer
311
+ """
312
+
313
+ def __init__(self, in_features, out_features):
314
+ super().__init__()
315
+ self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
316
+
317
+ def forward(self, x, conditioning_emb):
318
+ emb = self.scale_bias(conditioning_emb)
319
+ scale, shift = torch.chunk(emb, 2, -1)
320
+ x = x * (1 + scale) + shift
321
+ return x
Tiger Model/diffusiers-Tiger/models/transformer_2d.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..models.embeddings import ImagePositionalEmbeddings
23
+ from ..utils import BaseOutput, deprecate
24
+ from .attention import BasicTransformerBlock
25
+ from .embeddings import PatchEmbed
26
+ from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
+ from .modeling_utils import ModelMixin
28
+
29
+
30
+ @dataclass
31
+ class Transformer2DModelOutput(BaseOutput):
32
+ """
33
+ The output of [`Transformer2DModel`].
34
+
35
+ Args:
36
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
37
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
38
+ distributions for the unnoised latent pixels.
39
+ """
40
+
41
+ sample: torch.FloatTensor
42
+
43
+
44
+ class Transformer2DModel(ModelMixin, ConfigMixin):
45
+ """
46
+ A 2D Transformer model for image-like data.
47
+
48
+ Parameters:
49
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
50
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
51
+ in_channels (`int`, *optional*):
52
+ The number of channels in the input and output (specify if the input is **continuous**).
53
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
54
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
55
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
56
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
57
+ This is fixed during training since it is used to learn a number of position embeddings.
58
+ num_vector_embeds (`int`, *optional*):
59
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
60
+ Includes the class for the masked latent pixel.
61
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
62
+ num_embeds_ada_norm ( `int`, *optional*):
63
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
64
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
65
+ added to the hidden states.
66
+
67
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
68
+ attention_bias (`bool`, *optional*):
69
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
70
+ """
71
+
72
+ @register_to_config
73
+ def __init__(
74
+ self,
75
+ num_attention_heads: int = 16,
76
+ attention_head_dim: int = 88,
77
+ in_channels: Optional[int] = None,
78
+ out_channels: Optional[int] = None,
79
+ num_layers: int = 1,
80
+ dropout: float = 0.0,
81
+ norm_num_groups: int = 32,
82
+ cross_attention_dim: Optional[int] = None,
83
+ attention_bias: bool = False,
84
+ sample_size: Optional[int] = None,
85
+ num_vector_embeds: Optional[int] = None,
86
+ patch_size: Optional[int] = None,
87
+ activation_fn: str = "geglu",
88
+ num_embeds_ada_norm: Optional[int] = None,
89
+ use_linear_projection: bool = False,
90
+ only_cross_attention: bool = False,
91
+ upcast_attention: bool = False,
92
+ norm_type: str = "layer_norm",
93
+ norm_elementwise_affine: bool = True,
94
+ attention_type: str = "default",
95
+ ):
96
+ super().__init__()
97
+ self.use_linear_projection = use_linear_projection
98
+ self.num_attention_heads = num_attention_heads
99
+ self.attention_head_dim = attention_head_dim
100
+ inner_dim = num_attention_heads * attention_head_dim
101
+
102
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
103
+ # Define whether input is continuous or discrete depending on configuration
104
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
105
+ self.is_input_vectorized = num_vector_embeds is not None
106
+ self.is_input_patches = in_channels is not None and patch_size is not None
107
+
108
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
109
+ deprecation_message = (
110
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
111
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
112
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
113
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
114
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
115
+ )
116
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
117
+ norm_type = "ada_norm"
118
+
119
+ if self.is_input_continuous and self.is_input_vectorized:
120
+ raise ValueError(
121
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
122
+ " sure that either `in_channels` or `num_vector_embeds` is None."
123
+ )
124
+ elif self.is_input_vectorized and self.is_input_patches:
125
+ raise ValueError(
126
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
127
+ " sure that either `num_vector_embeds` or `num_patches` is None."
128
+ )
129
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
130
+ raise ValueError(
131
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
132
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
133
+ )
134
+
135
+ # 2. Define input layers
136
+ if self.is_input_continuous:
137
+ self.in_channels = in_channels
138
+
139
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
140
+ if use_linear_projection:
141
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
142
+ else:
143
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
144
+ elif self.is_input_vectorized:
145
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
146
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
147
+
148
+ self.height = sample_size
149
+ self.width = sample_size
150
+ self.num_vector_embeds = num_vector_embeds
151
+ self.num_latent_pixels = self.height * self.width
152
+
153
+ self.latent_image_embedding = ImagePositionalEmbeddings(
154
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
155
+ )
156
+ elif self.is_input_patches:
157
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
158
+
159
+ self.height = sample_size
160
+ self.width = sample_size
161
+
162
+ self.patch_size = patch_size
163
+ self.pos_embed = PatchEmbed(
164
+ height=sample_size,
165
+ width=sample_size,
166
+ patch_size=patch_size,
167
+ in_channels=in_channels,
168
+ embed_dim=inner_dim,
169
+ )
170
+
171
+ # 3. Define transformers blocks
172
+ self.transformer_blocks = nn.ModuleList(
173
+ [
174
+ BasicTransformerBlock(
175
+ inner_dim,
176
+ num_attention_heads,
177
+ attention_head_dim,
178
+ dropout=dropout,
179
+ cross_attention_dim=cross_attention_dim,
180
+ activation_fn=activation_fn,
181
+ num_embeds_ada_norm=num_embeds_ada_norm,
182
+ attention_bias=attention_bias,
183
+ only_cross_attention=only_cross_attention,
184
+ upcast_attention=upcast_attention,
185
+ norm_type=norm_type,
186
+ norm_elementwise_affine=norm_elementwise_affine,
187
+ attention_type=attention_type,
188
+ )
189
+ for d in range(num_layers)
190
+ ]
191
+ )
192
+
193
+ # 4. Define output layers
194
+ self.out_channels = in_channels if out_channels is None else out_channels
195
+ if self.is_input_continuous:
196
+ # TODO: should use out_channels for continuous projections
197
+ if use_linear_projection:
198
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
199
+ else:
200
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
201
+ elif self.is_input_vectorized:
202
+ self.norm_out = nn.LayerNorm(inner_dim)
203
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
204
+ elif self.is_input_patches:
205
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
206
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
207
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
208
+
209
+ self.gradient_checkpointing = False
210
+
211
+ def forward(
212
+ self,
213
+ hidden_states: torch.Tensor,
214
+ encoder_hidden_states: Optional[torch.Tensor] = None,
215
+ timestep: Optional[torch.LongTensor] = None,
216
+ class_labels: Optional[torch.LongTensor] = None,
217
+ cross_attention_kwargs: Dict[str, Any] = None,
218
+ attention_mask: Optional[torch.Tensor] = None,
219
+ encoder_attention_mask: Optional[torch.Tensor] = None,
220
+ return_dict: bool = True,
221
+ ):
222
+ """
223
+ The [`Transformer2DModel`] forward method.
224
+
225
+ Args:
226
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
227
+ Input `hidden_states`.
228
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
229
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
230
+ self-attention.
231
+ timestep ( `torch.LongTensor`, *optional*):
232
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
233
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
234
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
235
+ `AdaLayerZeroNorm`.
236
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
237
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
238
+
239
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
240
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
241
+
242
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
243
+ above. This bias will be added to the cross-attention scores.
244
+ return_dict (`bool`, *optional*, defaults to `True`):
245
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
246
+ tuple.
247
+
248
+ Returns:
249
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
250
+ `tuple` where the first element is the sample tensor.
251
+ """
252
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
253
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
254
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
255
+ # expects mask of shape:
256
+ # [batch, key_tokens]
257
+ # adds singleton query_tokens dimension:
258
+ # [batch, 1, key_tokens]
259
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
260
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
261
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
262
+ if attention_mask is not None and attention_mask.ndim == 2:
263
+ # assume that mask is expressed as:
264
+ # (1 = keep, 0 = discard)
265
+ # convert mask into a bias that can be added to attention scores:
266
+ # (keep = +0, discard = -10000.0)
267
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
268
+ attention_mask = attention_mask.unsqueeze(1)
269
+
270
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
271
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
272
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
273
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
274
+
275
+ # 1. Input
276
+ if self.is_input_continuous:
277
+ batch, _, height, width = hidden_states.shape
278
+ residual = hidden_states
279
+
280
+ hidden_states = self.norm(hidden_states)
281
+ if not self.use_linear_projection:
282
+ hidden_states = self.proj_in(hidden_states)
283
+ inner_dim = hidden_states.shape[1]
284
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
285
+ else:
286
+ inner_dim = hidden_states.shape[1]
287
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
288
+ hidden_states = self.proj_in(hidden_states)
289
+ elif self.is_input_vectorized:
290
+ hidden_states = self.latent_image_embedding(hidden_states)
291
+ elif self.is_input_patches:
292
+ hidden_states = self.pos_embed(hidden_states)
293
+
294
+ # 2. Blocks
295
+ for block in self.transformer_blocks:
296
+ if self.training and self.gradient_checkpointing:
297
+ hidden_states = torch.utils.checkpoint.checkpoint(
298
+ block,
299
+ hidden_states,
300
+ attention_mask,
301
+ encoder_hidden_states,
302
+ encoder_attention_mask,
303
+ timestep,
304
+ cross_attention_kwargs,
305
+ class_labels,
306
+ use_reentrant=False,
307
+ )
308
+ else:
309
+ hidden_states = block(
310
+ hidden_states,
311
+ attention_mask=attention_mask,
312
+ encoder_hidden_states=encoder_hidden_states,
313
+ encoder_attention_mask=encoder_attention_mask,
314
+ timestep=timestep,
315
+ cross_attention_kwargs=cross_attention_kwargs,
316
+ class_labels=class_labels,
317
+ )
318
+
319
+ # 3. Output
320
+ if self.is_input_continuous:
321
+ if not self.use_linear_projection:
322
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
323
+ hidden_states = self.proj_out(hidden_states)
324
+ else:
325
+ hidden_states = self.proj_out(hidden_states)
326
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
327
+
328
+ output = hidden_states + residual
329
+ elif self.is_input_vectorized:
330
+ hidden_states = self.norm_out(hidden_states)
331
+ logits = self.out(hidden_states)
332
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
333
+ logits = logits.permute(0, 2, 1)
334
+
335
+ # log(p(x_0))
336
+ output = F.log_softmax(logits.double(), dim=1).float()
337
+ elif self.is_input_patches:
338
+ # TODO: cleanup!
339
+ conditioning = self.transformer_blocks[0].norm1.emb(
340
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
341
+ )
342
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
343
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
344
+ hidden_states = self.proj_out_2(hidden_states)
345
+
346
+ # unpatchify
347
+ height = width = int(hidden_states.shape[1] ** 0.5)
348
+ hidden_states = hidden_states.reshape(
349
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
350
+ )
351
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
352
+ output = hidden_states.reshape(
353
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
354
+ )
355
+
356
+ if not return_dict:
357
+ return (output,)
358
+
359
+ return Transformer2DModelOutput(sample=output)
Tiger Model/diffusiers-Tiger/models/transformer_temporal.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput
22
+ from .attention import BasicTransformerBlock
23
+ from .modeling_utils import ModelMixin
24
+
25
+
26
+ @dataclass
27
+ class TransformerTemporalModelOutput(BaseOutput):
28
+ """
29
+ The output of [`TransformerTemporalModel`].
30
+
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33
+ The hidden states output conditioned on `encoder_hidden_states` input.
34
+ """
35
+
36
+ sample: torch.FloatTensor
37
+
38
+
39
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
40
+ """
41
+ A Transformer model for video-like data.
42
+
43
+ Parameters:
44
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46
+ in_channels (`int`, *optional*):
47
+ The number of channels in the input and output (specify if the input is **continuous**).
48
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52
+ This is fixed during training since it is used to learn a number of position embeddings.
53
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
54
+ attention_bias (`bool`, *optional*):
55
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
56
+ double_self_attention (`bool`, *optional*):
57
+ Configure if each `TransformerBlock` should contain two self-attention layers.
58
+ """
59
+
60
+ @register_to_config
61
+ def __init__(
62
+ self,
63
+ num_attention_heads: int = 16,
64
+ attention_head_dim: int = 88,
65
+ in_channels: Optional[int] = None,
66
+ out_channels: Optional[int] = None,
67
+ num_layers: int = 1,
68
+ dropout: float = 0.0,
69
+ norm_num_groups: int = 32,
70
+ cross_attention_dim: Optional[int] = None,
71
+ attention_bias: bool = False,
72
+ sample_size: Optional[int] = None,
73
+ activation_fn: str = "geglu",
74
+ norm_elementwise_affine: bool = True,
75
+ double_self_attention: bool = True,
76
+ ):
77
+ super().__init__()
78
+ self.num_attention_heads = num_attention_heads
79
+ self.attention_head_dim = attention_head_dim
80
+ inner_dim = num_attention_heads * attention_head_dim
81
+
82
+ self.in_channels = in_channels
83
+
84
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85
+ self.proj_in = nn.Linear(in_channels, inner_dim)
86
+
87
+ # 3. Define transformers blocks
88
+ self.transformer_blocks = nn.ModuleList(
89
+ [
90
+ BasicTransformerBlock(
91
+ inner_dim,
92
+ num_attention_heads,
93
+ attention_head_dim,
94
+ dropout=dropout,
95
+ cross_attention_dim=cross_attention_dim,
96
+ activation_fn=activation_fn,
97
+ attention_bias=attention_bias,
98
+ double_self_attention=double_self_attention,
99
+ norm_elementwise_affine=norm_elementwise_affine,
100
+ )
101
+ for d in range(num_layers)
102
+ ]
103
+ )
104
+
105
+ self.proj_out = nn.Linear(inner_dim, in_channels)
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ timestep=None,
112
+ class_labels=None,
113
+ num_frames=1,
114
+ cross_attention_kwargs=None,
115
+ return_dict: bool = True,
116
+ ):
117
+ """
118
+ The [`TransformerTemporal`] forward method.
119
+
120
+ Args:
121
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
122
+ Input hidden_states.
123
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125
+ self-attention.
126
+ timestep ( `torch.long`, *optional*):
127
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130
+ `AdaLayerZeroNorm`.
131
+ return_dict (`bool`, *optional*, defaults to `True`):
132
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
133
+ tuple.
134
+
135
+ Returns:
136
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
137
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
138
+ returned, otherwise a `tuple` where the first element is the sample tensor.
139
+ """
140
+ # 1. Input
141
+ batch_frames, channel, height, width = hidden_states.shape
142
+ batch_size = batch_frames // num_frames
143
+
144
+ residual = hidden_states
145
+
146
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
147
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
148
+
149
+ hidden_states = self.norm(hidden_states)
150
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
151
+
152
+ hidden_states = self.proj_in(hidden_states)
153
+
154
+ # 2. Blocks
155
+ for block in self.transformer_blocks:
156
+ hidden_states = block(
157
+ hidden_states,
158
+ encoder_hidden_states=encoder_hidden_states,
159
+ timestep=timestep,
160
+ cross_attention_kwargs=cross_attention_kwargs,
161
+ class_labels=class_labels,
162
+ )
163
+
164
+ # 3. Output
165
+ hidden_states = self.proj_out(hidden_states)
166
+ hidden_states = (
167
+ hidden_states[None, None, :]
168
+ .reshape(batch_size, height, width, channel, num_frames)
169
+ .permute(0, 3, 4, 1, 2)
170
+ .contiguous()
171
+ )
172
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
173
+
174
+ output = hidden_states + residual
175
+
176
+ if not return_dict:
177
+ return (output,)
178
+
179
+ return TransformerTemporalModelOutput(sample=output)
Tiger Model/diffusiers-Tiger/models/unet_1d.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput
23
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
24
+ from .modeling_utils import ModelMixin
25
+ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
26
+
27
+
28
+ @dataclass
29
+ class UNet1DOutput(BaseOutput):
30
+ """
31
+ The output of [`UNet1DModel`].
32
+
33
+ Args:
34
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
35
+ The hidden states output from the last layer of the model.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ class UNet1DModel(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
44
+
45
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
+ for all models (such as downloading or saving).
47
+
48
+ Parameters:
49
+ sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
50
+ in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
51
+ out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
52
+ extra_in_channels (`int`, *optional*, defaults to 0):
53
+ Number of additional channels to be added to the input of the first down block. Useful for cases where the
54
+ input data has more channels than what the model was initially designed for.
55
+ time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
56
+ freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
57
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
58
+ Whether to flip sin to cos for Fourier time embedding.
59
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
60
+ Tuple of downsample block types.
61
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
62
+ Tuple of upsample block types.
63
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
64
+ Tuple of block output channels.
65
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
66
+ out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
67
+ act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
68
+ norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
69
+ layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
70
+ downsample_each_block (`int`, *optional*, defaults to `False`):
71
+ Experimental feature for using a UNet without upsampling.
72
+ """
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ sample_size: int = 65536,
78
+ sample_rate: Optional[int] = None,
79
+ in_channels: int = 2,
80
+ out_channels: int = 2,
81
+ extra_in_channels: int = 0,
82
+ time_embedding_type: str = "fourier",
83
+ flip_sin_to_cos: bool = True,
84
+ use_timestep_embedding: bool = False,
85
+ freq_shift: float = 0.0,
86
+ down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
87
+ up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
88
+ mid_block_type: Tuple[str] = "UNetMidBlock1D",
89
+ out_block_type: str = None,
90
+ block_out_channels: Tuple[int] = (32, 32, 64),
91
+ act_fn: str = None,
92
+ norm_num_groups: int = 8,
93
+ layers_per_block: int = 1,
94
+ downsample_each_block: bool = False,
95
+ ):
96
+ super().__init__()
97
+ self.sample_size = sample_size
98
+
99
+ # time
100
+ if time_embedding_type == "fourier":
101
+ self.time_proj = GaussianFourierProjection(
102
+ embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
103
+ )
104
+ timestep_input_dim = 2 * block_out_channels[0]
105
+ elif time_embedding_type == "positional":
106
+ self.time_proj = Timesteps(
107
+ block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
108
+ )
109
+ timestep_input_dim = block_out_channels[0]
110
+
111
+ if use_timestep_embedding:
112
+ time_embed_dim = block_out_channels[0] * 4
113
+ self.time_mlp = TimestepEmbedding(
114
+ in_channels=timestep_input_dim,
115
+ time_embed_dim=time_embed_dim,
116
+ act_fn=act_fn,
117
+ out_dim=block_out_channels[0],
118
+ )
119
+
120
+ self.down_blocks = nn.ModuleList([])
121
+ self.mid_block = None
122
+ self.up_blocks = nn.ModuleList([])
123
+ self.out_block = None
124
+
125
+ # down
126
+ output_channel = in_channels
127
+ for i, down_block_type in enumerate(down_block_types):
128
+ input_channel = output_channel
129
+ output_channel = block_out_channels[i]
130
+
131
+ if i == 0:
132
+ input_channel += extra_in_channels
133
+
134
+ is_final_block = i == len(block_out_channels) - 1
135
+
136
+ down_block = get_down_block(
137
+ down_block_type,
138
+ num_layers=layers_per_block,
139
+ in_channels=input_channel,
140
+ out_channels=output_channel,
141
+ temb_channels=block_out_channels[0],
142
+ add_downsample=not is_final_block or downsample_each_block,
143
+ )
144
+ self.down_blocks.append(down_block)
145
+
146
+ # mid
147
+ self.mid_block = get_mid_block(
148
+ mid_block_type,
149
+ in_channels=block_out_channels[-1],
150
+ mid_channels=block_out_channels[-1],
151
+ out_channels=block_out_channels[-1],
152
+ embed_dim=block_out_channels[0],
153
+ num_layers=layers_per_block,
154
+ add_downsample=downsample_each_block,
155
+ )
156
+
157
+ # up
158
+ reversed_block_out_channels = list(reversed(block_out_channels))
159
+ output_channel = reversed_block_out_channels[0]
160
+ if out_block_type is None:
161
+ final_upsample_channels = out_channels
162
+ else:
163
+ final_upsample_channels = block_out_channels[0]
164
+
165
+ for i, up_block_type in enumerate(up_block_types):
166
+ prev_output_channel = output_channel
167
+ output_channel = (
168
+ reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
169
+ )
170
+
171
+ is_final_block = i == len(block_out_channels) - 1
172
+
173
+ up_block = get_up_block(
174
+ up_block_type,
175
+ num_layers=layers_per_block,
176
+ in_channels=prev_output_channel,
177
+ out_channels=output_channel,
178
+ temb_channels=block_out_channels[0],
179
+ add_upsample=not is_final_block,
180
+ )
181
+ self.up_blocks.append(up_block)
182
+ prev_output_channel = output_channel
183
+
184
+ # out
185
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
186
+ self.out_block = get_out_block(
187
+ out_block_type=out_block_type,
188
+ num_groups_out=num_groups_out,
189
+ embed_dim=block_out_channels[0],
190
+ out_channels=out_channels,
191
+ act_fn=act_fn,
192
+ fc_dim=block_out_channels[-1] // 4,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ sample: torch.FloatTensor,
198
+ timestep: Union[torch.Tensor, float, int],
199
+ return_dict: bool = True,
200
+ ) -> Union[UNet1DOutput, Tuple]:
201
+ r"""
202
+ The [`UNet1DModel`] forward method.
203
+
204
+ Args:
205
+ sample (`torch.FloatTensor`):
206
+ The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
208
+ return_dict (`bool`, *optional*, defaults to `True`):
209
+ Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
210
+
211
+ Returns:
212
+ [`~models.unet_1d.UNet1DOutput`] or `tuple`:
213
+ If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
214
+ returned where the first element is the sample tensor.
215
+ """
216
+
217
+ # 1. time
218
+ timesteps = timestep
219
+ if not torch.is_tensor(timesteps):
220
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
221
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
222
+ timesteps = timesteps[None].to(sample.device)
223
+
224
+ timestep_embed = self.time_proj(timesteps)
225
+ if self.config.use_timestep_embedding:
226
+ timestep_embed = self.time_mlp(timestep_embed)
227
+ else:
228
+ timestep_embed = timestep_embed[..., None]
229
+ timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
230
+ timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
231
+
232
+ # 2. down
233
+ down_block_res_samples = ()
234
+ for downsample_block in self.down_blocks:
235
+ sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
236
+ down_block_res_samples += res_samples
237
+
238
+ # 3. mid
239
+ if self.mid_block:
240
+ sample = self.mid_block(sample, timestep_embed)
241
+
242
+ # 4. up
243
+ for i, upsample_block in enumerate(self.up_blocks):
244
+ res_samples = down_block_res_samples[-1:]
245
+ down_block_res_samples = down_block_res_samples[:-1]
246
+ sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
247
+
248
+ # 5. post-process
249
+ if self.out_block:
250
+ sample = self.out_block(sample, timestep_embed)
251
+
252
+ if not return_dict:
253
+ return (sample,)
254
+
255
+ return UNet1DOutput(sample=sample)
Tiger Model/diffusiers-Tiger/models/unet_1d_blocks.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from .activations import get_activation
21
+ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
22
+
23
+
24
+ class DownResnetBlock1D(nn.Module):
25
+ def __init__(
26
+ self,
27
+ in_channels,
28
+ out_channels=None,
29
+ num_layers=1,
30
+ conv_shortcut=False,
31
+ temb_channels=32,
32
+ groups=32,
33
+ groups_out=None,
34
+ non_linearity=None,
35
+ time_embedding_norm="default",
36
+ output_scale_factor=1.0,
37
+ add_downsample=True,
38
+ ):
39
+ super().__init__()
40
+ self.in_channels = in_channels
41
+ out_channels = in_channels if out_channels is None else out_channels
42
+ self.out_channels = out_channels
43
+ self.use_conv_shortcut = conv_shortcut
44
+ self.time_embedding_norm = time_embedding_norm
45
+ self.add_downsample = add_downsample
46
+ self.output_scale_factor = output_scale_factor
47
+
48
+ if groups_out is None:
49
+ groups_out = groups
50
+
51
+ # there will always be at least one resnet
52
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
53
+
54
+ for _ in range(num_layers):
55
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
56
+
57
+ self.resnets = nn.ModuleList(resnets)
58
+
59
+ if non_linearity is None:
60
+ self.nonlinearity = None
61
+ else:
62
+ self.nonlinearity = get_activation(non_linearity)
63
+
64
+ self.downsample = None
65
+ if add_downsample:
66
+ self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
67
+
68
+ def forward(self, hidden_states, temb=None):
69
+ output_states = ()
70
+
71
+ hidden_states = self.resnets[0](hidden_states, temb)
72
+ for resnet in self.resnets[1:]:
73
+ hidden_states = resnet(hidden_states, temb)
74
+
75
+ output_states += (hidden_states,)
76
+
77
+ if self.nonlinearity is not None:
78
+ hidden_states = self.nonlinearity(hidden_states)
79
+
80
+ if self.downsample is not None:
81
+ hidden_states = self.downsample(hidden_states)
82
+
83
+ return hidden_states, output_states
84
+
85
+
86
+ class UpResnetBlock1D(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ out_channels=None,
91
+ num_layers=1,
92
+ temb_channels=32,
93
+ groups=32,
94
+ groups_out=None,
95
+ non_linearity=None,
96
+ time_embedding_norm="default",
97
+ output_scale_factor=1.0,
98
+ add_upsample=True,
99
+ ):
100
+ super().__init__()
101
+ self.in_channels = in_channels
102
+ out_channels = in_channels if out_channels is None else out_channels
103
+ self.out_channels = out_channels
104
+ self.time_embedding_norm = time_embedding_norm
105
+ self.add_upsample = add_upsample
106
+ self.output_scale_factor = output_scale_factor
107
+
108
+ if groups_out is None:
109
+ groups_out = groups
110
+
111
+ # there will always be at least one resnet
112
+ resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
113
+
114
+ for _ in range(num_layers):
115
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
116
+
117
+ self.resnets = nn.ModuleList(resnets)
118
+
119
+ if non_linearity is None:
120
+ self.nonlinearity = None
121
+ else:
122
+ self.nonlinearity = get_activation(non_linearity)
123
+
124
+ self.upsample = None
125
+ if add_upsample:
126
+ self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
127
+
128
+ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
129
+ if res_hidden_states_tuple is not None:
130
+ res_hidden_states = res_hidden_states_tuple[-1]
131
+ hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
132
+
133
+ hidden_states = self.resnets[0](hidden_states, temb)
134
+ for resnet in self.resnets[1:]:
135
+ hidden_states = resnet(hidden_states, temb)
136
+
137
+ if self.nonlinearity is not None:
138
+ hidden_states = self.nonlinearity(hidden_states)
139
+
140
+ if self.upsample is not None:
141
+ hidden_states = self.upsample(hidden_states)
142
+
143
+ return hidden_states
144
+
145
+
146
+ class ValueFunctionMidBlock1D(nn.Module):
147
+ def __init__(self, in_channels, out_channels, embed_dim):
148
+ super().__init__()
149
+ self.in_channels = in_channels
150
+ self.out_channels = out_channels
151
+ self.embed_dim = embed_dim
152
+
153
+ self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
154
+ self.down1 = Downsample1D(out_channels // 2, use_conv=True)
155
+ self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
156
+ self.down2 = Downsample1D(out_channels // 4, use_conv=True)
157
+
158
+ def forward(self, x, temb=None):
159
+ x = self.res1(x, temb)
160
+ x = self.down1(x)
161
+ x = self.res2(x, temb)
162
+ x = self.down2(x)
163
+ return x
164
+
165
+
166
+ class MidResTemporalBlock1D(nn.Module):
167
+ def __init__(
168
+ self,
169
+ in_channels,
170
+ out_channels,
171
+ embed_dim,
172
+ num_layers: int = 1,
173
+ add_downsample: bool = False,
174
+ add_upsample: bool = False,
175
+ non_linearity=None,
176
+ ):
177
+ super().__init__()
178
+ self.in_channels = in_channels
179
+ self.out_channels = out_channels
180
+ self.add_downsample = add_downsample
181
+
182
+ # there will always be at least one resnet
183
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
184
+
185
+ for _ in range(num_layers):
186
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
187
+
188
+ self.resnets = nn.ModuleList(resnets)
189
+
190
+ if non_linearity is None:
191
+ self.nonlinearity = None
192
+ else:
193
+ self.nonlinearity = get_activation(non_linearity)
194
+
195
+ self.upsample = None
196
+ if add_upsample:
197
+ self.upsample = Downsample1D(out_channels, use_conv=True)
198
+
199
+ self.downsample = None
200
+ if add_downsample:
201
+ self.downsample = Downsample1D(out_channels, use_conv=True)
202
+
203
+ if self.upsample and self.downsample:
204
+ raise ValueError("Block cannot downsample and upsample")
205
+
206
+ def forward(self, hidden_states, temb):
207
+ hidden_states = self.resnets[0](hidden_states, temb)
208
+ for resnet in self.resnets[1:]:
209
+ hidden_states = resnet(hidden_states, temb)
210
+
211
+ if self.upsample:
212
+ hidden_states = self.upsample(hidden_states)
213
+ if self.downsample:
214
+ self.downsample = self.downsample(hidden_states)
215
+
216
+ return hidden_states
217
+
218
+
219
+ class OutConv1DBlock(nn.Module):
220
+ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
221
+ super().__init__()
222
+ self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
223
+ self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
224
+ self.final_conv1d_act = get_activation(act_fn)
225
+ self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
226
+
227
+ def forward(self, hidden_states, temb=None):
228
+ hidden_states = self.final_conv1d_1(hidden_states)
229
+ hidden_states = rearrange_dims(hidden_states)
230
+ hidden_states = self.final_conv1d_gn(hidden_states)
231
+ hidden_states = rearrange_dims(hidden_states)
232
+ hidden_states = self.final_conv1d_act(hidden_states)
233
+ hidden_states = self.final_conv1d_2(hidden_states)
234
+ return hidden_states
235
+
236
+
237
+ class OutValueFunctionBlock(nn.Module):
238
+ def __init__(self, fc_dim, embed_dim, act_fn="mish"):
239
+ super().__init__()
240
+ self.final_block = nn.ModuleList(
241
+ [
242
+ nn.Linear(fc_dim + embed_dim, fc_dim // 2),
243
+ get_activation(act_fn),
244
+ nn.Linear(fc_dim // 2, 1),
245
+ ]
246
+ )
247
+
248
+ def forward(self, hidden_states, temb):
249
+ hidden_states = hidden_states.view(hidden_states.shape[0], -1)
250
+ hidden_states = torch.cat((hidden_states, temb), dim=-1)
251
+ for layer in self.final_block:
252
+ hidden_states = layer(hidden_states)
253
+
254
+ return hidden_states
255
+
256
+
257
+ _kernels = {
258
+ "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
259
+ "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
260
+ "lanczos3": [
261
+ 0.003689131001010537,
262
+ 0.015056144446134567,
263
+ -0.03399861603975296,
264
+ -0.066637322306633,
265
+ 0.13550527393817902,
266
+ 0.44638532400131226,
267
+ 0.44638532400131226,
268
+ 0.13550527393817902,
269
+ -0.066637322306633,
270
+ -0.03399861603975296,
271
+ 0.015056144446134567,
272
+ 0.003689131001010537,
273
+ ],
274
+ }
275
+
276
+
277
+ class Downsample1d(nn.Module):
278
+ def __init__(self, kernel="linear", pad_mode="reflect"):
279
+ super().__init__()
280
+ self.pad_mode = pad_mode
281
+ kernel_1d = torch.tensor(_kernels[kernel])
282
+ self.pad = kernel_1d.shape[0] // 2 - 1
283
+ self.register_buffer("kernel", kernel_1d)
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
287
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
288
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
289
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
290
+ weight[indices, indices] = kernel
291
+ return F.conv1d(hidden_states, weight, stride=2)
292
+
293
+
294
+ class Upsample1d(nn.Module):
295
+ def __init__(self, kernel="linear", pad_mode="reflect"):
296
+ super().__init__()
297
+ self.pad_mode = pad_mode
298
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
299
+ self.pad = kernel_1d.shape[0] // 2 - 1
300
+ self.register_buffer("kernel", kernel_1d)
301
+
302
+ def forward(self, hidden_states, temb=None):
303
+ hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
304
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
305
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
306
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
307
+ weight[indices, indices] = kernel
308
+ return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
309
+
310
+
311
+ class SelfAttention1d(nn.Module):
312
+ def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
313
+ super().__init__()
314
+ self.channels = in_channels
315
+ self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
316
+ self.num_heads = n_head
317
+
318
+ self.query = nn.Linear(self.channels, self.channels)
319
+ self.key = nn.Linear(self.channels, self.channels)
320
+ self.value = nn.Linear(self.channels, self.channels)
321
+
322
+ self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
323
+
324
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
325
+
326
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
327
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
328
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
329
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
330
+ return new_projection
331
+
332
+ def forward(self, hidden_states):
333
+ residual = hidden_states
334
+ batch, channel_dim, seq = hidden_states.shape
335
+
336
+ hidden_states = self.group_norm(hidden_states)
337
+ hidden_states = hidden_states.transpose(1, 2)
338
+
339
+ query_proj = self.query(hidden_states)
340
+ key_proj = self.key(hidden_states)
341
+ value_proj = self.value(hidden_states)
342
+
343
+ query_states = self.transpose_for_scores(query_proj)
344
+ key_states = self.transpose_for_scores(key_proj)
345
+ value_states = self.transpose_for_scores(value_proj)
346
+
347
+ scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
348
+
349
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
350
+ attention_probs = torch.softmax(attention_scores, dim=-1)
351
+
352
+ # compute attention output
353
+ hidden_states = torch.matmul(attention_probs, value_states)
354
+
355
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
356
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
357
+ hidden_states = hidden_states.view(new_hidden_states_shape)
358
+
359
+ # compute next hidden_states
360
+ hidden_states = self.proj_attn(hidden_states)
361
+ hidden_states = hidden_states.transpose(1, 2)
362
+ hidden_states = self.dropout(hidden_states)
363
+
364
+ output = hidden_states + residual
365
+
366
+ return output
367
+
368
+
369
+ class ResConvBlock(nn.Module):
370
+ def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
371
+ super().__init__()
372
+ self.is_last = is_last
373
+ self.has_conv_skip = in_channels != out_channels
374
+
375
+ if self.has_conv_skip:
376
+ self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
377
+
378
+ self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
379
+ self.group_norm_1 = nn.GroupNorm(1, mid_channels)
380
+ self.gelu_1 = nn.GELU()
381
+ self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
382
+
383
+ if not self.is_last:
384
+ self.group_norm_2 = nn.GroupNorm(1, out_channels)
385
+ self.gelu_2 = nn.GELU()
386
+
387
+ def forward(self, hidden_states):
388
+ residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
389
+
390
+ hidden_states = self.conv_1(hidden_states)
391
+ hidden_states = self.group_norm_1(hidden_states)
392
+ hidden_states = self.gelu_1(hidden_states)
393
+ hidden_states = self.conv_2(hidden_states)
394
+
395
+ if not self.is_last:
396
+ hidden_states = self.group_norm_2(hidden_states)
397
+ hidden_states = self.gelu_2(hidden_states)
398
+
399
+ output = hidden_states + residual
400
+ return output
401
+
402
+
403
+ class UNetMidBlock1D(nn.Module):
404
+ def __init__(self, mid_channels, in_channels, out_channels=None):
405
+ super().__init__()
406
+
407
+ out_channels = in_channels if out_channels is None else out_channels
408
+
409
+ # there is always at least one resnet
410
+ self.down = Downsample1d("cubic")
411
+ resnets = [
412
+ ResConvBlock(in_channels, mid_channels, mid_channels),
413
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
414
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
415
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
416
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
417
+ ResConvBlock(mid_channels, mid_channels, out_channels),
418
+ ]
419
+ attentions = [
420
+ SelfAttention1d(mid_channels, mid_channels // 32),
421
+ SelfAttention1d(mid_channels, mid_channels // 32),
422
+ SelfAttention1d(mid_channels, mid_channels // 32),
423
+ SelfAttention1d(mid_channels, mid_channels // 32),
424
+ SelfAttention1d(mid_channels, mid_channels // 32),
425
+ SelfAttention1d(out_channels, out_channels // 32),
426
+ ]
427
+ self.up = Upsample1d(kernel="cubic")
428
+
429
+ self.attentions = nn.ModuleList(attentions)
430
+ self.resnets = nn.ModuleList(resnets)
431
+
432
+ def forward(self, hidden_states, temb=None):
433
+ hidden_states = self.down(hidden_states)
434
+ for attn, resnet in zip(self.attentions, self.resnets):
435
+ hidden_states = resnet(hidden_states)
436
+ hidden_states = attn(hidden_states)
437
+
438
+ hidden_states = self.up(hidden_states)
439
+
440
+ return hidden_states
441
+
442
+
443
+ class AttnDownBlock1D(nn.Module):
444
+ def __init__(self, out_channels, in_channels, mid_channels=None):
445
+ super().__init__()
446
+ mid_channels = out_channels if mid_channels is None else mid_channels
447
+
448
+ self.down = Downsample1d("cubic")
449
+ resnets = [
450
+ ResConvBlock(in_channels, mid_channels, mid_channels),
451
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
452
+ ResConvBlock(mid_channels, mid_channels, out_channels),
453
+ ]
454
+ attentions = [
455
+ SelfAttention1d(mid_channels, mid_channels // 32),
456
+ SelfAttention1d(mid_channels, mid_channels // 32),
457
+ SelfAttention1d(out_channels, out_channels // 32),
458
+ ]
459
+
460
+ self.attentions = nn.ModuleList(attentions)
461
+ self.resnets = nn.ModuleList(resnets)
462
+
463
+ def forward(self, hidden_states, temb=None):
464
+ hidden_states = self.down(hidden_states)
465
+
466
+ for resnet, attn in zip(self.resnets, self.attentions):
467
+ hidden_states = resnet(hidden_states)
468
+ hidden_states = attn(hidden_states)
469
+
470
+ return hidden_states, (hidden_states,)
471
+
472
+
473
+ class DownBlock1D(nn.Module):
474
+ def __init__(self, out_channels, in_channels, mid_channels=None):
475
+ super().__init__()
476
+ mid_channels = out_channels if mid_channels is None else mid_channels
477
+
478
+ self.down = Downsample1d("cubic")
479
+ resnets = [
480
+ ResConvBlock(in_channels, mid_channels, mid_channels),
481
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
482
+ ResConvBlock(mid_channels, mid_channels, out_channels),
483
+ ]
484
+
485
+ self.resnets = nn.ModuleList(resnets)
486
+
487
+ def forward(self, hidden_states, temb=None):
488
+ hidden_states = self.down(hidden_states)
489
+
490
+ for resnet in self.resnets:
491
+ hidden_states = resnet(hidden_states)
492
+
493
+ return hidden_states, (hidden_states,)
494
+
495
+
496
+ class DownBlock1DNoSkip(nn.Module):
497
+ def __init__(self, out_channels, in_channels, mid_channels=None):
498
+ super().__init__()
499
+ mid_channels = out_channels if mid_channels is None else mid_channels
500
+
501
+ resnets = [
502
+ ResConvBlock(in_channels, mid_channels, mid_channels),
503
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
504
+ ResConvBlock(mid_channels, mid_channels, out_channels),
505
+ ]
506
+
507
+ self.resnets = nn.ModuleList(resnets)
508
+
509
+ def forward(self, hidden_states, temb=None):
510
+ hidden_states = torch.cat([hidden_states, temb], dim=1)
511
+ for resnet in self.resnets:
512
+ hidden_states = resnet(hidden_states)
513
+
514
+ return hidden_states, (hidden_states,)
515
+
516
+
517
+ class AttnUpBlock1D(nn.Module):
518
+ def __init__(self, in_channels, out_channels, mid_channels=None):
519
+ super().__init__()
520
+ mid_channels = out_channels if mid_channels is None else mid_channels
521
+
522
+ resnets = [
523
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
524
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
525
+ ResConvBlock(mid_channels, mid_channels, out_channels),
526
+ ]
527
+ attentions = [
528
+ SelfAttention1d(mid_channels, mid_channels // 32),
529
+ SelfAttention1d(mid_channels, mid_channels // 32),
530
+ SelfAttention1d(out_channels, out_channels // 32),
531
+ ]
532
+
533
+ self.attentions = nn.ModuleList(attentions)
534
+ self.resnets = nn.ModuleList(resnets)
535
+ self.up = Upsample1d(kernel="cubic")
536
+
537
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
538
+ res_hidden_states = res_hidden_states_tuple[-1]
539
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
540
+
541
+ for resnet, attn in zip(self.resnets, self.attentions):
542
+ hidden_states = resnet(hidden_states)
543
+ hidden_states = attn(hidden_states)
544
+
545
+ hidden_states = self.up(hidden_states)
546
+
547
+ return hidden_states
548
+
549
+
550
+ class UpBlock1D(nn.Module):
551
+ def __init__(self, in_channels, out_channels, mid_channels=None):
552
+ super().__init__()
553
+ mid_channels = in_channels if mid_channels is None else mid_channels
554
+
555
+ resnets = [
556
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
557
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
558
+ ResConvBlock(mid_channels, mid_channels, out_channels),
559
+ ]
560
+
561
+ self.resnets = nn.ModuleList(resnets)
562
+ self.up = Upsample1d(kernel="cubic")
563
+
564
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
565
+ res_hidden_states = res_hidden_states_tuple[-1]
566
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
567
+
568
+ for resnet in self.resnets:
569
+ hidden_states = resnet(hidden_states)
570
+
571
+ hidden_states = self.up(hidden_states)
572
+
573
+ return hidden_states
574
+
575
+
576
+ class UpBlock1DNoSkip(nn.Module):
577
+ def __init__(self, in_channels, out_channels, mid_channels=None):
578
+ super().__init__()
579
+ mid_channels = in_channels if mid_channels is None else mid_channels
580
+
581
+ resnets = [
582
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
583
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
584
+ ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
585
+ ]
586
+
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
590
+ res_hidden_states = res_hidden_states_tuple[-1]
591
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
592
+
593
+ for resnet in self.resnets:
594
+ hidden_states = resnet(hidden_states)
595
+
596
+ return hidden_states
597
+
598
+
599
+ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
600
+ if down_block_type == "DownResnetBlock1D":
601
+ return DownResnetBlock1D(
602
+ in_channels=in_channels,
603
+ num_layers=num_layers,
604
+ out_channels=out_channels,
605
+ temb_channels=temb_channels,
606
+ add_downsample=add_downsample,
607
+ )
608
+ elif down_block_type == "DownBlock1D":
609
+ return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
610
+ elif down_block_type == "AttnDownBlock1D":
611
+ return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
612
+ elif down_block_type == "DownBlock1DNoSkip":
613
+ return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
614
+ raise ValueError(f"{down_block_type} does not exist.")
615
+
616
+
617
+ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
618
+ if up_block_type == "UpResnetBlock1D":
619
+ return UpResnetBlock1D(
620
+ in_channels=in_channels,
621
+ num_layers=num_layers,
622
+ out_channels=out_channels,
623
+ temb_channels=temb_channels,
624
+ add_upsample=add_upsample,
625
+ )
626
+ elif up_block_type == "UpBlock1D":
627
+ return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
628
+ elif up_block_type == "AttnUpBlock1D":
629
+ return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
630
+ elif up_block_type == "UpBlock1DNoSkip":
631
+ return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
632
+ raise ValueError(f"{up_block_type} does not exist.")
633
+
634
+
635
+ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
636
+ if mid_block_type == "MidResTemporalBlock1D":
637
+ return MidResTemporalBlock1D(
638
+ num_layers=num_layers,
639
+ in_channels=in_channels,
640
+ out_channels=out_channels,
641
+ embed_dim=embed_dim,
642
+ add_downsample=add_downsample,
643
+ )
644
+ elif mid_block_type == "ValueFunctionMidBlock1D":
645
+ return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
646
+ elif mid_block_type == "UNetMidBlock1D":
647
+ return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
648
+ raise ValueError(f"{mid_block_type} does not exist.")
649
+
650
+
651
+ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
652
+ if out_block_type == "OutConv1DBlock":
653
+ return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
654
+ elif out_block_type == "ValueFunction":
655
+ return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
656
+ return None
Tiger Model/diffusiers-Tiger/models/unet_2d.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput
22
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
+ from .modeling_utils import ModelMixin
24
+ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
25
+
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ The output of [`UNet2DModel`].
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
+ The hidden states output from the last layer of the model.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class UNet2DModel(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
49
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
50
+ 1)`.
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
54
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
55
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
56
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
57
+ Whether to flip sin to cos for Fourier time embedding.
58
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
+ Tuple of downsample block types.
60
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
62
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
+ Tuple of upsample block types.
64
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
65
+ Tuple of block output channels.
66
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
67
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
68
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
69
+ downsample_type (`str`, *optional*, defaults to `conv`):
70
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
71
+ upsample_type (`str`, *optional*, defaults to `conv`):
72
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
73
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
74
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
75
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
76
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
77
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
78
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
79
+ class_embed_type (`str`, *optional*, defaults to `None`):
80
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
81
+ `"timestep"`, or `"identity"`.
82
+ num_class_embeds (`int`, *optional*, defaults to `None`):
83
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
84
+ conditioning with `class_embed_type` equal to `None`.
85
+ """
86
+
87
+ @register_to_config
88
+ def __init__(
89
+ self,
90
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
91
+ in_channels: int = 3,
92
+ out_channels: int = 3,
93
+ center_input_sample: bool = False,
94
+ time_embedding_type: str = "positional",
95
+ freq_shift: int = 0,
96
+ flip_sin_to_cos: bool = True,
97
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
98
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
99
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
100
+ layers_per_block: int = 2,
101
+ mid_block_scale_factor: float = 1,
102
+ downsample_padding: int = 1,
103
+ downsample_type: str = "conv",
104
+ upsample_type: str = "conv",
105
+ act_fn: str = "silu",
106
+ attention_head_dim: Optional[int] = 8,
107
+ norm_num_groups: int = 32,
108
+ norm_eps: float = 1e-5,
109
+ resnet_time_scale_shift: str = "default",
110
+ add_attention: bool = True,
111
+ class_embed_type: Optional[str] = None,
112
+ num_class_embeds: Optional[int] = None,
113
+ ):
114
+ super().__init__()
115
+
116
+ self.sample_size = sample_size
117
+ time_embed_dim = block_out_channels[0] * 4
118
+
119
+ # Check inputs
120
+ if len(down_block_types) != len(up_block_types):
121
+ raise ValueError(
122
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
123
+ )
124
+
125
+ if len(block_out_channels) != len(down_block_types):
126
+ raise ValueError(
127
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
128
+ )
129
+
130
+ # input
131
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
132
+
133
+ # time
134
+ if time_embedding_type == "fourier":
135
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
136
+ timestep_input_dim = 2 * block_out_channels[0]
137
+ elif time_embedding_type == "positional":
138
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
139
+ timestep_input_dim = block_out_channels[0]
140
+
141
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
142
+
143
+ # class embedding
144
+ if class_embed_type is None and num_class_embeds is not None:
145
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
146
+ elif class_embed_type == "timestep":
147
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
148
+ elif class_embed_type == "identity":
149
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
150
+ else:
151
+ self.class_embedding = None
152
+
153
+ self.down_blocks = nn.ModuleList([])
154
+ self.mid_block = None
155
+ self.up_blocks = nn.ModuleList([])
156
+
157
+ # down
158
+ output_channel = block_out_channels[0]
159
+ for i, down_block_type in enumerate(down_block_types):
160
+ input_channel = output_channel
161
+ output_channel = block_out_channels[i]
162
+ is_final_block = i == len(block_out_channels) - 1
163
+
164
+ down_block = get_down_block(
165
+ down_block_type,
166
+ num_layers=layers_per_block,
167
+ in_channels=input_channel,
168
+ out_channels=output_channel,
169
+ temb_channels=time_embed_dim,
170
+ add_downsample=not is_final_block,
171
+ resnet_eps=norm_eps,
172
+ resnet_act_fn=act_fn,
173
+ resnet_groups=norm_num_groups,
174
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
175
+ downsample_padding=downsample_padding,
176
+ resnet_time_scale_shift=resnet_time_scale_shift,
177
+ downsample_type=downsample_type,
178
+ )
179
+ self.down_blocks.append(down_block)
180
+
181
+ # mid
182
+ self.mid_block = UNetMidBlock2D(
183
+ in_channels=block_out_channels[-1],
184
+ temb_channels=time_embed_dim,
185
+ resnet_eps=norm_eps,
186
+ resnet_act_fn=act_fn,
187
+ output_scale_factor=mid_block_scale_factor,
188
+ resnet_time_scale_shift=resnet_time_scale_shift,
189
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
190
+ resnet_groups=norm_num_groups,
191
+ add_attention=add_attention,
192
+ )
193
+
194
+ # up
195
+ reversed_block_out_channels = list(reversed(block_out_channels))
196
+ output_channel = reversed_block_out_channels[0]
197
+ for i, up_block_type in enumerate(up_block_types):
198
+ prev_output_channel = output_channel
199
+ output_channel = reversed_block_out_channels[i]
200
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
201
+
202
+ is_final_block = i == len(block_out_channels) - 1
203
+
204
+ up_block = get_up_block(
205
+ up_block_type,
206
+ num_layers=layers_per_block + 1,
207
+ in_channels=input_channel,
208
+ out_channels=output_channel,
209
+ prev_output_channel=prev_output_channel,
210
+ temb_channels=time_embed_dim,
211
+ add_upsample=not is_final_block,
212
+ resnet_eps=norm_eps,
213
+ resnet_act_fn=act_fn,
214
+ resnet_groups=norm_num_groups,
215
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
216
+ resnet_time_scale_shift=resnet_time_scale_shift,
217
+ upsample_type=upsample_type,
218
+ )
219
+ self.up_blocks.append(up_block)
220
+ prev_output_channel = output_channel
221
+
222
+ # out
223
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
224
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
225
+ self.conv_act = nn.SiLU()
226
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
227
+
228
+ def forward(
229
+ self,
230
+ sample: torch.FloatTensor,
231
+ timestep: Union[torch.Tensor, float, int],
232
+ class_labels: Optional[torch.Tensor] = None,
233
+ return_dict: bool = True,
234
+ ) -> Union[UNet2DOutput, Tuple]:
235
+ r"""
236
+ The [`UNet2DModel`] forward method.
237
+
238
+ Args:
239
+ sample (`torch.FloatTensor`):
240
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
241
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
242
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
243
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
244
+ return_dict (`bool`, *optional*, defaults to `True`):
245
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
246
+
247
+ Returns:
248
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
249
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
250
+ returned where the first element is the sample tensor.
251
+ """
252
+ # 0. center input if necessary
253
+ if self.config.center_input_sample:
254
+ sample = 2 * sample - 1.0
255
+
256
+ # 1. time
257
+ timesteps = timestep
258
+ if not torch.is_tensor(timesteps):
259
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
260
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
261
+ timesteps = timesteps[None].to(sample.device)
262
+
263
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
264
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
265
+
266
+ t_emb = self.time_proj(timesteps)
267
+
268
+ # timesteps does not contain any weights and will always return f32 tensors
269
+ # but time_embedding might actually be running in fp16. so we need to cast here.
270
+ # there might be better ways to encapsulate this.
271
+ t_emb = t_emb.to(dtype=self.dtype)
272
+ emb = self.time_embedding(t_emb)
273
+
274
+ if self.class_embedding is not None:
275
+ if class_labels is None:
276
+ raise ValueError("class_labels should be provided when doing class conditioning")
277
+
278
+ if self.config.class_embed_type == "timestep":
279
+ class_labels = self.time_proj(class_labels)
280
+
281
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
282
+ emb = emb + class_emb
283
+
284
+ # 2. pre-process
285
+ skip_sample = sample
286
+ sample = self.conv_in(sample)
287
+
288
+ # 3. down
289
+ down_block_res_samples = (sample,)
290
+ for downsample_block in self.down_blocks:
291
+ if hasattr(downsample_block, "skip_conv"):
292
+ sample, res_samples, skip_sample = downsample_block(
293
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
294
+ )
295
+ else:
296
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
297
+
298
+ down_block_res_samples += res_samples
299
+
300
+ # 4. mid
301
+ sample = self.mid_block(sample, emb)
302
+
303
+ # 5. up
304
+ skip_sample = None
305
+ for upsample_block in self.up_blocks:
306
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
307
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
308
+
309
+ if hasattr(upsample_block, "skip_conv"):
310
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
311
+ else:
312
+ sample = upsample_block(sample, res_samples, emb)
313
+
314
+ # 6. post-process
315
+ sample = self.conv_norm_out(sample)
316
+ sample = self.conv_act(sample)
317
+ sample = self.conv_out(sample)
318
+
319
+ if skip_sample is not None:
320
+ sample += skip_sample
321
+
322
+ if self.config.time_embedding_type == "fourier":
323
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
324
+ sample = sample / timesteps
325
+
326
+ if not return_dict:
327
+ return (sample,)
328
+
329
+ return UNet2DOutput(sample=sample)
Tiger Model/diffusiers-Tiger/models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
Tiger Model/diffusiers-Tiger/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1009 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders import UNet2DConditionLoadersMixin
23
+ from ..utils import BaseOutput, logging
24
+ from .activations import get_activation
25
+ from .attention_processor import AttentionProcessor, AttnProcessor
26
+ from .embeddings import (
27
+ GaussianFourierProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
31
+ PositionNet,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from .modeling_utils import ModelMixin
39
+ from .unet_2d_blocks import (
40
+ UNetMidBlock2DCrossAttn,
41
+ UNetMidBlock2DSimpleCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ @dataclass
51
+ class UNet2DConditionOutput(BaseOutput):
52
+ """
53
+ The output of [`UNet2DConditionModel`].
54
+
55
+ Args:
56
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58
+ """
59
+
60
+ sample: torch.FloatTensor = None
61
+
62
+
63
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
64
+ r"""
65
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66
+ shaped output.
67
+
68
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69
+ for all models (such as downloading or saving).
70
+
71
+ Parameters:
72
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73
+ Height and width of input/output sample.
74
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
78
+ Whether to flip the sin to cos in the time embedding.
79
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81
+ The tuple of downsample blocks to use.
82
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
84
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
85
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
86
+ The tuple of upsample blocks to use.
87
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
88
+ Whether to include self-attention in the basic transformer blocks, see
89
+ [`~models.attention.BasicTransformerBlock`].
90
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
91
+ The tuple of output channels for each block.
92
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
93
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
94
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
95
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
96
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
97
+ If `None`, normalization and activation layers is skipped in post-processing.
98
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
99
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
100
+ The dimension of the cross attention features.
101
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
102
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
103
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
104
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
105
+ encoder_hid_dim (`int`, *optional*, defaults to None):
106
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
107
+ dimension to `cross_attention_dim`.
108
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
109
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
110
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
111
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
112
+ num_attention_heads (`int`, *optional*):
113
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
114
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
115
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
116
+ class_embed_type (`str`, *optional*, defaults to `None`):
117
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
118
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
119
+ addition_embed_type (`str`, *optional*, defaults to `None`):
120
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
121
+ "text". "text" will use the `TextTimeEmbedding` layer.
122
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
123
+ Dimension for the timestep embeddings.
124
+ num_class_embeds (`int`, *optional*, defaults to `None`):
125
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
126
+ class conditioning with `class_embed_type` equal to `None`.
127
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
128
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
129
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
130
+ An optional override for the dimension of the projected time embedding.
131
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
132
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
133
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
134
+ timestep_post_act (`str`, *optional*, defaults to `None`):
135
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
136
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
137
+ The dimension of `cond_proj` layer in the timestep embedding.
138
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
139
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
140
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
141
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
142
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
143
+ embeddings with the class embeddings.
144
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
145
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
146
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
147
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
148
+ otherwise.
149
+ """
150
+
151
+ _supports_gradient_checkpointing = True
152
+
153
+ @register_to_config
154
+ def __init__(
155
+ self,
156
+ sample_size: Optional[int] = None,
157
+ in_channels: int = 4,
158
+ out_channels: int = 4,
159
+ center_input_sample: bool = False,
160
+ flip_sin_to_cos: bool = True,
161
+ freq_shift: int = 0,
162
+ down_block_types: Tuple[str] = (
163
+ "CrossAttnDownBlock2D",
164
+ "CrossAttnDownBlock2D",
165
+ "CrossAttnDownBlock2D",
166
+ "DownBlock2D",
167
+ ),
168
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
169
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
170
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
171
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
172
+ layers_per_block: Union[int, Tuple[int]] = 2,
173
+ downsample_padding: int = 1,
174
+ mid_block_scale_factor: float = 1,
175
+ act_fn: str = "silu",
176
+ norm_num_groups: Optional[int] = 32,
177
+ norm_eps: float = 1e-5,
178
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
179
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
180
+ encoder_hid_dim: Optional[int] = None,
181
+ encoder_hid_dim_type: Optional[str] = None,
182
+ attention_head_dim: Union[int, Tuple[int]] = 8,
183
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
184
+ dual_cross_attention: bool = False,
185
+ use_linear_projection: bool = False,
186
+ class_embed_type: Optional[str] = None,
187
+ addition_embed_type: Optional[str] = None,
188
+ addition_time_embed_dim: Optional[int] = None,
189
+ num_class_embeds: Optional[int] = None,
190
+ upcast_attention: bool = False,
191
+ resnet_time_scale_shift: str = "default",
192
+ resnet_skip_time_act: bool = False,
193
+ resnet_out_scale_factor: int = 1.0,
194
+ time_embedding_type: str = "positional",
195
+ time_embedding_dim: Optional[int] = None,
196
+ time_embedding_act_fn: Optional[str] = None,
197
+ timestep_post_act: Optional[str] = None,
198
+ time_cond_proj_dim: Optional[int] = None,
199
+ conv_in_kernel: int = 3,
200
+ conv_out_kernel: int = 3,
201
+ projection_class_embeddings_input_dim: Optional[int] = None,
202
+ attention_type: str = "default",
203
+ class_embeddings_concat: bool = False,
204
+ mid_block_only_cross_attention: Optional[bool] = None,
205
+ cross_attention_norm: Optional[str] = None,
206
+ addition_embed_type_num_heads=64,
207
+ ):
208
+ super().__init__()
209
+
210
+ self.sample_size = sample_size
211
+
212
+ if num_attention_heads is not None:
213
+ raise ValueError(
214
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
215
+ )
216
+
217
+ # If `num_attention_heads` is not defined (which is the case for most models)
218
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
219
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
220
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
221
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
222
+ # which is why we correct for the naming here.
223
+ num_attention_heads = num_attention_heads or attention_head_dim
224
+
225
+ # Check inputs
226
+ if len(down_block_types) != len(up_block_types):
227
+ raise ValueError(
228
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
229
+ )
230
+
231
+ if len(block_out_channels) != len(down_block_types):
232
+ raise ValueError(
233
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
234
+ )
235
+
236
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
237
+ raise ValueError(
238
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
239
+ )
240
+
241
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
242
+ raise ValueError(
243
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
244
+ )
245
+
246
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
247
+ raise ValueError(
248
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ # input
262
+ conv_in_padding = (conv_in_kernel - 1) // 2
263
+ self.conv_in = nn.Conv2d(
264
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
265
+ )
266
+
267
+ # time
268
+ if time_embedding_type == "fourier":
269
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
270
+ if time_embed_dim % 2 != 0:
271
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
272
+ self.time_proj = GaussianFourierProjection(
273
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
274
+ )
275
+ timestep_input_dim = time_embed_dim
276
+ elif time_embedding_type == "positional":
277
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
278
+
279
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
280
+ timestep_input_dim = block_out_channels[0]
281
+ else:
282
+ raise ValueError(
283
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
284
+ )
285
+
286
+ self.time_embedding = TimestepEmbedding(
287
+ timestep_input_dim,
288
+ time_embed_dim,
289
+ act_fn=act_fn,
290
+ post_act_fn=timestep_post_act,
291
+ cond_proj_dim=time_cond_proj_dim,
292
+ )
293
+
294
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
295
+ encoder_hid_dim_type = "text_proj"
296
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
297
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
298
+
299
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
300
+ raise ValueError(
301
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
302
+ )
303
+
304
+ if encoder_hid_dim_type == "text_proj":
305
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
306
+ elif encoder_hid_dim_type == "text_image_proj":
307
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
308
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
309
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
310
+ self.encoder_hid_proj = TextImageProjection(
311
+ text_embed_dim=encoder_hid_dim,
312
+ image_embed_dim=cross_attention_dim,
313
+ cross_attention_dim=cross_attention_dim,
314
+ )
315
+ elif encoder_hid_dim_type == "image_proj":
316
+ # Kandinsky 2.2
317
+ self.encoder_hid_proj = ImageProjection(
318
+ image_embed_dim=encoder_hid_dim,
319
+ cross_attention_dim=cross_attention_dim,
320
+ )
321
+ elif encoder_hid_dim_type is not None:
322
+ raise ValueError(
323
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
324
+ )
325
+ else:
326
+ self.encoder_hid_proj = None
327
+
328
+ # class embedding
329
+ if class_embed_type is None and num_class_embeds is not None:
330
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
331
+ elif class_embed_type == "timestep":
332
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
333
+ elif class_embed_type == "identity":
334
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
335
+ elif class_embed_type == "projection":
336
+ if projection_class_embeddings_input_dim is None:
337
+ raise ValueError(
338
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
339
+ )
340
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
341
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
342
+ # 2. it projects from an arbitrary input dimension.
343
+ #
344
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
345
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
346
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
347
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
348
+ elif class_embed_type == "simple_projection":
349
+ if projection_class_embeddings_input_dim is None:
350
+ raise ValueError(
351
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
352
+ )
353
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
354
+ else:
355
+ self.class_embedding = None
356
+
357
+ if addition_embed_type == "text":
358
+ if encoder_hid_dim is not None:
359
+ text_time_embedding_from_dim = encoder_hid_dim
360
+ else:
361
+ text_time_embedding_from_dim = cross_attention_dim
362
+
363
+ self.add_embedding = TextTimeEmbedding(
364
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
365
+ )
366
+ elif addition_embed_type == "text_image":
367
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
368
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
369
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
370
+ self.add_embedding = TextImageTimeEmbedding(
371
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
372
+ )
373
+ elif addition_embed_type == "text_time":
374
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
375
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
376
+ elif addition_embed_type == "image":
377
+ # Kandinsky 2.2
378
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
379
+ elif addition_embed_type == "image_hint":
380
+ # Kandinsky 2.2 ControlNet
381
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
382
+ elif addition_embed_type is not None:
383
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
384
+
385
+ if time_embedding_act_fn is None:
386
+ self.time_embed_act = None
387
+ else:
388
+ self.time_embed_act = get_activation(time_embedding_act_fn)
389
+
390
+ self.down_blocks = nn.ModuleList([])
391
+ self.up_blocks = nn.ModuleList([])
392
+
393
+ if isinstance(only_cross_attention, bool):
394
+ if mid_block_only_cross_attention is None:
395
+ mid_block_only_cross_attention = only_cross_attention
396
+
397
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
398
+
399
+ if mid_block_only_cross_attention is None:
400
+ mid_block_only_cross_attention = False
401
+
402
+ if isinstance(num_attention_heads, int):
403
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
404
+
405
+ if isinstance(attention_head_dim, int):
406
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
407
+
408
+ if isinstance(cross_attention_dim, int):
409
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
410
+
411
+ if isinstance(layers_per_block, int):
412
+ layers_per_block = [layers_per_block] * len(down_block_types)
413
+
414
+ if isinstance(transformer_layers_per_block, int):
415
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
416
+
417
+ if class_embeddings_concat:
418
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
419
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
420
+ # regular time embeddings
421
+ blocks_time_embed_dim = time_embed_dim * 2
422
+ else:
423
+ blocks_time_embed_dim = time_embed_dim
424
+
425
+ # down
426
+ output_channel = block_out_channels[0]
427
+ for i, down_block_type in enumerate(down_block_types):
428
+ input_channel = output_channel
429
+ output_channel = block_out_channels[i]
430
+ is_final_block = i == len(block_out_channels) - 1
431
+
432
+ down_block = get_down_block(
433
+ down_block_type,
434
+ num_layers=layers_per_block[i],
435
+ transformer_layers_per_block=transformer_layers_per_block[i],
436
+ in_channels=input_channel,
437
+ out_channels=output_channel,
438
+ temb_channels=blocks_time_embed_dim,
439
+ add_downsample=not is_final_block,
440
+ resnet_eps=norm_eps,
441
+ resnet_act_fn=act_fn,
442
+ resnet_groups=norm_num_groups,
443
+ cross_attention_dim=cross_attention_dim[i],
444
+ num_attention_heads=num_attention_heads[i],
445
+ downsample_padding=downsample_padding,
446
+ dual_cross_attention=dual_cross_attention,
447
+ use_linear_projection=use_linear_projection,
448
+ only_cross_attention=only_cross_attention[i],
449
+ upcast_attention=upcast_attention,
450
+ resnet_time_scale_shift=resnet_time_scale_shift,
451
+ attention_type=attention_type,
452
+ resnet_skip_time_act=resnet_skip_time_act,
453
+ resnet_out_scale_factor=resnet_out_scale_factor,
454
+ cross_attention_norm=cross_attention_norm,
455
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
456
+ )
457
+ self.down_blocks.append(down_block)
458
+
459
+ # mid
460
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
461
+ self.mid_block = UNetMidBlock2DCrossAttn(
462
+ transformer_layers_per_block=transformer_layers_per_block[-1],
463
+ in_channels=block_out_channels[-1],
464
+ temb_channels=blocks_time_embed_dim,
465
+ resnet_eps=norm_eps,
466
+ resnet_act_fn=act_fn,
467
+ output_scale_factor=mid_block_scale_factor,
468
+ resnet_time_scale_shift=resnet_time_scale_shift,
469
+ cross_attention_dim=cross_attention_dim[-1],
470
+ num_attention_heads=num_attention_heads[-1],
471
+ resnet_groups=norm_num_groups,
472
+ dual_cross_attention=dual_cross_attention,
473
+ use_linear_projection=use_linear_projection,
474
+ upcast_attention=upcast_attention,
475
+ attention_type=attention_type,
476
+ )
477
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
478
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
479
+ in_channels=block_out_channels[-1],
480
+ temb_channels=blocks_time_embed_dim,
481
+ resnet_eps=norm_eps,
482
+ resnet_act_fn=act_fn,
483
+ output_scale_factor=mid_block_scale_factor,
484
+ cross_attention_dim=cross_attention_dim[-1],
485
+ attention_head_dim=attention_head_dim[-1],
486
+ resnet_groups=norm_num_groups,
487
+ resnet_time_scale_shift=resnet_time_scale_shift,
488
+ skip_time_act=resnet_skip_time_act,
489
+ only_cross_attention=mid_block_only_cross_attention,
490
+ cross_attention_norm=cross_attention_norm,
491
+ )
492
+ elif mid_block_type is None:
493
+ self.mid_block = None
494
+ else:
495
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
496
+
497
+ # count how many layers upsample the images
498
+ self.num_upsamplers = 0
499
+
500
+ # up
501
+ reversed_block_out_channels = list(reversed(block_out_channels))
502
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
503
+ reversed_layers_per_block = list(reversed(layers_per_block))
504
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
505
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
506
+ only_cross_attention = list(reversed(only_cross_attention))
507
+
508
+ output_channel = reversed_block_out_channels[0]
509
+ for i, up_block_type in enumerate(up_block_types):
510
+ is_final_block = i == len(block_out_channels) - 1
511
+
512
+ prev_output_channel = output_channel
513
+ output_channel = reversed_block_out_channels[i]
514
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
515
+
516
+ # add upsample block for all BUT final layer
517
+ if not is_final_block:
518
+ add_upsample = True
519
+ self.num_upsamplers += 1
520
+ else:
521
+ add_upsample = False
522
+
523
+ up_block = get_up_block(
524
+ up_block_type,
525
+ num_layers=reversed_layers_per_block[i] + 1,
526
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
527
+ in_channels=input_channel,
528
+ out_channels=output_channel,
529
+ prev_output_channel=prev_output_channel,
530
+ temb_channels=blocks_time_embed_dim,
531
+ add_upsample=add_upsample,
532
+ resnet_eps=norm_eps,
533
+ resnet_act_fn=act_fn,
534
+ resnet_groups=norm_num_groups,
535
+ cross_attention_dim=reversed_cross_attention_dim[i],
536
+ num_attention_heads=reversed_num_attention_heads[i],
537
+ dual_cross_attention=dual_cross_attention,
538
+ use_linear_projection=use_linear_projection,
539
+ only_cross_attention=only_cross_attention[i],
540
+ upcast_attention=upcast_attention,
541
+ resnet_time_scale_shift=resnet_time_scale_shift,
542
+ attention_type=attention_type,
543
+ resnet_skip_time_act=resnet_skip_time_act,
544
+ resnet_out_scale_factor=resnet_out_scale_factor,
545
+ cross_attention_norm=cross_attention_norm,
546
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
547
+ )
548
+ self.up_blocks.append(up_block)
549
+ prev_output_channel = output_channel
550
+
551
+ # out
552
+ if norm_num_groups is not None:
553
+ self.conv_norm_out = nn.GroupNorm(
554
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
555
+ )
556
+
557
+ self.conv_act = get_activation(act_fn)
558
+
559
+ else:
560
+ self.conv_norm_out = None
561
+ self.conv_act = None
562
+
563
+ conv_out_padding = (conv_out_kernel - 1) // 2
564
+ self.conv_out = nn.Conv2d(
565
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
566
+ )
567
+
568
+ if attention_type == "gated":
569
+ positive_len = 768
570
+ if isinstance(cross_attention_dim, int):
571
+ positive_len = cross_attention_dim
572
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
573
+ positive_len = cross_attention_dim[0]
574
+ self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
575
+
576
+ @property
577
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
578
+ r"""
579
+ Returns:
580
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
581
+ indexed by its weight name.
582
+ """
583
+ # set recursively
584
+ processors = {}
585
+
586
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
587
+ if hasattr(module, "set_processor"):
588
+ processors[f"{name}.processor"] = module.processor
589
+
590
+ for sub_name, child in module.named_children():
591
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
592
+
593
+ return processors
594
+
595
+ for name, module in self.named_children():
596
+ fn_recursive_add_processors(name, module, processors)
597
+
598
+ return processors
599
+
600
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
601
+ r"""
602
+ Sets the attention processor to use to compute attention.
603
+
604
+ Parameters:
605
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
606
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
607
+ for **all** `Attention` layers.
608
+
609
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
610
+ processor. This is strongly recommended when setting trainable attention processors.
611
+
612
+ """
613
+ count = len(self.attn_processors.keys())
614
+
615
+ if isinstance(processor, dict) and len(processor) != count:
616
+ raise ValueError(
617
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
618
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
619
+ )
620
+
621
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
622
+ if hasattr(module, "set_processor"):
623
+ if not isinstance(processor, dict):
624
+ module.set_processor(processor)
625
+ else:
626
+ module.set_processor(processor.pop(f"{name}.processor"))
627
+
628
+ for sub_name, child in module.named_children():
629
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
630
+
631
+ for name, module in self.named_children():
632
+ fn_recursive_attn_processor(name, module, processor)
633
+
634
+ def set_default_attn_processor(self):
635
+ """
636
+ Disables custom attention processors and sets the default attention implementation.
637
+ """
638
+ self.set_attn_processor(AttnProcessor())
639
+
640
+ def set_attention_slice(self, slice_size):
641
+ r"""
642
+ Enable sliced attention computation.
643
+
644
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
645
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
646
+
647
+ Args:
648
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
649
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
650
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
651
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
652
+ must be a multiple of `slice_size`.
653
+ """
654
+ sliceable_head_dims = []
655
+
656
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
657
+ if hasattr(module, "set_attention_slice"):
658
+ sliceable_head_dims.append(module.sliceable_head_dim)
659
+
660
+ for child in module.children():
661
+ fn_recursive_retrieve_sliceable_dims(child)
662
+
663
+ # retrieve number of attention layers
664
+ for module in self.children():
665
+ fn_recursive_retrieve_sliceable_dims(module)
666
+
667
+ num_sliceable_layers = len(sliceable_head_dims)
668
+
669
+ if slice_size == "auto":
670
+ # half the attention head size is usually a good trade-off between
671
+ # speed and memory
672
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
673
+ elif slice_size == "max":
674
+ # make smallest slice possible
675
+ slice_size = num_sliceable_layers * [1]
676
+
677
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
678
+
679
+ if len(slice_size) != len(sliceable_head_dims):
680
+ raise ValueError(
681
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
682
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
683
+ )
684
+
685
+ for i in range(len(slice_size)):
686
+ size = slice_size[i]
687
+ dim = sliceable_head_dims[i]
688
+ if size is not None and size > dim:
689
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
690
+
691
+ # Recursively walk through all the children.
692
+ # Any children which exposes the set_attention_slice method
693
+ # gets the message
694
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
695
+ if hasattr(module, "set_attention_slice"):
696
+ module.set_attention_slice(slice_size.pop())
697
+
698
+ for child in module.children():
699
+ fn_recursive_set_attention_slice(child, slice_size)
700
+
701
+ reversed_slice_size = list(reversed(slice_size))
702
+ for module in self.children():
703
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
704
+
705
+ def _set_gradient_checkpointing(self, module, value=False):
706
+ if hasattr(module, "gradient_checkpointing"):
707
+ module.gradient_checkpointing = value
708
+
709
+ def forward(
710
+ self,
711
+ sample: torch.FloatTensor,
712
+ timestep: Union[torch.Tensor, float, int],
713
+ encoder_hidden_states: torch.Tensor,
714
+ class_labels: Optional[torch.Tensor] = None,
715
+ timestep_cond: Optional[torch.Tensor] = None,
716
+ attention_mask: Optional[torch.Tensor] = None,
717
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
718
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
719
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
720
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
721
+ encoder_attention_mask: Optional[torch.Tensor] = None,
722
+ return_dict: bool = True,
723
+ ) -> Union[UNet2DConditionOutput, Tuple]:
724
+ r"""
725
+ The [`UNet2DConditionModel`] forward method.
726
+
727
+ Args:
728
+ sample (`torch.FloatTensor`):
729
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
730
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
731
+ encoder_hidden_states (`torch.FloatTensor`):
732
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
733
+ encoder_attention_mask (`torch.Tensor`):
734
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
735
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
736
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
737
+ return_dict (`bool`, *optional*, defaults to `True`):
738
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
739
+ tuple.
740
+ cross_attention_kwargs (`dict`, *optional*):
741
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
742
+ added_cond_kwargs: (`dict`, *optional*):
743
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
744
+ are passed along to the UNet blocks.
745
+
746
+ Returns:
747
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
748
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
749
+ a `tuple` is returned where the first element is the sample tensor.
750
+ """
751
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
752
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
753
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
754
+ # on the fly if necessary.
755
+ default_overall_up_factor = 2**self.num_upsamplers
756
+
757
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
758
+ forward_upsample_size = False
759
+ upsample_size = None
760
+
761
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
762
+ logger.info("Forward upsample size to force interpolation output size.")
763
+ forward_upsample_size = True
764
+
765
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
766
+ # expects mask of shape:
767
+ # [batch, key_tokens]
768
+ # adds singleton query_tokens dimension:
769
+ # [batch, 1, key_tokens]
770
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
771
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
772
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
773
+ if attention_mask is not None:
774
+ # assume that mask is expressed as:
775
+ # (1 = keep, 0 = discard)
776
+ # convert mask into a bias that can be added to attention scores:
777
+ # (keep = +0, discard = -10000.0)
778
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
779
+ attention_mask = attention_mask.unsqueeze(1)
780
+
781
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
782
+ if encoder_attention_mask is not None:
783
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
784
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
785
+
786
+ # 0. center input if necessary
787
+ if self.config.center_input_sample:
788
+ sample = 2 * sample - 1.0
789
+
790
+ # 1. time
791
+ timesteps = timestep
792
+ if not torch.is_tensor(timesteps):
793
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
794
+ # This would be a good case for the `match` statement (Python 3.10+)
795
+ is_mps = sample.device.type == "mps"
796
+ if isinstance(timestep, float):
797
+ dtype = torch.float32 if is_mps else torch.float64
798
+ else:
799
+ dtype = torch.int32 if is_mps else torch.int64
800
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
801
+ elif len(timesteps.shape) == 0:
802
+ timesteps = timesteps[None].to(sample.device)
803
+
804
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
805
+ timesteps = timesteps.expand(sample.shape[0])
806
+
807
+ t_emb = self.time_proj(timesteps)
808
+
809
+ # `Timesteps` does not contain any weights and will always return f32 tensors
810
+ # but time_embedding might actually be running in fp16. so we need to cast here.
811
+ # there might be better ways to encapsulate this.
812
+ t_emb = t_emb.to(dtype=sample.dtype)
813
+
814
+ emb = self.time_embedding(t_emb, timestep_cond)
815
+ aug_emb = None
816
+
817
+ if self.class_embedding is not None:
818
+ if class_labels is None:
819
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
820
+
821
+ if self.config.class_embed_type == "timestep":
822
+ class_labels = self.time_proj(class_labels)
823
+
824
+ # `Timesteps` does not contain any weights and will always return f32 tensors
825
+ # there might be better ways to encapsulate this.
826
+ class_labels = class_labels.to(dtype=sample.dtype)
827
+
828
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
829
+
830
+ if self.config.class_embeddings_concat:
831
+ emb = torch.cat([emb, class_emb], dim=-1)
832
+ else:
833
+ emb = emb + class_emb
834
+
835
+ if self.config.addition_embed_type == "text":
836
+ aug_emb = self.add_embedding(encoder_hidden_states)
837
+ elif self.config.addition_embed_type == "text_image":
838
+ # Kandinsky 2.1 - style
839
+ if "image_embeds" not in added_cond_kwargs:
840
+ raise ValueError(
841
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
842
+ )
843
+
844
+ image_embs = added_cond_kwargs.get("image_embeds")
845
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
846
+ aug_emb = self.add_embedding(text_embs, image_embs)
847
+ elif self.config.addition_embed_type == "text_time":
848
+ # SDXL - style
849
+ if "text_embeds" not in added_cond_kwargs:
850
+ raise ValueError(
851
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
852
+ )
853
+ text_embeds = added_cond_kwargs.get("text_embeds")
854
+ if "time_ids" not in added_cond_kwargs:
855
+ raise ValueError(
856
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
857
+ )
858
+ time_ids = added_cond_kwargs.get("time_ids")
859
+ time_embeds = self.add_time_proj(time_ids.flatten())
860
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
861
+
862
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
863
+ add_embeds = add_embeds.to(emb.dtype)
864
+ aug_emb = self.add_embedding(add_embeds)
865
+ elif self.config.addition_embed_type == "image":
866
+ # Kandinsky 2.2 - style
867
+ if "image_embeds" not in added_cond_kwargs:
868
+ raise ValueError(
869
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
870
+ )
871
+ image_embs = added_cond_kwargs.get("image_embeds")
872
+ aug_emb = self.add_embedding(image_embs)
873
+ elif self.config.addition_embed_type == "image_hint":
874
+ # Kandinsky 2.2 - style
875
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
876
+ raise ValueError(
877
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
878
+ )
879
+ image_embs = added_cond_kwargs.get("image_embeds")
880
+ hint = added_cond_kwargs.get("hint")
881
+ aug_emb, hint = self.add_embedding(image_embs, hint)
882
+ sample = torch.cat([sample, hint], dim=1)
883
+
884
+ emb = emb + aug_emb if aug_emb is not None else emb
885
+
886
+ if self.time_embed_act is not None:
887
+ emb = self.time_embed_act(emb)
888
+
889
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
890
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
891
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
892
+ # Kadinsky 2.1 - style
893
+ if "image_embeds" not in added_cond_kwargs:
894
+ raise ValueError(
895
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
896
+ )
897
+
898
+ image_embeds = added_cond_kwargs.get("image_embeds")
899
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
900
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
901
+ # Kandinsky 2.2 - style
902
+ if "image_embeds" not in added_cond_kwargs:
903
+ raise ValueError(
904
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
905
+ )
906
+ image_embeds = added_cond_kwargs.get("image_embeds")
907
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
908
+ # 2. pre-process
909
+ sample = self.conv_in(sample)
910
+
911
+ # 2.5 GLIGEN position net
912
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
913
+ cross_attention_kwargs = cross_attention_kwargs.copy()
914
+ gligen_args = cross_attention_kwargs.pop("gligen")
915
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
916
+
917
+ # 3. down
918
+
919
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
920
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
921
+
922
+ down_block_res_samples = (sample,)
923
+ for downsample_block in self.down_blocks:
924
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
925
+ # For t2i-adapter CrossAttnDownBlock2D
926
+ additional_residuals = {}
927
+ if is_adapter and len(down_block_additional_residuals) > 0:
928
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
929
+
930
+ sample, res_samples = downsample_block(
931
+ hidden_states=sample,
932
+ temb=emb,
933
+ encoder_hidden_states=encoder_hidden_states,
934
+ attention_mask=attention_mask,
935
+ cross_attention_kwargs=cross_attention_kwargs,
936
+ encoder_attention_mask=encoder_attention_mask,
937
+ **additional_residuals,
938
+ )
939
+ else:
940
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
941
+
942
+ if is_adapter and len(down_block_additional_residuals) > 0:
943
+ sample += down_block_additional_residuals.pop(0)
944
+
945
+ down_block_res_samples += res_samples
946
+
947
+ if is_controlnet:
948
+ new_down_block_res_samples = ()
949
+
950
+ for down_block_res_sample, down_block_additional_residual in zip(
951
+ down_block_res_samples, down_block_additional_residuals
952
+ ):
953
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
954
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
955
+
956
+ down_block_res_samples = new_down_block_res_samples
957
+
958
+ # 4. mid
959
+ if self.mid_block is not None:
960
+ sample = self.mid_block(
961
+ sample,
962
+ emb,
963
+ encoder_hidden_states=encoder_hidden_states,
964
+ attention_mask=attention_mask,
965
+ cross_attention_kwargs=cross_attention_kwargs,
966
+ encoder_attention_mask=encoder_attention_mask,
967
+ )
968
+
969
+ if is_controlnet:
970
+ sample = sample + mid_block_additional_residual
971
+
972
+ # 5. up
973
+ for i, upsample_block in enumerate(self.up_blocks):
974
+ is_final_block = i == len(self.up_blocks) - 1
975
+
976
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
977
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
978
+
979
+ # if we have not reached the final block and need to forward the
980
+ # upsample size, we do it here
981
+ if not is_final_block and forward_upsample_size:
982
+ upsample_size = down_block_res_samples[-1].shape[2:]
983
+
984
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
985
+ sample = upsample_block(
986
+ hidden_states=sample,
987
+ temb=emb,
988
+ res_hidden_states_tuple=res_samples,
989
+ encoder_hidden_states=encoder_hidden_states,
990
+ cross_attention_kwargs=cross_attention_kwargs,
991
+ upsample_size=upsample_size,
992
+ attention_mask=attention_mask,
993
+ encoder_attention_mask=encoder_attention_mask,
994
+ )
995
+ else:
996
+ sample = upsample_block(
997
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
998
+ )
999
+
1000
+ # 6. post-process
1001
+ if self.conv_norm_out:
1002
+ sample = self.conv_norm_out(sample)
1003
+ sample = self.conv_act(sample)
1004
+ sample = self.conv_out(sample)
1005
+
1006
+ if not return_dict:
1007
+ return (sample,)
1008
+
1009
+ return UNet2DConditionOutput(sample=sample)
Tiger Model/diffusiers-Tiger/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
+ from .transformer_2d import Transformer2DModel
20
+ from .transformer_temporal import TransformerTemporalModel
21
+
22
+
23
+ def get_down_block(
24
+ down_block_type,
25
+ num_layers,
26
+ in_channels,
27
+ out_channels,
28
+ temb_channels,
29
+ add_downsample,
30
+ resnet_eps,
31
+ resnet_act_fn,
32
+ num_attention_heads,
33
+ resnet_groups=None,
34
+ cross_attention_dim=None,
35
+ downsample_padding=None,
36
+ dual_cross_attention=False,
37
+ use_linear_projection=True,
38
+ only_cross_attention=False,
39
+ upcast_attention=False,
40
+ resnet_time_scale_shift="default",
41
+ ):
42
+ if down_block_type == "DownBlock3D":
43
+ return DownBlock3D(
44
+ num_layers=num_layers,
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ temb_channels=temb_channels,
48
+ add_downsample=add_downsample,
49
+ resnet_eps=resnet_eps,
50
+ resnet_act_fn=resnet_act_fn,
51
+ resnet_groups=resnet_groups,
52
+ downsample_padding=downsample_padding,
53
+ resnet_time_scale_shift=resnet_time_scale_shift,
54
+ )
55
+ elif down_block_type == "CrossAttnDownBlock3D":
56
+ if cross_attention_dim is None:
57
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
58
+ return CrossAttnDownBlock3D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ add_downsample=add_downsample,
64
+ resnet_eps=resnet_eps,
65
+ resnet_act_fn=resnet_act_fn,
66
+ resnet_groups=resnet_groups,
67
+ downsample_padding=downsample_padding,
68
+ cross_attention_dim=cross_attention_dim,
69
+ num_attention_heads=num_attention_heads,
70
+ dual_cross_attention=dual_cross_attention,
71
+ use_linear_projection=use_linear_projection,
72
+ only_cross_attention=only_cross_attention,
73
+ upcast_attention=upcast_attention,
74
+ resnet_time_scale_shift=resnet_time_scale_shift,
75
+ )
76
+ raise ValueError(f"{down_block_type} does not exist.")
77
+
78
+
79
+ def get_up_block(
80
+ up_block_type,
81
+ num_layers,
82
+ in_channels,
83
+ out_channels,
84
+ prev_output_channel,
85
+ temb_channels,
86
+ add_upsample,
87
+ resnet_eps,
88
+ resnet_act_fn,
89
+ num_attention_heads,
90
+ resnet_groups=None,
91
+ cross_attention_dim=None,
92
+ dual_cross_attention=False,
93
+ use_linear_projection=True,
94
+ only_cross_attention=False,
95
+ upcast_attention=False,
96
+ resnet_time_scale_shift="default",
97
+ ):
98
+ if up_block_type == "UpBlock3D":
99
+ return UpBlock3D(
100
+ num_layers=num_layers,
101
+ in_channels=in_channels,
102
+ out_channels=out_channels,
103
+ prev_output_channel=prev_output_channel,
104
+ temb_channels=temb_channels,
105
+ add_upsample=add_upsample,
106
+ resnet_eps=resnet_eps,
107
+ resnet_act_fn=resnet_act_fn,
108
+ resnet_groups=resnet_groups,
109
+ resnet_time_scale_shift=resnet_time_scale_shift,
110
+ )
111
+ elif up_block_type == "CrossAttnUpBlock3D":
112
+ if cross_attention_dim is None:
113
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
114
+ return CrossAttnUpBlock3D(
115
+ num_layers=num_layers,
116
+ in_channels=in_channels,
117
+ out_channels=out_channels,
118
+ prev_output_channel=prev_output_channel,
119
+ temb_channels=temb_channels,
120
+ add_upsample=add_upsample,
121
+ resnet_eps=resnet_eps,
122
+ resnet_act_fn=resnet_act_fn,
123
+ resnet_groups=resnet_groups,
124
+ cross_attention_dim=cross_attention_dim,
125
+ num_attention_heads=num_attention_heads,
126
+ dual_cross_attention=dual_cross_attention,
127
+ use_linear_projection=use_linear_projection,
128
+ only_cross_attention=only_cross_attention,
129
+ upcast_attention=upcast_attention,
130
+ resnet_time_scale_shift=resnet_time_scale_shift,
131
+ )
132
+ raise ValueError(f"{up_block_type} does not exist.")
133
+
134
+
135
+ class UNetMidBlock3DCrossAttn(nn.Module):
136
+ def __init__(
137
+ self,
138
+ in_channels: int,
139
+ temb_channels: int,
140
+ dropout: float = 0.0,
141
+ num_layers: int = 1,
142
+ resnet_eps: float = 1e-6,
143
+ resnet_time_scale_shift: str = "default",
144
+ resnet_act_fn: str = "swish",
145
+ resnet_groups: int = 32,
146
+ resnet_pre_norm: bool = True,
147
+ num_attention_heads=1,
148
+ output_scale_factor=1.0,
149
+ cross_attention_dim=1280,
150
+ dual_cross_attention=False,
151
+ use_linear_projection=True,
152
+ upcast_attention=False,
153
+ ):
154
+ super().__init__()
155
+
156
+ self.has_cross_attention = True
157
+ self.num_attention_heads = num_attention_heads
158
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
159
+
160
+ # there is always at least one resnet
161
+ resnets = [
162
+ ResnetBlock2D(
163
+ in_channels=in_channels,
164
+ out_channels=in_channels,
165
+ temb_channels=temb_channels,
166
+ eps=resnet_eps,
167
+ groups=resnet_groups,
168
+ dropout=dropout,
169
+ time_embedding_norm=resnet_time_scale_shift,
170
+ non_linearity=resnet_act_fn,
171
+ output_scale_factor=output_scale_factor,
172
+ pre_norm=resnet_pre_norm,
173
+ )
174
+ ]
175
+ temp_convs = [
176
+ TemporalConvLayer(
177
+ in_channels,
178
+ in_channels,
179
+ dropout=0.1,
180
+ )
181
+ ]
182
+ attentions = []
183
+ temp_attentions = []
184
+
185
+ for _ in range(num_layers):
186
+ attentions.append(
187
+ Transformer2DModel(
188
+ in_channels // num_attention_heads,
189
+ num_attention_heads,
190
+ in_channels=in_channels,
191
+ num_layers=1,
192
+ cross_attention_dim=cross_attention_dim,
193
+ norm_num_groups=resnet_groups,
194
+ use_linear_projection=use_linear_projection,
195
+ upcast_attention=upcast_attention,
196
+ )
197
+ )
198
+ temp_attentions.append(
199
+ TransformerTemporalModel(
200
+ in_channels // num_attention_heads,
201
+ num_attention_heads,
202
+ in_channels=in_channels,
203
+ num_layers=1,
204
+ cross_attention_dim=cross_attention_dim,
205
+ norm_num_groups=resnet_groups,
206
+ )
207
+ )
208
+ resnets.append(
209
+ ResnetBlock2D(
210
+ in_channels=in_channels,
211
+ out_channels=in_channels,
212
+ temb_channels=temb_channels,
213
+ eps=resnet_eps,
214
+ groups=resnet_groups,
215
+ dropout=dropout,
216
+ time_embedding_norm=resnet_time_scale_shift,
217
+ non_linearity=resnet_act_fn,
218
+ output_scale_factor=output_scale_factor,
219
+ pre_norm=resnet_pre_norm,
220
+ )
221
+ )
222
+ temp_convs.append(
223
+ TemporalConvLayer(
224
+ in_channels,
225
+ in_channels,
226
+ dropout=0.1,
227
+ )
228
+ )
229
+
230
+ self.resnets = nn.ModuleList(resnets)
231
+ self.temp_convs = nn.ModuleList(temp_convs)
232
+ self.attentions = nn.ModuleList(attentions)
233
+ self.temp_attentions = nn.ModuleList(temp_attentions)
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states,
238
+ temb=None,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ num_frames=1,
242
+ cross_attention_kwargs=None,
243
+ ):
244
+ hidden_states = self.resnets[0](hidden_states, temb)
245
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
246
+ for attn, temp_attn, resnet, temp_conv in zip(
247
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
248
+ ):
249
+ hidden_states = attn(
250
+ hidden_states,
251
+ encoder_hidden_states=encoder_hidden_states,
252
+ cross_attention_kwargs=cross_attention_kwargs,
253
+ return_dict=False,
254
+ )[0]
255
+ hidden_states = temp_attn(
256
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
257
+ )[0]
258
+ hidden_states = resnet(hidden_states, temb)
259
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
260
+
261
+ return hidden_states
262
+
263
+
264
+ class CrossAttnDownBlock3D(nn.Module):
265
+ def __init__(
266
+ self,
267
+ in_channels: int,
268
+ out_channels: int,
269
+ temb_channels: int,
270
+ dropout: float = 0.0,
271
+ num_layers: int = 1,
272
+ resnet_eps: float = 1e-6,
273
+ resnet_time_scale_shift: str = "default",
274
+ resnet_act_fn: str = "swish",
275
+ resnet_groups: int = 32,
276
+ resnet_pre_norm: bool = True,
277
+ num_attention_heads=1,
278
+ cross_attention_dim=1280,
279
+ output_scale_factor=1.0,
280
+ downsample_padding=1,
281
+ add_downsample=True,
282
+ dual_cross_attention=False,
283
+ use_linear_projection=False,
284
+ only_cross_attention=False,
285
+ upcast_attention=False,
286
+ ):
287
+ super().__init__()
288
+ resnets = []
289
+ attentions = []
290
+ temp_attentions = []
291
+ temp_convs = []
292
+
293
+ self.has_cross_attention = True
294
+ self.num_attention_heads = num_attention_heads
295
+
296
+ for i in range(num_layers):
297
+ in_channels = in_channels if i == 0 else out_channels
298
+ resnets.append(
299
+ ResnetBlock2D(
300
+ in_channels=in_channels,
301
+ out_channels=out_channels,
302
+ temb_channels=temb_channels,
303
+ eps=resnet_eps,
304
+ groups=resnet_groups,
305
+ dropout=dropout,
306
+ time_embedding_norm=resnet_time_scale_shift,
307
+ non_linearity=resnet_act_fn,
308
+ output_scale_factor=output_scale_factor,
309
+ pre_norm=resnet_pre_norm,
310
+ )
311
+ )
312
+ temp_convs.append(
313
+ TemporalConvLayer(
314
+ out_channels,
315
+ out_channels,
316
+ dropout=0.1,
317
+ )
318
+ )
319
+ attentions.append(
320
+ Transformer2DModel(
321
+ out_channels // num_attention_heads,
322
+ num_attention_heads,
323
+ in_channels=out_channels,
324
+ num_layers=1,
325
+ cross_attention_dim=cross_attention_dim,
326
+ norm_num_groups=resnet_groups,
327
+ use_linear_projection=use_linear_projection,
328
+ only_cross_attention=only_cross_attention,
329
+ upcast_attention=upcast_attention,
330
+ )
331
+ )
332
+ temp_attentions.append(
333
+ TransformerTemporalModel(
334
+ out_channels // num_attention_heads,
335
+ num_attention_heads,
336
+ in_channels=out_channels,
337
+ num_layers=1,
338
+ cross_attention_dim=cross_attention_dim,
339
+ norm_num_groups=resnet_groups,
340
+ )
341
+ )
342
+ self.resnets = nn.ModuleList(resnets)
343
+ self.temp_convs = nn.ModuleList(temp_convs)
344
+ self.attentions = nn.ModuleList(attentions)
345
+ self.temp_attentions = nn.ModuleList(temp_attentions)
346
+
347
+ if add_downsample:
348
+ self.downsamplers = nn.ModuleList(
349
+ [
350
+ Downsample2D(
351
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
352
+ )
353
+ ]
354
+ )
355
+ else:
356
+ self.downsamplers = None
357
+
358
+ self.gradient_checkpointing = False
359
+
360
+ def forward(
361
+ self,
362
+ hidden_states,
363
+ temb=None,
364
+ encoder_hidden_states=None,
365
+ attention_mask=None,
366
+ num_frames=1,
367
+ cross_attention_kwargs=None,
368
+ ):
369
+ # TODO(Patrick, William) - attention mask is not used
370
+ output_states = ()
371
+
372
+ for resnet, temp_conv, attn, temp_attn in zip(
373
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
374
+ ):
375
+ hidden_states = resnet(hidden_states, temb)
376
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
377
+ hidden_states = attn(
378
+ hidden_states,
379
+ encoder_hidden_states=encoder_hidden_states,
380
+ cross_attention_kwargs=cross_attention_kwargs,
381
+ return_dict=False,
382
+ )[0]
383
+ hidden_states = temp_attn(
384
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
385
+ )[0]
386
+
387
+ output_states += (hidden_states,)
388
+
389
+ if self.downsamplers is not None:
390
+ for downsampler in self.downsamplers:
391
+ hidden_states = downsampler(hidden_states)
392
+
393
+ output_states += (hidden_states,)
394
+
395
+ return hidden_states, output_states
396
+
397
+
398
+ class DownBlock3D(nn.Module):
399
+ def __init__(
400
+ self,
401
+ in_channels: int,
402
+ out_channels: int,
403
+ temb_channels: int,
404
+ dropout: float = 0.0,
405
+ num_layers: int = 1,
406
+ resnet_eps: float = 1e-6,
407
+ resnet_time_scale_shift: str = "default",
408
+ resnet_act_fn: str = "swish",
409
+ resnet_groups: int = 32,
410
+ resnet_pre_norm: bool = True,
411
+ output_scale_factor=1.0,
412
+ add_downsample=True,
413
+ downsample_padding=1,
414
+ ):
415
+ super().__init__()
416
+ resnets = []
417
+ temp_convs = []
418
+
419
+ for i in range(num_layers):
420
+ in_channels = in_channels if i == 0 else out_channels
421
+ resnets.append(
422
+ ResnetBlock2D(
423
+ in_channels=in_channels,
424
+ out_channels=out_channels,
425
+ temb_channels=temb_channels,
426
+ eps=resnet_eps,
427
+ groups=resnet_groups,
428
+ dropout=dropout,
429
+ time_embedding_norm=resnet_time_scale_shift,
430
+ non_linearity=resnet_act_fn,
431
+ output_scale_factor=output_scale_factor,
432
+ pre_norm=resnet_pre_norm,
433
+ )
434
+ )
435
+ temp_convs.append(
436
+ TemporalConvLayer(
437
+ out_channels,
438
+ out_channels,
439
+ dropout=0.1,
440
+ )
441
+ )
442
+
443
+ self.resnets = nn.ModuleList(resnets)
444
+ self.temp_convs = nn.ModuleList(temp_convs)
445
+
446
+ if add_downsample:
447
+ self.downsamplers = nn.ModuleList(
448
+ [
449
+ Downsample2D(
450
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
451
+ )
452
+ ]
453
+ )
454
+ else:
455
+ self.downsamplers = None
456
+
457
+ self.gradient_checkpointing = False
458
+
459
+ def forward(self, hidden_states, temb=None, num_frames=1):
460
+ output_states = ()
461
+
462
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
463
+ hidden_states = resnet(hidden_states, temb)
464
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
465
+
466
+ output_states += (hidden_states,)
467
+
468
+ if self.downsamplers is not None:
469
+ for downsampler in self.downsamplers:
470
+ hidden_states = downsampler(hidden_states)
471
+
472
+ output_states += (hidden_states,)
473
+
474
+ return hidden_states, output_states
475
+
476
+
477
+ class CrossAttnUpBlock3D(nn.Module):
478
+ def __init__(
479
+ self,
480
+ in_channels: int,
481
+ out_channels: int,
482
+ prev_output_channel: int,
483
+ temb_channels: int,
484
+ dropout: float = 0.0,
485
+ num_layers: int = 1,
486
+ resnet_eps: float = 1e-6,
487
+ resnet_time_scale_shift: str = "default",
488
+ resnet_act_fn: str = "swish",
489
+ resnet_groups: int = 32,
490
+ resnet_pre_norm: bool = True,
491
+ num_attention_heads=1,
492
+ cross_attention_dim=1280,
493
+ output_scale_factor=1.0,
494
+ add_upsample=True,
495
+ dual_cross_attention=False,
496
+ use_linear_projection=False,
497
+ only_cross_attention=False,
498
+ upcast_attention=False,
499
+ ):
500
+ super().__init__()
501
+ resnets = []
502
+ temp_convs = []
503
+ attentions = []
504
+ temp_attentions = []
505
+
506
+ self.has_cross_attention = True
507
+ self.num_attention_heads = num_attention_heads
508
+
509
+ for i in range(num_layers):
510
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
511
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
512
+
513
+ resnets.append(
514
+ ResnetBlock2D(
515
+ in_channels=resnet_in_channels + res_skip_channels,
516
+ out_channels=out_channels,
517
+ temb_channels=temb_channels,
518
+ eps=resnet_eps,
519
+ groups=resnet_groups,
520
+ dropout=dropout,
521
+ time_embedding_norm=resnet_time_scale_shift,
522
+ non_linearity=resnet_act_fn,
523
+ output_scale_factor=output_scale_factor,
524
+ pre_norm=resnet_pre_norm,
525
+ )
526
+ )
527
+ temp_convs.append(
528
+ TemporalConvLayer(
529
+ out_channels,
530
+ out_channels,
531
+ dropout=0.1,
532
+ )
533
+ )
534
+ attentions.append(
535
+ Transformer2DModel(
536
+ out_channels // num_attention_heads,
537
+ num_attention_heads,
538
+ in_channels=out_channels,
539
+ num_layers=1,
540
+ cross_attention_dim=cross_attention_dim,
541
+ norm_num_groups=resnet_groups,
542
+ use_linear_projection=use_linear_projection,
543
+ only_cross_attention=only_cross_attention,
544
+ upcast_attention=upcast_attention,
545
+ )
546
+ )
547
+ temp_attentions.append(
548
+ TransformerTemporalModel(
549
+ out_channels // num_attention_heads,
550
+ num_attention_heads,
551
+ in_channels=out_channels,
552
+ num_layers=1,
553
+ cross_attention_dim=cross_attention_dim,
554
+ norm_num_groups=resnet_groups,
555
+ )
556
+ )
557
+ self.resnets = nn.ModuleList(resnets)
558
+ self.temp_convs = nn.ModuleList(temp_convs)
559
+ self.attentions = nn.ModuleList(attentions)
560
+ self.temp_attentions = nn.ModuleList(temp_attentions)
561
+
562
+ if add_upsample:
563
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
564
+ else:
565
+ self.upsamplers = None
566
+
567
+ self.gradient_checkpointing = False
568
+
569
+ def forward(
570
+ self,
571
+ hidden_states,
572
+ res_hidden_states_tuple,
573
+ temb=None,
574
+ encoder_hidden_states=None,
575
+ upsample_size=None,
576
+ attention_mask=None,
577
+ num_frames=1,
578
+ cross_attention_kwargs=None,
579
+ ):
580
+ # TODO(Patrick, William) - attention mask is not used
581
+ for resnet, temp_conv, attn, temp_attn in zip(
582
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
583
+ ):
584
+ # pop res hidden states
585
+ res_hidden_states = res_hidden_states_tuple[-1]
586
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
587
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
588
+
589
+ hidden_states = resnet(hidden_states, temb)
590
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
591
+ hidden_states = attn(
592
+ hidden_states,
593
+ encoder_hidden_states=encoder_hidden_states,
594
+ cross_attention_kwargs=cross_attention_kwargs,
595
+ return_dict=False,
596
+ )[0]
597
+ hidden_states = temp_attn(
598
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
599
+ )[0]
600
+
601
+ if self.upsamplers is not None:
602
+ for upsampler in self.upsamplers:
603
+ hidden_states = upsampler(hidden_states, upsample_size)
604
+
605
+ return hidden_states
606
+
607
+
608
+ class UpBlock3D(nn.Module):
609
+ def __init__(
610
+ self,
611
+ in_channels: int,
612
+ prev_output_channel: int,
613
+ out_channels: int,
614
+ temb_channels: int,
615
+ dropout: float = 0.0,
616
+ num_layers: int = 1,
617
+ resnet_eps: float = 1e-6,
618
+ resnet_time_scale_shift: str = "default",
619
+ resnet_act_fn: str = "swish",
620
+ resnet_groups: int = 32,
621
+ resnet_pre_norm: bool = True,
622
+ output_scale_factor=1.0,
623
+ add_upsample=True,
624
+ ):
625
+ super().__init__()
626
+ resnets = []
627
+ temp_convs = []
628
+
629
+ for i in range(num_layers):
630
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
631
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
632
+
633
+ resnets.append(
634
+ ResnetBlock2D(
635
+ in_channels=resnet_in_channels + res_skip_channels,
636
+ out_channels=out_channels,
637
+ temb_channels=temb_channels,
638
+ eps=resnet_eps,
639
+ groups=resnet_groups,
640
+ dropout=dropout,
641
+ time_embedding_norm=resnet_time_scale_shift,
642
+ non_linearity=resnet_act_fn,
643
+ output_scale_factor=output_scale_factor,
644
+ pre_norm=resnet_pre_norm,
645
+ )
646
+ )
647
+ temp_convs.append(
648
+ TemporalConvLayer(
649
+ out_channels,
650
+ out_channels,
651
+ dropout=0.1,
652
+ )
653
+ )
654
+
655
+ self.resnets = nn.ModuleList(resnets)
656
+ self.temp_convs = nn.ModuleList(temp_convs)
657
+
658
+ if add_upsample:
659
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
660
+ else:
661
+ self.upsamplers = None
662
+
663
+ self.gradient_checkpointing = False
664
+
665
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
666
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
667
+ # pop res hidden states
668
+ res_hidden_states = res_hidden_states_tuple[-1]
669
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
670
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
671
+
672
+ hidden_states = resnet(hidden_states, temb)
673
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
674
+
675
+ if self.upsamplers is not None:
676
+ for upsampler in self.upsamplers:
677
+ hidden_states = upsampler(hidden_states, upsample_size)
678
+
679
+ return hidden_states
Tiger Model/diffusiers-Tiger/models/unet_3d_condition.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..loaders import UNet2DConditionLoadersMixin
24
+ from ..utils import BaseOutput, logging
25
+ from .attention_processor import AttentionProcessor, AttnProcessor
26
+ from .embeddings import TimestepEmbedding, Timesteps
27
+ from .modeling_utils import ModelMixin
28
+ from .transformer_temporal import TransformerTemporalModel
29
+ from .unet_3d_blocks import (
30
+ CrossAttnDownBlock3D,
31
+ CrossAttnUpBlock3D,
32
+ DownBlock3D,
33
+ UNetMidBlock3DCrossAttn,
34
+ UpBlock3D,
35
+ get_down_block,
36
+ get_up_block,
37
+ )
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class UNet3DConditionOutput(BaseOutput):
45
+ """
46
+ The output of [`UNet3DConditionModel`].
47
+
48
+ Args:
49
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
50
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
51
+ """
52
+
53
+ sample: torch.FloatTensor
54
+
55
+
56
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
57
+ r"""
58
+ A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
59
+ shaped output.
60
+
61
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
62
+ for all models (such as downloading or saving).
63
+
64
+ Parameters:
65
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
66
+ Height and width of input/output sample.
67
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
68
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
69
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
70
+ The tuple of downsample blocks to use.
71
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
72
+ The tuple of upsample blocks to use.
73
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
74
+ The tuple of output channels for each block.
75
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
76
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
77
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
78
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
79
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
80
+ If `None`, normalization and activation layers is skipped in post-processing.
81
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
82
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
83
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
84
+ num_attention_heads (`int`, *optional*): The number of attention heads.
85
+ """
86
+
87
+ _supports_gradient_checkpointing = False
88
+
89
+ @register_to_config
90
+ def __init__(
91
+ self,
92
+ sample_size: Optional[int] = None,
93
+ in_channels: int = 4,
94
+ out_channels: int = 4,
95
+ down_block_types: Tuple[str] = (
96
+ "CrossAttnDownBlock3D",
97
+ "CrossAttnDownBlock3D",
98
+ "CrossAttnDownBlock3D",
99
+ "DownBlock3D",
100
+ ),
101
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
102
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
+ layers_per_block: int = 2,
104
+ downsample_padding: int = 1,
105
+ mid_block_scale_factor: float = 1,
106
+ act_fn: str = "silu",
107
+ norm_num_groups: Optional[int] = 32,
108
+ norm_eps: float = 1e-5,
109
+ cross_attention_dim: int = 1024,
110
+ attention_head_dim: Union[int, Tuple[int]] = 64,
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.sample_size = sample_size
116
+
117
+ if num_attention_heads is not None:
118
+ raise NotImplementedError(
119
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
120
+ )
121
+
122
+ # If `num_attention_heads` is not defined (which is the case for most models)
123
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
124
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
125
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
126
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
127
+ # which is why we correct for the naming here.
128
+ num_attention_heads = num_attention_heads or attention_head_dim
129
+
130
+ # Check inputs
131
+ if len(down_block_types) != len(up_block_types):
132
+ raise ValueError(
133
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
134
+ )
135
+
136
+ if len(block_out_channels) != len(down_block_types):
137
+ raise ValueError(
138
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
139
+ )
140
+
141
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
142
+ raise ValueError(
143
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
144
+ )
145
+
146
+ # input
147
+ conv_in_kernel = 3
148
+ conv_out_kernel = 3
149
+ conv_in_padding = (conv_in_kernel - 1) // 2
150
+ self.conv_in = nn.Conv2d(
151
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
152
+ )
153
+
154
+ # time
155
+ time_embed_dim = block_out_channels[0] * 4
156
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
157
+ timestep_input_dim = block_out_channels[0]
158
+
159
+ self.time_embedding = TimestepEmbedding(
160
+ timestep_input_dim,
161
+ time_embed_dim,
162
+ act_fn=act_fn,
163
+ )
164
+
165
+ self.transformer_in = TransformerTemporalModel(
166
+ num_attention_heads=8,
167
+ attention_head_dim=attention_head_dim,
168
+ in_channels=block_out_channels[0],
169
+ num_layers=1,
170
+ )
171
+
172
+ # class embedding
173
+ self.down_blocks = nn.ModuleList([])
174
+ self.up_blocks = nn.ModuleList([])
175
+
176
+ if isinstance(num_attention_heads, int):
177
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
178
+
179
+ # down
180
+ output_channel = block_out_channels[0]
181
+ for i, down_block_type in enumerate(down_block_types):
182
+ input_channel = output_channel
183
+ output_channel = block_out_channels[i]
184
+ is_final_block = i == len(block_out_channels) - 1
185
+
186
+ down_block = get_down_block(
187
+ down_block_type,
188
+ num_layers=layers_per_block,
189
+ in_channels=input_channel,
190
+ out_channels=output_channel,
191
+ temb_channels=time_embed_dim,
192
+ add_downsample=not is_final_block,
193
+ resnet_eps=norm_eps,
194
+ resnet_act_fn=act_fn,
195
+ resnet_groups=norm_num_groups,
196
+ cross_attention_dim=cross_attention_dim,
197
+ num_attention_heads=num_attention_heads[i],
198
+ downsample_padding=downsample_padding,
199
+ dual_cross_attention=False,
200
+ )
201
+ self.down_blocks.append(down_block)
202
+
203
+ # mid
204
+ self.mid_block = UNetMidBlock3DCrossAttn(
205
+ in_channels=block_out_channels[-1],
206
+ temb_channels=time_embed_dim,
207
+ resnet_eps=norm_eps,
208
+ resnet_act_fn=act_fn,
209
+ output_scale_factor=mid_block_scale_factor,
210
+ cross_attention_dim=cross_attention_dim,
211
+ num_attention_heads=num_attention_heads[-1],
212
+ resnet_groups=norm_num_groups,
213
+ dual_cross_attention=False,
214
+ )
215
+
216
+ # count how many layers upsample the images
217
+ self.num_upsamplers = 0
218
+
219
+ # up
220
+ reversed_block_out_channels = list(reversed(block_out_channels))
221
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
222
+
223
+ output_channel = reversed_block_out_channels[0]
224
+ for i, up_block_type in enumerate(up_block_types):
225
+ is_final_block = i == len(block_out_channels) - 1
226
+
227
+ prev_output_channel = output_channel
228
+ output_channel = reversed_block_out_channels[i]
229
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
230
+
231
+ # add upsample block for all BUT final layer
232
+ if not is_final_block:
233
+ add_upsample = True
234
+ self.num_upsamplers += 1
235
+ else:
236
+ add_upsample = False
237
+
238
+ up_block = get_up_block(
239
+ up_block_type,
240
+ num_layers=layers_per_block + 1,
241
+ in_channels=input_channel,
242
+ out_channels=output_channel,
243
+ prev_output_channel=prev_output_channel,
244
+ temb_channels=time_embed_dim,
245
+ add_upsample=add_upsample,
246
+ resnet_eps=norm_eps,
247
+ resnet_act_fn=act_fn,
248
+ resnet_groups=norm_num_groups,
249
+ cross_attention_dim=cross_attention_dim,
250
+ num_attention_heads=reversed_num_attention_heads[i],
251
+ dual_cross_attention=False,
252
+ )
253
+ self.up_blocks.append(up_block)
254
+ prev_output_channel = output_channel
255
+
256
+ # out
257
+ if norm_num_groups is not None:
258
+ self.conv_norm_out = nn.GroupNorm(
259
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
260
+ )
261
+ self.conv_act = nn.SiLU()
262
+ else:
263
+ self.conv_norm_out = None
264
+ self.conv_act = None
265
+
266
+ conv_out_padding = (conv_out_kernel - 1) // 2
267
+ self.conv_out = nn.Conv2d(
268
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
269
+ )
270
+
271
+ @property
272
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
273
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
274
+ r"""
275
+ Returns:
276
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
277
+ indexed by its weight name.
278
+ """
279
+ # set recursively
280
+ processors = {}
281
+
282
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
283
+ if hasattr(module, "set_processor"):
284
+ processors[f"{name}.processor"] = module.processor
285
+
286
+ for sub_name, child in module.named_children():
287
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
288
+
289
+ return processors
290
+
291
+ for name, module in self.named_children():
292
+ fn_recursive_add_processors(name, module, processors)
293
+
294
+ return processors
295
+
296
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
297
+ def set_attention_slice(self, slice_size):
298
+ r"""
299
+ Enable sliced attention computation.
300
+
301
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
302
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
303
+
304
+ Args:
305
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
306
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
307
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
308
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
309
+ must be a multiple of `slice_size`.
310
+ """
311
+ sliceable_head_dims = []
312
+
313
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
314
+ if hasattr(module, "set_attention_slice"):
315
+ sliceable_head_dims.append(module.sliceable_head_dim)
316
+
317
+ for child in module.children():
318
+ fn_recursive_retrieve_sliceable_dims(child)
319
+
320
+ # retrieve number of attention layers
321
+ for module in self.children():
322
+ fn_recursive_retrieve_sliceable_dims(module)
323
+
324
+ num_sliceable_layers = len(sliceable_head_dims)
325
+
326
+ if slice_size == "auto":
327
+ # half the attention head size is usually a good trade-off between
328
+ # speed and memory
329
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
330
+ elif slice_size == "max":
331
+ # make smallest slice possible
332
+ slice_size = num_sliceable_layers * [1]
333
+
334
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
335
+
336
+ if len(slice_size) != len(sliceable_head_dims):
337
+ raise ValueError(
338
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
339
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
340
+ )
341
+
342
+ for i in range(len(slice_size)):
343
+ size = slice_size[i]
344
+ dim = sliceable_head_dims[i]
345
+ if size is not None and size > dim:
346
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
347
+
348
+ # Recursively walk through all the children.
349
+ # Any children which exposes the set_attention_slice method
350
+ # gets the message
351
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
352
+ if hasattr(module, "set_attention_slice"):
353
+ module.set_attention_slice(slice_size.pop())
354
+
355
+ for child in module.children():
356
+ fn_recursive_set_attention_slice(child, slice_size)
357
+
358
+ reversed_slice_size = list(reversed(slice_size))
359
+ for module in self.children():
360
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
361
+
362
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
363
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
364
+ r"""
365
+ Sets the attention processor to use to compute attention.
366
+
367
+ Parameters:
368
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
369
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
370
+ for **all** `Attention` layers.
371
+
372
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
373
+ processor. This is strongly recommended when setting trainable attention processors.
374
+
375
+ """
376
+ count = len(self.attn_processors.keys())
377
+
378
+ if isinstance(processor, dict) and len(processor) != count:
379
+ raise ValueError(
380
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
381
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
382
+ )
383
+
384
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
385
+ if hasattr(module, "set_processor"):
386
+ if not isinstance(processor, dict):
387
+ module.set_processor(processor)
388
+ else:
389
+ module.set_processor(processor.pop(f"{name}.processor"))
390
+
391
+ for sub_name, child in module.named_children():
392
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
393
+
394
+ for name, module in self.named_children():
395
+ fn_recursive_attn_processor(name, module, processor)
396
+
397
+ def enable_forward_chunking(self, chunk_size=None, dim=0):
398
+ """
399
+ Sets the attention processor to use [feed forward
400
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
401
+
402
+ Parameters:
403
+ chunk_size (`int`, *optional*):
404
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
405
+ over each tensor of dim=`dim`.
406
+ dim (`int`, *optional*, defaults to `0`):
407
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
408
+ or dim=1 (sequence length).
409
+ """
410
+ if dim not in [0, 1]:
411
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
412
+
413
+ # By default chunk size is 1
414
+ chunk_size = chunk_size or 1
415
+
416
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
417
+ if hasattr(module, "set_chunk_feed_forward"):
418
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
419
+
420
+ for child in module.children():
421
+ fn_recursive_feed_forward(child, chunk_size, dim)
422
+
423
+ for module in self.children():
424
+ fn_recursive_feed_forward(module, chunk_size, dim)
425
+
426
+ def disable_forward_chunking(self):
427
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
428
+ if hasattr(module, "set_chunk_feed_forward"):
429
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
430
+
431
+ for child in module.children():
432
+ fn_recursive_feed_forward(child, chunk_size, dim)
433
+
434
+ for module in self.children():
435
+ fn_recursive_feed_forward(module, None, 0)
436
+
437
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
438
+ def set_default_attn_processor(self):
439
+ """
440
+ Disables custom attention processors and sets the default attention implementation.
441
+ """
442
+ self.set_attn_processor(AttnProcessor())
443
+
444
+ def _set_gradient_checkpointing(self, module, value=False):
445
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
446
+ module.gradient_checkpointing = value
447
+
448
+ def forward(
449
+ self,
450
+ sample: torch.FloatTensor,
451
+ timestep: Union[torch.Tensor, float, int],
452
+ encoder_hidden_states: torch.Tensor,
453
+ class_labels: Optional[torch.Tensor] = None,
454
+ timestep_cond: Optional[torch.Tensor] = None,
455
+ attention_mask: Optional[torch.Tensor] = None,
456
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
457
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
458
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
459
+ return_dict: bool = True,
460
+ ) -> Union[UNet3DConditionOutput, Tuple]:
461
+ r"""
462
+ The [`UNet3DConditionModel`] forward method.
463
+
464
+ Args:
465
+ sample (`torch.FloatTensor`):
466
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
467
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
468
+ encoder_hidden_states (`torch.FloatTensor`):
469
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
470
+ return_dict (`bool`, *optional*, defaults to `True`):
471
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
472
+ tuple.
473
+ cross_attention_kwargs (`dict`, *optional*):
474
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
475
+
476
+ Returns:
477
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
478
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
479
+ a `tuple` is returned where the first element is the sample tensor.
480
+ """
481
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
482
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
483
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
484
+ # on the fly if necessary.
485
+ default_overall_up_factor = 2**self.num_upsamplers
486
+
487
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
488
+ forward_upsample_size = False
489
+ upsample_size = None
490
+
491
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
492
+ logger.info("Forward upsample size to force interpolation output size.")
493
+ forward_upsample_size = True
494
+
495
+ # prepare attention_mask
496
+ if attention_mask is not None:
497
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
498
+ attention_mask = attention_mask.unsqueeze(1)
499
+
500
+ # 1. time
501
+ timesteps = timestep
502
+ if not torch.is_tensor(timesteps):
503
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
504
+ # This would be a good case for the `match` statement (Python 3.10+)
505
+ is_mps = sample.device.type == "mps"
506
+ if isinstance(timestep, float):
507
+ dtype = torch.float32 if is_mps else torch.float64
508
+ else:
509
+ dtype = torch.int32 if is_mps else torch.int64
510
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
511
+ elif len(timesteps.shape) == 0:
512
+ timesteps = timesteps[None].to(sample.device)
513
+
514
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
515
+ num_frames = sample.shape[2]
516
+ timesteps = timesteps.expand(sample.shape[0])
517
+
518
+ t_emb = self.time_proj(timesteps)
519
+
520
+ # timesteps does not contain any weights and will always return f32 tensors
521
+ # but time_embedding might actually be running in fp16. so we need to cast here.
522
+ # there might be better ways to encapsulate this.
523
+ t_emb = t_emb.to(dtype=self.dtype)
524
+
525
+ emb = self.time_embedding(t_emb, timestep_cond)
526
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
527
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
528
+
529
+ # 2. pre-process
530
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
531
+ sample = self.conv_in(sample)
532
+
533
+ sample = self.transformer_in(
534
+ sample,
535
+ num_frames=num_frames,
536
+ cross_attention_kwargs=cross_attention_kwargs,
537
+ return_dict=False,
538
+ )[0]
539
+
540
+ # 3. down
541
+ down_block_res_samples = (sample,)
542
+ for downsample_block in self.down_blocks:
543
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
544
+ sample, res_samples = downsample_block(
545
+ hidden_states=sample,
546
+ temb=emb,
547
+ encoder_hidden_states=encoder_hidden_states,
548
+ attention_mask=attention_mask,
549
+ num_frames=num_frames,
550
+ cross_attention_kwargs=cross_attention_kwargs,
551
+ )
552
+ else:
553
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
554
+
555
+ down_block_res_samples += res_samples
556
+
557
+ if down_block_additional_residuals is not None:
558
+ new_down_block_res_samples = ()
559
+
560
+ for down_block_res_sample, down_block_additional_residual in zip(
561
+ down_block_res_samples, down_block_additional_residuals
562
+ ):
563
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
564
+ new_down_block_res_samples += (down_block_res_sample,)
565
+
566
+ down_block_res_samples = new_down_block_res_samples
567
+
568
+ # 4. mid
569
+ if self.mid_block is not None:
570
+ sample = self.mid_block(
571
+ sample,
572
+ emb,
573
+ encoder_hidden_states=encoder_hidden_states,
574
+ attention_mask=attention_mask,
575
+ num_frames=num_frames,
576
+ cross_attention_kwargs=cross_attention_kwargs,
577
+ )
578
+
579
+ if mid_block_additional_residual is not None:
580
+ sample = sample + mid_block_additional_residual
581
+
582
+ # 5. up
583
+ for i, upsample_block in enumerate(self.up_blocks):
584
+ is_final_block = i == len(self.up_blocks) - 1
585
+
586
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
587
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
588
+
589
+ # if we have not reached the final block and need to forward the
590
+ # upsample size, we do it here
591
+ if not is_final_block and forward_upsample_size:
592
+ upsample_size = down_block_res_samples[-1].shape[2:]
593
+
594
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
595
+ sample = upsample_block(
596
+ hidden_states=sample,
597
+ temb=emb,
598
+ res_hidden_states_tuple=res_samples,
599
+ encoder_hidden_states=encoder_hidden_states,
600
+ upsample_size=upsample_size,
601
+ attention_mask=attention_mask,
602
+ num_frames=num_frames,
603
+ cross_attention_kwargs=cross_attention_kwargs,
604
+ )
605
+ else:
606
+ sample = upsample_block(
607
+ hidden_states=sample,
608
+ temb=emb,
609
+ res_hidden_states_tuple=res_samples,
610
+ upsample_size=upsample_size,
611
+ num_frames=num_frames,
612
+ )
613
+
614
+ # 6. post-process
615
+ if self.conv_norm_out:
616
+ sample = self.conv_norm_out(sample)
617
+ sample = self.conv_act(sample)
618
+
619
+ sample = self.conv_out(sample)
620
+
621
+ # reshape to (batch, channel, framerate, width, height)
622
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
623
+
624
+ if not return_dict:
625
+ return (sample,)
626
+
627
+ return UNet3DConditionOutput(sample=sample)