tyfeld
commited on
Commit
·
ea359a8
1
Parent(s):
2f15a78
initial
Browse files- app.py +871 -0
- models/__init__.py +3 -0
- models/common_modules.py +357 -0
- models/configuration_llada.py +463 -0
- models/logging.py +338 -0
- models/lr_schedulers.py +302 -0
- models/misc.py +53 -0
- models/modeling_llada.py +1500 -0
- models/modeling_magvitv2.py +440 -0
- models/modeling_mmada.py +668 -0
- models/modeling_utils.py +1207 -0
- models/sampling.py +118 -0
- models/training_utils.py +455 -0
- training/__init__.py +1 -0
- training/prompting_utils.py +475 -0
app.py
ADDED
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from torchvision import transforms
|
7 |
+
from models import MAGVITv2, get_mask_schedule, MMadaModelLM
|
8 |
+
from training.prompting_utils import UniversalPrompting
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
def image_transform(image, resolution=256, normalize=True):
|
12 |
+
image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image)
|
13 |
+
image = transforms.CenterCrop((resolution, resolution))(image)
|
14 |
+
image = transforms.ToTensor()(image)
|
15 |
+
if normalize:
|
16 |
+
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
|
17 |
+
return image
|
18 |
+
|
19 |
+
def add_gumbel_noise(logits, temperature):
|
20 |
+
"""
|
21 |
+
Adds Gumbel noise to logits for stochastic sampling.
|
22 |
+
Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
|
23 |
+
This version is more numerically stable than a version involving exp() and division.
|
24 |
+
"""
|
25 |
+
if abs(temperature) < 1e-9: # Effectively zero temperature
|
26 |
+
return logits
|
27 |
+
# Ensure logits are float64 for precision with noise, as suggested by user context
|
28 |
+
logits = logits.to(torch.float64)
|
29 |
+
# Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1)
|
30 |
+
# Add small epsilon for numerical stability inside logs
|
31 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
32 |
+
standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
|
33 |
+
return logits + temperature * standard_gumbel_noise
|
34 |
+
|
35 |
+
def get_num_transfer_tokens(mask_index, steps):
|
36 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
37 |
+
# Ensure steps is at least 1 to avoid division by zero if mask_num is also 0 (though sum should be >=0)
|
38 |
+
steps = max(1, int(steps)) # Ensure steps is a positive integer
|
39 |
+
base = mask_num // steps
|
40 |
+
remainder = mask_num % steps
|
41 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
|
42 |
+
for i in range(mask_num.size(0)): # Iterate over batch
|
43 |
+
if remainder[i] > 0 : # Ensure remainder is positive before indexing
|
44 |
+
num_transfer_tokens[i, :remainder[i].item()] += 1 # .item() for single value tensor to int
|
45 |
+
return num_transfer_tokens
|
46 |
+
|
47 |
+
MODEL = None
|
48 |
+
TOKENIZER = None
|
49 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
50 |
+
MASK_ID = None
|
51 |
+
uni_prompting = None
|
52 |
+
VQ_MODEL = MAGVITv2().from_pretrained("/data_storage/shared/pretrained_models/models--showlab--magvitv2").to(DEVICE)
|
53 |
+
|
54 |
+
DEFAULT_MODEL_PATH = "/data_storage/lbw/MMaDA/mmada-training-stage3-llada-instruct-512-cot-uni/checkpoint-210000/unwrapped_model" # Default
|
55 |
+
CURRENT_MODEL_PATH = None
|
56 |
+
|
57 |
+
MODEL_CHOICES = [
|
58 |
+
"MMaDA-8B-Base",
|
59 |
+
"MMaDA-8B-MixCoT (coming soon)",
|
60 |
+
"MMaDA-8B-Max (coming soon)"
|
61 |
+
]
|
62 |
+
MODEL_ACTUAL_PATHS = {
|
63 |
+
"MMaDA-8B-Base": DEFAULT_MODEL_PATH,
|
64 |
+
}
|
65 |
+
|
66 |
+
def clear_outputs_action():
|
67 |
+
return None, None
|
68 |
+
|
69 |
+
def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_status):
|
70 |
+
global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
|
71 |
+
|
72 |
+
if MODEL is not None and CURRENT_MODEL_PATH == model_path_to_load:
|
73 |
+
return f"Model '{model_display_name_for_status}' from '{model_path_to_load}' is already loaded. MASK_ID: {MASK_ID}"
|
74 |
+
|
75 |
+
CURRENT_MODEL_PATH = model_path_to_load
|
76 |
+
|
77 |
+
status_msg_parts = [f"Loading '{model_display_name_for_status}'..."]
|
78 |
+
try:
|
79 |
+
TOKENIZER = AutoTokenizer.from_pretrained(model_path_to_load, trust_remote_code=True)
|
80 |
+
status_msg_parts.append(f"Tokenizer for '{model_display_name_for_status}' loaded.")
|
81 |
+
|
82 |
+
MODEL = MMadaModelLM.from_pretrained(model_path_to_load, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
|
83 |
+
status_msg_parts.append(f"Model '{model_display_name_for_status}' loaded to {DEVICE}.")
|
84 |
+
|
85 |
+
uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
|
86 |
+
|
87 |
+
if hasattr(TOKENIZER, 'mask_token_id') and TOKENIZER.mask_token_id is not None:
|
88 |
+
MASK_ID = TOKENIZER.mask_token_id
|
89 |
+
status_msg_parts.append(f"Using MASK_ID from tokenizer: {MASK_ID}.")
|
90 |
+
else:
|
91 |
+
MASK_ID = 126336
|
92 |
+
status_msg_parts.append(f"Using default MASK_ID: {MASK_ID}.")
|
93 |
+
|
94 |
+
if TOKENIZER.pad_token_id is None:
|
95 |
+
if TOKENIZER.eos_token_id is not None:
|
96 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
97 |
+
TOKENIZER.pad_token = TOKENIZER.eos_token
|
98 |
+
status_msg_parts.append(f"Set pad_token_id to eos_token_id ({TOKENIZER.eos_token_id}).")
|
99 |
+
else:
|
100 |
+
status_msg_parts.append("Warning: pad_token_id is None and no eos_token_id.")
|
101 |
+
|
102 |
+
if TOKENIZER.eos_token_id is None: # Important for cleaning up output in visualization
|
103 |
+
status_msg_parts.append("Warning: tokenizer.eos_token_id is None. EOS cleanup might not work.")
|
104 |
+
|
105 |
+
TOKENIZER.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
|
106 |
+
|
107 |
+
return " ".join(status_msg_parts)
|
108 |
+
except Exception as e:
|
109 |
+
MODEL = None
|
110 |
+
TOKENIZER = None
|
111 |
+
MASK_ID = None
|
112 |
+
CURRENT_MODEL_PATH = None
|
113 |
+
return f"Error loading model '{model_display_name_for_status}': {str(e)}"
|
114 |
+
|
115 |
+
def handle_model_selection_change(selected_model_name_ui):
|
116 |
+
if "coming soon" in selected_model_name_ui.lower():
|
117 |
+
global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
|
118 |
+
MODEL = None
|
119 |
+
TOKENIZER = None
|
120 |
+
MASK_ID = None
|
121 |
+
CURRENT_MODEL_PATH = None
|
122 |
+
return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
|
123 |
+
|
124 |
+
actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
|
125 |
+
if not actual_path:
|
126 |
+
return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
|
127 |
+
|
128 |
+
return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
|
129 |
+
|
130 |
+
|
131 |
+
def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
|
132 |
+
if current_x_ids_batch is None or current_x_ids_batch.ndim == 0 or current_x_ids_batch.shape[0] == 0:
|
133 |
+
return [("Error in sequence data for visualization.", "ERROR")]
|
134 |
+
# only answer part
|
135 |
+
current_x_ids_batch = current_x_ids_batch[:, prompt_len:]
|
136 |
+
seq_ids = current_x_ids_batch[0].tolist()
|
137 |
+
eos_token_id = tk.eos_token_id # Get EOS token ID
|
138 |
+
|
139 |
+
# Stage 1: Build initial list of tuples with (token_str, label, token_id_int)
|
140 |
+
# This helps in identifying EOS tokens later without re-checking the type.
|
141 |
+
intermediate_tuples = []
|
142 |
+
for j, token_id_int in enumerate(seq_ids):
|
143 |
+
try:
|
144 |
+
token_str = tk.decode([token_id_int], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
145 |
+
except Exception: # Handle cases where a token ID might be problematic (e.g. with mock)
|
146 |
+
token_str = f"[ID:{token_id_int}]"
|
147 |
+
|
148 |
+
label = "ERROR"
|
149 |
+
if token_id_int == current_mask_id:
|
150 |
+
token_str = "[MASK]"
|
151 |
+
label = "MASK"
|
152 |
+
else:
|
153 |
+
label = "GEN"
|
154 |
+
intermediate_tuples.append((token_str, label, token_id_int))
|
155 |
+
|
156 |
+
return intermediate_tuples
|
157 |
+
|
158 |
+
@torch.no_grad()
|
159 |
+
def generate_viz_wrapper_t2i(prompt_text, steps, guidance_scale, mask_schedule="cosine"):
|
160 |
+
global MODEL, TOKENIZER, MASK_ID, DEVICE, uni_prompting
|
161 |
+
|
162 |
+
if MODEL is None or TOKENIZER is None or MASK_ID is None:
|
163 |
+
yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
|
164 |
+
return
|
165 |
+
steps = int(steps)
|
166 |
+
guidance_scale = float(guidance_scale)
|
167 |
+
|
168 |
+
image_tokens = torch.ones((1, 1024), dtype=torch.long, device=DEVICE) * MASK_ID
|
169 |
+
prompt_text = [prompt_text]
|
170 |
+
input_ids, attention_mask = uni_prompting((prompt_text, image_tokens), 't2i_gen')
|
171 |
+
|
172 |
+
if guidance_scale > 0:
|
173 |
+
uncond_input_ids, uncond_attention_mask = uni_prompting(([''], image_tokens), 't2i_gen')
|
174 |
+
else:
|
175 |
+
uncond_input_ids, uncond_attention_mask = None, None
|
176 |
+
|
177 |
+
mask_schedule = get_mask_schedule(mask_schedule)
|
178 |
+
blank_image = Image.new("RGB", (512, 512), (255, 255, 255))
|
179 |
+
yield blank_image, "Starting generation..."
|
180 |
+
for image_step, status_msg_step in MODEL.t2i_generate_decoding_stepwise(
|
181 |
+
input_ids = input_ids,
|
182 |
+
uncond_input_ids = uncond_input_ids,
|
183 |
+
attention_mask = attention_mask,
|
184 |
+
uncond_attention_mask = uncond_attention_mask,
|
185 |
+
temperature=1.0,
|
186 |
+
timesteps = steps,
|
187 |
+
guidance_scale = guidance_scale,
|
188 |
+
noise_schedule = mask_schedule,
|
189 |
+
noise_type = "mask",
|
190 |
+
seq_len = 1024,
|
191 |
+
vq_model = VQ_MODEL,
|
192 |
+
uni_prompting=uni_prompting):
|
193 |
+
yield image_step, status_msg_step
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
@torch.no_grad()
|
199 |
+
def generate_viz_wrapper_lm(prompt_text, steps, gen_length, block_length, temperature,
|
200 |
+
cfg_scale, remasking_strategy, thinking_mode_lm):
|
201 |
+
global MODEL, TOKENIZER, MASK_ID, DEVICE
|
202 |
+
print(f"thinking_mode_lm: {thinking_mode_lm}")
|
203 |
+
if MODEL is None or TOKENIZER is None or MASK_ID is None:
|
204 |
+
yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
|
205 |
+
return
|
206 |
+
|
207 |
+
steps = int(steps)
|
208 |
+
gen_length = int(gen_length)
|
209 |
+
block_length = int(block_length)
|
210 |
+
|
211 |
+
if thinking_mode_lm:
|
212 |
+
prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
|
213 |
+
|
214 |
+
try:
|
215 |
+
m = [{"role": "user", "content": prompt_text}]
|
216 |
+
processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
|
217 |
+
except Exception as e:
|
218 |
+
yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
|
219 |
+
processed_prompt_text = prompt_text
|
220 |
+
try:
|
221 |
+
if TOKENIZER.pad_token_id is None:
|
222 |
+
if TOKENIZER.eos_token_id is not None:
|
223 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
224 |
+
else: # Should have been caught by load_model, but double check
|
225 |
+
yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
|
226 |
+
return
|
227 |
+
|
228 |
+
input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
|
229 |
+
raw_prompt_attention_mask = None
|
230 |
+
|
231 |
+
except Exception as e:
|
232 |
+
yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
|
233 |
+
return
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
batch_size = input_ids.shape[0]
|
238 |
+
prompt_len = input_ids.shape[1]
|
239 |
+
|
240 |
+
x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
|
241 |
+
x[:, :prompt_len] = input_ids.clone()
|
242 |
+
|
243 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
|
244 |
+
|
245 |
+
if gen_length == 0:
|
246 |
+
final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
|
247 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
|
248 |
+
return
|
249 |
+
|
250 |
+
if block_length <= 0 or gen_length % block_length != 0 :
|
251 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
252 |
+
f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
|
253 |
+
return
|
254 |
+
num_blocks = gen_length // block_length
|
255 |
+
|
256 |
+
if steps <=0 or steps % num_blocks != 0:
|
257 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
258 |
+
f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
|
259 |
+
return
|
260 |
+
steps_per_block = steps // num_blocks
|
261 |
+
|
262 |
+
for num_block_iter in range(num_blocks):
|
263 |
+
current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
|
264 |
+
current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
|
265 |
+
|
266 |
+
block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
|
267 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
|
268 |
+
(x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
|
269 |
+
|
270 |
+
num_transfer_tokens_for_this_block = get_num_transfer_tokens(
|
271 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
|
272 |
+
steps_per_block
|
273 |
+
)
|
274 |
+
|
275 |
+
for i_step_in_block in range(steps_per_block):
|
276 |
+
mask_index_global = (x == MASK_ID)
|
277 |
+
|
278 |
+
if cfg_scale > 0.:
|
279 |
+
un_x = x.clone()
|
280 |
+
# For unconditional pass, mask out the original prompt tokens that are not padding
|
281 |
+
# raw_prompt_attention_mask is (B, prompt_len)
|
282 |
+
prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
|
283 |
+
un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
|
284 |
+
|
285 |
+
x_cfg_input = torch.cat([x, un_x], dim=0)
|
286 |
+
# Pass attention_mask for CFG if model expects it, covering both parts
|
287 |
+
# For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
|
288 |
+
model_output = MODEL(x_cfg_input)
|
289 |
+
logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
|
290 |
+
logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
|
291 |
+
else:
|
292 |
+
# Not passing explicit attention_mask here; relies on model's internal handling.
|
293 |
+
model_output = MODEL(x)
|
294 |
+
logits = model_output.logits
|
295 |
+
|
296 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
297 |
+
x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
|
298 |
+
|
299 |
+
if remasking_strategy == 'low_confidence':
|
300 |
+
probs = F.softmax(logits.to(torch.float64), dim=-1)
|
301 |
+
x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
|
302 |
+
elif remasking_strategy == 'random':
|
303 |
+
x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
|
304 |
+
else:
|
305 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
|
306 |
+
return
|
307 |
+
|
308 |
+
confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
|
309 |
+
candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
|
310 |
+
confidence_for_selection = torch.where(
|
311 |
+
candidate_positions_for_unmasking,
|
312 |
+
x0_probs,
|
313 |
+
-torch.inf
|
314 |
+
)
|
315 |
+
|
316 |
+
x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
|
317 |
+
|
318 |
+
transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
|
319 |
+
num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
|
320 |
+
|
321 |
+
for j_batch_idx in range(batch_size):
|
322 |
+
k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
|
323 |
+
candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
|
324 |
+
|
325 |
+
if k_val > 0:
|
326 |
+
# Ensure confidence_for_selection[j_batch_idx] is 1D for topk
|
327 |
+
conf_slice = confidence_for_selection[j_batch_idx]
|
328 |
+
if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
|
329 |
+
|
330 |
+
# Check if there are enough valid (non -inf) confidences
|
331 |
+
valid_conf_count = (conf_slice > -torch.inf).sum().item()
|
332 |
+
actual_k = min(k_val, valid_conf_count)
|
333 |
+
|
334 |
+
if actual_k > 0:
|
335 |
+
_, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
|
336 |
+
transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
|
337 |
+
|
338 |
+
x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
|
339 |
+
|
340 |
+
current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
|
341 |
+
total_overall_steps = num_blocks * steps_per_block
|
342 |
+
status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
|
343 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
|
344 |
+
|
345 |
+
final_generated_ids = x[:, prompt_len:]
|
346 |
+
final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
|
347 |
+
|
348 |
+
final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
|
349 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
|
350 |
+
|
351 |
+
@torch.no_grad()
|
352 |
+
def generate_viz_wrapper(uploaded_image_pil, prompt_text, steps, gen_length, block_length, temperature,
|
353 |
+
cfg_scale, remasking_strategy, thinking_mode_mmu):
|
354 |
+
global MODEL, TOKENIZER, MASK_ID, DEVICE
|
355 |
+
|
356 |
+
if MODEL is None or TOKENIZER is None or MASK_ID is None:
|
357 |
+
yield [("Error: Model not loaded. Please load the model first.", "ERROR")], "Model not loaded."
|
358 |
+
return
|
359 |
+
|
360 |
+
steps = int(steps)
|
361 |
+
gen_length = int(gen_length)
|
362 |
+
block_length = int(block_length)
|
363 |
+
|
364 |
+
if thinking_mode_mmu:
|
365 |
+
prompt_text = "You should first think about the reasoning process in the mind and then provide the user with the answer. The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here\n" + prompt_text
|
366 |
+
|
367 |
+
try:
|
368 |
+
m = [{"role": "user", "content": prompt_text}]
|
369 |
+
processed_prompt_text = TOKENIZER.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
|
370 |
+
except Exception as e:
|
371 |
+
yield [("Error applying chat template.", "ERROR")], f"Chat template error: {e}"
|
372 |
+
processed_prompt_text = prompt_text
|
373 |
+
|
374 |
+
image_vq_ids_tensor = None
|
375 |
+
if uploaded_image_pil is not None:
|
376 |
+
try:
|
377 |
+
|
378 |
+
image = image_transform(uploaded_image_pil, resolution=512).to(DEVICE)
|
379 |
+
image = image.unsqueeze(0)
|
380 |
+
image_vq_ids_tensor = VQ_MODEL.get_code(image) + 126349
|
381 |
+
except Exception as e:
|
382 |
+
yield [("Error processing image.", "ERROR")], f"Image to VQ tokens conversion failed: {str(e)}"
|
383 |
+
return
|
384 |
+
|
385 |
+
|
386 |
+
try:
|
387 |
+
if TOKENIZER.pad_token_id is None:
|
388 |
+
if TOKENIZER.eos_token_id is not None:
|
389 |
+
TOKENIZER.pad_token_id = TOKENIZER.eos_token_id
|
390 |
+
else:
|
391 |
+
yield [("Tokenizer Error", "ERROR")], "pad_token_id is not set in tokenizer."
|
392 |
+
return
|
393 |
+
|
394 |
+
input_ids = TOKENIZER(text=processed_prompt_text, return_tensors="pt", padding="longest", padding_side="left", truncation=True, max_length=MODEL.config.max_position_embeddings if hasattr(MODEL.config, 'max_position_embeddings') else 2048)['input_ids'].to(DEVICE)
|
395 |
+
raw_prompt_attention_mask = None
|
396 |
+
if image_vq_ids_tensor is not None:
|
397 |
+
if image_vq_ids_tensor.ndim == 1:
|
398 |
+
image_vq_ids_tensor = image_vq_ids_tensor.unsqueeze(0)
|
399 |
+
|
400 |
+
input_ids = torch.cat([
|
401 |
+
(torch.ones(input_ids.shape[0], 1) * torch.tensor([126089])).to(DEVICE),
|
402 |
+
(torch.ones(input_ids.shape[0], 1) * torch.tensor([126084])).to(DEVICE),
|
403 |
+
image_vq_ids_tensor,
|
404 |
+
(torch.ones(input_ids.shape[0], 1) * torch.tensor([126085])).to(DEVICE),
|
405 |
+
input_ids
|
406 |
+
], dim=1).long()
|
407 |
+
|
408 |
+
else:
|
409 |
+
input_ids = input_ids
|
410 |
+
|
411 |
+
|
412 |
+
except Exception as e:
|
413 |
+
yield [("Error tokenizing prompt.", "ERROR")], f"Tokenization error: {e}"
|
414 |
+
return
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
batch_size = input_ids.shape[0]
|
419 |
+
prompt_len = input_ids.shape[1]
|
420 |
+
|
421 |
+
x = torch.full((batch_size, prompt_len + gen_length), MASK_ID, dtype=torch.long, device=DEVICE)
|
422 |
+
x[:, :prompt_len] = input_ids.clone()
|
423 |
+
|
424 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), "Starting generation: Prompt + Initial Masks"
|
425 |
+
|
426 |
+
if gen_length == 0:
|
427 |
+
final_text_output = TOKENIZER.batch_decode(x[:,prompt_len:], skip_special_tokens=True)
|
428 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_output[0] if final_text_output else ""
|
429 |
+
return
|
430 |
+
|
431 |
+
if block_length <= 0 or gen_length % block_length != 0 :
|
432 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
433 |
+
f"Error: gen_length ({gen_length}) must be divisible by block_length ({block_length}) and block_length > 0."
|
434 |
+
return
|
435 |
+
num_blocks = gen_length // block_length
|
436 |
+
|
437 |
+
if steps <=0 or steps % num_blocks != 0:
|
438 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), \
|
439 |
+
f"Error: steps ({steps}) must be positive and divisible by num_blocks ({num_blocks}). Steps: {steps}, Num Blocks: {num_blocks}"
|
440 |
+
return
|
441 |
+
steps_per_block = steps // num_blocks
|
442 |
+
|
443 |
+
for num_block_iter in range(num_blocks):
|
444 |
+
current_block_start_idx_in_x = prompt_len + num_block_iter * block_length
|
445 |
+
current_block_end_idx_in_x = prompt_len + (num_block_iter + 1) * block_length
|
446 |
+
|
447 |
+
block_masks_bool_current = torch.zeros_like(x, dtype=torch.bool)
|
448 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x] = \
|
449 |
+
(x[:, current_block_start_idx_in_x:current_block_end_idx_in_x] == MASK_ID)
|
450 |
+
|
451 |
+
num_transfer_tokens_for_this_block = get_num_transfer_tokens(
|
452 |
+
block_masks_bool_current[:, current_block_start_idx_in_x:current_block_end_idx_in_x],
|
453 |
+
steps_per_block
|
454 |
+
)
|
455 |
+
|
456 |
+
for i_step_in_block in range(steps_per_block):
|
457 |
+
mask_index_global = (x == MASK_ID)
|
458 |
+
|
459 |
+
if cfg_scale > 0.:
|
460 |
+
un_x = x.clone()
|
461 |
+
# For unconditional pass, mask out the original prompt tokens that are not padding
|
462 |
+
# raw_prompt_attention_mask is (B, prompt_len)
|
463 |
+
prompt_active_tokens_mask = raw_prompt_attention_mask.bool() # True where actual prompt tokens are
|
464 |
+
un_x[:, :prompt_len][prompt_active_tokens_mask] = MASK_ID
|
465 |
+
|
466 |
+
x_cfg_input = torch.cat([x, un_x], dim=0)
|
467 |
+
# Pass attention_mask for CFG if model expects it, covering both parts
|
468 |
+
# For simplicity, not passing explicit attention_mask here; relies on model's internal handling.
|
469 |
+
model_output = MODEL(x_cfg_input)
|
470 |
+
logits_cond, logits_uncond = torch.chunk(model_output.logits, 2, dim=0)
|
471 |
+
logits = logits_uncond + (cfg_scale + 1) * (logits_cond - logits_uncond)
|
472 |
+
else:
|
473 |
+
# Not passing explicit attention_mask here; relies on model's internal handling.
|
474 |
+
model_output = MODEL(x)
|
475 |
+
logits = model_output.logits
|
476 |
+
|
477 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
478 |
+
x0_predicted_tokens = torch.argmax(logits_with_noise, dim=-1)
|
479 |
+
|
480 |
+
if remasking_strategy == 'low_confidence':
|
481 |
+
probs = F.softmax(logits.to(torch.float64), dim=-1)
|
482 |
+
x0_probs = torch.gather(probs, dim=-1, index=x0_predicted_tokens.unsqueeze(-1)).squeeze(-1)
|
483 |
+
elif remasking_strategy == 'random':
|
484 |
+
x0_probs = torch.rand(x.shape, device=x.device, dtype=torch.float64)
|
485 |
+
else:
|
486 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), f"Error: Unknown remasking strategy '{remasking_strategy}'"
|
487 |
+
return
|
488 |
+
|
489 |
+
confidence_for_selection = torch.full_like(x0_probs, -torch.inf)
|
490 |
+
candidate_positions_for_unmasking = mask_index_global & block_masks_bool_current
|
491 |
+
confidence_for_selection = torch.where(
|
492 |
+
candidate_positions_for_unmasking,
|
493 |
+
x0_probs,
|
494 |
+
-torch.inf
|
495 |
+
)
|
496 |
+
|
497 |
+
x0_final_candidates = torch.where(mask_index_global, x0_predicted_tokens, x)
|
498 |
+
|
499 |
+
transfer_indices_bool = torch.zeros_like(x, dtype=torch.bool)
|
500 |
+
num_to_transfer_this_step_batch = num_transfer_tokens_for_this_block[:, i_step_in_block]
|
501 |
+
|
502 |
+
for j_batch_idx in range(batch_size):
|
503 |
+
k_val = min(num_to_transfer_this_step_batch[j_batch_idx].item(),
|
504 |
+
candidate_positions_for_unmasking[j_batch_idx].sum().item()) # ensure k isn't too large
|
505 |
+
|
506 |
+
if k_val > 0:
|
507 |
+
# Ensure confidence_for_selection[j_batch_idx] is 1D for topk
|
508 |
+
conf_slice = confidence_for_selection[j_batch_idx]
|
509 |
+
if conf_slice.ndim > 1: conf_slice = conf_slice.view(-1) # Should already be 1D from x0_probs
|
510 |
+
|
511 |
+
# Check if there are enough valid (non -inf) confidences
|
512 |
+
valid_conf_count = (conf_slice > -torch.inf).sum().item()
|
513 |
+
actual_k = min(k_val, valid_conf_count)
|
514 |
+
|
515 |
+
if actual_k > 0:
|
516 |
+
_, topk_indices_in_x = torch.topk(conf_slice, k=actual_k)
|
517 |
+
transfer_indices_bool[j_batch_idx, topk_indices_in_x] = True
|
518 |
+
|
519 |
+
x[transfer_indices_bool] = x0_final_candidates[transfer_indices_bool]
|
520 |
+
|
521 |
+
current_total_step = num_block_iter * steps_per_block + i_step_in_block + 1
|
522 |
+
total_overall_steps = num_blocks * steps_per_block
|
523 |
+
status_msg = f"Block {num_block_iter+1}/{num_blocks}, Step {i_step_in_block+1}/{steps_per_block} (Total: {current_total_step}/{total_overall_steps})"
|
524 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), status_msg
|
525 |
+
|
526 |
+
final_generated_ids = x[:, prompt_len:]
|
527 |
+
final_text_output = TOKENIZER.batch_decode(final_generated_ids, skip_special_tokens=True)
|
528 |
+
|
529 |
+
final_text_str = final_text_output[0] if final_text_output and len(final_text_output) > 0 else ""
|
530 |
+
yield get_highlighted_text_tuples(x, input_ids, prompt_len, TOKENIZER, MASK_ID, raw_prompt_attention_mask), final_text_str
|
531 |
+
|
532 |
+
|
533 |
+
css_styles = """
|
534 |
+
.gradio-container{font-family:'IBM Plex Sans',sans-serif;margin:auto;}
|
535 |
+
.gr-input {background:#f9f9f9 !important;border:1px solid #e0e0e0 !important;}
|
536 |
+
.gr-output{background:#f0f0f0 !important;border:1px solid #d0d0d0 !important;}
|
537 |
+
|
538 |
+
.highlighted-text span{
|
539 |
+
padding:2px 4px;border-radius:4px;margin:1px 2px;display:inline-block;line-height:1.6;
|
540 |
+
}
|
541 |
+
|
542 |
+
footer{display:none !important}
|
543 |
+
|
544 |
+
#live-update-scrollable-box {
|
545 |
+
max-height: 800px; /* 您可以根据需要调整这个最大高度,例如 '300px', '50vh' 等 */
|
546 |
+
overflow-y: auto !important; /* 当内容超出 max-height 时显示垂直滚动条 */
|
547 |
+
display: block; /* 确保元素是块级元素,以便 max-height 生效 */
|
548 |
+
|
549 |
+
}
|
550 |
+
#think_btn {
|
551 |
+
background-color: #f3f4f6 !important;
|
552 |
+
border: 1px solid #d0d0d0 !important;
|
553 |
+
color: #111827 !important;
|
554 |
+
font-size: 16px !important;
|
555 |
+
font-weight: bold !important;
|
556 |
+
}
|
557 |
+
#think_btn:hover {
|
558 |
+
background-color: #e0e0e0 !important;
|
559 |
+
border: 1px solid #c0c0c0 !important;
|
560 |
+
color: #222 !important;
|
561 |
+
}
|
562 |
+
#think_btn:active {
|
563 |
+
background-color: #2563eb !important;
|
564 |
+
border: 1px solid #b0b0b0 !important;
|
565 |
+
color: white !important;
|
566 |
+
}
|
567 |
+
"""
|
568 |
+
|
569 |
+
|
570 |
+
# thinking_mode_t2i = gr.State(False)
|
571 |
+
def toggle_thinking_mode_lm(current_thinking_mode):
|
572 |
+
# print(f"current_thinking_mode: {current_thinking_mode}")
|
573 |
+
new_state = not current_thinking_mode
|
574 |
+
new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
|
575 |
+
return new_state, gr.update(value=new_label)
|
576 |
+
|
577 |
+
def toggle_thinking_mode_mmu(current_thinking_mode):
|
578 |
+
new_state = not current_thinking_mode
|
579 |
+
new_label = "Thinking Mode ✅" if new_state else "Thinking Mode ❌"
|
580 |
+
return new_state, gr.update(value=new_label)
|
581 |
+
|
582 |
+
|
583 |
+
color_map_config = {
|
584 |
+
"MASK": "lightgrey",
|
585 |
+
"GEN": "#DCABFA",
|
586 |
+
}
|
587 |
+
|
588 |
+
theme = gr.themes.Ocean(
|
589 |
+
primary_hue="fuchsia",
|
590 |
+
)
|
591 |
+
with gr.Blocks(css=css_styles, theme=theme) as demo:
|
592 |
+
# with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
|
593 |
+
# with gr.Blocks() as demo:
|
594 |
+
thinking_mode_lm = gr.State(False)
|
595 |
+
thinking_mode_mmu = gr.State(False)
|
596 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>MMaDA </h1>")
|
597 |
+
gr.Markdown("Interactively explore the step-by-step generation process of a diffusion language model. "
|
598 |
+
"The model begins with a fully masked sequence (except for the prompt) and progressively refines it by unmasking tokens.")
|
599 |
+
gr.Markdown("### Select Model")
|
600 |
+
with gr.Row():
|
601 |
+
model_select_radio = gr.Radio(
|
602 |
+
label="Select Text Generation Model",
|
603 |
+
choices=MODEL_CHOICES,
|
604 |
+
value=MODEL_CHOICES[0]
|
605 |
+
)
|
606 |
+
model_load_status_box = gr.Textbox(
|
607 |
+
label="Model Load Status",
|
608 |
+
interactive=False,
|
609 |
+
lines=3,
|
610 |
+
max_lines=5
|
611 |
+
)
|
612 |
+
|
613 |
+
gr.Markdown("## Part 1. Text Generation")
|
614 |
+
with gr.Row():
|
615 |
+
with gr.Column(scale=2):
|
616 |
+
prompt_input_box_lm = gr.Textbox(label="Enter your prompt:", lines=3, value="A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?")
|
617 |
+
think_button_lm = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
|
618 |
+
with gr.Accordion("Generation Parameters", open=True):
|
619 |
+
with gr.Row():
|
620 |
+
gen_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
|
621 |
+
steps_slider_lm = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
622 |
+
with gr.Row():
|
623 |
+
block_length_slider_lm = gr.Slider(minimum=8, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
|
624 |
+
remasking_dropdown_lm = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
|
625 |
+
with gr.Row():
|
626 |
+
cfg_scale_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
|
627 |
+
temperature_slider_lm = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
|
628 |
+
|
629 |
+
|
630 |
+
with gr.Row():
|
631 |
+
run_button_ui_lm = gr.Button("Generate Sequence", variant="primary", scale=3)
|
632 |
+
clear_button_ui_lm = gr.Button("Clear Outputs", scale=1)
|
633 |
+
|
634 |
+
with gr.Column(scale=3):
|
635 |
+
# gr.Markdown("## Live Generation Process")
|
636 |
+
output_visualization_box_lm = gr.HighlightedText(
|
637 |
+
label="Live Generation Process",
|
638 |
+
show_legend=True,
|
639 |
+
color_map=color_map_config,
|
640 |
+
combine_adjacent=False,
|
641 |
+
interactive=False,
|
642 |
+
elem_id="live-update-scrollable-box",
|
643 |
+
)
|
644 |
+
# gr.Markdown("## Final Generated Text")
|
645 |
+
output_final_text_box_lm = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
646 |
+
|
647 |
+
|
648 |
+
|
649 |
+
gr.Examples(
|
650 |
+
examples=[
|
651 |
+
["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
|
652 |
+
["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
|
653 |
+
],
|
654 |
+
inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
|
655 |
+
outputs=[output_visualization_box_lm, output_final_text_box_lm],
|
656 |
+
fn=generate_viz_wrapper_lm,
|
657 |
+
)
|
658 |
+
|
659 |
+
gr.Markdown("---")
|
660 |
+
gr.Markdown("## Part 2. Multimodal Understanding")
|
661 |
+
with gr.Row():
|
662 |
+
with gr.Column(scale=2):
|
663 |
+
prompt_input_box_mmu = gr.Textbox(
|
664 |
+
label="Enter your prompt:",
|
665 |
+
lines=3,
|
666 |
+
value="Please describe this image in detail."
|
667 |
+
)
|
668 |
+
think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
|
669 |
+
with gr.Accordion("Generation Parameters", open=True):
|
670 |
+
with gr.Row():
|
671 |
+
gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
|
672 |
+
steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
673 |
+
with gr.Row():
|
674 |
+
block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
|
675 |
+
remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
|
676 |
+
with gr.Row():
|
677 |
+
cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
|
678 |
+
temperature_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=1, step=0.05, label="Temperature", info="Controls randomness via Gumbel noise. 0 is deterministic.")
|
679 |
+
|
680 |
+
with gr.Row():
|
681 |
+
image_upload_box = gr.Image(type="pil", label="Upload Image")
|
682 |
+
|
683 |
+
with gr.Row():
|
684 |
+
run_button_ui_mmu = gr.Button("Generate Description", variant="primary", scale=3)
|
685 |
+
clear_button_ui_mmu = gr.Button("Clear Outputs", scale=1)
|
686 |
+
|
687 |
+
with gr.Column(scale=3):
|
688 |
+
gr.Markdown("## Live Generation Process")
|
689 |
+
output_visualization_box_mmu = gr.HighlightedText(
|
690 |
+
label="Token Sequence (Live Update)",
|
691 |
+
show_legend=True,
|
692 |
+
color_map=color_map_config,
|
693 |
+
combine_adjacent=False,
|
694 |
+
interactive=False,
|
695 |
+
elem_id="live-update-scrollable-box",
|
696 |
+
)
|
697 |
+
gr.Markdown("## Final Generated Text")
|
698 |
+
output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
|
699 |
+
|
700 |
+
|
701 |
+
gr.Examples(
|
702 |
+
examples=[
|
703 |
+
[
|
704 |
+
"mmu_validation_2/sunflower.jpg",
|
705 |
+
"Please describe this image in detail.",
|
706 |
+
256,
|
707 |
+
512,
|
708 |
+
128,
|
709 |
+
1,
|
710 |
+
0,
|
711 |
+
"low_confidence"
|
712 |
+
],
|
713 |
+
[
|
714 |
+
"mmu_validation_2/woman.jpg",
|
715 |
+
"Please describe this image in detail.",
|
716 |
+
256,
|
717 |
+
512,
|
718 |
+
128,
|
719 |
+
1,
|
720 |
+
0,
|
721 |
+
"low_confidence"
|
722 |
+
]
|
723 |
+
],
|
724 |
+
inputs=[
|
725 |
+
image_upload_box,
|
726 |
+
prompt_input_box_mmu,
|
727 |
+
steps_slider_mmu,
|
728 |
+
gen_length_slider_mmu,
|
729 |
+
block_length_slider_mmu,
|
730 |
+
temperature_slider_mmu,
|
731 |
+
cfg_scale_slider_mmu,
|
732 |
+
remasking_dropdown_mmu
|
733 |
+
],
|
734 |
+
outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
|
735 |
+
fn=generate_viz_wrapper,
|
736 |
+
)
|
737 |
+
|
738 |
+
gr.Markdown("---")
|
739 |
+
gr.Markdown("## Part 3. Text-to-Image Generation")
|
740 |
+
with gr.Row():
|
741 |
+
with gr.Column(scale=2):
|
742 |
+
prompt_input_box_t2i = gr.Textbox(label="Enter your prompt:", lines=3, value="A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.")
|
743 |
+
|
744 |
+
with gr.Accordion("Generation Parameters", open=True):
|
745 |
+
with gr.Row():
|
746 |
+
steps_slider_t2i = gr.Slider(minimum=5, maximum=100, value=15, step=5, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
|
747 |
+
guidance_scale_slider_t2i = gr.Slider(minimum=0.0, maximum=7.0, value=3.5, step=0.5, label="Guidance Scale", info="Classifier-Free Guidance. 0 disables it.")
|
748 |
+
|
749 |
+
|
750 |
+
with gr.Row():
|
751 |
+
scheduler_radio_t2i = gr.Radio(
|
752 |
+
choices=["cosine", "sigmoid", "linear"],
|
753 |
+
value="cosine",
|
754 |
+
label="Scheduler",
|
755 |
+
)
|
756 |
+
|
757 |
+
with gr.Row():
|
758 |
+
run_button_ui_t2i = gr.Button("Generate Image", variant="primary", scale=3)
|
759 |
+
clear_button_ui_t2i = gr.Button("Clear Outputs", scale=1)
|
760 |
+
|
761 |
+
|
762 |
+
with gr.Column(scale=3):
|
763 |
+
# gr.Markdown("## Live Generation Process")
|
764 |
+
output_image_t2i = gr.Image(label="Generated Image", interactive=False, type="pil")
|
765 |
+
output_status_t2i = gr.Textbox(label="Generation Status", interactive=False)
|
766 |
+
|
767 |
+
gr.Examples(
|
768 |
+
examples=[
|
769 |
+
["A sea turtle swimming near a coral reef in the ocean, with a clear blue sky and water in the background.", 15, 3.5, "cosine"],
|
770 |
+
["A beautiful sunset over a calm ocean, with a few clouds in the sky.", 15, 3.5, "cosine"]
|
771 |
+
],
|
772 |
+
inputs=[prompt_input_box_t2i, steps_slider_t2i, guidance_scale_slider_t2i, scheduler_radio_t2i],
|
773 |
+
outputs=[output_image_t2i, output_status_t2i],
|
774 |
+
fn=generate_viz_wrapper_t2i,
|
775 |
+
)
|
776 |
+
|
777 |
+
run_button_ui_t2i.click(
|
778 |
+
fn=generate_viz_wrapper_t2i,
|
779 |
+
inputs=[
|
780 |
+
prompt_input_box_t2i,
|
781 |
+
steps_slider_t2i,
|
782 |
+
guidance_scale_slider_t2i,
|
783 |
+
scheduler_radio_t2i
|
784 |
+
],
|
785 |
+
outputs=[output_image_t2i, output_status_t2i]
|
786 |
+
)
|
787 |
+
|
788 |
+
clear_button_ui_t2i.click(
|
789 |
+
fn=lambda: (None, ""),
|
790 |
+
inputs=None,
|
791 |
+
outputs=[output_image_t2i, output_status_t2i],
|
792 |
+
queue=False
|
793 |
+
)
|
794 |
+
|
795 |
+
think_button_lm.click(
|
796 |
+
fn=toggle_thinking_mode_lm,
|
797 |
+
inputs=[thinking_mode_lm],
|
798 |
+
outputs=[thinking_mode_lm, think_button_lm]
|
799 |
+
)
|
800 |
+
|
801 |
+
think_button_mmu.click(
|
802 |
+
fn=toggle_thinking_mode_mmu,
|
803 |
+
inputs=[thinking_mode_mmu],
|
804 |
+
outputs=[thinking_mode_mmu, think_button_mmu]
|
805 |
+
)
|
806 |
+
|
807 |
+
|
808 |
+
|
809 |
+
def initialize_default_model():
|
810 |
+
default_model = "MMaDA-8B-Base"
|
811 |
+
result = handle_model_selection_change(default_model)
|
812 |
+
return default_model, result
|
813 |
+
|
814 |
+
demo.load(
|
815 |
+
fn=initialize_default_model,
|
816 |
+
inputs=None,
|
817 |
+
outputs=[model_select_radio, model_load_status_box],
|
818 |
+
queue=True
|
819 |
+
)
|
820 |
+
|
821 |
+
def clear_outputs():
|
822 |
+
return None, None, None # Clear image, visualization, and final text
|
823 |
+
|
824 |
+
clear_button_ui_lm.click(
|
825 |
+
fn=clear_outputs,
|
826 |
+
inputs=None,
|
827 |
+
outputs=[image_upload_box, output_visualization_box_lm, output_final_text_box_lm],
|
828 |
+
queue=False
|
829 |
+
)
|
830 |
+
clear_button_ui_mmu.click(
|
831 |
+
fn=clear_outputs,
|
832 |
+
inputs=None,
|
833 |
+
outputs=[image_upload_box, output_visualization_box_mmu, output_final_text_box_mmu],
|
834 |
+
queue=False
|
835 |
+
)
|
836 |
+
|
837 |
+
run_button_ui_lm.click(
|
838 |
+
fn=generate_viz_wrapper_lm,
|
839 |
+
inputs=[
|
840 |
+
prompt_input_box_lm,
|
841 |
+
steps_slider_lm,
|
842 |
+
gen_length_slider_lm,
|
843 |
+
block_length_slider_lm,
|
844 |
+
temperature_slider_lm,
|
845 |
+
cfg_scale_slider_lm,
|
846 |
+
remasking_dropdown_lm,
|
847 |
+
thinking_mode_lm
|
848 |
+
],
|
849 |
+
outputs=[output_visualization_box_lm, output_final_text_box_lm]
|
850 |
+
)
|
851 |
+
|
852 |
+
run_button_ui_mmu.click(
|
853 |
+
fn=generate_viz_wrapper,
|
854 |
+
inputs=[
|
855 |
+
image_upload_box,
|
856 |
+
prompt_input_box_mmu,
|
857 |
+
steps_slider_mmu,
|
858 |
+
gen_length_slider_mmu,
|
859 |
+
block_length_slider_mmu,
|
860 |
+
temperature_slider_mmu,
|
861 |
+
cfg_scale_slider_mmu,
|
862 |
+
remasking_dropdown_mmu,
|
863 |
+
thinking_mode_mmu
|
864 |
+
],
|
865 |
+
outputs=[output_visualization_box_mmu, output_final_text_box_mmu]
|
866 |
+
)
|
867 |
+
|
868 |
+
|
869 |
+
if __name__ == "__main__":
|
870 |
+
print(f"Starting Gradio App. Attempting to use device: {DEVICE}")
|
871 |
+
demo.launch(share=True)
|
models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .modeling_magvitv2 import VQGANEncoder, VQGANDecoder, LFQuantizer, MAGVITv2
|
2 |
+
from .sampling import *
|
3 |
+
from .modeling_mmada import MMadaModelLM, MMadaConfig
|
models/common_modules.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L34
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from typing import Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from einops.layers.torch import Rearrange
|
14 |
+
|
15 |
+
|
16 |
+
def nonlinearity(x):
|
17 |
+
# swish
|
18 |
+
return x * torch.sigmoid(x)
|
19 |
+
|
20 |
+
|
21 |
+
def Normalize(in_channels):
|
22 |
+
return torch.nn.GroupNorm(
|
23 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class Upsample(nn.Module):
|
28 |
+
def __init__(self, in_channels, with_conv):
|
29 |
+
super().__init__()
|
30 |
+
self.with_conv = with_conv
|
31 |
+
if self.with_conv:
|
32 |
+
self.conv = torch.nn.Conv2d(
|
33 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
38 |
+
if self.with_conv:
|
39 |
+
x = self.conv(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class DepthToSpaceUpsample(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
in_channels,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
conv = nn.Conv2d(in_channels, in_channels * 4, 1)
|
50 |
+
|
51 |
+
self.net = nn.Sequential(
|
52 |
+
conv,
|
53 |
+
nn.SiLU(),
|
54 |
+
Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
|
55 |
+
)
|
56 |
+
|
57 |
+
self.init_conv_(conv)
|
58 |
+
|
59 |
+
def init_conv_(self, conv):
|
60 |
+
o, i, h, w = conv.weight.shape
|
61 |
+
conv_weight = torch.empty(o // 4, i, h, w)
|
62 |
+
nn.init.kaiming_uniform_(conv_weight)
|
63 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
64 |
+
|
65 |
+
conv.weight.data.copy_(conv_weight)
|
66 |
+
nn.init.zeros_(conv.bias.data)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
out = self.net(x)
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
class Downsample(nn.Module):
|
74 |
+
def __init__(self, in_channels, with_conv):
|
75 |
+
super().__init__()
|
76 |
+
self.with_conv = with_conv
|
77 |
+
if self.with_conv:
|
78 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
79 |
+
self.conv = torch.nn.Conv2d(
|
80 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
if self.with_conv:
|
85 |
+
pad = (0, 1, 0, 1)
|
86 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
87 |
+
x = self.conv(x)
|
88 |
+
else:
|
89 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
def unpack_time(t, batch):
|
94 |
+
_, c, w, h = t.size()
|
95 |
+
out = torch.reshape(t, [batch, -1, c, w, h])
|
96 |
+
out = rearrange(out, "b t c h w -> b c t h w")
|
97 |
+
return out
|
98 |
+
|
99 |
+
|
100 |
+
def pack_time(t):
|
101 |
+
out = rearrange(t, "b c t h w -> b t c h w")
|
102 |
+
_, _, c, w, h = out.size()
|
103 |
+
return torch.reshape(out, [-1, c, w, h])
|
104 |
+
|
105 |
+
|
106 |
+
class TimeDownsample2x(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
dim,
|
110 |
+
dim_out=None,
|
111 |
+
kernel_size=3,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
if dim_out is None:
|
115 |
+
dim_out = dim
|
116 |
+
self.time_causal_padding = (kernel_size - 1, 0)
|
117 |
+
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
121 |
+
b, h, w, c, t = x.size()
|
122 |
+
x = torch.reshape(x, [-1, c, t])
|
123 |
+
|
124 |
+
x = F.pad(x, self.time_causal_padding)
|
125 |
+
out = self.conv(x)
|
126 |
+
|
127 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
128 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
129 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
130 |
+
return out
|
131 |
+
|
132 |
+
|
133 |
+
class TimeUpsample2x(nn.Module):
|
134 |
+
def __init__(self, dim, dim_out=None):
|
135 |
+
super().__init__()
|
136 |
+
if dim_out is None:
|
137 |
+
dim_out = dim
|
138 |
+
conv = nn.Conv1d(dim, dim_out * 2, 1)
|
139 |
+
|
140 |
+
self.net = nn.Sequential(
|
141 |
+
nn.SiLU(), conv, Rearrange("b (c p) t -> b c (t p)", p=2)
|
142 |
+
)
|
143 |
+
|
144 |
+
self.init_conv_(conv)
|
145 |
+
|
146 |
+
def init_conv_(self, conv):
|
147 |
+
o, i, t = conv.weight.shape
|
148 |
+
conv_weight = torch.empty(o // 2, i, t)
|
149 |
+
nn.init.kaiming_uniform_(conv_weight)
|
150 |
+
conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
|
151 |
+
|
152 |
+
conv.weight.data.copy_(conv_weight)
|
153 |
+
nn.init.zeros_(conv.bias.data)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
x = rearrange(x, "b c t h w -> b h w c t")
|
157 |
+
b, h, w, c, t = x.size()
|
158 |
+
x = torch.reshape(x, [-1, c, t])
|
159 |
+
|
160 |
+
out = self.net(x)
|
161 |
+
out = out[:, :, 1:].contiguous()
|
162 |
+
|
163 |
+
out = torch.reshape(out, [b, h, w, c, t])
|
164 |
+
out = rearrange(out, "b h w c t -> b c t h w")
|
165 |
+
return out
|
166 |
+
|
167 |
+
|
168 |
+
class AttnBlock(nn.Module):
|
169 |
+
def __init__(self, in_channels):
|
170 |
+
super().__init__()
|
171 |
+
self.in_channels = in_channels
|
172 |
+
|
173 |
+
self.norm = Normalize(in_channels)
|
174 |
+
self.q = torch.nn.Conv2d(
|
175 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
176 |
+
)
|
177 |
+
self.k = torch.nn.Conv2d(
|
178 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
179 |
+
)
|
180 |
+
self.v = torch.nn.Conv2d(
|
181 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
182 |
+
)
|
183 |
+
self.proj_out = torch.nn.Conv2d(
|
184 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
185 |
+
)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
h_ = x
|
189 |
+
h_ = self.norm(h_)
|
190 |
+
q = self.q(h_)
|
191 |
+
k = self.k(h_)
|
192 |
+
v = self.v(h_)
|
193 |
+
|
194 |
+
# compute attention
|
195 |
+
b, c, h, w = q.shape
|
196 |
+
q = q.reshape(b, c, h * w)
|
197 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
198 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
199 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
200 |
+
w_ = w_ * (int(c) ** (-0.5))
|
201 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
202 |
+
|
203 |
+
# attend to values
|
204 |
+
v = v.reshape(b, c, h * w)
|
205 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
206 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
207 |
+
h_ = h_.reshape(b, c, h, w)
|
208 |
+
|
209 |
+
h_ = self.proj_out(h_)
|
210 |
+
|
211 |
+
return x + h_
|
212 |
+
|
213 |
+
|
214 |
+
class TimeAttention(AttnBlock):
|
215 |
+
def forward(self, x, *args, **kwargs):
|
216 |
+
x = rearrange(x, "b c t h w -> b h w t c")
|
217 |
+
b, h, w, t, c = x.size()
|
218 |
+
x = torch.reshape(x, (-1, t, c))
|
219 |
+
|
220 |
+
x = super().forward(x, *args, **kwargs)
|
221 |
+
|
222 |
+
x = torch.reshape(x, [b, h, w, t, c])
|
223 |
+
return rearrange(x, "b h w t c -> b c t h w")
|
224 |
+
|
225 |
+
|
226 |
+
class Residual(nn.Module):
|
227 |
+
def __init__(self, fn: nn.Module):
|
228 |
+
super().__init__()
|
229 |
+
self.fn = fn
|
230 |
+
|
231 |
+
def forward(self, x, **kwargs):
|
232 |
+
return self.fn(x, **kwargs) + x
|
233 |
+
|
234 |
+
|
235 |
+
def cast_tuple(t, length=1):
|
236 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
237 |
+
|
238 |
+
|
239 |
+
class CausalConv3d(nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
chan_in,
|
243 |
+
chan_out,
|
244 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
245 |
+
pad_mode="constant",
|
246 |
+
**kwargs
|
247 |
+
):
|
248 |
+
super().__init__()
|
249 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
250 |
+
|
251 |
+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
252 |
+
|
253 |
+
dilation = kwargs.pop("dilation", 1)
|
254 |
+
stride = kwargs.pop("stride", 1)
|
255 |
+
|
256 |
+
self.pad_mode = pad_mode
|
257 |
+
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
258 |
+
height_pad = height_kernel_size // 2
|
259 |
+
width_pad = width_kernel_size // 2
|
260 |
+
|
261 |
+
self.time_pad = time_pad
|
262 |
+
self.time_causal_padding = (
|
263 |
+
width_pad,
|
264 |
+
width_pad,
|
265 |
+
height_pad,
|
266 |
+
height_pad,
|
267 |
+
time_pad,
|
268 |
+
0,
|
269 |
+
)
|
270 |
+
|
271 |
+
stride = (stride, 1, 1)
|
272 |
+
dilation = (dilation, 1, 1)
|
273 |
+
self.conv = nn.Conv3d(
|
274 |
+
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
275 |
+
)
|
276 |
+
|
277 |
+
def forward(self, x):
|
278 |
+
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
279 |
+
|
280 |
+
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
|
281 |
+
return self.conv(x)
|
282 |
+
|
283 |
+
|
284 |
+
def ResnetBlockCausal3D(
|
285 |
+
dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"
|
286 |
+
):
|
287 |
+
net = nn.Sequential(
|
288 |
+
Normalize(dim),
|
289 |
+
nn.SiLU(),
|
290 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
291 |
+
Normalize(dim),
|
292 |
+
nn.SiLU(),
|
293 |
+
CausalConv3d(dim, dim, kernel_size, pad_mode),
|
294 |
+
)
|
295 |
+
return Residual(net)
|
296 |
+
|
297 |
+
|
298 |
+
class ResnetBlock(nn.Module):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
*,
|
302 |
+
in_channels,
|
303 |
+
out_channels=None,
|
304 |
+
conv_shortcut=False,
|
305 |
+
dropout,
|
306 |
+
temb_channels=512
|
307 |
+
):
|
308 |
+
super().__init__()
|
309 |
+
self.in_channels = in_channels
|
310 |
+
out_channels = in_channels if out_channels is None else out_channels
|
311 |
+
self.out_channels = out_channels
|
312 |
+
self.use_conv_shortcut = conv_shortcut
|
313 |
+
|
314 |
+
self.norm1 = Normalize(in_channels)
|
315 |
+
self.conv1 = torch.nn.Conv2d(
|
316 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
317 |
+
)
|
318 |
+
if temb_channels > 0:
|
319 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
320 |
+
else:
|
321 |
+
self.temb_proj = None
|
322 |
+
self.norm2 = Normalize(out_channels)
|
323 |
+
self.dropout = torch.nn.Dropout(dropout)
|
324 |
+
self.conv2 = torch.nn.Conv2d(
|
325 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
326 |
+
)
|
327 |
+
if self.in_channels != self.out_channels:
|
328 |
+
if self.use_conv_shortcut:
|
329 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
330 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
334 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
335 |
+
)
|
336 |
+
|
337 |
+
def forward(self, x, temb):
|
338 |
+
h = x
|
339 |
+
h = self.norm1(h)
|
340 |
+
h = nonlinearity(h)
|
341 |
+
h = self.conv1(h)
|
342 |
+
|
343 |
+
if temb is not None:
|
344 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
345 |
+
|
346 |
+
h = self.norm2(h)
|
347 |
+
h = nonlinearity(h)
|
348 |
+
h = self.dropout(h)
|
349 |
+
h = self.conv2(h)
|
350 |
+
|
351 |
+
if self.in_channels != self.out_channels:
|
352 |
+
if self.use_conv_shortcut:
|
353 |
+
x = self.conv_shortcut(x)
|
354 |
+
else:
|
355 |
+
x = self.nin_shortcut(x)
|
356 |
+
|
357 |
+
return x + h
|
models/configuration_llada.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLaDA configuration
|
3 |
+
"""
|
4 |
+
from transformers import AutoConfig, PretrainedConfig
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from os import PathLike
|
8 |
+
from typing import Union
|
9 |
+
from dataclasses import asdict, dataclass, field
|
10 |
+
from glob import glob
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import (
|
13 |
+
Any,
|
14 |
+
Dict,
|
15 |
+
Iterable,
|
16 |
+
List,
|
17 |
+
Optional,
|
18 |
+
Tuple,
|
19 |
+
Type,
|
20 |
+
TypeVar,
|
21 |
+
Union,
|
22 |
+
cast,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
__all__ = [
|
27 |
+
"ActivationType",
|
28 |
+
"ActivationCheckpointingStrategy",
|
29 |
+
"BlockType",
|
30 |
+
"LayerNormType",
|
31 |
+
"InitFnType",
|
32 |
+
"ModelConfig",
|
33 |
+
]
|
34 |
+
|
35 |
+
PathOrStr = Union[str, PathLike]
|
36 |
+
|
37 |
+
|
38 |
+
class StrEnum(str, Enum):
|
39 |
+
"""
|
40 |
+
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
41 |
+
We include this here for compatibility with older version of Python.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __str__(self) -> str:
|
45 |
+
return self.value
|
46 |
+
|
47 |
+
def __repr__(self) -> str:
|
48 |
+
return f"'{str(self)}'"
|
49 |
+
|
50 |
+
|
51 |
+
class LayerNormType(StrEnum):
|
52 |
+
default = "default"
|
53 |
+
"""
|
54 |
+
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
55 |
+
"""
|
56 |
+
|
57 |
+
low_precision = "low_precision"
|
58 |
+
"""
|
59 |
+
A low-precision version of the default LayerNorm.
|
60 |
+
"""
|
61 |
+
|
62 |
+
rms = "rms"
|
63 |
+
"""
|
64 |
+
An RMSNorm implementation. When using ``torch.compile`` this is
|
65 |
+
probably the fastest implementation.
|
66 |
+
"""
|
67 |
+
|
68 |
+
gemma_rms = "gemma_rms"
|
69 |
+
"""
|
70 |
+
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
71 |
+
probably the fastest implementation.
|
72 |
+
"""
|
73 |
+
|
74 |
+
amd_compatible = "amd_compatible"
|
75 |
+
"""
|
76 |
+
LayerNorm implemented manually to work around an issue with ROCm.
|
77 |
+
"""
|
78 |
+
|
79 |
+
|
80 |
+
class ActivationType(StrEnum):
|
81 |
+
gelu = "gelu"
|
82 |
+
relu = "relu"
|
83 |
+
silu = "silu"
|
84 |
+
swiglu = "swiglu"
|
85 |
+
|
86 |
+
|
87 |
+
class BlockType(StrEnum):
|
88 |
+
sequential = "sequential"
|
89 |
+
parallel = "parallel"
|
90 |
+
|
91 |
+
llama = "llama"
|
92 |
+
"""
|
93 |
+
A block similar to the sequential block with slightly different
|
94 |
+
implementations of operations like attention to imitate the behavior of Llama.
|
95 |
+
"""
|
96 |
+
|
97 |
+
|
98 |
+
class InitFnType(StrEnum):
|
99 |
+
mitchell = "mitchell"
|
100 |
+
"""
|
101 |
+
The strategy suggested to us by Mitchell Wortsman from UW.
|
102 |
+
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
103 |
+
on the size of the weights as well as the depth of the layer.
|
104 |
+
"""
|
105 |
+
|
106 |
+
normal = "normal"
|
107 |
+
"""
|
108 |
+
All weights are initialized from the same normal distribution.
|
109 |
+
"""
|
110 |
+
|
111 |
+
kaiming_normal = "kaiming_normal"
|
112 |
+
"""
|
113 |
+
All weights are initialized with the Kaiming method from a normal distribution.
|
114 |
+
Note this currently won't work with FSDP.
|
115 |
+
"""
|
116 |
+
|
117 |
+
fan_in = "fan_in"
|
118 |
+
"""
|
119 |
+
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
120 |
+
is the input dimensionality of the kernel.
|
121 |
+
"""
|
122 |
+
|
123 |
+
full_megatron = "full_megatron"
|
124 |
+
"""
|
125 |
+
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
126 |
+
"""
|
127 |
+
|
128 |
+
|
129 |
+
@dataclass
|
130 |
+
class ModelConfig():
|
131 |
+
"""
|
132 |
+
LLaDA (model) configuration.
|
133 |
+
"""
|
134 |
+
|
135 |
+
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
136 |
+
|
137 |
+
d_model: int = 768
|
138 |
+
"""
|
139 |
+
The hidden size of the model.
|
140 |
+
"""
|
141 |
+
|
142 |
+
n_heads: int = 12
|
143 |
+
"""
|
144 |
+
The number of self-attention heads.
|
145 |
+
"""
|
146 |
+
|
147 |
+
n_kv_heads: Optional[int] = None
|
148 |
+
"""
|
149 |
+
The number of heads to use for keys and values. Defaults to `n_heads`.
|
150 |
+
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
151 |
+
Set this to 1 for multi-query attention.
|
152 |
+
Set it to some in-between value for Llama2-style grouped query attention.
|
153 |
+
"""
|
154 |
+
|
155 |
+
n_layers: int = 12
|
156 |
+
"""
|
157 |
+
The number of layers/blocks.
|
158 |
+
"""
|
159 |
+
|
160 |
+
mlp_ratio: int = 4
|
161 |
+
"""
|
162 |
+
The ratio of the inner MLP dimensionality to ``d_model``.
|
163 |
+
This is only used when ``mlp_hidden_size`` is not set.
|
164 |
+
"""
|
165 |
+
|
166 |
+
mlp_hidden_size: Optional[int] = None
|
167 |
+
"""
|
168 |
+
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
169 |
+
"""
|
170 |
+
|
171 |
+
activation_type: ActivationType = ActivationType.swiglu
|
172 |
+
"""
|
173 |
+
The activation function to use within the MLP layers.
|
174 |
+
"""
|
175 |
+
|
176 |
+
block_type: BlockType = BlockType.sequential
|
177 |
+
"""
|
178 |
+
The transformer block implementation.
|
179 |
+
"""
|
180 |
+
|
181 |
+
block_group_size: int = 1
|
182 |
+
"""
|
183 |
+
The number of blocks to group together into a single parent block.
|
184 |
+
This has no affect on the number of parameters in the model and is only used to wrap groups
|
185 |
+
of blocks together with a single FSDP wrapper during training.
|
186 |
+
"""
|
187 |
+
|
188 |
+
alibi: bool = False
|
189 |
+
"""
|
190 |
+
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
191 |
+
"""
|
192 |
+
|
193 |
+
alibi_bias_max: float = 8.0
|
194 |
+
"""
|
195 |
+
Maximum absolute value of ALiBi bias.
|
196 |
+
"""
|
197 |
+
|
198 |
+
rope: bool = False
|
199 |
+
"""
|
200 |
+
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
201 |
+
"""
|
202 |
+
|
203 |
+
rope_full_precision: bool = True
|
204 |
+
"""
|
205 |
+
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
206 |
+
apply RoPE at the precision of the input.
|
207 |
+
"""
|
208 |
+
|
209 |
+
flash_attention: bool = False
|
210 |
+
"""
|
211 |
+
If ``True``, use ``FlashAttention``.
|
212 |
+
"""
|
213 |
+
|
214 |
+
attention_dropout: float = 0.1
|
215 |
+
"""
|
216 |
+
The dropout probability within the attention modules.
|
217 |
+
"""
|
218 |
+
|
219 |
+
multi_query_attention: Optional[bool] = None
|
220 |
+
"""
|
221 |
+
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
222 |
+
and is more efficient during inference.
|
223 |
+
"""
|
224 |
+
|
225 |
+
attention_layer_norm: bool = False
|
226 |
+
"""
|
227 |
+
Apply layer norm to the keys and queries within the attention mechanism.
|
228 |
+
This can help stabilize training.
|
229 |
+
"""
|
230 |
+
|
231 |
+
residual_dropout: float = 0.1
|
232 |
+
"""
|
233 |
+
The dropout probability for the MLP and attention output within each block.
|
234 |
+
"""
|
235 |
+
|
236 |
+
embedding_dropout: float = 0.1
|
237 |
+
"""
|
238 |
+
The dropout probability for embeddings.
|
239 |
+
"""
|
240 |
+
|
241 |
+
input_emb_norm: bool = False
|
242 |
+
"""
|
243 |
+
An input hidden_states norm implementation by gemmma.
|
244 |
+
"""
|
245 |
+
|
246 |
+
layer_norm_type: LayerNormType = LayerNormType.default
|
247 |
+
"""
|
248 |
+
The layernorm implementation to use.
|
249 |
+
"""
|
250 |
+
|
251 |
+
layer_norm_with_affine: bool = True
|
252 |
+
"""
|
253 |
+
Whether to include bias and weight parameters for the layer norms.
|
254 |
+
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
255 |
+
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
256 |
+
to ``False``.
|
257 |
+
"""
|
258 |
+
|
259 |
+
rms_norm_eps: float = 1e-05
|
260 |
+
"""
|
261 |
+
The rms layernorm eps param.
|
262 |
+
"""
|
263 |
+
|
264 |
+
attention_layer_norm_with_affine: bool = True
|
265 |
+
"""
|
266 |
+
Toggle affine transform for the QK norms.
|
267 |
+
"""
|
268 |
+
|
269 |
+
max_sequence_length: int = 1024
|
270 |
+
"""
|
271 |
+
The maximum input sequence length supported by the model.
|
272 |
+
"""
|
273 |
+
|
274 |
+
rope_theta: float = 10000.0
|
275 |
+
"""
|
276 |
+
The rope base param.
|
277 |
+
"""
|
278 |
+
|
279 |
+
include_qkv_bias: Optional[bool] = False
|
280 |
+
"""
|
281 |
+
Whether or not to include bias parameters in qkv linear layers.
|
282 |
+
"""
|
283 |
+
|
284 |
+
include_bias: bool = False
|
285 |
+
"""
|
286 |
+
Whether or not to include bias parameters in linear layers.
|
287 |
+
In PaLM, they got rid of all bias terms because they found that large
|
288 |
+
models tend to have near 0 bias terms anyway.
|
289 |
+
"""
|
290 |
+
|
291 |
+
bias_for_layer_norm: Optional[bool] = None
|
292 |
+
"""
|
293 |
+
Whether or not to include bias parameters in layer norm.
|
294 |
+
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
295 |
+
layer norm.
|
296 |
+
When this is None (the default), it inherits the setting from include_bias.
|
297 |
+
"""
|
298 |
+
|
299 |
+
scale_logits: bool = False
|
300 |
+
"""
|
301 |
+
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
302 |
+
"""
|
303 |
+
|
304 |
+
vocab_size: int = 50257
|
305 |
+
"""
|
306 |
+
Vocabulary size of the model.
|
307 |
+
"""
|
308 |
+
|
309 |
+
embedding_size: Optional[int] = 50304
|
310 |
+
"""
|
311 |
+
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
312 |
+
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
313 |
+
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
314 |
+
substantially.
|
315 |
+
"""
|
316 |
+
|
317 |
+
weight_tying: bool = True
|
318 |
+
"""
|
319 |
+
Whether to tie output linear weights to the input embedding.
|
320 |
+
"""
|
321 |
+
|
322 |
+
eos_token_id: int = 50256
|
323 |
+
"""
|
324 |
+
The ID of the end-of-sentence special token.
|
325 |
+
"""
|
326 |
+
|
327 |
+
pad_token_id: int = 50256
|
328 |
+
"""
|
329 |
+
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
330 |
+
"""
|
331 |
+
|
332 |
+
mask_token_id: Optional[int] = 50256
|
333 |
+
"""
|
334 |
+
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
335 |
+
"""
|
336 |
+
|
337 |
+
init_device: Optional[str] = None
|
338 |
+
"""
|
339 |
+
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
340 |
+
"""
|
341 |
+
|
342 |
+
init_fn: InitFnType = InitFnType.normal
|
343 |
+
"""
|
344 |
+
The weight initialization strategy.
|
345 |
+
"""
|
346 |
+
|
347 |
+
init_std: float = 0.02
|
348 |
+
"""
|
349 |
+
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
350 |
+
as "normal".
|
351 |
+
"""
|
352 |
+
|
353 |
+
init_cutoff_factor: Optional[float] = None
|
354 |
+
"""
|
355 |
+
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
356 |
+
as "normal". Setting this to None means values are not cutoff.
|
357 |
+
"""
|
358 |
+
|
359 |
+
precision: Optional[str] = None
|
360 |
+
"""
|
361 |
+
Precision used to train/evaluate with. You shouldn't set this directly.
|
362 |
+
See :data:`TrainConfig.precision` instead.
|
363 |
+
"""
|
364 |
+
|
365 |
+
@property
|
366 |
+
def effective_n_kv_heads(self) -> int:
|
367 |
+
if self.n_kv_heads is None:
|
368 |
+
if self.multi_query_attention is True:
|
369 |
+
return 1
|
370 |
+
else:
|
371 |
+
return self.n_heads
|
372 |
+
else:
|
373 |
+
if self.multi_query_attention is None:
|
374 |
+
return self.n_kv_heads
|
375 |
+
if self.multi_query_attention:
|
376 |
+
n_kv_heads_should_be = 1
|
377 |
+
else:
|
378 |
+
n_kv_heads_should_be = self.n_heads
|
379 |
+
if self.n_kv_heads == n_kv_heads_should_be:
|
380 |
+
return n_kv_heads_should_be
|
381 |
+
else:
|
382 |
+
raise Exception(
|
383 |
+
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
384 |
+
)
|
385 |
+
|
386 |
+
class ActivationCheckpointingStrategy(StrEnum):
|
387 |
+
whole_layer = "whole_layer"
|
388 |
+
"""
|
389 |
+
Checkpoint every transformer layer.
|
390 |
+
"""
|
391 |
+
|
392 |
+
one_in_two = "one_in_two"
|
393 |
+
"""
|
394 |
+
Checkpoint one in two transformer layers.
|
395 |
+
"""
|
396 |
+
|
397 |
+
one_in_three = "one_in_three"
|
398 |
+
"""
|
399 |
+
Checkpoint one in three transformer layers.
|
400 |
+
"""
|
401 |
+
|
402 |
+
one_in_four = "one_in_four"
|
403 |
+
"""
|
404 |
+
Checkpoint one in four transformer layers.
|
405 |
+
"""
|
406 |
+
|
407 |
+
two_in_three = "two_in_three"
|
408 |
+
"""
|
409 |
+
Checkpoint two out of every three transformer layers.
|
410 |
+
"""
|
411 |
+
|
412 |
+
three_in_four = "three_in_four"
|
413 |
+
"""
|
414 |
+
Checkpoint three out of four of every transformer layers.
|
415 |
+
"""
|
416 |
+
|
417 |
+
four_in_five = "four_in_five"
|
418 |
+
"""
|
419 |
+
Checkpoint four out of five of every transformer layers.
|
420 |
+
"""
|
421 |
+
|
422 |
+
nine_in_ten = "nine_in_ten"
|
423 |
+
"""
|
424 |
+
Checkpoint nine out of ten of every transformer layers.
|
425 |
+
"""
|
426 |
+
|
427 |
+
fine_grained = "fine_grained"
|
428 |
+
"""
|
429 |
+
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
430 |
+
"""
|
431 |
+
|
432 |
+
|
433 |
+
class LLaDAConfig(PretrainedConfig):
|
434 |
+
model_type = "llada"
|
435 |
+
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
436 |
+
|
437 |
+
def __init__(self, use_cache: bool = False, **kwargs):
|
438 |
+
model_config = ModelConfig()
|
439 |
+
all_kwargs = model_config.__dict__
|
440 |
+
all_kwargs.update(kwargs)
|
441 |
+
all_kwargs.update({"use_cache": use_cache})
|
442 |
+
all_kwargs.update(
|
443 |
+
{
|
444 |
+
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
445 |
+
}
|
446 |
+
)
|
447 |
+
super().__init__(**all_kwargs)
|
448 |
+
|
449 |
+
@property
|
450 |
+
def num_attention_heads(self):
|
451 |
+
return self.n_heads
|
452 |
+
|
453 |
+
@property
|
454 |
+
def num_hidden_layers(self):
|
455 |
+
return self.n_layers
|
456 |
+
|
457 |
+
@property
|
458 |
+
def hidden_size(self):
|
459 |
+
return self.d_model
|
460 |
+
|
461 |
+
|
462 |
+
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
|
463 |
+
AutoConfig.register("llada", LLaDAConfig)
|
models/logging.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Optuna, Hugging Face
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Logging utilities."""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import threading
|
21 |
+
from logging import CRITICAL # NOQA
|
22 |
+
from logging import DEBUG # NOQA
|
23 |
+
from logging import ERROR # NOQA
|
24 |
+
from logging import FATAL # NOQA
|
25 |
+
from logging import INFO # NOQA
|
26 |
+
from logging import NOTSET # NOQA
|
27 |
+
from logging import WARN # NOQA
|
28 |
+
from logging import WARNING # NOQA
|
29 |
+
from typing import Optional
|
30 |
+
|
31 |
+
from tqdm import auto as tqdm_lib
|
32 |
+
|
33 |
+
_lock = threading.Lock()
|
34 |
+
_default_handler: Optional[logging.Handler] = None
|
35 |
+
|
36 |
+
log_levels = {
|
37 |
+
"debug": logging.DEBUG,
|
38 |
+
"info": logging.INFO,
|
39 |
+
"warning": logging.WARNING,
|
40 |
+
"error": logging.ERROR,
|
41 |
+
"critical": logging.CRITICAL,
|
42 |
+
}
|
43 |
+
|
44 |
+
_default_log_level = logging.WARNING
|
45 |
+
|
46 |
+
_tqdm_active = True
|
47 |
+
|
48 |
+
|
49 |
+
def _get_default_logging_level():
|
50 |
+
"""
|
51 |
+
If muse_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
52 |
+
not - fall back to `_default_log_level`
|
53 |
+
"""
|
54 |
+
env_level_str = os.getenv("muse_VERBOSITY", None)
|
55 |
+
if env_level_str:
|
56 |
+
if env_level_str in log_levels:
|
57 |
+
return log_levels[env_level_str]
|
58 |
+
else:
|
59 |
+
logging.getLogger().warning(
|
60 |
+
f"Unknown option muse_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }"
|
61 |
+
)
|
62 |
+
return _default_log_level
|
63 |
+
|
64 |
+
|
65 |
+
def _get_library_name() -> str:
|
66 |
+
return __name__.split(".")[0]
|
67 |
+
|
68 |
+
|
69 |
+
def _get_library_root_logger() -> logging.Logger:
|
70 |
+
return logging.getLogger(_get_library_name())
|
71 |
+
|
72 |
+
|
73 |
+
def _configure_library_root_logger() -> None:
|
74 |
+
global _default_handler
|
75 |
+
|
76 |
+
with _lock:
|
77 |
+
if _default_handler:
|
78 |
+
# This library has already configured the library root logger.
|
79 |
+
return
|
80 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
81 |
+
_default_handler.flush = sys.stderr.flush
|
82 |
+
|
83 |
+
# Apply our default configuration to the library root logger.
|
84 |
+
library_root_logger = _get_library_root_logger()
|
85 |
+
library_root_logger.addHandler(_default_handler)
|
86 |
+
library_root_logger.setLevel(_get_default_logging_level())
|
87 |
+
library_root_logger.propagate = False
|
88 |
+
|
89 |
+
|
90 |
+
def _reset_library_root_logger() -> None:
|
91 |
+
global _default_handler
|
92 |
+
|
93 |
+
with _lock:
|
94 |
+
if not _default_handler:
|
95 |
+
return
|
96 |
+
|
97 |
+
library_root_logger = _get_library_root_logger()
|
98 |
+
library_root_logger.removeHandler(_default_handler)
|
99 |
+
library_root_logger.setLevel(logging.NOTSET)
|
100 |
+
_default_handler = None
|
101 |
+
|
102 |
+
|
103 |
+
def get_log_levels_dict():
|
104 |
+
return log_levels
|
105 |
+
|
106 |
+
|
107 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
108 |
+
"""
|
109 |
+
Return a logger with the specified name.
|
110 |
+
|
111 |
+
This function is not supposed to be directly accessed unless you are writing a custom muse module.
|
112 |
+
"""
|
113 |
+
|
114 |
+
if name is None:
|
115 |
+
name = _get_library_name()
|
116 |
+
|
117 |
+
_configure_library_root_logger()
|
118 |
+
return logging.getLogger(name)
|
119 |
+
|
120 |
+
|
121 |
+
def get_verbosity() -> int:
|
122 |
+
"""
|
123 |
+
Return the current level for the 🤗 muse' root logger as an int.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
`int`: The logging level.
|
127 |
+
|
128 |
+
<Tip>
|
129 |
+
|
130 |
+
🤗 muse has following logging levels:
|
131 |
+
|
132 |
+
- 50: `muse.logging.CRITICAL` or `muse.logging.FATAL`
|
133 |
+
- 40: `muse.logging.ERROR`
|
134 |
+
- 30: `muse.logging.WARNING` or `muse.logging.WARN`
|
135 |
+
- 20: `muse.logging.INFO`
|
136 |
+
- 10: `muse.logging.DEBUG`
|
137 |
+
|
138 |
+
</Tip>"""
|
139 |
+
|
140 |
+
_configure_library_root_logger()
|
141 |
+
return _get_library_root_logger().getEffectiveLevel()
|
142 |
+
|
143 |
+
|
144 |
+
def set_verbosity(verbosity: int) -> None:
|
145 |
+
"""
|
146 |
+
Set the verbosity level for the 🤗 muse' root logger.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
verbosity (`int`):
|
150 |
+
Logging level, e.g., one of:
|
151 |
+
|
152 |
+
- `muse.logging.CRITICAL` or `muse.logging.FATAL`
|
153 |
+
- `muse.logging.ERROR`
|
154 |
+
- `muse.logging.WARNING` or `muse.logging.WARN`
|
155 |
+
- `muse.logging.INFO`
|
156 |
+
- `muse.logging.DEBUG`
|
157 |
+
"""
|
158 |
+
|
159 |
+
_configure_library_root_logger()
|
160 |
+
_get_library_root_logger().setLevel(verbosity)
|
161 |
+
|
162 |
+
|
163 |
+
def set_verbosity_info():
|
164 |
+
"""Set the verbosity to the `INFO` level."""
|
165 |
+
return set_verbosity(INFO)
|
166 |
+
|
167 |
+
|
168 |
+
def set_verbosity_warning():
|
169 |
+
"""Set the verbosity to the `WARNING` level."""
|
170 |
+
return set_verbosity(WARNING)
|
171 |
+
|
172 |
+
|
173 |
+
def set_verbosity_debug():
|
174 |
+
"""Set the verbosity to the `DEBUG` level."""
|
175 |
+
return set_verbosity(DEBUG)
|
176 |
+
|
177 |
+
|
178 |
+
def set_verbosity_error():
|
179 |
+
"""Set the verbosity to the `ERROR` level."""
|
180 |
+
return set_verbosity(ERROR)
|
181 |
+
|
182 |
+
|
183 |
+
def disable_default_handler() -> None:
|
184 |
+
"""Disable the default handler of the HuggingFace muse' root logger."""
|
185 |
+
|
186 |
+
_configure_library_root_logger()
|
187 |
+
|
188 |
+
assert _default_handler is not None
|
189 |
+
_get_library_root_logger().removeHandler(_default_handler)
|
190 |
+
|
191 |
+
|
192 |
+
def enable_default_handler() -> None:
|
193 |
+
"""Enable the default handler of the HuggingFace muse' root logger."""
|
194 |
+
|
195 |
+
_configure_library_root_logger()
|
196 |
+
|
197 |
+
assert _default_handler is not None
|
198 |
+
_get_library_root_logger().addHandler(_default_handler)
|
199 |
+
|
200 |
+
|
201 |
+
def add_handler(handler: logging.Handler) -> None:
|
202 |
+
"""adds a handler to the HuggingFace muse' root logger."""
|
203 |
+
|
204 |
+
_configure_library_root_logger()
|
205 |
+
|
206 |
+
assert handler is not None
|
207 |
+
_get_library_root_logger().addHandler(handler)
|
208 |
+
|
209 |
+
|
210 |
+
def remove_handler(handler: logging.Handler) -> None:
|
211 |
+
"""removes given handler from the HuggingFace muse' root logger."""
|
212 |
+
|
213 |
+
_configure_library_root_logger()
|
214 |
+
|
215 |
+
assert handler is not None and handler not in _get_library_root_logger().handlers
|
216 |
+
_get_library_root_logger().removeHandler(handler)
|
217 |
+
|
218 |
+
|
219 |
+
def disable_propagation() -> None:
|
220 |
+
"""
|
221 |
+
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
222 |
+
"""
|
223 |
+
|
224 |
+
_configure_library_root_logger()
|
225 |
+
_get_library_root_logger().propagate = False
|
226 |
+
|
227 |
+
|
228 |
+
def enable_propagation() -> None:
|
229 |
+
"""
|
230 |
+
Enable propagation of the library log outputs. Please disable the HuggingFace muse' default handler to prevent
|
231 |
+
double logging if the root logger has been configured.
|
232 |
+
"""
|
233 |
+
|
234 |
+
_configure_library_root_logger()
|
235 |
+
_get_library_root_logger().propagate = True
|
236 |
+
|
237 |
+
|
238 |
+
def enable_explicit_format() -> None:
|
239 |
+
"""
|
240 |
+
Enable explicit formatting for every HuggingFace muse' logger. The explicit formatter is as follows:
|
241 |
+
```
|
242 |
+
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
243 |
+
```
|
244 |
+
All handlers currently bound to the root logger are affected by this method.
|
245 |
+
"""
|
246 |
+
handlers = _get_library_root_logger().handlers
|
247 |
+
|
248 |
+
for handler in handlers:
|
249 |
+
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
250 |
+
handler.setFormatter(formatter)
|
251 |
+
|
252 |
+
|
253 |
+
def reset_format() -> None:
|
254 |
+
"""
|
255 |
+
Resets the formatting for HuggingFace muse' loggers.
|
256 |
+
|
257 |
+
All handlers currently bound to the root logger are affected by this method.
|
258 |
+
"""
|
259 |
+
handlers = _get_library_root_logger().handlers
|
260 |
+
|
261 |
+
for handler in handlers:
|
262 |
+
handler.setFormatter(None)
|
263 |
+
|
264 |
+
|
265 |
+
def warning_advice(self, *args, **kwargs):
|
266 |
+
"""
|
267 |
+
This method is identical to `logger.warning()`, but if env var muse_NO_ADVISORY_WARNINGS=1 is set, this
|
268 |
+
warning will not be printed
|
269 |
+
"""
|
270 |
+
no_advisory_warnings = os.getenv("muse_NO_ADVISORY_WARNINGS", False)
|
271 |
+
if no_advisory_warnings:
|
272 |
+
return
|
273 |
+
self.warning(*args, **kwargs)
|
274 |
+
|
275 |
+
|
276 |
+
logging.Logger.warning_advice = warning_advice
|
277 |
+
|
278 |
+
|
279 |
+
class EmptyTqdm:
|
280 |
+
"""Dummy tqdm which doesn't do anything."""
|
281 |
+
|
282 |
+
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
283 |
+
self._iterator = args[0] if args else None
|
284 |
+
|
285 |
+
def __iter__(self):
|
286 |
+
return iter(self._iterator)
|
287 |
+
|
288 |
+
def __getattr__(self, _):
|
289 |
+
"""Return empty function."""
|
290 |
+
|
291 |
+
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
292 |
+
return
|
293 |
+
|
294 |
+
return empty_fn
|
295 |
+
|
296 |
+
def __enter__(self):
|
297 |
+
return self
|
298 |
+
|
299 |
+
def __exit__(self, type_, value, traceback):
|
300 |
+
return
|
301 |
+
|
302 |
+
|
303 |
+
class _tqdm_cls:
|
304 |
+
def __call__(self, *args, **kwargs):
|
305 |
+
if _tqdm_active:
|
306 |
+
return tqdm_lib.tqdm(*args, **kwargs)
|
307 |
+
else:
|
308 |
+
return EmptyTqdm(*args, **kwargs)
|
309 |
+
|
310 |
+
def set_lock(self, *args, **kwargs):
|
311 |
+
self._lock = None
|
312 |
+
if _tqdm_active:
|
313 |
+
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
314 |
+
|
315 |
+
def get_lock(self):
|
316 |
+
if _tqdm_active:
|
317 |
+
return tqdm_lib.tqdm.get_lock()
|
318 |
+
|
319 |
+
|
320 |
+
tqdm = _tqdm_cls()
|
321 |
+
|
322 |
+
|
323 |
+
def is_progress_bar_enabled() -> bool:
|
324 |
+
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
325 |
+
global _tqdm_active
|
326 |
+
return bool(_tqdm_active)
|
327 |
+
|
328 |
+
|
329 |
+
def enable_progress_bar():
|
330 |
+
"""Enable tqdm progress bar."""
|
331 |
+
global _tqdm_active
|
332 |
+
_tqdm_active = True
|
333 |
+
|
334 |
+
|
335 |
+
def disable_progress_bar():
|
336 |
+
"""Disable tqdm progress bar."""
|
337 |
+
global _tqdm_active
|
338 |
+
_tqdm_active = False
|
models/lr_schedulers.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch optimization for diffusion models."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from enum import Enum
|
19 |
+
from typing import Optional, Union
|
20 |
+
|
21 |
+
from torch.optim import Optimizer
|
22 |
+
from torch.optim.lr_scheduler import LambdaLR
|
23 |
+
|
24 |
+
from .logging import get_logger
|
25 |
+
|
26 |
+
logger = get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class SchedulerType(Enum):
|
30 |
+
LINEAR = "linear"
|
31 |
+
COSINE = "cosine"
|
32 |
+
COSINE_WITH_RESTARTS = "cosine_with_restarts"
|
33 |
+
POLYNOMIAL = "polynomial"
|
34 |
+
CONSTANT = "constant"
|
35 |
+
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
36 |
+
|
37 |
+
|
38 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
39 |
+
"""
|
40 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
44 |
+
The optimizer for which to schedule the learning rate.
|
45 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
46 |
+
The index of the last epoch when resuming training.
|
47 |
+
|
48 |
+
Return:
|
49 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
50 |
+
"""
|
51 |
+
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
52 |
+
|
53 |
+
|
54 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
55 |
+
"""
|
56 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
57 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
61 |
+
The optimizer for which to schedule the learning rate.
|
62 |
+
num_warmup_steps (`int`):
|
63 |
+
The number of steps for the warmup phase.
|
64 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
65 |
+
The index of the last epoch when resuming training.
|
66 |
+
|
67 |
+
Return:
|
68 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def lr_lambda(current_step: int):
|
72 |
+
if current_step < num_warmup_steps:
|
73 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
74 |
+
return 1.0
|
75 |
+
|
76 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
77 |
+
|
78 |
+
|
79 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
80 |
+
"""
|
81 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
82 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
86 |
+
The optimizer for which to schedule the learning rate.
|
87 |
+
num_warmup_steps (`int`):
|
88 |
+
The number of steps for the warmup phase.
|
89 |
+
num_training_steps (`int`):
|
90 |
+
The total number of training steps.
|
91 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
92 |
+
The index of the last epoch when resuming training.
|
93 |
+
|
94 |
+
Return:
|
95 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def lr_lambda(current_step: int):
|
99 |
+
if current_step < num_warmup_steps:
|
100 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
101 |
+
return max(
|
102 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
103 |
+
)
|
104 |
+
|
105 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
106 |
+
|
107 |
+
|
108 |
+
def get_cosine_schedule_with_warmup(
|
109 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, min_lr_scale: float = 0.0
|
110 |
+
):
|
111 |
+
"""
|
112 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
113 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
114 |
+
initial lr set in the optimizer.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
118 |
+
The optimizer for which to schedule the learning rate.
|
119 |
+
num_warmup_steps (`int`):
|
120 |
+
The number of steps for the warmup phase.
|
121 |
+
num_training_steps (`int`):
|
122 |
+
The total number of training steps.
|
123 |
+
num_periods (`float`, *optional*, defaults to 0.5):
|
124 |
+
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
|
125 |
+
value to 0 following a half-cosine).
|
126 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
127 |
+
The index of the last epoch when resuming training.
|
128 |
+
|
129 |
+
Return:
|
130 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
131 |
+
"""
|
132 |
+
|
133 |
+
# def lr_lambda(current_step):
|
134 |
+
# if current_step < num_warmup_steps:
|
135 |
+
# return float(current_step) / float(max(1, num_warmup_steps))
|
136 |
+
# progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
137 |
+
# return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
138 |
+
|
139 |
+
# return LambdaLR(optimizer, lr_lambda, last_epoch)
|
140 |
+
|
141 |
+
def lr_lambda(current_step):
|
142 |
+
if current_step < num_warmup_steps:
|
143 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
144 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
145 |
+
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * 2.0 * num_cycles * progress))
|
146 |
+
return min_lr_scale + (1.0 - min_lr_scale) * cosine_decay
|
147 |
+
|
148 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
149 |
+
|
150 |
+
|
151 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
152 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
153 |
+
):
|
154 |
+
"""
|
155 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
156 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
157 |
+
linearly between 0 and the initial lr set in the optimizer.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
161 |
+
The optimizer for which to schedule the learning rate.
|
162 |
+
num_warmup_steps (`int`):
|
163 |
+
The number of steps for the warmup phase.
|
164 |
+
num_training_steps (`int`):
|
165 |
+
The total number of training steps.
|
166 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
167 |
+
The number of hard restarts to use.
|
168 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
169 |
+
The index of the last epoch when resuming training.
|
170 |
+
|
171 |
+
Return:
|
172 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def lr_lambda(current_step):
|
176 |
+
if current_step < num_warmup_steps:
|
177 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
178 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
179 |
+
if progress >= 1.0:
|
180 |
+
return 0.0
|
181 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
182 |
+
|
183 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
184 |
+
|
185 |
+
|
186 |
+
def get_polynomial_decay_schedule_with_warmup(
|
187 |
+
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
188 |
+
):
|
189 |
+
"""
|
190 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
191 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
192 |
+
initial lr set in the optimizer.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
196 |
+
The optimizer for which to schedule the learning rate.
|
197 |
+
num_warmup_steps (`int`):
|
198 |
+
The number of steps for the warmup phase.
|
199 |
+
num_training_steps (`int`):
|
200 |
+
The total number of training steps.
|
201 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
202 |
+
The end LR.
|
203 |
+
power (`float`, *optional*, defaults to 1.0):
|
204 |
+
Power factor.
|
205 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
206 |
+
The index of the last epoch when resuming training.
|
207 |
+
|
208 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
209 |
+
implementation at
|
210 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
211 |
+
|
212 |
+
Return:
|
213 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
214 |
+
|
215 |
+
"""
|
216 |
+
|
217 |
+
lr_init = optimizer.defaults["lr"]
|
218 |
+
if not (lr_init > lr_end):
|
219 |
+
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
|
220 |
+
|
221 |
+
def lr_lambda(current_step: int):
|
222 |
+
if current_step < num_warmup_steps:
|
223 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
224 |
+
elif current_step > num_training_steps:
|
225 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
226 |
+
else:
|
227 |
+
lr_range = lr_init - lr_end
|
228 |
+
decay_steps = num_training_steps - num_warmup_steps
|
229 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
230 |
+
decay = lr_range * pct_remaining**power + lr_end
|
231 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
232 |
+
|
233 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
234 |
+
|
235 |
+
|
236 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
237 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
238 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
239 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
240 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
241 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
242 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
243 |
+
}
|
244 |
+
|
245 |
+
|
246 |
+
def get_scheduler(
|
247 |
+
name: Union[str, SchedulerType],
|
248 |
+
optimizer: Optimizer,
|
249 |
+
num_warmup_steps: Optional[int] = None,
|
250 |
+
num_training_steps: Optional[int] = None,
|
251 |
+
num_cycles: int = 1,
|
252 |
+
power: float = 1.0,
|
253 |
+
min_lr_scale: float = 0.0
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
Unified API to get any scheduler from its name.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
name (`str` or `SchedulerType`):
|
260 |
+
The name of the scheduler to use.
|
261 |
+
optimizer (`torch.optim.Optimizer`):
|
262 |
+
The optimizer that will be used during training.
|
263 |
+
num_warmup_steps (`int`, *optional*):
|
264 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
265 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
266 |
+
num_training_steps (`int``, *optional*):
|
267 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
268 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
269 |
+
num_cycles (`int`, *optional*):
|
270 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
271 |
+
power (`float`, *optional*, defaults to 1.0):
|
272 |
+
Power factor. See `POLYNOMIAL` scheduler
|
273 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
274 |
+
The index of the last epoch when resuming training.
|
275 |
+
"""
|
276 |
+
name = SchedulerType(name)
|
277 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
278 |
+
if name == SchedulerType.CONSTANT:
|
279 |
+
return schedule_func(optimizer)
|
280 |
+
|
281 |
+
# All other schedulers require `num_warmup_steps`
|
282 |
+
if num_warmup_steps is None:
|
283 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
284 |
+
|
285 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
286 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
287 |
+
|
288 |
+
# All other schedulers require `num_training_steps`
|
289 |
+
if num_training_steps is None:
|
290 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
291 |
+
|
292 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
293 |
+
return schedule_func(
|
294 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, min_lr_scale=min_lr_scale
|
295 |
+
)
|
296 |
+
|
297 |
+
if name == SchedulerType.POLYNOMIAL:
|
298 |
+
return schedule_func(
|
299 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
300 |
+
)
|
301 |
+
|
302 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
models/misc.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import OmegaConf
|
2 |
+
import torch
|
3 |
+
from typing import (
|
4 |
+
Any,
|
5 |
+
Callable,
|
6 |
+
Dict,
|
7 |
+
Iterable,
|
8 |
+
List,
|
9 |
+
NamedTuple,
|
10 |
+
NewType,
|
11 |
+
Optional,
|
12 |
+
Sized,
|
13 |
+
Tuple,
|
14 |
+
Type,
|
15 |
+
TypeVar,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
+
try:
|
19 |
+
from typing import Literal
|
20 |
+
except ImportError:
|
21 |
+
from typing_extensions import Literal
|
22 |
+
|
23 |
+
# Tensor dtype
|
24 |
+
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
25 |
+
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
26 |
+
|
27 |
+
# Config type
|
28 |
+
from omegaconf import DictConfig
|
29 |
+
|
30 |
+
# PyTorch Tensor type
|
31 |
+
from torch import Tensor
|
32 |
+
|
33 |
+
# Runtime type checking decorator
|
34 |
+
from typeguard import typechecked as typechecker
|
35 |
+
|
36 |
+
|
37 |
+
def broadcast(tensor, src=0):
|
38 |
+
if not _distributed_available():
|
39 |
+
return tensor
|
40 |
+
else:
|
41 |
+
torch.distributed.broadcast(tensor, src=src)
|
42 |
+
return tensor
|
43 |
+
|
44 |
+
def _distributed_available():
|
45 |
+
return torch.distributed.is_available() and torch.distributed.is_initialized()
|
46 |
+
|
47 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
48 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
49 |
+
if '--local-rank' in cfg:
|
50 |
+
del cfg['--local-rank']
|
51 |
+
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
|
52 |
+
scfg = OmegaConf.structured(fields(**cfg))
|
53 |
+
return scfg
|
models/modeling_llada.py
ADDED
@@ -0,0 +1,1500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import sys
|
6 |
+
from abc import abstractmethod
|
7 |
+
from collections import defaultdict
|
8 |
+
from functools import partial
|
9 |
+
from typing import (
|
10 |
+
Callable,
|
11 |
+
Dict,
|
12 |
+
Iterable,
|
13 |
+
List,
|
14 |
+
NamedTuple,
|
15 |
+
Optional,
|
16 |
+
Sequence,
|
17 |
+
Set,
|
18 |
+
Tuple,
|
19 |
+
cast,
|
20 |
+
)
|
21 |
+
from dataclasses import fields
|
22 |
+
from typing import List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.backends.cuda
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from torch import einsum
|
29 |
+
from transformers import PreTrainedModel
|
30 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
31 |
+
from transformers.models.auto import AutoModel
|
32 |
+
from transformers.cache_utils import Cache
|
33 |
+
|
34 |
+
from .configuration_llada import (
|
35 |
+
LLaDAConfig,
|
36 |
+
StrEnum,
|
37 |
+
InitFnType,
|
38 |
+
ActivationType,
|
39 |
+
BlockType,
|
40 |
+
LayerNormType,
|
41 |
+
ModelConfig,
|
42 |
+
ActivationCheckpointingStrategy,
|
43 |
+
)
|
44 |
+
|
45 |
+
if sys.version_info.minor > 8:
|
46 |
+
from collections.abc import MutableMapping
|
47 |
+
elif sys.version_info.minor == 8:
|
48 |
+
from typing import MutableMapping
|
49 |
+
else:
|
50 |
+
raise SystemExit("This script supports Python 3.8 or higher")
|
51 |
+
|
52 |
+
__all__ = [
|
53 |
+
"LayerNormBase",
|
54 |
+
"LayerNorm",
|
55 |
+
"RMSLayerNorm",
|
56 |
+
"GemmaRMSLayerNorm",
|
57 |
+
"RotaryEmbedding",
|
58 |
+
"Activation",
|
59 |
+
"GELU",
|
60 |
+
"ReLU",
|
61 |
+
"SwiGLU",
|
62 |
+
"LLaDABlock",
|
63 |
+
"LLaDASequentialBlock",
|
64 |
+
"LLaDAModel",
|
65 |
+
"LLaDAOutput",
|
66 |
+
"LLaDAGenerateOutput",
|
67 |
+
]
|
68 |
+
|
69 |
+
|
70 |
+
log = logging.getLogger(__name__)
|
71 |
+
|
72 |
+
|
73 |
+
class ModuleType(StrEnum):
|
74 |
+
in_module = "in"
|
75 |
+
out_module = "out"
|
76 |
+
emb = "emb"
|
77 |
+
final_out = "final_out"
|
78 |
+
|
79 |
+
|
80 |
+
def init_weights(
|
81 |
+
config: ModelConfig,
|
82 |
+
module: Union[nn.Linear, nn.Embedding],
|
83 |
+
d: Optional[int] = None,
|
84 |
+
layer_id: Optional[int] = None,
|
85 |
+
std_factor: float = 1.0,
|
86 |
+
type_of_module: Optional[ModuleType] = None,
|
87 |
+
) -> None:
|
88 |
+
"""
|
89 |
+
Initialize weights of a linear or embedding module.
|
90 |
+
|
91 |
+
:param config: The model config.
|
92 |
+
:param module: The linear or embedding submodule to initialize.
|
93 |
+
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
|
94 |
+
for fused layers.
|
95 |
+
:param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
|
96 |
+
``1 / sqrt(2 * (layer_id + 1))``.
|
97 |
+
"""
|
98 |
+
d = d if d is not None else config.d_model
|
99 |
+
if config.init_fn == InitFnType.normal:
|
100 |
+
std = config.init_std * std_factor
|
101 |
+
if config.init_cutoff_factor is not None:
|
102 |
+
cutoff_value = config.init_cutoff_factor * std
|
103 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
|
104 |
+
else:
|
105 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
106 |
+
elif config.init_fn == InitFnType.mitchell:
|
107 |
+
std = std_factor / math.sqrt(d)
|
108 |
+
if layer_id is not None:
|
109 |
+
std = std / math.sqrt(2 * (layer_id + 1))
|
110 |
+
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
111 |
+
elif config.init_fn == InitFnType.kaiming_normal:
|
112 |
+
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
|
113 |
+
elif config.init_fn == InitFnType.fan_in:
|
114 |
+
std = std_factor / math.sqrt(d)
|
115 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
116 |
+
elif config.init_fn == InitFnType.full_megatron:
|
117 |
+
if type_of_module is None:
|
118 |
+
raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
|
119 |
+
|
120 |
+
cutoff_factor = config.init_cutoff_factor
|
121 |
+
if cutoff_factor is None:
|
122 |
+
cutoff_factor = 3
|
123 |
+
|
124 |
+
if type_of_module == ModuleType.in_module:
|
125 |
+
# for att_proj (same as QKV), ff_proj
|
126 |
+
std = config.init_std
|
127 |
+
elif type_of_module == ModuleType.out_module:
|
128 |
+
# for attn_out, ff_out
|
129 |
+
std = config.init_std / math.sqrt(2.0 * config.n_layers)
|
130 |
+
elif type_of_module == ModuleType.emb:
|
131 |
+
# positional embeddings (wpe)
|
132 |
+
# token embeddings (wte)
|
133 |
+
std = config.init_std
|
134 |
+
elif type_of_module == ModuleType.final_out:
|
135 |
+
# final output (ff_out)
|
136 |
+
std = config.d_model**-0.5
|
137 |
+
else:
|
138 |
+
raise RuntimeError(f"Unknown module type '{type_of_module}'")
|
139 |
+
nn.init.trunc_normal_(
|
140 |
+
module.weight,
|
141 |
+
mean=0.0,
|
142 |
+
std=std,
|
143 |
+
a=-cutoff_factor * std,
|
144 |
+
b=cutoff_factor * std,
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
raise NotImplementedError(config.init_fn)
|
148 |
+
|
149 |
+
if isinstance(module, nn.Linear):
|
150 |
+
if module.bias is not None:
|
151 |
+
nn.init.zeros_(module.bias)
|
152 |
+
|
153 |
+
if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
|
154 |
+
with torch.no_grad():
|
155 |
+
module.weight.div_(math.sqrt(2 * config.n_layers))
|
156 |
+
|
157 |
+
|
158 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
159 |
+
"""
|
160 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
161 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
162 |
+
"""
|
163 |
+
if check_neg_inf:
|
164 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
165 |
+
if check_pos_inf:
|
166 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
167 |
+
|
168 |
+
|
169 |
+
def activation_checkpoint_function(cfg: ModelConfig):
|
170 |
+
preserve_rng_state = (
|
171 |
+
(cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
|
172 |
+
)
|
173 |
+
from torch.utils.checkpoint import checkpoint
|
174 |
+
|
175 |
+
return partial(
|
176 |
+
checkpoint,
|
177 |
+
preserve_rng_state=preserve_rng_state,
|
178 |
+
use_reentrant=False,
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
|
183 |
+
"""
|
184 |
+
Cache for attention biases and other things that would normally be stored as buffers.
|
185 |
+
We avoid using buffers because we've run into various issues doing so with FSDP.
|
186 |
+
In general it appears the way FSDP handles buffers is not well-defined.
|
187 |
+
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
|
188 |
+
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
|
189 |
+
NaNs when they're synchronized due to casting or some other issue.
|
190 |
+
"""
|
191 |
+
|
192 |
+
|
193 |
+
def _non_meta_init_device(config: ModelConfig) -> torch.device:
|
194 |
+
if config.init_device is not None and config.init_device != "meta":
|
195 |
+
return torch.device(config.init_device)
|
196 |
+
else:
|
197 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
198 |
+
|
199 |
+
|
200 |
+
class Dropout(nn.Dropout):
|
201 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
202 |
+
if self.p == 0.0:
|
203 |
+
return input
|
204 |
+
else:
|
205 |
+
return F.dropout(input, self.p, self.training, self.inplace)
|
206 |
+
|
207 |
+
|
208 |
+
class LayerNormBase(nn.Module):
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
config: ModelConfig,
|
212 |
+
*,
|
213 |
+
size: Optional[int] = None,
|
214 |
+
elementwise_affine: Optional[bool] = True,
|
215 |
+
eps: float = 1e-05,
|
216 |
+
):
|
217 |
+
super().__init__()
|
218 |
+
self.config = config
|
219 |
+
self.eps = eps
|
220 |
+
self.normalized_shape = (size or config.d_model,)
|
221 |
+
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
|
222 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
|
223 |
+
use_bias = self.config.bias_for_layer_norm
|
224 |
+
if use_bias is None:
|
225 |
+
use_bias = self.config.include_bias
|
226 |
+
if use_bias:
|
227 |
+
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
|
228 |
+
else:
|
229 |
+
self.register_parameter("bias", None)
|
230 |
+
else:
|
231 |
+
self.register_parameter("bias", None)
|
232 |
+
self.register_parameter("weight", None)
|
233 |
+
|
234 |
+
@abstractmethod
|
235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
236 |
+
raise NotImplementedError
|
237 |
+
|
238 |
+
@classmethod
|
239 |
+
def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
|
240 |
+
if config.layer_norm_type == LayerNormType.default:
|
241 |
+
return LayerNorm(config, size=size, low_precision=False, **kwargs)
|
242 |
+
elif config.layer_norm_type == LayerNormType.low_precision:
|
243 |
+
return LayerNorm(config, size=size, low_precision=True, **kwargs)
|
244 |
+
elif config.layer_norm_type == LayerNormType.rms:
|
245 |
+
return RMSLayerNorm(config, size=size, **kwargs)
|
246 |
+
elif config.layer_norm_type == LayerNormType.gemma_rms:
|
247 |
+
return GemmaRMSLayerNorm(config, size=size, **kwargs)
|
248 |
+
else:
|
249 |
+
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
|
250 |
+
|
251 |
+
def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
252 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
253 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
254 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
255 |
+
if tensor.device.type == "cuda" and torch.is_autocast_enabled():
|
256 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
|
257 |
+
elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
258 |
+
return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
|
259 |
+
else:
|
260 |
+
return tensor
|
261 |
+
|
262 |
+
def reset_parameters(self):
|
263 |
+
if self.weight is not None:
|
264 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
265 |
+
if self.bias is not None:
|
266 |
+
torch.nn.init.zeros_(self.bias) # type: ignore
|
267 |
+
|
268 |
+
|
269 |
+
class LayerNorm(LayerNormBase):
|
270 |
+
"""
|
271 |
+
The default :class:`LayerNorm` implementation which can optionally run in low precision.
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(
|
275 |
+
self,
|
276 |
+
config: ModelConfig,
|
277 |
+
size: Optional[int] = None,
|
278 |
+
low_precision: bool = False,
|
279 |
+
elementwise_affine: Optional[bool] = None,
|
280 |
+
eps: float = 1e-05,
|
281 |
+
):
|
282 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
|
283 |
+
self.low_precision = low_precision
|
284 |
+
|
285 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
286 |
+
if self.low_precision:
|
287 |
+
module_device = x.device
|
288 |
+
downcast_x = self._cast_if_autocast_enabled(x)
|
289 |
+
downcast_weight = (
|
290 |
+
self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
291 |
+
)
|
292 |
+
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
293 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
294 |
+
return F.layer_norm(
|
295 |
+
downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
299 |
+
|
300 |
+
|
301 |
+
class RMSLayerNorm(LayerNormBase):
|
302 |
+
"""
|
303 |
+
RMS layer norm, a simplified :class:`LayerNorm` implementation
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
config: ModelConfig,
|
309 |
+
size: Optional[int] = None,
|
310 |
+
elementwise_affine: Optional[bool] = None,
|
311 |
+
eps: float = 1e-5,
|
312 |
+
):
|
313 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
314 |
+
|
315 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
316 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
317 |
+
og_dtype = x.dtype
|
318 |
+
x = x.to(torch.float32)
|
319 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
320 |
+
x = x * torch.rsqrt(variance + self.eps)
|
321 |
+
x = x.to(og_dtype)
|
322 |
+
|
323 |
+
if self.weight is not None:
|
324 |
+
if self.bias is not None:
|
325 |
+
return self.weight * x + self.bias
|
326 |
+
else:
|
327 |
+
return self.weight * x
|
328 |
+
else:
|
329 |
+
return x
|
330 |
+
|
331 |
+
|
332 |
+
class GemmaRMSLayerNorm(LayerNormBase):
|
333 |
+
"""
|
334 |
+
Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
|
335 |
+
"""
|
336 |
+
|
337 |
+
def __init__(
|
338 |
+
self,
|
339 |
+
config: ModelConfig,
|
340 |
+
size: Optional[int] = None,
|
341 |
+
elementwise_affine: Optional[bool] = None,
|
342 |
+
eps: float = 1e-5,
|
343 |
+
):
|
344 |
+
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
|
345 |
+
|
346 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
347 |
+
with torch.autocast(enabled=False, device_type=x.device.type):
|
348 |
+
og_dtype = x.dtype
|
349 |
+
x = x.to(torch.float32)
|
350 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
351 |
+
x = x * torch.rsqrt(variance + self.eps)
|
352 |
+
x = x.to(og_dtype)
|
353 |
+
|
354 |
+
if self.weight is not None:
|
355 |
+
if self.bias is not None:
|
356 |
+
return x * (1 + self.weight) + self.bias
|
357 |
+
else:
|
358 |
+
return x * (1 + self.weight)
|
359 |
+
else:
|
360 |
+
return x
|
361 |
+
|
362 |
+
|
363 |
+
class RotaryEmbedding(nn.Module):
|
364 |
+
"""
|
365 |
+
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
|
366 |
+
"""
|
367 |
+
|
368 |
+
def __init__(self, config: ModelConfig, cache: BufferCache):
|
369 |
+
super().__init__()
|
370 |
+
self.config = config
|
371 |
+
self.__cache = cache
|
372 |
+
# Warm up cache.
|
373 |
+
self.rope_theta = config.rope_theta
|
374 |
+
self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
|
375 |
+
|
376 |
+
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
377 |
+
if (
|
378 |
+
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
|
379 |
+
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
|
380 |
+
and pos_sin.shape[-2] >= seq_len
|
381 |
+
and pos_cos.shape[-2] >= seq_len
|
382 |
+
):
|
383 |
+
if pos_sin.device != device:
|
384 |
+
pos_sin = pos_sin.to(device)
|
385 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
386 |
+
if pos_cos.device != device:
|
387 |
+
pos_cos = pos_cos.to(device)
|
388 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
389 |
+
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
|
390 |
+
|
391 |
+
with torch.autocast(device.type, enabled=False):
|
392 |
+
dim = self.config.d_model // self.config.n_heads
|
393 |
+
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
394 |
+
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
395 |
+
freqs = einsum("i , j -> i j", seq, inv_freq)
|
396 |
+
positions = torch.cat((freqs, freqs), dim=-1)
|
397 |
+
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
|
398 |
+
self.__cache["rope_pos_sin"] = pos_sin
|
399 |
+
self.__cache["rope_pos_cos"] = pos_cos
|
400 |
+
return pos_sin, pos_cos
|
401 |
+
|
402 |
+
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
|
403 |
+
B, nh, T, hs = x.size()
|
404 |
+
x = x.view(B, nh, T, 2, hs // 2)
|
405 |
+
x1, x2 = x.unbind(dim=-2)
|
406 |
+
return torch.cat((-x2, x1), dim=-1)
|
407 |
+
|
408 |
+
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
409 |
+
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
410 |
+
|
411 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
412 |
+
if self.config.rope_full_precision:
|
413 |
+
q_, k_ = q.float(), k.float()
|
414 |
+
else:
|
415 |
+
q_, k_ = q, k
|
416 |
+
|
417 |
+
with torch.autocast(q.device.type, enabled=False):
|
418 |
+
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
|
419 |
+
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
|
420 |
+
pos_sin = pos_sin.type_as(q_)
|
421 |
+
pos_cos = pos_cos.type_as(q_)
|
422 |
+
q_ = self.apply_rotary_pos_emb(
|
423 |
+
pos_sin[:, :, key_len - query_len : key_len, :],
|
424 |
+
pos_cos[:, :, key_len - query_len : key_len, :],
|
425 |
+
q_,
|
426 |
+
)
|
427 |
+
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
|
428 |
+
return q_.type_as(q), k_.type_as(k)
|
429 |
+
|
430 |
+
|
431 |
+
class Activation(nn.Module):
|
432 |
+
def __init__(self, config: ModelConfig):
|
433 |
+
super().__init__()
|
434 |
+
self.config = config
|
435 |
+
|
436 |
+
@abstractmethod
|
437 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
438 |
+
raise NotImplementedError
|
439 |
+
|
440 |
+
@property
|
441 |
+
@abstractmethod
|
442 |
+
def output_multiplier(self) -> float:
|
443 |
+
raise NotImplementedError
|
444 |
+
|
445 |
+
@classmethod
|
446 |
+
def build(cls, config: ModelConfig) -> Activation:
|
447 |
+
if config.activation_type == ActivationType.gelu:
|
448 |
+
return cast(Activation, GELU(approximate="none"))
|
449 |
+
elif config.activation_type == ActivationType.relu:
|
450 |
+
return cast(Activation, ReLU(inplace=False))
|
451 |
+
elif config.activation_type == ActivationType.silu:
|
452 |
+
return cast(Activation, SiLU(inplace=False))
|
453 |
+
elif config.activation_type == ActivationType.swiglu:
|
454 |
+
return SwiGLU(config)
|
455 |
+
else:
|
456 |
+
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
|
457 |
+
|
458 |
+
|
459 |
+
class GELU(nn.GELU):
|
460 |
+
@property
|
461 |
+
def output_multiplier(self) -> float:
|
462 |
+
return 1.0
|
463 |
+
|
464 |
+
|
465 |
+
class ReLU(nn.ReLU):
|
466 |
+
@property
|
467 |
+
def output_multiplier(self) -> float:
|
468 |
+
return 1.0
|
469 |
+
|
470 |
+
class SiLU(nn.SiLU):
|
471 |
+
@property
|
472 |
+
def output_multiplier(self) -> float:
|
473 |
+
return 1.0
|
474 |
+
|
475 |
+
class SwiGLU(Activation):
|
476 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
477 |
+
x, gate = x.chunk(2, dim=-1)
|
478 |
+
return F.silu(gate) * x
|
479 |
+
|
480 |
+
@property
|
481 |
+
def output_multiplier(self) -> float:
|
482 |
+
return 0.5
|
483 |
+
|
484 |
+
|
485 |
+
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
|
486 |
+
att_bias = torch.triu(
|
487 |
+
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
|
488 |
+
diagonal=1,
|
489 |
+
)
|
490 |
+
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
|
491 |
+
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
|
492 |
+
|
493 |
+
|
494 |
+
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
|
495 |
+
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
|
496 |
+
if causal_bias.device != device:
|
497 |
+
causal_bias = causal_bias.to(device)
|
498 |
+
cache["causal_attention_bias"] = causal_bias
|
499 |
+
return causal_bias
|
500 |
+
with torch.autocast(device.type, enabled=False):
|
501 |
+
causal_bias = causal_attention_bias(seq_len, device)
|
502 |
+
cache["causal_attention_bias"] = causal_bias
|
503 |
+
return causal_bias
|
504 |
+
|
505 |
+
|
506 |
+
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
|
507 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
|
508 |
+
|
509 |
+
# shape: (1, 1, seq_len, seq_len)
|
510 |
+
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
|
511 |
+
alibi_bias.abs_().mul_(-1)
|
512 |
+
|
513 |
+
# shape: (n_heads,)
|
514 |
+
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
|
515 |
+
m.mul_(config.alibi_bias_max / config.n_heads)
|
516 |
+
|
517 |
+
# shape: (1, n_heads, seq_len, seq_len)
|
518 |
+
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
|
519 |
+
|
520 |
+
|
521 |
+
class LLaDABlock(nn.Module):
|
522 |
+
"""
|
523 |
+
A base class for transformer block implementations.
|
524 |
+
"""
|
525 |
+
|
526 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
527 |
+
super().__init__()
|
528 |
+
self.layer_id = layer_id
|
529 |
+
self.config = config
|
530 |
+
self.hidden_size = (
|
531 |
+
config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
|
532 |
+
)
|
533 |
+
self.__cache = cache
|
534 |
+
assert config.d_model % config.n_heads == 0
|
535 |
+
|
536 |
+
self._activation_checkpoint_fn = None
|
537 |
+
|
538 |
+
# Dropout.
|
539 |
+
self.dropout = Dropout(config.residual_dropout)
|
540 |
+
|
541 |
+
# Layer norms.
|
542 |
+
self.k_norm: Optional[LayerNormBase] = None
|
543 |
+
self.q_norm: Optional[LayerNormBase] = None
|
544 |
+
if config.attention_layer_norm:
|
545 |
+
self.k_norm = LayerNormBase.build(
|
546 |
+
config,
|
547 |
+
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
|
548 |
+
elementwise_affine=config.attention_layer_norm_with_affine,
|
549 |
+
)
|
550 |
+
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
|
551 |
+
|
552 |
+
# Activation function.
|
553 |
+
self.act = Activation.build(config)
|
554 |
+
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
|
555 |
+
|
556 |
+
# Attention output projection.
|
557 |
+
self.attn_out = nn.Linear(
|
558 |
+
config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
|
559 |
+
)
|
560 |
+
|
561 |
+
# Feed-forward output projection.
|
562 |
+
self.ff_out = nn.Linear(
|
563 |
+
int(self.act.output_multiplier * self.hidden_size),
|
564 |
+
config.d_model,
|
565 |
+
bias=config.include_bias,
|
566 |
+
device=config.init_device,
|
567 |
+
)
|
568 |
+
self.ff_out._is_residual = True # type: ignore
|
569 |
+
|
570 |
+
# Rotary embeddings.
|
571 |
+
if self.config.rope:
|
572 |
+
self.rotary_emb = RotaryEmbedding(config, self.__cache)
|
573 |
+
|
574 |
+
self.flash_attn_func = None
|
575 |
+
if config.flash_attention:
|
576 |
+
try:
|
577 |
+
from flash_attn import flash_attn_func # type: ignore
|
578 |
+
|
579 |
+
self.flash_attn_func = flash_attn_func
|
580 |
+
except ModuleNotFoundError:
|
581 |
+
pass
|
582 |
+
|
583 |
+
def reset_parameters(self):
|
584 |
+
if self.k_norm is not None:
|
585 |
+
self.k_norm.reset_parameters()
|
586 |
+
if self.q_norm is not None:
|
587 |
+
self.q_norm.reset_parameters()
|
588 |
+
init_weights(
|
589 |
+
self.config,
|
590 |
+
self.attn_out,
|
591 |
+
d=self.config.d_model,
|
592 |
+
layer_id=self.layer_id,
|
593 |
+
type_of_module=ModuleType.out_module,
|
594 |
+
)
|
595 |
+
init_weights(
|
596 |
+
self.config,
|
597 |
+
self.ff_out,
|
598 |
+
d=self.ff_out.in_features,
|
599 |
+
layer_id=self.layer_id,
|
600 |
+
type_of_module=ModuleType.out_module,
|
601 |
+
)
|
602 |
+
|
603 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
604 |
+
if strategy == ActivationCheckpointingStrategy.fine_grained:
|
605 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
606 |
+
else:
|
607 |
+
self._activation_checkpoint_fn = None
|
608 |
+
|
609 |
+
@classmethod
|
610 |
+
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
|
611 |
+
target_dtype = input_dtype
|
612 |
+
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
|
613 |
+
# `is_autocast_cpu_enabled()` for CPU autocast.
|
614 |
+
# See https://github.com/pytorch/pytorch/issues/110966.
|
615 |
+
if bias.device.type == "cuda" and torch.is_autocast_enabled():
|
616 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
617 |
+
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
|
618 |
+
target_dtype = torch.get_autocast_cpu_dtype()
|
619 |
+
if bias.dtype != target_dtype:
|
620 |
+
bias = bias.to(target_dtype)
|
621 |
+
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
|
622 |
+
return bias
|
623 |
+
|
624 |
+
def _scaled_dot_product_attention(
|
625 |
+
self,
|
626 |
+
q: torch.Tensor,
|
627 |
+
k: torch.Tensor,
|
628 |
+
v: torch.Tensor,
|
629 |
+
attn_mask: Optional[torch.Tensor] = None,
|
630 |
+
dropout_p: float = 0.0,
|
631 |
+
is_causal: bool = False,
|
632 |
+
) -> torch.Tensor:
|
633 |
+
"""
|
634 |
+
Computes scaled dot product attention on query, key and value tensors, using an optional
|
635 |
+
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
|
636 |
+
"""
|
637 |
+
if self.flash_attn_func is not None and attn_mask is None:
|
638 |
+
r = self.flash_attn_func(
|
639 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
|
640 |
+
)
|
641 |
+
return r.transpose(1, 2)
|
642 |
+
else:
|
643 |
+
# torch's sdpa doesn't support GQA, so we're doing this
|
644 |
+
assert k.size(1) == v.size(1)
|
645 |
+
num_kv_heads = k.size(1)
|
646 |
+
num_q_heads = q.size(1)
|
647 |
+
if num_q_heads != num_kv_heads:
|
648 |
+
assert num_q_heads % num_kv_heads == 0
|
649 |
+
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
650 |
+
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
651 |
+
|
652 |
+
# Modify: MDM set causal to False, and with no attn_mask.
|
653 |
+
return F.scaled_dot_product_attention(
|
654 |
+
q,
|
655 |
+
k,
|
656 |
+
v,
|
657 |
+
attn_mask=None,
|
658 |
+
dropout_p=dropout_p,
|
659 |
+
is_causal=False,
|
660 |
+
)
|
661 |
+
|
662 |
+
def attention(
|
663 |
+
self,
|
664 |
+
q: torch.Tensor,
|
665 |
+
k: torch.Tensor,
|
666 |
+
v: torch.Tensor,
|
667 |
+
attention_bias: Optional[torch.Tensor] = None,
|
668 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
669 |
+
use_cache: bool = False,
|
670 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
671 |
+
B, T, C = q.size() # batch size, sequence length, d_model
|
672 |
+
dtype = k.dtype
|
673 |
+
|
674 |
+
# Optionally apply layer norm to keys and queries.
|
675 |
+
if self.q_norm is not None and self.k_norm is not None:
|
676 |
+
q = self.q_norm(q).to(dtype=dtype)
|
677 |
+
k = self.k_norm(k).to(dtype=dtype)
|
678 |
+
|
679 |
+
# Move head forward to be next to the batch dim.
|
680 |
+
# shape: (B, nh, T, hs)
|
681 |
+
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
|
682 |
+
# shape: (B, n_kv_h, T, hs)
|
683 |
+
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
684 |
+
# shape: (B, n_kv_h, T, hs)
|
685 |
+
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
|
686 |
+
|
687 |
+
if layer_past is not None:
|
688 |
+
past_key, past_value = layer_past
|
689 |
+
k = torch.cat((past_key, k), dim=-2)
|
690 |
+
v = torch.cat((past_value, v), dim=-2)
|
691 |
+
|
692 |
+
present = (k, v) if use_cache else None
|
693 |
+
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
694 |
+
|
695 |
+
if self.config.rope:
|
696 |
+
# Apply rotary embeddings.
|
697 |
+
q, k = self.rotary_emb(q, k)
|
698 |
+
|
699 |
+
if attention_bias is not None:
|
700 |
+
# Resize and cast attention bias.
|
701 |
+
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
|
702 |
+
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
|
703 |
+
# as down-casting the attention bias to the autocast precision will result in -infs, which will
|
704 |
+
# cause the SDP attn function to produce NaNs.
|
705 |
+
attention_bias = self._cast_attn_bias(
|
706 |
+
attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
|
707 |
+
)
|
708 |
+
|
709 |
+
# Get the attention scores.
|
710 |
+
# shape: (B, nh, T, hs)
|
711 |
+
att = self._scaled_dot_product_attention(
|
712 |
+
q,
|
713 |
+
k,
|
714 |
+
v,
|
715 |
+
attn_mask=None,
|
716 |
+
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
717 |
+
is_causal=False,
|
718 |
+
)
|
719 |
+
|
720 |
+
# Re-assemble all head outputs side-by-side.
|
721 |
+
att = att.transpose(1, 2).contiguous().view(B, T, C)
|
722 |
+
|
723 |
+
# Apply output projection.
|
724 |
+
return self.attn_out(att), present
|
725 |
+
|
726 |
+
@abstractmethod
|
727 |
+
def forward(
|
728 |
+
self,
|
729 |
+
x: torch.Tensor,
|
730 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
731 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
732 |
+
use_cache: bool = False,
|
733 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
734 |
+
raise NotImplementedError
|
735 |
+
|
736 |
+
@classmethod
|
737 |
+
def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
|
738 |
+
if config.block_type == BlockType.sequential:
|
739 |
+
return LLaDASequentialBlock(layer_id, config, cache)
|
740 |
+
elif config.block_type == BlockType.llama:
|
741 |
+
return LLaDALlamaBlock(layer_id, config, cache)
|
742 |
+
else:
|
743 |
+
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
|
744 |
+
|
745 |
+
|
746 |
+
class LLaDASequentialBlock(LLaDABlock):
|
747 |
+
"""
|
748 |
+
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
749 |
+
(plus another skip connection).
|
750 |
+
"""
|
751 |
+
|
752 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
753 |
+
super().__init__(layer_id, config, cache)
|
754 |
+
# Layer norms.
|
755 |
+
self.attn_norm = LayerNorm.build(config)
|
756 |
+
self.ff_norm = LayerNorm.build(config)
|
757 |
+
# Attention input projection. Projects x -> (q, k, v)
|
758 |
+
head_dim = config.d_model // config.n_heads
|
759 |
+
self.fused_dims = (
|
760 |
+
config.d_model,
|
761 |
+
config.effective_n_kv_heads * head_dim,
|
762 |
+
config.effective_n_kv_heads * head_dim,
|
763 |
+
)
|
764 |
+
self.att_proj = nn.Linear(
|
765 |
+
config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
766 |
+
)
|
767 |
+
# Feed-forward input projection.
|
768 |
+
self.ff_proj = nn.Linear(
|
769 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
770 |
+
)
|
771 |
+
|
772 |
+
def reset_parameters(self):
|
773 |
+
super().reset_parameters()
|
774 |
+
self.attn_norm.reset_parameters()
|
775 |
+
self.ff_norm.reset_parameters()
|
776 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
777 |
+
init_weights(
|
778 |
+
self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
779 |
+
)
|
780 |
+
init_weights(
|
781 |
+
self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
|
782 |
+
)
|
783 |
+
|
784 |
+
def forward(
|
785 |
+
self,
|
786 |
+
x: torch.Tensor,
|
787 |
+
attention_bias: Optional[torch.Tensor] = None,
|
788 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
789 |
+
use_cache: bool = False,
|
790 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
791 |
+
# Get query, key, value projections.
|
792 |
+
# shape:
|
793 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
794 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
795 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
796 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
797 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
798 |
+
if self._activation_checkpoint_fn is not None:
|
799 |
+
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
|
800 |
+
self.fused_dims, dim=-1
|
801 |
+
)
|
802 |
+
else:
|
803 |
+
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
|
804 |
+
|
805 |
+
# Get attention scores.
|
806 |
+
if self._activation_checkpoint_fn is not None:
|
807 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
808 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
809 |
+
)
|
810 |
+
else:
|
811 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
812 |
+
|
813 |
+
# Add attention scores.
|
814 |
+
# shape: (B, T, C)
|
815 |
+
x = x + self.dropout(att)
|
816 |
+
|
817 |
+
# Add feed-forward projection.
|
818 |
+
# shape: (batch_size, seq_len, d_model)
|
819 |
+
og_x = x
|
820 |
+
if self._activation_checkpoint_fn is not None:
|
821 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
822 |
+
else:
|
823 |
+
x = self.ff_norm(x)
|
824 |
+
x = self.ff_proj(x)
|
825 |
+
if self._activation_checkpoint_fn is not None:
|
826 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
827 |
+
else:
|
828 |
+
x = self.act(x)
|
829 |
+
x = self.ff_out(x)
|
830 |
+
x = self.dropout(x)
|
831 |
+
x = og_x + x
|
832 |
+
|
833 |
+
return x, cache
|
834 |
+
|
835 |
+
|
836 |
+
class LLaDALlamaBlock(LLaDABlock):
|
837 |
+
"""
|
838 |
+
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
839 |
+
(plus another skip connection). This block is similar to `LLaDASequentialBlock`
|
840 |
+
but some operations have slightly different implementations to imitate the
|
841 |
+
behavior of Llama.
|
842 |
+
"""
|
843 |
+
|
844 |
+
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
|
845 |
+
super().__init__(layer_id, config, cache)
|
846 |
+
# Layer norms.
|
847 |
+
self.attn_norm = LayerNorm.build(config)
|
848 |
+
self.ff_norm = LayerNorm.build(config)
|
849 |
+
self.__cache = cache
|
850 |
+
|
851 |
+
# Attention input projection. Projects x -> (q, k, v)
|
852 |
+
head_dim = config.d_model // config.n_heads
|
853 |
+
q_proj_out_dim = config.d_model
|
854 |
+
k_proj_out_dim = config.effective_n_kv_heads * head_dim
|
855 |
+
v_proj_out_dim = config.effective_n_kv_heads * head_dim
|
856 |
+
self.q_proj = nn.Linear(
|
857 |
+
config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
858 |
+
)
|
859 |
+
self.k_proj = nn.Linear(
|
860 |
+
config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
861 |
+
)
|
862 |
+
self.v_proj = nn.Linear(
|
863 |
+
config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
|
864 |
+
)
|
865 |
+
|
866 |
+
# Feed-forward input projection.
|
867 |
+
self.ff_proj = nn.Linear(
|
868 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
869 |
+
)
|
870 |
+
# new add
|
871 |
+
self.up_proj = nn.Linear(
|
872 |
+
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
|
873 |
+
)
|
874 |
+
|
875 |
+
def reset_parameters(self):
|
876 |
+
super().reset_parameters()
|
877 |
+
self.attn_norm.reset_parameters()
|
878 |
+
self.ff_norm.reset_parameters()
|
879 |
+
# NOTE: the standard deviation for these weights does not depend on the layer.
|
880 |
+
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
|
881 |
+
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
|
882 |
+
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
|
883 |
+
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
|
884 |
+
init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
|
885 |
+
|
886 |
+
def forward(
|
887 |
+
self,
|
888 |
+
x: torch.Tensor,
|
889 |
+
attention_bias: Optional[torch.Tensor] = None,
|
890 |
+
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
891 |
+
use_cache: bool = False,
|
892 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
893 |
+
# Get query, key, value projections.
|
894 |
+
# shape:
|
895 |
+
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
896 |
+
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
897 |
+
# k, v: (batch_size, seq_len, d_model // n_heads)
|
898 |
+
# - for group query attn q: (batch_size, seq_len, d_model)
|
899 |
+
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
|
900 |
+
x_normed = self.attn_norm(x)
|
901 |
+
q = self.q_proj(x_normed)
|
902 |
+
k = self.k_proj(x_normed)
|
903 |
+
v = self.v_proj(x_normed)
|
904 |
+
|
905 |
+
# Get attention scores.
|
906 |
+
if self._activation_checkpoint_fn is not None:
|
907 |
+
att, cache = self._activation_checkpoint_fn( # type: ignore
|
908 |
+
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
|
909 |
+
)
|
910 |
+
else:
|
911 |
+
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
|
912 |
+
|
913 |
+
# Add attention scores.
|
914 |
+
# shape: (B, T, C)
|
915 |
+
x = x + self.dropout(att)
|
916 |
+
|
917 |
+
# Add feed-forward projection.
|
918 |
+
# shape: (batch_size, seq_len, d_model)
|
919 |
+
og_x = x
|
920 |
+
if self._activation_checkpoint_fn is not None:
|
921 |
+
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
922 |
+
else:
|
923 |
+
x = self.ff_norm(x)
|
924 |
+
x, x_up = self.ff_proj(x), self.up_proj(x) # new add
|
925 |
+
if self._activation_checkpoint_fn is not None:
|
926 |
+
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
|
927 |
+
else:
|
928 |
+
x = self.act(x)
|
929 |
+
x = x * x_up # new add
|
930 |
+
x = self.ff_out(x)
|
931 |
+
x = self.dropout(x)
|
932 |
+
x = og_x + x
|
933 |
+
|
934 |
+
return x, cache
|
935 |
+
|
936 |
+
|
937 |
+
class LLaDAOutput(NamedTuple):
|
938 |
+
logits: torch.FloatTensor
|
939 |
+
"""
|
940 |
+
A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
|
941 |
+
for the next token *before* normalization via (log) softmax.
|
942 |
+
"""
|
943 |
+
|
944 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
|
945 |
+
"""
|
946 |
+
Attention keys and values from each block.
|
947 |
+
"""
|
948 |
+
|
949 |
+
hidden_states: Optional[Tuple[torch.Tensor]]
|
950 |
+
"""
|
951 |
+
Hidden states from each block.
|
952 |
+
"""
|
953 |
+
|
954 |
+
|
955 |
+
class LLaDAGenerateOutput(NamedTuple):
|
956 |
+
token_ids: torch.LongTensor
|
957 |
+
"""
|
958 |
+
The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
|
959 |
+
These do *not* include the original input IDs.
|
960 |
+
"""
|
961 |
+
|
962 |
+
scores: torch.FloatTensor
|
963 |
+
"""
|
964 |
+
The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
|
965 |
+
"""
|
966 |
+
|
967 |
+
|
968 |
+
class LLaDABlockGroup(nn.ModuleList):
|
969 |
+
def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
|
970 |
+
super().__init__(modules)
|
971 |
+
self.config = config
|
972 |
+
self.layer_offset = layer_offset
|
973 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
974 |
+
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
|
975 |
+
|
976 |
+
def forward(
|
977 |
+
self,
|
978 |
+
x: torch.Tensor,
|
979 |
+
attention_bias: Optional[torch.FloatTensor] = None,
|
980 |
+
layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
981 |
+
use_cache: bool = False,
|
982 |
+
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
|
983 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
984 |
+
for block_idx, block in enumerate(self):
|
985 |
+
layer_past = None if layers_past is None else layers_past[block_idx]
|
986 |
+
block_idx += self.layer_offset
|
987 |
+
if (
|
988 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
989 |
+
or (
|
990 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
991 |
+
and block_idx % 2 == 0
|
992 |
+
)
|
993 |
+
or (
|
994 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
995 |
+
and block_idx % 3 == 0
|
996 |
+
)
|
997 |
+
or (
|
998 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
999 |
+
and block_idx % 4 == 0
|
1000 |
+
)
|
1001 |
+
):
|
1002 |
+
# shape: (batch_size, seq_len, d_model)
|
1003 |
+
x, cache = self._activation_checkpoint_fn( # type: ignore
|
1004 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
1005 |
+
)
|
1006 |
+
else:
|
1007 |
+
# shape: (batch_size, seq_len, d_model)
|
1008 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
1009 |
+
if attn_key_values is not None:
|
1010 |
+
assert cache is not None
|
1011 |
+
attn_key_values.append(cache)
|
1012 |
+
return x, attn_key_values
|
1013 |
+
|
1014 |
+
def reset_parameters(self):
|
1015 |
+
for block in self:
|
1016 |
+
block.reset_parameters()
|
1017 |
+
|
1018 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
1019 |
+
self.activation_checkpointing_strategy = strategy
|
1020 |
+
for block in self:
|
1021 |
+
block.set_activation_checkpointing(strategy)
|
1022 |
+
|
1023 |
+
|
1024 |
+
class LLaDAModel(nn.Module):
|
1025 |
+
def __init__(self, config: ModelConfig, init_params: bool = True):
|
1026 |
+
super().__init__()
|
1027 |
+
self.config = config
|
1028 |
+
self.__cache = BufferCache()
|
1029 |
+
|
1030 |
+
# Validate config.
|
1031 |
+
if self.config.alibi and self.config.flash_attention:
|
1032 |
+
raise Exception("ALiBi is currently not supported with FlashAttention")
|
1033 |
+
|
1034 |
+
if self.config.alibi and self.config.rope:
|
1035 |
+
raise Exception("ALiBi and RoPE are mutually exclusive")
|
1036 |
+
|
1037 |
+
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
|
1038 |
+
if self.config.embedding_size < self.config.vocab_size:
|
1039 |
+
raise Exception("embedding size should be at least as big as vocab size")
|
1040 |
+
elif self.config.embedding_size % 128 != 0:
|
1041 |
+
import warnings
|
1042 |
+
|
1043 |
+
warnings.warn(
|
1044 |
+
"Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
|
1048 |
+
self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
|
1049 |
+
|
1050 |
+
if not (
|
1051 |
+
0 < self.config.block_group_size <= self.config.n_layers
|
1052 |
+
and self.config.n_layers % self.config.block_group_size == 0
|
1053 |
+
):
|
1054 |
+
raise Exception("n layers must be divisible by block group size")
|
1055 |
+
|
1056 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
1057 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
|
1058 |
+
|
1059 |
+
self.transformer = nn.ModuleDict(
|
1060 |
+
dict(
|
1061 |
+
wte=nn.Embedding(
|
1062 |
+
config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
|
1063 |
+
),
|
1064 |
+
emb_drop=Dropout(config.embedding_dropout),
|
1065 |
+
ln_f=LayerNorm.build(config),
|
1066 |
+
)
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
|
1070 |
+
if self.config.block_group_size > 1:
|
1071 |
+
block_groups = [
|
1072 |
+
LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
|
1073 |
+
for i in range(0, config.n_layers, config.block_group_size)
|
1074 |
+
]
|
1075 |
+
self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
|
1076 |
+
else:
|
1077 |
+
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
1078 |
+
|
1079 |
+
if not (self.config.alibi or self.config.rope):
|
1080 |
+
self.transformer.update(
|
1081 |
+
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
1082 |
+
)
|
1083 |
+
if not config.weight_tying:
|
1084 |
+
self.transformer.update(
|
1085 |
+
{
|
1086 |
+
"ff_out": nn.Linear(
|
1087 |
+
config.d_model,
|
1088 |
+
config.embedding_size or config.vocab_size,
|
1089 |
+
bias=config.include_bias,
|
1090 |
+
device=config.init_device,
|
1091 |
+
)
|
1092 |
+
}
|
1093 |
+
)
|
1094 |
+
# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
|
1095 |
+
if init_params and self.config.init_device != "meta":
|
1096 |
+
self.reset_parameters()
|
1097 |
+
self.__num_fwd_flops: Optional[int] = None
|
1098 |
+
|
1099 |
+
# Warm up cache.
|
1100 |
+
if self.config.alibi:
|
1101 |
+
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
|
1102 |
+
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
|
1103 |
+
|
1104 |
+
def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
|
1105 |
+
self.activation_checkpointing_strategy = strategy
|
1106 |
+
if self.config.block_group_size != 1:
|
1107 |
+
for block_group in self.transformer.block_groups:
|
1108 |
+
block_group.set_activation_checkpointing(strategy)
|
1109 |
+
else:
|
1110 |
+
for block in self.transformer.blocks:
|
1111 |
+
block.set_activation_checkpointing(strategy)
|
1112 |
+
|
1113 |
+
@property
|
1114 |
+
def device(self) -> torch.device:
|
1115 |
+
device: torch.device = self.transformer.wte.weight.device # type: ignore
|
1116 |
+
if device.type == "meta":
|
1117 |
+
return _non_meta_init_device(self.config)
|
1118 |
+
else:
|
1119 |
+
return device
|
1120 |
+
|
1121 |
+
def reset_parameters(self):
|
1122 |
+
log.info("Initializing model parameters...")
|
1123 |
+
# Top-level embeddings / linear layers.
|
1124 |
+
init_weights(
|
1125 |
+
self.config,
|
1126 |
+
self.transformer.wte, # type: ignore
|
1127 |
+
std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
|
1128 |
+
type_of_module=ModuleType.emb,
|
1129 |
+
)
|
1130 |
+
if hasattr(self.transformer, "wpe"):
|
1131 |
+
init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
|
1132 |
+
|
1133 |
+
# Top-level layer norm.
|
1134 |
+
self.transformer.ln_f.reset_parameters() # type: ignore
|
1135 |
+
|
1136 |
+
# Output weights.
|
1137 |
+
if hasattr(self.transformer, "ff_out"):
|
1138 |
+
init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
|
1139 |
+
|
1140 |
+
# Let the blocks handle themselves.
|
1141 |
+
if self.config.block_group_size == 1:
|
1142 |
+
for block in self.transformer.blocks:
|
1143 |
+
block.reset_parameters()
|
1144 |
+
else:
|
1145 |
+
for block_group in self.transformer.block_groups:
|
1146 |
+
block_group.reset_parameters()
|
1147 |
+
|
1148 |
+
def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
1149 |
+
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
|
1150 |
+
-1
|
1151 |
+
] >= seq_len:
|
1152 |
+
if alibi_bias.device != device:
|
1153 |
+
alibi_bias = alibi_bias.to(device)
|
1154 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
1155 |
+
return alibi_bias
|
1156 |
+
with torch.autocast(device.type, enabled=False):
|
1157 |
+
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
1158 |
+
self.__cache["alibi_attention_bias"] = alibi_bias
|
1159 |
+
return alibi_bias
|
1160 |
+
|
1161 |
+
def forward(
|
1162 |
+
self,
|
1163 |
+
input_ids: torch.LongTensor,
|
1164 |
+
input_embeddings: Optional[torch.FloatTensor] = None,
|
1165 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1166 |
+
attention_bias: Optional[torch.Tensor] = None,
|
1167 |
+
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
1168 |
+
use_cache: bool = False,
|
1169 |
+
last_logits_only: bool = False,
|
1170 |
+
output_hidden_states: Optional[bool] = None,
|
1171 |
+
) -> LLaDAOutput:
|
1172 |
+
"""
|
1173 |
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
1174 |
+
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
|
1175 |
+
embeddings. When provided, it is treated as the output of the input embedding layer.
|
1176 |
+
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
|
1177 |
+
which input IDs are masked. A `1` value in the mask means that
|
1178 |
+
the corresponding input ID should *not* be ignored. A `0` means
|
1179 |
+
that the corresponding input ID is masked.
|
1180 |
+
|
1181 |
+
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
|
1182 |
+
library.
|
1183 |
+
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
|
1184 |
+
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
|
1185 |
+
to introduce causal or other biases.
|
1186 |
+
|
1187 |
+
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
|
1188 |
+
indicates that the i-th element in the sequence is allowed to attend to the j-th
|
1189 |
+
element in the sequence.
|
1190 |
+
|
1191 |
+
If the tensor is a float tensor, it will just be added to the attention
|
1192 |
+
scores before the softmax.
|
1193 |
+
|
1194 |
+
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
|
1195 |
+
:param past_key_values: Pre-computed keys and values for each attention block.
|
1196 |
+
Can be used to speed up sequential decoding. The `input_ids` which have
|
1197 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
1198 |
+
:param use_cache: If `True`, return key and value tensors for each block.
|
1199 |
+
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
|
1200 |
+
This can speed up decoding when you only care about the next token.
|
1201 |
+
"""
|
1202 |
+
# Add Basic MDM Model config check
|
1203 |
+
assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
|
1204 |
+
assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
|
1205 |
+
assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
|
1206 |
+
|
1207 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
1208 |
+
|
1209 |
+
if past_key_values:
|
1210 |
+
assert len(past_key_values) == self.config.n_layers
|
1211 |
+
|
1212 |
+
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
|
1213 |
+
if past_key_values is None:
|
1214 |
+
past_length = 0
|
1215 |
+
else:
|
1216 |
+
past_length = past_key_values[0][0].size(-2)
|
1217 |
+
|
1218 |
+
# Get embeddings of input.
|
1219 |
+
# shape: (batch_size, seq_len, d_model)
|
1220 |
+
# print(f"input_ids: {input_ids}, input_ids.shape: {input_ids.shape}")
|
1221 |
+
# print(f"transformer wte weight shape: {self.transformer.wte.weight.shape}")
|
1222 |
+
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
|
1223 |
+
|
1224 |
+
# print(f"xshape: {x.shape}")
|
1225 |
+
|
1226 |
+
if self.config.input_emb_norm:
|
1227 |
+
x = x * (self.config.d_model**0.5)
|
1228 |
+
|
1229 |
+
if not (self.config.alibi or self.config.rope):
|
1230 |
+
# Get positional embeddings.
|
1231 |
+
# shape: (1, seq_len)
|
1232 |
+
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
1233 |
+
# shape: (1, seq_len, d_model)
|
1234 |
+
pos_emb = self.transformer.wpe(pos) # type: ignore
|
1235 |
+
x = pos_emb + x
|
1236 |
+
|
1237 |
+
# Add input + positional embeddings and apply dropout.
|
1238 |
+
# shape: (batch_size, seq_len, d_model)
|
1239 |
+
x = self.transformer.emb_drop(x) # type: ignore
|
1240 |
+
|
1241 |
+
# Transform the attention mask into what the blocks expect.
|
1242 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
1243 |
+
# shape: (batch_size, 1, 1, seq_len)
|
1244 |
+
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
|
1245 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
1246 |
+
else:
|
1247 |
+
attention_mask = None
|
1248 |
+
|
1249 |
+
# Merge attention mask with attention bias.
|
1250 |
+
if (
|
1251 |
+
attention_bias is not None
|
1252 |
+
or attention_mask is not None
|
1253 |
+
or self.config.alibi
|
1254 |
+
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
1255 |
+
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
1256 |
+
# scores correctly.
|
1257 |
+
or past_key_values is not None
|
1258 |
+
):
|
1259 |
+
if attention_bias is None and self.config.alibi:
|
1260 |
+
# print(f"get_causal_attention_bias")
|
1261 |
+
attention_bias = get_causal_attention_bias(
|
1262 |
+
self.__cache, past_length + seq_len, x.device
|
1263 |
+
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
1264 |
+
elif attention_bias is None:
|
1265 |
+
# print(f"get_causal_attention_bias")
|
1266 |
+
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
1267 |
+
elif attention_bias.dtype in (torch.int8, torch.bool):
|
1268 |
+
# print(f"attention_bias.dtype in (torch.int8, torch.bool)")
|
1269 |
+
attention_bias = attention_bias.to(dtype=torch.float)
|
1270 |
+
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
1271 |
+
|
1272 |
+
# Transform to the right shape and data type.
|
1273 |
+
mask_len = seq_len
|
1274 |
+
if attention_mask is not None:
|
1275 |
+
mask_len = attention_mask.shape[-1]
|
1276 |
+
elif past_key_values is not None:
|
1277 |
+
mask_len = past_key_values[0][0].shape[-2] + seq_len
|
1278 |
+
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
|
1279 |
+
|
1280 |
+
# Add in the masking bias.
|
1281 |
+
if attention_mask is not None:
|
1282 |
+
attention_bias = attention_bias + attention_mask
|
1283 |
+
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
|
1284 |
+
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
|
1285 |
+
# it can produce NaNs.
|
1286 |
+
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
|
1287 |
+
|
1288 |
+
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
|
1289 |
+
|
1290 |
+
# decoder layers
|
1291 |
+
all_hidden_states = []
|
1292 |
+
|
1293 |
+
# Apply blocks one-by-one.
|
1294 |
+
if self.config.block_group_size == 1:
|
1295 |
+
for block_idx, block in enumerate(self.transformer.blocks):
|
1296 |
+
if output_hidden_states:
|
1297 |
+
# add hidden states
|
1298 |
+
all_hidden_states.append(x)
|
1299 |
+
|
1300 |
+
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
1301 |
+
if (
|
1302 |
+
(self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
|
1303 |
+
or (
|
1304 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
|
1305 |
+
and block_idx % 2 == 0
|
1306 |
+
)
|
1307 |
+
or (
|
1308 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
|
1309 |
+
and block_idx % 3 == 0
|
1310 |
+
)
|
1311 |
+
or (
|
1312 |
+
self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
|
1313 |
+
and block_idx % 4 == 0
|
1314 |
+
)
|
1315 |
+
):
|
1316 |
+
# shape: (batch_size, seq_len, d_model)
|
1317 |
+
x, cache = self._activation_checkpoint_fn(
|
1318 |
+
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
|
1319 |
+
)
|
1320 |
+
else:
|
1321 |
+
# shape: (batch_size, seq_len, d_model)
|
1322 |
+
x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
|
1323 |
+
if attn_key_values is not None:
|
1324 |
+
assert cache is not None
|
1325 |
+
attn_key_values.append(cache)
|
1326 |
+
else:
|
1327 |
+
for group_idx, block_group in enumerate(self.transformer.block_groups):
|
1328 |
+
if output_hidden_states:
|
1329 |
+
# add hidden states
|
1330 |
+
all_hidden_states.append(x)
|
1331 |
+
|
1332 |
+
layers_past = (
|
1333 |
+
None
|
1334 |
+
if past_key_values is None
|
1335 |
+
else past_key_values[
|
1336 |
+
group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
|
1337 |
+
]
|
1338 |
+
)
|
1339 |
+
x, cache = block_group(
|
1340 |
+
x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
|
1341 |
+
)
|
1342 |
+
if attn_key_values is not None:
|
1343 |
+
assert cache is not None
|
1344 |
+
attn_key_values.extend(cache)
|
1345 |
+
|
1346 |
+
if last_logits_only:
|
1347 |
+
# shape: (batch_size, 1, d_model)
|
1348 |
+
x = x[:, -1, :].unsqueeze(1)
|
1349 |
+
|
1350 |
+
# Apply final layer norm.
|
1351 |
+
# shape: (batch_size, seq_len or 1, d_model)
|
1352 |
+
x = self.transformer.ln_f(x) # type: ignore
|
1353 |
+
if output_hidden_states:
|
1354 |
+
# add final hidden state post-final-layernorm, following HuggingFace's convention
|
1355 |
+
all_hidden_states.append(x)
|
1356 |
+
|
1357 |
+
# Get logits.
|
1358 |
+
# shape: (batch_size, seq_len or 1, vocab_size)
|
1359 |
+
if self.config.weight_tying:
|
1360 |
+
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
1361 |
+
else:
|
1362 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
1363 |
+
if self.config.scale_logits:
|
1364 |
+
logits.mul_(1 / math.sqrt(self.config.d_model))
|
1365 |
+
|
1366 |
+
return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
|
1367 |
+
|
1368 |
+
|
1369 |
+
def create_model_config_from_pretrained_config(config: LLaDAConfig):
|
1370 |
+
"""
|
1371 |
+
Utility function
|
1372 |
+
"""
|
1373 |
+
|
1374 |
+
kwargs = {}
|
1375 |
+
for field in fields(ModelConfig):
|
1376 |
+
kwargs[field.name] = getattr(config, field.name)
|
1377 |
+
|
1378 |
+
model_config = ModelConfig(**kwargs)
|
1379 |
+
return model_config
|
1380 |
+
|
1381 |
+
|
1382 |
+
class LLaDAModelLM(PreTrainedModel):
|
1383 |
+
"""
|
1384 |
+
Extremely barebones HF model wrapper.
|
1385 |
+
"""
|
1386 |
+
|
1387 |
+
config_class = LLaDAConfig
|
1388 |
+
base_model_prefix = "model"
|
1389 |
+
_no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
|
1390 |
+
|
1391 |
+
def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
|
1392 |
+
super().__init__(config)
|
1393 |
+
|
1394 |
+
if not model:
|
1395 |
+
model_config = create_model_config_from_pretrained_config(config)
|
1396 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
1397 |
+
model_config.init_device = "cpu"
|
1398 |
+
self.model = LLaDAModel(model_config, init_params=init_params)
|
1399 |
+
else:
|
1400 |
+
self.model = model
|
1401 |
+
|
1402 |
+
def forward(
|
1403 |
+
self,
|
1404 |
+
input_ids: torch.LongTensor = None,
|
1405 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1406 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1407 |
+
attention_bias: Optional[torch.Tensor] = None,
|
1408 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1409 |
+
labels: Optional[torch.LongTensor] = None,
|
1410 |
+
use_cache: Optional[bool] = None,
|
1411 |
+
output_attentions: Optional[bool] = None,
|
1412 |
+
output_hidden_states: Optional[bool] = None,
|
1413 |
+
return_dict: Optional[bool] = None,
|
1414 |
+
cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
|
1415 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1416 |
+
if use_cache is None:
|
1417 |
+
use_cache = self.config.use_cache
|
1418 |
+
|
1419 |
+
if output_attentions:
|
1420 |
+
raise ValueError("output_attentions is not yet supported in LLaDA")
|
1421 |
+
|
1422 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1423 |
+
|
1424 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1425 |
+
outputs = self.model.forward(
|
1426 |
+
input_ids=input_ids,
|
1427 |
+
input_embeddings=inputs_embeds,
|
1428 |
+
attention_mask=attention_mask,
|
1429 |
+
attention_bias=attention_bias,
|
1430 |
+
past_key_values=None,
|
1431 |
+
use_cache=False,
|
1432 |
+
output_hidden_states=output_hidden_states,
|
1433 |
+
)
|
1434 |
+
|
1435 |
+
logits = outputs.logits
|
1436 |
+
hidden_states = outputs.hidden_states
|
1437 |
+
|
1438 |
+
loss = None
|
1439 |
+
if labels is not None:
|
1440 |
+
import warnings
|
1441 |
+
warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
|
1442 |
+
if not return_dict:
|
1443 |
+
output = (logits,) + outputs[1:]
|
1444 |
+
return (loss,) + output if loss is not None else output
|
1445 |
+
|
1446 |
+
return CausalLMOutputWithPast(
|
1447 |
+
logits=logits,
|
1448 |
+
past_key_values=outputs.attn_key_values,
|
1449 |
+
hidden_states=hidden_states,
|
1450 |
+
)
|
1451 |
+
|
1452 |
+
def can_generate(self) -> bool:
|
1453 |
+
return True
|
1454 |
+
|
1455 |
+
def prepare_inputs_for_generation(
|
1456 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
1457 |
+
):
|
1458 |
+
if past_key_values:
|
1459 |
+
# This is because we want the model to only process the last generated token.
|
1460 |
+
input_ids = input_ids[:, -1:]
|
1461 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
1462 |
+
|
1463 |
+
model_inputs.update(kwargs)
|
1464 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
1465 |
+
return model_inputs
|
1466 |
+
|
1467 |
+
# TODO: these are required to make the implementation complete.
|
1468 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
1469 |
+
# pass
|
1470 |
+
#
|
1471 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
1472 |
+
# pass
|
1473 |
+
#
|
1474 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
1475 |
+
# pass
|
1476 |
+
|
1477 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
1478 |
+
return self.model.transformer.wte
|
1479 |
+
|
1480 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
1481 |
+
self.model.transformer.wte = value
|
1482 |
+
|
1483 |
+
def get_output_embeddings(self):
|
1484 |
+
if self.config.weight_tying:
|
1485 |
+
return self.model.transformer.wte
|
1486 |
+
else:
|
1487 |
+
return self.model.transformer.ff_out
|
1488 |
+
|
1489 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
1490 |
+
if self.config.weight_tying:
|
1491 |
+
self.model.transformer.wte = value
|
1492 |
+
else:
|
1493 |
+
self.model.transformer.ff_out = value
|
1494 |
+
|
1495 |
+
def tie_weights(self):
|
1496 |
+
if self.config.weight_tying:
|
1497 |
+
self.model.transformer.ff_out = self.model.transformer.wte
|
1498 |
+
|
1499 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
1500 |
+
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
models/modeling_magvitv2.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from .common_modules import *
|
6 |
+
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
|
7 |
+
from .misc import *
|
8 |
+
import math
|
9 |
+
|
10 |
+
class Updateable:
|
11 |
+
def do_update_step(
|
12 |
+
self, epoch: int, global_step: int, on_load_weights: bool = False
|
13 |
+
):
|
14 |
+
for attr in self.__dir__():
|
15 |
+
if attr.startswith("_"):
|
16 |
+
continue
|
17 |
+
try:
|
18 |
+
module = getattr(self, attr)
|
19 |
+
except:
|
20 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
21 |
+
if isinstance(module, Updateable):
|
22 |
+
module.do_update_step(
|
23 |
+
epoch, global_step, on_load_weights=on_load_weights
|
24 |
+
)
|
25 |
+
self.update_step(epoch, global_step, on_load_weights=on_load_weights)
|
26 |
+
|
27 |
+
def do_update_step_end(self, epoch: int, global_step: int):
|
28 |
+
for attr in self.__dir__():
|
29 |
+
if attr.startswith("_"):
|
30 |
+
continue
|
31 |
+
try:
|
32 |
+
module = getattr(self, attr)
|
33 |
+
except:
|
34 |
+
continue # ignore attributes like property, which can't be retrived using getattr?
|
35 |
+
if isinstance(module, Updateable):
|
36 |
+
module.do_update_step_end(epoch, global_step)
|
37 |
+
self.update_step_end(epoch, global_step)
|
38 |
+
|
39 |
+
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
40 |
+
# override this method to implement custom update logic
|
41 |
+
# if on_load_weights is True, you should be careful doing things related to model evaluations,
|
42 |
+
# as the models and tensors are not guarenteed to be on the same device
|
43 |
+
pass
|
44 |
+
|
45 |
+
def update_step_end(self, epoch: int, global_step: int):
|
46 |
+
pass
|
47 |
+
|
48 |
+
class VQGANEncoder(ModelMixin, ConfigMixin):
|
49 |
+
@dataclass
|
50 |
+
class Config:
|
51 |
+
ch: int = 128
|
52 |
+
ch_mult: List[int] = field(default_factory=lambda: [1, 2, 2, 4, 4])
|
53 |
+
num_res_blocks: List[int] = field(default_factory=lambda: [4, 3, 4, 3, 4])
|
54 |
+
attn_resolutions: List[int] = field(default_factory=lambda: [5])
|
55 |
+
dropout: float = 0.0
|
56 |
+
in_ch: int = 3
|
57 |
+
out_ch: int = 3
|
58 |
+
resolution: int = 256
|
59 |
+
z_channels: int = 13
|
60 |
+
double_z: bool = False
|
61 |
+
|
62 |
+
def __init__(self,
|
63 |
+
ch: int = 128,
|
64 |
+
ch_mult: List[int] = [1, 2, 2, 4, 4],
|
65 |
+
num_res_blocks: List[int] = [4, 3, 4, 3, 4],
|
66 |
+
attn_resolutions: List[int] = [5],
|
67 |
+
dropout: float = 0.0,
|
68 |
+
in_ch: int = 3,
|
69 |
+
out_ch: int = 3,
|
70 |
+
resolution: int = 256,
|
71 |
+
z_channels: int = 13,
|
72 |
+
double_z: bool = False):
|
73 |
+
super().__init__()
|
74 |
+
self.ch = ch
|
75 |
+
self.temb_ch = 0
|
76 |
+
self.num_resolutions = len(ch_mult)
|
77 |
+
self.num_res_blocks = num_res_blocks
|
78 |
+
self.resolution = resolution
|
79 |
+
self.in_ch = in_ch
|
80 |
+
# downsampling
|
81 |
+
self.conv_in = torch.nn.Conv2d(
|
82 |
+
self.in_ch, self.ch, kernel_size=3, stride=1, padding=1
|
83 |
+
)
|
84 |
+
|
85 |
+
curr_res = self.resolution
|
86 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
87 |
+
self.down = nn.ModuleList()
|
88 |
+
for i_level in range(self.num_resolutions):
|
89 |
+
block = nn.ModuleList()
|
90 |
+
attn = nn.ModuleList()
|
91 |
+
block_in = self.ch * in_ch_mult[i_level]
|
92 |
+
block_out = self.ch * ch_mult[i_level]
|
93 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
94 |
+
block.append(
|
95 |
+
ResnetBlock(
|
96 |
+
in_channels=block_in,
|
97 |
+
out_channels=block_out,
|
98 |
+
temb_channels=self.temb_ch,
|
99 |
+
dropout=dropout,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
block_in = block_out
|
103 |
+
if curr_res in attn_resolutions:
|
104 |
+
attn.append(AttnBlock(block_in))
|
105 |
+
down = nn.Module()
|
106 |
+
down.block = block
|
107 |
+
down.attn = attn
|
108 |
+
if i_level != self.num_resolutions - 1:
|
109 |
+
down.downsample = Downsample(block_in, True)
|
110 |
+
curr_res = curr_res // 2
|
111 |
+
self.down.append(down)
|
112 |
+
|
113 |
+
# middle
|
114 |
+
self.mid = nn.Module()
|
115 |
+
self.mid.block_1 = ResnetBlock(
|
116 |
+
in_channels=block_in,
|
117 |
+
out_channels=block_in,
|
118 |
+
temb_channels=self.temb_ch,
|
119 |
+
dropout=dropout,
|
120 |
+
)
|
121 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
122 |
+
self.mid.block_2 = ResnetBlock(
|
123 |
+
in_channels=block_in,
|
124 |
+
out_channels=block_in,
|
125 |
+
temb_channels=self.temb_ch,
|
126 |
+
dropout=dropout,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
self.norm_out = Normalize(block_in)
|
131 |
+
self.conv_out = torch.nn.Conv2d(
|
132 |
+
block_in,
|
133 |
+
2 * z_channels if double_z else z_channels,
|
134 |
+
kernel_size=3,
|
135 |
+
stride=1,
|
136 |
+
padding=1,
|
137 |
+
)
|
138 |
+
|
139 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
140 |
+
# for param in self.parameters():
|
141 |
+
# broadcast(param, src=0)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
# timestep embedding
|
145 |
+
temb = None
|
146 |
+
|
147 |
+
# downsampling
|
148 |
+
hs = [self.conv_in(x)]
|
149 |
+
for i_level in range(self.num_resolutions):
|
150 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
151 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
152 |
+
if len(self.down[i_level].attn) > 0:
|
153 |
+
h = self.down[i_level].attn[i_block](h)
|
154 |
+
hs.append(h)
|
155 |
+
if i_level != self.num_resolutions - 1:
|
156 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
157 |
+
|
158 |
+
# middle
|
159 |
+
h = hs[-1]
|
160 |
+
h = self.mid.block_1(h, temb)
|
161 |
+
h = self.mid.attn_1(h)
|
162 |
+
h = self.mid.block_2(h, temb)
|
163 |
+
|
164 |
+
# end
|
165 |
+
h = self.norm_out(h)
|
166 |
+
h = nonlinearity(h)
|
167 |
+
h = self.conv_out(h)
|
168 |
+
h = self.quant_conv(h)
|
169 |
+
return h
|
170 |
+
|
171 |
+
|
172 |
+
class LFQuantizer(nn.Module):
|
173 |
+
def __init__(self, num_codebook_entry: int = -1,
|
174 |
+
codebook_dim: int = 13,
|
175 |
+
beta: float = 0.25,
|
176 |
+
entropy_multiplier: float = 0.1,
|
177 |
+
commit_loss_multiplier: float = 0.1, ):
|
178 |
+
super().__init__()
|
179 |
+
self.codebook_size = 2 ** codebook_dim
|
180 |
+
print(
|
181 |
+
f"Look-up free quantizer with codebook size: {self.codebook_size}"
|
182 |
+
)
|
183 |
+
self.e_dim = codebook_dim
|
184 |
+
self.beta = beta
|
185 |
+
|
186 |
+
indices = torch.arange(self.codebook_size)
|
187 |
+
|
188 |
+
binary = (
|
189 |
+
indices.unsqueeze(1)
|
190 |
+
>> torch.arange(codebook_dim - 1, -1, -1, dtype=torch.long)
|
191 |
+
) & 1
|
192 |
+
|
193 |
+
embedding = binary.float() * 2 - 1
|
194 |
+
self.register_buffer("embedding", embedding)
|
195 |
+
self.register_buffer(
|
196 |
+
"power_vals", 2 ** torch.arange(codebook_dim - 1, -1, -1)
|
197 |
+
)
|
198 |
+
self.commit_loss_multiplier = commit_loss_multiplier
|
199 |
+
self.entropy_multiplier = entropy_multiplier
|
200 |
+
|
201 |
+
def get_indices(self, z_q):
|
202 |
+
return (
|
203 |
+
(self.power_vals.reshape(1, -1, 1, 1) * (z_q > 0).float())
|
204 |
+
.sum(1, keepdim=True)
|
205 |
+
.long()
|
206 |
+
)
|
207 |
+
|
208 |
+
def get_codebook_entry(self, indices, shape=None):
|
209 |
+
if shape is None:
|
210 |
+
h, w = int(math.sqrt(indices.shape[-1])), int(math.sqrt(indices.shape[-1]))
|
211 |
+
else:
|
212 |
+
h, w = shape
|
213 |
+
b, _ = indices.shape
|
214 |
+
indices = indices.reshape(-1)
|
215 |
+
z_q = self.embedding[indices]
|
216 |
+
z_q = z_q.view(b, h, w, -1)
|
217 |
+
|
218 |
+
# reshape back to match original input shape
|
219 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
220 |
+
|
221 |
+
return z_q
|
222 |
+
|
223 |
+
def forward(self, z, get_code=False):
|
224 |
+
"""
|
225 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
226 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
227 |
+
z (continuous) -> z_q (discrete)
|
228 |
+
z.shape = (batch, channel, height, width)
|
229 |
+
quantization pipeline:
|
230 |
+
1. get encoder input (B,C,H,W)
|
231 |
+
2. flatten input to (B*H*W,C)
|
232 |
+
"""
|
233 |
+
if get_code:
|
234 |
+
return self.get_codebook_entry(z)
|
235 |
+
|
236 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
237 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
238 |
+
z_flattened = z.view(-1, self.e_dim)
|
239 |
+
ge_zero = (z_flattened > 0).float()
|
240 |
+
ones = torch.ones_like(z_flattened)
|
241 |
+
z_q = ones * ge_zero + -ones * (1 - ge_zero)
|
242 |
+
|
243 |
+
# preserve gradients
|
244 |
+
z_q = z_flattened + (z_q - z_flattened).detach()
|
245 |
+
|
246 |
+
# compute entropy loss
|
247 |
+
CatDist = torch.distributions.categorical.Categorical
|
248 |
+
logit = torch.stack(
|
249 |
+
[
|
250 |
+
-(z_flattened - torch.ones_like(z_q)).pow(2),
|
251 |
+
-(z_flattened - torch.ones_like(z_q) * -1).pow(2),
|
252 |
+
],
|
253 |
+
dim=-1,
|
254 |
+
)
|
255 |
+
cat_dist = CatDist(logits=logit)
|
256 |
+
entropy = cat_dist.entropy().mean()
|
257 |
+
mean_prob = cat_dist.probs.mean(0)
|
258 |
+
mean_entropy = CatDist(probs=mean_prob).entropy().mean()
|
259 |
+
|
260 |
+
# compute loss for embedding
|
261 |
+
commit_loss = torch.mean(
|
262 |
+
(z_q.detach() - z_flattened) ** 2
|
263 |
+
) + self.beta * torch.mean((z_q - z_flattened.detach()) ** 2)
|
264 |
+
|
265 |
+
# reshape back to match original input shape
|
266 |
+
z_q = z_q.view(z.shape)
|
267 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
268 |
+
|
269 |
+
return {
|
270 |
+
"z": z_q,
|
271 |
+
"quantizer_loss": commit_loss * self.commit_loss_multiplier,
|
272 |
+
"entropy_loss": (entropy - mean_entropy) * self.entropy_multiplier,
|
273 |
+
"indices": self.get_indices(z_q),
|
274 |
+
}
|
275 |
+
|
276 |
+
|
277 |
+
class VQGANDecoder(ModelMixin, ConfigMixin):
|
278 |
+
def __init__(self, ch: int = 128,
|
279 |
+
ch_mult: List[int] = [1, 1, 2, 2, 4],
|
280 |
+
num_res_blocks: List[int] = [4, 4, 3, 4, 3],
|
281 |
+
attn_resolutions: List[int] = [5],
|
282 |
+
dropout: float = 0.0,
|
283 |
+
in_ch: int = 3,
|
284 |
+
out_ch: int = 3,
|
285 |
+
resolution: int = 256,
|
286 |
+
z_channels: int = 13,
|
287 |
+
double_z: bool = False):
|
288 |
+
super().__init__()
|
289 |
+
self.ch = ch
|
290 |
+
self.temb_ch = 0
|
291 |
+
self.num_resolutions = len(ch_mult)
|
292 |
+
self.num_res_blocks = num_res_blocks
|
293 |
+
self.resolution = resolution
|
294 |
+
self.in_ch = in_ch
|
295 |
+
self.give_pre_end = False
|
296 |
+
|
297 |
+
self.z_channels = z_channels
|
298 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
299 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
300 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
301 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
|
302 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
303 |
+
print(
|
304 |
+
"Working with z of shape {} = {} dimensions.".format(
|
305 |
+
self.z_shape, np.prod(self.z_shape)
|
306 |
+
)
|
307 |
+
)
|
308 |
+
|
309 |
+
# z to block_in
|
310 |
+
self.conv_in = torch.nn.Conv2d(
|
311 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
312 |
+
)
|
313 |
+
|
314 |
+
# middle
|
315 |
+
self.mid = nn.Module()
|
316 |
+
self.mid.block_1 = ResnetBlock(
|
317 |
+
in_channels=block_in,
|
318 |
+
out_channels=block_in,
|
319 |
+
temb_channels=self.temb_ch,
|
320 |
+
dropout=dropout,
|
321 |
+
)
|
322 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
323 |
+
self.mid.block_2 = ResnetBlock(
|
324 |
+
in_channels=block_in,
|
325 |
+
out_channels=block_in,
|
326 |
+
temb_channels=self.temb_ch,
|
327 |
+
dropout=dropout,
|
328 |
+
)
|
329 |
+
|
330 |
+
# upsampling
|
331 |
+
self.up = nn.ModuleList()
|
332 |
+
for i_level in reversed(range(self.num_resolutions)):
|
333 |
+
block = nn.ModuleList()
|
334 |
+
attn = nn.ModuleList()
|
335 |
+
block_out = ch * ch_mult[i_level]
|
336 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
337 |
+
block.append(
|
338 |
+
ResnetBlock(
|
339 |
+
in_channels=block_in,
|
340 |
+
out_channels=block_out,
|
341 |
+
temb_channels=self.temb_ch,
|
342 |
+
dropout=dropout,
|
343 |
+
)
|
344 |
+
)
|
345 |
+
block_in = block_out
|
346 |
+
if curr_res in attn_resolutions:
|
347 |
+
attn.append(AttnBlock(block_in))
|
348 |
+
up = nn.Module()
|
349 |
+
up.block = block
|
350 |
+
up.attn = attn
|
351 |
+
if i_level != 0:
|
352 |
+
up.upsample = Upsample(block_in, True)
|
353 |
+
curr_res = curr_res * 2
|
354 |
+
self.up.insert(0, up) # prepend to get consistent order
|
355 |
+
|
356 |
+
self.norm_out = Normalize(block_in)
|
357 |
+
self.conv_out = torch.nn.Conv2d(
|
358 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
359 |
+
)
|
360 |
+
self.post_quant_conv = torch.nn.Conv2d(
|
361 |
+
z_channels, z_channels, 1
|
362 |
+
)
|
363 |
+
|
364 |
+
|
365 |
+
def forward(self, z):
|
366 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
367 |
+
self.last_z_shape = z.shape
|
368 |
+
# timestep embedding
|
369 |
+
temb = None
|
370 |
+
output = dict()
|
371 |
+
z = self.post_quant_conv(z)
|
372 |
+
|
373 |
+
# z to block_in
|
374 |
+
h = self.conv_in(z)
|
375 |
+
|
376 |
+
# middle
|
377 |
+
h = self.mid.block_1(h, temb)
|
378 |
+
h = self.mid.attn_1(h)
|
379 |
+
h = self.mid.block_2(h, temb)
|
380 |
+
|
381 |
+
# upsampling
|
382 |
+
for i_level in reversed(range(self.num_resolutions)):
|
383 |
+
for i_block in range(self.num_res_blocks[i_level]):
|
384 |
+
h = self.up[i_level].block[i_block](h, temb)
|
385 |
+
if len(self.up[i_level].attn) > 0:
|
386 |
+
h = self.up[i_level].attn[i_block](h)
|
387 |
+
if i_level != 0:
|
388 |
+
h = self.up[i_level].upsample(h)
|
389 |
+
|
390 |
+
# end
|
391 |
+
output["output"] = h
|
392 |
+
if self.give_pre_end:
|
393 |
+
return output
|
394 |
+
|
395 |
+
h = self.norm_out(h)
|
396 |
+
h = nonlinearity(h)
|
397 |
+
h = self.conv_out(h)
|
398 |
+
output["output"] = h
|
399 |
+
return output
|
400 |
+
|
401 |
+
|
402 |
+
class MAGVITv2(ModelMixin, ConfigMixin):
|
403 |
+
@register_to_config
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
):
|
407 |
+
super().__init__()
|
408 |
+
|
409 |
+
self.encoder = VQGANEncoder()
|
410 |
+
self.decoder = VQGANDecoder()
|
411 |
+
self.quantize = LFQuantizer()
|
412 |
+
|
413 |
+
def forward(self, pixel_values, return_loss=False):
|
414 |
+
pass
|
415 |
+
|
416 |
+
def encode(self, pixel_values, return_loss=False):
|
417 |
+
hidden_states = self.encoder(pixel_values)
|
418 |
+
quantized_states = self.quantize(hidden_states)['z']
|
419 |
+
codebook_indices = self.quantize.get_indices(quantized_states).reshape(pixel_values.shape[0], -1)
|
420 |
+
output = (quantized_states, codebook_indices)
|
421 |
+
return output
|
422 |
+
|
423 |
+
def get_code(self, pixel_values):
|
424 |
+
hidden_states = self.encoder(pixel_values)
|
425 |
+
codebook_indices = self.quantize.get_indices(self.quantize(hidden_states)['z']).reshape(pixel_values.shape[0], -1)
|
426 |
+
|
427 |
+
return codebook_indices
|
428 |
+
|
429 |
+
def decode_code(self, codebook_indices, shape=None):
|
430 |
+
z_q = self.quantize.get_codebook_entry(codebook_indices, shape=shape)
|
431 |
+
|
432 |
+
reconstructed_pixel_values = self.decoder(z_q)["output"]
|
433 |
+
return reconstructed_pixel_values
|
434 |
+
|
435 |
+
|
436 |
+
if __name__ == '__main__':
|
437 |
+
encoder = VQGANEncoder()
|
438 |
+
import ipdb
|
439 |
+
ipdb.set_trace()
|
440 |
+
print()
|
models/modeling_mmada.py
ADDED
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import sys
|
6 |
+
from abc import abstractmethod
|
7 |
+
from collections import defaultdict
|
8 |
+
from functools import partial
|
9 |
+
from typing import (
|
10 |
+
Callable,
|
11 |
+
Dict,
|
12 |
+
Iterable,
|
13 |
+
List,
|
14 |
+
NamedTuple,
|
15 |
+
Optional,
|
16 |
+
Sequence,
|
17 |
+
Set,
|
18 |
+
Tuple,
|
19 |
+
cast,
|
20 |
+
)
|
21 |
+
from dataclasses import fields
|
22 |
+
from typing import List, Optional, Tuple, Union
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
import torch.backends.cuda
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
from torch import einsum
|
29 |
+
from transformers import PreTrainedModel
|
30 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
31 |
+
from transformers.models.auto import AutoModel, AutoConfig, AutoModelForCausalLM
|
32 |
+
from transformers.cache_utils import Cache
|
33 |
+
from PIL import Image
|
34 |
+
from .configuration_llada import (
|
35 |
+
LLaDAConfig,
|
36 |
+
StrEnum,
|
37 |
+
InitFnType,
|
38 |
+
ActivationType,
|
39 |
+
BlockType,
|
40 |
+
LayerNormType,
|
41 |
+
ModelConfig,
|
42 |
+
ActivationCheckpointingStrategy,
|
43 |
+
)
|
44 |
+
|
45 |
+
from .modeling_llada import LLaDAModelLM
|
46 |
+
from .sampling import cosine_schedule, mask_by_random_topk
|
47 |
+
from transformers import PretrainedConfig
|
48 |
+
|
49 |
+
def add_gumbel_noise(logits, temperature):
|
50 |
+
'''
|
51 |
+
The Gumbel max is a method for sampling categorical distributions.
|
52 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
53 |
+
Thus, we use float64.
|
54 |
+
'''
|
55 |
+
if temperature == 0:
|
56 |
+
return logits
|
57 |
+
logits = logits.to(torch.float64)
|
58 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
59 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
60 |
+
return logits.exp() / gumbel_noise
|
61 |
+
|
62 |
+
|
63 |
+
def get_num_transfer_tokens(mask_index, steps):
|
64 |
+
'''
|
65 |
+
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
66 |
+
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
67 |
+
the expected number of tokens transitioned at each step should be consistent.
|
68 |
+
|
69 |
+
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
70 |
+
'''
|
71 |
+
mask_num = mask_index.sum(dim=1, keepdim=True)
|
72 |
+
|
73 |
+
base = mask_num // steps
|
74 |
+
remainder = mask_num % steps
|
75 |
+
|
76 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
77 |
+
|
78 |
+
for i in range(mask_num.size(0)):
|
79 |
+
num_transfer_tokens[i, :remainder[i]] += 1
|
80 |
+
|
81 |
+
return num_transfer_tokens
|
82 |
+
|
83 |
+
class MMadaConfig(PretrainedConfig):
|
84 |
+
model_type = "mmada"
|
85 |
+
|
86 |
+
def __init__(self, **kwargs):
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
allowed_keys = [
|
90 |
+
"vocab_size",
|
91 |
+
"llm_vocab_size",
|
92 |
+
"llm_model_path",
|
93 |
+
"codebook_size",
|
94 |
+
"num_vq_tokens",
|
95 |
+
"num_new_special_tokens",
|
96 |
+
"gradient_checkpointing",
|
97 |
+
"new_vocab_size",
|
98 |
+
]
|
99 |
+
|
100 |
+
for key in allowed_keys:
|
101 |
+
if key in kwargs:
|
102 |
+
setattr(self, key, kwargs[key])
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
class MMadaModelLM(LLaDAModelLM):
|
107 |
+
config_class = MMadaConfig
|
108 |
+
base_model_prefix = "model"
|
109 |
+
def __init__(self, config: MMadaConfig, *args, **kwargs):
|
110 |
+
print(f"Initializing MMadaModelLM with config: {config}")
|
111 |
+
super().__init__(config, *args, **kwargs)
|
112 |
+
|
113 |
+
# # resize token embeddings
|
114 |
+
# print(f"Resizing token embeddings to {config.new_vocab_size}")
|
115 |
+
# self.resize_token_embeddings(config.new_vocab_size)
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def t2i_generate(
|
119 |
+
self,
|
120 |
+
input_ids: torch.LongTensor = None,
|
121 |
+
uncond_input_ids: torch.LongTensor = None,
|
122 |
+
attention_mask=None,
|
123 |
+
uncond_attention_mask=None,
|
124 |
+
temperature=1.0,
|
125 |
+
timesteps=18, # ideal number of steps is 18 in maskgit paper
|
126 |
+
guidance_scale=0,
|
127 |
+
noise_schedule=cosine_schedule,
|
128 |
+
generator: torch.Generator = None,
|
129 |
+
config=None,
|
130 |
+
seq_len=1024,
|
131 |
+
mask_token_id = 126336,
|
132 |
+
resolution = 512,
|
133 |
+
codebook_size = 8192,
|
134 |
+
**kwargs,
|
135 |
+
):
|
136 |
+
"""
|
137 |
+
Generate 1:1 similar to the original MaskGit repo
|
138 |
+
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
|
139 |
+
"""
|
140 |
+
|
141 |
+
# begin with all image token ids masked
|
142 |
+
# 计算有多少个mask token
|
143 |
+
mask_count = (input_ids == mask_token_id).sum().item()
|
144 |
+
num_vq_tokens = seq_len
|
145 |
+
num_new_special_tokens = 0
|
146 |
+
uni_prompting = kwargs.get("uni_prompting", None)
|
147 |
+
# print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}")
|
148 |
+
input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
|
149 |
+
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens)
|
150 |
+
|
151 |
+
# for classifier-free guidance
|
152 |
+
if uncond_input_ids is not None:
|
153 |
+
uncond_prefix = uncond_input_ids[:, :resolution + 1]
|
154 |
+
|
155 |
+
for step in range(timesteps):
|
156 |
+
if uncond_input_ids is not None and guidance_scale > 0:
|
157 |
+
uncond_input_ids = torch.cat(
|
158 |
+
[uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
|
159 |
+
model_input = torch.cat([input_ids, uncond_input_ids])
|
160 |
+
attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
|
161 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
162 |
+
logits = self(model_input, attention_bias=attention_bias).logits
|
163 |
+
# print(f"logits.shape: {logits.shape}")
|
164 |
+
cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
|
165 |
+
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
166 |
+
# it seems that muse has a different cfg setting
|
167 |
+
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
|
168 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
169 |
+
else:
|
170 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
171 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
172 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
173 |
+
|
174 |
+
# logits: 1, 1024, 8192
|
175 |
+
# print(f"logits.shape: {logits.shape}")
|
176 |
+
probs = logits.softmax(dim=-1)
|
177 |
+
sampled = probs.reshape(-1, logits.size(-1))
|
178 |
+
# print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}")
|
179 |
+
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024
|
180 |
+
|
181 |
+
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
|
182 |
+
# print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}")
|
183 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
|
184 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
185 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
186 |
+
ratio = 1.0 * (step + 1) / timesteps
|
187 |
+
mask_ratio = noise_schedule(torch.tensor(ratio))
|
188 |
+
# Computes the probabilities of each selected tokens.
|
189 |
+
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
|
190 |
+
selected_probs = selected_probs.squeeze(-1)
|
191 |
+
|
192 |
+
# Ignores the tokens given in the input by overwriting their confidence.
|
193 |
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
194 |
+
# Gets mask lens for each sample in the batch according to the mask ratio.
|
195 |
+
mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
|
196 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
197 |
+
# one and for the next iteration
|
198 |
+
mask_len = torch.max(
|
199 |
+
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
200 |
+
)
|
201 |
+
# print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}")
|
202 |
+
# Adds noise for randomness
|
203 |
+
temperature = temperature * (1.0 - ratio)
|
204 |
+
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
|
205 |
+
# Masks tokens with lower confidence.
|
206 |
+
input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
|
207 |
+
sampled_ids + len(uni_prompting.text_tokenizer)
|
208 |
+
+ num_new_special_tokens)
|
209 |
+
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
|
210 |
+
|
211 |
+
return sampled_ids
|
212 |
+
|
213 |
+
def forward_process(
|
214 |
+
self,
|
215 |
+
input_ids,
|
216 |
+
labels,
|
217 |
+
batch_size_t2i=0,
|
218 |
+
batch_size_lm=0,
|
219 |
+
batch_size_mmu=0,
|
220 |
+
max_seq_length=128,
|
221 |
+
p_mask_lm=None,
|
222 |
+
p_mask_mmu=None,
|
223 |
+
answer_lengths=None,
|
224 |
+
t2i_masks=None,
|
225 |
+
answer_lengths_lm=None
|
226 |
+
):
|
227 |
+
# attention bias, True for batch_size, 1, seq_len, seq_len
|
228 |
+
attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
|
229 |
+
attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
|
230 |
+
attention_bias[:batch_size_t2i] = attention_bias_t2i
|
231 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
232 |
+
# logits = self(input_ids).logits
|
233 |
+
self.output_size = logits.shape[-1]
|
234 |
+
|
235 |
+
# print(f"logits shape: {logits.shape}") B, 359, vocab_size
|
236 |
+
|
237 |
+
if batch_size_t2i == 0:
|
238 |
+
loss_t2i = torch.tensor(0.0, device=input_ids.device)
|
239 |
+
else:
|
240 |
+
# t2i loss
|
241 |
+
loss_t2i = F.cross_entropy(
|
242 |
+
logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
|
243 |
+
labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
|
244 |
+
)
|
245 |
+
|
246 |
+
# llada loss
|
247 |
+
masked_indices = input_ids == self.config.mask_token_id
|
248 |
+
masked_indices_lm = masked_indices[batch_size_t2i:batch_size_t2i + batch_size_lm]
|
249 |
+
# 新增调试代码:统计每行mask数量
|
250 |
+
# if masked_indices_lm.numel() > 0:
|
251 |
+
# mask_counts = torch.sum(masked_indices_lm, dim=1)
|
252 |
+
# logging.info(f"[LM mask nums]: {mask_counts.cpu()}.")
|
253 |
+
# else:
|
254 |
+
# logging.info("[LM mask nums] no LM sample.")
|
255 |
+
masked_indices_mmu = masked_indices[-batch_size_mmu:]
|
256 |
+
p_mask_lm = p_mask_lm.to(masked_indices_lm.device)
|
257 |
+
p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device)
|
258 |
+
answer_lengths = answer_lengths.to(masked_indices_mmu.device)
|
259 |
+
loss_lm = F.cross_entropy(
|
260 |
+
logits[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1, self.output_size),
|
261 |
+
labels[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none'
|
262 |
+
)/p_mask_lm[masked_indices_lm]
|
263 |
+
# print(f"logits lm shape: {logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape}")
|
264 |
+
loss_lm = loss_lm.sum() / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0] * logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[1])
|
265 |
+
|
266 |
+
# llm loss
|
267 |
+
answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device)
|
268 |
+
loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0])
|
269 |
+
|
270 |
+
loss_mmu = F.cross_entropy(
|
271 |
+
logits[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1, self.output_size),
|
272 |
+
labels[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none'
|
273 |
+
)/p_mask_mmu[masked_indices_mmu]
|
274 |
+
loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[-batch_size_mmu:].shape[0])
|
275 |
+
|
276 |
+
return logits, loss_t2i, loss_lm, loss_mmu
|
277 |
+
|
278 |
+
def forward_process_with_r2i(
|
279 |
+
self,
|
280 |
+
input_ids,
|
281 |
+
labels,
|
282 |
+
t2i_masks=None,
|
283 |
+
max_seq_length=128,
|
284 |
+
batch_size_t2i=0,
|
285 |
+
batch_size_lm=0,
|
286 |
+
batch_size_mmu=0,
|
287 |
+
batch_size_r2i=0,
|
288 |
+
p_mask_lm=None,
|
289 |
+
p_mask_mmu=None,
|
290 |
+
p_mask_r2i=None,
|
291 |
+
answer_lengths=None,
|
292 |
+
answer_lengths_lm=None,
|
293 |
+
answer_lengths_r2i=None,
|
294 |
+
):
|
295 |
+
# attention bias, True for batch_size, 1, seq_len, seq_len
|
296 |
+
attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
|
297 |
+
attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
|
298 |
+
attention_bias[:batch_size_t2i] = attention_bias_t2i
|
299 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
300 |
+
# logits = self(input_ids).logits
|
301 |
+
self.output_size = logits.shape[-1]
|
302 |
+
|
303 |
+
# print(f"logits shape: {logits.shape}") B, 359, vocab_size
|
304 |
+
|
305 |
+
if batch_size_t2i == 0:
|
306 |
+
loss_t2i = torch.tensor(0.0, device=input_ids.device)
|
307 |
+
else:
|
308 |
+
# t2i loss
|
309 |
+
loss_t2i = F.cross_entropy(
|
310 |
+
logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
|
311 |
+
labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
|
312 |
+
)
|
313 |
+
|
314 |
+
# llada loss
|
315 |
+
|
316 |
+
start_lm = batch_size_t2i
|
317 |
+
end_lm = start_lm + batch_size_lm
|
318 |
+
start_mmu = end_lm
|
319 |
+
end_mmu = start_mmu + batch_size_mmu
|
320 |
+
start_r2i = end_mmu
|
321 |
+
end_r2i = start_r2i + batch_size_r2i
|
322 |
+
|
323 |
+
masked_indices = input_ids == self.config.mask_token_id
|
324 |
+
masked_indices_lm = masked_indices[start_lm:end_lm]
|
325 |
+
masked_indices_mmu = masked_indices[start_mmu:end_mmu]
|
326 |
+
masked_indices_r2i = masked_indices[start_r2i:end_r2i]
|
327 |
+
|
328 |
+
p_mask_lm = p_mask_lm.to(masked_indices_lm.device)
|
329 |
+
p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device)
|
330 |
+
p_mask_r2i = p_mask_r2i.to(masked_indices_r2i.device)
|
331 |
+
|
332 |
+
answer_lengths = answer_lengths.to(masked_indices_mmu.device)
|
333 |
+
answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device)
|
334 |
+
answer_lengths_r2i = answer_lengths_r2i.to(masked_indices_r2i.device)
|
335 |
+
|
336 |
+
loss_lm = F.cross_entropy(
|
337 |
+
logits[start_lm:end_lm][masked_indices_lm].contiguous().view(-1, self.output_size),
|
338 |
+
labels[start_lm:end_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none'
|
339 |
+
)/p_mask_lm[masked_indices_lm]
|
340 |
+
# print(f"logits lm shape: {logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape}")
|
341 |
+
loss_lm = loss_lm.sum() / (logits[start_lm:end_lm].shape[0] * logits[start_lm:end_lm].shape[1])
|
342 |
+
loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[start_lm:end_lm].shape[0])
|
343 |
+
|
344 |
+
loss_mmu = F.cross_entropy(
|
345 |
+
logits[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1, self.output_size),
|
346 |
+
labels[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none'
|
347 |
+
)/p_mask_mmu[masked_indices_mmu]
|
348 |
+
loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[start_mmu:end_mmu].shape[0])
|
349 |
+
|
350 |
+
loss_r2i = F.cross_entropy(
|
351 |
+
logits[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1, self.output_size),
|
352 |
+
labels[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1), ignore_index=-100, reduction='none'
|
353 |
+
)/p_mask_r2i[masked_indices_r2i]
|
354 |
+
loss_r2i = torch.sum(loss_r2i/answer_lengths_r2i[masked_indices_r2i]) / (logits[start_r2i:end_r2i].shape[0])
|
355 |
+
|
356 |
+
return logits, loss_t2i, loss_lm, loss_mmu, loss_r2i
|
357 |
+
|
358 |
+
|
359 |
+
def forward_t2i(
|
360 |
+
self,
|
361 |
+
input_ids,
|
362 |
+
labels,
|
363 |
+
batch_size_t2i=0,
|
364 |
+
max_seq_length=128,
|
365 |
+
t2i_masks=None
|
366 |
+
):
|
367 |
+
# attention bias, True for batch_size, 1, seq_len, seq_len
|
368 |
+
attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1])
|
369 |
+
attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1)
|
370 |
+
attention_bias[:batch_size_t2i] = attention_bias_t2i
|
371 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
372 |
+
# logits = self(input_ids).logits
|
373 |
+
self.output_size = logits.shape[-1]
|
374 |
+
|
375 |
+
# print(f"logits shape: {logits.shape}") B, 359, vocab_size
|
376 |
+
|
377 |
+
loss_t2i = F.cross_entropy(
|
378 |
+
logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size),
|
379 |
+
labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100,
|
380 |
+
)
|
381 |
+
|
382 |
+
return loss_t2i
|
383 |
+
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
|
388 |
+
@torch.no_grad()
|
389 |
+
def mmu_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
|
390 |
+
"""
|
391 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
392 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
393 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
394 |
+
"""
|
395 |
+
|
396 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
397 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
398 |
+
# print(f"attention_bias: {attention_bias}")
|
399 |
+
else:
|
400 |
+
attention_bias = None
|
401 |
+
try:
|
402 |
+
device = idx.device
|
403 |
+
except:
|
404 |
+
device = input_embeddings.device
|
405 |
+
|
406 |
+
result = []
|
407 |
+
batch_size = idx.shape[0]
|
408 |
+
x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device)
|
409 |
+
x[:, :idx.shape[1]] = idx.clone()
|
410 |
+
prompt_index = (x != mask_id)
|
411 |
+
|
412 |
+
|
413 |
+
assert max_new_tokens % block_length == 0
|
414 |
+
num_blocks = max_new_tokens // block_length
|
415 |
+
|
416 |
+
assert steps % num_blocks == 0
|
417 |
+
steps = steps // num_blocks
|
418 |
+
|
419 |
+
# print(f"num_blocks: {num_blocks}, steps: {steps}")
|
420 |
+
# num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps)
|
421 |
+
for num_block in range(num_blocks):
|
422 |
+
block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
423 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
424 |
+
# num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps)
|
425 |
+
# print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}")
|
426 |
+
for i in range(steps):
|
427 |
+
mask_index = (x == mask_id)
|
428 |
+
if cfg_scale > 0.0:
|
429 |
+
un_x = x.clone()
|
430 |
+
un_x[prompt_index] = mask_id
|
431 |
+
x_ = torch.cat([x, un_x], dim=0)
|
432 |
+
logits = self(x_).logits
|
433 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
434 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
435 |
+
else:
|
436 |
+
logits = self(x, attention_bias=attention_bias).logits
|
437 |
+
|
438 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
439 |
+
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
440 |
+
if remasking == 'low_confidence':
|
441 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
442 |
+
x0_p = torch.squeeze(
|
443 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
444 |
+
elif remasking == 'random':
|
445 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
446 |
+
else:
|
447 |
+
raise NotImplementedError(remasking)
|
448 |
+
|
449 |
+
x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
450 |
+
|
451 |
+
x0 = torch.where(mask_index, x0, x)
|
452 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
453 |
+
|
454 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
455 |
+
for j in range(confidence.shape[0]):
|
456 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
457 |
+
transfer_index[j, select_index] = True
|
458 |
+
x[transfer_index] = x0[transfer_index]
|
459 |
+
|
460 |
+
|
461 |
+
# logits = logits[:, -1, :] / temperature
|
462 |
+
# # optionally crop the logits to only the top k options
|
463 |
+
# if top_k is not None:
|
464 |
+
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
465 |
+
# logits[logits < v[:, [-1]]] = -float('Inf')
|
466 |
+
# # apply softmax to convert logits to (normalized) probabilities
|
467 |
+
# probs = F.softmax(logits, dim=-1)
|
468 |
+
# # sample from the distribution
|
469 |
+
# idx_next = torch.multinomial(probs, num_samples=1)
|
470 |
+
# result.append(idx_next[0][0])
|
471 |
+
# # append sampled index to the running sequence and continue
|
472 |
+
# if self.config.w_clip_vit:
|
473 |
+
# idx_next_embeddings = self.mmada.model.embed_tokens(idx_next)
|
474 |
+
# input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1)
|
475 |
+
# else:
|
476 |
+
# idx = torch.cat((idx, idx_next), dim=1)
|
477 |
+
|
478 |
+
# if eot_token is not None and idx_next.cpu() == eot_token:
|
479 |
+
# break
|
480 |
+
|
481 |
+
return x
|
482 |
+
|
483 |
+
@torch.no_grad()
|
484 |
+
def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
|
485 |
+
"""
|
486 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
487 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
488 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
489 |
+
"""
|
490 |
+
|
491 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
492 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
493 |
+
# print(f"attention_bias: {attention_bias}")
|
494 |
+
else:
|
495 |
+
attention_bias = None
|
496 |
+
try:
|
497 |
+
device = idx.device
|
498 |
+
except:
|
499 |
+
device = input_embeddings.device
|
500 |
+
|
501 |
+
result = []
|
502 |
+
batch_size = idx.shape[0]
|
503 |
+
x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device)
|
504 |
+
x[:, :idx.shape[1]] = idx.clone()
|
505 |
+
prompt_index = (x != mask_id)
|
506 |
+
|
507 |
+
|
508 |
+
assert max_new_tokens % block_length == 0
|
509 |
+
num_blocks = max_new_tokens // block_length
|
510 |
+
|
511 |
+
assert steps % num_blocks == 0
|
512 |
+
steps = steps // num_blocks
|
513 |
+
|
514 |
+
for num_block in range(num_blocks):
|
515 |
+
block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
516 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
517 |
+
for i in range(steps):
|
518 |
+
mask_index = (x == mask_id)
|
519 |
+
if cfg_scale > 0.0:
|
520 |
+
un_x = x.clone()
|
521 |
+
un_x[prompt_index] = mask_id
|
522 |
+
x_ = torch.cat([x, un_x], dim=0)
|
523 |
+
logits = self(x_).logits
|
524 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
525 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
526 |
+
else:
|
527 |
+
logits = self(x, attention_bias=attention_bias).logits
|
528 |
+
|
529 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
530 |
+
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
531 |
+
if remasking == 'low_confidence':
|
532 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
533 |
+
x0_p = torch.squeeze(
|
534 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
535 |
+
elif remasking == 'random':
|
536 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
537 |
+
else:
|
538 |
+
raise NotImplementedError(remasking)
|
539 |
+
|
540 |
+
x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
541 |
+
|
542 |
+
x0 = torch.where(mask_index, x0, x)
|
543 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
544 |
+
|
545 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
546 |
+
for j in range(confidence.shape[0]):
|
547 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
548 |
+
transfer_index[j, select_index] = True
|
549 |
+
x[transfer_index] = x0[transfer_index]
|
550 |
+
if eot_token is not None:
|
551 |
+
last_token_index_in_current_block = idx.shape[1] + (num_block + 1) * block_length - 1
|
552 |
+
if last_token_index_in_current_block < x.shape[1]:
|
553 |
+
tokens_at_block_end = x[:, last_token_index_in_current_block]
|
554 |
+
if torch.all(tokens_at_block_end == eot_token):
|
555 |
+
break
|
556 |
+
return x
|
557 |
+
|
558 |
+
@torch.no_grad()
|
559 |
+
def t2i_generate_decoding_stepwise(
|
560 |
+
self,
|
561 |
+
input_ids: torch.LongTensor = None,
|
562 |
+
uncond_input_ids: torch.LongTensor = None,
|
563 |
+
attention_mask=None,
|
564 |
+
uncond_attention_mask=None,
|
565 |
+
temperature=1.0,
|
566 |
+
timesteps=18, # ideal number of steps is 18 in maskgit paper
|
567 |
+
guidance_scale=0,
|
568 |
+
noise_schedule=cosine_schedule,
|
569 |
+
generator: torch.Generator = None,
|
570 |
+
config=None,
|
571 |
+
seq_len=1024,
|
572 |
+
mask_token_id = 126336,
|
573 |
+
resolution = 512,
|
574 |
+
codebook_size = 8192,
|
575 |
+
vq_model = None,
|
576 |
+
**kwargs,
|
577 |
+
):
|
578 |
+
"""
|
579 |
+
Generate 1:1 similar to the original MaskGit repo
|
580 |
+
https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
|
581 |
+
"""
|
582 |
+
|
583 |
+
# begin with all image token ids masked
|
584 |
+
# 计算有多少个mask token
|
585 |
+
mask_count = (input_ids == mask_token_id).sum().item()
|
586 |
+
num_vq_tokens = seq_len
|
587 |
+
num_new_special_tokens = 0
|
588 |
+
uni_prompting = kwargs.get("uni_prompting", None)
|
589 |
+
# print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}")
|
590 |
+
input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone()
|
591 |
+
input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens)
|
592 |
+
|
593 |
+
# for classifier-free guidance
|
594 |
+
if uncond_input_ids is not None:
|
595 |
+
uncond_prefix = uncond_input_ids[:, :resolution + 1]
|
596 |
+
|
597 |
+
for step in range(timesteps):
|
598 |
+
if uncond_input_ids is not None and guidance_scale > 0:
|
599 |
+
uncond_input_ids = torch.cat(
|
600 |
+
[uncond_prefix, input_ids[:, resolution + 1:]], dim=1)
|
601 |
+
model_input = torch.cat([input_ids, uncond_input_ids])
|
602 |
+
attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0)
|
603 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
604 |
+
logits = self(model_input, attention_bias=attention_bias).logits
|
605 |
+
# print(f"logits.shape: {logits.shape}")
|
606 |
+
cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0)
|
607 |
+
# logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
608 |
+
# it seems that muse has a different cfg setting
|
609 |
+
logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits
|
610 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
611 |
+
else:
|
612 |
+
attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1)
|
613 |
+
logits = self(input_ids, attention_bias=attention_bias).logits
|
614 |
+
logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size]
|
615 |
+
|
616 |
+
# logits: 1, 1024, 8192
|
617 |
+
# print(f"logits.shape: {logits.shape}")
|
618 |
+
probs = logits.softmax(dim=-1)
|
619 |
+
sampled = probs.reshape(-1, logits.size(-1))
|
620 |
+
# print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}")
|
621 |
+
sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024
|
622 |
+
|
623 |
+
unknown_map = input_ids_minus_lm_vocab_size == mask_token_id
|
624 |
+
# print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}")
|
625 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size)
|
626 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
627 |
+
current_image_vq_indices = sampled_ids.clone()
|
628 |
+
# print(f"current_image_vq_indices: {current_image_vq_indices}")
|
629 |
+
current_image_vq_indices = torch.clamp(current_image_vq_indices, 0, 8192 - 1)
|
630 |
+
current_image = vq_model.decode_code(current_image_vq_indices)
|
631 |
+
images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0)
|
632 |
+
images *= 255.0
|
633 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
634 |
+
pil_images = Image.fromarray(images[0])
|
635 |
+
yield pil_images, f"Step {step + 1}/{timesteps}"
|
636 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
637 |
+
ratio = 1.0 * (step + 1) / timesteps
|
638 |
+
mask_ratio = noise_schedule(torch.tensor(ratio))
|
639 |
+
# Computes the probabilities of each selected tokens.
|
640 |
+
selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
|
641 |
+
selected_probs = selected_probs.squeeze(-1)
|
642 |
+
|
643 |
+
# Ignores the tokens given in the input by overwriting their confidence.
|
644 |
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
645 |
+
# Gets mask lens for each sample in the batch according to the mask ratio.
|
646 |
+
mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device)
|
647 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
648 |
+
# one and for the next iteration
|
649 |
+
mask_len = torch.max(
|
650 |
+
torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
651 |
+
)
|
652 |
+
# print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}")
|
653 |
+
# Adds noise for randomness
|
654 |
+
temperature = temperature * (1.0 - ratio)
|
655 |
+
masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
|
656 |
+
# Masks tokens with lower confidence.
|
657 |
+
input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id,
|
658 |
+
sampled_ids + len(uni_prompting.text_tokenizer)
|
659 |
+
+ num_new_special_tokens)
|
660 |
+
input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids)
|
661 |
+
|
662 |
+
|
663 |
+
return sampled_ids
|
664 |
+
|
665 |
+
|
666 |
+
AutoConfig.register("mmada", MMadaConfig)
|
667 |
+
AutoModelForCausalLM.register(MMadaConfig, MMadaModelLM)
|
668 |
+
AutoModel.register(MMadaConfig, MMadaModelLM)
|
models/modeling_utils.py
ADDED
@@ -0,0 +1,1207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
import itertools
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
from collections import OrderedDict
|
23 |
+
from functools import partial
|
24 |
+
from pathlib import Path
|
25 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
26 |
+
|
27 |
+
import safetensors
|
28 |
+
import torch
|
29 |
+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
30 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
31 |
+
from torch import Tensor, nn
|
32 |
+
|
33 |
+
from diffusers import __version__
|
34 |
+
from diffusers.utils import (
|
35 |
+
FLAX_WEIGHTS_NAME,
|
36 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
37 |
+
WEIGHTS_INDEX_NAME,
|
38 |
+
_add_variant,
|
39 |
+
_get_checkpoint_shard_files,
|
40 |
+
_get_model_file,
|
41 |
+
deprecate,
|
42 |
+
is_accelerate_available,
|
43 |
+
is_torch_version,
|
44 |
+
logging,
|
45 |
+
)
|
46 |
+
|
47 |
+
CONFIG_NAME = "config.json"
|
48 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
49 |
+
SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
|
50 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
51 |
+
|
52 |
+
from diffusers.utils.hub_utils import (
|
53 |
+
PushToHubMixin,
|
54 |
+
load_or_create_model_card,
|
55 |
+
populate_model_card,
|
56 |
+
)
|
57 |
+
from diffusers.models.model_loading_utils import (
|
58 |
+
_determine_device_map,
|
59 |
+
_fetch_index_file,
|
60 |
+
_load_state_dict_into_model,
|
61 |
+
load_model_dict_into_meta,
|
62 |
+
load_state_dict,
|
63 |
+
)
|
64 |
+
|
65 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
66 |
+
|
67 |
+
logger = logging.get_logger(__name__)
|
68 |
+
|
69 |
+
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
70 |
+
|
71 |
+
|
72 |
+
if is_torch_version(">=", "1.9.0"):
|
73 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
74 |
+
else:
|
75 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
76 |
+
|
77 |
+
|
78 |
+
if is_accelerate_available():
|
79 |
+
import accelerate
|
80 |
+
|
81 |
+
|
82 |
+
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
83 |
+
try:
|
84 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
85 |
+
return next(parameters_and_buffers).device
|
86 |
+
except StopIteration:
|
87 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
88 |
+
|
89 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
90 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
91 |
+
return tuples
|
92 |
+
|
93 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
94 |
+
first_tuple = next(gen)
|
95 |
+
return first_tuple[1].device
|
96 |
+
|
97 |
+
|
98 |
+
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
99 |
+
try:
|
100 |
+
params = tuple(parameter.parameters())
|
101 |
+
if len(params) > 0:
|
102 |
+
return params[0].dtype
|
103 |
+
|
104 |
+
buffers = tuple(parameter.buffers())
|
105 |
+
if len(buffers) > 0:
|
106 |
+
return buffers[0].dtype
|
107 |
+
|
108 |
+
except StopIteration:
|
109 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
110 |
+
|
111 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
112 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
113 |
+
return tuples
|
114 |
+
|
115 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
116 |
+
first_tuple = next(gen)
|
117 |
+
return first_tuple[1].dtype
|
118 |
+
|
119 |
+
|
120 |
+
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
121 |
+
r"""
|
122 |
+
Base class for all models.
|
123 |
+
|
124 |
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
125 |
+
saving models.
|
126 |
+
|
127 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
128 |
+
"""
|
129 |
+
|
130 |
+
config_name = CONFIG_NAME
|
131 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
132 |
+
_supports_gradient_checkpointing = False
|
133 |
+
_keys_to_ignore_on_load_unexpected = None
|
134 |
+
_no_split_modules = None
|
135 |
+
|
136 |
+
def __init__(self):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
def __getattr__(self, name: str) -> Any:
|
140 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
141 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
142 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
143 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
144 |
+
"""
|
145 |
+
|
146 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
147 |
+
is_attribute = name in self.__dict__
|
148 |
+
|
149 |
+
if is_in_config and not is_attribute:
|
150 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
151 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
152 |
+
return self._internal_dict[name]
|
153 |
+
|
154 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
155 |
+
return super().__getattr__(name)
|
156 |
+
|
157 |
+
@property
|
158 |
+
def is_gradient_checkpointing(self) -> bool:
|
159 |
+
"""
|
160 |
+
Whether gradient checkpointing is activated for this model or not.
|
161 |
+
"""
|
162 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
163 |
+
|
164 |
+
def enable_gradient_checkpointing(self) -> None:
|
165 |
+
"""
|
166 |
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
167 |
+
*checkpoint activations* in other frameworks).
|
168 |
+
"""
|
169 |
+
if not self._supports_gradient_checkpointing:
|
170 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
171 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
172 |
+
|
173 |
+
def disable_gradient_checkpointing(self) -> None:
|
174 |
+
"""
|
175 |
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
176 |
+
*checkpoint activations* in other frameworks).
|
177 |
+
"""
|
178 |
+
if self._supports_gradient_checkpointing:
|
179 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
180 |
+
|
181 |
+
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
182 |
+
r"""
|
183 |
+
Set the switch for the npu flash attention.
|
184 |
+
"""
|
185 |
+
|
186 |
+
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
187 |
+
if hasattr(module, "set_use_npu_flash_attention"):
|
188 |
+
module.set_use_npu_flash_attention(valid)
|
189 |
+
|
190 |
+
for child in module.children():
|
191 |
+
fn_recursive_set_npu_flash_attention(child)
|
192 |
+
|
193 |
+
for module in self.children():
|
194 |
+
if isinstance(module, torch.nn.Module):
|
195 |
+
fn_recursive_set_npu_flash_attention(module)
|
196 |
+
|
197 |
+
def enable_npu_flash_attention(self) -> None:
|
198 |
+
r"""
|
199 |
+
Enable npu flash attention from torch_npu
|
200 |
+
|
201 |
+
"""
|
202 |
+
self.set_use_npu_flash_attention(True)
|
203 |
+
|
204 |
+
def disable_npu_flash_attention(self) -> None:
|
205 |
+
r"""
|
206 |
+
disable npu flash attention from torch_npu
|
207 |
+
|
208 |
+
"""
|
209 |
+
self.set_use_npu_flash_attention(False)
|
210 |
+
|
211 |
+
def set_use_memory_efficient_attention_xformers(
|
212 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
213 |
+
) -> None:
|
214 |
+
# Recursively walk through all the children.
|
215 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
216 |
+
# gets the message
|
217 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
218 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
219 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
220 |
+
|
221 |
+
for child in module.children():
|
222 |
+
fn_recursive_set_mem_eff(child)
|
223 |
+
|
224 |
+
for module in self.children():
|
225 |
+
if isinstance(module, torch.nn.Module):
|
226 |
+
fn_recursive_set_mem_eff(module)
|
227 |
+
|
228 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
|
229 |
+
r"""
|
230 |
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
231 |
+
|
232 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
233 |
+
inference. Speed up during training is not guaranteed.
|
234 |
+
|
235 |
+
<Tip warning={true}>
|
236 |
+
|
237 |
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
238 |
+
precedent.
|
239 |
+
|
240 |
+
</Tip>
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
attention_op (`Callable`, *optional*):
|
244 |
+
Override the default `None` operator for use as `op` argument to the
|
245 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
246 |
+
function of xFormers.
|
247 |
+
|
248 |
+
Examples:
|
249 |
+
|
250 |
+
```py
|
251 |
+
>>> import torch
|
252 |
+
>>> from diffusers import UNet2DConditionModel
|
253 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
254 |
+
|
255 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
256 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
257 |
+
... )
|
258 |
+
>>> model = model.to("cuda")
|
259 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
260 |
+
```
|
261 |
+
"""
|
262 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
263 |
+
|
264 |
+
def disable_xformers_memory_efficient_attention(self) -> None:
|
265 |
+
r"""
|
266 |
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
267 |
+
"""
|
268 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
269 |
+
|
270 |
+
def save_pretrained(
|
271 |
+
self,
|
272 |
+
save_directory: Union[str, os.PathLike],
|
273 |
+
is_main_process: bool = True,
|
274 |
+
save_function: Optional[Callable] = None,
|
275 |
+
safe_serialization: bool = True,
|
276 |
+
variant: Optional[str] = None,
|
277 |
+
max_shard_size: Union[int, str] = "10GB",
|
278 |
+
push_to_hub: bool = False,
|
279 |
+
**kwargs,
|
280 |
+
):
|
281 |
+
"""
|
282 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
283 |
+
[`~models.ModelMixin.from_pretrained`] class method.
|
284 |
+
|
285 |
+
Arguments:
|
286 |
+
save_directory (`str` or `os.PathLike`):
|
287 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
288 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
289 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
290 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
291 |
+
process to avoid race conditions.
|
292 |
+
save_function (`Callable`):
|
293 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
294 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
295 |
+
`DIFFUSERS_SAVE_MODE`.
|
296 |
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
297 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
298 |
+
variant (`str`, *optional*):
|
299 |
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
300 |
+
max_shard_size (`int` or `str`, defaults to `"10GB"`):
|
301 |
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
302 |
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
303 |
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
304 |
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
305 |
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
306 |
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
307 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
308 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
309 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
310 |
+
namespace).
|
311 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
312 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
313 |
+
"""
|
314 |
+
if os.path.isfile(save_directory):
|
315 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
316 |
+
return
|
317 |
+
|
318 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
319 |
+
weights_name = _add_variant(weights_name, variant)
|
320 |
+
weight_name_split = weights_name.split(".")
|
321 |
+
if len(weight_name_split) in [2, 3]:
|
322 |
+
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
323 |
+
else:
|
324 |
+
raise ValueError(f"Invalid {weights_name} provided.")
|
325 |
+
|
326 |
+
os.makedirs(save_directory, exist_ok=True)
|
327 |
+
|
328 |
+
if push_to_hub:
|
329 |
+
commit_message = kwargs.pop("commit_message", None)
|
330 |
+
private = kwargs.pop("private", False)
|
331 |
+
create_pr = kwargs.pop("create_pr", False)
|
332 |
+
token = kwargs.pop("token", None)
|
333 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
334 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
335 |
+
|
336 |
+
# Only save the model itself if we are using distributed training
|
337 |
+
model_to_save = self
|
338 |
+
|
339 |
+
# Attach architecture to the config
|
340 |
+
# Save the config
|
341 |
+
if is_main_process:
|
342 |
+
model_to_save.save_config(save_directory)
|
343 |
+
|
344 |
+
# Save the model
|
345 |
+
state_dict = model_to_save.state_dict()
|
346 |
+
|
347 |
+
# Save the model
|
348 |
+
state_dict_split = split_torch_state_dict_into_shards(
|
349 |
+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
350 |
+
)
|
351 |
+
|
352 |
+
# Clean the folder from a previous save
|
353 |
+
if is_main_process:
|
354 |
+
for filename in os.listdir(save_directory):
|
355 |
+
if filename in state_dict_split.filename_to_tensors.keys():
|
356 |
+
continue
|
357 |
+
full_filename = os.path.join(save_directory, filename)
|
358 |
+
if not os.path.isfile(full_filename):
|
359 |
+
continue
|
360 |
+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
361 |
+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
362 |
+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
363 |
+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
364 |
+
if (
|
365 |
+
filename.startswith(weights_without_ext)
|
366 |
+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
367 |
+
):
|
368 |
+
os.remove(full_filename)
|
369 |
+
|
370 |
+
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
371 |
+
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
372 |
+
filepath = os.path.join(save_directory, filename)
|
373 |
+
if safe_serialization:
|
374 |
+
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
375 |
+
# joyfulness), but for now this enough.
|
376 |
+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
377 |
+
else:
|
378 |
+
torch.save(shard, filepath)
|
379 |
+
|
380 |
+
if state_dict_split.is_sharded:
|
381 |
+
index = {
|
382 |
+
"metadata": state_dict_split.metadata,
|
383 |
+
"weight_map": state_dict_split.tensor_to_filename,
|
384 |
+
}
|
385 |
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
386 |
+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
387 |
+
# Save the index as well
|
388 |
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
389 |
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
390 |
+
f.write(content)
|
391 |
+
logger.info(
|
392 |
+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
393 |
+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
394 |
+
f"index located at {save_index_file}."
|
395 |
+
)
|
396 |
+
else:
|
397 |
+
path_to_weights = os.path.join(save_directory, weights_name)
|
398 |
+
logger.info(f"Model weights saved in {path_to_weights}")
|
399 |
+
|
400 |
+
if push_to_hub:
|
401 |
+
# Create a new empty model card and eventually tag it
|
402 |
+
model_card = load_or_create_model_card(repo_id, token=token)
|
403 |
+
model_card = populate_model_card(model_card)
|
404 |
+
model_card.save(Path(save_directory, "README.md").as_posix())
|
405 |
+
|
406 |
+
self._upload_folder(
|
407 |
+
save_directory,
|
408 |
+
repo_id,
|
409 |
+
token=token,
|
410 |
+
commit_message=commit_message,
|
411 |
+
create_pr=create_pr,
|
412 |
+
)
|
413 |
+
|
414 |
+
@classmethod
|
415 |
+
@validate_hf_hub_args
|
416 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
417 |
+
r"""
|
418 |
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
419 |
+
|
420 |
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
421 |
+
train the model, set it back in training mode with `model.train()`.
|
422 |
+
|
423 |
+
Parameters:
|
424 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
425 |
+
Can be either:
|
426 |
+
|
427 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
428 |
+
the Hub.
|
429 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
430 |
+
with [`~ModelMixin.save_pretrained`].
|
431 |
+
|
432 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
433 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
434 |
+
is not used.
|
435 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
436 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
437 |
+
dtype is automatically derived from the model's weights.
|
438 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
439 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
440 |
+
cached versions if they exist.
|
441 |
+
proxies (`Dict[str, str]`, *optional*):
|
442 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
443 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
444 |
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
445 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
446 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
447 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
448 |
+
won't be downloaded from the Hub.
|
449 |
+
token (`str` or *bool*, *optional*):
|
450 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
451 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
452 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
453 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
454 |
+
allowed by Git.
|
455 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
456 |
+
Load the model weights from a Flax checkpoint save file.
|
457 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
458 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
459 |
+
mirror (`str`, *optional*):
|
460 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
461 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
462 |
+
information.
|
463 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
464 |
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
465 |
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
466 |
+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
467 |
+
|
468 |
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
469 |
+
more information about each option see [designing a device
|
470 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
471 |
+
max_memory (`Dict`, *optional*):
|
472 |
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
473 |
+
each GPU and the available CPU RAM if unset.
|
474 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
475 |
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
476 |
+
offload_state_dict (`bool`, *optional*):
|
477 |
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
478 |
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
479 |
+
when there is some disk offload.
|
480 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
481 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
482 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
483 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
484 |
+
argument to `True` will raise an error.
|
485 |
+
variant (`str`, *optional*):
|
486 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
487 |
+
loading `from_flax`.
|
488 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
489 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
490 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
491 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
492 |
+
|
493 |
+
<Tip>
|
494 |
+
|
495 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
496 |
+
`huggingface-cli login`. You can also activate the special
|
497 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
498 |
+
firewalled environment.
|
499 |
+
|
500 |
+
</Tip>
|
501 |
+
|
502 |
+
Example:
|
503 |
+
|
504 |
+
```py
|
505 |
+
from diffusers import UNet2DConditionModel
|
506 |
+
|
507 |
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
508 |
+
```
|
509 |
+
|
510 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
511 |
+
|
512 |
+
```bash
|
513 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
514 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
515 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
516 |
+
```
|
517 |
+
"""
|
518 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
519 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
520 |
+
force_download = kwargs.pop("force_download", False)
|
521 |
+
from_flax = kwargs.pop("from_flax", False)
|
522 |
+
proxies = kwargs.pop("proxies", None)
|
523 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
524 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
525 |
+
token = kwargs.pop("token", None)
|
526 |
+
revision = kwargs.pop("revision", None)
|
527 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
528 |
+
subfolder = kwargs.pop("subfolder", None)
|
529 |
+
device_map = kwargs.pop("device_map", None)
|
530 |
+
max_memory = kwargs.pop("max_memory", None)
|
531 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
532 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
533 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
534 |
+
variant = kwargs.pop("variant", None)
|
535 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
536 |
+
|
537 |
+
allow_pickle = False
|
538 |
+
if use_safetensors is None:
|
539 |
+
use_safetensors = True
|
540 |
+
allow_pickle = True
|
541 |
+
|
542 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
543 |
+
low_cpu_mem_usage = False
|
544 |
+
logger.warning(
|
545 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
546 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
547 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
548 |
+
" install accelerate\n```\n."
|
549 |
+
)
|
550 |
+
|
551 |
+
if device_map is not None and not is_accelerate_available():
|
552 |
+
raise NotImplementedError(
|
553 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
554 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
555 |
+
)
|
556 |
+
|
557 |
+
# Check if we can handle device_map and dispatching the weights
|
558 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
559 |
+
raise NotImplementedError(
|
560 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
561 |
+
" `device_map=None`."
|
562 |
+
)
|
563 |
+
|
564 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
565 |
+
raise NotImplementedError(
|
566 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
567 |
+
" `low_cpu_mem_usage=False`."
|
568 |
+
)
|
569 |
+
|
570 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
571 |
+
raise ValueError(
|
572 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
573 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
574 |
+
)
|
575 |
+
|
576 |
+
# change device_map into a map if we passed an int, a str or a torch.device
|
577 |
+
if isinstance(device_map, torch.device):
|
578 |
+
device_map = {"": device_map}
|
579 |
+
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
580 |
+
try:
|
581 |
+
device_map = {"": torch.device(device_map)}
|
582 |
+
except RuntimeError:
|
583 |
+
raise ValueError(
|
584 |
+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
585 |
+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
586 |
+
)
|
587 |
+
elif isinstance(device_map, int):
|
588 |
+
if device_map < 0:
|
589 |
+
raise ValueError(
|
590 |
+
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
591 |
+
)
|
592 |
+
else:
|
593 |
+
device_map = {"": device_map}
|
594 |
+
|
595 |
+
if device_map is not None:
|
596 |
+
if low_cpu_mem_usage is None:
|
597 |
+
low_cpu_mem_usage = True
|
598 |
+
elif not low_cpu_mem_usage:
|
599 |
+
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
600 |
+
|
601 |
+
if low_cpu_mem_usage:
|
602 |
+
if device_map is not None and not is_torch_version(">=", "1.10"):
|
603 |
+
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
604 |
+
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
605 |
+
|
606 |
+
# Load config if we don't provide a configuration
|
607 |
+
config_path = pretrained_model_name_or_path
|
608 |
+
|
609 |
+
user_agent = {
|
610 |
+
"diffusers": __version__,
|
611 |
+
"file_type": "model",
|
612 |
+
"framework": "pytorch",
|
613 |
+
}
|
614 |
+
|
615 |
+
# load config
|
616 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
617 |
+
config_path,
|
618 |
+
cache_dir=cache_dir,
|
619 |
+
return_unused_kwargs=True,
|
620 |
+
return_commit_hash=True,
|
621 |
+
force_download=force_download,
|
622 |
+
proxies=proxies,
|
623 |
+
local_files_only=local_files_only,
|
624 |
+
token=token,
|
625 |
+
revision=revision,
|
626 |
+
subfolder=subfolder,
|
627 |
+
user_agent=user_agent,
|
628 |
+
**kwargs,
|
629 |
+
)
|
630 |
+
|
631 |
+
# Determine if we're loading from a directory of sharded checkpoints.
|
632 |
+
is_sharded = False
|
633 |
+
index_file = None
|
634 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
635 |
+
index_file = _fetch_index_file(
|
636 |
+
is_local=is_local,
|
637 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
638 |
+
subfolder=subfolder or "",
|
639 |
+
use_safetensors=use_safetensors,
|
640 |
+
cache_dir=cache_dir,
|
641 |
+
variant=variant,
|
642 |
+
force_download=force_download,
|
643 |
+
proxies=proxies,
|
644 |
+
local_files_only=local_files_only,
|
645 |
+
token=token,
|
646 |
+
revision=revision,
|
647 |
+
user_agent=user_agent,
|
648 |
+
commit_hash=commit_hash,
|
649 |
+
)
|
650 |
+
if index_file is not None and index_file.is_file():
|
651 |
+
is_sharded = True
|
652 |
+
|
653 |
+
if is_sharded and from_flax:
|
654 |
+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
655 |
+
|
656 |
+
# load model
|
657 |
+
model_file = None
|
658 |
+
if from_flax:
|
659 |
+
model_file = _get_model_file(
|
660 |
+
pretrained_model_name_or_path,
|
661 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
662 |
+
cache_dir=cache_dir,
|
663 |
+
force_download=force_download,
|
664 |
+
proxies=proxies,
|
665 |
+
local_files_only=local_files_only,
|
666 |
+
token=token,
|
667 |
+
revision=revision,
|
668 |
+
subfolder=subfolder,
|
669 |
+
user_agent=user_agent,
|
670 |
+
commit_hash=commit_hash,
|
671 |
+
)
|
672 |
+
model = cls.from_config(config, **unused_kwargs)
|
673 |
+
|
674 |
+
# Convert the weights
|
675 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
676 |
+
|
677 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
678 |
+
else:
|
679 |
+
if is_sharded:
|
680 |
+
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
681 |
+
pretrained_model_name_or_path,
|
682 |
+
index_file,
|
683 |
+
cache_dir=cache_dir,
|
684 |
+
proxies=proxies,
|
685 |
+
local_files_only=local_files_only,
|
686 |
+
token=token,
|
687 |
+
user_agent=user_agent,
|
688 |
+
revision=revision,
|
689 |
+
subfolder=subfolder or "",
|
690 |
+
)
|
691 |
+
|
692 |
+
elif use_safetensors and not is_sharded:
|
693 |
+
try:
|
694 |
+
model_file = _get_model_file(
|
695 |
+
pretrained_model_name_or_path,
|
696 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
697 |
+
cache_dir=cache_dir,
|
698 |
+
force_download=force_download,
|
699 |
+
proxies=proxies,
|
700 |
+
local_files_only=local_files_only,
|
701 |
+
token=token,
|
702 |
+
revision=revision,
|
703 |
+
subfolder=subfolder,
|
704 |
+
user_agent=user_agent,
|
705 |
+
commit_hash=commit_hash,
|
706 |
+
)
|
707 |
+
|
708 |
+
except IOError as e:
|
709 |
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
710 |
+
if not allow_pickle:
|
711 |
+
raise
|
712 |
+
logger.warning(
|
713 |
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
714 |
+
)
|
715 |
+
|
716 |
+
if model_file is None and not is_sharded:
|
717 |
+
model_file = _get_model_file(
|
718 |
+
pretrained_model_name_or_path,
|
719 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
720 |
+
cache_dir=cache_dir,
|
721 |
+
force_download=force_download,
|
722 |
+
proxies=proxies,
|
723 |
+
local_files_only=local_files_only,
|
724 |
+
token=token,
|
725 |
+
revision=revision,
|
726 |
+
subfolder=subfolder,
|
727 |
+
user_agent=user_agent,
|
728 |
+
commit_hash=commit_hash,
|
729 |
+
)
|
730 |
+
|
731 |
+
if low_cpu_mem_usage:
|
732 |
+
# Instantiate model with empty weights
|
733 |
+
with accelerate.init_empty_weights():
|
734 |
+
model = cls.from_config(config, **unused_kwargs)
|
735 |
+
|
736 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
737 |
+
if device_map is None and not is_sharded:
|
738 |
+
param_device = "cpu"
|
739 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
740 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
741 |
+
# move the params from meta device to cpu
|
742 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
743 |
+
if len(missing_keys) > 0:
|
744 |
+
raise ValueError(
|
745 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
746 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
747 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
748 |
+
" those weights or else make sure your checkpoint file is correct."
|
749 |
+
)
|
750 |
+
|
751 |
+
unexpected_keys = load_model_dict_into_meta(
|
752 |
+
model,
|
753 |
+
state_dict,
|
754 |
+
device=param_device,
|
755 |
+
dtype=torch_dtype,
|
756 |
+
model_name_or_path=pretrained_model_name_or_path,
|
757 |
+
)
|
758 |
+
|
759 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
760 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
761 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
762 |
+
|
763 |
+
if len(unexpected_keys) > 0:
|
764 |
+
logger.warning(
|
765 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
766 |
+
)
|
767 |
+
|
768 |
+
else: # else let accelerate handle loading and dispatching.
|
769 |
+
# Load weights and dispatch according to the device_map
|
770 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
771 |
+
force_hook = True
|
772 |
+
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
773 |
+
if device_map is None and is_sharded:
|
774 |
+
# we load the parameters on the cpu
|
775 |
+
device_map = {"": "cpu"}
|
776 |
+
force_hook = False
|
777 |
+
try:
|
778 |
+
accelerate.load_checkpoint_and_dispatch(
|
779 |
+
model,
|
780 |
+
model_file if not is_sharded else index_file,
|
781 |
+
device_map,
|
782 |
+
max_memory=max_memory,
|
783 |
+
offload_folder=offload_folder,
|
784 |
+
offload_state_dict=offload_state_dict,
|
785 |
+
dtype=torch_dtype,
|
786 |
+
force_hooks=force_hook,
|
787 |
+
strict=True,
|
788 |
+
)
|
789 |
+
except AttributeError as e:
|
790 |
+
# When using accelerate loading, we do not have the ability to load the state
|
791 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
792 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
793 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
794 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
795 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
796 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
797 |
+
# the weights so we don't have to do this again.
|
798 |
+
|
799 |
+
if "'Attention' object has no attribute" in str(e):
|
800 |
+
logger.warning(
|
801 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
802 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
803 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
804 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
805 |
+
" please also re-upload it or open a PR on the original repository."
|
806 |
+
)
|
807 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
808 |
+
accelerate.load_checkpoint_and_dispatch(
|
809 |
+
model,
|
810 |
+
model_file if not is_sharded else index_file,
|
811 |
+
device_map,
|
812 |
+
max_memory=max_memory,
|
813 |
+
offload_folder=offload_folder,
|
814 |
+
offload_state_dict=offload_state_dict,
|
815 |
+
dtype=torch_dtype,
|
816 |
+
force_hooks=force_hook,
|
817 |
+
strict=True,
|
818 |
+
)
|
819 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
820 |
+
else:
|
821 |
+
raise e
|
822 |
+
|
823 |
+
loading_info = {
|
824 |
+
"missing_keys": [],
|
825 |
+
"unexpected_keys": [],
|
826 |
+
"mismatched_keys": [],
|
827 |
+
"error_msgs": [],
|
828 |
+
}
|
829 |
+
else:
|
830 |
+
model = cls.from_config(config, **unused_kwargs)
|
831 |
+
|
832 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
833 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
834 |
+
|
835 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
836 |
+
model,
|
837 |
+
state_dict,
|
838 |
+
model_file,
|
839 |
+
pretrained_model_name_or_path,
|
840 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
841 |
+
)
|
842 |
+
|
843 |
+
loading_info = {
|
844 |
+
"missing_keys": missing_keys,
|
845 |
+
"unexpected_keys": unexpected_keys,
|
846 |
+
"mismatched_keys": mismatched_keys,
|
847 |
+
"error_msgs": error_msgs,
|
848 |
+
}
|
849 |
+
|
850 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
851 |
+
raise ValueError(
|
852 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
853 |
+
)
|
854 |
+
elif torch_dtype is not None:
|
855 |
+
model = model.to(torch_dtype)
|
856 |
+
|
857 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
858 |
+
|
859 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
860 |
+
model.eval()
|
861 |
+
if output_loading_info:
|
862 |
+
return model, loading_info
|
863 |
+
|
864 |
+
return model
|
865 |
+
|
866 |
+
@classmethod
|
867 |
+
def _load_pretrained_model(
|
868 |
+
cls,
|
869 |
+
model,
|
870 |
+
state_dict: OrderedDict,
|
871 |
+
resolved_archive_file,
|
872 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
873 |
+
ignore_mismatched_sizes: bool = False,
|
874 |
+
):
|
875 |
+
# Retrieve missing & unexpected_keys
|
876 |
+
model_state_dict = model.state_dict()
|
877 |
+
loaded_keys = list(state_dict.keys())
|
878 |
+
|
879 |
+
expected_keys = list(model_state_dict.keys())
|
880 |
+
|
881 |
+
original_loaded_keys = loaded_keys
|
882 |
+
|
883 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
884 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
885 |
+
|
886 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
887 |
+
model_to_load = model
|
888 |
+
|
889 |
+
def _find_mismatched_keys(
|
890 |
+
state_dict,
|
891 |
+
model_state_dict,
|
892 |
+
loaded_keys,
|
893 |
+
ignore_mismatched_sizes,
|
894 |
+
):
|
895 |
+
mismatched_keys = []
|
896 |
+
if ignore_mismatched_sizes:
|
897 |
+
for checkpoint_key in loaded_keys:
|
898 |
+
model_key = checkpoint_key
|
899 |
+
|
900 |
+
if (
|
901 |
+
model_key in model_state_dict
|
902 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
903 |
+
):
|
904 |
+
mismatched_keys.append(
|
905 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
906 |
+
)
|
907 |
+
del state_dict[checkpoint_key]
|
908 |
+
return mismatched_keys
|
909 |
+
|
910 |
+
if state_dict is not None:
|
911 |
+
# Whole checkpoint
|
912 |
+
mismatched_keys = _find_mismatched_keys(
|
913 |
+
state_dict,
|
914 |
+
model_state_dict,
|
915 |
+
original_loaded_keys,
|
916 |
+
ignore_mismatched_sizes,
|
917 |
+
)
|
918 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
919 |
+
|
920 |
+
if len(error_msgs) > 0:
|
921 |
+
error_msg = "\n\t".join(error_msgs)
|
922 |
+
if "size mismatch" in error_msg:
|
923 |
+
error_msg += (
|
924 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
925 |
+
)
|
926 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
927 |
+
|
928 |
+
if len(unexpected_keys) > 0:
|
929 |
+
logger.warning(
|
930 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
931 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
932 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
933 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
934 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
935 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
936 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
937 |
+
" BertForSequenceClassification model)."
|
938 |
+
)
|
939 |
+
else:
|
940 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
941 |
+
if len(missing_keys) > 0:
|
942 |
+
logger.warning(
|
943 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
944 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
945 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
946 |
+
)
|
947 |
+
elif len(mismatched_keys) == 0:
|
948 |
+
logger.info(
|
949 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
950 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
951 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
952 |
+
" without further training."
|
953 |
+
)
|
954 |
+
if len(mismatched_keys) > 0:
|
955 |
+
mismatched_warning = "\n".join(
|
956 |
+
[
|
957 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
958 |
+
for key, shape1, shape2 in mismatched_keys
|
959 |
+
]
|
960 |
+
)
|
961 |
+
logger.warning(
|
962 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
963 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
964 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
965 |
+
" able to use it for predictions and inference."
|
966 |
+
)
|
967 |
+
|
968 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
969 |
+
|
970 |
+
@classmethod
|
971 |
+
def _get_signature_keys(cls, obj):
|
972 |
+
parameters = inspect.signature(obj.__init__).parameters
|
973 |
+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
974 |
+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
975 |
+
expected_modules = set(required_parameters.keys()) - {"self"}
|
976 |
+
|
977 |
+
return expected_modules, optional_parameters
|
978 |
+
|
979 |
+
# Adapted from `transformers` modeling_utils.py
|
980 |
+
def _get_no_split_modules(self, device_map: str):
|
981 |
+
"""
|
982 |
+
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
983 |
+
get the underlying `_no_split_modules`.
|
984 |
+
|
985 |
+
Args:
|
986 |
+
device_map (`str`):
|
987 |
+
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
988 |
+
|
989 |
+
Returns:
|
990 |
+
`List[str]`: List of modules that should not be split
|
991 |
+
"""
|
992 |
+
_no_split_modules = set()
|
993 |
+
modules_to_check = [self]
|
994 |
+
while len(modules_to_check) > 0:
|
995 |
+
module = modules_to_check.pop(-1)
|
996 |
+
# if the module does not appear in _no_split_modules, we also check the children
|
997 |
+
if module.__class__.__name__ not in _no_split_modules:
|
998 |
+
if isinstance(module, ModelMixin):
|
999 |
+
if module._no_split_modules is None:
|
1000 |
+
raise ValueError(
|
1001 |
+
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
1002 |
+
"class needs to implement the `_no_split_modules` attribute."
|
1003 |
+
)
|
1004 |
+
else:
|
1005 |
+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
1006 |
+
modules_to_check += list(module.children())
|
1007 |
+
return list(_no_split_modules)
|
1008 |
+
|
1009 |
+
@property
|
1010 |
+
def device(self) -> torch.device:
|
1011 |
+
"""
|
1012 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
1013 |
+
device).
|
1014 |
+
"""
|
1015 |
+
return get_parameter_device(self)
|
1016 |
+
|
1017 |
+
@property
|
1018 |
+
def dtype(self) -> torch.dtype:
|
1019 |
+
"""
|
1020 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1021 |
+
"""
|
1022 |
+
return get_parameter_dtype(self)
|
1023 |
+
|
1024 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
1025 |
+
"""
|
1026 |
+
Get number of (trainable or non-embedding) parameters in the module.
|
1027 |
+
|
1028 |
+
Args:
|
1029 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
1030 |
+
Whether or not to return only the number of trainable parameters.
|
1031 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
1032 |
+
Whether or not to return only the number of non-embedding parameters.
|
1033 |
+
|
1034 |
+
Returns:
|
1035 |
+
`int`: The number of parameters.
|
1036 |
+
|
1037 |
+
Example:
|
1038 |
+
|
1039 |
+
```py
|
1040 |
+
from diffusers import UNet2DConditionModel
|
1041 |
+
|
1042 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
1043 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
1044 |
+
unet.num_parameters(only_trainable=True)
|
1045 |
+
859520964
|
1046 |
+
```
|
1047 |
+
"""
|
1048 |
+
|
1049 |
+
if exclude_embeddings:
|
1050 |
+
embedding_param_names = [
|
1051 |
+
f"{name}.weight"
|
1052 |
+
for name, module_type in self.named_modules()
|
1053 |
+
if isinstance(module_type, torch.nn.Embedding)
|
1054 |
+
]
|
1055 |
+
non_embedding_parameters = [
|
1056 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
1057 |
+
]
|
1058 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
1059 |
+
else:
|
1060 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
1061 |
+
|
1062 |
+
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
1063 |
+
deprecated_attention_block_paths = []
|
1064 |
+
|
1065 |
+
def recursive_find_attn_block(name, module):
|
1066 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1067 |
+
deprecated_attention_block_paths.append(name)
|
1068 |
+
|
1069 |
+
for sub_name, sub_module in module.named_children():
|
1070 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
1071 |
+
recursive_find_attn_block(sub_name, sub_module)
|
1072 |
+
|
1073 |
+
recursive_find_attn_block("", self)
|
1074 |
+
|
1075 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
1076 |
+
# because it is possible we are loading from a state dict that was already
|
1077 |
+
# converted
|
1078 |
+
|
1079 |
+
for path in deprecated_attention_block_paths:
|
1080 |
+
# group_norm path stays the same
|
1081 |
+
|
1082 |
+
# query -> to_q
|
1083 |
+
if f"{path}.query.weight" in state_dict:
|
1084 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
1085 |
+
if f"{path}.query.bias" in state_dict:
|
1086 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
1087 |
+
|
1088 |
+
# key -> to_k
|
1089 |
+
if f"{path}.key.weight" in state_dict:
|
1090 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
1091 |
+
if f"{path}.key.bias" in state_dict:
|
1092 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
1093 |
+
|
1094 |
+
# value -> to_v
|
1095 |
+
if f"{path}.value.weight" in state_dict:
|
1096 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
1097 |
+
if f"{path}.value.bias" in state_dict:
|
1098 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
1099 |
+
|
1100 |
+
# proj_attn -> to_out.0
|
1101 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
1102 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
1103 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
1104 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
1105 |
+
|
1106 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1107 |
+
deprecated_attention_block_modules = []
|
1108 |
+
|
1109 |
+
def recursive_find_attn_block(module):
|
1110 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1111 |
+
deprecated_attention_block_modules.append(module)
|
1112 |
+
|
1113 |
+
for sub_module in module.children():
|
1114 |
+
recursive_find_attn_block(sub_module)
|
1115 |
+
|
1116 |
+
recursive_find_attn_block(self)
|
1117 |
+
|
1118 |
+
for module in deprecated_attention_block_modules:
|
1119 |
+
module.query = module.to_q
|
1120 |
+
module.key = module.to_k
|
1121 |
+
module.value = module.to_v
|
1122 |
+
module.proj_attn = module.to_out[0]
|
1123 |
+
|
1124 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
1125 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
1126 |
+
# making an incorrect assumption that this model should be converted when
|
1127 |
+
# it really shouldn't be.
|
1128 |
+
del module.to_q
|
1129 |
+
del module.to_k
|
1130 |
+
del module.to_v
|
1131 |
+
del module.to_out
|
1132 |
+
|
1133 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1134 |
+
deprecated_attention_block_modules = []
|
1135 |
+
|
1136 |
+
def recursive_find_attn_block(module) -> None:
|
1137 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1138 |
+
deprecated_attention_block_modules.append(module)
|
1139 |
+
|
1140 |
+
for sub_module in module.children():
|
1141 |
+
recursive_find_attn_block(sub_module)
|
1142 |
+
|
1143 |
+
recursive_find_attn_block(self)
|
1144 |
+
|
1145 |
+
for module in deprecated_attention_block_modules:
|
1146 |
+
module.to_q = module.query
|
1147 |
+
module.to_k = module.key
|
1148 |
+
module.to_v = module.value
|
1149 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
1150 |
+
|
1151 |
+
del module.query
|
1152 |
+
del module.key
|
1153 |
+
del module.value
|
1154 |
+
del module.proj_attn
|
1155 |
+
|
1156 |
+
|
1157 |
+
class LegacyModelMixin(ModelMixin):
|
1158 |
+
r"""
|
1159 |
+
A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
1160 |
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
1161 |
+
"""
|
1162 |
+
|
1163 |
+
@classmethod
|
1164 |
+
@validate_hf_hub_args
|
1165 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
1166 |
+
# To prevent dependency import problem.
|
1167 |
+
from diffusers.models.model_loading_utils import _fetch_remapped_cls_from_config
|
1168 |
+
|
1169 |
+
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
1170 |
+
kwargs_copy = kwargs.copy()
|
1171 |
+
|
1172 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
1173 |
+
force_download = kwargs.pop("force_download", False)
|
1174 |
+
proxies = kwargs.pop("proxies", None)
|
1175 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
1176 |
+
token = kwargs.pop("token", None)
|
1177 |
+
revision = kwargs.pop("revision", None)
|
1178 |
+
subfolder = kwargs.pop("subfolder", None)
|
1179 |
+
|
1180 |
+
# Load config if we don't provide a configuration
|
1181 |
+
config_path = pretrained_model_name_or_path
|
1182 |
+
|
1183 |
+
user_agent = {
|
1184 |
+
"diffusers": __version__,
|
1185 |
+
"file_type": "model",
|
1186 |
+
"framework": "pytorch",
|
1187 |
+
}
|
1188 |
+
|
1189 |
+
# load config
|
1190 |
+
config, _, _ = cls.load_config(
|
1191 |
+
config_path,
|
1192 |
+
cache_dir=cache_dir,
|
1193 |
+
return_unused_kwargs=True,
|
1194 |
+
return_commit_hash=True,
|
1195 |
+
force_download=force_download,
|
1196 |
+
proxies=proxies,
|
1197 |
+
local_files_only=local_files_only,
|
1198 |
+
token=token,
|
1199 |
+
revision=revision,
|
1200 |
+
subfolder=subfolder,
|
1201 |
+
user_agent=user_agent,
|
1202 |
+
**kwargs,
|
1203 |
+
)
|
1204 |
+
# resolve remapping
|
1205 |
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
1206 |
+
|
1207 |
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
models/sampling.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/lucidrains/muse-maskgit-pytorch
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def log(t, eps=1e-20):
|
11 |
+
return torch.log(t.clamp(min=eps))
|
12 |
+
|
13 |
+
|
14 |
+
def gumbel_noise(t, generator=None):
|
15 |
+
noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
|
16 |
+
return -log(-log(noise))
|
17 |
+
|
18 |
+
|
19 |
+
def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
|
20 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
|
21 |
+
|
22 |
+
|
23 |
+
def top_k(logits, thres=0.9):
|
24 |
+
k = math.ceil((1 - thres) * logits.shape[-1])
|
25 |
+
val, ind = logits.topk(k, dim=-1)
|
26 |
+
probs = torch.full_like(logits, float("-inf"))
|
27 |
+
probs.scatter_(2, ind, val)
|
28 |
+
return probs
|
29 |
+
|
30 |
+
|
31 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
32 |
+
confidence = log(probs) + temperature * gumbel_noise(probs, generator=generator)
|
33 |
+
sorted_confidence = torch.sort(confidence, dim=-1).values
|
34 |
+
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
35 |
+
masking = confidence < cut_off
|
36 |
+
return masking
|
37 |
+
|
38 |
+
|
39 |
+
def cosine_schedule(t):
|
40 |
+
return torch.cos(t * math.pi * 0.5)
|
41 |
+
|
42 |
+
|
43 |
+
def linear_schedule(t):
|
44 |
+
mask_ratio = 1 - t
|
45 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
46 |
+
return mask_ratio
|
47 |
+
|
48 |
+
|
49 |
+
def pow(t, method):
|
50 |
+
exponent = float(method.replace("pow", ""))
|
51 |
+
mask_ratio = 1.0 - t**exponent
|
52 |
+
mask_ratio = mask_ratio.clamp(min=1e-6, max=1.0)
|
53 |
+
return mask_ratio
|
54 |
+
|
55 |
+
|
56 |
+
def sigmoid_schedule(t, start=-3, end=3, tau=1.0, clip_min=1e-6):
|
57 |
+
for item in [t, start, end, tau]:
|
58 |
+
item = torch.tensor(item) if not torch.is_tensor(item) else item
|
59 |
+
|
60 |
+
# A gamma function based on sigmoid function.
|
61 |
+
v_start = torch.sigmoid(torch.tensor(start / tau))
|
62 |
+
v_end = torch.sigmoid(torch.tensor(end / tau))
|
63 |
+
output = torch.sigmoid((t * (end - start) + start) / tau)
|
64 |
+
output = (v_end - output) / (v_end - v_start)
|
65 |
+
return torch.clip(output, clip_min, 1.0)
|
66 |
+
|
67 |
+
|
68 |
+
def get_mask_schedule(method, **schedule_kwargs):
|
69 |
+
if method == "cosine":
|
70 |
+
return cosine_schedule
|
71 |
+
elif method == "linear":
|
72 |
+
return linear_schedule
|
73 |
+
elif "pow" in method:
|
74 |
+
return partial(pow, method=method)
|
75 |
+
elif method == "sigmoid":
|
76 |
+
return partial(sigmoid_schedule, **schedule_kwargs)
|
77 |
+
else:
|
78 |
+
raise ValueError("Unknown schedule method: {}".format(method))
|
79 |
+
|
80 |
+
def top_k_top_p_filtering(
|
81 |
+
logits: torch.Tensor,
|
82 |
+
top_k: int = 0,
|
83 |
+
top_p: float = 1.0,
|
84 |
+
filter_value: float = -float("Inf"),
|
85 |
+
min_tokens_to_keep: int = 1,
|
86 |
+
) -> torch.Tensor:
|
87 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
88 |
+
Args:
|
89 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
90 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
91 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
92 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
93 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
94 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
95 |
+
"""
|
96 |
+
if top_k > 0:
|
97 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
98 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
99 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
100 |
+
logits[indices_to_remove] = filter_value
|
101 |
+
|
102 |
+
if top_p < 1.0:
|
103 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
104 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
105 |
+
|
106 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
107 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
108 |
+
if min_tokens_to_keep > 1:
|
109 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
110 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
111 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
112 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
113 |
+
sorted_indices_to_remove[..., 0] = 0
|
114 |
+
|
115 |
+
# scatter sorted tensors to original indexing
|
116 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
117 |
+
logits[indices_to_remove] = filter_value
|
118 |
+
return logits
|
models/training_utils.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import copy
|
17 |
+
import os
|
18 |
+
import random
|
19 |
+
from typing import Any, Dict, Iterable, Optional, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
|
27 |
+
def enable_full_determinism(seed: int):
|
28 |
+
"""
|
29 |
+
Helper function for reproducible behavior during distributed training. See
|
30 |
+
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
|
31 |
+
"""
|
32 |
+
# set seed first
|
33 |
+
set_seed(seed)
|
34 |
+
|
35 |
+
# Enable PyTorch deterministic mode. This potentially requires either the environment
|
36 |
+
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
|
37 |
+
# depending on the CUDA version, so we set them both here
|
38 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
39 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
40 |
+
torch.use_deterministic_algorithms(True)
|
41 |
+
|
42 |
+
# Enable CUDNN deterministic mode
|
43 |
+
torch.backends.cudnn.deterministic = True
|
44 |
+
torch.backends.cudnn.benchmark = False
|
45 |
+
|
46 |
+
|
47 |
+
def set_seed(seed: int):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
51 |
+
seed (`int`): The seed to set.
|
52 |
+
"""
|
53 |
+
random.seed(seed)
|
54 |
+
np.random.seed(seed)
|
55 |
+
torch.manual_seed(seed)
|
56 |
+
torch.cuda.manual_seed_all(seed)
|
57 |
+
# ^^ safe to call this function even if cuda is not available
|
58 |
+
|
59 |
+
|
60 |
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
61 |
+
class EMA:
|
62 |
+
"""
|
63 |
+
Exponential Moving Average of models weights
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
parameters: Iterable[torch.nn.Parameter],
|
69 |
+
decay: float = 0.9999,
|
70 |
+
min_decay: float = 0.0,
|
71 |
+
update_after_step: int = 0,
|
72 |
+
use_ema_warmup: bool = False,
|
73 |
+
inv_gamma: Union[float, int] = 1.0,
|
74 |
+
power: Union[float, int] = 2 / 3,
|
75 |
+
model_cls: Optional[Any] = None,
|
76 |
+
model_config: Dict[str, Any] = None,
|
77 |
+
**kwargs,
|
78 |
+
):
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
82 |
+
decay (float): The decay factor for the exponential moving average.
|
83 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
84 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
85 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
86 |
+
inv_gamma (float):
|
87 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
88 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
89 |
+
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
90 |
+
weights will be stored on CPU.
|
91 |
+
|
92 |
+
@crowsonkb's notes on EMA Warmup:
|
93 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
94 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
95 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
96 |
+
at 215.4k steps).
|
97 |
+
"""
|
98 |
+
|
99 |
+
parameters = list(parameters)
|
100 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
101 |
+
|
102 |
+
self.temp_stored_params = None
|
103 |
+
|
104 |
+
self.decay = decay
|
105 |
+
self.min_decay = min_decay
|
106 |
+
self.update_after_step = update_after_step
|
107 |
+
self.use_ema_warmup = use_ema_warmup
|
108 |
+
self.inv_gamma = inv_gamma
|
109 |
+
self.power = power
|
110 |
+
self.optimization_step = 0
|
111 |
+
self.cur_decay_value = None # set in `step()`
|
112 |
+
|
113 |
+
self.model_cls = model_cls
|
114 |
+
self.model_config = model_config
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def from_pretrained(cls, path, model_cls) -> "EMA":
|
118 |
+
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
|
119 |
+
model = model_cls.from_pretrained(path)
|
120 |
+
|
121 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
|
122 |
+
|
123 |
+
ema_model.load_state_dict(ema_kwargs)
|
124 |
+
return ema_model
|
125 |
+
|
126 |
+
def save_pretrained(self, path):
|
127 |
+
if self.model_cls is None:
|
128 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
129 |
+
|
130 |
+
if self.model_config is None:
|
131 |
+
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
132 |
+
|
133 |
+
model = self.model_cls.from_config(self.model_config)
|
134 |
+
state_dict = self.state_dict()
|
135 |
+
state_dict.pop("shadow_params", None)
|
136 |
+
|
137 |
+
model.register_to_config(**state_dict)
|
138 |
+
self.copy_to(model.parameters())
|
139 |
+
model.save_pretrained(path)
|
140 |
+
|
141 |
+
def get_decay(self, optimization_step: int) -> float:
|
142 |
+
"""
|
143 |
+
Compute the decay factor for the exponential moving average.
|
144 |
+
"""
|
145 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
146 |
+
|
147 |
+
if step <= 0:
|
148 |
+
return 0.0
|
149 |
+
|
150 |
+
if self.use_ema_warmup:
|
151 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
152 |
+
else:
|
153 |
+
cur_decay_value = (1 + step) / (10 + step)
|
154 |
+
|
155 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
156 |
+
# make sure decay is not smaller than min_decay
|
157 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
158 |
+
return cur_decay_value
|
159 |
+
|
160 |
+
@torch.no_grad()
|
161 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
162 |
+
parameters = list(parameters)
|
163 |
+
|
164 |
+
self.optimization_step += 1
|
165 |
+
|
166 |
+
# Compute the decay factor for the exponential moving average.
|
167 |
+
decay = self.get_decay(self.optimization_step)
|
168 |
+
self.cur_decay_value = decay
|
169 |
+
one_minus_decay = 1 - decay
|
170 |
+
|
171 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
172 |
+
if param.requires_grad:
|
173 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
174 |
+
else:
|
175 |
+
s_param.copy_(param)
|
176 |
+
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
|
179 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
180 |
+
"""
|
181 |
+
Copy current averaged parameters into given collection of parameters.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
185 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
186 |
+
`ExponentialMovingAverage` was initialized will be used.
|
187 |
+
"""
|
188 |
+
parameters = list(parameters)
|
189 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
190 |
+
param.data.copy_(s_param.to(param.device).data)
|
191 |
+
|
192 |
+
def to(self, device=None, dtype=None) -> None:
|
193 |
+
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
device: like `device` argument to `torch.Tensor.to`
|
197 |
+
"""
|
198 |
+
# .to() on the tensors handles None correctly
|
199 |
+
self.shadow_params = [
|
200 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
201 |
+
for p in self.shadow_params
|
202 |
+
]
|
203 |
+
|
204 |
+
def state_dict(self) -> dict:
|
205 |
+
r"""
|
206 |
+
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
207 |
+
checkpointing to save the ema state dict.
|
208 |
+
"""
|
209 |
+
# Following PyTorch conventions, references to tensors are returned:
|
210 |
+
# "returns a reference to the state and not its copy!" -
|
211 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
212 |
+
return {
|
213 |
+
"decay": self.decay,
|
214 |
+
"min_decay": self.min_decay,
|
215 |
+
"optimization_step": self.optimization_step,
|
216 |
+
"update_after_step": self.update_after_step,
|
217 |
+
"use_ema_warmup": self.use_ema_warmup,
|
218 |
+
"inv_gamma": self.inv_gamma,
|
219 |
+
"power": self.power,
|
220 |
+
"shadow_params": self.shadow_params,
|
221 |
+
}
|
222 |
+
|
223 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
224 |
+
r"""
|
225 |
+
Args:
|
226 |
+
Save the current parameters for restoring later.
|
227 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
228 |
+
temporarily stored.
|
229 |
+
"""
|
230 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
231 |
+
|
232 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
233 |
+
r"""
|
234 |
+
Args:
|
235 |
+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
|
236 |
+
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
237 |
+
validation (or model saving), use this to restore the former parameters.
|
238 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
239 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
240 |
+
`ExponentialMovingAverage` was initialized will be used.
|
241 |
+
"""
|
242 |
+
if self.temp_stored_params is None:
|
243 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
244 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
245 |
+
param.data.copy_(c_param.data)
|
246 |
+
|
247 |
+
# Better memory-wise.
|
248 |
+
self.temp_stored_params = None
|
249 |
+
|
250 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
251 |
+
r"""
|
252 |
+
Args:
|
253 |
+
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
254 |
+
ema state dict.
|
255 |
+
state_dict (dict): EMA state. Should be an object returned
|
256 |
+
from a call to :meth:`state_dict`.
|
257 |
+
"""
|
258 |
+
# deepcopy, to be consistent with module API
|
259 |
+
state_dict = copy.deepcopy(state_dict)
|
260 |
+
|
261 |
+
self.decay = state_dict.get("decay", self.decay)
|
262 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
263 |
+
raise ValueError("Decay must be between 0 and 1")
|
264 |
+
|
265 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
266 |
+
if not isinstance(self.min_decay, float):
|
267 |
+
raise ValueError("Invalid min_decay")
|
268 |
+
|
269 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
270 |
+
if not isinstance(self.optimization_step, int):
|
271 |
+
raise ValueError("Invalid optimization_step")
|
272 |
+
|
273 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
274 |
+
if not isinstance(self.update_after_step, int):
|
275 |
+
raise ValueError("Invalid update_after_step")
|
276 |
+
|
277 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
278 |
+
if not isinstance(self.use_ema_warmup, bool):
|
279 |
+
raise ValueError("Invalid use_ema_warmup")
|
280 |
+
|
281 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
282 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
283 |
+
raise ValueError("Invalid inv_gamma")
|
284 |
+
|
285 |
+
self.power = state_dict.get("power", self.power)
|
286 |
+
if not isinstance(self.power, (float, int)):
|
287 |
+
raise ValueError("Invalid power")
|
288 |
+
|
289 |
+
shadow_params = state_dict.get("shadow_params", None)
|
290 |
+
if shadow_params is not None:
|
291 |
+
self.shadow_params = shadow_params
|
292 |
+
if not isinstance(self.shadow_params, list):
|
293 |
+
raise ValueError("shadow_params must be a list")
|
294 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
295 |
+
raise ValueError("shadow_params must all be Tensors")
|
296 |
+
|
297 |
+
|
298 |
+
# calculates entropy over each pixel distribution
|
299 |
+
def pixel_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
|
300 |
+
# only calculated entropy over image tokens that were masked in the original image
|
301 |
+
masked_tokens = input_ids == mask_id
|
302 |
+
num_masked_pixels = masked_tokens.sum(-1)
|
303 |
+
|
304 |
+
probs = F.softmax(logits, dim=-1)
|
305 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
306 |
+
|
307 |
+
entropy_per_pixel = -((probs * log_probs).sum(-1))
|
308 |
+
|
309 |
+
# the predictions for non-masked aren't used, so set their entropies to zero
|
310 |
+
entropy_per_pixel[~masked_tokens] = 0
|
311 |
+
|
312 |
+
entropy_per_image_numerator = entropy_per_pixel.sum(-1)
|
313 |
+
entropy_per_image = entropy_per_image_numerator / num_masked_pixels
|
314 |
+
|
315 |
+
total_buckets = 10
|
316 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
317 |
+
|
318 |
+
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
|
319 |
+
|
320 |
+
return entropy_by_masked_bucket
|
321 |
+
|
322 |
+
|
323 |
+
# calculates entropy over the averaged distribution of pixels for the whole image
|
324 |
+
def image_entropy_per_percent_masked_bucket(logits, input_ids, mask_id):
|
325 |
+
# only calculated entropy over image tokens that were masked in the original image
|
326 |
+
masked_tokens = input_ids == mask_id
|
327 |
+
num_masked_pixels = masked_tokens.sum(-1, keepdim=True)
|
328 |
+
|
329 |
+
pixel_probs = F.softmax(logits, dim=-1)
|
330 |
+
pixel_probs[~masked_tokens] = 0
|
331 |
+
image_probs_numerator = pixel_probs.sum(-2)
|
332 |
+
image_probs = image_probs_numerator / num_masked_pixels
|
333 |
+
|
334 |
+
image_log_probs = image_probs.log()
|
335 |
+
|
336 |
+
entropy_per_image = -((image_probs * image_log_probs).sum(-1))
|
337 |
+
|
338 |
+
total_buckets = 10
|
339 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
340 |
+
|
341 |
+
entropy_by_masked_bucket = average_by_buckets(entropy_per_image, masked_buckets, total_buckets)
|
342 |
+
|
343 |
+
return entropy_by_masked_bucket
|
344 |
+
|
345 |
+
|
346 |
+
def cross_entropy_per_percent_masked_bucket(logits, labels, input_ids, mask_id, output_size, label_smoothing):
|
347 |
+
cross_entropy_per_image = F.cross_entropy(
|
348 |
+
logits.view(-1, output_size),
|
349 |
+
labels.view(-1),
|
350 |
+
ignore_index=-100,
|
351 |
+
label_smoothing=label_smoothing,
|
352 |
+
reduction="none",
|
353 |
+
)
|
354 |
+
|
355 |
+
total_buckets = 10
|
356 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
357 |
+
|
358 |
+
cross_entropy_by_percent_masked_bucket = average_by_buckets(cross_entropy_per_image, masked_buckets, total_buckets)
|
359 |
+
|
360 |
+
return cross_entropy_by_percent_masked_bucket
|
361 |
+
|
362 |
+
|
363 |
+
def token_probability_distributions_per_percent_masked_bucket(logits, input_ids, mask_id):
|
364 |
+
probs = F.softmax(logits, dim=-1)
|
365 |
+
|
366 |
+
total_buckets = 10
|
367 |
+
masked_buckets = input_ids_to_masked_buckets(input_ids, mask_id, total_buckets)
|
368 |
+
|
369 |
+
data = []
|
370 |
+
|
371 |
+
for bucket_idx in range(total_buckets):
|
372 |
+
indices_for_bucket = masked_buckets[masked_buckets == bucket_idx]
|
373 |
+
|
374 |
+
# It's ok if none were noised in the range of this bucket. This
|
375 |
+
# function will be called for a later training step where it's likely
|
376 |
+
# there will be an element noised in the range.
|
377 |
+
if indices_for_bucket.shape[0] == 0:
|
378 |
+
continue
|
379 |
+
|
380 |
+
index_for_bucket = indices_for_bucket[0]
|
381 |
+
|
382 |
+
image_probs = probs[index_for_bucket]
|
383 |
+
|
384 |
+
# find the index of a masked pixel for the image
|
385 |
+
input_ids_for_image = input_ids[index_for_bucket]
|
386 |
+
masked_pixels_probs = image_probs[input_ids_for_image == mask_id]
|
387 |
+
|
388 |
+
masked_pixel_probs = masked_pixels_probs[0]
|
389 |
+
|
390 |
+
masked_pixel_probs = masked_pixel_probs.cpu().numpy()
|
391 |
+
|
392 |
+
for masked_pixel_prob in masked_pixel_probs:
|
393 |
+
data.append({"bucket": bucket_idx, "masked_pixel_prob": masked_pixel_prob})
|
394 |
+
|
395 |
+
df = pd.DataFrame(data)
|
396 |
+
|
397 |
+
return df
|
398 |
+
|
399 |
+
|
400 |
+
def average_by_buckets(values, masked_buckets, total_buckets):
|
401 |
+
unique_buckets, bucket_counts = masked_buckets.unique(dim=0, return_counts=True)
|
402 |
+
|
403 |
+
numerator = torch.zeros(total_buckets, device=values.device)
|
404 |
+
|
405 |
+
numerator.scatter_add_(0, masked_buckets, values)
|
406 |
+
|
407 |
+
# default value is one because the buckets for which there aren't
|
408 |
+
# any values will have a numerator of zero. So we just need to not divide
|
409 |
+
# by zero.
|
410 |
+
denominator = torch.ones(total_buckets, device=values.device, dtype=torch.long)
|
411 |
+
denominator[unique_buckets] = bucket_counts
|
412 |
+
|
413 |
+
averaged_by_buckets = numerator / denominator
|
414 |
+
|
415 |
+
return averaged_by_buckets
|
416 |
+
|
417 |
+
|
418 |
+
def input_ids_to_masked_buckets(input_ids, mask_id, total_buckets=10):
|
419 |
+
assert total_buckets == 10
|
420 |
+
|
421 |
+
masked_percent = (input_ids == mask_id).sum(-1) / input_ids.shape[-1]
|
422 |
+
|
423 |
+
# we do not formally use timesteps to noise images. Instead, we mask a percent
|
424 |
+
# of the pixels. We don't want to log entropy for every mask percent between 0 and 1,
|
425 |
+
# and we also want to track how the entropy evolves over time w/in a range of mask
|
426 |
+
# percents that should have similar entropy. So we bucket the masked percents into a
|
427 |
+
# fixed number of buckets
|
428 |
+
|
429 |
+
# we could generalize this later if needed but for now, let's just assume a fixed
|
430 |
+
# number of 10 buckets.
|
431 |
+
|
432 |
+
# How this maps to a bucket index:
|
433 |
+
# (mask) * bucket_index +
|
434 |
+
# (mask_1) * bucket_index_1
|
435 |
+
#
|
436 |
+
# -> Where the mask is true will be set to the expected bucket index,
|
437 |
+
# where the mask is false will be set to 0.
|
438 |
+
#
|
439 |
+
# Given the probabilities are between 0 and 1, each masked_percent will get mapped
|
440 |
+
# to a timestep by one and only one of the masks.
|
441 |
+
|
442 |
+
masked_buckets = (
|
443 |
+
((0 < masked_percent) & (masked_percent <= 0.1)) * 0
|
444 |
+
+ ((0.1 < masked_percent) & (masked_percent <= 0.2)) * 1
|
445 |
+
+ ((0.2 < masked_percent) & (masked_percent <= 0.3)) * 2
|
446 |
+
+ ((0.3 < masked_percent) & (masked_percent <= 0.4)) * 3
|
447 |
+
+ ((0.4 < masked_percent) & (masked_percent <= 0.5)) * 4
|
448 |
+
+ ((0.5 < masked_percent) & (masked_percent <= 0.6)) * 5
|
449 |
+
+ ((0.6 < masked_percent) & (masked_percent <= 0.7)) * 6
|
450 |
+
+ ((0.7 < masked_percent) & (masked_percent <= 0.8)) * 7
|
451 |
+
+ ((0.8 < masked_percent) & (masked_percent <= 0.9)) * 8
|
452 |
+
+ ((0.9 < masked_percent) & (masked_percent <= 1.0)) * 9
|
453 |
+
)
|
454 |
+
|
455 |
+
return masked_buckets
|
training/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .mmada_grpo_trainer import DiffusionGRPOTrainer
|
training/prompting_utils.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2025 MMaDA team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
reserved_token_mapping = {
|
18 |
+
'<|soi|>': 126084,
|
19 |
+
'<|eoi|>': 126085,
|
20 |
+
'<|sov|>': 126086,
|
21 |
+
'<|eov|>': 126087,
|
22 |
+
'<|t2i|>': 126088,
|
23 |
+
'<|mmu|>': 126089,
|
24 |
+
'<|t2v|>': 126090,
|
25 |
+
'<|v2v|>': 126091,
|
26 |
+
'<|lvg|>': 126092,
|
27 |
+
'[iPAD]': 126093,
|
28 |
+
'<|r2i|>': 126094,
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
import torch
|
33 |
+
class UniversalPrompting():
|
34 |
+
def __init__(self, text_tokenizer,
|
35 |
+
special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
|
36 |
+
max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=False):
|
37 |
+
"""
|
38 |
+
:param text_tokenizer: original text tokenizer
|
39 |
+
"""
|
40 |
+
if not use_reserved_token:
|
41 |
+
self.text_tokenizer = text_tokenizer
|
42 |
+
self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
43 |
+
self.text_tokenizer.add_tokens(list(special_tokens))
|
44 |
+
self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
|
45 |
+
special_tokens}
|
46 |
+
self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
|
47 |
+
self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
|
48 |
+
self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
|
49 |
+
else:
|
50 |
+
self.text_tokenizer = text_tokenizer
|
51 |
+
self.sptids_dict = {}
|
52 |
+
for token, token_id in reserved_token_mapping.items():
|
53 |
+
self.sptids_dict[token] = torch.tensor([token_id])
|
54 |
+
self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
|
55 |
+
self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
|
56 |
+
end_header_tokens = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>'])
|
57 |
+
if end_header_tokens and len(end_header_tokens) > 0 and end_header_tokens[0]:
|
58 |
+
self.sptids_dict['<|end_header_id|>'] = torch.tensor(end_header_tokens)
|
59 |
+
self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>']))
|
60 |
+
self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>']))
|
61 |
+
else:
|
62 |
+
special_tokens_dict = {
|
63 |
+
'additional_special_tokens': [
|
64 |
+
'<|start_header_id|>',
|
65 |
+
'<|end_header_id|>',
|
66 |
+
'<|eot_id|>'
|
67 |
+
]
|
68 |
+
}
|
69 |
+
num_added = self.text_tokenizer.add_special_tokens(special_tokens_dict)
|
70 |
+
new_token_id = self.text_tokenizer.convert_tokens_to_ids(['<|end_header_id|>'])
|
71 |
+
self.sptids_dict['<|end_header_id|>'] = torch.tensor(new_token_id)
|
72 |
+
self.sptids_dict['<|eot_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|eot_id|>']))
|
73 |
+
self.sptids_dict['<|start_header_id|>'] = torch.tensor(self.text_tokenizer.convert_tokens_to_ids(['<|start_header_id|>']))
|
74 |
+
# plus 1 because at this time we add a task token before
|
75 |
+
print(f"self.sptids_dict: {self.sptids_dict}")
|
76 |
+
self.max_text_len = max_text_len + 1
|
77 |
+
self.pad_id = reserved_token_mapping['[iPAD]']
|
78 |
+
self.ignore_id = ignore_id
|
79 |
+
self.cond_dropout_prob = cond_dropout_prob
|
80 |
+
|
81 |
+
def t2i_prompt(self, text_ids, image_ids, labels):
|
82 |
+
|
83 |
+
device = image_ids.device
|
84 |
+
sequence_ids = []
|
85 |
+
attention_masks = []
|
86 |
+
label_ids = []
|
87 |
+
probs = torch.rand(len(text_ids))
|
88 |
+
for i in range(len(text_ids)):
|
89 |
+
|
90 |
+
if len(text_ids[i]) == 0:
|
91 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
92 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
93 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
94 |
+
|
95 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
96 |
+
|
97 |
+
# randomly dropout text condition
|
98 |
+
if probs[i] < self.cond_dropout_prob:
|
99 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
|
100 |
+
|
101 |
+
if self.max_text_len >= len(temp_ids):
|
102 |
+
old_len = len(temp_ids)
|
103 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
104 |
+
temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2)
|
105 |
+
else:
|
106 |
+
# should add the eos token
|
107 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
108 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens
|
109 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
110 |
+
temp_label_ids = torch.cat([
|
111 |
+
# should we predict text tokens when doing image reconstruction?
|
112 |
+
torch.tensor(temp_ids).to(device),
|
113 |
+
self.sptids_dict['<|soi|>'].to(device),
|
114 |
+
labels[i],
|
115 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
116 |
+
], dim=0)
|
117 |
+
|
118 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
119 |
+
|
120 |
+
temp_ids = torch.cat([
|
121 |
+
torch.tensor(temp_ids).to(device),
|
122 |
+
self.sptids_dict['<|soi|>'].to(device),
|
123 |
+
image_ids[i],
|
124 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
125 |
+
], dim=0)
|
126 |
+
|
127 |
+
# sequence_ids: [pad]...[pad] <|t2i|> <bos> text_1 ... text_n <eos> <|soi|> image_1 ... image_m <|eoi|>
|
128 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
129 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
130 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
131 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
132 |
+
|
133 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
134 |
+
|
135 |
+
def t2i_gen_prompt(self, text_ids, image_ids):
|
136 |
+
|
137 |
+
device = image_ids.device
|
138 |
+
sequence_ids = []
|
139 |
+
attention_masks = []
|
140 |
+
for i in range(len(text_ids)):
|
141 |
+
if len(text_ids[i]) == 0:
|
142 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
143 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
144 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
145 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
146 |
+
temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
|
147 |
+
if self.max_text_len >= len(temp_ids):
|
148 |
+
old_len = len(temp_ids)
|
149 |
+
temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
|
150 |
+
temp_masks = [0] * (self.max_text_len - old_len) + [1] * (old_len + image_ids.shape[-1] + 2)
|
151 |
+
else:
|
152 |
+
# should add the eos token
|
153 |
+
temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
154 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 2) # +2 for two special tokens
|
155 |
+
|
156 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
157 |
+
temp_ids = torch.cat([
|
158 |
+
torch.tensor(temp_ids).to(device),
|
159 |
+
self.sptids_dict['<|soi|>'].to(device),
|
160 |
+
image_ids[i],
|
161 |
+
self.sptids_dict['<|eoi|>'].to(device)
|
162 |
+
], dim=0)
|
163 |
+
|
164 |
+
temp_masks = torch.tensor(temp_masks).to(device)
|
165 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
166 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
167 |
+
|
168 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0)
|
169 |
+
|
170 |
+
# language modeling
|
171 |
+
def lm_prompt(self, text_ids, max_seq_len):
|
172 |
+
sequence_ids = []
|
173 |
+
attention_masks = []
|
174 |
+
label_ids = []
|
175 |
+
for i in range(len(text_ids)):
|
176 |
+
if len(text_ids[i]) == 0:
|
177 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
178 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
179 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
180 |
+
|
181 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
182 |
+
|
183 |
+
if max_seq_len >= len(temp_ids):
|
184 |
+
temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
185 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
186 |
+
temp_masks = [1] * len(temp_ids) + [0] * (max_seq_len - len(temp_ids))
|
187 |
+
else:
|
188 |
+
# In language modeling, we only process text tokens. We do not add the eos token if the text length
|
189 |
+
# exceeds the max sequence length
|
190 |
+
temp_labels_ids = temp_ids[:max_seq_len]
|
191 |
+
temp_ids = temp_ids[:max_seq_len]
|
192 |
+
temp_masks = [1] * len(temp_ids) # +2 for two special tokens
|
193 |
+
|
194 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
195 |
+
temp_ids = torch.tensor(temp_ids)
|
196 |
+
temp_masks = torch.tensor(temp_masks)
|
197 |
+
temp_labels_ids = torch.tensor(temp_labels_ids)
|
198 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
199 |
+
attention_masks.append(temp_masks.unsqueeze(0))
|
200 |
+
label_ids.append(temp_labels_ids.unsqueeze(0))
|
201 |
+
|
202 |
+
# input_ids, masks, labels
|
203 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(attention_masks, dim=0), torch.cat(label_ids, dim=0)
|
204 |
+
|
205 |
+
# language modeling
|
206 |
+
def lm_chat_prompt(self, text_ids, max_seq_len):
|
207 |
+
sequence_ids = []
|
208 |
+
prompt_masks = []
|
209 |
+
label_ids = []
|
210 |
+
|
211 |
+
for i in range(len(text_ids)):
|
212 |
+
if len(text_ids[i]) == 0:
|
213 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
214 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
215 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
216 |
+
|
217 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
218 |
+
|
219 |
+
if max_seq_len >= len(temp_ids):
|
220 |
+
temp_labels_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
221 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_seq_len - len(temp_ids))
|
222 |
+
else:
|
223 |
+
# In language modeling, we only process text tokens. We do not add the eos token if the text length
|
224 |
+
# exceeds the max sequence length
|
225 |
+
temp_labels_ids = temp_ids[:max_seq_len]
|
226 |
+
temp_ids = temp_ids[:max_seq_len]
|
227 |
+
|
228 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
229 |
+
end_header_pos = -1
|
230 |
+
for pos in range(len(temp_ids) - 1, -1, -1): # 尝试从文本序列中寻找<|end_header_id|>
|
231 |
+
if temp_ids[pos] == end_header_id:
|
232 |
+
end_header_pos = pos
|
233 |
+
break
|
234 |
+
if end_header_pos != -1:
|
235 |
+
prompt_length = end_header_pos + 1
|
236 |
+
else:
|
237 |
+
prompt_length = 0
|
238 |
+
temp_masks = [1] * prompt_length + [0] * (len(temp_ids) - prompt_length)
|
239 |
+
|
240 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
241 |
+
temp_ids = torch.tensor(temp_ids)
|
242 |
+
temp_masks = torch.tensor(temp_masks)
|
243 |
+
temp_labels_ids = torch.tensor(temp_labels_ids)
|
244 |
+
sequence_ids.append(temp_ids.unsqueeze(0))
|
245 |
+
prompt_masks.append(temp_masks.unsqueeze(0))
|
246 |
+
label_ids.append(temp_labels_ids.unsqueeze(0))
|
247 |
+
|
248 |
+
# input_ids, masks, labels
|
249 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0)
|
250 |
+
|
251 |
+
def mmu_prompt(self, image_ids, text_ids):
|
252 |
+
device = image_ids.device
|
253 |
+
sequence_ids = []
|
254 |
+
prompt_masks = []
|
255 |
+
label_ids = []
|
256 |
+
max_text_len = self.max_text_len - 1
|
257 |
+
for i in range(len(text_ids)):
|
258 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
259 |
+
# for empty list []
|
260 |
+
|
261 |
+
if len(text_ids[i]) == 0:
|
262 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
263 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
264 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
265 |
+
|
266 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
267 |
+
|
268 |
+
if max_text_len >= len(temp_ids):
|
269 |
+
# minus 1 because task token was prepended to the former image tokens
|
270 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_text_len - len(temp_ids))
|
271 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) + [0] * (max_text_len - len(temp_ids))
|
272 |
+
else:
|
273 |
+
# should add the eos token
|
274 |
+
temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
275 |
+
temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
|
276 |
+
|
277 |
+
# prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
|
278 |
+
temp_label_ids = torch.cat([
|
279 |
+
torch.tensor([self.ignore_id]).to(device),
|
280 |
+
torch.tensor([self.ignore_id]).to(device),
|
281 |
+
torch.ones_like(image_ids[i]) * self.ignore_id,
|
282 |
+
torch.tensor([self.ignore_id]).to(device),
|
283 |
+
torch.tensor(temp_ids).to(device),
|
284 |
+
], dim=0)
|
285 |
+
|
286 |
+
temp_label_ids = torch.where(temp_label_ids == self.pad_id, self.ignore_id, temp_label_ids)
|
287 |
+
|
288 |
+
return_temp_ids = torch.cat([
|
289 |
+
self.sptids_dict['<|mmu|>'].to(device), # task token
|
290 |
+
self.sptids_dict['<|soi|>'].to(device),
|
291 |
+
image_ids[i],
|
292 |
+
self.sptids_dict['<|eoi|>'].to(device),
|
293 |
+
torch.tensor(temp_ids).to(device),
|
294 |
+
], dim=0)
|
295 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
296 |
+
end_header_pos = -1
|
297 |
+
for pos in range(len(temp_ids) - 1, -1, -1):
|
298 |
+
if temp_ids[pos] == end_header_id:
|
299 |
+
end_header_pos = pos
|
300 |
+
break
|
301 |
+
if end_header_pos != -1:
|
302 |
+
prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1
|
303 |
+
else:
|
304 |
+
prompt_length = len(return_temp_ids) - len(temp_ids)
|
305 |
+
predict_length = len(return_temp_ids) - prompt_length
|
306 |
+
prompt_mask = [1] * prompt_length + [0] * predict_length
|
307 |
+
prompt_mask = torch.tensor(prompt_mask).to(device)
|
308 |
+
sequence_ids.append(return_temp_ids.unsqueeze(0))
|
309 |
+
prompt_masks.append(prompt_mask.unsqueeze(0))
|
310 |
+
label_ids.append(temp_label_ids.unsqueeze(0))
|
311 |
+
|
312 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(label_ids, dim=0)
|
313 |
+
|
314 |
+
def mmu_gen_prompt(self, image_ids, text_ids):
|
315 |
+
device = image_ids.device
|
316 |
+
sequence_ids = []
|
317 |
+
prompt_masks = []
|
318 |
+
max_text_len = self.max_text_len - 1
|
319 |
+
for i in range(len(text_ids)):
|
320 |
+
|
321 |
+
if len(text_ids[i]) == 0:
|
322 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
323 |
+
elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
|
324 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
325 |
+
|
326 |
+
temp_ids = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
327 |
+
|
328 |
+
if max_text_len >= len(temp_ids):
|
329 |
+
# minus 1 because task token was prepended to the former image tokens
|
330 |
+
temp_ids = temp_ids + [self.text_tokenizer.eos_token_id] * (max_text_len - len(temp_ids))
|
331 |
+
else:
|
332 |
+
# should add the eos token
|
333 |
+
temp_ids = temp_ids[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
334 |
+
|
335 |
+
# print(f"mmu temp_ids: {temp_ids}")
|
336 |
+
return_temp_ids = torch.cat([
|
337 |
+
self.sptids_dict['<|mmu|>'].to(device), # task token
|
338 |
+
self.sptids_dict['<|soi|>'].to(device),
|
339 |
+
image_ids[i],
|
340 |
+
self.sptids_dict['<|eoi|>'].to(device),
|
341 |
+
torch.tensor(temp_ids).to(device),
|
342 |
+
], dim=0)
|
343 |
+
|
344 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
345 |
+
end_header_pos = -1
|
346 |
+
for pos in range(len(temp_ids) - 1, -1, -1):
|
347 |
+
if temp_ids[pos] == end_header_id:
|
348 |
+
end_header_pos = pos
|
349 |
+
break
|
350 |
+
if end_header_pos != -1:
|
351 |
+
prompt_length = len(return_temp_ids) - len(temp_ids) + end_header_pos + 1
|
352 |
+
else:
|
353 |
+
prompt_length = len(return_temp_ids) - len(temp_ids)
|
354 |
+
predict_length = len(temp_ids) - prompt_length
|
355 |
+
print(f"prompt_length: {prompt_length}, predict_length: {predict_length}, all length: {len(return_temp_ids)}, {return_temp_ids[-predict_length:]}")
|
356 |
+
prompt_mask = [1] * prompt_length + [0] * predict_length
|
357 |
+
prompt_mask = torch.tensor(prompt_mask).to(device)
|
358 |
+
sequence_ids.append(return_temp_ids.unsqueeze(0))
|
359 |
+
prompt_masks.append(prompt_mask.unsqueeze(0))
|
360 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0)
|
361 |
+
|
362 |
+
def r2i_prompt(self, image_ids, text_ids):
|
363 |
+
device = image_ids.device
|
364 |
+
sequence_ids = []
|
365 |
+
prompt_masks = []
|
366 |
+
label_ids = []
|
367 |
+
r2i_id = int(self.sptids_dict['<|r2i|>'])
|
368 |
+
soi_id = int(self.sptids_dict['<|soi|>'])
|
369 |
+
eoi_id = int(self.sptids_dict['<|eoi|>'])
|
370 |
+
max_text_len = self.max_text_len - 1 # 512,include BOS text EOS
|
371 |
+
for i in range(len(text_ids)):
|
372 |
+
# note that, llama3 tokenizer automatically add the bot token at first but without eot
|
373 |
+
# for empty list []
|
374 |
+
if len(text_ids[i]) == 0:
|
375 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id]
|
376 |
+
elif text_ids[i][0]!= self.text_tokenizer.bos_token_id:
|
377 |
+
text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
|
378 |
+
text_ids_with_bos_eos = text_ids[i] + [self.text_tokenizer.eos_token_id]
|
379 |
+
if max_text_len >= len(text_ids_with_bos_eos):
|
380 |
+
# minus 1 because task token was prepended to the former image tokens
|
381 |
+
text_ids_full_len = text_ids_with_bos_eos + [self.text_tokenizer.eos_token_id] * (max_text_len - len(text_ids_with_bos_eos))
|
382 |
+
else:
|
383 |
+
# should add the eos token
|
384 |
+
text_ids_full_len = text_ids_with_bos_eos[:max_text_len - 1] + [self.text_tokenizer.eos_token_id]
|
385 |
+
|
386 |
+
sequence_ids.append(torch.cat([
|
387 |
+
torch.tensor([r2i_id]).to(device), # task token
|
388 |
+
torch.tensor(text_ids_full_len).to(device),
|
389 |
+
torch.tensor([soi_id]).to(device),
|
390 |
+
image_ids[i],
|
391 |
+
torch.tensor([eoi_id]).to(device),
|
392 |
+
], dim=0).unsqueeze(0))
|
393 |
+
|
394 |
+
end_header_id = int(self.sptids_dict['<|end_header_id|>'])
|
395 |
+
end_header_pos = -1
|
396 |
+
for pos in range(len(text_ids_full_len) - 1, -1, -1):
|
397 |
+
if text_ids_full_len[pos] == end_header_id:
|
398 |
+
end_header_pos = pos
|
399 |
+
break
|
400 |
+
prompt_mask = torch.zeros(sequence_ids[i].size(1)).to(device)
|
401 |
+
prompt_mask[0] = 1 # task_id
|
402 |
+
if end_header_pos != -1:
|
403 |
+
prompt_mask[1:end_header_pos+2] = 1
|
404 |
+
else:
|
405 |
+
prompt_mask[1:len(text_ids_full_len)+1] = 1
|
406 |
+
prompt_mask[len(text_ids_full_len)+1] = 1
|
407 |
+
prompt_mask[len(text_ids_full_len)+2+len(image_ids[i])] = 1
|
408 |
+
prompt_masks.append(prompt_mask.unsqueeze(0))
|
409 |
+
|
410 |
+
return torch.cat(sequence_ids, dim=0), torch.cat(prompt_masks, dim=0), torch.cat(sequence_ids, dim=0)
|
411 |
+
|
412 |
+
|
413 |
+
|
414 |
+
def mask_prompt(self):
|
415 |
+
pass
|
416 |
+
|
417 |
+
def __call__(self, input, task, padding=True, config=None):
|
418 |
+
"""
|
419 |
+
input (tuple) : data pairs contain text(str), image(tensor), or videos(tensor).
|
420 |
+
task (str) : a flag indicates the current task.
|
421 |
+
"""
|
422 |
+
if task == "t2i":
|
423 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
424 |
+
image_ids = input[1] # (B, #tokens)
|
425 |
+
sequence_ids_with_masks = self.t2i_prompt(text_ids, image_ids, input[2])
|
426 |
+
|
427 |
+
elif task == "t2v":
|
428 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
429 |
+
image_ids = input[1] # (B, #tokens)
|
430 |
+
sequence_ids_with_masks = self.t2v_prompt(text_ids, image_ids, input[2])
|
431 |
+
|
432 |
+
elif task == "t2i_plus_lm":
|
433 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
434 |
+
image_ids = input[1] # (B, #tokens)
|
435 |
+
sequence_ids_with_masks = self.t2i_prompt(text_ids[:config.training.batch_size], image_ids,
|
436 |
+
input[2])
|
437 |
+
sequence_ids_with_masks_lm = self.lm_prompt(text_ids[config.training.batch_size:], input[3])
|
438 |
+
return sequence_ids_with_masks, sequence_ids_with_masks_lm
|
439 |
+
|
440 |
+
elif task == "t2i_gen":
|
441 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
442 |
+
image_ids = input[1] # (B, #tokens)
|
443 |
+
sequence_ids_with_masks = self.t2i_gen_prompt(text_ids, image_ids)
|
444 |
+
|
445 |
+
elif task == "t2v_gen":
|
446 |
+
text_ids = self.text_tokenizer(input[0])['input_ids'] # (B, max_len)
|
447 |
+
image_ids = input[1] # (B, #tokens)
|
448 |
+
sequence_ids_with_masks = self.t2v_gen_prompt(text_ids, image_ids)
|
449 |
+
|
450 |
+
elif task == "lm":
|
451 |
+
text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
|
452 |
+
sequence_ids_with_masks = self.lm_prompt(text_ids, input[1])
|
453 |
+
|
454 |
+
elif task == "lm_chat":
|
455 |
+
text_ids = self.text_tokenizer(input[0], truncation=True)['input_ids'] # (B, max_len)
|
456 |
+
sequence_ids_with_masks = self.lm_chat_prompt(text_ids, input[1])
|
457 |
+
|
458 |
+
elif task == "mmu":
|
459 |
+
image_ids = input[0]
|
460 |
+
text_ids = self.text_tokenizer(input[1])['input_ids']
|
461 |
+
sequence_ids_with_masks = self.mmu_prompt(image_ids, text_ids)
|
462 |
+
|
463 |
+
elif task == "r2i":
|
464 |
+
image_ids = input[0]
|
465 |
+
text_ids = self.text_tokenizer(input[1])['input_ids']
|
466 |
+
sequence_ids_with_masks = self.r2i_prompt(image_ids, text_ids)
|
467 |
+
|
468 |
+
else:
|
469 |
+
raise NotImplementedError
|
470 |
+
|
471 |
+
return sequence_ids_with_masks
|
472 |
+
|
473 |
+
|
474 |
+
if __name__ == '__main__':
|
475 |
+
pass
|