quickjkee commited on
Commit
f3f05b4
·
verified ·
1 Parent(s): 0472a4c

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +282 -0
pipeline.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from typing import Any, Callable, Dict, List, Union, Optional
18
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
19
+ from diffusers.utils import (
20
+ USE_PEFT_BACKEND,
21
+ is_torch_xla_available,
22
+ logging,
23
+ replace_example_docstring,
24
+ scale_lora_layers,
25
+ unscale_lora_layers,
26
+ )
27
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
28
+ from diffusers import DiffusionPipeline, ImagePipelineOutput
29
+
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
+
37
+
38
+
39
+ class SwDPipeline(DiffusionPipeline):
40
+
41
+ @torch.no_grad()
42
+ def __call__(
43
+ self,
44
+ prompt: Union[str, List[str]] = None,
45
+ prompt_2: Optional[Union[str, List[str]]] = None,
46
+ prompt_3: Optional[Union[str, List[str]]] = None,
47
+ height: Optional[int] = None,
48
+ width: Optional[int] = None,
49
+ num_inference_steps: int = 28,
50
+ sigmas: Optional[List[float]] = None,
51
+ timesteps: Optional[List[float]] = None,
52
+ scales: List[float] = None,
53
+ guidance_scale: float = 7.0,
54
+ negative_prompt: Optional[Union[str, List[str]]] = None,
55
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
56
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
57
+ num_images_per_prompt: Optional[int] = 1,
58
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
59
+ latents: Optional[torch.FloatTensor] = None,
60
+ prompt_embeds: Optional[torch.FloatTensor] = None,
61
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
62
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
63
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
64
+ ip_adapter_image: Optional[PipelineImageInput] = None,
65
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
66
+ output_type: Optional[str] = "pil",
67
+ return_dict: bool = True,
68
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
69
+ clip_skip: Optional[int] = None,
70
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
71
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
72
+ max_sequence_length: int = 256,
73
+ skip_guidance_layers: List[int] = None,
74
+ skip_layer_guidance_scale: float = 2.8,
75
+ skip_layer_guidance_stop: float = 0.2,
76
+ skip_layer_guidance_start: float = 0.01,
77
+ mu: Optional[float] = None,
78
+ ):
79
+ height = height or self.default_sample_size * self.vae_scale_factor
80
+ width = width or self.default_sample_size * self.vae_scale_factor
81
+
82
+ # 1. Check inputs. Raise error if not correct
83
+ self.check_inputs(
84
+ prompt,
85
+ prompt_2,
86
+ prompt_3,
87
+ height,
88
+ width,
89
+ negative_prompt=negative_prompt,
90
+ negative_prompt_2=negative_prompt_2,
91
+ negative_prompt_3=negative_prompt_3,
92
+ prompt_embeds=prompt_embeds,
93
+ negative_prompt_embeds=negative_prompt_embeds,
94
+ pooled_prompt_embeds=pooled_prompt_embeds,
95
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
96
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
97
+ max_sequence_length=max_sequence_length,
98
+ )
99
+
100
+ self._guidance_scale = guidance_scale
101
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
102
+ self._clip_skip = clip_skip
103
+ self._joint_attention_kwargs = joint_attention_kwargs
104
+ self._interrupt = False
105
+
106
+ # 2. Define call parameters
107
+ if prompt is not None and isinstance(prompt, str):
108
+ batch_size = 1
109
+ elif prompt is not None and isinstance(prompt, list):
110
+ batch_size = len(prompt)
111
+ else:
112
+ batch_size = prompt_embeds.shape[0]
113
+
114
+ device = self._execution_device
115
+
116
+ lora_scale = (
117
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
118
+ )
119
+ (
120
+ prompt_embeds,
121
+ negative_prompt_embeds,
122
+ pooled_prompt_embeds,
123
+ negative_pooled_prompt_embeds,
124
+ ) = self.encode_prompt(
125
+ prompt=prompt,
126
+ prompt_2=prompt_2,
127
+ prompt_3=prompt_3,
128
+ negative_prompt=negative_prompt,
129
+ negative_prompt_2=negative_prompt_2,
130
+ negative_prompt_3=negative_prompt_3,
131
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
132
+ prompt_embeds=prompt_embeds,
133
+ negative_prompt_embeds=negative_prompt_embeds,
134
+ pooled_prompt_embeds=pooled_prompt_embeds,
135
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
136
+ device=device,
137
+ clip_skip=self.clip_skip,
138
+ num_images_per_prompt=num_images_per_prompt,
139
+ max_sequence_length=max_sequence_length,
140
+ lora_scale=lora_scale,
141
+ )
142
+
143
+ if self.do_classifier_free_guidance:
144
+ if skip_guidance_layers is not None:
145
+ original_prompt_embeds = prompt_embeds
146
+ original_pooled_prompt_embeds = pooled_prompt_embeds
147
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
148
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
149
+
150
+ # 4. Prepare latent variables
151
+ num_channels_latents = self.transformer.config.in_channels
152
+ latents = self.prepare_latents(
153
+ batch_size * num_images_per_prompt,
154
+ num_channels_latents,
155
+ height,
156
+ width,
157
+ prompt_embeds.dtype,
158
+ device,
159
+ generator,
160
+ latents,
161
+ )
162
+
163
+ # 5. Prepare timesteps
164
+ scheduler_kwargs = {}
165
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
166
+ self._num_timesteps = len(timesteps)
167
+
168
+ # 6. Prepare image embeddings
169
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
170
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
171
+ ip_adapter_image,
172
+ ip_adapter_image_embeds,
173
+ device,
174
+ batch_size * num_images_per_prompt,
175
+ self.do_classifier_free_guidance,
176
+ )
177
+
178
+ if self.joint_attention_kwargs is None:
179
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
180
+ else:
181
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
182
+
183
+ # 7. Denoising loop
184
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
185
+ for i, t in enumerate(timesteps):
186
+ if self.interrupt:
187
+ continue
188
+
189
+ # expand the latents if we are doing classifier free guidance
190
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
191
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
192
+ timestep = t.expand(latent_model_input.shape[0])
193
+
194
+ noise_pred = self.transformer(
195
+ hidden_states=latent_model_input,
196
+ timestep=timestep,
197
+ encoder_hidden_states=prompt_embeds,
198
+ pooled_projections=pooled_prompt_embeds,
199
+ joint_attention_kwargs=self.joint_attention_kwargs,
200
+ return_dict=False,
201
+ )[0]
202
+
203
+ # perform guidance
204
+ if self.do_classifier_free_guidance:
205
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
206
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
207
+ should_skip_layers = (
208
+ True
209
+ if i > num_inference_steps * skip_layer_guidance_start
210
+ and i < num_inference_steps * skip_layer_guidance_stop
211
+ else False
212
+ )
213
+ if skip_guidance_layers is not None and should_skip_layers:
214
+ timestep = t.expand(latents.shape[0])
215
+ latent_model_input = latents
216
+ noise_pred_skip_layers = self.transformer(
217
+ hidden_states=latent_model_input,
218
+ timestep=timestep,
219
+ encoder_hidden_states=original_prompt_embeds,
220
+ pooled_projections=original_pooled_prompt_embeds,
221
+ joint_attention_kwargs=self.joint_attention_kwargs,
222
+ return_dict=False,
223
+ skip_layers=skip_guidance_layers,
224
+ )[0]
225
+ noise_pred = (
226
+ noise_pred + (
227
+ noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
228
+ )
229
+
230
+ # compute the previous noisy sample x_t -> x_t-1
231
+ latents_dtype = latents.dtype
232
+ sigma = sigmas[i]
233
+ sigma_next = sigmas[i + 1]
234
+ x0_pred = (latents - sigma * noise_pred)
235
+ try:
236
+ x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
237
+ except IndexError:
238
+ x0_pred = x0_pred
239
+ noise = torch.randn(x0_pred.shape, generator=generator).to('cuda').half()
240
+ latents = (1 - sigma_next) * x0_pred + sigma_next * noise
241
+
242
+ if latents.dtype != latents_dtype:
243
+ if torch.backends.mps.is_available():
244
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
245
+ latents = latents.to(latents_dtype)
246
+
247
+ if callback_on_step_end is not None:
248
+ callback_kwargs = {}
249
+ for k in callback_on_step_end_tensor_inputs:
250
+ callback_kwargs[k] = locals()[k]
251
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
252
+
253
+ latents = callback_outputs.pop("latents", latents)
254
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
255
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
256
+ negative_pooled_prompt_embeds = callback_outputs.pop(
257
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
258
+ )
259
+
260
+ # call the callback, if provided
261
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
262
+ progress_bar.update()
263
+
264
+ if XLA_AVAILABLE:
265
+ xm.mark_step()
266
+
267
+ if output_type == "latent":
268
+ image = latents
269
+
270
+ else:
271
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
272
+
273
+ image = self.vae.decode(latents, return_dict=False)[0]
274
+ image = self.image_processor.postprocess(image, output_type=output_type)
275
+
276
+ # Offload all models
277
+ self.maybe_free_model_hooks()
278
+
279
+ if not return_dict:
280
+ return (image,)
281
+
282
+ return StableDiffusion3PipelineOutput(images=image)