tyfeld commited on
Commit
ea359a8
·
1 Parent(s): 2f15a78
app.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from transformers import AutoTokenizer
6
+ from torchvision import transforms
7
+ from models import MAGVITv2, get_mask_schedule, MMadaModelLM
8
+ from training.prompting_utils import UniversalPrompting
9
+ from PIL import Image
10
+
11
+ def image_transform(image, resolution=256, normalize=True):
12
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
13
+ image = transforms.CenterCrop((resolution, resolution))(image)
14
+ image = transforms.ToTensor()(image)
15
+ if normalize:
16
+ image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
17
+ return image
18
+
19
+ def add_gumbel_noise(logits, temperature):
20
+ """
21
+ Adds Gumbel noise to logits for stochastic sampling.
22
+ Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
23
+ This version is more numerically stable than a version involving exp() and division.
24
+ """
25
+ if abs(temperature) < 1e-9: # Effectively zero temperature
26
+ return logits
27
+ # Ensure logits are float64 for precision with noise, as suggested by user context
28
+ logits = logits.to(torch.float64)
29
+ # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1)
30
+ # Add small epsilon for numerical stability inside logs
31
+ noise = torch.rand_like(logits, dtype=torch.float64)
32
+ standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
33
+ return logits + temperature * standard_gumbel_noise
34
+
35
+ def get_num_transfer_tokens(mask_index, steps):
36
+ mask_num = mask_index.sum(dim=1, keepdim=True)
37
+ # Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0)
38
+ steps = max(1, int(steps)) # Ensure steps is a positive integer
39
+ base = mask_num // steps
40
+ remainder = mask_num % steps
41
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
42
+ for i in range(mask_num.size(0)): # Iterate over batch
43
+ if remainder[i] > 0 : # Ensure remainder is positive before indexing
44
+ num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
45
+ return num_transfer_tokens
46
+
47
+ MODEL = None
48
+ TOKENIZER = None
49
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ MASK_ID = None
51
+ uni_prompting = None
52
+ VQ_MODEL = MAGVITv2().from_pretrained("/data_storage/shared/pretrained_models/models--showlab--magvitv2").to(DEVICE)
53
+
54
+ DEFAULT_MODEL_PATH = "/data_storage/lbw/MMaDA/mmada-training-stage3-llada-instruct-512-cot-uni/checkpoint-210000/unwrapped_model" # Default
55
+ CURRENT_MODEL_PATH = None
56
+
57
+ MODEL_CHOICES = [
58
+ "MMaDA-8B-Base",
59
+ "MMaDA-8B-MixCoT (coming soon)",
60
+ "MMaDA-8B-Max (coming soon)"
61
+ ]
62
+ MODEL_ACTUAL_PATHS = {
63
+ "MMaDA-8B-Base": DEFAULT_MODEL_PATH,
64
+ }
65
+
66
+ def clear_outputs_action():
67
+ return None, None
68
+
69
+ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
70
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
71
+
72
+ if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load:
73
+ return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}"
74
+
75
+ CURRENT_MODEL_PATH = model_path_to_load
76
+
77
+ status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
78
+ try:
79
+ TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
80
+ status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
81
+
82
+ MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
83
+ status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
84
+
85
+ uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
86
+
87
+ if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
88
+ MASK_ID = TOKENIZER.mask_token_id
89
+ status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
90
+ else:
91
+ MASK_ID = 126336
92
+ status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
93
+
94
+ if TOKENIZER.pad_token_id is None:
95
+ if TOKENIZER.eos_token_id is not None:
96
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
97
+ TOKENIZER.pad_token = TOKENIZER.eos_token
98
+ status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
99
+ else:
100
+ status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
101
+
102
+ if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
103
+ status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
104
+
105
+ TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
106
+
107
+ return " ".join(status_msg_parts)
108
+ except Exception as e:
109
+ MODEL = None
110
+ TOKENIZER = None
111
+ MASK_ID = None
112
+ CURRENT_MODEL_PATH = None
113
+ return f"Error loading model '{model_display_name_for_status}': {str(e)}"
114
+
115
+ def handle_model_selection_change(selected_model_name_ui):
116
+ if "coming soon" in selected_model_name_ui.lower():
117
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
118
+ MODEL = None
119
+ TOKENIZER = None
120
+ MASK_ID = None
121
+ CURRENT_MODEL_PATH = None
122
+ return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
123
+
124
+ actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
125
+ if not actual_path:
126
+ return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
127
+
128
+ return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
129
+
130
+
131
+ def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
132
+ if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
133
+ return [("Error in sequence data for visualization.", "ERROR")]
134
+ # only answer part
135
+ current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
136
+ seq_ids = current_x_ids_batch[0].tolist()
137
+ eos_token_id = tk.eos_token_id # Get EOS token ID
138
+
139
+ # Stage 1: Build initial list of tuples with (token_str, label, token_id_int)
140
+ # This helps in identifying EOS tokens later without re-checking the type.
141
+ intermediate_tuples = []
142
+ for j, token_id_int in enumerate(seq_ids):
143
+ try:
144
+ token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
145
+ except Exception: # Handle cases where a token ID might be problematic (e.g. with mock)
146
+ token_str = f"[ID:{token_id_int}]"
147
+
148
+ label = "ERROR"
149
+ if token_id_int == current_mask_id:
150
+ token_str = "[MASK]"
151
+ label = "MASK"
152
+ else:
153
+ label = "GEN"
154
+ intermediate_tuples.append((token_str, label, token_id_int))
155
+
156
+ return intermediate_tuples
157
+
158
+ @torch.no_grad()
159
+ def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
160
+ global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
161
+
162
+ if MODEL is None or TOKENIZER is None or MASK_ID is None:
163
+ yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
164
+ return
165
+ steps = int(steps)
166
+ guidance_scale = float(guidance_scale)
167
+
168
+ image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
169
+ prompt_text = [prompt_text]
170
+ input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
171
+
172
+ if guidance_scale > 0:
173
+ uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
174
+ else:
175
+ uncond_input_ids, uncond_attention_mask = None, None
176
+
177
+ mask_schedule = get_mask_schedule(mask_schedule)
178
+ blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
179
+ yield blank_image, "Starting generation..."
180
+ for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
181
+ input_ids = input_ids,
182
+ uncond_input_ids = uncond_input_ids,
183
+ attention_mask = attention_mask,
184
+ uncond_attention_mask = uncond_attention_mask,
185
+ temperature=1.0,
186
+ timesteps = steps,
187
+ guidance_scale = guidance_scale,
188
+ noise_schedule = mask_schedule,
189
+ noise_type = "mask",
190
+ seq_len = 1024,
191
+ vq_model = VQ_MODEL,
192
+ uni_prompting=uni_prompting):
193
+ yield image_step, status_msg_step
194
+
195
+
196
+
197
+
198
+ @torch.no_grad()
199
+ def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
200
+ cfg_scale, remasking_strategy, thinking_mode_lm):
201
+ global MODEL, TOKENIZER, MASK_ID, DEVICE
202
+ print(f"thinking_mode_lm: {thinking_mode_lm}")
203
+ if MODEL is None or TOKENIZER is None or MASK_ID is None:
204
+ yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
205
+ return
206
+
207
+ steps = int(steps)
208
+ gen_length = int(gen_length)
209
+ block_length = int(block_length)
210
+
211
+ if thinking_mode_lm:
212
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
213
+
214
+ try:
215
+ m = [{"role": "user", "content": prompt_text}]
216
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
217
+ except Exception as e:
218
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
219
+ processed_prompt_text = prompt_text
220
+ try:
221
+ if TOKENIZER.pad_token_id is None:
222
+ if TOKENIZER.eos_token_id is not None:
223
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
224
+ else: # Should have been caught by load_model, but double check
225
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
226
+ return
227
+
228
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
229
+ raw_prompt_attention_mask = None
230
+
231
+ except Exception as e:
232
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
233
+ return
234
+
235
+
236
+
237
+ batch_size = input_ids.shape[0]
238
+ prompt_len = input_ids.shape[1]
239
+
240
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
241
+ x[:, :prompt_len] = input_ids.clone()
242
+
243
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
244
+
245
+ if gen_length == 0:
246
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
247
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
248
+ return
249
+
250
+ if block_length <= 0 or gen_length % block_length != 0 :
251
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
252
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
253
+ return
254
+ num_blocks = gen_length // block_length
255
+
256
+ if steps <=0 or steps % num_blocks != 0:
257
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
258
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
259
+ return
260
+ steps_per_block = steps // num_blocks
261
+
262
+ for num_block_iter in range(num_blocks):
263
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
264
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
265
+
266
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
267
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
268
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
269
+
270
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
271
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
272
+ steps_per_block
273
+ )
274
+
275
+ for i_step_in_block in range(steps_per_block):
276
+ mask_index_global = (x == MASK_ID)
277
+
278
+ if cfg_scale > 0.:
279
+ un_x = x.clone()
280
+ # For unconditional pass, mask out the original prompt tokens that are not padding
281
+ # raw_prompt_attention_mask is (B, prompt_len)
282
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
283
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
284
+
285
+ x_cfg_input = torch.cat([x, un_x], dim=0)
286
+ # Pass attention_mask for CFG if model expects it, covering both parts
287
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
288
+ model_output = MODEL(x_cfg_input)
289
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
290
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
291
+ else:
292
+ # Not passing explicit attention_mask here; relies on model's internal handling.
293
+ model_output = MODEL(x)
294
+ logits = model_output.logits
295
+
296
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
297
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
298
+
299
+ if remasking_strategy == 'low_confidence':
300
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
301
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
302
+ elif remasking_strategy == 'random':
303
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
304
+ else:
305
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
306
+ return
307
+
308
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
309
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
310
+ confidence_for_selection = torch.where(
311
+ candidate_positions_for_unmasking,
312
+ x0_probs,
313
+ -torch.inf
314
+ )
315
+
316
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
317
+
318
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
319
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
320
+
321
+ for j_batch_idx in range(batch_size):
322
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
323
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
324
+
325
+ if k_val > 0:
326
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
327
+ conf_slice = confidence_for_selection[j_batch_idx]
328
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
329
+
330
+ # Check if there are enough valid (non -inf) confidences
331
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
332
+ actual_k = min(k_val, valid_conf_count)
333
+
334
+ if actual_k > 0:
335
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
336
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
337
+
338
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
339
+
340
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
341
+ total_overall_steps = num_blocks * steps_per_block
342
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
343
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
344
+
345
+ final_generated_ids = x[:, prompt_len:]
346
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
347
+
348
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
349
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
350
+
351
+ @torch.no_grad()
352
+ def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
353
+ cfg_scale, remasking_strategy, thinking_mode_mmu):
354
+ global MODEL, TOKENIZER, MASK_ID, DEVICE
355
+
356
+ if MODEL is None or TOKENIZER is None or MASK_ID is None:
357
+ yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
358
+ return
359
+
360
+ steps = int(steps)
361
+ gen_length = int(gen_length)
362
+ block_length = int(block_length)
363
+
364
+ if thinking_mode_mmu:
365
+ prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
366
+
367
+ try:
368
+ m = [{"role": "user", "content": prompt_text}]
369
+ processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
370
+ except Exception as e:
371
+ yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
372
+ processed_prompt_text = prompt_text
373
+
374
+ image_vq_ids_tensor = None
375
+ if uploaded_image_pil is not None:
376
+ try:
377
+
378
+ image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
379
+ image = image.unsqueeze(0)
380
+ image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
381
+ except Exception as e:
382
+ yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
383
+ return
384
+
385
+
386
+ try:
387
+ if TOKENIZER.pad_token_id is None:
388
+ if TOKENIZER.eos_token_id is not None:
389
+ TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
390
+ else:
391
+ yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
392
+ return
393
+
394
+ input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
395
+ raw_prompt_attention_mask = None
396
+ if image_vq_ids_tensor is not None:
397
+ if image_vq_ids_tensor.ndim == 1:
398
+ image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
399
+
400
+ input_ids = torch.cat([
401
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
402
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
403
+ image_vq_ids_tensor,
404
+ (torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
405
+ input_ids
406
+ ], dim=1).long()
407
+
408
+ else:
409
+ input_ids = input_ids
410
+
411
+
412
+ except Exception as e:
413
+ yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
414
+ return
415
+
416
+
417
+
418
+ batch_size = input_ids.shape[0]
419
+ prompt_len = input_ids.shape[1]
420
+
421
+ x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
422
+ x[:, :prompt_len] = input_ids.clone()
423
+
424
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
425
+
426
+ if gen_length == 0:
427
+ final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
428
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
429
+ return
430
+
431
+ if block_length <= 0 or gen_length % block_length != 0 :
432
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
433
+ f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
434
+ return
435
+ num_blocks = gen_length // block_length
436
+
437
+ if steps <=0 or steps % num_blocks != 0:
438
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
439
+ f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
440
+ return
441
+ steps_per_block = steps // num_blocks
442
+
443
+ for num_block_iter in range(num_blocks):
444
+ current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
445
+ current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
446
+
447
+ block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
448
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
449
+ (x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
450
+
451
+ num_transfer_tokens_for_this_block = get_num_transfer_tokens(
452
+ block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
453
+ steps_per_block
454
+ )
455
+
456
+ for i_step_in_block in range(steps_per_block):
457
+ mask_index_global = (x == MASK_ID)
458
+
459
+ if cfg_scale > 0.:
460
+ un_x = x.clone()
461
+ # For unconditional pass, mask out the original prompt tokens that are not padding
462
+ # raw_prompt_attention_mask is (B, prompt_len)
463
+ prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
464
+ un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
465
+
466
+ x_cfg_input = torch.cat([x, un_x], dim=0)
467
+ # Pass attention_mask for CFG if model expects it, covering both parts
468
+ # For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
469
+ model_output = MODEL(x_cfg_input)
470
+ logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
471
+ logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
472
+ else:
473
+ # Not passing explicit attention_mask here; relies on model's internal handling.
474
+ model_output = MODEL(x)
475
+ logits = model_output.logits
476
+
477
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
478
+ x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
479
+
480
+ if remasking_strategy == 'low_confidence':
481
+ probs = F.softmax(logits.to(torch.float64), dim=-1)
482
+ x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
483
+ elif remasking_strategy == 'random':
484
+ x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
485
+ else:
486
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
487
+ return
488
+
489
+ confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
490
+ candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
491
+ confidence_for_selection = torch.where(
492
+ candidate_positions_for_unmasking,
493
+ x0_probs,
494
+ -torch.inf
495
+ )
496
+
497
+ x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
498
+
499
+ transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
500
+ num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
501
+
502
+ for j_batch_idx in range(batch_size):
503
+ k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
504
+ candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
505
+
506
+ if k_val > 0:
507
+ # Ensure confidence_for_selection[j_batch_idx] is 1D for topk
508
+ conf_slice = confidence_for_selection[j_batch_idx]
509
+ if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
510
+
511
+ # Check if there are enough valid (non -inf) confidences
512
+ valid_conf_count = (conf_slice > -torch.inf).sum().item()
513
+ actual_k = min(k_val, valid_conf_count)
514
+
515
+ if actual_k > 0:
516
+ _, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
517
+ transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
518
+
519
+ x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
520
+
521
+ current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
522
+ total_overall_steps = num_blocks * steps_per_block
523
+ status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
524
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
525
+
526
+ final_generated_ids = x[:, prompt_len:]
527
+ final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
528
+
529
+ final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
530
+ yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
531
+
532
+
533
+ css_styles = """
534
+ .gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
535
+ .gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
536
+ .gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
537
+
538
+ .highlighted-text span{
539
+ padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;
540
+ }
541
+
542
+ footer{display:none !important}
543
+
544
+ #live-update-scrollable-box {
545
+ max-height: 800px; /* 您可以根据需要调整这个最大高度,例如 '300px', '50vh' 等 */
546
+ overflow-y: auto !important; /* 当内容超出 max-height 时显示垂直滚动条 */
547
+ display: block; /* 确保元素是块级元素,以便 max-height 生效 */
548
+
549
+ }
550
+ #think_btn {
551
+ background-color: #f3f4f6 !important;
552
+ border: 1px solid #d0d0d0 !important;
553
+ color: #111827 !important;
554
+ font-size: 16px !important;
555
+ font-weight: bold !important;
556
+ }
557
+ #think_btn:hover {
558
+ background-color: #e0e0e0 !important;
559
+ border: 1px solid #c0c0c0 !important;
560
+ color: #222 !important;
561
+ }
562
+ #think_btn:active {
563
+ background-color: #2563eb !important;
564
+ border: 1px solid #b0b0b0 !important;
565
+ color: white !important;
566
+ }
567
+ """
568
+
569
+
570
+ # thinking_mode_t2i = gr.State(False)
571
+ def toggle_thinking_mode_lm(current_thinking_mode):
572
+ # print(f"current_thinking_mode: {current_thinking_mode}")
573
+ new_state = not current_thinking_mode
574
+ new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
575
+ return new_state, gr.update(value=new_label)
576
+
577
+ def toggle_thinking_mode_mmu(current_thinking_mode):
578
+ new_state = not current_thinking_mode
579
+ new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
580
+ return new_state, gr.update(value=new_label)
581
+
582
+
583
+ color_map_config = {
584
+ "MASK": "lightgrey",
585
+ "GEN": "#DCABFA",
586
+ }
587
+
588
+ theme = gr.themes.Ocean(
589
+ primary_hue="fuchsia",
590
+ )
591
+ with gr.Blocks(css=css_styles, theme=theme) as demo:
592
+ # with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
593
+ # with gr.Blocks() as demo:
594
+ thinking_mode_lm = gr.State(False)
595
+ thinking_mode_mmu = gr.State(False)
596
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>MMaDA </h1>")
597
+ gr.Markdown("Interactively explore the step-by-step generation process of a diffusion language model. "
598
+ "The model begins with a fully masked sequence (except for the prompt) and progressively refines it by unmasking tokens.")
599
+ gr.Markdown("### Select Model")
600
+ with gr.Row():
601
+ model_select_radio = gr.Radio(
602
+ label="Select Text Generation Model",
603
+ choices=MODEL_CHOICES,
604
+ value=MODEL_CHOICES[0]
605
+ )
606
+ model_load_status_box = gr.Textbox(
607
+ label="Model Load Status",
608
+ interactive=False,
609
+ lines=3,
610
+ max_lines=5
611
+ )
612
+
613
+ gr.Markdown("## Part 1. Text Generation")
614
+ with gr.Row():
615
+ with gr.Column(scale=2):
616
+ prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
617
+ think_button_lm = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
618
+ with gr.Accordion("Generation Parameters", open=True):
619
+ with gr.Row():
620
+ gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
621
+ steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
622
+ with gr.Row():
623
+ block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
624
+ remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
625
+ with gr.Row():
626
+ cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
627
+ temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
628
+
629
+
630
+ with gr.Row():
631
+ run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
632
+ clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
633
+
634
+ with gr.Column(scale=3):
635
+ # gr.Markdown("## Live Generation Process")
636
+ output_visualization_box_lm = gr.HighlightedText(
637
+ label="Live Generation Process",
638
+ show_legend=True,
639
+ color_map=color_map_config,
640
+ combine_adjacent=False,
641
+ interactive=False,
642
+ elem_id="live-update-scrollable-box",
643
+ )
644
+ # gr.Markdown("## Final Generated Text")
645
+ output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
646
+
647
+
648
+
649
+ gr.Examples(
650
+ examples=[
651
+ ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
652
+ ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
653
+ ],
654
+ inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
655
+ outputs=[output_visualization_box_lm, output_final_text_box_lm],
656
+ fn=generate_viz_wrapper_lm,
657
+ )
658
+
659
+ gr.Markdown("---")
660
+ gr.Markdown("## Part 2. Multimodal Understanding")
661
+ with gr.Row():
662
+ with gr.Column(scale=2):
663
+ prompt_input_box_mmu = gr.Textbox(
664
+ label="Enter your prompt:",
665
+ lines=3,
666
+ value="Please describe this image in detail."
667
+ )
668
+ think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
669
+ with gr.Accordion("Generation Parameters", open=True):
670
+ with gr.Row():
671
+ gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
672
+ steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
673
+ with gr.Row():
674
+ block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
675
+ remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
676
+ with gr.Row():
677
+ cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
678
+ temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
679
+
680
+ with gr.Row():
681
+ image_upload_box = gr.Image(type="pil", label="Upload Image")
682
+
683
+ with gr.Row():
684
+ run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
685
+ clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
686
+
687
+ with gr.Column(scale=3):
688
+ gr.Markdown("## Live Generation Process")
689
+ output_visualization_box_mmu = gr.HighlightedText(
690
+ label="Token Sequence (Live Update)",
691
+ show_legend=True,
692
+ color_map=color_map_config,
693
+ combine_adjacent=False,
694
+ interactive=False,
695
+ elem_id="live-update-scrollable-box",
696
+ )
697
+ gr.Markdown("## Final Generated Text")
698
+ output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
699
+
700
+
701
+ gr.Examples(
702
+ examples=[
703
+ [
704
+ "mmu_validation_2/sunflower.jpg",
705
+ "Please describe this image in detail.",
706
+ 256,
707
+ 512,
708
+ 128,
709
+ 1,
710
+ 0,
711
+ "low_confidence"
712
+ ],
713
+ [
714
+ "mmu_validation_2/woman.jpg",
715
+ "Please describe this image in detail.",
716
+ 256,
717
+ 512,
718
+ 128,
719
+ 1,
720
+ 0,
721
+ "low_confidence"
722
+ ]
723
+ ],
724
+ inputs=[
725
+ image_upload_box,
726
+ prompt_input_box_mmu,
727
+ steps_slider_mmu,
728
+ gen_length_slider_mmu,
729
+ block_length_slider_mmu,
730
+ temperature_slider_mmu,
731
+ cfg_scale_slider_mmu,
732
+ remasking_dropdown_mmu
733
+ ],
734
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
735
+ fn=generate_viz_wrapper,
736
+ )
737
+
738
+ gr.Markdown("---")
739
+ gr.Markdown("## Part 3. Text-to-Image Generation")
740
+ with gr.Row():
741
+ with gr.Column(scale=2):
742
+ prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
743
+
744
+ with gr.Accordion("Generation Parameters", open=True):
745
+ with gr.Row():
746
+ steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
747
+ guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale", info="Classifier-Free Guidance. 0 disables it.")
748
+
749
+
750
+ with gr.Row():
751
+ scheduler_radio_t2i = gr.Radio(
752
+ choices=["cosine", "sigmoid", "linear"],
753
+ value="cosine",
754
+ label="Scheduler",
755
+ )
756
+
757
+ with gr.Row():
758
+ run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
759
+ clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
760
+
761
+
762
+ with gr.Column(scale=3):
763
+ # gr.Markdown("## Live Generation Process")
764
+ output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
765
+ output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
766
+
767
+ gr.Examples(
768
+ examples=[
769
+ ["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
770
+ ["A beautiful sunset over a calm ocean, with a few clouds in the sky.", 15, 3.5, "cosine"]
771
+ ],
772
+ inputs=[prompt_input_box_t2i, steps_slider_t2i, guidance_scale_slider_t2i, scheduler_radio_t2i],
773
+ outputs=[output_image_t2i, output_status_t2i],
774
+ fn=generate_viz_wrapper_t2i,
775
+ )
776
+
777
+ run_button_ui_t2i.click(
778
+ fn=generate_viz_wrapper_t2i,
779
+ inputs=[
780
+ prompt_input_box_t2i,
781
+ steps_slider_t2i,
782
+ guidance_scale_slider_t2i,
783
+ scheduler_radio_t2i
784
+ ],
785
+ outputs=[output_image_t2i, output_status_t2i]
786
+ )
787
+
788
+ clear_button_ui_t2i.click(
789
+ fn=lambda: (None, ""),
790
+ inputs=None,
791
+ outputs=[output_image_t2i, output_status_t2i],
792
+ queue=False
793
+ )
794
+
795
+ think_button_lm.click(
796
+ fn=toggle_thinking_mode_lm,
797
+ inputs=[thinking_mode_lm],
798
+ outputs=[thinking_mode_lm, think_button_lm]
799
+ )
800
+
801
+ think_button_mmu.click(
802
+ fn=toggle_thinking_mode_mmu,
803
+ inputs=[thinking_mode_mmu],
804
+ outputs=[thinking_mode_mmu, think_button_mmu]
805
+ )
806
+
807
+
808
+
809
+ def initialize_default_model():
810
+ default_model = "MMaDA-8B-Base"
811
+ result = handle_model_selection_change(default_model)
812
+ return default_model, result
813
+
814
+ demo.load(
815
+ fn=initialize_default_model,
816
+ inputs=None,
817
+ outputs=[model_select_radio, model_load_status_box],
818
+ queue=True
819
+ )
820
+
821
+ def clear_outputs():
822
+ return None, None, None # Clear image, visualization, and final text
823
+
824
+ clear_button_ui_lm.click(
825
+ fn=clear_outputs,
826
+ inputs=None,
827
+ outputs=[image_upload_box, output_visualization_box_lm, output_final_text_box_lm],
828
+ queue=False
829
+ )
830
+ clear_button_ui_mmu.click(
831
+ fn=clear_outputs,
832
+ inputs=None,
833
+ outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu],
834
+ queue=False
835
+ )
836
+
837
+ run_button_ui_lm.click(
838
+ fn=generate_viz_wrapper_lm,
839
+ inputs=[
840
+ prompt_input_box_lm,
841
+ steps_slider_lm,
842
+ gen_length_slider_lm,
843
+ block_length_slider_lm,
844
+ temperature_slider_lm,
845
+ cfg_scale_slider_lm,
846
+ remasking_dropdown_lm,
847
+ thinking_mode_lm
848
+ ],
849
+ outputs=[output_visualization_box_lm, output_final_text_box_lm]
850
+ )
851
+
852
+ run_button_ui_mmu.click(
853
+ fn=generate_viz_wrapper,
854
+ inputs=[
855
+ image_upload_box,
856
+ prompt_input_box_mmu,
857
+ steps_slider_mmu,
858
+ gen_length_slider_mmu,
859
+ block_length_slider_mmu,
860
+ temperature_slider_mmu,
861
+ cfg_scale_slider_mmu,
862
+ remasking_dropdown_mmu,
863
+ thinking_mode_mmu
864
+ ],
865
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu]
866
+ )
867
+
868
+
869
+ if __name__ == "__main__":
870
+ print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
871
+ demo.launch(share=True)
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2
2
+ from .sampling import *
3
+ from .modeling_mmada import MMadaModelLM, MMadaConfig
models/common_modules.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34
3
+ """
4
+
5
+ import math
6
+ from typing import Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+ from einops.layers.torch import Rearrange
14
+
15
+
16
+ def nonlinearity(x):
17
+ # swish
18
+ return x * torch.sigmoid(x)
19
+
20
+
21
+ def Normalize(in_channels):
22
+ return torch.nn.GroupNorm(
23
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
24
+ )
25
+
26
+
27
+ class Upsample(nn.Module):
28
+ def __init__(self, in_channels, with_conv):
29
+ super().__init__()
30
+ self.with_conv = with_conv
31
+ if self.with_conv:
32
+ self.conv = torch.nn.Conv2d(
33
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
34
+ )
35
+
36
+ def forward(self, x):
37
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
38
+ if self.with_conv:
39
+ x = self.conv(x)
40
+ return x
41
+
42
+
43
+ class DepthToSpaceUpsample(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_channels,
47
+ ):
48
+ super().__init__()
49
+ conv = nn.Conv2d(in_channels, in_channels * 4, 1)
50
+
51
+ self.net = nn.Sequential(
52
+ conv,
53
+ nn.SiLU(),
54
+ Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
55
+ )
56
+
57
+ self.init_conv_(conv)
58
+
59
+ def init_conv_(self, conv):
60
+ o, i, h, w = conv.weight.shape
61
+ conv_weight = torch.empty(o // 4, i, h, w)
62
+ nn.init.kaiming_uniform_(conv_weight)
63
+ conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
64
+
65
+ conv.weight.data.copy_(conv_weight)
66
+ nn.init.zeros_(conv.bias.data)
67
+
68
+ def forward(self, x):
69
+ out = self.net(x)
70
+ return out
71
+
72
+
73
+ class Downsample(nn.Module):
74
+ def __init__(self, in_channels, with_conv):
75
+ super().__init__()
76
+ self.with_conv = with_conv
77
+ if self.with_conv:
78
+ # no asymmetric padding in torch conv, must do it ourselves
79
+ self.conv = torch.nn.Conv2d(
80
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
81
+ )
82
+
83
+ def forward(self, x):
84
+ if self.with_conv:
85
+ pad = (0, 1, 0, 1)
86
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
87
+ x = self.conv(x)
88
+ else:
89
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
90
+ return x
91
+
92
+
93
+ def unpack_time(t, batch):
94
+ _, c, w, h = t.size()
95
+ out = torch.reshape(t, [batch, -1, c, w, h])
96
+ out = rearrange(out, "b t c h w -> b c t h w")
97
+ return out
98
+
99
+
100
+ def pack_time(t):
101
+ out = rearrange(t, "b c t h w -> b t c h w")
102
+ _, _, c, w, h = out.size()
103
+ return torch.reshape(out, [-1, c, w, h])
104
+
105
+
106
+ class TimeDownsample2x(nn.Module):
107
+ def __init__(
108
+ self,
109
+ dim,
110
+ dim_out=None,
111
+ kernel_size=3,
112
+ ):
113
+ super().__init__()
114
+ if dim_out is None:
115
+ dim_out = dim
116
+ self.time_causal_padding = (kernel_size - 1, 0)
117
+ self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
118
+
119
+ def forward(self, x):
120
+ x = rearrange(x, "b c t h w -> b h w c t")
121
+ b, h, w, c, t = x.size()
122
+ x = torch.reshape(x, [-1, c, t])
123
+
124
+ x = F.pad(x, self.time_causal_padding)
125
+ out = self.conv(x)
126
+
127
+ out = torch.reshape(out, [b, h, w, c, t])
128
+ out = rearrange(out, "b h w c t -> b c t h w")
129
+ out = rearrange(out, "b h w c t -> b c t h w")
130
+ return out
131
+
132
+
133
+ class TimeUpsample2x(nn.Module):
134
+ def __init__(self, dim, dim_out=None):
135
+ super().__init__()
136
+ if dim_out is None:
137
+ dim_out = dim
138
+ conv = nn.Conv1d(dim, dim_out * 2, 1)
139
+
140
+ self.net = nn.Sequential(
141
+ nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2)
142
+ )
143
+
144
+ self.init_conv_(conv)
145
+
146
+ def init_conv_(self, conv):
147
+ o, i, t = conv.weight.shape
148
+ conv_weight = torch.empty(o // 2, i, t)
149
+ nn.init.kaiming_uniform_(conv_weight)
150
+ conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
151
+
152
+ conv.weight.data.copy_(conv_weight)
153
+ nn.init.zeros_(conv.bias.data)
154
+
155
+ def forward(self, x):
156
+ x = rearrange(x, "b c t h w -> b h w c t")
157
+ b, h, w, c, t = x.size()
158
+ x = torch.reshape(x, [-1, c, t])
159
+
160
+ out = self.net(x)
161
+ out = out[:, :, 1:].contiguous()
162
+
163
+ out = torch.reshape(out, [b, h, w, c, t])
164
+ out = rearrange(out, "b h w c t -> b c t h w")
165
+ return out
166
+
167
+
168
+ class AttnBlock(nn.Module):
169
+ def __init__(self, in_channels):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+
173
+ self.norm = Normalize(in_channels)
174
+ self.q = torch.nn.Conv2d(
175
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
176
+ )
177
+ self.k = torch.nn.Conv2d(
178
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
179
+ )
180
+ self.v = torch.nn.Conv2d(
181
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
182
+ )
183
+ self.proj_out = torch.nn.Conv2d(
184
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
185
+ )
186
+
187
+ def forward(self, x):
188
+ h_ = x
189
+ h_ = self.norm(h_)
190
+ q = self.q(h_)
191
+ k = self.k(h_)
192
+ v = self.v(h_)
193
+
194
+ # compute attention
195
+ b, c, h, w = q.shape
196
+ q = q.reshape(b, c, h * w)
197
+ q = q.permute(0, 2, 1) # b,hw,c
198
+ k = k.reshape(b, c, h * w) # b,c,hw
199
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
200
+ w_ = w_ * (int(c) ** (-0.5))
201
+ w_ = torch.nn.functional.softmax(w_, dim=2)
202
+
203
+ # attend to values
204
+ v = v.reshape(b, c, h * w)
205
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
206
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
207
+ h_ = h_.reshape(b, c, h, w)
208
+
209
+ h_ = self.proj_out(h_)
210
+
211
+ return x + h_
212
+
213
+
214
+ class TimeAttention(AttnBlock):
215
+ def forward(self, x, *args, **kwargs):
216
+ x = rearrange(x, "b c t h w -> b h w t c")
217
+ b, h, w, t, c = x.size()
218
+ x = torch.reshape(x, (-1, t, c))
219
+
220
+ x = super().forward(x, *args, **kwargs)
221
+
222
+ x = torch.reshape(x, [b, h, w, t, c])
223
+ return rearrange(x, "b h w t c -> b c t h w")
224
+
225
+
226
+ class Residual(nn.Module):
227
+ def __init__(self, fn: nn.Module):
228
+ super().__init__()
229
+ self.fn = fn
230
+
231
+ def forward(self, x, **kwargs):
232
+ return self.fn(x, **kwargs) + x
233
+
234
+
235
+ def cast_tuple(t, length=1):
236
+ return t if isinstance(t, tuple) else ((t,) * length)
237
+
238
+
239
+ class CausalConv3d(nn.Module):
240
+ def __init__(
241
+ self,
242
+ chan_in,
243
+ chan_out,
244
+ kernel_size: Union[int, Tuple[int, int, int]],
245
+ pad_mode="constant",
246
+ **kwargs
247
+ ):
248
+ super().__init__()
249
+ kernel_size = cast_tuple(kernel_size, 3)
250
+
251
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
252
+
253
+ dilation = kwargs.pop("dilation", 1)
254
+ stride = kwargs.pop("stride", 1)
255
+
256
+ self.pad_mode = pad_mode
257
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
258
+ height_pad = height_kernel_size // 2
259
+ width_pad = width_kernel_size // 2
260
+
261
+ self.time_pad = time_pad
262
+ self.time_causal_padding = (
263
+ width_pad,
264
+ width_pad,
265
+ height_pad,
266
+ height_pad,
267
+ time_pad,
268
+ 0,
269
+ )
270
+
271
+ stride = (stride, 1, 1)
272
+ dilation = (dilation, 1, 1)
273
+ self.conv = nn.Conv3d(
274
+ chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
275
+ )
276
+
277
+ def forward(self, x):
278
+ pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
279
+
280
+ x = F.pad(x, self.time_causal_padding, mode=pad_mode)
281
+ return self.conv(x)
282
+
283
+
284
+ def ResnetBlockCausal3D(
285
+ dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"
286
+ ):
287
+ net = nn.Sequential(
288
+ Normalize(dim),
289
+ nn.SiLU(),
290
+ CausalConv3d(dim, dim, kernel_size, pad_mode),
291
+ Normalize(dim),
292
+ nn.SiLU(),
293
+ CausalConv3d(dim, dim, kernel_size, pad_mode),
294
+ )
295
+ return Residual(net)
296
+
297
+
298
+ class ResnetBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ *,
302
+ in_channels,
303
+ out_channels=None,
304
+ conv_shortcut=False,
305
+ dropout,
306
+ temb_channels=512
307
+ ):
308
+ super().__init__()
309
+ self.in_channels = in_channels
310
+ out_channels = in_channels if out_channels is None else out_channels
311
+ self.out_channels = out_channels
312
+ self.use_conv_shortcut = conv_shortcut
313
+
314
+ self.norm1 = Normalize(in_channels)
315
+ self.conv1 = torch.nn.Conv2d(
316
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
317
+ )
318
+ if temb_channels > 0:
319
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
320
+ else:
321
+ self.temb_proj = None
322
+ self.norm2 = Normalize(out_channels)
323
+ self.dropout = torch.nn.Dropout(dropout)
324
+ self.conv2 = torch.nn.Conv2d(
325
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
326
+ )
327
+ if self.in_channels != self.out_channels:
328
+ if self.use_conv_shortcut:
329
+ self.conv_shortcut = torch.nn.Conv2d(
330
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
331
+ )
332
+ else:
333
+ self.nin_shortcut = torch.nn.Conv2d(
334
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
335
+ )
336
+
337
+ def forward(self, x, temb):
338
+ h = x
339
+ h = self.norm1(h)
340
+ h = nonlinearity(h)
341
+ h = self.conv1(h)
342
+
343
+ if temb is not None:
344
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
345
+
346
+ h = self.norm2(h)
347
+ h = nonlinearity(h)
348
+ h = self.dropout(h)
349
+ h = self.conv2(h)
350
+
351
+ if self.in_channels != self.out_channels:
352
+ if self.use_conv_shortcut:
353
+ x = self.conv_shortcut(x)
354
+ else:
355
+ x = self.nin_shortcut(x)
356
+
357
+ return x + h
models/configuration_llada.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaDA configuration
3
+ """
4
+ from transformers import AutoConfig, PretrainedConfig
5
+
6
+ from enum import Enum
7
+ from os import PathLike
8
+ from typing import Union
9
+ from dataclasses import asdict, dataclass, field
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from typing import (
13
+ Any,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ cast,
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "ActivationType",
28
+ "ActivationCheckpointingStrategy",
29
+ "BlockType",
30
+ "LayerNormType",
31
+ "InitFnType",
32
+ "ModelConfig",
33
+ ]
34
+
35
+ PathOrStr = Union[str, PathLike]
36
+
37
+
38
+ class StrEnum(str, Enum):
39
+ """
40
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
41
+ We include this here for compatibility with older version of Python.
42
+ """
43
+
44
+ def __str__(self) -> str:
45
+ return self.value
46
+
47
+ def __repr__(self) -> str:
48
+ return f"'{str(self)}'"
49
+
50
+
51
+ class LayerNormType(StrEnum):
52
+ default = "default"
53
+ """
54
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
55
+ """
56
+
57
+ low_precision = "low_precision"
58
+ """
59
+ A low-precision version of the default LayerNorm.
60
+ """
61
+
62
+ rms = "rms"
63
+ """
64
+ An RMSNorm implementation. When using ``torch.compile`` this is
65
+ probably the fastest implementation.
66
+ """
67
+
68
+ gemma_rms = "gemma_rms"
69
+ """
70
+ An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
71
+ probably the fastest implementation.
72
+ """
73
+
74
+ amd_compatible = "amd_compatible"
75
+ """
76
+ LayerNorm implemented manually to work around an issue with ROCm.
77
+ """
78
+
79
+
80
+ class ActivationType(StrEnum):
81
+ gelu = "gelu"
82
+ relu = "relu"
83
+ silu = "silu"
84
+ swiglu = "swiglu"
85
+
86
+
87
+ class BlockType(StrEnum):
88
+ sequential = "sequential"
89
+ parallel = "parallel"
90
+
91
+ llama = "llama"
92
+ """
93
+ A block similar to the sequential block with slightly different
94
+ implementations of operations like attention to imitate the behavior of Llama.
95
+ """
96
+
97
+
98
+ class InitFnType(StrEnum):
99
+ mitchell = "mitchell"
100
+ """
101
+ The strategy suggested to us by Mitchell Wortsman from UW.
102
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
103
+ on the size of the weights as well as the depth of the layer.
104
+ """
105
+
106
+ normal = "normal"
107
+ """
108
+ All weights are initialized from the same normal distribution.
109
+ """
110
+
111
+ kaiming_normal = "kaiming_normal"
112
+ """
113
+ All weights are initialized with the Kaiming method from a normal distribution.
114
+ Note this currently won't work with FSDP.
115
+ """
116
+
117
+ fan_in = "fan_in"
118
+ """
119
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
120
+ is the input dimensionality of the kernel.
121
+ """
122
+
123
+ full_megatron = "full_megatron"
124
+ """
125
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
126
+ """
127
+
128
+
129
+ @dataclass
130
+ class ModelConfig():
131
+ """
132
+ LLaDA (model) configuration.
133
+ """
134
+
135
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
136
+
137
+ d_model: int = 768
138
+ """
139
+ The hidden size of the model.
140
+ """
141
+
142
+ n_heads: int = 12
143
+ """
144
+ The number of self-attention heads.
145
+ """
146
+
147
+ n_kv_heads: Optional[int] = None
148
+ """
149
+ The number of heads to use for keys and values. Defaults to `n_heads`.
150
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
151
+ Set this to 1 for multi-query attention.
152
+ Set it to some in-between value for Llama2-style grouped query attention.
153
+ """
154
+
155
+ n_layers: int = 12
156
+ """
157
+ The number of layers/blocks.
158
+ """
159
+
160
+ mlp_ratio: int = 4
161
+ """
162
+ The ratio of the inner MLP dimensionality to ``d_model``.
163
+ This is only used when ``mlp_hidden_size`` is not set.
164
+ """
165
+
166
+ mlp_hidden_size: Optional[int] = None
167
+ """
168
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
169
+ """
170
+
171
+ activation_type: ActivationType = ActivationType.swiglu
172
+ """
173
+ The activation function to use within the MLP layers.
174
+ """
175
+
176
+ block_type: BlockType = BlockType.sequential
177
+ """
178
+ The transformer block implementation.
179
+ """
180
+
181
+ block_group_size: int = 1
182
+ """
183
+ The number of blocks to group together into a single parent block.
184
+ This has no affect on the number of parameters in the model and is only used to wrap groups
185
+ of blocks together with a single FSDP wrapper during training.
186
+ """
187
+
188
+ alibi: bool = False
189
+ """
190
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
191
+ """
192
+
193
+ alibi_bias_max: float = 8.0
194
+ """
195
+ Maximum absolute value of ALiBi bias.
196
+ """
197
+
198
+ rope: bool = False
199
+ """
200
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
201
+ """
202
+
203
+ rope_full_precision: bool = True
204
+ """
205
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
206
+ apply RoPE at the precision of the input.
207
+ """
208
+
209
+ flash_attention: bool = False
210
+ """
211
+ If ``True``, use ``FlashAttention``.
212
+ """
213
+
214
+ attention_dropout: float = 0.1
215
+ """
216
+ The dropout probability within the attention modules.
217
+ """
218
+
219
+ multi_query_attention: Optional[bool] = None
220
+ """
221
+ Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
222
+ and is more efficient during inference.
223
+ """
224
+
225
+ attention_layer_norm: bool = False
226
+ """
227
+ Apply layer norm to the keys and queries within the attention mechanism.
228
+ This can help stabilize training.
229
+ """
230
+
231
+ residual_dropout: float = 0.1
232
+ """
233
+ The dropout probability for the MLP and attention output within each block.
234
+ """
235
+
236
+ embedding_dropout: float = 0.1
237
+ """
238
+ The dropout probability for embeddings.
239
+ """
240
+
241
+ input_emb_norm: bool = False
242
+ """
243
+ An input hidden_states norm implementation by gemmma.
244
+ """
245
+
246
+ layer_norm_type: LayerNormType = LayerNormType.default
247
+ """
248
+ The layernorm implementation to use.
249
+ """
250
+
251
+ layer_norm_with_affine: bool = True
252
+ """
253
+ Whether to include bias and weight parameters for the layer norms.
254
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
255
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
256
+ to ``False``.
257
+ """
258
+
259
+ rms_norm_eps: float = 1e-05
260
+ """
261
+ The rms layernorm eps param.
262
+ """
263
+
264
+ attention_layer_norm_with_affine: bool = True
265
+ """
266
+ Toggle affine transform for the QK norms.
267
+ """
268
+
269
+ max_sequence_length: int = 1024
270
+ """
271
+ The maximum input sequence length supported by the model.
272
+ """
273
+
274
+ rope_theta: float = 10000.0
275
+ """
276
+ The rope base param.
277
+ """
278
+
279
+ include_qkv_bias: Optional[bool] = False
280
+ """
281
+ Whether or not to include bias parameters in qkv linear layers.
282
+ """
283
+
284
+ include_bias: bool = False
285
+ """
286
+ Whether or not to include bias parameters in linear layers.
287
+ In PaLM, they got rid of all bias terms because they found that large
288
+ models tend to have near 0 bias terms anyway.
289
+ """
290
+
291
+ bias_for_layer_norm: Optional[bool] = None
292
+ """
293
+ Whether or not to include bias parameters in layer norm.
294
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
295
+ layer norm.
296
+ When this is None (the default), it inherits the setting from include_bias.
297
+ """
298
+
299
+ scale_logits: bool = False
300
+ """
301
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
302
+ """
303
+
304
+ vocab_size: int = 50257
305
+ """
306
+ Vocabulary size of the model.
307
+ """
308
+
309
+ embedding_size: Optional[int] = 50304
310
+ """
311
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
312
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
313
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
314
+ substantially.
315
+ """
316
+
317
+ weight_tying: bool = True
318
+ """
319
+ Whether to tie output linear weights to the input embedding.
320
+ """
321
+
322
+ eos_token_id: int = 50256
323
+ """
324
+ The ID of the end-of-sentence special token.
325
+ """
326
+
327
+ pad_token_id: int = 50256
328
+ """
329
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
330
+ """
331
+
332
+ mask_token_id: Optional[int] = 50256
333
+ """
334
+ The ID of the token to use for mask token. Defaults to the ID of the EOS token.
335
+ """
336
+
337
+ init_device: Optional[str] = None
338
+ """
339
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
340
+ """
341
+
342
+ init_fn: InitFnType = InitFnType.normal
343
+ """
344
+ The weight initialization strategy.
345
+ """
346
+
347
+ init_std: float = 0.02
348
+ """
349
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
350
+ as "normal".
351
+ """
352
+
353
+ init_cutoff_factor: Optional[float] = None
354
+ """
355
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
356
+ as "normal". Setting this to None means values are not cutoff.
357
+ """
358
+
359
+ precision: Optional[str] = None
360
+ """
361
+ Precision used to train/evaluate with. You shouldn't set this directly.
362
+ See :data:`TrainConfig.precision` instead.
363
+ """
364
+
365
+ @property
366
+ def effective_n_kv_heads(self) -> int:
367
+ if self.n_kv_heads is None:
368
+ if self.multi_query_attention is True:
369
+ return 1
370
+ else:
371
+ return self.n_heads
372
+ else:
373
+ if self.multi_query_attention is None:
374
+ return self.n_kv_heads
375
+ if self.multi_query_attention:
376
+ n_kv_heads_should_be = 1
377
+ else:
378
+ n_kv_heads_should_be = self.n_heads
379
+ if self.n_kv_heads == n_kv_heads_should_be:
380
+ return n_kv_heads_should_be
381
+ else:
382
+ raise Exception(
383
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
384
+ )
385
+
386
+ class ActivationCheckpointingStrategy(StrEnum):
387
+ whole_layer = "whole_layer"
388
+ """
389
+ Checkpoint every transformer layer.
390
+ """
391
+
392
+ one_in_two = "one_in_two"
393
+ """
394
+ Checkpoint one in two transformer layers.
395
+ """
396
+
397
+ one_in_three = "one_in_three"
398
+ """
399
+ Checkpoint one in three transformer layers.
400
+ """
401
+
402
+ one_in_four = "one_in_four"
403
+ """
404
+ Checkpoint one in four transformer layers.
405
+ """
406
+
407
+ two_in_three = "two_in_three"
408
+ """
409
+ Checkpoint two out of every three transformer layers.
410
+ """
411
+
412
+ three_in_four = "three_in_four"
413
+ """
414
+ Checkpoint three out of four of every transformer layers.
415
+ """
416
+
417
+ four_in_five = "four_in_five"
418
+ """
419
+ Checkpoint four out of five of every transformer layers.
420
+ """
421
+
422
+ nine_in_ten = "nine_in_ten"
423
+ """
424
+ Checkpoint nine out of ten of every transformer layers.
425
+ """
426
+
427
+ fine_grained = "fine_grained"
428
+ """
429
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
430
+ """
431
+
432
+
433
+ class LLaDAConfig(PretrainedConfig):
434
+ model_type = "llada"
435
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
436
+
437
+ def __init__(self, use_cache: bool = False, **kwargs):
438
+ model_config = ModelConfig()
439
+ all_kwargs = model_config.__dict__
440
+ all_kwargs.update(kwargs)
441
+ all_kwargs.update({"use_cache": use_cache})
442
+ all_kwargs.update(
443
+ {
444
+ "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
445
+ }
446
+ )
447
+ super().__init__(**all_kwargs)
448
+
449
+ @property
450
+ def num_attention_heads(self):
451
+ return self.n_heads
452
+
453
+ @property
454
+ def num_hidden_layers(self):
455
+ return self.n_layers
456
+
457
+ @property
458
+ def hidden_size(self):
459
+ return self.d_model
460
+
461
+
462
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
463
+ AutoConfig.register("llada", LLaDAConfig)
models/logging.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Optuna, Hugging Face
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Logging utilities."""
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import threading
21
+ from logging import CRITICAL # NOQA
22
+ from logging import DEBUG # NOQA
23
+ from logging import ERROR # NOQA
24
+ from logging import FATAL # NOQA
25
+ from logging import INFO # NOQA
26
+ from logging import NOTSET # NOQA
27
+ from logging import WARN # NOQA
28
+ from logging import WARNING # NOQA
29
+ from typing import Optional
30
+
31
+ from tqdm import auto as tqdm_lib
32
+
33
+ _lock = threading.Lock()
34
+ _default_handler: Optional[logging.Handler] = None
35
+
36
+ log_levels = {
37
+ "debug": logging.DEBUG,
38
+ "info": logging.INFO,
39
+ "warning": logging.WARNING,
40
+ "error": logging.ERROR,
41
+ "critical": logging.CRITICAL,
42
+ }
43
+
44
+ _default_log_level = logging.WARNING
45
+
46
+ _tqdm_active = True
47
+
48
+
49
+ def _get_default_logging_level():
50
+ """
51
+ If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
52
+ not - fall back to `_default_log_level`
53
+ """
54
+ env_level_str = os.getenv("muse_VERBOSITY", None)
55
+ if env_level_str:
56
+ if env_level_str in log_levels:
57
+ return log_levels[env_level_str]
58
+ else:
59
+ logging.getLogger().warning(
60
+ f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
61
+ )
62
+ return _default_log_level
63
+
64
+
65
+ def _get_library_name() -> str:
66
+ return __name__.split(".")[0]
67
+
68
+
69
+ def _get_library_root_logger() -> logging.Logger:
70
+ return logging.getLogger(_get_library_name())
71
+
72
+
73
+ def _configure_library_root_logger() -> None:
74
+ global _default_handler
75
+
76
+ with _lock:
77
+ if _default_handler:
78
+ # This library has already configured the library root logger.
79
+ return
80
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
81
+ _default_handler.flush = sys.stderr.flush
82
+
83
+ # Apply our default configuration to the library root logger.
84
+ library_root_logger = _get_library_root_logger()
85
+ library_root_logger.addHandler(_default_handler)
86
+ library_root_logger.setLevel(_get_default_logging_level())
87
+ library_root_logger.propagate = False
88
+
89
+
90
+ def _reset_library_root_logger() -> None:
91
+ global _default_handler
92
+
93
+ with _lock:
94
+ if not _default_handler:
95
+ return
96
+
97
+ library_root_logger = _get_library_root_logger()
98
+ library_root_logger.removeHandler(_default_handler)
99
+ library_root_logger.setLevel(logging.NOTSET)
100
+ _default_handler = None
101
+
102
+
103
+ def get_log_levels_dict():
104
+ return log_levels
105
+
106
+
107
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
108
+ """
109
+ Return a logger with the specified name.
110
+
111
+ This function is not supposed to be directly accessed unless you are writing a custom muse module.
112
+ """
113
+
114
+ if name is None:
115
+ name = _get_library_name()
116
+
117
+ _configure_library_root_logger()
118
+ return logging.getLogger(name)
119
+
120
+
121
+ def get_verbosity() -> int:
122
+ """
123
+ Return the current level for the 🤗 muse' root logger as an int.
124
+
125
+ Returns:
126
+ `int`: The logging level.
127
+
128
+ <Tip>
129
+
130
+ 🤗 muse has following logging levels:
131
+
132
+ - 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
133
+ - 40: `muse.logging.ERROR`
134
+ - 30: `muse.logging.WARNING` or `muse.logging.WARN`
135
+ - 20: `muse.logging.INFO`
136
+ - 10: `muse.logging.DEBUG`
137
+
138
+ </Tip>"""
139
+
140
+ _configure_library_root_logger()
141
+ return _get_library_root_logger().getEffectiveLevel()
142
+
143
+
144
+ def set_verbosity(verbosity: int) -> None:
145
+ """
146
+ Set the verbosity level for the 🤗 muse' root logger.
147
+
148
+ Args:
149
+ verbosity (`int`):
150
+ Logging level, e.g., one of:
151
+
152
+ - `muse.logging.CRITICAL` or `muse.logging.FATAL`
153
+ - `muse.logging.ERROR`
154
+ - `muse.logging.WARNING` or `muse.logging.WARN`
155
+ - `muse.logging.INFO`
156
+ - `muse.logging.DEBUG`
157
+ """
158
+
159
+ _configure_library_root_logger()
160
+ _get_library_root_logger().setLevel(verbosity)
161
+
162
+
163
+ def set_verbosity_info():
164
+ """Set the verbosity to the `INFO` level."""
165
+ return set_verbosity(INFO)
166
+
167
+
168
+ def set_verbosity_warning():
169
+ """Set the verbosity to the `WARNING` level."""
170
+ return set_verbosity(WARNING)
171
+
172
+
173
+ def set_verbosity_debug():
174
+ """Set the verbosity to the `DEBUG` level."""
175
+ return set_verbosity(DEBUG)
176
+
177
+
178
+ def set_verbosity_error():
179
+ """Set the verbosity to the `ERROR` level."""
180
+ return set_verbosity(ERROR)
181
+
182
+
183
+ def disable_default_handler() -> None:
184
+ """Disable the default handler of the HuggingFace muse' root logger."""
185
+
186
+ _configure_library_root_logger()
187
+
188
+ assert _default_handler is not None
189
+ _get_library_root_logger().removeHandler(_default_handler)
190
+
191
+
192
+ def enable_default_handler() -> None:
193
+ """Enable the default handler of the HuggingFace muse' root logger."""
194
+
195
+ _configure_library_root_logger()
196
+
197
+ assert _default_handler is not None
198
+ _get_library_root_logger().addHandler(_default_handler)
199
+
200
+
201
+ def add_handler(handler: logging.Handler) -> None:
202
+ """adds a handler to the HuggingFace muse' root logger."""
203
+
204
+ _configure_library_root_logger()
205
+
206
+ assert handler is not None
207
+ _get_library_root_logger().addHandler(handler)
208
+
209
+
210
+ def remove_handler(handler: logging.Handler) -> None:
211
+ """removes given handler from the HuggingFace muse' root logger."""
212
+
213
+ _configure_library_root_logger()
214
+
215
+ assert handler is not None and handler not in _get_library_root_logger().handlers
216
+ _get_library_root_logger().removeHandler(handler)
217
+
218
+
219
+ def disable_propagation() -> None:
220
+ """
221
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
222
+ """
223
+
224
+ _configure_library_root_logger()
225
+ _get_library_root_logger().propagate = False
226
+
227
+
228
+ def enable_propagation() -> None:
229
+ """
230
+ Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
231
+ double logging if the root logger has been configured.
232
+ """
233
+
234
+ _configure_library_root_logger()
235
+ _get_library_root_logger().propagate = True
236
+
237
+
238
+ def enable_explicit_format() -> None:
239
+ """
240
+ Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
241
+ ```
242
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
243
+ ```
244
+ All handlers currently bound to the root logger are affected by this method.
245
+ """
246
+ handlers = _get_library_root_logger().handlers
247
+
248
+ for handler in handlers:
249
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
250
+ handler.setFormatter(formatter)
251
+
252
+
253
+ def reset_format() -> None:
254
+ """
255
+ Resets the formatting for HuggingFace muse' loggers.
256
+
257
+ All handlers currently bound to the root logger are affected by this method.
258
+ """
259
+ handlers = _get_library_root_logger().handlers
260
+
261
+ for handler in handlers:
262
+ handler.setFormatter(None)
263
+
264
+
265
+ def warning_advice(self, *args, **kwargs):
266
+ """
267
+ This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
268
+ warning will not be printed
269
+ """
270
+ no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
271
+ if no_advisory_warnings:
272
+ return
273
+ self.warning(*args, **kwargs)
274
+
275
+
276
+ logging.Logger.warning_advice = warning_advice
277
+
278
+
279
+ class EmptyTqdm:
280
+ """Dummy tqdm which doesn't do anything."""
281
+
282
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
283
+ self._iterator = args[0] if args else None
284
+
285
+ def __iter__(self):
286
+ return iter(self._iterator)
287
+
288
+ def __getattr__(self, _):
289
+ """Return empty function."""
290
+
291
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
292
+ return
293
+
294
+ return empty_fn
295
+
296
+ def __enter__(self):
297
+ return self
298
+
299
+ def __exit__(self, type_, value, traceback):
300
+ return
301
+
302
+
303
+ class _tqdm_cls:
304
+ def __call__(self, *args, **kwargs):
305
+ if _tqdm_active:
306
+ return tqdm_lib.tqdm(*args, **kwargs)
307
+ else:
308
+ return EmptyTqdm(*args, **kwargs)
309
+
310
+ def set_lock(self, *args, **kwargs):
311
+ self._lock = None
312
+ if _tqdm_active:
313
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
314
+
315
+ def get_lock(self):
316
+ if _tqdm_active:
317
+ return tqdm_lib.tqdm.get_lock()
318
+
319
+
320
+ tqdm = _tqdm_cls()
321
+
322
+
323
+ def is_progress_bar_enabled() -> bool:
324
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
325
+ global _tqdm_active
326
+ return bool(_tqdm_active)
327
+
328
+
329
+ def enable_progress_bar():
330
+ """Enable tqdm progress bar."""
331
+ global _tqdm_active
332
+ _tqdm_active = True
333
+
334
+
335
+ def disable_progress_bar():
336
+ """Disable tqdm progress bar."""
337
+ global _tqdm_active
338
+ _tqdm_active = False
models/lr_schedulers.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for diffusion models."""
16
+
17
+ import math
18
+ from enum import Enum
19
+ from typing import Optional, Union
20
+
21
+ from torch.optim import Optimizer
22
+ from torch.optim.lr_scheduler import LambdaLR
23
+
24
+ from .logging import get_logger
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ class SchedulerType(Enum):
30
+ LINEAR = "linear"
31
+ COSINE = "cosine"
32
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
33
+ POLYNOMIAL = "polynomial"
34
+ CONSTANT = "constant"
35
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
36
+
37
+
38
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
39
+ """
40
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
41
+
42
+ Args:
43
+ optimizer ([`~torch.optim.Optimizer`]):
44
+ The optimizer for which to schedule the learning rate.
45
+ last_epoch (`int`, *optional*, defaults to -1):
46
+ The index of the last epoch when resuming training.
47
+
48
+ Return:
49
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
50
+ """
51
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
52
+
53
+
54
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
55
+ """
56
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
57
+ increases linearly between 0 and the initial lr set in the optimizer.
58
+
59
+ Args:
60
+ optimizer ([`~torch.optim.Optimizer`]):
61
+ The optimizer for which to schedule the learning rate.
62
+ num_warmup_steps (`int`):
63
+ The number of steps for the warmup phase.
64
+ last_epoch (`int`, *optional*, defaults to -1):
65
+ The index of the last epoch when resuming training.
66
+
67
+ Return:
68
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
69
+ """
70
+
71
+ def lr_lambda(current_step: int):
72
+ if current_step < num_warmup_steps:
73
+ return float(current_step) / float(max(1.0, num_warmup_steps))
74
+ return 1.0
75
+
76
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
77
+
78
+
79
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
80
+ """
81
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
82
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
83
+
84
+ Args:
85
+ optimizer ([`~torch.optim.Optimizer`]):
86
+ The optimizer for which to schedule the learning rate.
87
+ num_warmup_steps (`int`):
88
+ The number of steps for the warmup phase.
89
+ num_training_steps (`int`):
90
+ The total number of training steps.
91
+ last_epoch (`int`, *optional*, defaults to -1):
92
+ The index of the last epoch when resuming training.
93
+
94
+ Return:
95
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
96
+ """
97
+
98
+ def lr_lambda(current_step: int):
99
+ if current_step < num_warmup_steps:
100
+ return float(current_step) / float(max(1, num_warmup_steps))
101
+ return max(
102
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
103
+ )
104
+
105
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
106
+
107
+
108
+ def get_cosine_schedule_with_warmup(
109
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, min_lr_scale: float = 0.0
110
+ ):
111
+ """
112
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
113
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
114
+ initial lr set in the optimizer.
115
+
116
+ Args:
117
+ optimizer ([`~torch.optim.Optimizer`]):
118
+ The optimizer for which to schedule the learning rate.
119
+ num_warmup_steps (`int`):
120
+ The number of steps for the warmup phase.
121
+ num_training_steps (`int`):
122
+ The total number of training steps.
123
+ num_periods (`float`, *optional*, defaults to 0.5):
124
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
125
+ value to 0 following a half-cosine).
126
+ last_epoch (`int`, *optional*, defaults to -1):
127
+ The index of the last epoch when resuming training.
128
+
129
+ Return:
130
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
131
+ """
132
+
133
+ # def lr_lambda(current_step):
134
+ # if current_step < num_warmup_steps:
135
+ # return float(current_step) / float(max(1, num_warmup_steps))
136
+ # progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
137
+ # return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
138
+
139
+ # return LambdaLR(optimizer, lr_lambda, last_epoch)
140
+
141
+ def lr_lambda(current_step):
142
+ if current_step < num_warmup_steps:
143
+ return float(current_step) / float(max(1, num_warmup_steps))
144
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
145
+ cosine_decay = 0.5 * (1.0 + math.cos(math.pi * 2.0 * num_cycles * progress))
146
+ return min_lr_scale + (1.0 - min_lr_scale) * cosine_decay
147
+
148
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
149
+
150
+
151
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
152
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
153
+ ):
154
+ """
155
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
156
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
157
+ linearly between 0 and the initial lr set in the optimizer.
158
+
159
+ Args:
160
+ optimizer ([`~torch.optim.Optimizer`]):
161
+ The optimizer for which to schedule the learning rate.
162
+ num_warmup_steps (`int`):
163
+ The number of steps for the warmup phase.
164
+ num_training_steps (`int`):
165
+ The total number of training steps.
166
+ num_cycles (`int`, *optional*, defaults to 1):
167
+ The number of hard restarts to use.
168
+ last_epoch (`int`, *optional*, defaults to -1):
169
+ The index of the last epoch when resuming training.
170
+
171
+ Return:
172
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
173
+ """
174
+
175
+ def lr_lambda(current_step):
176
+ if current_step < num_warmup_steps:
177
+ return float(current_step) / float(max(1, num_warmup_steps))
178
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
179
+ if progress >= 1.0:
180
+ return 0.0
181
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
182
+
183
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
184
+
185
+
186
+ def get_polynomial_decay_schedule_with_warmup(
187
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
188
+ ):
189
+ """
190
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
191
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
192
+ initial lr set in the optimizer.
193
+
194
+ Args:
195
+ optimizer ([`~torch.optim.Optimizer`]):
196
+ The optimizer for which to schedule the learning rate.
197
+ num_warmup_steps (`int`):
198
+ The number of steps for the warmup phase.
199
+ num_training_steps (`int`):
200
+ The total number of training steps.
201
+ lr_end (`float`, *optional*, defaults to 1e-7):
202
+ The end LR.
203
+ power (`float`, *optional*, defaults to 1.0):
204
+ Power factor.
205
+ last_epoch (`int`, *optional*, defaults to -1):
206
+ The index of the last epoch when resuming training.
207
+
208
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
209
+ implementation at
210
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
211
+
212
+ Return:
213
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
214
+
215
+ """
216
+
217
+ lr_init = optimizer.defaults["lr"]
218
+ if not (lr_init > lr_end):
219
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
220
+
221
+ def lr_lambda(current_step: int):
222
+ if current_step < num_warmup_steps:
223
+ return float(current_step) / float(max(1, num_warmup_steps))
224
+ elif current_step > num_training_steps:
225
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
226
+ else:
227
+ lr_range = lr_init - lr_end
228
+ decay_steps = num_training_steps - num_warmup_steps
229
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
230
+ decay = lr_range * pct_remaining**power + lr_end
231
+ return decay / lr_init # as LambdaLR multiplies by lr_init
232
+
233
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
234
+
235
+
236
+ TYPE_TO_SCHEDULER_FUNCTION = {
237
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
238
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
239
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
240
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
241
+ SchedulerType.CONSTANT: get_constant_schedule,
242
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
243
+ }
244
+
245
+
246
+ def get_scheduler(
247
+ name: Union[str, SchedulerType],
248
+ optimizer: Optimizer,
249
+ num_warmup_steps: Optional[int] = None,
250
+ num_training_steps: Optional[int] = None,
251
+ num_cycles: int = 1,
252
+ power: float = 1.0,
253
+ min_lr_scale: float = 0.0
254
+ ):
255
+ """
256
+ Unified API to get any scheduler from its name.
257
+
258
+ Args:
259
+ name (`str` or `SchedulerType`):
260
+ The name of the scheduler to use.
261
+ optimizer (`torch.optim.Optimizer`):
262
+ The optimizer that will be used during training.
263
+ num_warmup_steps (`int`, *optional*):
264
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
265
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
266
+ num_training_steps (`int``, *optional*):
267
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
268
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
269
+ num_cycles (`int`, *optional*):
270
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
271
+ power (`float`, *optional*, defaults to 1.0):
272
+ Power factor. See `POLYNOMIAL` scheduler
273
+ last_epoch (`int`, *optional*, defaults to -1):
274
+ The index of the last epoch when resuming training.
275
+ """
276
+ name = SchedulerType(name)
277
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
278
+ if name == SchedulerType.CONSTANT:
279
+ return schedule_func(optimizer)
280
+
281
+ # All other schedulers require `num_warmup_steps`
282
+ if num_warmup_steps is None:
283
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
284
+
285
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
286
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
287
+
288
+ # All other schedulers require `num_training_steps`
289
+ if num_training_steps is None:
290
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
291
+
292
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
293
+ return schedule_func(
294
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, min_lr_scale=min_lr_scale
295
+ )
296
+
297
+ if name == SchedulerType.POLYNOMIAL:
298
+ return schedule_func(
299
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
300
+ )
301
+
302
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
models/misc.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import torch
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Iterable,
8
+ List,
9
+ NamedTuple,
10
+ NewType,
11
+ Optional,
12
+ Sized,
13
+ Tuple,
14
+ Type,
15
+ TypeVar,
16
+ Union,
17
+ )
18
+ try:
19
+ from typing import Literal
20
+ except ImportError:
21
+ from typing_extensions import Literal
22
+
23
+ # Tensor dtype
24
+ # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
25
+ from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
26
+
27
+ # Config type
28
+ from omegaconf import DictConfig
29
+
30
+ # PyTorch Tensor type
31
+ from torch import Tensor
32
+
33
+ # Runtime type checking decorator
34
+ from typeguard import typechecked as typechecker
35
+
36
+
37
+ def broadcast(tensor, src=0):
38
+ if not _distributed_available():
39
+ return tensor
40
+ else:
41
+ torch.distributed.broadcast(tensor, src=src)
42
+ return tensor
43
+
44
+ def _distributed_available():
45
+ return torch.distributed.is_available() and torch.distributed.is_initialized()
46
+
47
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
48
+ # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
49
+ if '--local-rank' in cfg:
50
+ del cfg['--local-rank']
51
+ # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
52
+ scfg = OmegaConf.structured(fields(**cfg))
53
+ return scfg
models/modeling_llada.py ADDED
@@ -0,0 +1,1500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import sys
6
+ from abc import abstractmethod
7
+ from collections import defaultdict
8
+ from functools import partial
9
+ from typing import (
10
+ Callable,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ NamedTuple,
15
+ Optional,
16
+ Sequence,
17
+ Set,
18
+ Tuple,
19
+ cast,
20
+ )
21
+ from dataclasses import fields
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.backends.cuda
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch import einsum
29
+ from transformers import PreTrainedModel
30
+ from transformers.modeling_outputs import CausalLMOutputWithPast
31
+ from transformers.models.auto import AutoModel
32
+ from transformers.cache_utils import Cache
33
+
34
+ from .configuration_llada import (
35
+ LLaDAConfig,
36
+ StrEnum,
37
+ InitFnType,
38
+ ActivationType,
39
+ BlockType,
40
+ LayerNormType,
41
+ ModelConfig,
42
+ ActivationCheckpointingStrategy,
43
+ )
44
+
45
+ if sys.version_info.minor > 8:
46
+ from collections.abc import MutableMapping
47
+ elif sys.version_info.minor == 8:
48
+ from typing import MutableMapping
49
+ else:
50
+ raise SystemExit("This script supports Python 3.8 or higher")
51
+
52
+ __all__ = [
53
+ "LayerNormBase",
54
+ "LayerNorm",
55
+ "RMSLayerNorm",
56
+ "GemmaRMSLayerNorm",
57
+ "RotaryEmbedding",
58
+ "Activation",
59
+ "GELU",
60
+ "ReLU",
61
+ "SwiGLU",
62
+ "LLaDABlock",
63
+ "LLaDASequentialBlock",
64
+ "LLaDAModel",
65
+ "LLaDAOutput",
66
+ "LLaDAGenerateOutput",
67
+ ]
68
+
69
+
70
+ log = logging.getLogger(__name__)
71
+
72
+
73
+ class ModuleType(StrEnum):
74
+ in_module = "in"
75
+ out_module = "out"
76
+ emb = "emb"
77
+ final_out = "final_out"
78
+
79
+
80
+ def init_weights(
81
+ config: ModelConfig,
82
+ module: Union[nn.Linear, nn.Embedding],
83
+ d: Optional[int] = None,
84
+ layer_id: Optional[int] = None,
85
+ std_factor: float = 1.0,
86
+ type_of_module: Optional[ModuleType] = None,
87
+ ) -> None:
88
+ """
89
+ Initialize weights of a linear or embedding module.
90
+
91
+ :param config: The model config.
92
+ :param module: The linear or embedding submodule to initialize.
93
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
94
+ for fused layers.
95
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
96
+ ``1 / sqrt(2 * (layer_id + 1))``.
97
+ """
98
+ d = d if d is not None else config.d_model
99
+ if config.init_fn == InitFnType.normal:
100
+ std = config.init_std * std_factor
101
+ if config.init_cutoff_factor is not None:
102
+ cutoff_value = config.init_cutoff_factor * std
103
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
104
+ else:
105
+ nn.init.normal_(module.weight, mean=0.0, std=std)
106
+ elif config.init_fn == InitFnType.mitchell:
107
+ std = std_factor / math.sqrt(d)
108
+ if layer_id is not None:
109
+ std = std / math.sqrt(2 * (layer_id + 1))
110
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
111
+ elif config.init_fn == InitFnType.kaiming_normal:
112
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
113
+ elif config.init_fn == InitFnType.fan_in:
114
+ std = std_factor / math.sqrt(d)
115
+ nn.init.normal_(module.weight, mean=0.0, std=std)
116
+ elif config.init_fn == InitFnType.full_megatron:
117
+ if type_of_module is None:
118
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
119
+
120
+ cutoff_factor = config.init_cutoff_factor
121
+ if cutoff_factor is None:
122
+ cutoff_factor = 3
123
+
124
+ if type_of_module == ModuleType.in_module:
125
+ # for att_proj (same as QKV), ff_proj
126
+ std = config.init_std
127
+ elif type_of_module == ModuleType.out_module:
128
+ # for attn_out, ff_out
129
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
130
+ elif type_of_module == ModuleType.emb:
131
+ # positional embeddings (wpe)
132
+ # token embeddings (wte)
133
+ std = config.init_std
134
+ elif type_of_module == ModuleType.final_out:
135
+ # final output (ff_out)
136
+ std = config.d_model**-0.5
137
+ else:
138
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
139
+ nn.init.trunc_normal_(
140
+ module.weight,
141
+ mean=0.0,
142
+ std=std,
143
+ a=-cutoff_factor * std,
144
+ b=cutoff_factor * std,
145
+ )
146
+ else:
147
+ raise NotImplementedError(config.init_fn)
148
+
149
+ if isinstance(module, nn.Linear):
150
+ if module.bias is not None:
151
+ nn.init.zeros_(module.bias)
152
+
153
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
154
+ with torch.no_grad():
155
+ module.weight.div_(math.sqrt(2 * config.n_layers))
156
+
157
+
158
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
159
+ """
160
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
161
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
162
+ """
163
+ if check_neg_inf:
164
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
165
+ if check_pos_inf:
166
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
167
+
168
+
169
+ def activation_checkpoint_function(cfg: ModelConfig):
170
+ preserve_rng_state = (
171
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
172
+ )
173
+ from torch.utils.checkpoint import checkpoint
174
+
175
+ return partial(
176
+ checkpoint,
177
+ preserve_rng_state=preserve_rng_state,
178
+ use_reentrant=False,
179
+ )
180
+
181
+
182
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
183
+ """
184
+ Cache for attention biases and other things that would normally be stored as buffers.
185
+ We avoid using buffers because we've run into various issues doing so with FSDP.
186
+ In general it appears the way FSDP handles buffers is not well-defined.
187
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
188
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
189
+ NaNs when they're synchronized due to casting or some other issue.
190
+ """
191
+
192
+
193
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
194
+ if config.init_device is not None and config.init_device != "meta":
195
+ return torch.device(config.init_device)
196
+ else:
197
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
+
199
+
200
+ class Dropout(nn.Dropout):
201
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
202
+ if self.p == 0.0:
203
+ return input
204
+ else:
205
+ return F.dropout(input, self.p, self.training, self.inplace)
206
+
207
+
208
+ class LayerNormBase(nn.Module):
209
+ def __init__(
210
+ self,
211
+ config: ModelConfig,
212
+ *,
213
+ size: Optional[int] = None,
214
+ elementwise_affine: Optional[bool] = True,
215
+ eps: float = 1e-05,
216
+ ):
217
+ super().__init__()
218
+ self.config = config
219
+ self.eps = eps
220
+ self.normalized_shape = (size or config.d_model,)
221
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
222
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
223
+ use_bias = self.config.bias_for_layer_norm
224
+ if use_bias is None:
225
+ use_bias = self.config.include_bias
226
+ if use_bias:
227
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
228
+ else:
229
+ self.register_parameter("bias", None)
230
+ else:
231
+ self.register_parameter("bias", None)
232
+ self.register_parameter("weight", None)
233
+
234
+ @abstractmethod
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ raise NotImplementedError
237
+
238
+ @classmethod
239
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
240
+ if config.layer_norm_type == LayerNormType.default:
241
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
242
+ elif config.layer_norm_type == LayerNormType.low_precision:
243
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
244
+ elif config.layer_norm_type == LayerNormType.rms:
245
+ return RMSLayerNorm(config, size=size, **kwargs)
246
+ elif config.layer_norm_type == LayerNormType.gemma_rms:
247
+ return GemmaRMSLayerNorm(config, size=size, **kwargs)
248
+ else:
249
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
250
+
251
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
252
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
253
+ # `is_autocast_cpu_enabled()` for CPU autocast.
254
+ # See https://github.com/pytorch/pytorch/issues/110966.
255
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
256
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
257
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
258
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
259
+ else:
260
+ return tensor
261
+
262
+ def reset_parameters(self):
263
+ if self.weight is not None:
264
+ torch.nn.init.ones_(self.weight) # type: ignore
265
+ if self.bias is not None:
266
+ torch.nn.init.zeros_(self.bias) # type: ignore
267
+
268
+
269
+ class LayerNorm(LayerNormBase):
270
+ """
271
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ config: ModelConfig,
277
+ size: Optional[int] = None,
278
+ low_precision: bool = False,
279
+ elementwise_affine: Optional[bool] = None,
280
+ eps: float = 1e-05,
281
+ ):
282
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
283
+ self.low_precision = low_precision
284
+
285
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
286
+ if self.low_precision:
287
+ module_device = x.device
288
+ downcast_x = self._cast_if_autocast_enabled(x)
289
+ downcast_weight = (
290
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
291
+ )
292
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
293
+ with torch.autocast(enabled=False, device_type=module_device.type):
294
+ return F.layer_norm(
295
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
296
+ )
297
+ else:
298
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
299
+
300
+
301
+ class RMSLayerNorm(LayerNormBase):
302
+ """
303
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ config: ModelConfig,
309
+ size: Optional[int] = None,
310
+ elementwise_affine: Optional[bool] = None,
311
+ eps: float = 1e-5,
312
+ ):
313
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
314
+
315
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
+ with torch.autocast(enabled=False, device_type=x.device.type):
317
+ og_dtype = x.dtype
318
+ x = x.to(torch.float32)
319
+ variance = x.pow(2).mean(-1, keepdim=True)
320
+ x = x * torch.rsqrt(variance + self.eps)
321
+ x = x.to(og_dtype)
322
+
323
+ if self.weight is not None:
324
+ if self.bias is not None:
325
+ return self.weight * x + self.bias
326
+ else:
327
+ return self.weight * x
328
+ else:
329
+ return x
330
+
331
+
332
+ class GemmaRMSLayerNorm(LayerNormBase):
333
+ """
334
+ Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ config: ModelConfig,
340
+ size: Optional[int] = None,
341
+ elementwise_affine: Optional[bool] = None,
342
+ eps: float = 1e-5,
343
+ ):
344
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
345
+
346
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
347
+ with torch.autocast(enabled=False, device_type=x.device.type):
348
+ og_dtype = x.dtype
349
+ x = x.to(torch.float32)
350
+ variance = x.pow(2).mean(-1, keepdim=True)
351
+ x = x * torch.rsqrt(variance + self.eps)
352
+ x = x.to(og_dtype)
353
+
354
+ if self.weight is not None:
355
+ if self.bias is not None:
356
+ return x * (1 + self.weight) + self.bias
357
+ else:
358
+ return x * (1 + self.weight)
359
+ else:
360
+ return x
361
+
362
+
363
+ class RotaryEmbedding(nn.Module):
364
+ """
365
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
366
+ """
367
+
368
+ def __init__(self, config: ModelConfig, cache: BufferCache):
369
+ super().__init__()
370
+ self.config = config
371
+ self.__cache = cache
372
+ # Warm up cache.
373
+ self.rope_theta = config.rope_theta
374
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
375
+
376
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
377
+ if (
378
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
379
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
380
+ and pos_sin.shape[-2] >= seq_len
381
+ and pos_cos.shape[-2] >= seq_len
382
+ ):
383
+ if pos_sin.device != device:
384
+ pos_sin = pos_sin.to(device)
385
+ self.__cache["rope_pos_sin"] = pos_sin
386
+ if pos_cos.device != device:
387
+ pos_cos = pos_cos.to(device)
388
+ self.__cache["rope_pos_cos"] = pos_cos
389
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
390
+
391
+ with torch.autocast(device.type, enabled=False):
392
+ dim = self.config.d_model // self.config.n_heads
393
+ inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
394
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
395
+ freqs = einsum("i , j -> i j", seq, inv_freq)
396
+ positions = torch.cat((freqs, freqs), dim=-1)
397
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
398
+ self.__cache["rope_pos_sin"] = pos_sin
399
+ self.__cache["rope_pos_cos"] = pos_cos
400
+ return pos_sin, pos_cos
401
+
402
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
403
+ B, nh, T, hs = x.size()
404
+ x = x.view(B, nh, T, 2, hs // 2)
405
+ x1, x2 = x.unbind(dim=-2)
406
+ return torch.cat((-x2, x1), dim=-1)
407
+
408
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
409
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
410
+
411
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
412
+ if self.config.rope_full_precision:
413
+ q_, k_ = q.float(), k.float()
414
+ else:
415
+ q_, k_ = q, k
416
+
417
+ with torch.autocast(q.device.type, enabled=False):
418
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
419
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
420
+ pos_sin = pos_sin.type_as(q_)
421
+ pos_cos = pos_cos.type_as(q_)
422
+ q_ = self.apply_rotary_pos_emb(
423
+ pos_sin[:, :, key_len - query_len : key_len, :],
424
+ pos_cos[:, :, key_len - query_len : key_len, :],
425
+ q_,
426
+ )
427
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
428
+ return q_.type_as(q), k_.type_as(k)
429
+
430
+
431
+ class Activation(nn.Module):
432
+ def __init__(self, config: ModelConfig):
433
+ super().__init__()
434
+ self.config = config
435
+
436
+ @abstractmethod
437
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
438
+ raise NotImplementedError
439
+
440
+ @property
441
+ @abstractmethod
442
+ def output_multiplier(self) -> float:
443
+ raise NotImplementedError
444
+
445
+ @classmethod
446
+ def build(cls, config: ModelConfig) -> Activation:
447
+ if config.activation_type == ActivationType.gelu:
448
+ return cast(Activation, GELU(approximate="none"))
449
+ elif config.activation_type == ActivationType.relu:
450
+ return cast(Activation, ReLU(inplace=False))
451
+ elif config.activation_type == ActivationType.silu:
452
+ return cast(Activation, SiLU(inplace=False))
453
+ elif config.activation_type == ActivationType.swiglu:
454
+ return SwiGLU(config)
455
+ else:
456
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
457
+
458
+
459
+ class GELU(nn.GELU):
460
+ @property
461
+ def output_multiplier(self) -> float:
462
+ return 1.0
463
+
464
+
465
+ class ReLU(nn.ReLU):
466
+ @property
467
+ def output_multiplier(self) -> float:
468
+ return 1.0
469
+
470
+ class SiLU(nn.SiLU):
471
+ @property
472
+ def output_multiplier(self) -> float:
473
+ return 1.0
474
+
475
+ class SwiGLU(Activation):
476
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
477
+ x, gate = x.chunk(2, dim=-1)
478
+ return F.silu(gate) * x
479
+
480
+ @property
481
+ def output_multiplier(self) -> float:
482
+ return 0.5
483
+
484
+
485
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
486
+ att_bias = torch.triu(
487
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
488
+ diagonal=1,
489
+ )
490
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
491
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
492
+
493
+
494
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
495
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
496
+ if causal_bias.device != device:
497
+ causal_bias = causal_bias.to(device)
498
+ cache["causal_attention_bias"] = causal_bias
499
+ return causal_bias
500
+ with torch.autocast(device.type, enabled=False):
501
+ causal_bias = causal_attention_bias(seq_len, device)
502
+ cache["causal_attention_bias"] = causal_bias
503
+ return causal_bias
504
+
505
+
506
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
507
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
508
+
509
+ # shape: (1, 1, seq_len, seq_len)
510
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
511
+ alibi_bias.abs_().mul_(-1)
512
+
513
+ # shape: (n_heads,)
514
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
515
+ m.mul_(config.alibi_bias_max / config.n_heads)
516
+
517
+ # shape: (1, n_heads, seq_len, seq_len)
518
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
519
+
520
+
521
+ class LLaDABlock(nn.Module):
522
+ """
523
+ A base class for transformer block implementations.
524
+ """
525
+
526
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
527
+ super().__init__()
528
+ self.layer_id = layer_id
529
+ self.config = config
530
+ self.hidden_size = (
531
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
532
+ )
533
+ self.__cache = cache
534
+ assert config.d_model % config.n_heads == 0
535
+
536
+ self._activation_checkpoint_fn = None
537
+
538
+ # Dropout.
539
+ self.dropout = Dropout(config.residual_dropout)
540
+
541
+ # Layer norms.
542
+ self.k_norm: Optional[LayerNormBase] = None
543
+ self.q_norm: Optional[LayerNormBase] = None
544
+ if config.attention_layer_norm:
545
+ self.k_norm = LayerNormBase.build(
546
+ config,
547
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
548
+ elementwise_affine=config.attention_layer_norm_with_affine,
549
+ )
550
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
551
+
552
+ # Activation function.
553
+ self.act = Activation.build(config)
554
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
555
+
556
+ # Attention output projection.
557
+ self.attn_out = nn.Linear(
558
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
559
+ )
560
+
561
+ # Feed-forward output projection.
562
+ self.ff_out = nn.Linear(
563
+ int(self.act.output_multiplier * self.hidden_size),
564
+ config.d_model,
565
+ bias=config.include_bias,
566
+ device=config.init_device,
567
+ )
568
+ self.ff_out._is_residual = True # type: ignore
569
+
570
+ # Rotary embeddings.
571
+ if self.config.rope:
572
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
573
+
574
+ self.flash_attn_func = None
575
+ if config.flash_attention:
576
+ try:
577
+ from flash_attn import flash_attn_func # type: ignore
578
+
579
+ self.flash_attn_func = flash_attn_func
580
+ except ModuleNotFoundError:
581
+ pass
582
+
583
+ def reset_parameters(self):
584
+ if self.k_norm is not None:
585
+ self.k_norm.reset_parameters()
586
+ if self.q_norm is not None:
587
+ self.q_norm.reset_parameters()
588
+ init_weights(
589
+ self.config,
590
+ self.attn_out,
591
+ d=self.config.d_model,
592
+ layer_id=self.layer_id,
593
+ type_of_module=ModuleType.out_module,
594
+ )
595
+ init_weights(
596
+ self.config,
597
+ self.ff_out,
598
+ d=self.ff_out.in_features,
599
+ layer_id=self.layer_id,
600
+ type_of_module=ModuleType.out_module,
601
+ )
602
+
603
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
604
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
605
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
606
+ else:
607
+ self._activation_checkpoint_fn = None
608
+
609
+ @classmethod
610
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
611
+ target_dtype = input_dtype
612
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
613
+ # `is_autocast_cpu_enabled()` for CPU autocast.
614
+ # See https://github.com/pytorch/pytorch/issues/110966.
615
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
616
+ target_dtype = torch.get_autocast_gpu_dtype()
617
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
618
+ target_dtype = torch.get_autocast_cpu_dtype()
619
+ if bias.dtype != target_dtype:
620
+ bias = bias.to(target_dtype)
621
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
622
+ return bias
623
+
624
+ def _scaled_dot_product_attention(
625
+ self,
626
+ q: torch.Tensor,
627
+ k: torch.Tensor,
628
+ v: torch.Tensor,
629
+ attn_mask: Optional[torch.Tensor] = None,
630
+ dropout_p: float = 0.0,
631
+ is_causal: bool = False,
632
+ ) -> torch.Tensor:
633
+ """
634
+ Computes scaled dot product attention on query, key and value tensors, using an optional
635
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
636
+ """
637
+ if self.flash_attn_func is not None and attn_mask is None:
638
+ r = self.flash_attn_func(
639
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
640
+ )
641
+ return r.transpose(1, 2)
642
+ else:
643
+ # torch's sdpa doesn't support GQA, so we're doing this
644
+ assert k.size(1) == v.size(1)
645
+ num_kv_heads = k.size(1)
646
+ num_q_heads = q.size(1)
647
+ if num_q_heads != num_kv_heads:
648
+ assert num_q_heads % num_kv_heads == 0
649
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
650
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
651
+
652
+ # Modify: MDM set causal to False, and with no attn_mask.
653
+ return F.scaled_dot_product_attention(
654
+ q,
655
+ k,
656
+ v,
657
+ attn_mask=None,
658
+ dropout_p=dropout_p,
659
+ is_causal=False,
660
+ )
661
+
662
+ def attention(
663
+ self,
664
+ q: torch.Tensor,
665
+ k: torch.Tensor,
666
+ v: torch.Tensor,
667
+ attention_bias: Optional[torch.Tensor] = None,
668
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
669
+ use_cache: bool = False,
670
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
671
+ B, T, C = q.size() # batch size, sequence length, d_model
672
+ dtype = k.dtype
673
+
674
+ # Optionally apply layer norm to keys and queries.
675
+ if self.q_norm is not None and self.k_norm is not None:
676
+ q = self.q_norm(q).to(dtype=dtype)
677
+ k = self.k_norm(k).to(dtype=dtype)
678
+
679
+ # Move head forward to be next to the batch dim.
680
+ # shape: (B, nh, T, hs)
681
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
682
+ # shape: (B, n_kv_h, T, hs)
683
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
684
+ # shape: (B, n_kv_h, T, hs)
685
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
686
+
687
+ if layer_past is not None:
688
+ past_key, past_value = layer_past
689
+ k = torch.cat((past_key, k), dim=-2)
690
+ v = torch.cat((past_value, v), dim=-2)
691
+
692
+ present = (k, v) if use_cache else None
693
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
694
+
695
+ if self.config.rope:
696
+ # Apply rotary embeddings.
697
+ q, k = self.rotary_emb(q, k)
698
+
699
+ if attention_bias is not None:
700
+ # Resize and cast attention bias.
701
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
702
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
703
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
704
+ # cause the SDP attn function to produce NaNs.
705
+ attention_bias = self._cast_attn_bias(
706
+ attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
707
+ )
708
+
709
+ # Get the attention scores.
710
+ # shape: (B, nh, T, hs)
711
+ att = self._scaled_dot_product_attention(
712
+ q,
713
+ k,
714
+ v,
715
+ attn_mask=None,
716
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
717
+ is_causal=False,
718
+ )
719
+
720
+ # Re-assemble all head outputs side-by-side.
721
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
722
+
723
+ # Apply output projection.
724
+ return self.attn_out(att), present
725
+
726
+ @abstractmethod
727
+ def forward(
728
+ self,
729
+ x: torch.Tensor,
730
+ attention_bias: Optional[torch.FloatTensor] = None,
731
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
732
+ use_cache: bool = False,
733
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
734
+ raise NotImplementedError
735
+
736
+ @classmethod
737
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
738
+ if config.block_type == BlockType.sequential:
739
+ return LLaDASequentialBlock(layer_id, config, cache)
740
+ elif config.block_type == BlockType.llama:
741
+ return LLaDALlamaBlock(layer_id, config, cache)
742
+ else:
743
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
744
+
745
+
746
+ class LLaDASequentialBlock(LLaDABlock):
747
+ """
748
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
749
+ (plus another skip connection).
750
+ """
751
+
752
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
753
+ super().__init__(layer_id, config, cache)
754
+ # Layer norms.
755
+ self.attn_norm = LayerNorm.build(config)
756
+ self.ff_norm = LayerNorm.build(config)
757
+ # Attention input projection. Projects x -> (q, k, v)
758
+ head_dim = config.d_model // config.n_heads
759
+ self.fused_dims = (
760
+ config.d_model,
761
+ config.effective_n_kv_heads * head_dim,
762
+ config.effective_n_kv_heads * head_dim,
763
+ )
764
+ self.att_proj = nn.Linear(
765
+ config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
766
+ )
767
+ # Feed-forward input projection.
768
+ self.ff_proj = nn.Linear(
769
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
770
+ )
771
+
772
+ def reset_parameters(self):
773
+ super().reset_parameters()
774
+ self.attn_norm.reset_parameters()
775
+ self.ff_norm.reset_parameters()
776
+ # NOTE: the standard deviation for these weights does not depend on the layer.
777
+ init_weights(
778
+ self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
779
+ )
780
+ init_weights(
781
+ self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
782
+ )
783
+
784
+ def forward(
785
+ self,
786
+ x: torch.Tensor,
787
+ attention_bias: Optional[torch.Tensor] = None,
788
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
789
+ use_cache: bool = False,
790
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
791
+ # Get query, key, value projections.
792
+ # shape:
793
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
794
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
795
+ # k, v: (batch_size, seq_len, d_model // n_heads)
796
+ # - for group query attn q: (batch_size, seq_len, d_model)
797
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
798
+ if self._activation_checkpoint_fn is not None:
799
+ q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
800
+ self.fused_dims, dim=-1
801
+ )
802
+ else:
803
+ q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
804
+
805
+ # Get attention scores.
806
+ if self._activation_checkpoint_fn is not None:
807
+ att, cache = self._activation_checkpoint_fn( # type: ignore
808
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
809
+ )
810
+ else:
811
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
812
+
813
+ # Add attention scores.
814
+ # shape: (B, T, C)
815
+ x = x + self.dropout(att)
816
+
817
+ # Add feed-forward projection.
818
+ # shape: (batch_size, seq_len, d_model)
819
+ og_x = x
820
+ if self._activation_checkpoint_fn is not None:
821
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
822
+ else:
823
+ x = self.ff_norm(x)
824
+ x = self.ff_proj(x)
825
+ if self._activation_checkpoint_fn is not None:
826
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
827
+ else:
828
+ x = self.act(x)
829
+ x = self.ff_out(x)
830
+ x = self.dropout(x)
831
+ x = og_x + x
832
+
833
+ return x, cache
834
+
835
+
836
+ class LLaDALlamaBlock(LLaDABlock):
837
+ """
838
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
839
+ (plus another skip connection). This block is similar to `LLaDASequentialBlock`
840
+ but some operations have slightly different implementations to imitate the
841
+ behavior of Llama.
842
+ """
843
+
844
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
845
+ super().__init__(layer_id, config, cache)
846
+ # Layer norms.
847
+ self.attn_norm = LayerNorm.build(config)
848
+ self.ff_norm = LayerNorm.build(config)
849
+ self.__cache = cache
850
+
851
+ # Attention input projection. Projects x -> (q, k, v)
852
+ head_dim = config.d_model // config.n_heads
853
+ q_proj_out_dim = config.d_model
854
+ k_proj_out_dim = config.effective_n_kv_heads * head_dim
855
+ v_proj_out_dim = config.effective_n_kv_heads * head_dim
856
+ self.q_proj = nn.Linear(
857
+ config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
858
+ )
859
+ self.k_proj = nn.Linear(
860
+ config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
861
+ )
862
+ self.v_proj = nn.Linear(
863
+ config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
864
+ )
865
+
866
+ # Feed-forward input projection.
867
+ self.ff_proj = nn.Linear(
868
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
869
+ )
870
+ # new add
871
+ self.up_proj = nn.Linear(
872
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
873
+ )
874
+
875
+ def reset_parameters(self):
876
+ super().reset_parameters()
877
+ self.attn_norm.reset_parameters()
878
+ self.ff_norm.reset_parameters()
879
+ # NOTE: the standard deviation for these weights does not depend on the layer.
880
+ init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
881
+ init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
882
+ init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
883
+ init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
884
+ init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
885
+
886
+ def forward(
887
+ self,
888
+ x: torch.Tensor,
889
+ attention_bias: Optional[torch.Tensor] = None,
890
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
891
+ use_cache: bool = False,
892
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
893
+ # Get query, key, value projections.
894
+ # shape:
895
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
896
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
897
+ # k, v: (batch_size, seq_len, d_model // n_heads)
898
+ # - for group query attn q: (batch_size, seq_len, d_model)
899
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
900
+ x_normed = self.attn_norm(x)
901
+ q = self.q_proj(x_normed)
902
+ k = self.k_proj(x_normed)
903
+ v = self.v_proj(x_normed)
904
+
905
+ # Get attention scores.
906
+ if self._activation_checkpoint_fn is not None:
907
+ att, cache = self._activation_checkpoint_fn( # type: ignore
908
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
909
+ )
910
+ else:
911
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
912
+
913
+ # Add attention scores.
914
+ # shape: (B, T, C)
915
+ x = x + self.dropout(att)
916
+
917
+ # Add feed-forward projection.
918
+ # shape: (batch_size, seq_len, d_model)
919
+ og_x = x
920
+ if self._activation_checkpoint_fn is not None:
921
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
922
+ else:
923
+ x = self.ff_norm(x)
924
+ x, x_up = self.ff_proj(x), self.up_proj(x) # new add
925
+ if self._activation_checkpoint_fn is not None:
926
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
927
+ else:
928
+ x = self.act(x)
929
+ x = x * x_up # new add
930
+ x = self.ff_out(x)
931
+ x = self.dropout(x)
932
+ x = og_x + x
933
+
934
+ return x, cache
935
+
936
+
937
+ class LLaDAOutput(NamedTuple):
938
+ logits: torch.FloatTensor
939
+ """
940
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
941
+ for the next token *before* normalization via (log) softmax.
942
+ """
943
+
944
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
945
+ """
946
+ Attention keys and values from each block.
947
+ """
948
+
949
+ hidden_states: Optional[Tuple[torch.Tensor]]
950
+ """
951
+ Hidden states from each block.
952
+ """
953
+
954
+
955
+ class LLaDAGenerateOutput(NamedTuple):
956
+ token_ids: torch.LongTensor
957
+ """
958
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
959
+ These do *not* include the original input IDs.
960
+ """
961
+
962
+ scores: torch.FloatTensor
963
+ """
964
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
965
+ """
966
+
967
+
968
+ class LLaDABlockGroup(nn.ModuleList):
969
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
970
+ super().__init__(modules)
971
+ self.config = config
972
+ self.layer_offset = layer_offset
973
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
974
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
975
+
976
+ def forward(
977
+ self,
978
+ x: torch.Tensor,
979
+ attention_bias: Optional[torch.FloatTensor] = None,
980
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
981
+ use_cache: bool = False,
982
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
983
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
984
+ for block_idx, block in enumerate(self):
985
+ layer_past = None if layers_past is None else layers_past[block_idx]
986
+ block_idx += self.layer_offset
987
+ if (
988
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
989
+ or (
990
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
991
+ and block_idx % 2 == 0
992
+ )
993
+ or (
994
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
995
+ and block_idx % 3 == 0
996
+ )
997
+ or (
998
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
999
+ and block_idx % 4 == 0
1000
+ )
1001
+ ):
1002
+ # shape: (batch_size, seq_len, d_model)
1003
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1004
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1005
+ )
1006
+ else:
1007
+ # shape: (batch_size, seq_len, d_model)
1008
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1009
+ if attn_key_values is not None:
1010
+ assert cache is not None
1011
+ attn_key_values.append(cache)
1012
+ return x, attn_key_values
1013
+
1014
+ def reset_parameters(self):
1015
+ for block in self:
1016
+ block.reset_parameters()
1017
+
1018
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1019
+ self.activation_checkpointing_strategy = strategy
1020
+ for block in self:
1021
+ block.set_activation_checkpointing(strategy)
1022
+
1023
+
1024
+ class LLaDAModel(nn.Module):
1025
+ def __init__(self, config: ModelConfig, init_params: bool = True):
1026
+ super().__init__()
1027
+ self.config = config
1028
+ self.__cache = BufferCache()
1029
+
1030
+ # Validate config.
1031
+ if self.config.alibi and self.config.flash_attention:
1032
+ raise Exception("ALiBi is currently not supported with FlashAttention")
1033
+
1034
+ if self.config.alibi and self.config.rope:
1035
+ raise Exception("ALiBi and RoPE are mutually exclusive")
1036
+
1037
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1038
+ if self.config.embedding_size < self.config.vocab_size:
1039
+ raise Exception("embedding size should be at least as big as vocab size")
1040
+ elif self.config.embedding_size % 128 != 0:
1041
+ import warnings
1042
+
1043
+ warnings.warn(
1044
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1045
+ )
1046
+
1047
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1048
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1049
+
1050
+ if not (
1051
+ 0 < self.config.block_group_size <= self.config.n_layers
1052
+ and self.config.n_layers % self.config.block_group_size == 0
1053
+ ):
1054
+ raise Exception("n layers must be divisible by block group size")
1055
+
1056
+ torch.backends.cuda.enable_flash_sdp(True)
1057
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1058
+
1059
+ self.transformer = nn.ModuleDict(
1060
+ dict(
1061
+ wte=nn.Embedding(
1062
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1063
+ ),
1064
+ emb_drop=Dropout(config.embedding_dropout),
1065
+ ln_f=LayerNorm.build(config),
1066
+ )
1067
+ )
1068
+
1069
+ blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1070
+ if self.config.block_group_size > 1:
1071
+ block_groups = [
1072
+ LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
1073
+ for i in range(0, config.n_layers, config.block_group_size)
1074
+ ]
1075
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1076
+ else:
1077
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1078
+
1079
+ if not (self.config.alibi or self.config.rope):
1080
+ self.transformer.update(
1081
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1082
+ )
1083
+ if not config.weight_tying:
1084
+ self.transformer.update(
1085
+ {
1086
+ "ff_out": nn.Linear(
1087
+ config.d_model,
1088
+ config.embedding_size or config.vocab_size,
1089
+ bias=config.include_bias,
1090
+ device=config.init_device,
1091
+ )
1092
+ }
1093
+ )
1094
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1095
+ if init_params and self.config.init_device != "meta":
1096
+ self.reset_parameters()
1097
+ self.__num_fwd_flops: Optional[int] = None
1098
+
1099
+ # Warm up cache.
1100
+ if self.config.alibi:
1101
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1102
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1103
+
1104
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1105
+ self.activation_checkpointing_strategy = strategy
1106
+ if self.config.block_group_size != 1:
1107
+ for block_group in self.transformer.block_groups:
1108
+ block_group.set_activation_checkpointing(strategy)
1109
+ else:
1110
+ for block in self.transformer.blocks:
1111
+ block.set_activation_checkpointing(strategy)
1112
+
1113
+ @property
1114
+ def device(self) -> torch.device:
1115
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1116
+ if device.type == "meta":
1117
+ return _non_meta_init_device(self.config)
1118
+ else:
1119
+ return device
1120
+
1121
+ def reset_parameters(self):
1122
+ log.info("Initializing model parameters...")
1123
+ # Top-level embeddings / linear layers.
1124
+ init_weights(
1125
+ self.config,
1126
+ self.transformer.wte, # type: ignore
1127
+ std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1128
+ type_of_module=ModuleType.emb,
1129
+ )
1130
+ if hasattr(self.transformer, "wpe"):
1131
+ init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1132
+
1133
+ # Top-level layer norm.
1134
+ self.transformer.ln_f.reset_parameters() # type: ignore
1135
+
1136
+ # Output weights.
1137
+ if hasattr(self.transformer, "ff_out"):
1138
+ init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1139
+
1140
+ # Let the blocks handle themselves.
1141
+ if self.config.block_group_size == 1:
1142
+ for block in self.transformer.blocks:
1143
+ block.reset_parameters()
1144
+ else:
1145
+ for block_group in self.transformer.block_groups:
1146
+ block_group.reset_parameters()
1147
+
1148
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1149
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1150
+ -1
1151
+ ] >= seq_len:
1152
+ if alibi_bias.device != device:
1153
+ alibi_bias = alibi_bias.to(device)
1154
+ self.__cache["alibi_attention_bias"] = alibi_bias
1155
+ return alibi_bias
1156
+ with torch.autocast(device.type, enabled=False):
1157
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1158
+ self.__cache["alibi_attention_bias"] = alibi_bias
1159
+ return alibi_bias
1160
+
1161
+ def forward(
1162
+ self,
1163
+ input_ids: torch.LongTensor,
1164
+ input_embeddings: Optional[torch.FloatTensor] = None,
1165
+ attention_mask: Optional[torch.Tensor] = None,
1166
+ attention_bias: Optional[torch.Tensor] = None,
1167
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1168
+ use_cache: bool = False,
1169
+ last_logits_only: bool = False,
1170
+ output_hidden_states: Optional[bool] = None,
1171
+ ) -> LLaDAOutput:
1172
+ """
1173
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1174
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1175
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1176
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1177
+ which input IDs are masked. A `1` value in the mask means that
1178
+ the corresponding input ID should *not* be ignored. A `0` means
1179
+ that the corresponding input ID is masked.
1180
+
1181
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1182
+ library.
1183
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1184
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1185
+ to introduce causal or other biases.
1186
+
1187
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1188
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1189
+ element in the sequence.
1190
+
1191
+ If the tensor is a float tensor, it will just be added to the attention
1192
+ scores before the softmax.
1193
+
1194
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1195
+ :param past_key_values: Pre-computed keys and values for each attention block.
1196
+ Can be used to speed up sequential decoding. The `input_ids` which have
1197
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1198
+ :param use_cache: If `True`, return key and value tensors for each block.
1199
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1200
+ This can speed up decoding when you only care about the next token.
1201
+ """
1202
+ # Add Basic MDM Model config check
1203
+ assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
1204
+ assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
1205
+ assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
1206
+
1207
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1208
+
1209
+ if past_key_values:
1210
+ assert len(past_key_values) == self.config.n_layers
1211
+
1212
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1213
+ if past_key_values is None:
1214
+ past_length = 0
1215
+ else:
1216
+ past_length = past_key_values[0][0].size(-2)
1217
+
1218
+ # Get embeddings of input.
1219
+ # shape: (batch_size, seq_len, d_model)
1220
+ # print(f"input_ids: {input_ids}, input_ids.shape: {input_ids.shape}")
1221
+ # print(f"transformer wte weight shape: {self.transformer.wte.weight.shape}")
1222
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
1223
+
1224
+ # print(f"xshape: {x.shape}")
1225
+
1226
+ if self.config.input_emb_norm:
1227
+ x = x * (self.config.d_model**0.5)
1228
+
1229
+ if not (self.config.alibi or self.config.rope):
1230
+ # Get positional embeddings.
1231
+ # shape: (1, seq_len)
1232
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1233
+ # shape: (1, seq_len, d_model)
1234
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1235
+ x = pos_emb + x
1236
+
1237
+ # Add input + positional embeddings and apply dropout.
1238
+ # shape: (batch_size, seq_len, d_model)
1239
+ x = self.transformer.emb_drop(x) # type: ignore
1240
+
1241
+ # Transform the attention mask into what the blocks expect.
1242
+ if attention_mask is not None and 0.0 in attention_mask:
1243
+ # shape: (batch_size, 1, 1, seq_len)
1244
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1245
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1246
+ else:
1247
+ attention_mask = None
1248
+
1249
+ # Merge attention mask with attention bias.
1250
+ if (
1251
+ attention_bias is not None
1252
+ or attention_mask is not None
1253
+ or self.config.alibi
1254
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1255
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1256
+ # scores correctly.
1257
+ or past_key_values is not None
1258
+ ):
1259
+ if attention_bias is None and self.config.alibi:
1260
+ # print(f"get_causal_attention_bias")
1261
+ attention_bias = get_causal_attention_bias(
1262
+ self.__cache, past_length + seq_len, x.device
1263
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1264
+ elif attention_bias is None:
1265
+ # print(f"get_causal_attention_bias")
1266
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1267
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1268
+ # print(f"attention_bias.dtype in (torch.int8, torch.bool)")
1269
+ attention_bias = attention_bias.to(dtype=torch.float)
1270
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1271
+
1272
+ # Transform to the right shape and data type.
1273
+ mask_len = seq_len
1274
+ if attention_mask is not None:
1275
+ mask_len = attention_mask.shape[-1]
1276
+ elif past_key_values is not None:
1277
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1278
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1279
+
1280
+ # Add in the masking bias.
1281
+ if attention_mask is not None:
1282
+ attention_bias = attention_bias + attention_mask
1283
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1284
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1285
+ # it can produce NaNs.
1286
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1287
+
1288
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1289
+
1290
+ # decoder layers
1291
+ all_hidden_states = []
1292
+
1293
+ # Apply blocks one-by-one.
1294
+ if self.config.block_group_size == 1:
1295
+ for block_idx, block in enumerate(self.transformer.blocks):
1296
+ if output_hidden_states:
1297
+ # add hidden states
1298
+ all_hidden_states.append(x)
1299
+
1300
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1301
+ if (
1302
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1303
+ or (
1304
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1305
+ and block_idx % 2 == 0
1306
+ )
1307
+ or (
1308
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1309
+ and block_idx % 3 == 0
1310
+ )
1311
+ or (
1312
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1313
+ and block_idx % 4 == 0
1314
+ )
1315
+ ):
1316
+ # shape: (batch_size, seq_len, d_model)
1317
+ x, cache = self._activation_checkpoint_fn(
1318
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1319
+ )
1320
+ else:
1321
+ # shape: (batch_size, seq_len, d_model)
1322
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1323
+ if attn_key_values is not None:
1324
+ assert cache is not None
1325
+ attn_key_values.append(cache)
1326
+ else:
1327
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1328
+ if output_hidden_states:
1329
+ # add hidden states
1330
+ all_hidden_states.append(x)
1331
+
1332
+ layers_past = (
1333
+ None
1334
+ if past_key_values is None
1335
+ else past_key_values[
1336
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1337
+ ]
1338
+ )
1339
+ x, cache = block_group(
1340
+ x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1341
+ )
1342
+ if attn_key_values is not None:
1343
+ assert cache is not None
1344
+ attn_key_values.extend(cache)
1345
+
1346
+ if last_logits_only:
1347
+ # shape: (batch_size, 1, d_model)
1348
+ x = x[:, -1, :].unsqueeze(1)
1349
+
1350
+ # Apply final layer norm.
1351
+ # shape: (batch_size, seq_len or 1, d_model)
1352
+ x = self.transformer.ln_f(x) # type: ignore
1353
+ if output_hidden_states:
1354
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1355
+ all_hidden_states.append(x)
1356
+
1357
+ # Get logits.
1358
+ # shape: (batch_size, seq_len or 1, vocab_size)
1359
+ if self.config.weight_tying:
1360
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1361
+ else:
1362
+ logits = self.transformer.ff_out(x) # type: ignore
1363
+ if self.config.scale_logits:
1364
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1365
+
1366
+ return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
1367
+
1368
+
1369
+ def create_model_config_from_pretrained_config(config: LLaDAConfig):
1370
+ """
1371
+ Utility function
1372
+ """
1373
+
1374
+ kwargs = {}
1375
+ for field in fields(ModelConfig):
1376
+ kwargs[field.name] = getattr(config, field.name)
1377
+
1378
+ model_config = ModelConfig(**kwargs)
1379
+ return model_config
1380
+
1381
+
1382
+ class LLaDAModelLM(PreTrainedModel):
1383
+ """
1384
+ Extremely barebones HF model wrapper.
1385
+ """
1386
+
1387
+ config_class = LLaDAConfig
1388
+ base_model_prefix = "model"
1389
+ _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
1390
+
1391
+ def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
1392
+ super().__init__(config)
1393
+
1394
+ if not model:
1395
+ model_config = create_model_config_from_pretrained_config(config)
1396
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1397
+ model_config.init_device = "cpu"
1398
+ self.model = LLaDAModel(model_config, init_params=init_params)
1399
+ else:
1400
+ self.model = model
1401
+
1402
+ def forward(
1403
+ self,
1404
+ input_ids: torch.LongTensor = None,
1405
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1406
+ attention_mask: Optional[torch.Tensor] = None,
1407
+ attention_bias: Optional[torch.Tensor] = None,
1408
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1409
+ labels: Optional[torch.LongTensor] = None,
1410
+ use_cache: Optional[bool] = None,
1411
+ output_attentions: Optional[bool] = None,
1412
+ output_hidden_states: Optional[bool] = None,
1413
+ return_dict: Optional[bool] = None,
1414
+ cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
1415
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1416
+ if use_cache is None:
1417
+ use_cache = self.config.use_cache
1418
+
1419
+ if output_attentions:
1420
+ raise ValueError("output_attentions is not yet supported in LLaDA")
1421
+
1422
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1423
+
1424
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1425
+ outputs = self.model.forward(
1426
+ input_ids=input_ids,
1427
+ input_embeddings=inputs_embeds,
1428
+ attention_mask=attention_mask,
1429
+ attention_bias=attention_bias,
1430
+ past_key_values=None,
1431
+ use_cache=False,
1432
+ output_hidden_states=output_hidden_states,
1433
+ )
1434
+
1435
+ logits = outputs.logits
1436
+ hidden_states = outputs.hidden_states
1437
+
1438
+ loss = None
1439
+ if labels is not None:
1440
+ import warnings
1441
+ warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
1442
+ if not return_dict:
1443
+ output = (logits,) + outputs[1:]
1444
+ return (loss,) + output if loss is not None else output
1445
+
1446
+ return CausalLMOutputWithPast(
1447
+ logits=logits,
1448
+ past_key_values=outputs.attn_key_values,
1449
+ hidden_states=hidden_states,
1450
+ )
1451
+
1452
+ def can_generate(self) -> bool:
1453
+ return True
1454
+
1455
+ def prepare_inputs_for_generation(
1456
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
1457
+ ):
1458
+ if past_key_values:
1459
+ # This is because we want the model to only process the last generated token.
1460
+ input_ids = input_ids[:, -1:]
1461
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
1462
+
1463
+ model_inputs.update(kwargs)
1464
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
1465
+ return model_inputs
1466
+
1467
+ # TODO: these are required to make the implementation complete.
1468
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
1469
+ # pass
1470
+ #
1471
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
1472
+ # pass
1473
+ #
1474
+ # def _reorder_cache(self, past_key_values, beam_idx):
1475
+ # pass
1476
+
1477
+ def get_input_embeddings(self) -> torch.nn.Module:
1478
+ return self.model.transformer.wte
1479
+
1480
+ def set_input_embeddings(self, value: torch.nn.Module):
1481
+ self.model.transformer.wte = value
1482
+
1483
+ def get_output_embeddings(self):
1484
+ if self.config.weight_tying:
1485
+ return self.model.transformer.wte
1486
+ else:
1487
+ return self.model.transformer.ff_out
1488
+
1489
+ def set_output_embeddings(self, value: torch.nn.Module):
1490
+ if self.config.weight_tying:
1491
+ self.model.transformer.wte = value
1492
+ else:
1493
+ self.model.transformer.ff_out = value
1494
+
1495
+ def tie_weights(self):
1496
+ if self.config.weight_tying:
1497
+ self.model.transformer.ff_out = self.model.transformer.wte
1498
+
1499
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1500
+ AutoModel.register(LLaDAConfig, LLaDAModelLM)
models/modeling_magvitv2.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from .common_modules import *
6
+ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
7
+ from .misc import *
8
+ import math
9
+
10
+ class Updateable:
11
+ def do_update_step(
12
+ self, epoch: int, global_step: int, on_load_weights: bool = False
13
+ ):
14
+ for attr in self.__dir__():
15
+ if attr.startswith("_"):
16
+ continue
17
+ try:
18
+ module = getattr(self, attr)
19
+ except:
20
+ continue # ignore attributes like property, which can't be retrived using getattr?
21
+ if isinstance(module, Updateable):
22
+ module.do_update_step(
23
+ epoch, global_step, on_load_weights=on_load_weights
24
+ )
25
+ self.update_step(epoch, global_step, on_load_weights=on_load_weights)
26
+
27
+ def do_update_step_end(self, epoch: int, global_step: int):
28
+ for attr in self.__dir__():
29
+ if attr.startswith("_"):
30
+ continue
31
+ try:
32
+ module = getattr(self, attr)
33
+ except:
34
+ continue # ignore attributes like property, which can't be retrived using getattr?
35
+ if isinstance(module, Updateable):
36
+ module.do_update_step_end(epoch, global_step)
37
+ self.update_step_end(epoch, global_step)
38
+
39
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
40
+ # override this method to implement custom update logic
41
+ # if on_load_weights is True, you should be careful doing things related to model evaluations,
42
+ # as the models and tensors are not guarenteed to be on the same device
43
+ pass
44
+
45
+ def update_step_end(self, epoch: int, global_step: int):
46
+ pass
47
+
48
+ class VQGANEncoder(ModelMixin, ConfigMixin):
49
+ @dataclass
50
+ class Config:
51
+ ch: int = 128
52
+ ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4])
53
+ num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4])
54
+ attn_resolutions: List[int] = field(default_factory=lambda: [5])
55
+ dropout: float = 0.0
56
+ in_ch: int = 3
57
+ out_ch: int = 3
58
+ resolution: int = 256
59
+ z_channels: int = 13
60
+ double_z: bool = False
61
+
62
+ def __init__(self,
63
+ ch: int = 128,
64
+ ch_mult: List[int] = [1, 2, 2, 4, 4],
65
+ num_res_blocks: List[int] = [4, 3, 4, 3, 4],
66
+ attn_resolutions: List[int] = [5],
67
+ dropout: float = 0.0,
68
+ in_ch: int = 3,
69
+ out_ch: int = 3,
70
+ resolution: int = 256,
71
+ z_channels: int = 13,
72
+ double_z: bool = False):
73
+ super().__init__()
74
+ self.ch = ch
75
+ self.temb_ch = 0
76
+ self.num_resolutions = len(ch_mult)
77
+ self.num_res_blocks = num_res_blocks
78
+ self.resolution = resolution
79
+ self.in_ch = in_ch
80
+ # downsampling
81
+ self.conv_in = torch.nn.Conv2d(
82
+ self.in_ch, self.ch, kernel_size=3, stride=1, padding=1
83
+ )
84
+
85
+ curr_res = self.resolution
86
+ in_ch_mult = (1,) + tuple(ch_mult)
87
+ self.down = nn.ModuleList()
88
+ for i_level in range(self.num_resolutions):
89
+ block = nn.ModuleList()
90
+ attn = nn.ModuleList()
91
+ block_in = self.ch * in_ch_mult[i_level]
92
+ block_out = self.ch * ch_mult[i_level]
93
+ for i_block in range(self.num_res_blocks[i_level]):
94
+ block.append(
95
+ ResnetBlock(
96
+ in_channels=block_in,
97
+ out_channels=block_out,
98
+ temb_channels=self.temb_ch,
99
+ dropout=dropout,
100
+ )
101
+ )
102
+ block_in = block_out
103
+ if curr_res in attn_resolutions:
104
+ attn.append(AttnBlock(block_in))
105
+ down = nn.Module()
106
+ down.block = block
107
+ down.attn = attn
108
+ if i_level != self.num_resolutions - 1:
109
+ down.downsample = Downsample(block_in, True)
110
+ curr_res = curr_res // 2
111
+ self.down.append(down)
112
+
113
+ # middle
114
+ self.mid = nn.Module()
115
+ self.mid.block_1 = ResnetBlock(
116
+ in_channels=block_in,
117
+ out_channels=block_in,
118
+ temb_channels=self.temb_ch,
119
+ dropout=dropout,
120
+ )
121
+ self.mid.attn_1 = AttnBlock(block_in)
122
+ self.mid.block_2 = ResnetBlock(
123
+ in_channels=block_in,
124
+ out_channels=block_in,
125
+ temb_channels=self.temb_ch,
126
+ dropout=dropout,
127
+ )
128
+
129
+
130
+ self.norm_out = Normalize(block_in)
131
+ self.conv_out = torch.nn.Conv2d(
132
+ block_in,
133
+ 2 * z_channels if double_z else z_channels,
134
+ kernel_size=3,
135
+ stride=1,
136
+ padding=1,
137
+ )
138
+
139
+ self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
140
+ # for param in self.parameters():
141
+ # broadcast(param, src=0)
142
+
143
+ def forward(self, x):
144
+ # timestep embedding
145
+ temb = None
146
+
147
+ # downsampling
148
+ hs = [self.conv_in(x)]
149
+ for i_level in range(self.num_resolutions):
150
+ for i_block in range(self.num_res_blocks[i_level]):
151
+ h = self.down[i_level].block[i_block](hs[-1], temb)
152
+ if len(self.down[i_level].attn) > 0:
153
+ h = self.down[i_level].attn[i_block](h)
154
+ hs.append(h)
155
+ if i_level != self.num_resolutions - 1:
156
+ hs.append(self.down[i_level].downsample(hs[-1]))
157
+
158
+ # middle
159
+ h = hs[-1]
160
+ h = self.mid.block_1(h, temb)
161
+ h = self.mid.attn_1(h)
162
+ h = self.mid.block_2(h, temb)
163
+
164
+ # end
165
+ h = self.norm_out(h)
166
+ h = nonlinearity(h)
167
+ h = self.conv_out(h)
168
+ h = self.quant_conv(h)
169
+ return h
170
+
171
+
172
+ class LFQuantizer(nn.Module):
173
+ def __init__(self, num_codebook_entry: int = -1,
174
+ codebook_dim: int = 13,
175
+ beta: float = 0.25,
176
+ entropy_multiplier: float = 0.1,
177
+ commit_loss_multiplier: float = 0.1, ):
178
+ super().__init__()
179
+ self.codebook_size = 2 ** codebook_dim
180
+ print(
181
+ f"Look-up free quantizer with codebook size: {self.codebook_size}"
182
+ )
183
+ self.e_dim = codebook_dim
184
+ self.beta = beta
185
+
186
+ indices = torch.arange(self.codebook_size)
187
+
188
+ binary = (
189
+ indices.unsqueeze(1)
190
+ >> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long)
191
+ ) & 1
192
+
193
+ embedding = binary.float() * 2 - 1
194
+ self.register_buffer("embedding", embedding)
195
+ self.register_buffer(
196
+ "power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1)
197
+ )
198
+ self.commit_loss_multiplier = commit_loss_multiplier
199
+ self.entropy_multiplier = entropy_multiplier
200
+
201
+ def get_indices(self, z_q):
202
+ return (
203
+ (self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float())
204
+ .sum(1, keepdim=True)
205
+ .long()
206
+ )
207
+
208
+ def get_codebook_entry(self, indices, shape=None):
209
+ if shape is None:
210
+ h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1]))
211
+ else:
212
+ h, w = shape
213
+ b, _ = indices.shape
214
+ indices = indices.reshape(-1)
215
+ z_q = self.embedding[indices]
216
+ z_q = z_q.view(b, h, w, -1)
217
+
218
+ # reshape back to match original input shape
219
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
220
+
221
+ return z_q
222
+
223
+ def forward(self, z, get_code=False):
224
+ """
225
+ Inputs the output of the encoder network z and maps it to a discrete
226
+ one-hot vector that is the index of the closest embedding vector e_j
227
+ z (continuous) -> z_q (discrete)
228
+ z.shape = (batch, channel, height, width)
229
+ quantization pipeline:
230
+ 1. get encoder input (B,C,H,W)
231
+ 2. flatten input to (B*H*W,C)
232
+ """
233
+ if get_code:
234
+ return self.get_codebook_entry(z)
235
+
236
+ # reshape z -> (batch, height, width, channel) and flatten
237
+ z = z.permute(0, 2, 3, 1).contiguous()
238
+ z_flattened = z.view(-1, self.e_dim)
239
+ ge_zero = (z_flattened > 0).float()
240
+ ones = torch.ones_like(z_flattened)
241
+ z_q = ones * ge_zero + -ones * (1 - ge_zero)
242
+
243
+ # preserve gradients
244
+ z_q = z_flattened + (z_q - z_flattened).detach()
245
+
246
+ # compute entropy loss
247
+ CatDist = torch.distributions.categorical.Categorical
248
+ logit = torch.stack(
249
+ [
250
+ -(z_flattened - torch.ones_like(z_q)).pow(2),
251
+ -(z_flattened - torch.ones_like(z_q) * -1).pow(2),
252
+ ],
253
+ dim=-1,
254
+ )
255
+ cat_dist = CatDist(logits=logit)
256
+ entropy = cat_dist.entropy().mean()
257
+ mean_prob = cat_dist.probs.mean(0)
258
+ mean_entropy = CatDist(probs=mean_prob).entropy().mean()
259
+
260
+ # compute loss for embedding
261
+ commit_loss = torch.mean(
262
+ (z_q.detach() - z_flattened) ** 2
263
+ ) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2)
264
+
265
+ # reshape back to match original input shape
266
+ z_q = z_q.view(z.shape)
267
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
268
+
269
+ return {
270
+ "z": z_q,
271
+ "quantizer_loss": commit_loss * self.commit_loss_multiplier,
272
+ "entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier,
273
+ "indices": self.get_indices(z_q),
274
+ }
275
+
276
+
277
+ class VQGANDecoder(ModelMixin, ConfigMixin):
278
+ def __init__(self, ch: int = 128,
279
+ ch_mult: List[int] = [1, 1, 2, 2, 4],
280
+ num_res_blocks: List[int] = [4, 4, 3, 4, 3],
281
+ attn_resolutions: List[int] = [5],
282
+ dropout: float = 0.0,
283
+ in_ch: int = 3,
284
+ out_ch: int = 3,
285
+ resolution: int = 256,
286
+ z_channels: int = 13,
287
+ double_z: bool = False):
288
+ super().__init__()
289
+ self.ch = ch
290
+ self.temb_ch = 0
291
+ self.num_resolutions = len(ch_mult)
292
+ self.num_res_blocks = num_res_blocks
293
+ self.resolution = resolution
294
+ self.in_ch = in_ch
295
+ self.give_pre_end = False
296
+
297
+ self.z_channels = z_channels
298
+ # compute in_ch_mult, block_in and curr_res at lowest res
299
+ in_ch_mult = (1,) + tuple(ch_mult)
300
+ block_in = ch * ch_mult[self.num_resolutions - 1]
301
+ curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
302
+ self.z_shape = (1, z_channels, curr_res, curr_res)
303
+ print(
304
+ "Working with z of shape {} = {} dimensions.".format(
305
+ self.z_shape, np.prod(self.z_shape)
306
+ )
307
+ )
308
+
309
+ # z to block_in
310
+ self.conv_in = torch.nn.Conv2d(
311
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
312
+ )
313
+
314
+ # middle
315
+ self.mid = nn.Module()
316
+ self.mid.block_1 = ResnetBlock(
317
+ in_channels=block_in,
318
+ out_channels=block_in,
319
+ temb_channels=self.temb_ch,
320
+ dropout=dropout,
321
+ )
322
+ self.mid.attn_1 = AttnBlock(block_in)
323
+ self.mid.block_2 = ResnetBlock(
324
+ in_channels=block_in,
325
+ out_channels=block_in,
326
+ temb_channels=self.temb_ch,
327
+ dropout=dropout,
328
+ )
329
+
330
+ # upsampling
331
+ self.up = nn.ModuleList()
332
+ for i_level in reversed(range(self.num_resolutions)):
333
+ block = nn.ModuleList()
334
+ attn = nn.ModuleList()
335
+ block_out = ch * ch_mult[i_level]
336
+ for i_block in range(self.num_res_blocks[i_level]):
337
+ block.append(
338
+ ResnetBlock(
339
+ in_channels=block_in,
340
+ out_channels=block_out,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout,
343
+ )
344
+ )
345
+ block_in = block_out
346
+ if curr_res in attn_resolutions:
347
+ attn.append(AttnBlock(block_in))
348
+ up = nn.Module()
349
+ up.block = block
350
+ up.attn = attn
351
+ if i_level != 0:
352
+ up.upsample = Upsample(block_in, True)
353
+ curr_res = curr_res * 2
354
+ self.up.insert(0, up) # prepend to get consistent order
355
+
356
+ self.norm_out = Normalize(block_in)
357
+ self.conv_out = torch.nn.Conv2d(
358
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
359
+ )
360
+ self.post_quant_conv = torch.nn.Conv2d(
361
+ z_channels, z_channels, 1
362
+ )
363
+
364
+
365
+ def forward(self, z):
366
+ # assert z.shape[1:] == self.z_shape[1:]
367
+ self.last_z_shape = z.shape
368
+ # timestep embedding
369
+ temb = None
370
+ output = dict()
371
+ z = self.post_quant_conv(z)
372
+
373
+ # z to block_in
374
+ h = self.conv_in(z)
375
+
376
+ # middle
377
+ h = self.mid.block_1(h, temb)
378
+ h = self.mid.attn_1(h)
379
+ h = self.mid.block_2(h, temb)
380
+
381
+ # upsampling
382
+ for i_level in reversed(range(self.num_resolutions)):
383
+ for i_block in range(self.num_res_blocks[i_level]):
384
+ h = self.up[i_level].block[i_block](h, temb)
385
+ if len(self.up[i_level].attn) > 0:
386
+ h = self.up[i_level].attn[i_block](h)
387
+ if i_level != 0:
388
+ h = self.up[i_level].upsample(h)
389
+
390
+ # end
391
+ output["output"] = h
392
+ if self.give_pre_end:
393
+ return output
394
+
395
+ h = self.norm_out(h)
396
+ h = nonlinearity(h)
397
+ h = self.conv_out(h)
398
+ output["output"] = h
399
+ return output
400
+
401
+
402
+ class MAGVITv2(ModelMixin, ConfigMixin):
403
+ @register_to_config
404
+ def __init__(
405
+ self,
406
+ ):
407
+ super().__init__()
408
+
409
+ self.encoder = VQGANEncoder()
410
+ self.decoder = VQGANDecoder()
411
+ self.quantize = LFQuantizer()
412
+
413
+ def forward(self, pixel_values, return_loss=False):
414
+ pass
415
+
416
+ def encode(self, pixel_values, return_loss=False):
417
+ hidden_states = self.encoder(pixel_values)
418
+ quantized_states = self.quantize(hidden_states)['z']
419
+ codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1)
420
+ output = (quantized_states, codebook_indices)
421
+ return output
422
+
423
+ def get_code(self, pixel_values):
424
+ hidden_states = self.encoder(pixel_values)
425
+ codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1)
426
+
427
+ return codebook_indices
428
+
429
+ def decode_code(self, codebook_indices, shape=None):
430
+ z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape)
431
+
432
+ reconstructed_pixel_values = self.decoder(z_q)["output"]
433
+ return reconstructed_pixel_values
434
+
435
+
436
+ if __name__ == '__main__':
437
+ encoder = VQGANEncoder()
438
+ import ipdb
439
+ ipdb.set_trace()
440
+ print()
models/modeling_mmada.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import sys
6
+ from abc import abstractmethod
7
+ from collections import defaultdict
8
+ from functools import partial
9
+ from typing import (
10
+ Callable,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ NamedTuple,
15
+ Optional,
16
+ Sequence,
17
+ Set,
18
+ Tuple,
19
+ cast,
20
+ )
21
+ from dataclasses import fields
22
+ from typing import List, Optional, Tuple, Union
23
+ import numpy as np
24
+ import torch
25
+ import torch.backends.cuda
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch import einsum
29
+ from transformers import PreTrainedModel
30
+ from transformers.modeling_outputs import CausalLMOutputWithPast
31
+ from transformers.models.auto import AutoModel, AutoConfig, AutoModelForCausalLM
32
+ from transformers.cache_utils import Cache
33
+ from PIL import Image
34
+ from .configuration_llada import (
35
+ LLaDAConfig,
36
+ StrEnum,
37
+ InitFnType,
38
+ ActivationType,
39
+ BlockType,
40
+ LayerNormType,
41
+ ModelConfig,
42
+ ActivationCheckpointingStrategy,
43
+ )
44
+
45
+ from .modeling_llada import LLaDAModelLM
46
+ from .sampling import cosine_schedule, mask_by_random_topk
47
+ from transformers import PretrainedConfig
48
+
49
+ def add_gumbel_noise(logits, temperature):
50
+ '''
51
+ The Gumbel max is a method for sampling categorical distributions.
52
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
53
+ Thus, we use float64.
54
+ '''
55
+ if temperature == 0:
56
+ return logits
57
+ logits = logits.to(torch.float64)
58
+ noise = torch.rand_like(logits, dtype=torch.float64)
59
+ gumbel_noise = (- torch.log(noise)) ** temperature
60
+ return logits.exp() / gumbel_noise
61
+
62
+
63
+ def get_num_transfer_tokens(mask_index, steps):
64
+ '''
65
+ In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
66
+ Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
67
+ the expected number of tokens transitioned at each step should be consistent.
68
+
69
+ This function is designed to precompute the number of tokens that need to be transitioned at each step.
70
+ '''
71
+ mask_num = mask_index.sum(dim=1, keepdim=True)
72
+
73
+ base = mask_num // steps
74
+ remainder = mask_num % steps
75
+
76
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
77
+
78
+ for i in range(mask_num.size(0)):
79
+ num_transfer_tokens[i, :remainder[i]] += 1
80
+
81
+ return num_transfer_tokens
82
+
83
+ class MMadaConfig(PretrainedConfig):
84
+ model_type = "mmada"
85
+
86
+ def __init__(self, **kwargs):
87
+ super().__init__(**kwargs)
88
+
89
+ allowed_keys = [
90
+ "vocab_size",
91
+ "llm_vocab_size",
92
+ "llm_model_path",
93
+ "codebook_size",
94
+ "num_vq_tokens",
95
+ "num_new_special_tokens",
96
+ "gradient_checkpointing",
97
+ "new_vocab_size",
98
+ ]
99
+
100
+ for key in allowed_keys:
101
+ if key in kwargs:
102
+ setattr(self, key, kwargs[key])
103
+
104
+
105
+
106
+ class MMadaModelLM(LLaDAModelLM):
107
+ config_class = MMadaConfig
108
+ base_model_prefix = "model"
109
+ def __init__(self, config: MMadaConfig, *args, **kwargs):
110
+ print(f"Initializing MMadaModelLM with config: {config}")
111
+ super().__init__(config, *args, **kwargs)
112
+
113
+ # # resize token embeddings
114
+ # print(f"Resizing token embeddings to {config.new_vocab_size}")
115
+ # self.resize_token_embeddings(config.new_vocab_size)
116
+
117
+ @torch.no_grad()
118
+ def t2i_generate(
119
+ self,
120
+ input_ids: torch.LongTensor = None,
121
+ uncond_input_ids: torch.LongTensor = None,
122
+ attention_mask=None,
123
+ uncond_attention_mask=None,
124
+ temperature=1.0,
125
+ timesteps=18, # ideal number of steps is 18 in maskgit paper
126
+ guidance_scale=0,
127
+ noise_schedule=cosine_schedule,
128
+ generator: torch.Generator = None,
129
+ config=None,
130
+ seq_len=1024,
131
+ mask_token_id = 126336,
132
+ resolution = 512,
133
+ codebook_size = 8192,
134
+ **kwargs,
135
+ ):
136
+ """
137
+ Generate 1:1 similar to the original MaskGit repo
138
+ https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
139
+ """
140
+
141
+ # begin with all image token ids masked
142
+ # 计算有多少个mask token
143
+ mask_count = (input_ids == mask_token_id).sum().item()
144
+ num_vq_tokens = seq_len
145
+ num_new_special_tokens = 0
146
+ uni_prompting = kwargs.get("uni_prompting", None)
147
+ # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}")
148
+ input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
149
+ input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens)
150
+
151
+ # for classifier-free guidance
152
+ if uncond_input_ids is not None:
153
+ uncond_prefix = uncond_input_ids[:, :resolution + 1]
154
+
155
+ for step in range(timesteps):
156
+ if uncond_input_ids is not None and guidance_scale > 0:
157
+ uncond_input_ids = torch.cat(
158
+ [uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
159
+ model_input = torch.cat([input_ids, uncond_input_ids])
160
+ attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
161
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
162
+ logits = self(model_input, attention_bias=attention_bias).logits
163
+ # print(f"logits.shape: {logits.shape}")
164
+ cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
165
+ # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
166
+ # it seems that muse has a different cfg setting
167
+ logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
168
+ logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
169
+ else:
170
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
171
+ logits = self(input_ids, attention_bias=attention_bias).logits
172
+ logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
173
+
174
+ # logits: 1, 1024, 8192
175
+ # print(f"logits.shape: {logits.shape}")
176
+ probs = logits.softmax(dim=-1)
177
+ sampled = probs.reshape(-1, logits.size(-1))
178
+ # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}")
179
+ sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024
180
+
181
+ unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
182
+ # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}")
183
+ sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
184
+ # Defines the mask ratio for the next round. The number to mask out is
185
+ # determined by mask_ratio * unknown_number_in_the_beginning.
186
+ ratio = 1.0 * (step + 1) / timesteps
187
+ mask_ratio = noise_schedule(torch.tensor(ratio))
188
+ # Computes the probabilities of each selected tokens.
189
+ selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
190
+ selected_probs = selected_probs.squeeze(-1)
191
+
192
+ # Ignores the tokens given in the input by overwriting their confidence.
193
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
194
+ # Gets mask lens for each sample in the batch according to the mask ratio.
195
+ mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
196
+ # Keeps at least one of prediction in this round and also masks out at least
197
+ # one and for the next iteration
198
+ mask_len = torch.max(
199
+ torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
200
+ )
201
+ # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}")
202
+ # Adds noise for randomness
203
+ temperature = temperature * (1.0 - ratio)
204
+ masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
205
+ # Masks tokens with lower confidence.
206
+ input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
207
+ sampled_ids + len(uni_prompting.text_tokenizer)
208
+ + num_new_special_tokens)
209
+ input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
210
+
211
+ return sampled_ids
212
+
213
+ def forward_process(
214
+ self,
215
+ input_ids,
216
+ labels,
217
+ batch_size_t2i=0,
218
+ batch_size_lm=0,
219
+ batch_size_mmu=0,
220
+ max_seq_length=128,
221
+ p_mask_lm=None,
222
+ p_mask_mmu=None,
223
+ answer_lengths=None,
224
+ t2i_masks=None,
225
+ answer_lengths_lm=None
226
+ ):
227
+ # attention bias, True for batch_size, 1, seq_len, seq_len
228
+ attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
229
+ attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
230
+ attention_bias[:batch_size_t2i] = attention_bias_t2i
231
+ logits = self(input_ids, attention_bias=attention_bias).logits
232
+ # logits = self(input_ids).logits
233
+ self.output_size = logits.shape[-1]
234
+
235
+ # print(f"logits shape: {logits.shape}") B, 359, vocab_size
236
+
237
+ if batch_size_t2i == 0:
238
+ loss_t2i = torch.tensor(0.0, device=input_ids.device)
239
+ else:
240
+ # t2i loss
241
+ loss_t2i = F.cross_entropy(
242
+ logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
243
+ labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
244
+ )
245
+
246
+ # llada loss
247
+ masked_indices = input_ids == self.config.mask_token_id
248
+ masked_indices_lm = masked_indices[batch_size_t2i:batch_size_t2i + batch_size_lm]
249
+ # 新增调试代码:统计每行mask数量
250
+ # if masked_indices_lm.numel() > 0:
251
+ # mask_counts = torch.sum(masked_indices_lm, dim=1)
252
+ # logging.info(f"[LM mask nums]: {mask_counts.cpu()}.")
253
+ # else:
254
+ # logging.info("[LM mask nums] no LM sample.")
255
+ masked_indices_mmu = masked_indices[-batch_size_mmu:]
256
+ p_mask_lm = p_mask_lm.to(masked_indices_lm.device)
257
+ p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device)
258
+ answer_lengths = answer_lengths.to(masked_indices_mmu.device)
259
+ loss_lm = F.cross_entropy(
260
+ logits[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1, self.output_size),
261
+ labels[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none'
262
+ )/p_mask_lm[masked_indices_lm]
263
+ # print(f"logits lm shape: {logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape}")
264
+ loss_lm = loss_lm.sum() / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0] * logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[1])
265
+
266
+ # llm loss
267
+ answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device)
268
+ loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0])
269
+
270
+ loss_mmu = F.cross_entropy(
271
+ logits[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1, self.output_size),
272
+ labels[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none'
273
+ )/p_mask_mmu[masked_indices_mmu]
274
+ loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[-batch_size_mmu:].shape[0])
275
+
276
+ return logits, loss_t2i, loss_lm, loss_mmu
277
+
278
+ def forward_process_with_r2i(
279
+ self,
280
+ input_ids,
281
+ labels,
282
+ t2i_masks=None,
283
+ max_seq_length=128,
284
+ batch_size_t2i=0,
285
+ batch_size_lm=0,
286
+ batch_size_mmu=0,
287
+ batch_size_r2i=0,
288
+ p_mask_lm=None,
289
+ p_mask_mmu=None,
290
+ p_mask_r2i=None,
291
+ answer_lengths=None,
292
+ answer_lengths_lm=None,
293
+ answer_lengths_r2i=None,
294
+ ):
295
+ # attention bias, True for batch_size, 1, seq_len, seq_len
296
+ attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
297
+ attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
298
+ attention_bias[:batch_size_t2i] = attention_bias_t2i
299
+ logits = self(input_ids, attention_bias=attention_bias).logits
300
+ # logits = self(input_ids).logits
301
+ self.output_size = logits.shape[-1]
302
+
303
+ # print(f"logits shape: {logits.shape}") B, 359, vocab_size
304
+
305
+ if batch_size_t2i == 0:
306
+ loss_t2i = torch.tensor(0.0, device=input_ids.device)
307
+ else:
308
+ # t2i loss
309
+ loss_t2i = F.cross_entropy(
310
+ logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
311
+ labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
312
+ )
313
+
314
+ # llada loss
315
+
316
+ start_lm = batch_size_t2i
317
+ end_lm = start_lm + batch_size_lm
318
+ start_mmu = end_lm
319
+ end_mmu = start_mmu + batch_size_mmu
320
+ start_r2i = end_mmu
321
+ end_r2i = start_r2i + batch_size_r2i
322
+
323
+ masked_indices = input_ids == self.config.mask_token_id
324
+ masked_indices_lm = masked_indices[start_lm:end_lm]
325
+ masked_indices_mmu = masked_indices[start_mmu:end_mmu]
326
+ masked_indices_r2i = masked_indices[start_r2i:end_r2i]
327
+
328
+ p_mask_lm = p_mask_lm.to(masked_indices_lm.device)
329
+ p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device)
330
+ p_mask_r2i = p_mask_r2i.to(masked_indices_r2i.device)
331
+
332
+ answer_lengths = answer_lengths.to(masked_indices_mmu.device)
333
+ answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device)
334
+ answer_lengths_r2i = answer_lengths_r2i.to(masked_indices_r2i.device)
335
+
336
+ loss_lm = F.cross_entropy(
337
+ logits[start_lm:end_lm][masked_indices_lm].contiguous().view(-1, self.output_size),
338
+ labels[start_lm:end_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none'
339
+ )/p_mask_lm[masked_indices_lm]
340
+ # print(f"logits lm shape: {logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape}")
341
+ loss_lm = loss_lm.sum() / (logits[start_lm:end_lm].shape[0] * logits[start_lm:end_lm].shape[1])
342
+ loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[start_lm:end_lm].shape[0])
343
+
344
+ loss_mmu = F.cross_entropy(
345
+ logits[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1, self.output_size),
346
+ labels[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none'
347
+ )/p_mask_mmu[masked_indices_mmu]
348
+ loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[start_mmu:end_mmu].shape[0])
349
+
350
+ loss_r2i = F.cross_entropy(
351
+ logits[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1, self.output_size),
352
+ labels[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1), ignore_index=-100, reduction='none'
353
+ )/p_mask_r2i[masked_indices_r2i]
354
+ loss_r2i = torch.sum(loss_r2i/answer_lengths_r2i[masked_indices_r2i]) / (logits[start_r2i:end_r2i].shape[0])
355
+
356
+ return logits, loss_t2i, loss_lm, loss_mmu, loss_r2i
357
+
358
+
359
+ def forward_t2i(
360
+ self,
361
+ input_ids,
362
+ labels,
363
+ batch_size_t2i=0,
364
+ max_seq_length=128,
365
+ t2i_masks=None
366
+ ):
367
+ # attention bias, True for batch_size, 1, seq_len, seq_len
368
+ attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
369
+ attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
370
+ attention_bias[:batch_size_t2i] = attention_bias_t2i
371
+ logits = self(input_ids, attention_bias=attention_bias).logits
372
+ # logits = self(input_ids).logits
373
+ self.output_size = logits.shape[-1]
374
+
375
+ # print(f"logits shape: {logits.shape}") B, 359, vocab_size
376
+
377
+ loss_t2i = F.cross_entropy(
378
+ logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
379
+ labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
380
+ )
381
+
382
+ return loss_t2i
383
+
384
+
385
+
386
+
387
+
388
+ @torch.no_grad()
389
+ def mmu_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
390
+ """
391
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
392
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
393
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
394
+ """
395
+
396
+ if attention_mask is not None and 0.0 in attention_mask:
397
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
398
+ # print(f"attention_bias: {attention_bias}")
399
+ else:
400
+ attention_bias = None
401
+ try:
402
+ device = idx.device
403
+ except:
404
+ device = input_embeddings.device
405
+
406
+ result = []
407
+ batch_size = idx.shape[0]
408
+ x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device)
409
+ x[:, :idx.shape[1]] = idx.clone()
410
+ prompt_index = (x != mask_id)
411
+
412
+
413
+ assert max_new_tokens % block_length == 0
414
+ num_blocks = max_new_tokens // block_length
415
+
416
+ assert steps % num_blocks == 0
417
+ steps = steps // num_blocks
418
+
419
+ # print(f"num_blocks: {num_blocks}, steps: {steps}")
420
+ # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps)
421
+ for num_block in range(num_blocks):
422
+ block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id)
423
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
424
+ # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps)
425
+ # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}")
426
+ for i in range(steps):
427
+ mask_index = (x == mask_id)
428
+ if cfg_scale > 0.0:
429
+ un_x = x.clone()
430
+ un_x[prompt_index] = mask_id
431
+ x_ = torch.cat([x, un_x], dim=0)
432
+ logits = self(x_).logits
433
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
434
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
435
+ else:
436
+ logits = self(x, attention_bias=attention_bias).logits
437
+
438
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
439
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
440
+ if remasking == 'low_confidence':
441
+ p = F.softmax(logits.to(torch.float64), dim=-1)
442
+ x0_p = torch.squeeze(
443
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
444
+ elif remasking == 'random':
445
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
446
+ else:
447
+ raise NotImplementedError(remasking)
448
+
449
+ x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf
450
+
451
+ x0 = torch.where(mask_index, x0, x)
452
+ confidence = torch.where(mask_index, x0_p, -np.inf)
453
+
454
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
455
+ for j in range(confidence.shape[0]):
456
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
457
+ transfer_index[j, select_index] = True
458
+ x[transfer_index] = x0[transfer_index]
459
+
460
+
461
+ # logits = logits[:, -1, :] / temperature
462
+ # # optionally crop the logits to only the top k options
463
+ # if top_k is not None:
464
+ # v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
465
+ # logits[logits < v[:, [-1]]] = -float('Inf')
466
+ # # apply softmax to convert logits to (normalized) probabilities
467
+ # probs = F.softmax(logits, dim=-1)
468
+ # # sample from the distribution
469
+ # idx_next = torch.multinomial(probs, num_samples=1)
470
+ # result.append(idx_next[0][0])
471
+ # # append sampled index to the running sequence and continue
472
+ # if self.config.w_clip_vit:
473
+ # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next)
474
+ # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
475
+ # else:
476
+ # idx = torch.cat((idx, idx_next), dim=1)
477
+
478
+ # if eot_token is not None and idx_next.cpu() == eot_token:
479
+ # break
480
+
481
+ return x
482
+
483
+ @torch.no_grad()
484
+ def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
485
+ """
486
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
487
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
488
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
489
+ """
490
+
491
+ if attention_mask is not None and 0.0 in attention_mask:
492
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
493
+ # print(f"attention_bias: {attention_bias}")
494
+ else:
495
+ attention_bias = None
496
+ try:
497
+ device = idx.device
498
+ except:
499
+ device = input_embeddings.device
500
+
501
+ result = []
502
+ batch_size = idx.shape[0]
503
+ x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device)
504
+ x[:, :idx.shape[1]] = idx.clone()
505
+ prompt_index = (x != mask_id)
506
+
507
+
508
+ assert max_new_tokens % block_length == 0
509
+ num_blocks = max_new_tokens // block_length
510
+
511
+ assert steps % num_blocks == 0
512
+ steps = steps // num_blocks
513
+
514
+ for num_block in range(num_blocks):
515
+ block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id)
516
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
517
+ for i in range(steps):
518
+ mask_index = (x == mask_id)
519
+ if cfg_scale > 0.0:
520
+ un_x = x.clone()
521
+ un_x[prompt_index] = mask_id
522
+ x_ = torch.cat([x, un_x], dim=0)
523
+ logits = self(x_).logits
524
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
525
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
526
+ else:
527
+ logits = self(x, attention_bias=attention_bias).logits
528
+
529
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
530
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
531
+ if remasking == 'low_confidence':
532
+ p = F.softmax(logits.to(torch.float64), dim=-1)
533
+ x0_p = torch.squeeze(
534
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
535
+ elif remasking == 'random':
536
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
537
+ else:
538
+ raise NotImplementedError(remasking)
539
+
540
+ x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf
541
+
542
+ x0 = torch.where(mask_index, x0, x)
543
+ confidence = torch.where(mask_index, x0_p, -np.inf)
544
+
545
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
546
+ for j in range(confidence.shape[0]):
547
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
548
+ transfer_index[j, select_index] = True
549
+ x[transfer_index] = x0[transfer_index]
550
+ if eot_token is not None:
551
+ last_token_index_in_current_block = idx.shape[1] + (num_block + 1) * block_length - 1
552
+ if last_token_index_in_current_block < x.shape[1]:
553
+ tokens_at_block_end = x[:, last_token_index_in_current_block]
554
+ if torch.all(tokens_at_block_end == eot_token):
555
+ break
556
+ return x
557
+
558
+ @torch.no_grad()
559
+ def t2i_generate_decoding_stepwise(
560
+ self,
561
+ input_ids: torch.LongTensor = None,
562
+ uncond_input_ids: torch.LongTensor = None,
563
+ attention_mask=None,
564
+ uncond_attention_mask=None,
565
+ temperature=1.0,
566
+ timesteps=18, # ideal number of steps is 18 in maskgit paper
567
+ guidance_scale=0,
568
+ noise_schedule=cosine_schedule,
569
+ generator: torch.Generator = None,
570
+ config=None,
571
+ seq_len=1024,
572
+ mask_token_id = 126336,
573
+ resolution = 512,
574
+ codebook_size = 8192,
575
+ vq_model = None,
576
+ **kwargs,
577
+ ):
578
+ """
579
+ Generate 1:1 similar to the original MaskGit repo
580
+ https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
581
+ """
582
+
583
+ # begin with all image token ids masked
584
+ # 计算有多少个mask token
585
+ mask_count = (input_ids == mask_token_id).sum().item()
586
+ num_vq_tokens = seq_len
587
+ num_new_special_tokens = 0
588
+ uni_prompting = kwargs.get("uni_prompting", None)
589
+ # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}")
590
+ input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
591
+ input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens)
592
+
593
+ # for classifier-free guidance
594
+ if uncond_input_ids is not None:
595
+ uncond_prefix = uncond_input_ids[:, :resolution + 1]
596
+
597
+ for step in range(timesteps):
598
+ if uncond_input_ids is not None and guidance_scale > 0:
599
+ uncond_input_ids = torch.cat(
600
+ [uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
601
+ model_input = torch.cat([input_ids, uncond_input_ids])
602
+ attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
603
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
604
+ logits = self(model_input, attention_bias=attention_bias).logits
605
+ # print(f"logits.shape: {logits.shape}")
606
+ cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
607
+ # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
608
+ # it seems that muse has a different cfg setting
609
+ logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
610
+ logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
611
+ else:
612
+ attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
613
+ logits = self(input_ids, attention_bias=attention_bias).logits
614
+ logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
615
+
616
+ # logits: 1, 1024, 8192
617
+ # print(f"logits.shape: {logits.shape}")
618
+ probs = logits.softmax(dim=-1)
619
+ sampled = probs.reshape(-1, logits.size(-1))
620
+ # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}")
621
+ sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024
622
+
623
+ unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
624
+ # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}")
625
+ sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
626
+ # Defines the mask ratio for the next round. The number to mask out is
627
+ current_image_vq_indices = sampled_ids.clone()
628
+ # print(f"current_image_vq_indices: {current_image_vq_indices}")
629
+ current_image_vq_indices = torch.clamp(current_image_vq_indices, 0, 8192 - 1)
630
+ current_image = vq_model.decode_code(current_image_vq_indices)
631
+ images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0)
632
+ images *= 255.0
633
+ images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
634
+ pil_images = Image.fromarray(images[0])
635
+ yield pil_images, f"Step {step + 1}/{timesteps}"
636
+ # determined by mask_ratio * unknown_number_in_the_beginning.
637
+ ratio = 1.0 * (step + 1) / timesteps
638
+ mask_ratio = noise_schedule(torch.tensor(ratio))
639
+ # Computes the probabilities of each selected tokens.
640
+ selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
641
+ selected_probs = selected_probs.squeeze(-1)
642
+
643
+ # Ignores the tokens given in the input by overwriting their confidence.
644
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
645
+ # Gets mask lens for each sample in the batch according to the mask ratio.
646
+ mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
647
+ # Keeps at least one of prediction in this round and also masks out at least
648
+ # one and for the next iteration
649
+ mask_len = torch.max(
650
+ torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
651
+ )
652
+ # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}")
653
+ # Adds noise for randomness
654
+ temperature = temperature * (1.0 - ratio)
655
+ masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
656
+ # Masks tokens with lower confidence.
657
+ input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
658
+ sampled_ids + len(uni_prompting.text_tokenizer)
659
+ + num_new_special_tokens)
660
+ input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
661
+
662
+
663
+ return sampled_ids
664
+
665
+
666
+ AutoConfig.register("mmada", MMadaConfig)
667
+ AutoModelForCausalLM.register(MMadaConfig, MMadaModelLM)
668
+ AutoModel.register(MMadaConfig, MMadaModelLM)
models/modeling_utils.py ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import json
20
+ import os
21
+ import re
22
+ from collections import OrderedDict
23
+ from functools import partial
24
+ from pathlib import Path
25
+ from typing import Any, Callable, List, Optional, Tuple, Union
26
+
27
+ import safetensors
28
+ import torch
29
+ from huggingface_hub import create_repo, split_torch_state_dict_into_shards
30
+ from huggingface_hub.utils import validate_hf_hub_args
31
+ from torch import Tensor, nn
32
+
33
+ from diffusers import __version__
34
+ from diffusers.utils import (
35
+ FLAX_WEIGHTS_NAME,
36
+ SAFE_WEIGHTS_INDEX_NAME,
37
+ WEIGHTS_INDEX_NAME,
38
+ _add_variant,
39
+ _get_checkpoint_shard_files,
40
+ _get_model_file,
41
+ deprecate,
42
+ is_accelerate_available,
43
+ is_torch_version,
44
+ logging,
45
+ )
46
+
47
+ CONFIG_NAME = "config.json"
48
+ WEIGHTS_NAME = "pytorch_model.bin"
49
+ SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
50
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
51
+
52
+ from diffusers.utils.hub_utils import (
53
+ PushToHubMixin,
54
+ load_or_create_model_card,
55
+ populate_model_card,
56
+ )
57
+ from diffusers.models.model_loading_utils import (
58
+ _determine_device_map,
59
+ _fetch_index_file,
60
+ _load_state_dict_into_model,
61
+ load_model_dict_into_meta,
62
+ load_state_dict,
63
+ )
64
+
65
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
66
+
67
+ logger = logging.get_logger(__name__)
68
+
69
+ _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
70
+
71
+
72
+ if is_torch_version(">=", "1.9.0"):
73
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
74
+ else:
75
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
76
+
77
+
78
+ if is_accelerate_available():
79
+ import accelerate
80
+
81
+
82
+ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
83
+ try:
84
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
85
+ return next(parameters_and_buffers).device
86
+ except StopIteration:
87
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
88
+
89
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
90
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
91
+ return tuples
92
+
93
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
94
+ first_tuple = next(gen)
95
+ return first_tuple[1].device
96
+
97
+
98
+ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
99
+ try:
100
+ params = tuple(parameter.parameters())
101
+ if len(params) > 0:
102
+ return params[0].dtype
103
+
104
+ buffers = tuple(parameter.buffers())
105
+ if len(buffers) > 0:
106
+ return buffers[0].dtype
107
+
108
+ except StopIteration:
109
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
110
+
111
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
112
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
113
+ return tuples
114
+
115
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
116
+ first_tuple = next(gen)
117
+ return first_tuple[1].dtype
118
+
119
+
120
+ class ModelMixin(torch.nn.Module, PushToHubMixin):
121
+ r"""
122
+ Base class for all models.
123
+
124
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
125
+ saving models.
126
+
127
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
128
+ """
129
+
130
+ config_name = CONFIG_NAME
131
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
132
+ _supports_gradient_checkpointing = False
133
+ _keys_to_ignore_on_load_unexpected = None
134
+ _no_split_modules = None
135
+
136
+ def __init__(self):
137
+ super().__init__()
138
+
139
+ def __getattr__(self, name: str) -> Any:
140
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
141
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
142
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
143
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
144
+ """
145
+
146
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
147
+ is_attribute = name in self.__dict__
148
+
149
+ if is_in_config and not is_attribute:
150
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
151
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
152
+ return self._internal_dict[name]
153
+
154
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
155
+ return super().__getattr__(name)
156
+
157
+ @property
158
+ def is_gradient_checkpointing(self) -> bool:
159
+ """
160
+ Whether gradient checkpointing is activated for this model or not.
161
+ """
162
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
163
+
164
+ def enable_gradient_checkpointing(self) -> None:
165
+ """
166
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
167
+ *checkpoint activations* in other frameworks).
168
+ """
169
+ if not self._supports_gradient_checkpointing:
170
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
171
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
172
+
173
+ def disable_gradient_checkpointing(self) -> None:
174
+ """
175
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
176
+ *checkpoint activations* in other frameworks).
177
+ """
178
+ if self._supports_gradient_checkpointing:
179
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
180
+
181
+ def set_use_npu_flash_attention(self, valid: bool) -> None:
182
+ r"""
183
+ Set the switch for the npu flash attention.
184
+ """
185
+
186
+ def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
187
+ if hasattr(module, "set_use_npu_flash_attention"):
188
+ module.set_use_npu_flash_attention(valid)
189
+
190
+ for child in module.children():
191
+ fn_recursive_set_npu_flash_attention(child)
192
+
193
+ for module in self.children():
194
+ if isinstance(module, torch.nn.Module):
195
+ fn_recursive_set_npu_flash_attention(module)
196
+
197
+ def enable_npu_flash_attention(self) -> None:
198
+ r"""
199
+ Enable npu flash attention from torch_npu
200
+
201
+ """
202
+ self.set_use_npu_flash_attention(True)
203
+
204
+ def disable_npu_flash_attention(self) -> None:
205
+ r"""
206
+ disable npu flash attention from torch_npu
207
+
208
+ """
209
+ self.set_use_npu_flash_attention(False)
210
+
211
+ def set_use_memory_efficient_attention_xformers(
212
+ self, valid: bool, attention_op: Optional[Callable] = None
213
+ ) -> None:
214
+ # Recursively walk through all the children.
215
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
216
+ # gets the message
217
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
218
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
219
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
220
+
221
+ for child in module.children():
222
+ fn_recursive_set_mem_eff(child)
223
+
224
+ for module in self.children():
225
+ if isinstance(module, torch.nn.Module):
226
+ fn_recursive_set_mem_eff(module)
227
+
228
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
229
+ r"""
230
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
231
+
232
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
233
+ inference. Speed up during training is not guaranteed.
234
+
235
+ <Tip warning={true}>
236
+
237
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
238
+ precedent.
239
+
240
+ </Tip>
241
+
242
+ Parameters:
243
+ attention_op (`Callable`, *optional*):
244
+ Override the default `None` operator for use as `op` argument to the
245
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
246
+ function of xFormers.
247
+
248
+ Examples:
249
+
250
+ ```py
251
+ >>> import torch
252
+ >>> from diffusers import UNet2DConditionModel
253
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
254
+
255
+ >>> model = UNet2DConditionModel.from_pretrained(
256
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
257
+ ... )
258
+ >>> model = model.to("cuda")
259
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
260
+ ```
261
+ """
262
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
263
+
264
+ def disable_xformers_memory_efficient_attention(self) -> None:
265
+ r"""
266
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
267
+ """
268
+ self.set_use_memory_efficient_attention_xformers(False)
269
+
270
+ def save_pretrained(
271
+ self,
272
+ save_directory: Union[str, os.PathLike],
273
+ is_main_process: bool = True,
274
+ save_function: Optional[Callable] = None,
275
+ safe_serialization: bool = True,
276
+ variant: Optional[str] = None,
277
+ max_shard_size: Union[int, str] = "10GB",
278
+ push_to_hub: bool = False,
279
+ **kwargs,
280
+ ):
281
+ """
282
+ Save a model and its configuration file to a directory so that it can be reloaded using the
283
+ [`~models.ModelMixin.from_pretrained`] class method.
284
+
285
+ Arguments:
286
+ save_directory (`str` or `os.PathLike`):
287
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
288
+ is_main_process (`bool`, *optional*, defaults to `True`):
289
+ Whether the process calling this is the main process or not. Useful during distributed training and you
290
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
291
+ process to avoid race conditions.
292
+ save_function (`Callable`):
293
+ The function to use to save the state dictionary. Useful during distributed training when you need to
294
+ replace `torch.save` with another method. Can be configured with the environment variable
295
+ `DIFFUSERS_SAVE_MODE`.
296
+ safe_serialization (`bool`, *optional*, defaults to `True`):
297
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
298
+ variant (`str`, *optional*):
299
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
300
+ max_shard_size (`int` or `str`, defaults to `"10GB"`):
301
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
302
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
303
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
304
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
305
+ This is to establish a common default size for this argument across different libraries in the Hugging
306
+ Face ecosystem (`transformers`, and `accelerate`, for example).
307
+ push_to_hub (`bool`, *optional*, defaults to `False`):
308
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
309
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
310
+ namespace).
311
+ kwargs (`Dict[str, Any]`, *optional*):
312
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
313
+ """
314
+ if os.path.isfile(save_directory):
315
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
316
+ return
317
+
318
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
319
+ weights_name = _add_variant(weights_name, variant)
320
+ weight_name_split = weights_name.split(".")
321
+ if len(weight_name_split) in [2, 3]:
322
+ weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
323
+ else:
324
+ raise ValueError(f"Invalid {weights_name} provided.")
325
+
326
+ os.makedirs(save_directory, exist_ok=True)
327
+
328
+ if push_to_hub:
329
+ commit_message = kwargs.pop("commit_message", None)
330
+ private = kwargs.pop("private", False)
331
+ create_pr = kwargs.pop("create_pr", False)
332
+ token = kwargs.pop("token", None)
333
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
334
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
335
+
336
+ # Only save the model itself if we are using distributed training
337
+ model_to_save = self
338
+
339
+ # Attach architecture to the config
340
+ # Save the config
341
+ if is_main_process:
342
+ model_to_save.save_config(save_directory)
343
+
344
+ # Save the model
345
+ state_dict = model_to_save.state_dict()
346
+
347
+ # Save the model
348
+ state_dict_split = split_torch_state_dict_into_shards(
349
+ state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
350
+ )
351
+
352
+ # Clean the folder from a previous save
353
+ if is_main_process:
354
+ for filename in os.listdir(save_directory):
355
+ if filename in state_dict_split.filename_to_tensors.keys():
356
+ continue
357
+ full_filename = os.path.join(save_directory, filename)
358
+ if not os.path.isfile(full_filename):
359
+ continue
360
+ weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
361
+ weights_without_ext = weights_without_ext.replace("{suffix}", "")
362
+ filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
363
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
364
+ if (
365
+ filename.startswith(weights_without_ext)
366
+ and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
367
+ ):
368
+ os.remove(full_filename)
369
+
370
+ for filename, tensors in state_dict_split.filename_to_tensors.items():
371
+ shard = {tensor: state_dict[tensor] for tensor in tensors}
372
+ filepath = os.path.join(save_directory, filename)
373
+ if safe_serialization:
374
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
375
+ # joyfulness), but for now this enough.
376
+ safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
377
+ else:
378
+ torch.save(shard, filepath)
379
+
380
+ if state_dict_split.is_sharded:
381
+ index = {
382
+ "metadata": state_dict_split.metadata,
383
+ "weight_map": state_dict_split.tensor_to_filename,
384
+ }
385
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
386
+ save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
387
+ # Save the index as well
388
+ with open(save_index_file, "w", encoding="utf-8") as f:
389
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
390
+ f.write(content)
391
+ logger.info(
392
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
393
+ f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
394
+ f"index located at {save_index_file}."
395
+ )
396
+ else:
397
+ path_to_weights = os.path.join(save_directory, weights_name)
398
+ logger.info(f"Model weights saved in {path_to_weights}")
399
+
400
+ if push_to_hub:
401
+ # Create a new empty model card and eventually tag it
402
+ model_card = load_or_create_model_card(repo_id, token=token)
403
+ model_card = populate_model_card(model_card)
404
+ model_card.save(Path(save_directory, "README.md").as_posix())
405
+
406
+ self._upload_folder(
407
+ save_directory,
408
+ repo_id,
409
+ token=token,
410
+ commit_message=commit_message,
411
+ create_pr=create_pr,
412
+ )
413
+
414
+ @classmethod
415
+ @validate_hf_hub_args
416
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
417
+ r"""
418
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
419
+
420
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
421
+ train the model, set it back in training mode with `model.train()`.
422
+
423
+ Parameters:
424
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
425
+ Can be either:
426
+
427
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
428
+ the Hub.
429
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
430
+ with [`~ModelMixin.save_pretrained`].
431
+
432
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
433
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
434
+ is not used.
435
+ torch_dtype (`str` or `torch.dtype`, *optional*):
436
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
437
+ dtype is automatically derived from the model's weights.
438
+ force_download (`bool`, *optional*, defaults to `False`):
439
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
440
+ cached versions if they exist.
441
+ proxies (`Dict[str, str]`, *optional*):
442
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
443
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
444
+ output_loading_info (`bool`, *optional*, defaults to `False`):
445
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
446
+ local_files_only(`bool`, *optional*, defaults to `False`):
447
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
448
+ won't be downloaded from the Hub.
449
+ token (`str` or *bool*, *optional*):
450
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
451
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
452
+ revision (`str`, *optional*, defaults to `"main"`):
453
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
454
+ allowed by Git.
455
+ from_flax (`bool`, *optional*, defaults to `False`):
456
+ Load the model weights from a Flax checkpoint save file.
457
+ subfolder (`str`, *optional*, defaults to `""`):
458
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
459
+ mirror (`str`, *optional*):
460
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
461
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
462
+ information.
463
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
464
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
465
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
466
+ same device. Defaults to `None`, meaning that the model will be loaded on CPU.
467
+
468
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
469
+ more information about each option see [designing a device
470
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
471
+ max_memory (`Dict`, *optional*):
472
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
473
+ each GPU and the available CPU RAM if unset.
474
+ offload_folder (`str` or `os.PathLike`, *optional*):
475
+ The path to offload weights if `device_map` contains the value `"disk"`.
476
+ offload_state_dict (`bool`, *optional*):
477
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
478
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
479
+ when there is some disk offload.
480
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
481
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
482
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
483
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
484
+ argument to `True` will raise an error.
485
+ variant (`str`, *optional*):
486
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
487
+ loading `from_flax`.
488
+ use_safetensors (`bool`, *optional*, defaults to `None`):
489
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
490
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
491
+ weights. If set to `False`, `safetensors` weights are not loaded.
492
+
493
+ <Tip>
494
+
495
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
496
+ `huggingface-cli login`. You can also activate the special
497
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
498
+ firewalled environment.
499
+
500
+ </Tip>
501
+
502
+ Example:
503
+
504
+ ```py
505
+ from diffusers import UNet2DConditionModel
506
+
507
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
508
+ ```
509
+
510
+ If you get the error message below, you need to finetune the weights for your downstream task:
511
+
512
+ ```bash
513
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
514
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
515
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
516
+ ```
517
+ """
518
+ cache_dir = kwargs.pop("cache_dir", None)
519
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
520
+ force_download = kwargs.pop("force_download", False)
521
+ from_flax = kwargs.pop("from_flax", False)
522
+ proxies = kwargs.pop("proxies", None)
523
+ output_loading_info = kwargs.pop("output_loading_info", False)
524
+ local_files_only = kwargs.pop("local_files_only", None)
525
+ token = kwargs.pop("token", None)
526
+ revision = kwargs.pop("revision", None)
527
+ torch_dtype = kwargs.pop("torch_dtype", None)
528
+ subfolder = kwargs.pop("subfolder", None)
529
+ device_map = kwargs.pop("device_map", None)
530
+ max_memory = kwargs.pop("max_memory", None)
531
+ offload_folder = kwargs.pop("offload_folder", None)
532
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
533
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
534
+ variant = kwargs.pop("variant", None)
535
+ use_safetensors = kwargs.pop("use_safetensors", None)
536
+
537
+ allow_pickle = False
538
+ if use_safetensors is None:
539
+ use_safetensors = True
540
+ allow_pickle = True
541
+
542
+ if low_cpu_mem_usage and not is_accelerate_available():
543
+ low_cpu_mem_usage = False
544
+ logger.warning(
545
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
546
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
547
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
548
+ " install accelerate\n```\n."
549
+ )
550
+
551
+ if device_map is not None and not is_accelerate_available():
552
+ raise NotImplementedError(
553
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
554
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
555
+ )
556
+
557
+ # Check if we can handle device_map and dispatching the weights
558
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
559
+ raise NotImplementedError(
560
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
561
+ " `device_map=None`."
562
+ )
563
+
564
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
565
+ raise NotImplementedError(
566
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
567
+ " `low_cpu_mem_usage=False`."
568
+ )
569
+
570
+ if low_cpu_mem_usage is False and device_map is not None:
571
+ raise ValueError(
572
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
573
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
574
+ )
575
+
576
+ # change device_map into a map if we passed an int, a str or a torch.device
577
+ if isinstance(device_map, torch.device):
578
+ device_map = {"": device_map}
579
+ elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
580
+ try:
581
+ device_map = {"": torch.device(device_map)}
582
+ except RuntimeError:
583
+ raise ValueError(
584
+ "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
585
+ f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
586
+ )
587
+ elif isinstance(device_map, int):
588
+ if device_map < 0:
589
+ raise ValueError(
590
+ "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
591
+ )
592
+ else:
593
+ device_map = {"": device_map}
594
+
595
+ if device_map is not None:
596
+ if low_cpu_mem_usage is None:
597
+ low_cpu_mem_usage = True
598
+ elif not low_cpu_mem_usage:
599
+ raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
600
+
601
+ if low_cpu_mem_usage:
602
+ if device_map is not None and not is_torch_version(">=", "1.10"):
603
+ # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
604
+ raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
605
+
606
+ # Load config if we don't provide a configuration
607
+ config_path = pretrained_model_name_or_path
608
+
609
+ user_agent = {
610
+ "diffusers": __version__,
611
+ "file_type": "model",
612
+ "framework": "pytorch",
613
+ }
614
+
615
+ # load config
616
+ config, unused_kwargs, commit_hash = cls.load_config(
617
+ config_path,
618
+ cache_dir=cache_dir,
619
+ return_unused_kwargs=True,
620
+ return_commit_hash=True,
621
+ force_download=force_download,
622
+ proxies=proxies,
623
+ local_files_only=local_files_only,
624
+ token=token,
625
+ revision=revision,
626
+ subfolder=subfolder,
627
+ user_agent=user_agent,
628
+ **kwargs,
629
+ )
630
+
631
+ # Determine if we're loading from a directory of sharded checkpoints.
632
+ is_sharded = False
633
+ index_file = None
634
+ is_local = os.path.isdir(pretrained_model_name_or_path)
635
+ index_file = _fetch_index_file(
636
+ is_local=is_local,
637
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
638
+ subfolder=subfolder or "",
639
+ use_safetensors=use_safetensors,
640
+ cache_dir=cache_dir,
641
+ variant=variant,
642
+ force_download=force_download,
643
+ proxies=proxies,
644
+ local_files_only=local_files_only,
645
+ token=token,
646
+ revision=revision,
647
+ user_agent=user_agent,
648
+ commit_hash=commit_hash,
649
+ )
650
+ if index_file is not None and index_file.is_file():
651
+ is_sharded = True
652
+
653
+ if is_sharded and from_flax:
654
+ raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
655
+
656
+ # load model
657
+ model_file = None
658
+ if from_flax:
659
+ model_file = _get_model_file(
660
+ pretrained_model_name_or_path,
661
+ weights_name=FLAX_WEIGHTS_NAME,
662
+ cache_dir=cache_dir,
663
+ force_download=force_download,
664
+ proxies=proxies,
665
+ local_files_only=local_files_only,
666
+ token=token,
667
+ revision=revision,
668
+ subfolder=subfolder,
669
+ user_agent=user_agent,
670
+ commit_hash=commit_hash,
671
+ )
672
+ model = cls.from_config(config, **unused_kwargs)
673
+
674
+ # Convert the weights
675
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
676
+
677
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
678
+ else:
679
+ if is_sharded:
680
+ sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
681
+ pretrained_model_name_or_path,
682
+ index_file,
683
+ cache_dir=cache_dir,
684
+ proxies=proxies,
685
+ local_files_only=local_files_only,
686
+ token=token,
687
+ user_agent=user_agent,
688
+ revision=revision,
689
+ subfolder=subfolder or "",
690
+ )
691
+
692
+ elif use_safetensors and not is_sharded:
693
+ try:
694
+ model_file = _get_model_file(
695
+ pretrained_model_name_or_path,
696
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
697
+ cache_dir=cache_dir,
698
+ force_download=force_download,
699
+ proxies=proxies,
700
+ local_files_only=local_files_only,
701
+ token=token,
702
+ revision=revision,
703
+ subfolder=subfolder,
704
+ user_agent=user_agent,
705
+ commit_hash=commit_hash,
706
+ )
707
+
708
+ except IOError as e:
709
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
710
+ if not allow_pickle:
711
+ raise
712
+ logger.warning(
713
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
714
+ )
715
+
716
+ if model_file is None and not is_sharded:
717
+ model_file = _get_model_file(
718
+ pretrained_model_name_or_path,
719
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
720
+ cache_dir=cache_dir,
721
+ force_download=force_download,
722
+ proxies=proxies,
723
+ local_files_only=local_files_only,
724
+ token=token,
725
+ revision=revision,
726
+ subfolder=subfolder,
727
+ user_agent=user_agent,
728
+ commit_hash=commit_hash,
729
+ )
730
+
731
+ if low_cpu_mem_usage:
732
+ # Instantiate model with empty weights
733
+ with accelerate.init_empty_weights():
734
+ model = cls.from_config(config, **unused_kwargs)
735
+
736
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
737
+ if device_map is None and not is_sharded:
738
+ param_device = "cpu"
739
+ state_dict = load_state_dict(model_file, variant=variant)
740
+ model._convert_deprecated_attention_blocks(state_dict)
741
+ # move the params from meta device to cpu
742
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
743
+ if len(missing_keys) > 0:
744
+ raise ValueError(
745
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
746
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
747
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
748
+ " those weights or else make sure your checkpoint file is correct."
749
+ )
750
+
751
+ unexpected_keys = load_model_dict_into_meta(
752
+ model,
753
+ state_dict,
754
+ device=param_device,
755
+ dtype=torch_dtype,
756
+ model_name_or_path=pretrained_model_name_or_path,
757
+ )
758
+
759
+ if cls._keys_to_ignore_on_load_unexpected is not None:
760
+ for pat in cls._keys_to_ignore_on_load_unexpected:
761
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
762
+
763
+ if len(unexpected_keys) > 0:
764
+ logger.warning(
765
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
766
+ )
767
+
768
+ else: # else let accelerate handle loading and dispatching.
769
+ # Load weights and dispatch according to the device_map
770
+ # by default the device_map is None and the weights are loaded on the CPU
771
+ force_hook = True
772
+ device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
773
+ if device_map is None and is_sharded:
774
+ # we load the parameters on the cpu
775
+ device_map = {"": "cpu"}
776
+ force_hook = False
777
+ try:
778
+ accelerate.load_checkpoint_and_dispatch(
779
+ model,
780
+ model_file if not is_sharded else index_file,
781
+ device_map,
782
+ max_memory=max_memory,
783
+ offload_folder=offload_folder,
784
+ offload_state_dict=offload_state_dict,
785
+ dtype=torch_dtype,
786
+ force_hooks=force_hook,
787
+ strict=True,
788
+ )
789
+ except AttributeError as e:
790
+ # When using accelerate loading, we do not have the ability to load the state
791
+ # dict and rename the weight names manually. Additionally, accelerate skips
792
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
793
+ # (which look like they should be private variables?), so we can't use the standard hooks
794
+ # to rename parameters on load. We need to mimic the original weight names so the correct
795
+ # attributes are available. After we have loaded the weights, we convert the deprecated
796
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
797
+ # the weights so we don't have to do this again.
798
+
799
+ if "'Attention' object has no attribute" in str(e):
800
+ logger.warning(
801
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
802
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
803
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
804
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
805
+ " please also re-upload it or open a PR on the original repository."
806
+ )
807
+ model._temp_convert_self_to_deprecated_attention_blocks()
808
+ accelerate.load_checkpoint_and_dispatch(
809
+ model,
810
+ model_file if not is_sharded else index_file,
811
+ device_map,
812
+ max_memory=max_memory,
813
+ offload_folder=offload_folder,
814
+ offload_state_dict=offload_state_dict,
815
+ dtype=torch_dtype,
816
+ force_hooks=force_hook,
817
+ strict=True,
818
+ )
819
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
820
+ else:
821
+ raise e
822
+
823
+ loading_info = {
824
+ "missing_keys": [],
825
+ "unexpected_keys": [],
826
+ "mismatched_keys": [],
827
+ "error_msgs": [],
828
+ }
829
+ else:
830
+ model = cls.from_config(config, **unused_kwargs)
831
+
832
+ state_dict = load_state_dict(model_file, variant=variant)
833
+ model._convert_deprecated_attention_blocks(state_dict)
834
+
835
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
836
+ model,
837
+ state_dict,
838
+ model_file,
839
+ pretrained_model_name_or_path,
840
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
841
+ )
842
+
843
+ loading_info = {
844
+ "missing_keys": missing_keys,
845
+ "unexpected_keys": unexpected_keys,
846
+ "mismatched_keys": mismatched_keys,
847
+ "error_msgs": error_msgs,
848
+ }
849
+
850
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
851
+ raise ValueError(
852
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
853
+ )
854
+ elif torch_dtype is not None:
855
+ model = model.to(torch_dtype)
856
+
857
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
858
+
859
+ # Set model in evaluation mode to deactivate DropOut modules by default
860
+ model.eval()
861
+ if output_loading_info:
862
+ return model, loading_info
863
+
864
+ return model
865
+
866
+ @classmethod
867
+ def _load_pretrained_model(
868
+ cls,
869
+ model,
870
+ state_dict: OrderedDict,
871
+ resolved_archive_file,
872
+ pretrained_model_name_or_path: Union[str, os.PathLike],
873
+ ignore_mismatched_sizes: bool = False,
874
+ ):
875
+ # Retrieve missing & unexpected_keys
876
+ model_state_dict = model.state_dict()
877
+ loaded_keys = list(state_dict.keys())
878
+
879
+ expected_keys = list(model_state_dict.keys())
880
+
881
+ original_loaded_keys = loaded_keys
882
+
883
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
884
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
885
+
886
+ # Make sure we are able to load base models as well as derived models (with heads)
887
+ model_to_load = model
888
+
889
+ def _find_mismatched_keys(
890
+ state_dict,
891
+ model_state_dict,
892
+ loaded_keys,
893
+ ignore_mismatched_sizes,
894
+ ):
895
+ mismatched_keys = []
896
+ if ignore_mismatched_sizes:
897
+ for checkpoint_key in loaded_keys:
898
+ model_key = checkpoint_key
899
+
900
+ if (
901
+ model_key in model_state_dict
902
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
903
+ ):
904
+ mismatched_keys.append(
905
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
906
+ )
907
+ del state_dict[checkpoint_key]
908
+ return mismatched_keys
909
+
910
+ if state_dict is not None:
911
+ # Whole checkpoint
912
+ mismatched_keys = _find_mismatched_keys(
913
+ state_dict,
914
+ model_state_dict,
915
+ original_loaded_keys,
916
+ ignore_mismatched_sizes,
917
+ )
918
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
919
+
920
+ if len(error_msgs) > 0:
921
+ error_msg = "\n\t".join(error_msgs)
922
+ if "size mismatch" in error_msg:
923
+ error_msg += (
924
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
925
+ )
926
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
927
+
928
+ if len(unexpected_keys) > 0:
929
+ logger.warning(
930
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
931
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
932
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
933
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
934
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
935
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
936
+ " identical (initializing a BertForSequenceClassification model from a"
937
+ " BertForSequenceClassification model)."
938
+ )
939
+ else:
940
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
941
+ if len(missing_keys) > 0:
942
+ logger.warning(
943
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
944
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
945
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
946
+ )
947
+ elif len(mismatched_keys) == 0:
948
+ logger.info(
949
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
950
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
951
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
952
+ " without further training."
953
+ )
954
+ if len(mismatched_keys) > 0:
955
+ mismatched_warning = "\n".join(
956
+ [
957
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
958
+ for key, shape1, shape2 in mismatched_keys
959
+ ]
960
+ )
961
+ logger.warning(
962
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
963
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
964
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
965
+ " able to use it for predictions and inference."
966
+ )
967
+
968
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
969
+
970
+ @classmethod
971
+ def _get_signature_keys(cls, obj):
972
+ parameters = inspect.signature(obj.__init__).parameters
973
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
974
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
975
+ expected_modules = set(required_parameters.keys()) - {"self"}
976
+
977
+ return expected_modules, optional_parameters
978
+
979
+ # Adapted from `transformers` modeling_utils.py
980
+ def _get_no_split_modules(self, device_map: str):
981
+ """
982
+ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
983
+ get the underlying `_no_split_modules`.
984
+
985
+ Args:
986
+ device_map (`str`):
987
+ The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
988
+
989
+ Returns:
990
+ `List[str]`: List of modules that should not be split
991
+ """
992
+ _no_split_modules = set()
993
+ modules_to_check = [self]
994
+ while len(modules_to_check) > 0:
995
+ module = modules_to_check.pop(-1)
996
+ # if the module does not appear in _no_split_modules, we also check the children
997
+ if module.__class__.__name__ not in _no_split_modules:
998
+ if isinstance(module, ModelMixin):
999
+ if module._no_split_modules is None:
1000
+ raise ValueError(
1001
+ f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
1002
+ "class needs to implement the `_no_split_modules` attribute."
1003
+ )
1004
+ else:
1005
+ _no_split_modules = _no_split_modules | set(module._no_split_modules)
1006
+ modules_to_check += list(module.children())
1007
+ return list(_no_split_modules)
1008
+
1009
+ @property
1010
+ def device(self) -> torch.device:
1011
+ """
1012
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1013
+ device).
1014
+ """
1015
+ return get_parameter_device(self)
1016
+
1017
+ @property
1018
+ def dtype(self) -> torch.dtype:
1019
+ """
1020
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1021
+ """
1022
+ return get_parameter_dtype(self)
1023
+
1024
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
1025
+ """
1026
+ Get number of (trainable or non-embedding) parameters in the module.
1027
+
1028
+ Args:
1029
+ only_trainable (`bool`, *optional*, defaults to `False`):
1030
+ Whether or not to return only the number of trainable parameters.
1031
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
1032
+ Whether or not to return only the number of non-embedding parameters.
1033
+
1034
+ Returns:
1035
+ `int`: The number of parameters.
1036
+
1037
+ Example:
1038
+
1039
+ ```py
1040
+ from diffusers import UNet2DConditionModel
1041
+
1042
+ model_id = "runwayml/stable-diffusion-v1-5"
1043
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
1044
+ unet.num_parameters(only_trainable=True)
1045
+ 859520964
1046
+ ```
1047
+ """
1048
+
1049
+ if exclude_embeddings:
1050
+ embedding_param_names = [
1051
+ f"{name}.weight"
1052
+ for name, module_type in self.named_modules()
1053
+ if isinstance(module_type, torch.nn.Embedding)
1054
+ ]
1055
+ non_embedding_parameters = [
1056
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1057
+ ]
1058
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1059
+ else:
1060
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1061
+
1062
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1063
+ deprecated_attention_block_paths = []
1064
+
1065
+ def recursive_find_attn_block(name, module):
1066
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1067
+ deprecated_attention_block_paths.append(name)
1068
+
1069
+ for sub_name, sub_module in module.named_children():
1070
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
1071
+ recursive_find_attn_block(sub_name, sub_module)
1072
+
1073
+ recursive_find_attn_block("", self)
1074
+
1075
+ # NOTE: we have to check if the deprecated parameters are in the state dict
1076
+ # because it is possible we are loading from a state dict that was already
1077
+ # converted
1078
+
1079
+ for path in deprecated_attention_block_paths:
1080
+ # group_norm path stays the same
1081
+
1082
+ # query -> to_q
1083
+ if f"{path}.query.weight" in state_dict:
1084
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
1085
+ if f"{path}.query.bias" in state_dict:
1086
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
1087
+
1088
+ # key -> to_k
1089
+ if f"{path}.key.weight" in state_dict:
1090
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
1091
+ if f"{path}.key.bias" in state_dict:
1092
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
1093
+
1094
+ # value -> to_v
1095
+ if f"{path}.value.weight" in state_dict:
1096
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
1097
+ if f"{path}.value.bias" in state_dict:
1098
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
1099
+
1100
+ # proj_attn -> to_out.0
1101
+ if f"{path}.proj_attn.weight" in state_dict:
1102
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
1103
+ if f"{path}.proj_attn.bias" in state_dict:
1104
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1105
+
1106
+ def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1107
+ deprecated_attention_block_modules = []
1108
+
1109
+ def recursive_find_attn_block(module):
1110
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1111
+ deprecated_attention_block_modules.append(module)
1112
+
1113
+ for sub_module in module.children():
1114
+ recursive_find_attn_block(sub_module)
1115
+
1116
+ recursive_find_attn_block(self)
1117
+
1118
+ for module in deprecated_attention_block_modules:
1119
+ module.query = module.to_q
1120
+ module.key = module.to_k
1121
+ module.value = module.to_v
1122
+ module.proj_attn = module.to_out[0]
1123
+
1124
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
1125
+ # that _all_ the weights are loaded into the new attributes and we're not
1126
+ # making an incorrect assumption that this model should be converted when
1127
+ # it really shouldn't be.
1128
+ del module.to_q
1129
+ del module.to_k
1130
+ del module.to_v
1131
+ del module.to_out
1132
+
1133
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1134
+ deprecated_attention_block_modules = []
1135
+
1136
+ def recursive_find_attn_block(module) -> None:
1137
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1138
+ deprecated_attention_block_modules.append(module)
1139
+
1140
+ for sub_module in module.children():
1141
+ recursive_find_attn_block(sub_module)
1142
+
1143
+ recursive_find_attn_block(self)
1144
+
1145
+ for module in deprecated_attention_block_modules:
1146
+ module.to_q = module.query
1147
+ module.to_k = module.key
1148
+ module.to_v = module.value
1149
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
1150
+
1151
+ del module.query
1152
+ del module.key
1153
+ del module.value
1154
+ del module.proj_attn
1155
+
1156
+
1157
+ class LegacyModelMixin(ModelMixin):
1158
+ r"""
1159
+ A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
1160
+ pipeline-specific classes (like `DiTTransformer2DModel`).
1161
+ """
1162
+
1163
+ @classmethod
1164
+ @validate_hf_hub_args
1165
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1166
+ # To prevent dependency import problem.
1167
+ from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config
1168
+
1169
+ # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
1170
+ kwargs_copy = kwargs.copy()
1171
+
1172
+ cache_dir = kwargs.pop("cache_dir", None)
1173
+ force_download = kwargs.pop("force_download", False)
1174
+ proxies = kwargs.pop("proxies", None)
1175
+ local_files_only = kwargs.pop("local_files_only", None)
1176
+ token = kwargs.pop("token", None)
1177
+ revision = kwargs.pop("revision", None)
1178
+ subfolder = kwargs.pop("subfolder", None)
1179
+
1180
+ # Load config if we don't provide a configuration
1181
+ config_path = pretrained_model_name_or_path
1182
+
1183
+ user_agent = {
1184
+ "diffusers": __version__,
1185
+ "file_type": "model",
1186
+ "framework": "pytorch",
1187
+ }
1188
+
1189
+ # load config
1190
+ config, _, _ = cls.load_config(
1191
+ config_path,
1192
+ cache_dir=cache_dir,
1193
+ return_unused_kwargs=True,
1194
+ return_commit_hash=True,
1195
+ force_download=force_download,
1196
+ proxies=proxies,
1197
+ local_files_only=local_files_only,
1198
+ token=token,
1199
+ revision=revision,
1200
+ subfolder=subfolder,
1201
+ user_agent=user_agent,
1202
+ **kwargs,
1203
+ )
1204
+ # resolve remapping
1205
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
1206
+
1207
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
models/sampling.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/lucidrains/muse-maskgit-pytorch
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def log(t, eps=1e-20):
11
+ return torch.log(t.clamp(min=eps))
12
+
13
+
14
+ def gumbel_noise(t, generator=None):
15
+ noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
16
+ return -log(-log(noise))
17
+
18
+
19
+ def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
20
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
21
+
22
+
23
+ def top_k(logits, thres=0.9):
24
+ k = math.ceil((1 - thres) * logits.shape[-1])
25
+ val, ind = logits.topk(k, dim=-1)
26
+ probs = torch.full_like(logits, float("-inf"))
27
+ probs.scatter_(2, ind, val)
28
+ return probs
29
+
30
+
31
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
32
+ confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator)
33
+ sorted_confidence = torch.sort(confidence, dim=-1).values
34
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
35
+ masking = confidence < cut_off
36
+ return masking
37
+
38
+
39
+ def cosine_schedule(t):
40
+ return torch.cos(t * math.pi * 0.5)
41
+
42
+
43
+ def linear_schedule(t):
44
+ mask_ratio = 1 - t
45
+ mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
46
+ return mask_ratio
47
+
48
+
49
+ def pow(t, method):
50
+ exponent = float(method.replace("pow", ""))
51
+ mask_ratio = 1.0 - t**exponent
52
+ mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
53
+ return mask_ratio
54
+
55
+
56
+ def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6):
57
+ for item in [t, start, end, tau]:
58
+ item = torch.tensor(item) if not torch.is_tensor(item) else item
59
+
60
+ # A gamma function based on sigmoid function.
61
+ v_start = torch.sigmoid(torch.tensor(start / tau))
62
+ v_end = torch.sigmoid(torch.tensor(end / tau))
63
+ output = torch.sigmoid((t * (end - start) + start) / tau)
64
+ output = (v_end - output) / (v_end - v_start)
65
+ return torch.clip(output, clip_min, 1.0)
66
+
67
+
68
+ def get_mask_schedule(method, **schedule_kwargs):
69
+ if method == "cosine":
70
+ return cosine_schedule
71
+ elif method == "linear":
72
+ return linear_schedule
73
+ elif "pow" in method:
74
+ return partial(pow, method=method)
75
+ elif method == "sigmoid":
76
+ return partial(sigmoid_schedule, **schedule_kwargs)
77
+ else:
78
+ raise ValueError("Unknown schedule method: {}".format(method))
79
+
80
+ def top_k_top_p_filtering(
81
+ logits: torch.Tensor,
82
+ top_k: int = 0,
83
+ top_p: float = 1.0,
84
+ filter_value: float = -float("Inf"),
85
+ min_tokens_to_keep: int = 1,
86
+ ) -> torch.Tensor:
87
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
88
+ Args:
89
+ logits: logits distribution shape (batch size, vocabulary size)
90
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
91
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
92
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
93
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
94
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
95
+ """
96
+ if top_k > 0:
97
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
98
+ # Remove all tokens with a probability less than the last token of the top-k
99
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
100
+ logits[indices_to_remove] = filter_value
101
+
102
+ if top_p < 1.0:
103
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
104
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
105
+
106
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
107
+ sorted_indices_to_remove = cumulative_probs > top_p
108
+ if min_tokens_to_keep > 1:
109
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
110
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
111
+ # Shift the indices to the right to keep also the first token above the threshold
112
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
113
+ sorted_indices_to_remove[..., 0] = 0
114
+
115
+ # scatter sorted tensors to original indexing
116
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
117
+ logits[indices_to_remove] = filter_value
118
+ return logits
models/training_utils.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ import os
18
+ import random
19
+ from typing import Any, Dict, Iterable, Optional, Union
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+
27
+ def enable_full_determinism(seed: int):
28
+ """
29
+ Helper function for reproducible behavior during distributed training. See
30
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
31
+ """
32
+ # set seed first
33
+ set_seed(seed)
34
+
35
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
36
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
37
+ # depending on the CUDA version, so we set them both here
38
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
39
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
40
+ torch.use_deterministic_algorithms(True)
41
+
42
+ # Enable CUDNN deterministic mode
43
+ torch.backends.cudnn.deterministic = True
44
+ torch.backends.cudnn.benchmark = False
45
+
46
+
47
+ def set_seed(seed: int):
48
+ """
49
+ Args:
50
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
51
+ seed (`int`): The seed to set.
52
+ """
53
+ random.seed(seed)
54
+ np.random.seed(seed)
55
+ torch.manual_seed(seed)
56
+ torch.cuda.manual_seed_all(seed)
57
+ # ^^ safe to call this function even if cuda is not available
58
+
59
+
60
+ # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
61
+ class EMA:
62
+ """
63
+ Exponential Moving Average of models weights
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ parameters: Iterable[torch.nn.Parameter],
69
+ decay: float = 0.9999,
70
+ min_decay: float = 0.0,
71
+ update_after_step: int = 0,
72
+ use_ema_warmup: bool = False,
73
+ inv_gamma: Union[float, int] = 1.0,
74
+ power: Union[float, int] = 2 / 3,
75
+ model_cls: Optional[Any] = None,
76
+ model_config: Dict[str, Any] = None,
77
+ **kwargs,
78
+ ):
79
+ """
80
+ Args:
81
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
82
+ decay (float): The decay factor for the exponential moving average.
83
+ min_decay (float): The minimum decay factor for the exponential moving average.
84
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
85
+ use_ema_warmup (bool): Whether to use EMA warmup.
86
+ inv_gamma (float):
87
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
88
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
89
+ device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
90
+ weights will be stored on CPU.
91
+
92
+ @crowsonkb's notes on EMA Warmup:
93
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
94
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
95
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
96
+ at 215.4k steps).
97
+ """
98
+
99
+ parameters = list(parameters)
100
+ self.shadow_params = [p.clone().detach() for p in parameters]
101
+
102
+ self.temp_stored_params = None
103
+
104
+ self.decay = decay
105
+ self.min_decay = min_decay
106
+ self.update_after_step = update_after_step
107
+ self.use_ema_warmup = use_ema_warmup
108
+ self.inv_gamma = inv_gamma
109
+ self.power = power
110
+ self.optimization_step = 0
111
+ self.cur_decay_value = None # set in `step()`
112
+
113
+ self.model_cls = model_cls
114
+ self.model_config = model_config
115
+
116
+ @classmethod
117
+ def from_pretrained(cls, path, model_cls) -> "EMA":
118
+ _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
119
+ model = model_cls.from_pretrained(path)
120
+
121
+ ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
122
+
123
+ ema_model.load_state_dict(ema_kwargs)
124
+ return ema_model
125
+
126
+ def save_pretrained(self, path):
127
+ if self.model_cls is None:
128
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
129
+
130
+ if self.model_config is None:
131
+ raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
132
+
133
+ model = self.model_cls.from_config(self.model_config)
134
+ state_dict = self.state_dict()
135
+ state_dict.pop("shadow_params", None)
136
+
137
+ model.register_to_config(**state_dict)
138
+ self.copy_to(model.parameters())
139
+ model.save_pretrained(path)
140
+
141
+ def get_decay(self, optimization_step: int) -> float:
142
+ """
143
+ Compute the decay factor for the exponential moving average.
144
+ """
145
+ step = max(0, optimization_step - self.update_after_step - 1)
146
+
147
+ if step <= 0:
148
+ return 0.0
149
+
150
+ if self.use_ema_warmup:
151
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
152
+ else:
153
+ cur_decay_value = (1 + step) / (10 + step)
154
+
155
+ cur_decay_value = min(cur_decay_value, self.decay)
156
+ # make sure decay is not smaller than min_decay
157
+ cur_decay_value = max(cur_decay_value, self.min_decay)
158
+ return cur_decay_value
159
+
160
+ @torch.no_grad()
161
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
162
+ parameters = list(parameters)
163
+
164
+ self.optimization_step += 1
165
+
166
+ # Compute the decay factor for the exponential moving average.
167
+ decay = self.get_decay(self.optimization_step)
168
+ self.cur_decay_value = decay
169
+ one_minus_decay = 1 - decay
170
+
171
+ for s_param, param in zip(self.shadow_params, parameters):
172
+ if param.requires_grad:
173
+ s_param.sub_(one_minus_decay * (s_param - param))
174
+ else:
175
+ s_param.copy_(param)
176
+
177
+ torch.cuda.empty_cache()
178
+
179
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
180
+ """
181
+ Copy current averaged parameters into given collection of parameters.
182
+
183
+ Args:
184
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
185
+ updated with the stored moving averages. If `None`, the parameters with which this
186
+ `ExponentialMovingAverage` was initialized will be used.
187
+ """
188
+ parameters = list(parameters)
189
+ for s_param, param in zip(self.shadow_params, parameters):
190
+ param.data.copy_(s_param.to(param.device).data)
191
+
192
+ def to(self, device=None, dtype=None) -> None:
193
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
194
+
195
+ Args:
196
+ device: like `device` argument to `torch.Tensor.to`
197
+ """
198
+ # .to() on the tensors handles None correctly
199
+ self.shadow_params = [
200
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
201
+ for p in self.shadow_params
202
+ ]
203
+
204
+ def state_dict(self) -> dict:
205
+ r"""
206
+ Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
207
+ checkpointing to save the ema state dict.
208
+ """
209
+ # Following PyTorch conventions, references to tensors are returned:
210
+ # "returns a reference to the state and not its copy!" -
211
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
212
+ return {
213
+ "decay": self.decay,
214
+ "min_decay": self.min_decay,
215
+ "optimization_step": self.optimization_step,
216
+ "update_after_step": self.update_after_step,
217
+ "use_ema_warmup": self.use_ema_warmup,
218
+ "inv_gamma": self.inv_gamma,
219
+ "power": self.power,
220
+ "shadow_params": self.shadow_params,
221
+ }
222
+
223
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
224
+ r"""
225
+ Args:
226
+ Save the current parameters for restoring later.
227
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
228
+ temporarily stored.
229
+ """
230
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
231
+
232
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
233
+ r"""
234
+ Args:
235
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
236
+ affecting the original optimization process. Store the parameters before the `copy_to()` method. After
237
+ validation (or model saving), use this to restore the former parameters.
238
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
239
+ updated with the stored parameters. If `None`, the parameters with which this
240
+ `ExponentialMovingAverage` was initialized will be used.
241
+ """
242
+ if self.temp_stored_params is None:
243
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
244
+ for c_param, param in zip(self.temp_stored_params, parameters):
245
+ param.data.copy_(c_param.data)
246
+
247
+ # Better memory-wise.
248
+ self.temp_stored_params = None
249
+
250
+ def load_state_dict(self, state_dict: dict) -> None:
251
+ r"""
252
+ Args:
253
+ Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
254
+ ema state dict.
255
+ state_dict (dict): EMA state. Should be an object returned
256
+ from a call to :meth:`state_dict`.
257
+ """
258
+ # deepcopy, to be consistent with module API
259
+ state_dict = copy.deepcopy(state_dict)
260
+
261
+ self.decay = state_dict.get("decay", self.decay)
262
+ if self.decay < 0.0 or self.decay > 1.0:
263
+ raise ValueError("Decay must be between 0 and 1")
264
+
265
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
266
+ if not isinstance(self.min_decay, float):
267
+ raise ValueError("Invalid min_decay")
268
+
269
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
270
+ if not isinstance(self.optimization_step, int):
271
+ raise ValueError("Invalid optimization_step")
272
+
273
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
274
+ if not isinstance(self.update_after_step, int):
275
+ raise ValueError("Invalid update_after_step")
276
+
277
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
278
+ if not isinstance(self.use_ema_warmup, bool):
279
+ raise ValueError("Invalid use_ema_warmup")
280
+
281
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
282
+ if not isinstance(self.inv_gamma, (float, int)):
283
+ raise ValueError("Invalid inv_gamma")
284
+
285
+ self.power = state_dict.get("power", self.power)
286
+ if not isinstance(self.power, (float, int)):
287
+ raise ValueError("Invalid power")
288
+
289
+ shadow_params = state_dict.get("shadow_params", None)
290
+ if shadow_params is not None:
291
+ self.shadow_params = shadow_params
292
+ if not isinstance(self.shadow_params, list):
293
+ raise ValueError("shadow_params must be a list")
294
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
295
+ raise ValueError("shadow_params must all be Tensors")
296
+
297
+
298
+ # calculates entropy over each pixel distribution
299
+ def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
300
+ # only calculated entropy over image tokens that were masked in the original image
301
+ masked_tokens = input_ids == mask_id
302
+ num_masked_pixels = masked_tokens.sum(-1)
303
+
304
+ probs = F.softmax(logits, dim=-1)
305
+ log_probs = F.log_softmax(logits, dim=-1)
306
+
307
+ entropy_per_pixel = -((probs * log_probs).sum(-1))
308
+
309
+ # the predictions for non-masked aren't used, so set their entropies to zero
310
+ entropy_per_pixel[~masked_tokens] = 0
311
+
312
+ entropy_per_image_numerator = entropy_per_pixel.sum(-1)
313
+ entropy_per_image = entropy_per_image_numerator / num_masked_pixels
314
+
315
+ total_buckets = 10
316
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
317
+
318
+ entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
319
+
320
+ return entropy_by_masked_bucket
321
+
322
+
323
+ # calculates entropy over the averaged distribution of pixels for the whole image
324
+ def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
325
+ # only calculated entropy over image tokens that were masked in the original image
326
+ masked_tokens = input_ids == mask_id
327
+ num_masked_pixels = masked_tokens.sum(-1, keepdim=True)
328
+
329
+ pixel_probs = F.softmax(logits, dim=-1)
330
+ pixel_probs[~masked_tokens] = 0
331
+ image_probs_numerator = pixel_probs.sum(-2)
332
+ image_probs = image_probs_numerator / num_masked_pixels
333
+
334
+ image_log_probs = image_probs.log()
335
+
336
+ entropy_per_image = -((image_probs * image_log_probs).sum(-1))
337
+
338
+ total_buckets = 10
339
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
340
+
341
+ entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
342
+
343
+ return entropy_by_masked_bucket
344
+
345
+
346
+ def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing):
347
+ cross_entropy_per_image = F.cross_entropy(
348
+ logits.view(-1, output_size),
349
+ labels.view(-1),
350
+ ignore_index=-100,
351
+ label_smoothing=label_smoothing,
352
+ reduction="none",
353
+ )
354
+
355
+ total_buckets = 10
356
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
357
+
358
+ cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets)
359
+
360
+ return cross_entropy_by_percent_masked_bucket
361
+
362
+
363
+ def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id):
364
+ probs = F.softmax(logits, dim=-1)
365
+
366
+ total_buckets = 10
367
+ masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
368
+
369
+ data = []
370
+
371
+ for bucket_idx in range(total_buckets):
372
+ indices_for_bucket = masked_buckets[masked_buckets == bucket_idx]
373
+
374
+ # It's ok if none were noised in the range of this bucket. This
375
+ # function will be called for a later training step where it's likely
376
+ # there will be an element noised in the range.
377
+ if indices_for_bucket.shape[0] == 0:
378
+ continue
379
+
380
+ index_for_bucket = indices_for_bucket[0]
381
+
382
+ image_probs = probs[index_for_bucket]
383
+
384
+ # find the index of a masked pixel for the image
385
+ input_ids_for_image = input_ids[index_for_bucket]
386
+ masked_pixels_probs = image_probs[input_ids_for_image == mask_id]
387
+
388
+ masked_pixel_probs = masked_pixels_probs[0]
389
+
390
+ masked_pixel_probs = masked_pixel_probs.cpu().numpy()
391
+
392
+ for masked_pixel_prob in masked_pixel_probs:
393
+ data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob})
394
+
395
+ df = pd.DataFrame(data)
396
+
397
+ return df
398
+
399
+
400
+ def average_by_buckets(values, masked_buckets, total_buckets):
401
+ unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True)
402
+
403
+ numerator = torch.zeros(total_buckets, device=values.device)
404
+
405
+ numerator.scatter_add_(0, masked_buckets, values)
406
+
407
+ # default value is one because the buckets for which there aren't
408
+ # any values will have a numerator of zero. So we just need to not divide
409
+ # by zero.
410
+ denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long)
411
+ denominator[unique_buckets] = bucket_counts
412
+
413
+ averaged_by_buckets = numerator / denominator
414
+
415
+ return averaged_by_buckets
416
+
417
+
418
+ def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10):
419
+ assert total_buckets == 10
420
+
421
+ masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1]
422
+
423
+ # we do not formally use timesteps to noise images. Instead, we mask a percent
424
+ # of the pixels. We don't want to log entropy for every mask percent between 0 and 1,
425
+ # and we also want to track how the entropy evolves over time w/in a range of mask
426
+ # percents that should have similar entropy. So we bucket the masked percents into a
427
+ # fixed number of buckets
428
+
429
+ # we could generalize this later if needed but for now, let's just assume a fixed
430
+ # number of 10 buckets.
431
+
432
+ # How this maps to a bucket index:
433
+ # (mask) * bucket_index +
434
+ # (mask_1) * bucket_index_1
435
+ #
436
+ # -> Where the mask is true will be set to the expected bucket index,
437
+ # where the mask is false will be set to 0.
438
+ #
439
+ # Given the probabilities are between 0 and 1, each masked_percent will get mapped
440
+ # to a timestep by one and only one of the masks.
441
+
442
+ masked_buckets = (
443
+ ((0 < masked_percent) & (masked_percent <= 0.1)) * 0
444
+ + ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1
445
+ + ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2
446
+ + ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3
447
+ + ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4
448
+ + ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5
449
+ + ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6
450
+ + ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7
451
+ + ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8
452
+ + ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9
453
+ )
454
+
455
+ return masked_buckets
training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .mmada_grpo_trainer import DiffusionGRPOTrainer
training/prompting_utils.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 MMaDA team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ reserved_token_mapping = {
18
+ '<|soi|>': 126084,
19
+ '<|eoi|>': 126085,
20
+ '<|sov|>': 126086,
21
+ '<|eov|>': 126087,
22
+ '<|t2i|>': 126088,
23
+ '<|mmu|>': 126089,
24
+ '<|t2v|>': 126090,
25
+ '<|v2v|>': 126091,
26
+ '<|lvg|>': 126092,
27
+ '[iPAD]': 126093,
28
+ '<|r2i|>': 126094,
29
+ }
30
+
31
+
32
+ import torch
33
+ class UniversalPrompting():
34
+ def __init__(self, text_tokenizer,
35
+ special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
36
+ max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=False):
37
+ """
38
+ :param text_tokenizer: original text tokenizer
39
+ """
40
+ if not use_reserved_token:
41
+ self.text_tokenizer = text_tokenizer
42
+ self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
43
+ self.text_tokenizer.add_tokens(list(special_tokens))
44
+ self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
45
+ special_tokens}
46
+ self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
47
+ self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
48
+ self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
49
+ else:
50
+ self.text_tokenizer = text_tokenizer
51
+ self.sptids_dict = {}
52
+ for token, token_id in reserved_token_mapping.items():
53
+ self.sptids_dict[token] = torch.tensor([token_id])
54
+ self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
55
+ self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
56
+ end_header_tokens = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>'])
57
+ if end_header_tokens and len(end_header_tokens) > 0 and end_header_tokens[0]:
58
+ self.sptids_dict['<|end_header_id|>'] = torch.tensor(end_header_tokens)
59
+ self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>']))
60
+ self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>']))
61
+ else:
62
+ special_tokens_dict = {
63
+ 'additional_special_tokens': [
64
+ '<|start_header_id|>',
65
+ '<|end_header_id|>',
66
+ '<|eot_id|>'
67
+ ]
68
+ }
69
+ num_added = self.text_tokenizer.add_special_tokens(special_tokens_dict)
70
+ new_token_id = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>'])
71
+ self.sptids_dict['<|end_header_id|>'] = torch.tensor(new_token_id)
72
+ self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>']))
73
+ self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>']))
74
+ # plus 1 because at this time we add a task token before
75
+ print(f"self.sptids_dict: {self.sptids_dict}")
76
+ self.max_text_len = max_text_len + 1
77
+ self.pad_id = reserved_token_mapping['[iPAD]']
78
+ self.ignore_id = ignore_id
79
+ self.cond_dropout_prob = cond_dropout_prob
80
+
81
+ def t2i_prompt(self, text_ids, image_ids, labels):
82
+
83
+ device = image_ids.device
84
+ sequence_ids = []
85
+ attention_masks = []
86
+ label_ids = []
87
+ probs = torch.rand(len(text_ids))
88
+ for i in range(len(text_ids)):
89
+
90
+ if len(text_ids[i]) == 0:
91
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
92
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
93
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
94
+
95
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
96
+
97
+ # randomly dropout text condition
98
+ if probs[i] < self.cond_dropout_prob:
99
+ temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
100
+
101
+ if self.max_text_len >= len(temp_ids):
102
+ old_len = len(temp_ids)
103
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
104
+ temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2)
105
+ else:
106
+ # should add the eos token
107
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
108
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens
109
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
110
+ temp_label_ids = torch.cat([
111
+ # should we predict text tokens when doing image reconstruction?
112
+ torch.tensor(temp_ids).to(device),
113
+ self.sptids_dict['<|soi|>'].to(device),
114
+ labels[i],
115
+ self.sptids_dict['<|eoi|>'].to(device)
116
+ ], dim=0)
117
+
118
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
119
+
120
+ temp_ids = torch.cat([
121
+ torch.tensor(temp_ids).to(device),
122
+ self.sptids_dict['<|soi|>'].to(device),
123
+ image_ids[i],
124
+ self.sptids_dict['<|eoi|>'].to(device)
125
+ ], dim=0)
126
+
127
+ # sequence_ids: [pad]...[pad] <|t2i|> <bos> text_1 ... text_n <eos> <|soi|> image_1 ... image_m <|eoi|>
128
+ temp_masks = torch.tensor(temp_masks).to(device)
129
+ sequence_ids.append(temp_ids.unsqueeze(0))
130
+ attention_masks.append(temp_masks.unsqueeze(0))
131
+ label_ids.append(temp_label_ids.unsqueeze(0))
132
+
133
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
134
+
135
+ def t2i_gen_prompt(self, text_ids, image_ids):
136
+
137
+ device = image_ids.device
138
+ sequence_ids = []
139
+ attention_masks = []
140
+ for i in range(len(text_ids)):
141
+ if len(text_ids[i]) == 0:
142
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
143
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
144
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
145
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
146
+ temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
147
+ if self.max_text_len >= len(temp_ids):
148
+ old_len = len(temp_ids)
149
+ temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
150
+ temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2)
151
+ else:
152
+ # should add the eos token
153
+ temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
154
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens
155
+
156
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
157
+ temp_ids = torch.cat([
158
+ torch.tensor(temp_ids).to(device),
159
+ self.sptids_dict['<|soi|>'].to(device),
160
+ image_ids[i],
161
+ self.sptids_dict['<|eoi|>'].to(device)
162
+ ], dim=0)
163
+
164
+ temp_masks = torch.tensor(temp_masks).to(device)
165
+ sequence_ids.append(temp_ids.unsqueeze(0))
166
+ attention_masks.append(temp_masks.unsqueeze(0))
167
+
168
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
169
+
170
+ # language modeling
171
+ def lm_prompt(self, text_ids, max_seq_len):
172
+ sequence_ids = []
173
+ attention_masks = []
174
+ label_ids = []
175
+ for i in range(len(text_ids)):
176
+ if len(text_ids[i]) == 0:
177
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
178
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
179
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
180
+
181
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
182
+
183
+ if max_seq_len >= len(temp_ids):
184
+ temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
185
+ temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
186
+ temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids))
187
+ else:
188
+ # In language modeling, we only process text tokens. We do not add the eos token if the text length
189
+ # exceeds the max sequence length
190
+ temp_labels_ids = temp_ids[:max_seq_len]
191
+ temp_ids = temp_ids[:max_seq_len]
192
+ temp_masks = [1] * len(temp_ids) # +2 for two special tokens
193
+
194
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
195
+ temp_ids = torch.tensor(temp_ids)
196
+ temp_masks = torch.tensor(temp_masks)
197
+ temp_labels_ids = torch.tensor(temp_labels_ids)
198
+ sequence_ids.append(temp_ids.unsqueeze(0))
199
+ attention_masks.append(temp_masks.unsqueeze(0))
200
+ label_ids.append(temp_labels_ids.unsqueeze(0))
201
+
202
+ # input_ids, masks, labels
203
+ return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
204
+
205
+ # language modeling
206
+ def lm_chat_prompt(self, text_ids, max_seq_len):
207
+ sequence_ids = []
208
+ prompt_masks = []
209
+ label_ids = []
210
+
211
+ for i in range(len(text_ids)):
212
+ if len(text_ids[i]) == 0:
213
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
214
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
215
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
216
+
217
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
218
+
219
+ if max_seq_len >= len(temp_ids):
220
+ temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
221
+ temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
222
+ else:
223
+ # In language modeling, we only process text tokens. We do not add the eos token if the text length
224
+ # exceeds the max sequence length
225
+ temp_labels_ids = temp_ids[:max_seq_len]
226
+ temp_ids = temp_ids[:max_seq_len]
227
+
228
+ end_header_id = int(self.sptids_dict['<|end_header_id|>'])
229
+ end_header_pos = -1
230
+ for pos in range(len(temp_ids) - 1, -1, -1): # 尝试从文本序列中寻找<|end_header_id|>
231
+ if temp_ids[pos] == end_header_id:
232
+ end_header_pos = pos
233
+ break
234
+ if end_header_pos != -1:
235
+ prompt_length = end_header_pos + 1
236
+ else:
237
+ prompt_length = 0
238
+ temp_masks = [1] * prompt_length + [0] * (len(temp_ids) - prompt_length)
239
+
240
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
241
+ temp_ids = torch.tensor(temp_ids)
242
+ temp_masks = torch.tensor(temp_masks)
243
+ temp_labels_ids = torch.tensor(temp_labels_ids)
244
+ sequence_ids.append(temp_ids.unsqueeze(0))
245
+ prompt_masks.append(temp_masks.unsqueeze(0))
246
+ label_ids.append(temp_labels_ids.unsqueeze(0))
247
+
248
+ # input_ids, masks, labels
249
+ return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0)
250
+
251
+ def mmu_prompt(self, image_ids, text_ids):
252
+ device = image_ids.device
253
+ sequence_ids = []
254
+ prompt_masks = []
255
+ label_ids = []
256
+ max_text_len = self.max_text_len - 1
257
+ for i in range(len(text_ids)):
258
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
259
+ # for empty list []
260
+
261
+ if len(text_ids[i]) == 0:
262
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
263
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
264
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
265
+
266
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
267
+
268
+ if max_text_len >= len(temp_ids):
269
+ # minus 1 because task token was prepended to the former image tokens
270
+ temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_text_len - len(temp_ids))
271
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids))
272
+ else:
273
+ # should add the eos token
274
+ temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
275
+ temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
276
+
277
+ # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
278
+ temp_label_ids = torch.cat([
279
+ torch.tensor([self.ignore_id]).to(device),
280
+ torch.tensor([self.ignore_id]).to(device),
281
+ torch.ones_like(image_ids[i]) * self.ignore_id,
282
+ torch.tensor([self.ignore_id]).to(device),
283
+ torch.tensor(temp_ids).to(device),
284
+ ], dim=0)
285
+
286
+ temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
287
+
288
+ return_temp_ids = torch.cat([
289
+ self.sptids_dict['<|mmu|>'].to(device), # task token
290
+ self.sptids_dict['<|soi|>'].to(device),
291
+ image_ids[i],
292
+ self.sptids_dict['<|eoi|>'].to(device),
293
+ torch.tensor(temp_ids).to(device),
294
+ ], dim=0)
295
+ end_header_id = int(self.sptids_dict['<|end_header_id|>'])
296
+ end_header_pos = -1
297
+ for pos in range(len(temp_ids) - 1, -1, -1):
298
+ if temp_ids[pos] == end_header_id:
299
+ end_header_pos = pos
300
+ break
301
+ if end_header_pos != -1:
302
+ prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1
303
+ else:
304
+ prompt_length = len(return_temp_ids) - len(temp_ids)
305
+ predict_length = len(return_temp_ids) - prompt_length
306
+ prompt_mask = [1] * prompt_length + [0] * predict_length
307
+ prompt_mask = torch.tensor(prompt_mask).to(device)
308
+ sequence_ids.append(return_temp_ids.unsqueeze(0))
309
+ prompt_masks.append(prompt_mask.unsqueeze(0))
310
+ label_ids.append(temp_label_ids.unsqueeze(0))
311
+
312
+ return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0)
313
+
314
+ def mmu_gen_prompt(self, image_ids, text_ids):
315
+ device = image_ids.device
316
+ sequence_ids = []
317
+ prompt_masks = []
318
+ max_text_len = self.max_text_len - 1
319
+ for i in range(len(text_ids)):
320
+
321
+ if len(text_ids[i]) == 0:
322
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
323
+ elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
324
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
325
+
326
+ temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
327
+
328
+ if max_text_len >= len(temp_ids):
329
+ # minus 1 because task token was prepended to the former image tokens
330
+ temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_text_len - len(temp_ids))
331
+ else:
332
+ # should add the eos token
333
+ temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
334
+
335
+ # print(f"mmu temp_ids: {temp_ids}")
336
+ return_temp_ids = torch.cat([
337
+ self.sptids_dict['<|mmu|>'].to(device), # task token
338
+ self.sptids_dict['<|soi|>'].to(device),
339
+ image_ids[i],
340
+ self.sptids_dict['<|eoi|>'].to(device),
341
+ torch.tensor(temp_ids).to(device),
342
+ ], dim=0)
343
+
344
+ end_header_id = int(self.sptids_dict['<|end_header_id|>'])
345
+ end_header_pos = -1
346
+ for pos in range(len(temp_ids) - 1, -1, -1):
347
+ if temp_ids[pos] == end_header_id:
348
+ end_header_pos = pos
349
+ break
350
+ if end_header_pos != -1:
351
+ prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1
352
+ else:
353
+ prompt_length = len(return_temp_ids) - len(temp_ids)
354
+ predict_length = len(temp_ids) - prompt_length
355
+ print(f"prompt_length: {prompt_length}, predict_length: {predict_length}, all length: {len(return_temp_ids)}, {return_temp_ids[-predict_length:]}")
356
+ prompt_mask = [1] * prompt_length + [0] * predict_length
357
+ prompt_mask = torch.tensor(prompt_mask).to(device)
358
+ sequence_ids.append(return_temp_ids.unsqueeze(0))
359
+ prompt_masks.append(prompt_mask.unsqueeze(0))
360
+ return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0)
361
+
362
+ def r2i_prompt(self, image_ids, text_ids):
363
+ device = image_ids.device
364
+ sequence_ids = []
365
+ prompt_masks = []
366
+ label_ids = []
367
+ r2i_id = int(self.sptids_dict['<|r2i|>'])
368
+ soi_id = int(self.sptids_dict['<|soi|>'])
369
+ eoi_id = int(self.sptids_dict['<|eoi|>'])
370
+ max_text_len = self.max_text_len - 1 # 512,include BOS text EOS
371
+ for i in range(len(text_ids)):
372
+ # note that, llama3 tokenizer automatically add the bot token at first but without eot
373
+ # for empty list []
374
+ if len(text_ids[i]) == 0:
375
+ text_ids[i] = [self.text_tokenizer.bos_token_id]
376
+ elif text_ids[i][0]!= self.text_tokenizer.bos_token_id:
377
+ text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
378
+ text_ids_with_bos_eos = text_ids[i] + [self.text_tokenizer.eos_token_id]
379
+ if max_text_len >= len(text_ids_with_bos_eos):
380
+ # minus 1 because task token was prepended to the former image tokens
381
+ text_ids_full_len = text_ids_with_bos_eos + [self.text_tokenizer.eos_token_id] * (max_text_len - len(text_ids_with_bos_eos))
382
+ else:
383
+ # should add the eos token
384
+ text_ids_full_len = text_ids_with_bos_eos[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
385
+
386
+ sequence_ids.append(torch.cat([
387
+ torch.tensor([r2i_id]).to(device), # task token
388
+ torch.tensor(text_ids_full_len).to(device),
389
+ torch.tensor([soi_id]).to(device),
390
+ image_ids[i],
391
+ torch.tensor([eoi_id]).to(device),
392
+ ], dim=0).unsqueeze(0))
393
+
394
+ end_header_id = int(self.sptids_dict['<|end_header_id|>'])
395
+ end_header_pos = -1
396
+ for pos in range(len(text_ids_full_len) - 1, -1, -1):
397
+ if text_ids_full_len[pos] == end_header_id:
398
+ end_header_pos = pos
399
+ break
400
+ prompt_mask = torch.zeros(sequence_ids[i].size(1)).to(device)
401
+ prompt_mask[0] = 1 # task_id
402
+ if end_header_pos != -1:
403
+ prompt_mask[1:end_header_pos+2] = 1
404
+ else:
405
+ prompt_mask[1:len(text_ids_full_len)+1] = 1
406
+ prompt_mask[len(text_ids_full_len)+1] = 1
407
+ prompt_mask[len(text_ids_full_len)+2+len(image_ids[i])] = 1
408
+ prompt_masks.append(prompt_mask.unsqueeze(0))
409
+
410
+ return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(sequence_ids, dim=0)
411
+
412
+
413
+
414
+ def mask_prompt(self):
415
+ pass
416
+
417
+ def __call__(self, input, task, padding=True, config=None):
418
+ """
419
+ input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor).
420
+ task (str) : a flag indicates the current task.
421
+ """
422
+ if task == "t2i":
423
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
424
+ image_ids = input[1] # (B, #tokens)
425
+ sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2])
426
+
427
+ elif task == "t2v":
428
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
429
+ image_ids = input[1] # (B, #tokens)
430
+ sequence_ids_with_masks = self.t2v_prompt(text_ids, image_ids, input[2])
431
+
432
+ elif task == "t2i_plus_lm":
433
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
434
+ image_ids = input[1] # (B, #tokens)
435
+ sequence_ids_with_masks = self.t2i_prompt(text_ids[:config.training.batch_size], image_ids,
436
+ input[2])
437
+ sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3])
438
+ return sequence_ids_with_masks, sequence_ids_with_masks_lm
439
+
440
+ elif task == "t2i_gen":
441
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
442
+ image_ids = input[1] # (B, #tokens)
443
+ sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids)
444
+
445
+ elif task == "t2v_gen":
446
+ text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
447
+ image_ids = input[1] # (B, #tokens)
448
+ sequence_ids_with_masks = self.t2v_gen_prompt(text_ids, image_ids)
449
+
450
+ elif task == "lm":
451
+ text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
452
+ sequence_ids_with_masks = self.lm_prompt(text_ids, input[1])
453
+
454
+ elif task == "lm_chat":
455
+ text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
456
+ sequence_ids_with_masks = self.lm_chat_prompt(text_ids, input[1])
457
+
458
+ elif task == "mmu":
459
+ image_ids = input[0]
460
+ text_ids = self.text_tokenizer(input[1])['input_ids']
461
+ sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids)
462
+
463
+ elif task == "r2i":
464
+ image_ids = input[0]
465
+ text_ids = self.text_tokenizer(input[1])['input_ids']
466
+ sequence_ids_with_masks = self.r2i_prompt(image_ids, text_ids)
467
+
468
+ else:
469
+ raise NotImplementedError
470
+
471
+ return sequence_ids_with_masks
472
+
473
+
474
+ if __name__ == '__main__':
475
+ pass