animrods commited on
Commit
2122032
·
verified ·
1 Parent(s): 8a0a37b

Delete bria_utils.py

Browse files
Files changed (1) hide show
  1. bria_utils.py +0 -302
bria_utils.py DELETED
@@ -1,302 +0,0 @@
1
- from typing import Union, Optional, List
2
- import torch
3
- from diffusers.utils import logging
4
- from transformers import (
5
- T5EncoderModel,
6
- T5TokenizerFast,
7
- )
8
- from transformers import (
9
- CLIPTextModel,
10
- CLIPTextModelWithProjection,
11
- CLIPTokenizer
12
- )
13
-
14
- import numpy as np
15
- import torch.distributed as dist
16
- import math
17
- import os
18
-
19
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
-
21
- def get_t5_prompt_embeds(
22
- tokenizer: T5TokenizerFast ,
23
- text_encoder: T5EncoderModel,
24
- prompt: Union[str, List[str]] = None,
25
- num_images_per_prompt: int = 1,
26
- max_sequence_length: int = 128,
27
- device: Optional[torch.device] = None,
28
- ):
29
- device = device or text_encoder.device
30
-
31
- prompt = [prompt] if isinstance(prompt, str) else prompt
32
- batch_size = len(prompt)
33
-
34
- text_inputs = tokenizer(
35
- prompt,
36
- # padding="max_length",
37
- max_length=max_sequence_length,
38
- truncation=True,
39
- add_special_tokens=True,
40
- return_tensors="pt",
41
- )
42
- text_input_ids = text_inputs.input_ids
43
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
44
-
45
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
46
- removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
47
- logger.warning(
48
- "The following part of your input was truncated because `max_sequence_length` is set to "
49
- f" {max_sequence_length} tokens: {removed_text}"
50
- )
51
-
52
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
53
-
54
- # Concat zeros to max_sequence
55
- b, seq_len, dim = prompt_embeds.shape
56
- if seq_len<max_sequence_length:
57
- padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
58
- prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)
59
-
60
- prompt_embeds = prompt_embeds.to(device=device)
61
-
62
- _, seq_len, _ = prompt_embeds.shape
63
-
64
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
65
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
66
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
67
-
68
- return prompt_embeds
69
-
70
- # in order the get the same sigmas as in training and sample from them
71
- def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
72
- timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
73
- sigmas = timesteps / num_train_timesteps
74
-
75
- inds = [int(ind) for ind in np.linspace(0, num_train_timesteps-1, num_inference_steps)]
76
- new_sigmas = sigmas[inds]
77
- return new_sigmas
78
-
79
- def is_ng_none(negative_prompt):
80
- return negative_prompt is None or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')
81
-
82
- class CudaTimerContext:
83
- def __init__(self, times_arr):
84
- self.times_arr = times_arr
85
-
86
- def __enter__(self):
87
- self.before_event = torch.cuda.Event(enable_timing=True)
88
- self.after_event = torch.cuda.Event(enable_timing=True)
89
- self.before_event.record()
90
-
91
- def __exit__(self, type, value, traceback):
92
- self.after_event.record()
93
- torch.cuda.synchronize()
94
- elapsed_time = self.before_event.elapsed_time(self.after_event)/1000
95
- self.times_arr.append(elapsed_time)
96
-
97
-
98
- def get_env_prefix():
99
- env = os.environ.get("CLOUD_PROVIDER",'AWS').upper()
100
- if env=='AWS':
101
- return 'SM_CHANNEL'
102
- elif env=='AZURE':
103
- return 'AZUREML_DATAREFERENCE'
104
-
105
- raise Exception(f'Env {env} not supported')
106
-
107
-
108
- def compute_density_for_timestep_sampling(
109
- weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
110
- ):
111
- """Compute the density for sampling the timesteps when doing SD3 training.
112
-
113
- Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
114
-
115
- SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
116
- """
117
- if weighting_scheme == "logit_normal":
118
- # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
119
- u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
120
- u = torch.nn.functional.sigmoid(u)
121
- elif weighting_scheme == "mode":
122
- u = torch.rand(size=(batch_size,), device="cpu")
123
- u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
124
- else:
125
- u = torch.rand(size=(batch_size,), device="cpu")
126
- return u
127
-
128
- def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
129
- """Computes loss weighting scheme for SD3 training.
130
-
131
- Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
132
-
133
- SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
134
- """
135
- if weighting_scheme == "sigma_sqrt":
136
- weighting = (sigmas**-2.0).float()
137
- elif weighting_scheme == "cosmap":
138
- bot = 1 - 2 * sigmas + 2 * sigmas**2
139
- weighting = 2 / (math.pi * bot)
140
- else:
141
- weighting = torch.ones_like(sigmas)
142
- return weighting
143
-
144
-
145
- def initialize_distributed():
146
- # Initialize the process group for distributed training
147
- dist.init_process_group('nccl')
148
-
149
- # Get the current process's rank (ID) and the total number of processes (world size)
150
- rank = dist.get_rank()
151
- world_size = dist.get_world_size()
152
-
153
- print(f"Initialized distributed training: Rank {rank}/{world_size}")
154
-
155
-
156
- def get_clip_prompt_embeds(
157
- text_encoder: CLIPTextModel,
158
- text_encoder_2: CLIPTextModelWithProjection,
159
- tokenizer: CLIPTokenizer,
160
- tokenizer_2: CLIPTokenizer,
161
- prompt: Union[str, List[str]] = None,
162
- num_images_per_prompt: int = 1,
163
- max_sequence_length: int = 77,
164
- device: Optional[torch.device] = None,
165
- ):
166
-
167
- device = device or text_encoder.device
168
- assert max_sequence_length == tokenizer.model_max_length
169
- prompt = [prompt] if isinstance(prompt, str) else prompt
170
-
171
- # Define tokenizers and text encoders
172
- tokenizers = [tokenizer, tokenizer_2]
173
- text_encoders = [text_encoder, text_encoder_2]
174
-
175
- # textual inversion: process multi-vector tokens if necessary
176
- prompt_embeds_list = []
177
- prompts = [prompt, prompt]
178
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
179
- text_inputs = tokenizer(
180
- prompt,
181
- padding="max_length",
182
- max_length=tokenizer.model_max_length,
183
- truncation=True,
184
- return_tensors="pt",
185
- )
186
-
187
- text_input_ids = text_inputs.input_ids
188
- prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
189
-
190
- # We are only ALWAYS interested in the pooled output of the final text encoder
191
- pooled_prompt_embeds = prompt_embeds[0]
192
- prompt_embeds = prompt_embeds.hidden_states[-2]
193
-
194
- prompt_embeds_list.append(prompt_embeds)
195
-
196
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
197
-
198
-
199
- bs_embed, seq_len, _ = prompt_embeds.shape
200
- # duplicate text embeddings for each generation per prompt, using mps friendly method
201
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
202
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
203
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
204
- bs_embed * num_images_per_prompt, -1
205
- )
206
-
207
- return prompt_embeds, pooled_prompt_embeds
208
-
209
- def get_1d_rotary_pos_embed(
210
- dim: int,
211
- pos: Union[np.ndarray, int],
212
- theta: float = 10000.0,
213
- use_real=False,
214
- linear_factor=1.0,
215
- ntk_factor=1.0,
216
- repeat_interleave_real=True,
217
- freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
218
- ):
219
- """
220
- Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
221
-
222
- This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
223
- index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
224
- data type.
225
-
226
- Args:
227
- dim (`int`): Dimension of the frequency tensor.
228
- pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
229
- theta (`float`, *optional*, defaults to 10000.0):
230
- Scaling factor for frequency computation. Defaults to 10000.0.
231
- use_real (`bool`, *optional*):
232
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
233
- linear_factor (`float`, *optional*, defaults to 1.0):
234
- Scaling factor for the context extrapolation. Defaults to 1.0.
235
- ntk_factor (`float`, *optional*, defaults to 1.0):
236
- Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
237
- repeat_interleave_real (`bool`, *optional*, defaults to `True`):
238
- If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
239
- Otherwise, they are concateanted with themselves.
240
- freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
241
- the dtype of the frequency tensor.
242
- Returns:
243
- `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
244
- """
245
- assert dim % 2 == 0
246
-
247
- if isinstance(pos, int):
248
- pos = torch.arange(pos)
249
- if isinstance(pos, np.ndarray):
250
- pos = torch.from_numpy(pos) # type: ignore # [S]
251
-
252
- theta = theta * ntk_factor
253
- freqs = (
254
- 1.0
255
- / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
256
- / linear_factor
257
- ) # [D/2]
258
- freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
259
- if use_real and repeat_interleave_real:
260
- # flux, hunyuan-dit, cogvideox
261
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
262
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
263
- return freqs_cos, freqs_sin
264
- elif use_real:
265
- # stable audio, allegro
266
- freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
267
- freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
268
- return freqs_cos, freqs_sin
269
- else:
270
- # lumina
271
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
272
- return freqs_cis
273
-
274
-
275
- class FluxPosEmbed(torch.nn.Module):
276
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
277
- def __init__(self, theta: int, axes_dim: List[int]):
278
- super().__init__()
279
- self.theta = theta
280
- self.axes_dim = axes_dim
281
-
282
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
283
- n_axes = ids.shape[-1]
284
- cos_out = []
285
- sin_out = []
286
- pos = ids.float()
287
- is_mps = ids.device.type == "mps"
288
- freqs_dtype = torch.float32 if is_mps else torch.float64
289
- for i in range(n_axes):
290
- cos, sin = get_1d_rotary_pos_embed(
291
- self.axes_dim[i],
292
- pos[:, i],
293
- theta=self.theta,
294
- repeat_interleave_real=True,
295
- use_real=True,
296
- freqs_dtype=freqs_dtype,
297
- )
298
- cos_out.append(cos)
299
- sin_out.append(sin)
300
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
301
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
302
- return freqs_cos, freqs_sin