3v324v23 commited on
Commit
40f772a
·
1 Parent(s): 4d120be
Files changed (6) hide show
  1. app.py +162 -0
  2. data1.pth +3 -0
  3. data2.pth +3 -0
  4. data3.pth +3 -0
  5. pipeline_controlnet_sd_xl_raw.py +1895 -0
  6. requirements.txt +208 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from pipeline_controlnet_sd_xl_raw import StableDiffusionXLControlNetRAWPipeline
4
+ from diffusers import ControlNetModel, UniPCMultistepScheduler
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import traceback
8
+
9
+ # ========== 1. Load Models ==========
10
+ # base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
11
+ # controlnet_path = "/mnt/wencheng/RAWPami/diffusers/examples/controlnet/controlnet-model"
12
+
13
+ # controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
14
+ # pipe = StableDiffusionXLControlNetRAWPipeline.from_pretrained(
15
+ # base_model_path,
16
+ # controlnet=controlnet,
17
+ # torch_dtype=torch.float16
18
+ # )
19
+ pipe = StableDiffusionXLControlNetRAWPipeline.from_pretrained(
20
+ "wencheng256/DiffusionRAW",
21
+ torch_dtype=torch.float16
22
+ )
23
+
24
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
25
+ pipe.enable_model_cpu_offload()
26
+
27
+ # ========== 2. Utility function: tensor -> PIL ==========
28
+ def tensor_to_pil(img_tensor: torch.Tensor) -> Image.Image:
29
+ if img_tensor.is_cuda:
30
+ img_tensor = img_tensor.cpu()
31
+ if img_tensor.dtype != torch.float32:
32
+ img_tensor = img_tensor.float()
33
+ img_tensor = img_tensor.clamp(0, 1)
34
+ return transforms.ToPILImage()(img_tensor)
35
+
36
+ # ========== 3. Load a .pth file ==========
37
+ def load_pth_data(pth_path):
38
+ data = torch.load(pth_path)
39
+ rgb_tensor = data["rgb"]
40
+ raw_tensor = data["raw"]
41
+ mask_tensor = data["mask"]
42
+ cond_tensor = data["condition"]
43
+
44
+ # Assuming each key can contain multiple images; using the first index only
45
+ raw_image_pil = tensor_to_pil(raw_tensor[0][:, :448])
46
+ rgb_tensor = tensor_to_pil(torch.flip(rgb_tensor[0], dims=[0])[:, :448])
47
+ mask_image_pil = tensor_to_pil(1 - mask_tensor[0])
48
+
49
+ return rgb_tensor, raw_image_pil, mask_image_pil, raw_tensor, mask_tensor, cond_tensor
50
+
51
+ # ========== 4. Inference function ==========
52
+ def infer_fn(prompt, mask_edited, raw_tensor_state, mask_tensor_state, cond_tensor_state):
53
+ """
54
+ mask_edited: using tool='sketch' returns a dict containing {'image': PIL, 'mask': PIL}.
55
+ """
56
+ try:
57
+ if isinstance(mask_edited, dict):
58
+ # Usually we only need the drawn mask
59
+ mask_edited = mask_edited["mask"]
60
+
61
+ mask_edited_tensor = transforms.ToTensor()(mask_edited)
62
+ # Keep only one channel as grayscale mask
63
+ mask_edited_tensor = mask_edited_tensor[:1]
64
+ mask_edited_tensor = mask_edited_tensor.unsqueeze(0).half()
65
+
66
+ raw_t = raw_tensor_state.half()
67
+ cond_t = cond_tensor_state.half()
68
+
69
+ generator = torch.manual_seed(0)
70
+ print("Mask shape:", mask_edited_tensor.shape)
71
+ print("Raw shape:", raw_t.shape)
72
+ print("Cond shape:", cond_t.shape)
73
+
74
+ result = pipe(
75
+ prompt=prompt,
76
+ num_inference_steps=20,
77
+ generator=generator,
78
+ image=raw_t,
79
+ mask_image=mask_edited_tensor,
80
+ control_image=cond_t
81
+ ).images[0]
82
+
83
+ return tensor_to_pil(result)
84
+
85
+ except Exception as e:
86
+ traceback.print_exc()
87
+ return "Error occurred during inference. Please check the terminal logs!"
88
+
89
+ def build_demo():
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("# DiffusionRAW ")
92
+
93
+ # Provide a dropdown to select pth file
94
+ pth_options = ["./data1.pth", "./data2.pth", "./data3.pth"]
95
+ with gr.Row():
96
+ pth_selector = gr.Dropdown(
97
+ pth_options,
98
+ value=pth_options[0],
99
+ label="Select a PTH file"
100
+ )
101
+ load_button = gr.Button("Load")
102
+
103
+ with gr.Row():
104
+ # Display the raw image
105
+ raw_display = gr.Image(
106
+ label="Raw Image (Display Only)",
107
+ interactive=False,
108
+ )
109
+ rgb_display = gr.Image(
110
+ label="sRGB Image (Display Only)",
111
+ interactive=False,
112
+ )
113
+ # Mask editor with sketch tool
114
+ mask_editor = gr.Image(
115
+ label="Mask (Sketch)",
116
+ tool="sketch",
117
+ type="pil",
118
+ brush_color="#FFFFFF",
119
+ interactive=True,
120
+ width=512,
121
+ height=512
122
+ )
123
+
124
+ # States to store tensors
125
+ raw_tensor_state = gr.State()
126
+ mask_tensor_state = gr.State()
127
+ cond_tensor_state = gr.State()
128
+
129
+ load_button.click(
130
+ fn=load_pth_data,
131
+ inputs=[pth_selector],
132
+ outputs=[
133
+ rgb_display,
134
+ raw_display,
135
+ mask_editor,
136
+ raw_tensor_state,
137
+ mask_tensor_state,
138
+ cond_tensor_state
139
+ ]
140
+ )
141
+
142
+ prompt_input = gr.Textbox(label="Prompt", value="An RAW Image.", lines=1)
143
+ generate_button = gr.Button("Generate")
144
+ output_image = gr.Image(label="Output", show_download_button=False)
145
+
146
+ generate_button.click(
147
+ fn=infer_fn,
148
+ inputs=[
149
+ prompt_input,
150
+ mask_editor,
151
+ raw_tensor_state,
152
+ mask_tensor_state,
153
+ cond_tensor_state
154
+ ],
155
+ outputs=[output_image]
156
+ )
157
+
158
+ return demo
159
+
160
+ if __name__ == "__main__":
161
+ demo = build_demo()
162
+ demo.launch(server_name="0.0.0.0", server_port=9112, debug=True)
data1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3682dac1755bb77ee361279083deb66785e7a214a8f369e2e78686a2e7c6eb81
3
+ size 35841792
data2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25c4e21bd1b79b3cb8a6ed6a42b2aa1dd4d81eb0811123cbeee9d583f7cb3a46
3
+ size 35841792
data3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1a52df7e3fe3c713402923462323582ca510303751d81010806fe4a83501daf
3
+ size 35841792
pipeline_controlnet_sd_xl_raw.py ADDED
@@ -0,0 +1,1895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from transformers import (
23
+ CLIPImageProcessor,
24
+ CLIPTextModel,
25
+ CLIPTextModelWithProjection,
26
+ CLIPTokenizer,
27
+ CLIPVisionModelWithProjection,
28
+ )
29
+
30
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
31
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
+ from diffusers.loaders import (
33
+ FromSingleFileMixin,
34
+ IPAdapterMixin,
35
+ StableDiffusionXLLoraLoaderMixin,
36
+ TextualInversionLoaderMixin,
37
+ )
38
+ from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
39
+ from diffusers.models.attention_processor import (
40
+ AttnProcessor2_0,
41
+ XFormersAttnProcessor,
42
+ )
43
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
44
+ from diffusers.schedulers import KarrasDiffusionSchedulers
45
+ from diffusers.utils import (
46
+ USE_PEFT_BACKEND,
47
+ deprecate,
48
+ is_invisible_watermark_available,
49
+ logging,
50
+ replace_example_docstring,
51
+ scale_lora_layers,
52
+ unscale_lora_layers,
53
+ )
54
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
55
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
56
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
57
+
58
+
59
+ if is_invisible_watermark_available():
60
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
61
+
62
+
63
+ from diffusers.utils import is_torch_xla_available
64
+
65
+
66
+ if is_torch_xla_available():
67
+ import torch_xla.core.xla_model as xm
68
+
69
+ XLA_AVAILABLE = True
70
+ else:
71
+ XLA_AVAILABLE = False
72
+
73
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
74
+
75
+
76
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
77
+ def retrieve_latents(
78
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
79
+ ):
80
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
81
+ return encoder_output.latent_dist.sample(generator)
82
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
83
+ return encoder_output.latent_dist.mode()
84
+ elif hasattr(encoder_output, "latents"):
85
+ return encoder_output.latents
86
+ else:
87
+ raise AttributeError("Could not access latents of provided encoder_output")
88
+
89
+
90
+ EXAMPLE_DOC_STRING = """
91
+ Examples:
92
+ ```py
93
+ >>> # !pip install transformers accelerate
94
+ >>> from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
95
+ >>> from diffusers.utils import load_image
96
+ >>> from PIL import Image
97
+ >>> import numpy as np
98
+ >>> import torch
99
+
100
+ >>> init_image = load_image(
101
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
102
+ ... )
103
+ >>> init_image = init_image.resize((1024, 1024))
104
+
105
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
106
+
107
+ >>> mask_image = load_image(
108
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
109
+ ... )
110
+ >>> mask_image = mask_image.resize((1024, 1024))
111
+
112
+
113
+ >>> def make_canny_condition(image):
114
+ ... image = np.array(image)
115
+ ... image = cv2.Canny(image, 100, 200)
116
+ ... image = image[:, :, None]
117
+ ... image = np.concatenate([image, image, image], axis=2)
118
+ ... image = Image.fromarray(image)
119
+ ... return image
120
+
121
+
122
+ >>> control_image = make_canny_condition(init_image)
123
+
124
+ >>> controlnet = ControlNetModel.from_pretrained(
125
+ ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
126
+ ... )
127
+ >>> pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
128
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
129
+ ... )
130
+
131
+ >>> pipe.enable_model_cpu_offload()
132
+
133
+ >>> # generate image
134
+ >>> image = pipe(
135
+ ... "a handsome man with ray-ban sunglasses",
136
+ ... num_inference_steps=20,
137
+ ... generator=generator,
138
+ ... eta=1.0,
139
+ ... image=init_image,
140
+ ... mask_image=mask_image,
141
+ ... control_image=control_image,
142
+ ... ).images[0]
143
+ ```
144
+ """
145
+
146
+
147
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
148
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
149
+ r"""
150
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
151
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
152
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
153
+
154
+ Args:
155
+ noise_cfg (`torch.Tensor`):
156
+ The predicted noise tensor for the guided diffusion process.
157
+ noise_pred_text (`torch.Tensor`):
158
+ The predicted noise tensor for the text-guided diffusion process.
159
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
160
+ A rescale factor applied to the noise predictions.
161
+
162
+ Returns:
163
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
164
+ """
165
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
166
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
167
+ # rescale the results from guidance (fixes overexposure)
168
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
169
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
170
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
171
+ return noise_cfg
172
+
173
+
174
+ class StableDiffusionXLControlNetRAWPipeline(
175
+ DiffusionPipeline,
176
+ StableDiffusionMixin,
177
+ StableDiffusionXLLoraLoaderMixin,
178
+ FromSingleFileMixin,
179
+ IPAdapterMixin,
180
+ TextualInversionLoaderMixin,
181
+ ):
182
+ r"""
183
+ Pipeline for text-to-image generation using Stable Diffusion XL.
184
+
185
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
186
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
187
+
188
+ The pipeline also inherits the following loading methods:
189
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
190
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
191
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
192
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
193
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
194
+
195
+ Args:
196
+ vae ([`AutoencoderKL`]):
197
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
198
+ text_encoder ([`CLIPTextModel`]):
199
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
200
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
201
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
202
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
203
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
204
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
205
+ specifically the
206
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
207
+ variant.
208
+ tokenizer (`CLIPTokenizer`):
209
+ Tokenizer of class
210
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
211
+ tokenizer_2 (`CLIPTokenizer`):
212
+ Second Tokenizer of class
213
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
214
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
215
+ scheduler ([`SchedulerMixin`]):
216
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
217
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
218
+ """
219
+
220
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
221
+
222
+ _optional_components = [
223
+ "tokenizer",
224
+ "tokenizer_2",
225
+ "text_encoder",
226
+ "text_encoder_2",
227
+ "image_encoder",
228
+ "feature_extractor",
229
+ ]
230
+ _callback_tensor_inputs = [
231
+ "latents",
232
+ "prompt_embeds",
233
+ "negative_prompt_embeds",
234
+ "add_text_embeds",
235
+ "add_time_ids",
236
+ "negative_pooled_prompt_embeds",
237
+ "add_neg_time_ids",
238
+ "mask",
239
+ "masked_image_latents",
240
+ "control_image",
241
+ ]
242
+
243
+ def __init__(
244
+ self,
245
+ vae: AutoencoderKL,
246
+ text_encoder: CLIPTextModel,
247
+ text_encoder_2: CLIPTextModelWithProjection,
248
+ tokenizer: CLIPTokenizer,
249
+ tokenizer_2: CLIPTokenizer,
250
+ unet: UNet2DConditionModel,
251
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
252
+ scheduler: KarrasDiffusionSchedulers,
253
+ requires_aesthetics_score: bool = False,
254
+ force_zeros_for_empty_prompt: bool = True,
255
+ add_watermarker: Optional[bool] = None,
256
+ feature_extractor: Optional[CLIPImageProcessor] = None,
257
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
258
+ ):
259
+ super().__init__()
260
+
261
+ if isinstance(controlnet, (list, tuple)):
262
+ controlnet = MultiControlNetModel(controlnet)
263
+
264
+ self.register_modules(
265
+ vae=vae,
266
+ text_encoder=text_encoder,
267
+ text_encoder_2=text_encoder_2,
268
+ tokenizer=tokenizer,
269
+ tokenizer_2=tokenizer_2,
270
+ unet=unet,
271
+ controlnet=controlnet,
272
+ scheduler=scheduler,
273
+ feature_extractor=feature_extractor,
274
+ image_encoder=image_encoder,
275
+ )
276
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
277
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
278
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
279
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
280
+ self.mask_processor = VaeImageProcessor(
281
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
282
+ )
283
+ self.control_image_processor = VaeImageProcessor(
284
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
285
+ )
286
+
287
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
288
+
289
+ if add_watermarker:
290
+ self.watermark = StableDiffusionXLWatermarker()
291
+ else:
292
+ self.watermark = None
293
+
294
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
295
+ def encode_prompt(
296
+ self,
297
+ prompt: str,
298
+ prompt_2: Optional[str] = None,
299
+ device: Optional[torch.device] = None,
300
+ num_images_per_prompt: int = 1,
301
+ do_classifier_free_guidance: bool = True,
302
+ negative_prompt: Optional[str] = None,
303
+ negative_prompt_2: Optional[str] = None,
304
+ prompt_embeds: Optional[torch.Tensor] = None,
305
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
306
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
307
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
308
+ lora_scale: Optional[float] = None,
309
+ clip_skip: Optional[int] = None,
310
+ ):
311
+ r"""
312
+ Encodes the prompt into text encoder hidden states.
313
+
314
+ Args:
315
+ prompt (`str` or `List[str]`, *optional*):
316
+ prompt to be encoded
317
+ prompt_2 (`str` or `List[str]`, *optional*):
318
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
319
+ used in both text-encoders
320
+ device: (`torch.device`):
321
+ torch device
322
+ num_images_per_prompt (`int`):
323
+ number of images that should be generated per prompt
324
+ do_classifier_free_guidance (`bool`):
325
+ whether to use classifier free guidance or not
326
+ negative_prompt (`str` or `List[str]`, *optional*):
327
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
328
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
329
+ less than `1`).
330
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
331
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
332
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
333
+ prompt_embeds (`torch.Tensor`, *optional*):
334
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
335
+ provided, text embeddings will be generated from `prompt` input argument.
336
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
337
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
338
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
339
+ argument.
340
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
341
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
342
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
343
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
344
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
345
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
346
+ input argument.
347
+ lora_scale (`float`, *optional*):
348
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
349
+ clip_skip (`int`, *optional*):
350
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
351
+ the output of the pre-final layer will be used for computing the prompt embeddings.
352
+ """
353
+ device = device or self._execution_device
354
+
355
+ # set lora scale so that monkey patched LoRA
356
+ # function of text encoder can correctly access it
357
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
358
+ self._lora_scale = lora_scale
359
+
360
+ # dynamically adjust the LoRA scale
361
+ if self.text_encoder is not None:
362
+ if not USE_PEFT_BACKEND:
363
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
364
+ else:
365
+ scale_lora_layers(self.text_encoder, lora_scale)
366
+
367
+ if self.text_encoder_2 is not None:
368
+ if not USE_PEFT_BACKEND:
369
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
370
+ else:
371
+ scale_lora_layers(self.text_encoder_2, lora_scale)
372
+
373
+ prompt = [prompt] if isinstance(prompt, str) else prompt
374
+
375
+ if prompt is not None:
376
+ batch_size = len(prompt)
377
+ else:
378
+ batch_size = prompt_embeds.shape[0]
379
+
380
+ # Define tokenizers and text encoders
381
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
382
+ text_encoders = (
383
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
384
+ )
385
+
386
+ if prompt_embeds is None:
387
+ prompt_2 = prompt_2 or prompt
388
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
389
+
390
+ # textual inversion: process multi-vector tokens if necessary
391
+ prompt_embeds_list = []
392
+ prompts = [prompt, prompt_2]
393
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
394
+ if isinstance(self, TextualInversionLoaderMixin):
395
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
396
+
397
+ text_inputs = tokenizer(
398
+ prompt,
399
+ padding="max_length",
400
+ max_length=tokenizer.model_max_length,
401
+ truncation=True,
402
+ return_tensors="pt",
403
+ )
404
+
405
+ text_input_ids = text_inputs.input_ids
406
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
407
+
408
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
409
+ text_input_ids, untruncated_ids
410
+ ):
411
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
412
+ logger.warning(
413
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
414
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
415
+ )
416
+
417
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
418
+
419
+ # We are only ALWAYS interested in the pooled output of the final text encoder
420
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
421
+ pooled_prompt_embeds = prompt_embeds[0]
422
+
423
+ if clip_skip is None:
424
+ prompt_embeds = prompt_embeds.hidden_states[-2]
425
+ else:
426
+ # "2" because SDXL always indexes from the penultimate layer.
427
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
428
+
429
+ prompt_embeds_list.append(prompt_embeds)
430
+
431
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
432
+
433
+ # get unconditional embeddings for classifier free guidance
434
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
435
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
436
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
437
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
438
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
439
+ negative_prompt = negative_prompt or ""
440
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
441
+
442
+ # normalize str to list
443
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
444
+ negative_prompt_2 = (
445
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
446
+ )
447
+
448
+ uncond_tokens: List[str]
449
+ if prompt is not None and type(prompt) is not type(negative_prompt):
450
+ raise TypeError(
451
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
452
+ f" {type(prompt)}."
453
+ )
454
+ elif batch_size != len(negative_prompt):
455
+ raise ValueError(
456
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
457
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
458
+ " the batch size of `prompt`."
459
+ )
460
+ else:
461
+ uncond_tokens = [negative_prompt, negative_prompt_2]
462
+
463
+ negative_prompt_embeds_list = []
464
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
465
+ if isinstance(self, TextualInversionLoaderMixin):
466
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
467
+
468
+ max_length = prompt_embeds.shape[1]
469
+ uncond_input = tokenizer(
470
+ negative_prompt,
471
+ padding="max_length",
472
+ max_length=max_length,
473
+ truncation=True,
474
+ return_tensors="pt",
475
+ )
476
+
477
+ negative_prompt_embeds = text_encoder(
478
+ uncond_input.input_ids.to(device),
479
+ output_hidden_states=True,
480
+ )
481
+
482
+ # We are only ALWAYS interested in the pooled output of the final text encoder
483
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
484
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
485
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
486
+
487
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
488
+
489
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
490
+
491
+ if self.text_encoder_2 is not None:
492
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
493
+ else:
494
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
495
+
496
+ bs_embed, seq_len, _ = prompt_embeds.shape
497
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
498
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
499
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
500
+
501
+ if do_classifier_free_guidance:
502
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
503
+ seq_len = negative_prompt_embeds.shape[1]
504
+
505
+ if self.text_encoder_2 is not None:
506
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
507
+ else:
508
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
509
+
510
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
511
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
512
+
513
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
514
+ bs_embed * num_images_per_prompt, -1
515
+ )
516
+ if do_classifier_free_guidance:
517
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
518
+ bs_embed * num_images_per_prompt, -1
519
+ )
520
+
521
+ if self.text_encoder is not None:
522
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
523
+ # Retrieve the original scale by scaling back the LoRA layers
524
+ unscale_lora_layers(self.text_encoder, lora_scale)
525
+
526
+ if self.text_encoder_2 is not None:
527
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
528
+ # Retrieve the original scale by scaling back the LoRA layers
529
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
530
+
531
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
532
+
533
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
534
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
535
+ dtype = next(self.image_encoder.parameters()).dtype
536
+
537
+ if not isinstance(image, torch.Tensor):
538
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
539
+
540
+ image = image.to(device=device, dtype=dtype)
541
+ if output_hidden_states:
542
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
543
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
544
+ uncond_image_enc_hidden_states = self.image_encoder(
545
+ torch.zeros_like(image), output_hidden_states=True
546
+ ).hidden_states[-2]
547
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
548
+ num_images_per_prompt, dim=0
549
+ )
550
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
551
+ else:
552
+ image_embeds = self.image_encoder(image).image_embeds
553
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
554
+ uncond_image_embeds = torch.zeros_like(image_embeds)
555
+
556
+ return image_embeds, uncond_image_embeds
557
+
558
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
559
+ def prepare_ip_adapter_image_embeds(
560
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
561
+ ):
562
+ image_embeds = []
563
+ if do_classifier_free_guidance:
564
+ negative_image_embeds = []
565
+ if ip_adapter_image_embeds is None:
566
+ if not isinstance(ip_adapter_image, list):
567
+ ip_adapter_image = [ip_adapter_image]
568
+
569
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
570
+ raise ValueError(
571
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
572
+ )
573
+
574
+ for single_ip_adapter_image, image_proj_layer in zip(
575
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
576
+ ):
577
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
578
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
579
+ single_ip_adapter_image, device, 1, output_hidden_state
580
+ )
581
+
582
+ image_embeds.append(single_image_embeds[None, :])
583
+ if do_classifier_free_guidance:
584
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
585
+ else:
586
+ for single_image_embeds in ip_adapter_image_embeds:
587
+ if do_classifier_free_guidance:
588
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
589
+ negative_image_embeds.append(single_negative_image_embeds)
590
+ image_embeds.append(single_image_embeds)
591
+
592
+ ip_adapter_image_embeds = []
593
+ for i, single_image_embeds in enumerate(image_embeds):
594
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
595
+ if do_classifier_free_guidance:
596
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
597
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
598
+
599
+ single_image_embeds = single_image_embeds.to(device=device)
600
+ ip_adapter_image_embeds.append(single_image_embeds)
601
+
602
+ return ip_adapter_image_embeds
603
+
604
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
605
+ def prepare_extra_step_kwargs(self, generator, eta):
606
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
607
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
608
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
609
+ # and should be between [0, 1]
610
+
611
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
612
+ extra_step_kwargs = {}
613
+ if accepts_eta:
614
+ extra_step_kwargs["eta"] = eta
615
+
616
+ # check if the scheduler accepts generator
617
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
618
+ if accepts_generator:
619
+ extra_step_kwargs["generator"] = generator
620
+ return extra_step_kwargs
621
+
622
+ def check_image(self, image, prompt, prompt_embeds):
623
+ image_is_pil = isinstance(image, PIL.Image.Image)
624
+ image_is_tensor = isinstance(image, torch.Tensor)
625
+ image_is_np = isinstance(image, np.ndarray)
626
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
627
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
628
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
629
+
630
+ if (
631
+ not image_is_pil
632
+ and not image_is_tensor
633
+ and not image_is_np
634
+ and not image_is_pil_list
635
+ and not image_is_tensor_list
636
+ and not image_is_np_list
637
+ ):
638
+ raise TypeError(
639
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
640
+ )
641
+
642
+ if image_is_pil:
643
+ image_batch_size = 1
644
+ else:
645
+ image_batch_size = len(image)
646
+
647
+ if prompt is not None and isinstance(prompt, str):
648
+ prompt_batch_size = 1
649
+ elif prompt is not None and isinstance(prompt, list):
650
+ prompt_batch_size = len(prompt)
651
+ elif prompt_embeds is not None:
652
+ prompt_batch_size = prompt_embeds.shape[0]
653
+
654
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
655
+ raise ValueError(
656
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
657
+ )
658
+
659
+ def check_inputs(
660
+ self,
661
+ prompt,
662
+ prompt_2,
663
+ image,
664
+ mask_image,
665
+ strength,
666
+ num_inference_steps,
667
+ callback_steps,
668
+ output_type,
669
+ negative_prompt=None,
670
+ negative_prompt_2=None,
671
+ prompt_embeds=None,
672
+ negative_prompt_embeds=None,
673
+ ip_adapter_image=None,
674
+ ip_adapter_image_embeds=None,
675
+ pooled_prompt_embeds=None,
676
+ negative_pooled_prompt_embeds=None,
677
+ controlnet_conditioning_scale=1.0,
678
+ control_guidance_start=0.0,
679
+ control_guidance_end=1.0,
680
+ callback_on_step_end_tensor_inputs=None,
681
+ padding_mask_crop=None,
682
+ ):
683
+ if strength < 0 or strength > 1:
684
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
685
+ if num_inference_steps is None:
686
+ raise ValueError("`num_inference_steps` cannot be None.")
687
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
688
+ raise ValueError(
689
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
690
+ f" {type(num_inference_steps)}."
691
+ )
692
+
693
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
694
+ raise ValueError(
695
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
696
+ f" {type(callback_steps)}."
697
+ )
698
+
699
+ if callback_on_step_end_tensor_inputs is not None and not all(
700
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
701
+ ):
702
+ raise ValueError(
703
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
704
+ )
705
+
706
+ if prompt is not None and prompt_embeds is not None:
707
+ raise ValueError(
708
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
709
+ " only forward one of the two."
710
+ )
711
+ elif prompt_2 is not None and prompt_embeds is not None:
712
+ raise ValueError(
713
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
714
+ " only forward one of the two."
715
+ )
716
+ elif prompt is None and prompt_embeds is None:
717
+ raise ValueError(
718
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
719
+ )
720
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
721
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
722
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
723
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
724
+
725
+ if negative_prompt is not None and negative_prompt_embeds is not None:
726
+ raise ValueError(
727
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
728
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
729
+ )
730
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
731
+ raise ValueError(
732
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
733
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
734
+ )
735
+
736
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
737
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
738
+ raise ValueError(
739
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
740
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
741
+ f" {negative_prompt_embeds.shape}."
742
+ )
743
+
744
+ if padding_mask_crop is not None:
745
+ if not isinstance(image, PIL.Image.Image):
746
+ raise ValueError(
747
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
748
+ )
749
+ if not isinstance(mask_image, PIL.Image.Image):
750
+ raise ValueError(
751
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
752
+ f" {type(mask_image)}."
753
+ )
754
+ if output_type != "pil":
755
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
756
+
757
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
758
+ raise ValueError(
759
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
760
+ )
761
+
762
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
763
+ raise ValueError(
764
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
765
+ )
766
+
767
+ # `prompt` needs more sophisticated handling when there are multiple
768
+ # conditionings.
769
+ if isinstance(self.controlnet, MultiControlNetModel):
770
+ if isinstance(prompt, list):
771
+ logger.warning(
772
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
773
+ " prompts. The conditionings will be fixed across the prompts."
774
+ )
775
+
776
+ # Check `image`
777
+ # is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
778
+ # self.controlnet, torch._dynamo.eval_frame.OptimizedModule
779
+ # )
780
+ # if (
781
+ # isinstance(self.controlnet, ControlNetModel)
782
+ # or is_compiled
783
+ # and isinstance(self.controlnet._orig_mod, ControlNetModel)
784
+ # ):
785
+ # self.check_image(image, prompt, prompt_embeds)
786
+ # elif (
787
+ # isinstance(self.controlnet, MultiControlNetModel)
788
+ # or is_compiled
789
+ # and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
790
+ # ):
791
+ # if not isinstance(image, list):
792
+ # raise TypeError("For multiple controlnets: `image` must be type `list`")
793
+
794
+ # # When `image` is a nested list:
795
+ # # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
796
+ # elif any(isinstance(i, list) for i in image):
797
+ # raise ValueError("A single batch of multiple conditionings are supported at the moment.")
798
+ # elif len(image) != len(self.controlnet.nets):
799
+ # raise ValueError(
800
+ # f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
801
+ # )
802
+
803
+ # for image_ in image:
804
+ # self.check_image(image_, prompt, prompt_embeds)
805
+ # else:
806
+ # assert False
807
+
808
+ # Check `controlnet_conditioning_scale`
809
+ if (
810
+ isinstance(self.controlnet, ControlNetModel)
811
+ or is_compiled
812
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
813
+ ):
814
+ if not isinstance(controlnet_conditioning_scale, float):
815
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
816
+ elif (
817
+ isinstance(self.controlnet, MultiControlNetModel)
818
+ or is_compiled
819
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
820
+ ):
821
+ if isinstance(controlnet_conditioning_scale, list):
822
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
823
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
824
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
825
+ self.controlnet.nets
826
+ ):
827
+ raise ValueError(
828
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
829
+ " the same length as the number of controlnets"
830
+ )
831
+ else:
832
+ assert False
833
+
834
+ if not isinstance(control_guidance_start, (tuple, list)):
835
+ control_guidance_start = [control_guidance_start]
836
+
837
+ if not isinstance(control_guidance_end, (tuple, list)):
838
+ control_guidance_end = [control_guidance_end]
839
+
840
+ if len(control_guidance_start) != len(control_guidance_end):
841
+ raise ValueError(
842
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
843
+ )
844
+
845
+ if isinstance(self.controlnet, MultiControlNetModel):
846
+ if len(control_guidance_start) != len(self.controlnet.nets):
847
+ raise ValueError(
848
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
849
+ )
850
+
851
+ for start, end in zip(control_guidance_start, control_guidance_end):
852
+ if start >= end:
853
+ raise ValueError(
854
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
855
+ )
856
+ if start < 0.0:
857
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
858
+ if end > 1.0:
859
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
860
+
861
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
862
+ raise ValueError(
863
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
864
+ )
865
+
866
+ if ip_adapter_image_embeds is not None:
867
+ if not isinstance(ip_adapter_image_embeds, list):
868
+ raise ValueError(
869
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
870
+ )
871
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
872
+ raise ValueError(
873
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
874
+ )
875
+
876
+ def prepare_control_image(
877
+ self,
878
+ image,
879
+ width,
880
+ height,
881
+ batch_size,
882
+ num_images_per_prompt,
883
+ device,
884
+ dtype,
885
+ crops_coords,
886
+ resize_mode,
887
+ do_classifier_free_guidance=False,
888
+ guess_mode=False,
889
+ ):
890
+ image = self.control_image_processor.preprocess(
891
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
892
+ ).to(dtype=torch.float32)
893
+ image_batch_size = image.shape[0]
894
+
895
+ if image_batch_size == 1:
896
+ repeat_by = batch_size
897
+ else:
898
+ # image batch size is the same as prompt batch size
899
+ repeat_by = num_images_per_prompt
900
+
901
+ image = image.repeat_interleave(repeat_by, dim=0)
902
+
903
+ image = image.to(device=device, dtype=dtype)
904
+
905
+ if do_classifier_free_guidance and not guess_mode:
906
+ image = torch.cat([image] * 2)
907
+
908
+ return image
909
+
910
+ def prepare_latents(
911
+ self,
912
+ batch_size,
913
+ num_channels_latents,
914
+ height,
915
+ width,
916
+ dtype,
917
+ device,
918
+ generator,
919
+ latents=None,
920
+ image=None,
921
+ timestep=None,
922
+ is_strength_max=True,
923
+ add_noise=True,
924
+ return_noise=False,
925
+ return_image_latents=False,
926
+ ):
927
+ shape = (
928
+ batch_size,
929
+ num_channels_latents,
930
+ int(height) // self.vae_scale_factor,
931
+ int(width) // self.vae_scale_factor,
932
+ )
933
+ if isinstance(generator, list) and len(generator) != batch_size:
934
+ raise ValueError(
935
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
936
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
937
+ )
938
+
939
+ if (image is None or timestep is None) and not is_strength_max:
940
+ raise ValueError(
941
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
942
+ "However, either the image or the noise timestep has not been provided."
943
+ )
944
+
945
+ if return_image_latents or (latents is None and not is_strength_max):
946
+ image = image.to(device=device, dtype=dtype)
947
+
948
+ if image.shape[1] == 4:
949
+ image_latents = image
950
+ else:
951
+ image_latents = self._encode_vae_image(image=image, generator=generator)
952
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
953
+
954
+ if latents is None and add_noise:
955
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
956
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
957
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
958
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
959
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
960
+ elif add_noise:
961
+ noise = latents.to(device)
962
+ latents = noise * self.scheduler.init_noise_sigma
963
+ else:
964
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
965
+ latents = image_latents.to(device)
966
+
967
+ outputs = (latents,)
968
+
969
+ if return_noise:
970
+ outputs += (noise,)
971
+
972
+ if return_image_latents:
973
+ outputs += (image_latents,)
974
+
975
+ return outputs
976
+
977
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
978
+ dtype = image.dtype
979
+ if self.vae.config.force_upcast:
980
+ image = image.float()
981
+ self.vae.to(dtype=torch.float32)
982
+
983
+ if isinstance(generator, list):
984
+ image_latents = [
985
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
986
+ for i in range(image.shape[0])
987
+ ]
988
+ image_latents = torch.cat(image_latents, dim=0)
989
+ else:
990
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
991
+
992
+ if self.vae.config.force_upcast:
993
+ self.vae.to(dtype)
994
+
995
+ image_latents = image_latents.to(dtype)
996
+ image_latents = self.vae.config.scaling_factor * image_latents
997
+
998
+ return image_latents
999
+
1000
+ def prepare_mask_latents(
1001
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
1002
+ ):
1003
+ # resize the mask to latents shape as we concatenate the mask to the latents
1004
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
1005
+ # and half precision
1006
+ mask = torch.nn.functional.interpolate(
1007
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
1008
+ )
1009
+ mask = mask.to(device=device, dtype=dtype)
1010
+
1011
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
1012
+ if mask.shape[0] < batch_size:
1013
+ if not batch_size % mask.shape[0] == 0:
1014
+ raise ValueError(
1015
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
1016
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
1017
+ " of masks that you pass is divisible by the total requested batch size."
1018
+ )
1019
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
1020
+
1021
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
1022
+
1023
+ masked_image_latents = None
1024
+ if masked_image is not None:
1025
+ masked_image = masked_image.to(device=device, dtype=dtype)
1026
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
1027
+ if masked_image_latents.shape[0] < batch_size:
1028
+ if not batch_size % masked_image_latents.shape[0] == 0:
1029
+ raise ValueError(
1030
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
1031
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
1032
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
1033
+ )
1034
+ masked_image_latents = masked_image_latents.repeat(
1035
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
1036
+ )
1037
+
1038
+ masked_image_latents = (
1039
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
1040
+ )
1041
+
1042
+ # aligning device to prevent device errors when concating it with the latent model input
1043
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
1044
+
1045
+ return mask, masked_image_latents
1046
+
1047
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
1048
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
1049
+ # get the original timestep using init_timestep
1050
+ if denoising_start is None:
1051
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
1052
+ t_start = max(num_inference_steps - init_timestep, 0)
1053
+
1054
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
1055
+ if hasattr(self.scheduler, "set_begin_index"):
1056
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
1057
+
1058
+ return timesteps, num_inference_steps - t_start
1059
+
1060
+ else:
1061
+ # Strength is irrelevant if we directly request a timestep to start at;
1062
+ # that is, strength is determined by the denoising_start instead.
1063
+ discrete_timestep_cutoff = int(
1064
+ round(
1065
+ self.scheduler.config.num_train_timesteps
1066
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
1067
+ )
1068
+ )
1069
+
1070
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
1071
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
1072
+ # if the scheduler is a 2nd order scheduler we might have to do +1
1073
+ # because `num_inference_steps` might be even given that every timestep
1074
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
1075
+ # mean that we cut the timesteps in the middle of the denoising step
1076
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
1077
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
1078
+ num_inference_steps = num_inference_steps + 1
1079
+
1080
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
1081
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
1082
+ timesteps = self.scheduler.timesteps[t_start:]
1083
+ if hasattr(self.scheduler, "set_begin_index"):
1084
+ self.scheduler.set_begin_index(t_start)
1085
+ return timesteps, num_inference_steps
1086
+
1087
+ def _get_add_time_ids(
1088
+ self,
1089
+ original_size,
1090
+ crops_coords_top_left,
1091
+ target_size,
1092
+ aesthetic_score,
1093
+ negative_aesthetic_score,
1094
+ dtype,
1095
+ text_encoder_projection_dim=None,
1096
+ ):
1097
+ if self.config.requires_aesthetics_score:
1098
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
1099
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
1100
+ else:
1101
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1102
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
1103
+
1104
+ passed_add_embed_dim = (
1105
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
1106
+ )
1107
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
1108
+
1109
+ if (
1110
+ expected_add_embed_dim > passed_add_embed_dim
1111
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
1112
+ ):
1113
+ raise ValueError(
1114
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
1115
+ )
1116
+ elif (
1117
+ expected_add_embed_dim < passed_add_embed_dim
1118
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
1119
+ ):
1120
+ raise ValueError(
1121
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
1122
+ )
1123
+ elif expected_add_embed_dim != passed_add_embed_dim:
1124
+ raise ValueError(
1125
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
1126
+ )
1127
+
1128
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1129
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
1130
+
1131
+ return add_time_ids, add_neg_time_ids
1132
+
1133
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
1134
+ def upcast_vae(self):
1135
+ dtype = self.vae.dtype
1136
+ self.vae.to(dtype=torch.float32)
1137
+ use_torch_2_0_or_xformers = isinstance(
1138
+ self.vae.decoder.mid_block.attentions[0].processor,
1139
+ (
1140
+ AttnProcessor2_0,
1141
+ XFormersAttnProcessor,
1142
+ ),
1143
+ )
1144
+ # if xformers or torch_2_0 is used attention block does not need
1145
+ # to be in float32 which can save lots of memory
1146
+ if use_torch_2_0_or_xformers:
1147
+ self.vae.post_quant_conv.to(dtype)
1148
+ self.vae.decoder.conv_in.to(dtype)
1149
+ self.vae.decoder.mid_block.to(dtype)
1150
+
1151
+ @property
1152
+ def guidance_scale(self):
1153
+ return self._guidance_scale
1154
+
1155
+ @property
1156
+ def clip_skip(self):
1157
+ return self._clip_skip
1158
+
1159
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1160
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1161
+ # corresponds to doing no classifier free guidance.
1162
+ @property
1163
+ def do_classifier_free_guidance(self):
1164
+ return self._guidance_scale > 1
1165
+
1166
+ @property
1167
+ def cross_attention_kwargs(self):
1168
+ return self._cross_attention_kwargs
1169
+
1170
+ @property
1171
+ def num_timesteps(self):
1172
+ return self._num_timesteps
1173
+
1174
+ @property
1175
+ def interrupt(self):
1176
+ return self._interrupt
1177
+
1178
+ @torch.no_grad()
1179
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1180
+ def __call__(
1181
+ self,
1182
+ prompt: Union[str, List[str]] = None,
1183
+ prompt_2: Optional[Union[str, List[str]]] = None,
1184
+ image: PipelineImageInput = None,
1185
+ mask_image: PipelineImageInput = None,
1186
+ control_image = None,
1187
+ height: Optional[int] = None,
1188
+ width: Optional[int] = None,
1189
+ padding_mask_crop: Optional[int] = None,
1190
+ strength: float = 0.9999,
1191
+ num_inference_steps: int = 50,
1192
+ denoising_start: Optional[float] = None,
1193
+ denoising_end: Optional[float] = None,
1194
+ guidance_scale: float = 5.0,
1195
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1196
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1197
+ num_images_per_prompt: Optional[int] = 1,
1198
+ eta: float = 0.0,
1199
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1200
+ latents: Optional[torch.Tensor] = None,
1201
+ prompt_embeds: Optional[torch.Tensor] = None,
1202
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1203
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1204
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1205
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1206
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1207
+ output_type: Optional[str] = "pil",
1208
+ return_dict: bool = True,
1209
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1210
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
1211
+ guess_mode: bool = False,
1212
+ control_guidance_start: Union[float, List[float]] = 0.0,
1213
+ control_guidance_end: Union[float, List[float]] = 1.0,
1214
+ guidance_rescale: float = 0.0,
1215
+ original_size: Tuple[int, int] = None,
1216
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1217
+ target_size: Tuple[int, int] = None,
1218
+ aesthetic_score: float = 6.0,
1219
+ negative_aesthetic_score: float = 2.5,
1220
+ clip_skip: Optional[int] = None,
1221
+ callback_on_step_end: Optional[
1222
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1223
+ ] = None,
1224
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1225
+ **kwargs,
1226
+ ):
1227
+ r"""
1228
+ Function invoked when calling the pipeline for generation.
1229
+
1230
+ Args:
1231
+ prompt (`str` or `List[str]`, *optional*):
1232
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1233
+ instead.
1234
+ prompt_2 (`str` or `List[str]`, *optional*):
1235
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1236
+ used in both text-encoders
1237
+ image (`PIL.Image.Image`):
1238
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
1239
+ be masked out with `mask_image` and repainted according to `prompt`.
1240
+ mask_image (`PIL.Image.Image`):
1241
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1242
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1243
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1244
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
1245
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1246
+ The height in pixels of the generated image.
1247
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1248
+ The width in pixels of the generated image.
1249
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1250
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1251
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1252
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
1253
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1254
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
1255
+ the image is large and contain information irrelevant for inpainting, such as background.
1256
+ strength (`float`, *optional*, defaults to 0.9999):
1257
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1258
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1259
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1260
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1261
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1262
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
1263
+ integer, the value of `strength` will be ignored.
1264
+ num_inference_steps (`int`, *optional*, defaults to 50):
1265
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1266
+ expense of slower inference.
1267
+ denoising_start (`float`, *optional*):
1268
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1269
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
1270
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
1271
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
1272
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
1273
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1274
+ denoising_end (`float`, *optional*):
1275
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1276
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1277
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
1278
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
1279
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
1280
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1281
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1282
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1283
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1284
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1285
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1286
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1287
+ usually at the expense of lower image quality.
1288
+ negative_prompt (`str` or `List[str]`, *optional*):
1289
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1290
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1291
+ less than `1`).
1292
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1293
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1294
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1295
+ prompt_embeds (`torch.Tensor`, *optional*):
1296
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1297
+ provided, text embeddings will be generated from `prompt` input argument.
1298
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1299
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1300
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1301
+ argument.
1302
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1303
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1304
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1305
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1306
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1307
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1308
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1309
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1310
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1311
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1312
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1313
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1314
+ input argument.
1315
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1316
+ The number of images to generate per prompt.
1317
+ eta (`float`, *optional*, defaults to 0.0):
1318
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1319
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1320
+ generator (`torch.Generator`, *optional*):
1321
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1322
+ to make generation deterministic.
1323
+ latents (`torch.Tensor`, *optional*):
1324
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1325
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1326
+ tensor will ge generated by sampling using the supplied random `generator`.
1327
+ output_type (`str`, *optional*, defaults to `"pil"`):
1328
+ The output format of the generate image. Choose between
1329
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1330
+ return_dict (`bool`, *optional*, defaults to `True`):
1331
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1332
+ plain tuple.
1333
+ cross_attention_kwargs (`dict`, *optional*):
1334
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1335
+ `self.processor` in
1336
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1337
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1338
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1339
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
1340
+ explained in section 2.2 of
1341
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1342
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1343
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1344
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1345
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1346
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1347
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1348
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1349
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
1350
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1351
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
1352
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1353
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1354
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1355
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1356
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1357
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1358
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
1359
+ clip_skip (`int`, *optional*):
1360
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1361
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1362
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1363
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1364
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1365
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1366
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1367
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1368
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1369
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1370
+ `._callback_tensor_inputs` attribute of your pipeline class.
1371
+
1372
+ Examples:
1373
+
1374
+ Returns:
1375
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
1376
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1377
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
1378
+ """
1379
+
1380
+ callback = kwargs.pop("callback", None)
1381
+ callback_steps = kwargs.pop("callback_steps", None)
1382
+
1383
+ if callback is not None:
1384
+ deprecate(
1385
+ "callback",
1386
+ "1.0.0",
1387
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1388
+ )
1389
+ if callback_steps is not None:
1390
+ deprecate(
1391
+ "callback_steps",
1392
+ "1.0.0",
1393
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1394
+ )
1395
+
1396
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1397
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1398
+
1399
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1400
+
1401
+ # align format for control guidance
1402
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1403
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1404
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1405
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1406
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1407
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1408
+ control_guidance_start, control_guidance_end = (
1409
+ mult * [control_guidance_start],
1410
+ mult * [control_guidance_end],
1411
+ )
1412
+
1413
+ # # 0.0 Default height and width to unet
1414
+ # height = height or self.unet.config.sample_size * self.vae_scale_factor
1415
+ # width = width or self.unet.config.sample_size * self.vae_scale_factor
1416
+
1417
+ # 0.1 align format for control guidance
1418
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1419
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1420
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1421
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1422
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1423
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1424
+ control_guidance_start, control_guidance_end = (
1425
+ mult * [control_guidance_start],
1426
+ mult * [control_guidance_end],
1427
+ )
1428
+
1429
+ # 1. Check inputs
1430
+ # self.check_inputs(
1431
+ # prompt,
1432
+ # prompt_2,
1433
+ # control_image,
1434
+ # mask_image,
1435
+ # strength,
1436
+ # num_inference_steps,
1437
+ # callback_steps,
1438
+ # output_type,
1439
+ # negative_prompt,
1440
+ # negative_prompt_2,
1441
+ # prompt_embeds,
1442
+ # negative_prompt_embeds,
1443
+ # ip_adapter_image,
1444
+ # ip_adapter_image_embeds,
1445
+ # pooled_prompt_embeds,
1446
+ # negative_pooled_prompt_embeds,
1447
+ # controlnet_conditioning_scale,
1448
+ # control_guidance_start,
1449
+ # control_guidance_end,
1450
+ # callback_on_step_end_tensor_inputs,
1451
+ # padding_mask_crop,
1452
+ # )
1453
+
1454
+ def pad_to_800(img: torch.Tensor) -> torch.Tensor:
1455
+ b, c, h, w = img.shape
1456
+ pad_h = max(0, 800 - h)
1457
+ pad_w = max(0, 800 - w)
1458
+ padded = F.pad(img, (0, pad_w, 0, pad_h), mode="constant", value=0)
1459
+ return padded[:, :800, :800]
1460
+ image = pad_to_800(image)
1461
+
1462
+ self._guidance_scale = guidance_scale
1463
+ self._clip_skip = clip_skip
1464
+ self._cross_attention_kwargs = cross_attention_kwargs
1465
+ self._interrupt = False
1466
+
1467
+ # 2. Define call parameters
1468
+ if prompt is not None and isinstance(prompt, str):
1469
+ batch_size = 1
1470
+ elif prompt is not None and isinstance(prompt, list):
1471
+ batch_size = len(prompt)
1472
+ else:
1473
+ batch_size = prompt_embeds.shape[0]
1474
+
1475
+ device = self._execution_device
1476
+
1477
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1478
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1479
+
1480
+ # 3. Encode input prompt
1481
+ text_encoder_lora_scale = (
1482
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1483
+ )
1484
+
1485
+ (
1486
+ prompt_embeds,
1487
+ negative_prompt_embeds,
1488
+ pooled_prompt_embeds,
1489
+ negative_pooled_prompt_embeds,
1490
+ ) = self.encode_prompt(
1491
+ prompt=prompt,
1492
+ prompt_2=prompt_2,
1493
+ device=device,
1494
+ num_images_per_prompt=num_images_per_prompt,
1495
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1496
+ negative_prompt=negative_prompt,
1497
+ negative_prompt_2=negative_prompt_2,
1498
+ prompt_embeds=prompt_embeds,
1499
+ negative_prompt_embeds=negative_prompt_embeds,
1500
+ pooled_prompt_embeds=pooled_prompt_embeds,
1501
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1502
+ lora_scale=text_encoder_lora_scale,
1503
+ clip_skip=self.clip_skip,
1504
+ )
1505
+
1506
+ # 3.1 Encode ip_adapter_image
1507
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1508
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1509
+ ip_adapter_image,
1510
+ ip_adapter_image_embeds,
1511
+ device,
1512
+ batch_size * num_images_per_prompt,
1513
+ self.do_classifier_free_guidance,
1514
+ )
1515
+
1516
+ # 4. set timesteps
1517
+ def denoising_value_valid(dnv):
1518
+ return isinstance(dnv, float) and 0 < dnv < 1
1519
+
1520
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1521
+ timesteps, num_inference_steps = self.get_timesteps(
1522
+ num_inference_steps,
1523
+ strength,
1524
+ device,
1525
+ denoising_start=denoising_start if denoising_value_valid(denoising_start) else None,
1526
+ )
1527
+ # check that number of inference steps is not < 1 - as this doesn't make sense
1528
+ if num_inference_steps < 1:
1529
+ raise ValueError(
1530
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1531
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1532
+ )
1533
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1534
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1535
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1536
+ is_strength_max = strength == 1.0
1537
+ self._num_timesteps = len(timesteps)
1538
+
1539
+ # 5. Preprocess mask and image - resizes image and mask w.r.t height and width
1540
+ # 5.1 Prepare init image
1541
+ height, width = (800, 800)
1542
+
1543
+ if padding_mask_crop is not None:
1544
+ height, width = self.image_processor.get_default_height_width(image, height, width)
1545
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1546
+ resize_mode = "fill"
1547
+ else:
1548
+ crops_coords = None
1549
+ resize_mode = "default"
1550
+
1551
+ original_image = image
1552
+ init_image = image
1553
+ # init_image = self.image_processor.preprocess(
1554
+ # image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1555
+ # )
1556
+ # init_image = init_image.to(dtype=torch.float32)
1557
+
1558
+ # 5.2 Prepare control images
1559
+ # if isinstance(controlnet, ControlNetModel):
1560
+ # control_image = self.prepare_control_image(
1561
+ # image=control_image,
1562
+ # width=width,
1563
+ # height=height,
1564
+ # batch_size=batch_size * num_images_per_prompt,
1565
+ # num_images_per_prompt=num_images_per_prompt,
1566
+ # device=device,
1567
+ # dtype=controlnet.dtype,
1568
+ # crops_coords=crops_coords,
1569
+ # resize_mode=resize_mode,
1570
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
1571
+ # guess_mode=guess_mode,
1572
+ # )
1573
+ # elif isinstance(controlnet, MultiControlNetModel):
1574
+ # control_images = []
1575
+
1576
+ # for control_image_ in control_image:
1577
+ # control_image_ = self.prepare_control_image(
1578
+ # image=control_image_,
1579
+ # width=width,
1580
+ # height=height,
1581
+ # batch_size=batch_size * num_images_per_prompt,
1582
+ # num_images_per_prompt=num_images_per_prompt,
1583
+ # device=device,
1584
+ # dtype=controlnet.dtype,
1585
+ # crops_coords=crops_coords,
1586
+ # resize_mode=resize_mode,
1587
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
1588
+ # guess_mode=guess_mode,
1589
+ # )
1590
+
1591
+ # control_images.append(control_image_)
1592
+
1593
+ # control_image = control_images
1594
+ # else:
1595
+ # raise ValueError(f"{controlnet.__class__} is not supported.")
1596
+
1597
+ # 5.3 Prepare mask
1598
+ mask = mask_image
1599
+ masked_image = init_image * (mask < 0.5)
1600
+ _, _, height, width = init_image.shape
1601
+
1602
+ # 6. Prepare latent variables
1603
+ num_channels_latents = self.vae.config.latent_channels
1604
+ num_channels_unet = self.unet.config.in_channels
1605
+ return_image_latents = num_channels_unet == 4
1606
+
1607
+ add_noise = True if denoising_start is None else False
1608
+ latents_outputs = self.prepare_latents(
1609
+ batch_size * num_images_per_prompt,
1610
+ num_channels_latents,
1611
+ height,
1612
+ width,
1613
+ prompt_embeds.dtype,
1614
+ device,
1615
+ generator,
1616
+ latents,
1617
+ image=init_image,
1618
+ timestep=latent_timestep,
1619
+ is_strength_max=is_strength_max,
1620
+ add_noise=add_noise,
1621
+ return_noise=True,
1622
+ return_image_latents=return_image_latents,
1623
+ )
1624
+
1625
+ if return_image_latents:
1626
+ latents, noise, image_latents = latents_outputs
1627
+ else:
1628
+ latents, noise = latents_outputs
1629
+
1630
+ # 7. Prepare mask latent variables
1631
+ mask, masked_image_latents = self.prepare_mask_latents(
1632
+ mask,
1633
+ masked_image,
1634
+ batch_size * num_images_per_prompt,
1635
+ height,
1636
+ width,
1637
+ prompt_embeds.dtype,
1638
+ device,
1639
+ generator,
1640
+ self.do_classifier_free_guidance,
1641
+ )
1642
+
1643
+ # 8. Check that sizes of mask, masked image and latents match
1644
+ if num_channels_unet == 9:
1645
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
1646
+ num_channels_mask = mask.shape[1]
1647
+ num_channels_masked_image = masked_image_latents.shape[1]
1648
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1649
+ raise ValueError(
1650
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1651
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1652
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1653
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
1654
+ " `pipeline.unet` or your `mask_image` or `image` input."
1655
+ )
1656
+ elif num_channels_unet != 4:
1657
+ raise ValueError(
1658
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1659
+ )
1660
+ # 8.1 Prepare extra step kwargs.
1661
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1662
+
1663
+ # 8.2 Create tensor stating which controlnets to keep
1664
+ controlnet_keep = []
1665
+ for i in range(len(timesteps)):
1666
+ keeps = [
1667
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1668
+ for s, e in zip(control_guidance_start, control_guidance_end)
1669
+ ]
1670
+ controlnet_keep.append(keeps if isinstance(controlnet, MultiControlNetModel) else keeps[0])
1671
+
1672
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1673
+ height, width = latents.shape[-2:]
1674
+ height = height * self.vae_scale_factor
1675
+ width = width * self.vae_scale_factor
1676
+
1677
+ original_size = original_size or (height, width)
1678
+ target_size = target_size or (height, width)
1679
+
1680
+ # 10. Prepare added time ids & embeddings
1681
+ add_text_embeds = pooled_prompt_embeds
1682
+ if self.text_encoder_2 is None:
1683
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1684
+ else:
1685
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1686
+
1687
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1688
+ original_size,
1689
+ crops_coords_top_left,
1690
+ target_size,
1691
+ aesthetic_score,
1692
+ negative_aesthetic_score,
1693
+ dtype=prompt_embeds.dtype,
1694
+ text_encoder_projection_dim=text_encoder_projection_dim,
1695
+ )
1696
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1697
+
1698
+ if self.do_classifier_free_guidance:
1699
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1700
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1701
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1702
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1703
+
1704
+ prompt_embeds = prompt_embeds.to(device)
1705
+ add_text_embeds = add_text_embeds.to(device)
1706
+ add_time_ids = add_time_ids.to(device)
1707
+
1708
+ # 11. Denoising loop
1709
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1710
+
1711
+ if (
1712
+ denoising_end is not None
1713
+ and denoising_start is not None
1714
+ and denoising_value_valid(denoising_end)
1715
+ and denoising_value_valid(denoising_start)
1716
+ and denoising_start >= denoising_end
1717
+ ):
1718
+ raise ValueError(
1719
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
1720
+ + f" {denoising_end} when using type float."
1721
+ )
1722
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
1723
+ discrete_timestep_cutoff = int(
1724
+ round(
1725
+ self.scheduler.config.num_train_timesteps
1726
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
1727
+ )
1728
+ )
1729
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1730
+ timesteps = timesteps[:num_inference_steps]
1731
+
1732
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1733
+ for i, t in enumerate(timesteps):
1734
+ if self.interrupt:
1735
+ continue
1736
+
1737
+ # expand the latents if we are doing classifier free guidance
1738
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1739
+
1740
+ # concat latents, mask, masked_image_latents in the channel dimension
1741
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1742
+
1743
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1744
+
1745
+ # controlnet(s) inference
1746
+ if guess_mode and self.do_classifier_free_guidance:
1747
+ # Infer ControlNet only for the conditional batch.
1748
+ control_model_input = latents
1749
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1750
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1751
+ controlnet_added_cond_kwargs = {
1752
+ "text_embeds": add_text_embeds.chunk(2)[1],
1753
+ "time_ids": add_time_ids.chunk(2)[1],
1754
+ }
1755
+ else:
1756
+ control_model_input = latent_model_input
1757
+ controlnet_prompt_embeds = prompt_embeds
1758
+ controlnet_added_cond_kwargs = added_cond_kwargs
1759
+
1760
+ if isinstance(controlnet_keep[i], list):
1761
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1762
+ else:
1763
+ controlnet_cond_scale = controlnet_conditioning_scale
1764
+ if isinstance(controlnet_cond_scale, list):
1765
+ controlnet_cond_scale = controlnet_cond_scale[0]
1766
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1767
+
1768
+ # # Resize control_image to match the size of the input to the controlnet
1769
+ # if control_image.shape[-2:] != control_model_input.shape[-2:]:
1770
+ # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False)
1771
+
1772
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1773
+ control_model_input,
1774
+ t,
1775
+ encoder_hidden_states=controlnet_prompt_embeds,
1776
+ controlnet_cond=control_image,
1777
+ conditioning_scale=cond_scale,
1778
+ guess_mode=guess_mode,
1779
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1780
+ return_dict=False,
1781
+ )
1782
+
1783
+ if guess_mode and self.do_classifier_free_guidance:
1784
+ # Inferred ControlNet only for the conditional batch.
1785
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1786
+ # add 0 to the unconditional batch to keep it unchanged.
1787
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1788
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1789
+
1790
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1791
+ added_cond_kwargs["image_embeds"] = image_embeds
1792
+
1793
+ if num_channels_unet == 9:
1794
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1795
+
1796
+ # predict the noise residual
1797
+ noise_pred = self.unet(
1798
+ latent_model_input,
1799
+ t,
1800
+ encoder_hidden_states=prompt_embeds,
1801
+ cross_attention_kwargs=self.cross_attention_kwargs,
1802
+ down_block_additional_residuals=down_block_res_samples,
1803
+ mid_block_additional_residual=mid_block_res_sample,
1804
+ added_cond_kwargs=added_cond_kwargs,
1805
+ return_dict=False,
1806
+ )[0]
1807
+
1808
+ # perform guidance
1809
+ if self.do_classifier_free_guidance:
1810
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1811
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1812
+
1813
+ if self.do_classifier_free_guidance and guidance_rescale > 0.0:
1814
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1815
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1816
+
1817
+ # compute the previous noisy sample x_t -> x_t-1
1818
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1819
+
1820
+ if num_channels_unet == 4:
1821
+ init_latents_proper = image_latents
1822
+ if self.do_classifier_free_guidance:
1823
+ init_mask, _ = mask.chunk(2)
1824
+ else:
1825
+ init_mask = mask
1826
+
1827
+ if i < len(timesteps) - 1:
1828
+ noise_timestep = timesteps[i + 1]
1829
+ init_latents_proper = self.scheduler.add_noise(
1830
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1831
+ )
1832
+
1833
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1834
+
1835
+ if callback_on_step_end is not None:
1836
+ callback_kwargs = {}
1837
+ for k in callback_on_step_end_tensor_inputs:
1838
+ callback_kwargs[k] = locals()[k]
1839
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1840
+
1841
+ latents = callback_outputs.pop("latents", latents)
1842
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1843
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1844
+ control_image = callback_outputs.pop("control_image", control_image)
1845
+
1846
+ # call the callback, if provided
1847
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1848
+ progress_bar.update()
1849
+ if callback is not None and i % callback_steps == 0:
1850
+ step_idx = i // getattr(self.scheduler, "order", 1)
1851
+ callback(step_idx, t, latents)
1852
+
1853
+ if XLA_AVAILABLE:
1854
+ xm.mark_step()
1855
+
1856
+ # make sure the VAE is in float32 mode, as it overflows in float16
1857
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1858
+ self.upcast_vae()
1859
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1860
+
1861
+ # If we do sequential model offloading, let's offload unet and controlnet
1862
+ # manually for max memory savings
1863
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1864
+ self.unet.to("cpu")
1865
+ self.controlnet.to("cpu")
1866
+ torch.cuda.empty_cache()
1867
+
1868
+ if not output_type == "latent":
1869
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1870
+ else:
1871
+ return StableDiffusionXLPipelineOutput(images=latents)
1872
+
1873
+ # apply watermark if available
1874
+ if self.watermark is not None:
1875
+ image = self.watermark.apply_watermark(image)
1876
+
1877
+ # image = self.image_processor.postprocess(image, output_type=output_type)
1878
+ image = image[:, :, :448]
1879
+ # if not output_type == "latent":
1880
+ # # apply watermark if available
1881
+ # if self.watermark is not None:
1882
+ # image = self.watermark.apply_watermark(image)
1883
+
1884
+ # image = self.image_processor.postprocess(image, output_type=output_type)
1885
+
1886
+ # if padding_mask_crop is not None:
1887
+ # image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1888
+
1889
+ # Offload all models
1890
+ self.maybe_free_model_hooks()
1891
+
1892
+ if not return_dict:
1893
+ return (image,)
1894
+
1895
+ return StableDiffusionXLPipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==1.5.2
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.11.14
6
+ aiosignal==1.3.2
7
+ altair==5.5.0
8
+ annotated-types==0.7.0
9
+ anyio==4.9.0
10
+ argon2-cffi==23.1.0
11
+ argon2-cffi-bindings==21.2.0
12
+ arrow==1.3.0
13
+ asttokens==3.0.0
14
+ async-lru==2.0.5
15
+ async-timeout==5.0.1
16
+ attrs==25.3.0
17
+ autocommand==2.2.2
18
+ babel==2.17.0
19
+ backports.tarfile==1.2.0
20
+ beautifulsoup4==4.13.3
21
+ bleach==6.2.0
22
+ certifi==2025.1.31
23
+ cffi==1.17.1
24
+ charset-normalizer==3.4.1
25
+ click==8.1.8
26
+ comm==0.2.2
27
+ contourpy==1.3.0
28
+ cycler==0.12.1
29
+ datasets==3.4.1
30
+ debugpy==1.8.13
31
+ decorator==5.2.1
32
+ deepspeed==0.15.4
33
+ defusedxml==0.7.1
34
+ diffusers
35
+ dill==0.3.8
36
+ docker-pycreds==0.4.0
37
+ einops==0.8.1
38
+ eval_type_backport==0.2.2
39
+ exceptiongroup==1.2.2
40
+ executing==2.2.0
41
+ fastapi==0.115.12
42
+ fastjsonschema==2.21.1
43
+ ffmpy==0.5.0
44
+ filelock==3.13.1
45
+ fonttools==4.56.0
46
+ fqdn==1.5.1
47
+ frozenlist==1.5.0
48
+ fsspec==2024.6.1
49
+ ftfy==6.3.1
50
+ gitdb==4.0.12
51
+ GitPython==3.1.44
52
+ gradio==3.36.1
53
+ gradio_client==0.6.1
54
+ grpcio==1.71.0
55
+ h11==0.14.0
56
+ hjson==3.1.0
57
+ httpcore==1.0.7
58
+ httpx==0.28.1
59
+ huggingface-hub==0.29.3
60
+ idna==3.10
61
+ importlib_metadata==8.6.1
62
+ importlib_resources==6.5.2
63
+ inflect==7.3.1
64
+ ipykernel==6.29.5
65
+ ipython==8.18.1
66
+ isoduration==20.11.0
67
+ jaraco.collections==5.1.0
68
+ jaraco.context==5.3.0
69
+ jaraco.functools==4.0.1
70
+ jaraco.text==3.12.1
71
+ jedi==0.19.2
72
+ Jinja2==3.1.4
73
+ json5==0.10.0
74
+ jsonpointer==3.0.0
75
+ jsonschema==4.23.0
76
+ jsonschema-specifications==2024.10.1
77
+ jupyter-events==0.12.0
78
+ jupyter-lsp==2.2.5
79
+ jupyter_client==8.6.3
80
+ jupyter_core==5.7.2
81
+ jupyter_server==2.15.0
82
+ jupyter_server_terminals==0.5.3
83
+ jupyterlab==4.3.6
84
+ jupyterlab_pygments==0.3.0
85
+ jupyterlab_server==2.27.3
86
+ kiwisolver==1.4.7
87
+ linkify-it-py==2.0.3
88
+ Markdown==3.7
89
+ markdown-it-py==2.2.0
90
+ MarkupSafe==2.1.5
91
+ matplotlib==3.9.4
92
+ matplotlib-inline==0.1.7
93
+ mdit-py-plugins==0.3.3
94
+ mdurl==0.1.2
95
+ mistune==3.1.3
96
+ more-itertools==10.3.0
97
+ mpmath==1.3.0
98
+ msgpack==1.1.0
99
+ multidict==6.2.0
100
+ multiprocess==0.70.16
101
+ narwhals==1.33.0
102
+ nbclient==0.10.2
103
+ nbconvert==7.16.6
104
+ nbformat==5.10.4
105
+ nest-asyncio==1.6.0
106
+ networkx==3.2.1
107
+ ninja==1.11.1.3
108
+ notebook_shim==0.2.4
109
+ numpy==1.26.3
110
+ nvidia-cublas-cu12==12.4.5.8
111
+ nvidia-cuda-cupti-cu12==12.4.127
112
+ nvidia-cuda-nvrtc-cu12==12.4.127
113
+ nvidia-cuda-runtime-cu12==12.4.127
114
+ nvidia-cudnn-cu12==9.1.0.70
115
+ nvidia-cufft-cu12==11.2.1.3
116
+ nvidia-curand-cu12==10.3.5.147
117
+ nvidia-cusolver-cu12==11.6.1.9
118
+ nvidia-cusparse-cu12==12.3.1.170
119
+ nvidia-cusparselt-cu12==0.6.2
120
+ nvidia-ml-py==12.570.86
121
+ nvidia-nccl-cu12==2.21.5
122
+ nvidia-nvjitlink-cu12==12.4.127
123
+ nvidia-nvtx-cu12==12.4.127
124
+ opencv-python==4.11.0.86
125
+ orjson==3.10.16
126
+ overrides==7.7.0
127
+ packaging==24.2
128
+ pandas==2.2.3
129
+ pandocfilters==1.5.1
130
+ parso==0.8.4
131
+ pexpect==4.9.0
132
+ pillow==10.4.0
133
+ platformdirs==4.3.6
134
+ prometheus_client==0.21.1
135
+ prompt_toolkit==3.0.50
136
+ propcache==0.3.0
137
+ protobuf==5.29.3
138
+ psutil==7.0.0
139
+ ptyprocess==0.7.0
140
+ pure_eval==0.2.3
141
+ py-cpuinfo==9.0.0
142
+ pyarrow==19.0.1
143
+ pycparser==2.22
144
+ pydantic==2.10.6
145
+ pydantic_core==2.27.2
146
+ pydub==0.25.1
147
+ Pygments==2.19.1
148
+ pyparsing==3.2.2
149
+ python-dateutil==2.9.0.post0
150
+ python-json-logger==3.3.0
151
+ python-multipart==0.0.20
152
+ pytz==2025.1
153
+ PyYAML==6.0.2
154
+ pyzmq==26.3.0
155
+ referencing==0.36.2
156
+ regex==2024.11.6
157
+ requests==2.32.3
158
+ rfc3339-validator==0.1.4
159
+ rfc3986-validator==0.1.1
160
+ rich==14.0.0
161
+ rpds-py==0.23.1
162
+ ruff==0.11.2
163
+ safetensors==0.5.3
164
+ semantic-version==2.10.0
165
+ Send2Trash==1.8.3
166
+ sentry-sdk==2.23.1
167
+ setproctitle==1.3.5
168
+ shellingham==1.5.4
169
+ six==1.17.0
170
+ smmap==5.0.2
171
+ sniffio==1.3.1
172
+ soupsieve==2.6
173
+ stack-data==0.6.3
174
+ starlette==0.46.1
175
+ sympy==1.13.1
176
+ tensorboard==2.19.0
177
+ tensorboard-data-server==0.7.2
178
+ terminado==0.18.1
179
+ tinycss2==1.4.0
180
+ tokenizers==0.21.1
181
+ tomli==2.0.1
182
+ tomlkit==0.12.0
183
+ torch==2.5.1
184
+ torchaudio==2.5.1
185
+ torchvision==0.20.1
186
+ tornado==6.4.2
187
+ tqdm==4.67.1
188
+ traitlets==5.14.3
189
+ transformers==4.49.0
190
+ triton==3.1.0
191
+ typeguard==4.3.0
192
+ typer==0.15.2
193
+ types-python-dateutil==2.9.0.20241206
194
+ typing_extensions==4.12.2
195
+ tzdata==2025.1
196
+ uc-micro-py==1.0.3
197
+ uri-template==1.3.0
198
+ urllib3==2.3.0
199
+ uvicorn==0.34.0
200
+ wcwidth==0.2.13
201
+ webcolors==24.11.1
202
+ webencodings==0.5.1
203
+ websocket-client==1.8.0
204
+ websockets==11.0.3
205
+ Werkzeug==3.1.3
206
+ xxhash==3.5.0
207
+ yarl==1.18.3
208
+ zipp==3.21.0