AlanB commited on
Commit
e5ecc74
·
1 Parent(s): f4783b6

Upload pipeline.py

Browse files

From my fork https://github.com/Skquark/diffusers

Files changed (1) hide show
  1. pipeline.py +1148 -0
pipeline.py ADDED
@@ -0,0 +1,1148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import re
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ import PIL
9
+ from diffusers.configuration_utils import FrozenDict
10
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
+ from diffusers.pipeline_utils import DiffusionPipeline
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
+ from diffusers.utils import deprecate, is_accelerate_available, logging
16
+
17
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
18
+ from packaging import version
19
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
20
+
21
+
22
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
23
+ PIL_INTERPOLATION = {
24
+ "linear": PIL.Image.Resampling.BILINEAR,
25
+ "bilinear": PIL.Image.Resampling.BILINEAR,
26
+ "bicubic": PIL.Image.Resampling.BICUBIC,
27
+ "lanczos": PIL.Image.Resampling.LANCZOS,
28
+ "nearest": PIL.Image.Resampling.NEAREST,
29
+ }
30
+ else:
31
+ PIL_INTERPOLATION = {
32
+ "linear": PIL.Image.LINEAR,
33
+ "bilinear": PIL.Image.BILINEAR,
34
+ "bicubic": PIL.Image.BICUBIC,
35
+ "lanczos": PIL.Image.LANCZOS,
36
+ "nearest": PIL.Image.NEAREST,
37
+ }
38
+ # ------------------------------------------------------------------------------
39
+
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ re_attention = re.compile(
44
+ r"""
45
+ \\\(|
46
+ \\\)|
47
+ \\\[|
48
+ \\]|
49
+ \\\\|
50
+ \\|
51
+ \(|
52
+ \[|
53
+ :([+-]?[.\d]+)\)|
54
+ \)|
55
+ ]|
56
+ [^\\()\[\]:]+|
57
+ :
58
+ """,
59
+ re.X,
60
+ )
61
+
62
+
63
+ def parse_prompt_attention(text):
64
+ """
65
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
66
+ Accepted tokens are:
67
+ (abc) - increases attention to abc by a multiplier of 1.1
68
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
69
+ [abc] - decreases attention to abc by a multiplier of 1.1
70
+ \( - literal character '('
71
+ \[ - literal character '['
72
+ \) - literal character ')'
73
+ \] - literal character ']'
74
+ \\ - literal character '\'
75
+ anything else - just text
76
+ >>> parse_prompt_attention('normal text')
77
+ [['normal text', 1.0]]
78
+ >>> parse_prompt_attention('an (important) word')
79
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
80
+ >>> parse_prompt_attention('(unbalanced')
81
+ [['unbalanced', 1.1]]
82
+ >>> parse_prompt_attention('\(literal\]')
83
+ [['(literal]', 1.0]]
84
+ >>> parse_prompt_attention('(unnecessary)(parens)')
85
+ [['unnecessaryparens', 1.1]]
86
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
87
+ [['a ', 1.0],
88
+ ['house', 1.5730000000000004],
89
+ [' ', 1.1],
90
+ ['on', 1.0],
91
+ [' a ', 1.1],
92
+ ['hill', 0.55],
93
+ [', sun, ', 1.1],
94
+ ['sky', 1.4641000000000006],
95
+ ['.', 1.1]]
96
+ """
97
+
98
+ res = []
99
+ round_brackets = []
100
+ square_brackets = []
101
+
102
+ round_bracket_multiplier = 1.1
103
+ square_bracket_multiplier = 1 / 1.1
104
+
105
+ def multiply_range(start_position, multiplier):
106
+ for p in range(start_position, len(res)):
107
+ res[p][1] *= multiplier
108
+
109
+ for m in re_attention.finditer(text):
110
+ text = m.group(0)
111
+ weight = m.group(1)
112
+
113
+ if text.startswith("\\"):
114
+ res.append([text[1:], 1.0])
115
+ elif text == "(":
116
+ round_brackets.append(len(res))
117
+ elif text == "[":
118
+ square_brackets.append(len(res))
119
+ elif weight is not None and len(round_brackets) > 0:
120
+ multiply_range(round_brackets.pop(), float(weight))
121
+ elif text == ")" and len(round_brackets) > 0:
122
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
123
+ elif text == "]" and len(square_brackets) > 0:
124
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
125
+ else:
126
+ res.append([text, 1.0])
127
+
128
+ for pos in round_brackets:
129
+ multiply_range(pos, round_bracket_multiplier)
130
+
131
+ for pos in square_brackets:
132
+ multiply_range(pos, square_bracket_multiplier)
133
+
134
+ if len(res) == 0:
135
+ res = [["", 1.0]]
136
+
137
+ # merge runs of identical weights
138
+ i = 0
139
+ while i + 1 < len(res):
140
+ if res[i][1] == res[i + 1][1]:
141
+ res[i][0] += res[i + 1][0]
142
+ res.pop(i + 1)
143
+ else:
144
+ i += 1
145
+
146
+ return res
147
+
148
+
149
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
150
+ r"""
151
+ Tokenize a list of prompts and return its tokens with weights of each token.
152
+
153
+ No padding, starting or ending token is included.
154
+ """
155
+ tokens = []
156
+ weights = []
157
+ truncated = False
158
+ for text in prompt:
159
+ texts_and_weights = parse_prompt_attention(text)
160
+ text_token = []
161
+ text_weight = []
162
+ for word, weight in texts_and_weights:
163
+ # tokenize and discard the starting and the ending token
164
+ token = pipe.tokenizer(word).input_ids[1:-1]
165
+ text_token += token
166
+ # copy the weight by length of token
167
+ text_weight += [weight] * len(token)
168
+ # stop if the text is too long (longer than truncation limit)
169
+ if len(text_token) > max_length:
170
+ truncated = True
171
+ break
172
+ # truncate
173
+ if len(text_token) > max_length:
174
+ truncated = True
175
+ text_token = text_token[:max_length]
176
+ text_weight = text_weight[:max_length]
177
+ tokens.append(text_token)
178
+ weights.append(text_weight)
179
+ if truncated:
180
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
181
+ return tokens, weights
182
+
183
+
184
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
185
+ r"""
186
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
187
+ """
188
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
189
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
190
+ for i in range(len(tokens)):
191
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
192
+ if no_boseos_middle:
193
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
194
+ else:
195
+ w = []
196
+ if len(weights[i]) == 0:
197
+ w = [1.0] * weights_length
198
+ else:
199
+ for j in range(max_embeddings_multiples):
200
+ w.append(1.0) # weight for starting token in this chunk
201
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
202
+ w.append(1.0) # weight for ending token in this chunk
203
+ w += [1.0] * (weights_length - len(w))
204
+ weights[i] = w[:]
205
+
206
+ return tokens, weights
207
+
208
+
209
+ def get_unweighted_text_embeddings(
210
+ pipe: DiffusionPipeline,
211
+ text_input: torch.Tensor,
212
+ chunk_length: int,
213
+ no_boseos_middle: Optional[bool] = True,
214
+ ):
215
+ """
216
+ When the length of tokens is a multiple of the capacity of the text encoder,
217
+ it should be split into chunks and sent to the text encoder individually.
218
+ """
219
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
220
+ if max_embeddings_multiples > 1:
221
+ text_embeddings = []
222
+ for i in range(max_embeddings_multiples):
223
+ # extract the i-th chunk
224
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
225
+
226
+ # cover the head and the tail by the starting and the ending tokens
227
+ text_input_chunk[:, 0] = text_input[0, 0]
228
+ text_input_chunk[:, -1] = text_input[0, -1]
229
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
230
+
231
+ if no_boseos_middle:
232
+ if i == 0:
233
+ # discard the ending token
234
+ text_embedding = text_embedding[:, :-1]
235
+ elif i == max_embeddings_multiples - 1:
236
+ # discard the starting token
237
+ text_embedding = text_embedding[:, 1:]
238
+ else:
239
+ # discard both starting and ending tokens
240
+ text_embedding = text_embedding[:, 1:-1]
241
+
242
+ text_embeddings.append(text_embedding)
243
+ text_embeddings = torch.concat(text_embeddings, axis=1)
244
+ else:
245
+ text_embeddings = pipe.text_encoder(text_input)[0]
246
+ return text_embeddings
247
+
248
+
249
+ def get_weighted_text_embeddings(
250
+ pipe: DiffusionPipeline,
251
+ prompt: Union[str, List[str]],
252
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
253
+ max_embeddings_multiples: Optional[int] = 1,
254
+ no_boseos_middle: Optional[bool] = False,
255
+ skip_parsing: Optional[bool] = False,
256
+ skip_weighting: Optional[bool] = False,
257
+ **kwargs,
258
+ ):
259
+ r"""
260
+ Prompts can be assigned with local weights using brackets. For example,
261
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
262
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
263
+
264
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
265
+
266
+ Args:
267
+ pipe (`DiffusionPipeline`):
268
+ Pipe to provide access to the tokenizer and the text encoder.
269
+ prompt (`str` or `List[str]`):
270
+ The prompt or prompts to guide the image generation.
271
+ uncond_prompt (`str` or `List[str]`):
272
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
273
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
274
+ max_embeddings_multiples (`int`, *optional*, defaults to `1`):
275
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
276
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
277
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
278
+ ending token in each of the chunk in the middle.
279
+ skip_parsing (`bool`, *optional*, defaults to `False`):
280
+ Skip the parsing of brackets.
281
+ skip_weighting (`bool`, *optional*, defaults to `False`):
282
+ Skip the weighting. When the parsing is skipped, it is forced True.
283
+ """
284
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
285
+ if isinstance(prompt, str):
286
+ prompt = [prompt]
287
+
288
+ if not skip_parsing:
289
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
290
+ if uncond_prompt is not None:
291
+ if isinstance(uncond_prompt, str):
292
+ uncond_prompt = [uncond_prompt]
293
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
294
+ else:
295
+ prompt_tokens = [
296
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
297
+ ]
298
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
299
+ if uncond_prompt is not None:
300
+ if isinstance(uncond_prompt, str):
301
+ uncond_prompt = [uncond_prompt]
302
+ uncond_tokens = [
303
+ token[1:-1]
304
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
305
+ ]
306
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
307
+
308
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
309
+ max_length = max([len(token) for token in prompt_tokens])
310
+ if uncond_prompt is not None:
311
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
312
+
313
+ max_embeddings_multiples = min(
314
+ max_embeddings_multiples,
315
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
316
+ )
317
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
318
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
319
+
320
+ # pad the length of tokens and weights
321
+ bos = pipe.tokenizer.bos_token_id
322
+ eos = pipe.tokenizer.eos_token_id
323
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
324
+ prompt_tokens,
325
+ prompt_weights,
326
+ max_length,
327
+ bos,
328
+ eos,
329
+ no_boseos_middle=no_boseos_middle,
330
+ chunk_length=pipe.tokenizer.model_max_length,
331
+ )
332
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
333
+ if uncond_prompt is not None:
334
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
335
+ uncond_tokens,
336
+ uncond_weights,
337
+ max_length,
338
+ bos,
339
+ eos,
340
+ no_boseos_middle=no_boseos_middle,
341
+ chunk_length=pipe.tokenizer.model_max_length,
342
+ )
343
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
344
+
345
+ # get the embeddings
346
+ text_embeddings = get_unweighted_text_embeddings(
347
+ pipe,
348
+ prompt_tokens,
349
+ pipe.tokenizer.model_max_length,
350
+ no_boseos_middle=no_boseos_middle,
351
+ )
352
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
353
+ if uncond_prompt is not None:
354
+ uncond_embeddings = get_unweighted_text_embeddings(
355
+ pipe,
356
+ uncond_tokens,
357
+ pipe.tokenizer.model_max_length,
358
+ no_boseos_middle=no_boseos_middle,
359
+ )
360
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
361
+
362
+ # assign weights to the prompts and normalize in the sense of mean
363
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
364
+ if (not skip_parsing) and (not skip_weighting):
365
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
366
+ text_embeddings *= prompt_weights.unsqueeze(-1)
367
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
368
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
369
+ if uncond_prompt is not None:
370
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
371
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
372
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
373
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
374
+
375
+ if uncond_prompt is not None:
376
+ return text_embeddings, uncond_embeddings
377
+ return text_embeddings, None
378
+
379
+
380
+ def preprocess_image(image):
381
+ w, h = image.size
382
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
383
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
384
+ image = np.array(image).astype(np.float32) / 255.0
385
+ image = image[None].transpose(0, 3, 1, 2)
386
+ image = torch.from_numpy(image)
387
+ return 2.0 * image - 1.0
388
+
389
+
390
+ def preprocess_mask(mask):
391
+ mask = mask.convert("L")
392
+ w, h = mask.size
393
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
394
+ mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
395
+ mask = np.array(mask).astype(np.float32) / 255.0
396
+ mask = np.tile(mask, (4, 1, 1))
397
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
398
+ mask = 1 - mask # repaint white, keep black
399
+ mask = torch.from_numpy(mask)
400
+ return mask
401
+
402
+
403
+ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
404
+ r"""
405
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
406
+ weighting in prompt.
407
+
408
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
409
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
410
+
411
+ Args:
412
+ vae ([`AutoencoderKL`]):
413
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
414
+ text_encoder ([`CLIPTextModel`]):
415
+ Frozen text-encoder. Stable Diffusion uses the text portion of
416
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
417
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
418
+ tokenizer (`CLIPTokenizer`):
419
+ Tokenizer of class
420
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
421
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
422
+ scheduler ([`SchedulerMixin`]):
423
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
424
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
425
+ safety_checker ([`StableDiffusionSafetyChecker`]):
426
+ Classification module that estimates whether generated images could be considered offensive or harmful.
427
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
428
+ feature_extractor ([`CLIPFeatureExtractor`]):
429
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
430
+ """
431
+
432
+ def __init__(
433
+ self,
434
+ vae: AutoencoderKL,
435
+ text_encoder: CLIPTextModel,
436
+ tokenizer: CLIPTokenizer,
437
+ unet: UNet2DConditionModel,
438
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
439
+ safety_checker: StableDiffusionSafetyChecker,
440
+ feature_extractor: CLIPFeatureExtractor,
441
+ ):
442
+ super().__init__()
443
+
444
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
445
+ deprecation_message = (
446
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
447
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
448
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
449
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
450
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
451
+ " file"
452
+ )
453
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
454
+ new_config = dict(scheduler.config)
455
+ new_config["steps_offset"] = 1
456
+ scheduler._internal_dict = FrozenDict(new_config)
457
+
458
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
459
+ deprecation_message = (
460
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
461
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
462
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
463
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
464
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
465
+ )
466
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
467
+ new_config = dict(scheduler.config)
468
+ new_config["clip_sample"] = False
469
+ scheduler._internal_dict = FrozenDict(new_config)
470
+
471
+ if safety_checker is None:
472
+ logger.warning(
473
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
474
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
475
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
476
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
477
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
478
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
479
+ )
480
+
481
+ self.register_modules(
482
+ vae=vae,
483
+ text_encoder=text_encoder,
484
+ tokenizer=tokenizer,
485
+ unet=unet,
486
+ scheduler=scheduler,
487
+ safety_checker=safety_checker,
488
+ feature_extractor=feature_extractor,
489
+ )
490
+
491
+ def enable_xformers_memory_efficient_attention(self):
492
+ r"""
493
+ Enable memory efficient attention as implemented in xformers.
494
+
495
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
496
+ time. Speed up at training time is not guaranteed.
497
+
498
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
499
+ is used.
500
+ """
501
+ self.unet.set_use_memory_efficient_attention_xformers(True)
502
+
503
+ def disable_xformers_memory_efficient_attention(self):
504
+ r"""
505
+ Disable memory efficient attention as implemented in xformers.
506
+ """
507
+ self.unet.set_use_memory_efficient_attention_xformers(False)
508
+
509
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
510
+ r"""
511
+ Enable sliced attention computation.
512
+
513
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
514
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
515
+
516
+ Args:
517
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
518
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
519
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
520
+ `attention_head_dim` must be a multiple of `slice_size`.
521
+ """
522
+ if slice_size == "auto":
523
+ # half the attention head size is usually a good trade-off between
524
+ # speed and memory
525
+ slice_size = self.unet.config.attention_head_dim // 2
526
+ self.unet.set_attention_slice(slice_size)
527
+
528
+ def disable_attention_slicing(self):
529
+ r"""
530
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
531
+ back to computing attention in one step.
532
+ """
533
+ # set slice_size = `None` to disable `attention slicing`
534
+ self.enable_attention_slicing(None)
535
+
536
+ def enable_sequential_cpu_offload(self):
537
+ r"""
538
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
539
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
540
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
541
+ """
542
+ if is_accelerate_available():
543
+ from accelerate import cpu_offload
544
+ else:
545
+ raise ImportError("Please install accelerate via `pip install accelerate`")
546
+
547
+ device = self.device
548
+
549
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
550
+ if cpu_offloaded_model is not None:
551
+ cpu_offload(cpu_offloaded_model, device)
552
+
553
+ @torch.no_grad()
554
+ def __call__(
555
+ self,
556
+ prompt: Union[str, List[str]],
557
+ negative_prompt: Optional[Union[str, List[str]]] = None,
558
+ init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
559
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
560
+ height: int = 512,
561
+ width: int = 512,
562
+ num_inference_steps: int = 50,
563
+ guidance_scale: float = 7.5,
564
+ strength: float = 0.8,
565
+ num_images_per_prompt: Optional[int] = 1,
566
+ eta: float = 0.0,
567
+ generator: Optional[torch.Generator] = None,
568
+ latents: Optional[torch.FloatTensor] = None,
569
+ max_embeddings_multiples: Optional[int] = 3,
570
+ output_type: Optional[str] = "pil",
571
+ return_dict: bool = True,
572
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
573
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
574
+ callback_steps: Optional[int] = 1,
575
+ **kwargs,
576
+ ):
577
+ r"""
578
+ Function invoked when calling the pipeline for generation.
579
+
580
+ Args:
581
+ prompt (`str` or `List[str]`):
582
+ The prompt or prompts to guide the image generation.
583
+ negative_prompt (`str` or `List[str]`, *optional*):
584
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
585
+ if `guidance_scale` is less than `1`).
586
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
587
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
588
+ process.
589
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
590
+ `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
591
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
592
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
593
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
594
+ height (`int`, *optional*, defaults to 512):
595
+ The height in pixels of the generated image.
596
+ width (`int`, *optional*, defaults to 512):
597
+ The width in pixels of the generated image.
598
+ num_inference_steps (`int`, *optional*, defaults to 50):
599
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
600
+ expense of slower inference.
601
+ guidance_scale (`float`, *optional*, defaults to 7.5):
602
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
603
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
604
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
605
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
606
+ usually at the expense of lower image quality.
607
+ strength (`float`, *optional*, defaults to 0.8):
608
+ Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
609
+ `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
610
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
611
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
612
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
613
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
614
+ The number of images to generate per prompt.
615
+ eta (`float`, *optional*, defaults to 0.0):
616
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
617
+ [`schedulers.DDIMScheduler`], will be ignored for others.
618
+ generator (`torch.Generator`, *optional*):
619
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
620
+ deterministic.
621
+ latents (`torch.FloatTensor`, *optional*):
622
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
623
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
624
+ tensor will ge generated by sampling using the supplied random `generator`.
625
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
626
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
627
+ output_type (`str`, *optional*, defaults to `"pil"`):
628
+ The output format of the generate image. Choose between
629
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
630
+ return_dict (`bool`, *optional*, defaults to `True`):
631
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
632
+ plain tuple.
633
+ callback (`Callable`, *optional*):
634
+ A function that will be called every `callback_steps` steps during inference. The function will be
635
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
636
+ is_cancelled_callback (`Callable`, *optional*):
637
+ A function that will be called every `callback_steps` steps during inference. If the function returns
638
+ `True`, the inference will be cancelled.
639
+ callback_steps (`int`, *optional*, defaults to 1):
640
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
641
+ called at every step.
642
+
643
+ Returns:
644
+ `None` if cancelled by `is_cancelled_callback`,
645
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
646
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
647
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
648
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
649
+ (nsfw) content, according to the `safety_checker`.
650
+ """
651
+
652
+ if isinstance(prompt, str):
653
+ batch_size = 1
654
+ prompt = [prompt]
655
+ elif isinstance(prompt, list):
656
+ batch_size = len(prompt)
657
+ else:
658
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
659
+
660
+ if strength < 0 or strength > 1:
661
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
662
+
663
+ if height % 8 != 0 or width % 8 != 0:
664
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
665
+
666
+ if (callback_steps is None) or (
667
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
668
+ ):
669
+ raise ValueError(
670
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
671
+ f" {type(callback_steps)}."
672
+ )
673
+
674
+ # get prompt text embeddings
675
+
676
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
677
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
678
+ # corresponds to doing no classifier free guidance.
679
+ do_classifier_free_guidance = guidance_scale > 1.0
680
+ # get unconditional embeddings for classifier free guidance
681
+ if negative_prompt is None:
682
+ negative_prompt = [""] * batch_size
683
+ elif isinstance(negative_prompt, str):
684
+ negative_prompt = [negative_prompt] * batch_size
685
+ if batch_size != len(negative_prompt):
686
+ raise ValueError(
687
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
688
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
689
+ " the batch size of `prompt`."
690
+ )
691
+
692
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
693
+ pipe=self,
694
+ prompt=prompt,
695
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
696
+ max_embeddings_multiples=max_embeddings_multiples,
697
+ **kwargs,
698
+ )
699
+ bs_embed, seq_len, _ = text_embeddings.shape
700
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
701
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
702
+
703
+ if do_classifier_free_guidance:
704
+ bs_embed, seq_len, _ = uncond_embeddings.shape
705
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
706
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
707
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
708
+
709
+ # set timesteps
710
+ self.scheduler.set_timesteps(num_inference_steps)
711
+
712
+ latents_dtype = text_embeddings.dtype
713
+ init_latents_orig = None
714
+ mask = None
715
+ noise = None
716
+
717
+ if init_image is None:
718
+ # get the initial random noise unless the user supplied it
719
+
720
+ # Unlike in other pipelines, latents need to be generated in the target device
721
+ # for 1-to-1 results reproducibility with the CompVis implementation.
722
+ # However this currently doesn't work in `mps`.
723
+ latents_shape = (
724
+ batch_size * num_images_per_prompt,
725
+ self.unet.in_channels,
726
+ height // 8,
727
+ width // 8,
728
+ )
729
+
730
+ if latents is None:
731
+ if self.device.type == "mps":
732
+ # randn does not exist on mps
733
+ latents = torch.randn(
734
+ latents_shape,
735
+ generator=generator,
736
+ device="cpu",
737
+ dtype=latents_dtype,
738
+ ).to(self.device)
739
+ else:
740
+ latents = torch.randn(
741
+ latents_shape,
742
+ generator=generator,
743
+ device=self.device,
744
+ dtype=latents_dtype,
745
+ )
746
+ else:
747
+ if latents.shape != latents_shape:
748
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
749
+ latents = latents.to(self.device)
750
+
751
+ timesteps = self.scheduler.timesteps.to(self.device)
752
+
753
+ # scale the initial noise by the standard deviation required by the scheduler
754
+ latents = latents * self.scheduler.init_noise_sigma
755
+ else:
756
+ if isinstance(init_image, PIL.Image.Image):
757
+ init_image = preprocess_image(init_image)
758
+ # encode the init image into latents and scale the latents
759
+ init_image = init_image.to(device=self.device, dtype=latents_dtype)
760
+ init_latent_dist = self.vae.encode(init_image).latent_dist
761
+ init_latents = init_latent_dist.sample(generator=generator)
762
+ init_latents = 0.18215 * init_latents
763
+ init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
764
+ init_latents_orig = init_latents
765
+
766
+ # preprocess mask
767
+ if mask_image is not None:
768
+ if isinstance(mask_image, PIL.Image.Image):
769
+ mask_image = preprocess_mask(mask_image)
770
+ mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
771
+ mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
772
+
773
+ # check sizes
774
+ if not mask.shape == init_latents.shape:
775
+ raise ValueError("The mask and init_image should be the same size!")
776
+
777
+ # get the original timestep using init_timestep
778
+ offset = self.scheduler.config.get("steps_offset", 0)
779
+ init_timestep = int(num_inference_steps * strength) + offset
780
+ init_timestep = min(init_timestep, num_inference_steps)
781
+
782
+ timesteps = self.scheduler.timesteps[-init_timestep]
783
+ timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
784
+
785
+ # add noise to latents using the timesteps
786
+ if self.device.type == "mps":
787
+ # randn does not exist on mps
788
+ noise = torch.randn(
789
+ init_latents.shape,
790
+ generator=generator,
791
+ device="cpu",
792
+ dtype=latents_dtype,
793
+ ).to(self.device)
794
+ else:
795
+ noise = torch.randn(
796
+ init_latents.shape,
797
+ generator=generator,
798
+ device=self.device,
799
+ dtype=latents_dtype,
800
+ )
801
+ latents = self.scheduler.add_noise(init_latents, noise, timesteps)
802
+
803
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
804
+ timesteps = self.scheduler.timesteps[t_start:].to(self.device)
805
+
806
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
807
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
808
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
809
+ # and should be between [0, 1]
810
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
811
+ extra_step_kwargs = {}
812
+ if accepts_eta:
813
+ extra_step_kwargs["eta"] = eta
814
+
815
+ for i, t in enumerate(self.progress_bar(timesteps)):
816
+ # expand the latents if we are doing classifier free guidance
817
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
818
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
819
+
820
+ # predict the noise residual
821
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
822
+
823
+ # perform guidance
824
+ if do_classifier_free_guidance:
825
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
826
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
827
+
828
+ # compute the previous noisy sample x_t -> x_t-1
829
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
830
+
831
+ if mask is not None:
832
+ # masking
833
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
834
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
835
+
836
+ # call the callback, if provided
837
+ if i % callback_steps == 0:
838
+ if callback is not None:
839
+ callback(i, t, latents)
840
+ if is_cancelled_callback is not None and is_cancelled_callback():
841
+ return None
842
+
843
+ latents = 1 / 0.18215 * latents
844
+ image = self.vae.decode(latents).sample
845
+
846
+ image = (image / 2 + 0.5).clamp(0, 1)
847
+
848
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
849
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
850
+
851
+ if self.safety_checker is not None:
852
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
853
+ self.device
854
+ )
855
+ image, has_nsfw_concept = self.safety_checker(
856
+ images=image,
857
+ clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
858
+ )
859
+ else:
860
+ has_nsfw_concept = None
861
+
862
+ if output_type == "pil":
863
+ image = self.numpy_to_pil(image)
864
+
865
+ if not return_dict:
866
+ return (image, has_nsfw_concept)
867
+
868
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
869
+
870
+ def text2img(
871
+ self,
872
+ prompt: Union[str, List[str]],
873
+ negative_prompt: Optional[Union[str, List[str]]] = None,
874
+ height: int = 512,
875
+ width: int = 512,
876
+ num_inference_steps: int = 50,
877
+ guidance_scale: float = 7.5,
878
+ num_images_per_prompt: Optional[int] = 1,
879
+ eta: float = 0.0,
880
+ generator: Optional[torch.Generator] = None,
881
+ latents: Optional[torch.FloatTensor] = None,
882
+ max_embeddings_multiples: Optional[int] = 3,
883
+ output_type: Optional[str] = "pil",
884
+ return_dict: bool = True,
885
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
886
+ callback_steps: Optional[int] = 1,
887
+ **kwargs,
888
+ ):
889
+ r"""
890
+ Function for text-to-image generation.
891
+ Args:
892
+ prompt (`str` or `List[str]`):
893
+ The prompt or prompts to guide the image generation.
894
+ negative_prompt (`str` or `List[str]`, *optional*):
895
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
896
+ if `guidance_scale` is less than `1`).
897
+ height (`int`, *optional*, defaults to 512):
898
+ The height in pixels of the generated image.
899
+ width (`int`, *optional*, defaults to 512):
900
+ The width in pixels of the generated image.
901
+ num_inference_steps (`int`, *optional*, defaults to 50):
902
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
903
+ expense of slower inference.
904
+ guidance_scale (`float`, *optional*, defaults to 7.5):
905
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
906
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
907
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
908
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
909
+ usually at the expense of lower image quality.
910
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
911
+ The number of images to generate per prompt.
912
+ eta (`float`, *optional*, defaults to 0.0):
913
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
914
+ [`schedulers.DDIMScheduler`], will be ignored for others.
915
+ generator (`torch.Generator`, *optional*):
916
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
917
+ deterministic.
918
+ latents (`torch.FloatTensor`, *optional*):
919
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
920
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
921
+ tensor will ge generated by sampling using the supplied random `generator`.
922
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
923
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
924
+ output_type (`str`, *optional*, defaults to `"pil"`):
925
+ The output format of the generate image. Choose between
926
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
927
+ return_dict (`bool`, *optional*, defaults to `True`):
928
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
929
+ plain tuple.
930
+ callback (`Callable`, *optional*):
931
+ A function that will be called every `callback_steps` steps during inference. The function will be
932
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
933
+ callback_steps (`int`, *optional*, defaults to 1):
934
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
935
+ called at every step.
936
+ Returns:
937
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
938
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
939
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
940
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
941
+ (nsfw) content, according to the `safety_checker`.
942
+ """
943
+ return self.__call__(
944
+ prompt=prompt,
945
+ negative_prompt=negative_prompt,
946
+ height=height,
947
+ width=width,
948
+ num_inference_steps=num_inference_steps,
949
+ guidance_scale=guidance_scale,
950
+ num_images_per_prompt=num_images_per_prompt,
951
+ eta=eta,
952
+ generator=generator,
953
+ latents=latents,
954
+ max_embeddings_multiples=max_embeddings_multiples,
955
+ output_type=output_type,
956
+ return_dict=return_dict,
957
+ callback=callback,
958
+ callback_steps=callback_steps,
959
+ **kwargs,
960
+ )
961
+
962
+ def img2img(
963
+ self,
964
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
965
+ prompt: Union[str, List[str]],
966
+ negative_prompt: Optional[Union[str, List[str]]] = None,
967
+ strength: float = 0.8,
968
+ num_inference_steps: Optional[int] = 50,
969
+ guidance_scale: Optional[float] = 7.5,
970
+ num_images_per_prompt: Optional[int] = 1,
971
+ eta: Optional[float] = 0.0,
972
+ generator: Optional[torch.Generator] = None,
973
+ max_embeddings_multiples: Optional[int] = 3,
974
+ output_type: Optional[str] = "pil",
975
+ return_dict: bool = True,
976
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
977
+ callback_steps: Optional[int] = 1,
978
+ **kwargs,
979
+ ):
980
+ r"""
981
+ Function for image-to-image generation.
982
+ Args:
983
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
984
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
985
+ process.
986
+ prompt (`str` or `List[str]`):
987
+ The prompt or prompts to guide the image generation.
988
+ negative_prompt (`str` or `List[str]`, *optional*):
989
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
990
+ if `guidance_scale` is less than `1`).
991
+ strength (`float`, *optional*, defaults to 0.8):
992
+ Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
993
+ `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
994
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
995
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
996
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
997
+ num_inference_steps (`int`, *optional*, defaults to 50):
998
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
999
+ expense of slower inference. This parameter will be modulated by `strength`.
1000
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1001
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1002
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1003
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1004
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1005
+ usually at the expense of lower image quality.
1006
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1007
+ The number of images to generate per prompt.
1008
+ eta (`float`, *optional*, defaults to 0.0):
1009
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1010
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1011
+ generator (`torch.Generator`, *optional*):
1012
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1013
+ deterministic.
1014
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1015
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1016
+ output_type (`str`, *optional*, defaults to `"pil"`):
1017
+ The output format of the generate image. Choose between
1018
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1019
+ return_dict (`bool`, *optional*, defaults to `True`):
1020
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1021
+ plain tuple.
1022
+ callback (`Callable`, *optional*):
1023
+ A function that will be called every `callback_steps` steps during inference. The function will be
1024
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1025
+ callback_steps (`int`, *optional*, defaults to 1):
1026
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1027
+ called at every step.
1028
+ Returns:
1029
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1030
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1031
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1032
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1033
+ (nsfw) content, according to the `safety_checker`.
1034
+ """
1035
+ return self.__call__(
1036
+ prompt=prompt,
1037
+ negative_prompt=negative_prompt,
1038
+ init_image=init_image,
1039
+ num_inference_steps=num_inference_steps,
1040
+ guidance_scale=guidance_scale,
1041
+ strength=strength,
1042
+ num_images_per_prompt=num_images_per_prompt,
1043
+ eta=eta,
1044
+ generator=generator,
1045
+ max_embeddings_multiples=max_embeddings_multiples,
1046
+ output_type=output_type,
1047
+ return_dict=return_dict,
1048
+ callback=callback,
1049
+ callback_steps=callback_steps,
1050
+ **kwargs,
1051
+ )
1052
+
1053
+ def inpaint(
1054
+ self,
1055
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
1056
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1057
+ prompt: Union[str, List[str]],
1058
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1059
+ strength: float = 0.8,
1060
+ num_inference_steps: Optional[int] = 50,
1061
+ guidance_scale: Optional[float] = 7.5,
1062
+ num_images_per_prompt: Optional[int] = 1,
1063
+ eta: Optional[float] = 0.0,
1064
+ generator: Optional[torch.Generator] = None,
1065
+ max_embeddings_multiples: Optional[int] = 3,
1066
+ output_type: Optional[str] = "pil",
1067
+ return_dict: bool = True,
1068
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1069
+ callback_steps: Optional[int] = 1,
1070
+ **kwargs,
1071
+ ):
1072
+ r"""
1073
+ Function for inpaint.
1074
+ Args:
1075
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
1076
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1077
+ process. This is the image whose masked region will be inpainted.
1078
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1079
+ `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
1080
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1081
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1082
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1083
+ prompt (`str` or `List[str]`):
1084
+ The prompt or prompts to guide the image generation.
1085
+ negative_prompt (`str` or `List[str]`, *optional*):
1086
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1087
+ if `guidance_scale` is less than `1`).
1088
+ strength (`float`, *optional*, defaults to 0.8):
1089
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1090
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1091
+ in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
1092
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1093
+ num_inference_steps (`int`, *optional*, defaults to 50):
1094
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1095
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1096
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1097
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1098
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1099
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1100
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1101
+ usually at the expense of lower image quality.
1102
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1103
+ The number of images to generate per prompt.
1104
+ eta (`float`, *optional*, defaults to 0.0):
1105
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1106
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1107
+ generator (`torch.Generator`, *optional*):
1108
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1109
+ deterministic.
1110
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1111
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1112
+ output_type (`str`, *optional*, defaults to `"pil"`):
1113
+ The output format of the generate image. Choose between
1114
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1115
+ return_dict (`bool`, *optional*, defaults to `True`):
1116
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1117
+ plain tuple.
1118
+ callback (`Callable`, *optional*):
1119
+ A function that will be called every `callback_steps` steps during inference. The function will be
1120
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1121
+ callback_steps (`int`, *optional*, defaults to 1):
1122
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1123
+ called at every step.
1124
+ Returns:
1125
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1126
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1127
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1128
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1129
+ (nsfw) content, according to the `safety_checker`.
1130
+ """
1131
+ return self.__call__(
1132
+ prompt=prompt,
1133
+ negative_prompt=negative_prompt,
1134
+ init_image=init_image,
1135
+ mask_image=mask_image,
1136
+ num_inference_steps=num_inference_steps,
1137
+ guidance_scale=guidance_scale,
1138
+ strength=strength,
1139
+ num_images_per_prompt=num_images_per_prompt,
1140
+ eta=eta,
1141
+ generator=generator,
1142
+ max_embeddings_multiples=max_embeddings_multiples,
1143
+ output_type=output_type,
1144
+ return_dict=return_dict,
1145
+ callback=callback,
1146
+ callback_steps=callback_steps,
1147
+ **kwargs,
1148
+ )