unidisc / model.py
aswerdlow's picture
Initial commit
131da64
raw
history blame contribute delete
93.9 kB
import math
import random
import types
import time
from collections import defaultdict
from contextlib import nullcontext
from functools import cached_property, partial
from contextlib import ExitStack
from numpy import mask_indices
from unidisc.utils.tensor_utils import get_contiguous_blocks, get_contiguous_blocks_per_sample, get_interleaved_indices
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate.utils import gather, gather_object
from einops import rearrange
from tensordict import TensorDict
from torch import Tensor, nn
from tqdm.auto import tqdm
import model_eval
import model_setup
import model_utils
import utils
from decoupled_utils import (Profiler, barrier, dprint, get_rank, get_world_size, gprint,
is_local_main_process, is_main_process,
is_torch_cuda_available, is_torch_xla_available,
print_memory, rprint, save_memory_profile,
synchronize_device, try_except, use_dist)
from unidisc.tokenizers.image_tokenizers import (decode_latents, get_image_batch,
get_vae, vae_encode_image)
from unidisc.utils.cuda_utils import sync_times
from unidisc.utils.xla_utils import shard_output
from model_utils import (Loss, ddprint, ema_update, empty_device_cache, get_chameleon_txt_indices, get_interleaved_block_mask, log,
replace_nan_dict, update_histogram, update_logs, get_block_mask)
from unidisc.utils.trainer_utils import TrainingState, incremental_dict_update, linear_warmup
is_xla_available = is_torch_xla_available()
if is_xla_available:
import torch_xla
from torch_xla.distributed.spmd import XLAShardedTensor
def maybe_unwrap(t: torch.Tensor) -> torch.Tensor:
return t.global_tensor if isinstance(t, XLAShardedTensor) else t
class Diffusion:
def __init__(self, config, tokenizer, device, disable_init=False):
super().__init__()
setup_methods = [
'init', 'to', 'get_params', 'get_vae', 'get_cond_vae', 'configure_optimizers',
'_validate_configuration', 'register_signal_handler', 'on_train_start',
'optimizer_step', 'init_dataloader', 'set_accelerator', 'set_callbacks',
'on_train_step_end', 'init_optimizer_lr_scheduler', 'after_backward', 'checkpoint',
'print_hashes', 'shortcut_return', 'reset_validation_metrics', 'unwrap_model'
]
for method_name in setup_methods:
setattr(self, method_name, types.MethodType(getattr(model_setup, method_name), self))
utils_methods = [
'get_coord_plot', '_score_entropy', 'sample_subs_guidance',
'restore_model_and_semi_ar_sample', '_reconstruction_loss',
'restore_model_and_sample', 'get_score', '_staggered_score',
'_analytic_update', '_denoiser_update', '_transp_transition',
'eval_retokenize', 'compute_generative_perplexity', '_d3pm_loss',
'_d3pm_parameterization', '_sedd_parameterization',
'get_base_shapes_for_mup', 'update_histogram', '_maybe_sub_sample',
'viz_images_from_dataloader', 'compute_cider'
]
for method_name in utils_methods:
setattr(self, method_name, types.MethodType(getattr(model_utils, method_name), self))
eval_methods = [
'get_every_n_evals', 'on_validation_epoch_start', 'sample',
'predict_step', 'validation_step', 'on_validation_epoch_end',
'on_validation_epoch_cleanup', '_sample_prior', '_ddpm_forward',
'_ddpm_update', '_ddpm_caching_update', '_sample', '_ar_sampler',
'decode_batch', 'sample_transfusion', 'sample_continuous_image',
'decode_sampling', '_ddpm_update_finetune_controlled_tweedie',
'sample_masking', 'log_flops', "visualize_samples", "_maskgit_update",
"_first_hitting_update", "update_inline_fid", "compute_inline_fid",
"update_clean_fid", "compute_clean_fid_eval", "sample_for_fid",
"compute_clip_score", "mauve_store_references", "zero_shot_eval_step",
"zero_shot_eval_epoch_end", "get_cfg_weight", "cleanup_fid_output",
"calculate_chameleon_perplexity", "get_anole_data",
"update_img_to_txt_mauve_clip", "compute_mauve_entropy",
"get_top_k", "compute_entropy", "get_mauve_score", "get_valid_seq", "gather_tokens",
"count_valid_tokens", "compute_val_metrics_standalone", "_maskgit_nucleus_update",
"get_img_text_saturation_batch", "handle_interleaved_decode", "get_interleaved_image",
"auto_enhance", "get_clip_score", "get_dfn_score", "get_hpsv2_score", "get_model_likelihood_score",
"get_laion_aesthetic_score", "get_rewards", "get_chameleon_score", "clear_reward_models",
"get_text_likelihood_score", "get_text_reward_model_score", "save_image_text_pair"
]
for method_name in eval_methods:
setattr(self, method_name, types.MethodType(getattr(model_eval, method_name), self))
if disable_init:
pass
else:
model_setup.init(self, config, tokenizer, device)
@cached_property
def xla_mesh(self):
import torch_xla.distributed.spmd as xs
return xs.get_global_mesh()
def on_train_resume(self):
if not is_torch_xla_available():
empty_device_cache()
if self.ema is not None and not self.config.trainer.use_custom_ema:
self.ema.restore(self.get_params(), raise_error_if_already_restored=False)
self.backbone.train()
def zero_shot_update_batch(self, batch):
dataset = self.config.data.train
if dataset is None:
return batch
def get_attr(attr_name):
return getattr(self.config.model, attr_name, None)
if dataset == "nlphuji/flickr30k":
# image captioning dataset
# above thing but order is [txt, img]
batch['gt_input_ids'] = batch['input_ids']
image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
image_input_ids += self.text_vocab_size
batch["input_ids"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), image_input_ids], dim=-1).to(self.device)
batch['attention_mask'] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.bool), torch.ones_like(image_input_ids, dtype=torch.bool)], dim=-1).to(self.device)
batch["modality"] = torch.cat([torch.zeros_like(batch['gt_input_ids'], dtype=torch.int64), torch.ones_like(image_input_ids, dtype=torch.int64)], dim=-1).to(self.device)
elif dataset == "facebook/winoground":
# get image and text input ids
caption_0_input_ids = batch['caption_0_input_ids']
caption_1_input_ids = batch['caption_1_input_ids']
image_0 = batch['img_0']
image_1 = batch['img_1']
# tokenize and store captions separately
image_0_input_ids = vae_encode_image(self.config, self.get_vae(), image_0, self.device, get_attr("vae_type")) + self.text_vocab_size
image_1_input_ids = vae_encode_image(self.config, self.get_vae(), image_1, self.device, get_attr("vae_type")) + self.text_vocab_size
# make 4 combinat ions of image and text
batch['input_ids_0_0'] = torch.cat([caption_0_input_ids, image_0_input_ids], dim=-1).to(self.device)
batch['input_ids_0_1'] = torch.cat([caption_0_input_ids, image_1_input_ids], dim=-1).to(self.device)
batch['input_ids_1_0'] = torch.cat([caption_1_input_ids, image_0_input_ids], dim=-1).to(self.device)
batch['input_ids_1_1'] = torch.cat([caption_1_input_ids, image_1_input_ids], dim=-1).to(self.device)
batch['attention_mask'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.bool), torch.ones_like(image_0_input_ids, dtype=torch.bool)], dim=-1).to(self.device)
batch['modality'] = torch.cat([torch.zeros_like(caption_0_input_ids, dtype=torch.int64), torch.ones_like(image_0_input_ids, dtype=torch.int64)], dim=-1).to(self.device)
# elif dataset == "facebook/winoground":
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
return batch
def update_batch(self, batch):
if getattr(self.config.eval, 'big_seq_len_eval', False):
# new batch of 8192 seq length with txt length 4096 and img length 4096s
N = self.config.model.length
new_batch = dict()
new_batch['input_ids'] = torch.zeros(batch['input_ids'].shape[0], N, device=self.device, dtype=batch['input_ids'].dtype)
new_batch['attention_mask'] = torch.ones(batch['attention_mask'].shape[0], N, device=self.device, dtype=batch['attention_mask'].dtype)
new_batch['modality'] = torch.zeros(batch['modality'].shape[0], N, device=self.device, dtype=batch['modality'].dtype)
new_batch['modality'][:, N//2:] = 1
new_batch['modality_mask'] = F.one_hot(new_batch['modality'], num_classes=2).to(torch.bool)
batch = new_batch
return batch
continuous_mode = self.config.trainer.image_mode == "continuous"
if batch is None:
gprint(f"Warning! Batch is None")
return batch
if isinstance(batch, TensorDict):
batch.batch_size = (batch.batch_size[0],)
if self.image_model or getattr(self.config.data, "force_image_dataset", False):
text_input_ids = None
if isinstance(batch, TensorDict) and (self.is_compiled or getattr(self.config.trainer, "force_convert_to_dict", False)):
batch = dict(batch.items())
if "txt_input_ids" in batch or "img_input_ids" in batch:
index_keys = ["img_input_ids", "txt_input_ids", "sample_ids"]
for key in index_keys:
if key in batch:
if isinstance(batch[key], list):
batch[key] = torch.stack(batch[key], dim=0)
batch[key] = batch[key].to(torch.int64)
index_keys = ["img_label"]
for key in index_keys:
if key in batch:
batch[key] = batch[key].squeeze(-1)
img_input_ids = batch.pop("img_input_ids")
batch["input_ids"] = img_input_ids
batch["attention_mask"] = torch.ones_like(img_input_ids).to(torch.bool)
if "txt_input_ids" in batch:
batch["input_ids"] = torch.cat([batch["txt_input_ids"], batch["input_ids"] + self.text_vocab_size], dim=-1)
batch["attention_mask"] = torch.cat([batch["txt_attention_mask"], batch["attention_mask"]], dim=-1)
batch["input_ids"] = batch["input_ids"].to(torch.int64)
if "modality" not in batch:
if getattr(self.config.trainer, "ignore_text_in_unified", False):
modality = torch.ones_like(batch["input_ids"], dtype=torch.int64)
else:
assert self.config.model.txt_length > 0 and self.config.model.img_length > 0
modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
modality[:, -img_input_ids.shape[-1]:] = 1
batch["modality"] = modality
elif (self.config.trainer.multimodal_batches or continuous_mode) and \
not getattr(self.config.trainer, "use_legacy_update_batch_fn", False):
if "img" in batch:
is_image_batch = (batch["modality"] == 1).all(dim=-1)
image_input_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
assert ((batch["modality"].sum(dim=-1) == 0) | (batch["modality"].sum(dim=-1) >= image_input_ids.shape[1])).all()
if getattr(self.config.trainer, "add_label", False):
assert (batch["modality"] == 1).all()
batch["input_ids"][:, 1:] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"][:, 1:])
elif image_input_ids.ndim == 3:
batch["img_emb"] = torch.where((batch["modality"] == 1)[:, :, None], image_input_ids, torch.nan)
elif (batch["input_ids"][batch["modality"] == 1] == -1).all():
batch["input_ids"].masked_scatter_(batch["modality"] == 1, image_input_ids)
else:
batch["input_ids"] = torch.where(is_image_batch[:, None], image_input_ids, batch["input_ids"])
if getattr(self.config.trainer, "force_shift_raw_image_batches", False):
assert not getattr(self.config.trainer, "force_shift_image_batches", False)
batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"])
else:
if getattr(self.config.trainer, "add_label", False):
shift_index = self.vocab_size - self.config.model.add_labels
batch["input_ids"] = torch.cat([batch["label"] + shift_index, batch["input_ids"]], dim=-1)
batch["attention_mask"] = torch.cat([torch.zeros_like(batch["label"], dtype=torch.bool), batch["attention_mask"]], dim=-1)
batch["modality"] = torch.cat([torch.ones_like(batch["label"], dtype=torch.int64), batch["modality"]], dim=-1)
assert (batch["modality"] == 1).all()
batch["input_ids"] = batch["input_ids"].to(torch.int64)
if "sample_ids" in batch:
batch["sample_ids"] = batch["sample_ids"].to(torch.int64)
if getattr(self.config.trainer, "force_shift_image_batches", False):
batch["input_ids"] = torch.where(batch["modality"] == 1, batch["input_ids"] + self.text_vocab_size, batch["input_ids"])
else:
if continuous_mode:
assert False
else:
if "input_ids" in batch and not self.config.trainer.ignore_text_in_unified:
assert self.config.model.unified_model
assert "attention_mask" in batch
text_input_ids = batch["input_ids"]
image_ids = get_image_batch(self.config, self.get_vae(), batch, self.device)
image_attention_mask = torch.ones_like(image_ids).to(torch.bool)
if "cond_img" in batch:
cond_image_ids = get_image_batch(self.config, self.get_cond_vae(), batch, self.device, use_cond=True)
batch["cond_input_ids"] = cond_image_ids
if text_input_ids is not None:
assert batch["input_ids"].shape[1] == self.config.model.txt_length
assert image_ids.shape[1] == self.config.model.img_length
image_ids = image_ids + self.text_vocab_size
batch["input_ids"] = torch.cat([batch["input_ids"].to(self.device), image_ids], dim=-1)
batch["attention_mask"] = torch.cat([batch["attention_mask"].to(self.device), image_attention_mask], dim=-1).to(torch.bool)
assert batch["input_ids"].shape[1] == batch["attention_mask"].shape[1] == self.config.model.length
batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
batch["modality"][:, -image_ids.shape[-1]:] = 1
else:
assert self.unified_model is False
batch["input_ids"] = image_ids
batch["attention_mask"] = image_attention_mask
batch["modality"] = torch.ones_like(batch["input_ids"], dtype=torch.int64)
if "txt_x0_unmask" in batch and "img_x0_unmask" in batch:
assert not continuous_mode
batch["gt_img_input_ids"] = image_ids
batch["x0_unmask"] = torch.cat([batch["txt_x0_unmask"], batch["img_x0_unmask"]], dim=-1)
batch["input_ids"][~batch["x0_unmask"]] = self.mask_index
if (batch["input_ids"].shape[1] != self.config.model.length) and not self.config.trainer.ar_inpainting:
gprint(f"Warning! Input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}")
batch["input_ids"] = batch["input_ids"][:, : self.config.model.length]
assert False, f"input ids are not the correct length input ids shape: {batch['input_ids'].shape}, model length: {self.config.model.length}"
if getattr(self.config.model, "img_cond", False):
assert "cond_input_ids" in batch
assert not continuous_mode
if "modality" in batch:
batch["modality"] = batch["modality"].to(torch.int64)
if self.config.trainer.multimodal_batches and batch["modality"].ndim == 2 and batch["modality"].shape[-1] == 1:
batch["modality"] = batch["modality"].repeat(1, self.config.model.length)
else:
if self.image_model and not self.config.trainer.multimodal_batches:
assert self.config.model.txt_length > 0 and self.config.model.img_length > 0
modality = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
modality[:, self.static_img_sl] = 1
batch["modality"] = modality
elif self.config.data.txt_only:
batch["modality"] = torch.zeros_like(batch["input_ids"], dtype=torch.int64)
if "modality" in batch:
batch["modality"][batch["modality"] == -1] = 0
assert batch["modality"].min() == 0 and batch["modality"].max() == 1
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
batch["batch_contains_img"] = (batch["modality"] == 1).any(dim=-1)
batch['txt_sl'] = self.txt_sl(batch)
batch['img_sl'] = self.img_sl(batch)
if getattr(self.config.trainer, "force_remove_img_tokens", False):
assert not continuous_mode
batch["input_ids"] = batch["input_ids"][batch['txt_sl']]
batch["attention_mask"] = batch["attention_mask"][batch['txt_sl']]
if getattr(self.config.trainer, "add_label", False):
assert getattr(self.config.model, "add_labels", False)
assert "label" in batch
batch["label"] = batch["label"].to(torch.int64)
assert 0 <= batch["label"].min() and batch["label"].max() < self.config.model.add_labels
shift_index = self.vocab_size - self.config.model.add_labels
assert batch["input_ids"].shape[-1] == self.config.model.length
if batch["label"].ndim == 1:
batch["input_ids"][:, [0]] = (batch["label"] + shift_index).unsqueeze(-1)
else:
batch["input_ids"][:, [0]] = batch["label"] + shift_index
batch["attention_mask"][:, 0] = False
if isinstance(batch, dict):
for key in batch.keys():
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.device)
elif isinstance(batch, TensorDict):
assert self.config.backbone != "gemma"
batch = batch.to(self.device)
if getattr(self.config.trainer, "force_full_attention_mask", False):
batch["attention_mask"] = torch.ones_like(batch["attention_mask"], dtype=torch.bool)
batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
if self.config.data.require_sample_ids:
assert "sample_ids" in batch
batch["sample_ids"][~(batch["attention_mask"].bool())] = -1
batch["attention_mask"][batch["sample_ids"] == -1] = False
# Flip [txt, img] -> [img, txt]
# TODO: Flip by sample not batch. As we train w/~8 batches, it's for now
if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.parameterization == "ar" and getattr(self.config.trainer, "rand_flip_ar_prob", None) is not None:
assert (batch["modality"][:, :self.config.model.txt_length] == 0).all() and (batch["modality"][:, self.config.model.txt_length:] == 1).all(), "Modality does not match img_before_txt configuration"
batch_flip_mask = torch.rand(batch["modality"].shape[0], device=self.device) < self.config.trainer.rand_flip_ar_prob
img_slice = slice(-self.config.model.img_length, None)
txt_slice = slice(None, self.config.model.txt_length)
for key in ["modality", "attention_mask", "input_ids"]:
batch[key][batch_flip_mask] = torch.cat([batch[key][batch_flip_mask][:, img_slice], batch[key][batch_flip_mask][:, txt_slice]], dim=1)
if "modality_mask" in batch:
batch["modality_mask"] = F.one_hot(batch["modality"], num_classes=2).to(torch.bool)
batch['txt_sl'] = None
batch['img_sl'] = None
batch["batch_flip_mask"] = batch_flip_mask
if self.config.trainer.interleaved and "sample_ids" not in batch:
batch["sample_ids"] = torch.zeros_like(batch["modality"], dtype=torch.int64)
if self.config.trainer.interleaved:
batch_indices, start_positions, end_positions = get_contiguous_blocks(batch["modality"])
interleaved_metadata = TensorDict({
"batch_indices": batch_indices,
"start_positions": start_positions,
"end_positions": end_positions
}, batch_size=[])
allowed_image_sizes = (64, 256, 1024, 2304, 4096)
block_sizes = (end_positions - start_positions).to(torch.int32)
is_txt_block = batch["modality"][batch_indices, start_positions] == 0
is_valid_img_size = torch.isin(block_sizes, torch.tensor(allowed_image_sizes, dtype=torch.int32, device=self.device))
if not ((is_txt_block | is_valid_img_size).all()):
gprint(f"WARNING: Found non-text block of size {block_sizes[~(is_txt_block | is_valid_img_size)]} in interleaved batch")
if isinstance(batch, TensorDict):
batch.batch_size = []
batch["interleaved_metadata"] = interleaved_metadata
return batch
def get_cond_dict(self, batch):
ret_dict = dict()
if "cond_input_ids" in batch:
ret_dict["x_cond"] = batch["cond_input_ids"]
if "img_label" in batch:
ret_dict["label"] = batch["img_label"]
if self.config.model.use_attention_mask:
ret_dict["attention_mask"] = batch["attention_mask"]
if self.config.trainer.multimodal_batches:
ret_dict["modality"] = batch["modality"]
if self.config.trainer.image_mode == "continuous":
ret_dict["continuous_mode"] = True
ret_dict["modality"] = batch["modality"]
if self.parameterization == "ar" and "modality" in batch:
ret_dict["modality"] = batch["modality"]
return ret_dict
def training_step(self, batch, batch_idx):
batch = self.update_batch(batch)
return self.compute_loss(batch, prefix="train", batch_idx=batch_idx)
def q_xt(self, x, move_chance, allow_move_mask=None, return_ignore_batch_mask_for_metrics=False, mask_image_square=False, mask_text_region=False, batch=None):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape (batch_size, 1).
"""
if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False):
num_to_mask = int(x.shape[1] * move_chance[0].item())
batch_size, seq_len = x.shape
random_indices = torch.rand(batch_size, seq_len, device=x.device).argsort(dim=1)[:, :num_to_mask]
xt = x.scatter(1, random_indices, self.mask_index)
return xt
move_indices = torch.rand(*x.shape, device=x.device) < move_chance
if mask_image_square:
latent_dim = int(math.sqrt(self.config.model.img_length))
img_move_indices = move_indices[:, self.static_img_sl].clone().reshape(move_indices.shape[0], latent_dim, latent_dim)
max_d = int(math.sqrt(self.config.model.img_length))
for b in range(move_indices.shape[0]):
if move_chance[b] == 1:
continue
h, w = img_move_indices[b].shape
d = random.randint(max_d // 2, max_d - 2)
i = random.randint(0, h - d)
j = random.randint(0, w - d)
mask = torch.zeros_like(img_move_indices[b], dtype=torch.bool)
mask[i:i+d, j:j+d] = True
move_indices[b, self.static_img_sl] = mask.reshape(-1)
if mask_text_region:
for b in range(x.shape[0]):
if move_chance[b] == 1:
continue
should_mask = torch.zeros_like(move_indices[b, self.static_txt_sl], dtype=torch.bool)
max_valid = (x[b] == self.tokenizer.eos_token_id).nonzero()[0, 0] if self.tokenizer.eos_token_id in x[b] else x.shape[1]
d = random.randint(max_valid//3, max_valid-1)
start = random.randint(0, max_valid - d)
should_mask[start:start+d] = True
move_indices[b, self.static_txt_sl] = should_mask
ignore_batch_mask_for_metrics = None
should_mask_txt, should_mask_img = None, None
if (mask_prob := getattr(self.config.trainer, "mask_entire_modality", None)) is not None \
and (mask_image_square is False and mask_text_region is False) and self.backbone.training:
assert batch is not None
batch_size, seq_len = x.shape
if getattr(self.config.trainer, "mask_txt_only", False):
should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob
should_mask_img = torch.zeros_like(should_mask_txt, device=x.device)
else:
should_mask_txt = torch.rand(batch_size, 1, device=x.device) < mask_prob/2
should_mask_img = torch.rand(batch_size, 1, device=x.device) < mask_prob/2
if self.config.trainer.multimodal_batches:
if self.config.trainer.interleaved:
batch_indices, start_positions, end_positions = get_contiguous_blocks_per_sample(batch["modality"], batch["sample_ids"])
block_size = end_positions - start_positions
size_mask = block_size > 4
batch_indices, start_positions, end_positions = batch_indices[size_mask], start_positions[size_mask], end_positions[size_mask]
block_counts = torch.zeros_like(batch_indices)
max_num_sample_ids = torch.zeros_like(batch_indices)
for i in range(len(batch_indices)):
curr_sample_id = batch["sample_ids"][batch_indices[i], start_positions[i]]
# Find blocks before this one with same batch index and sample_id
prev_blocks_mask = (batch_indices[:i] == batch_indices[i]) & \
(batch["sample_ids"][batch_indices[:i], start_positions[:i]] == curr_sample_id)
total_in_sample = ((batch_indices == batch_indices[i]) & (batch["sample_ids"][batch_indices, start_positions] == curr_sample_id)).sum()
block_counts[i] = prev_blocks_mask.sum()
max_num_sample_ids[i] = total_in_sample
block_prob = (block_counts + 1) / max_num_sample_ids
positions = torch.arange(move_indices.shape[-1], device=move_indices.device).unsqueeze(0) # Shape: [1, N]
mask = (positions >= start_positions.unsqueeze(1)) & (positions < end_positions.unsqueeze(1)) # Shape: [M, N]
mask = mask & (torch.rand(batch_indices.shape[0], 1, device=x.device) < (mask_prob * block_prob * 2)[..., None])
expanded_batch_indices = batch_indices.unsqueeze(1).expand(-1, move_indices.shape[1]) # Shape: [M, N]
# True if we should manually mask the part of the sequence
accum = torch.zeros_like(move_indices, dtype=torch.int32) # Shape: [B, N]
accum.scatter_add_(0, expanded_batch_indices, mask.int()) # Accumulate counts
accum = accum.to(torch.bool)
move_indices = move_indices | accum
# We ignore the entire sequence if any of the blocks are fully masked
ignore_batch_mask_for_metrics = torch.zeros((move_indices.shape[0],), device=x.device, dtype=torch.bool)
ignore_batch_mask_for_metrics.scatter_add_(0, batch_indices, mask.any(dim=-1))
else:
# TODO: Be smarter about masking for interleaved
# To make sure that we have even masking prob, we prefer to mask less but equally
both_mask = should_mask_txt & should_mask_img
should_mask_txt = torch.where(both_mask, False, should_mask_txt)
should_mask_img = torch.where(both_mask, False, should_mask_img)
move_indices = torch.where(should_mask_txt, batch["modality_mask"][..., 0], move_indices)
move_indices = torch.where(should_mask_img, batch["modality_mask"][..., 1], move_indices)
ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt
else:
both_mask = should_mask_txt & should_mask_img
should_mask_txt[both_mask] = False
should_mask_img[both_mask] = False
should_mask_img[batch["txt_sl"].all(dim=-1)] = False
move_indices[:, self.static_txt_sl] = torch.where(should_mask_txt, True, move_indices[:, self.static_txt_sl])
move_indices[:, self.static_img_sl] = torch.where(should_mask_img, True, move_indices[:, self.static_img_sl])
ignore_batch_mask_for_metrics = should_mask_img | should_mask_txt
joint_ar_nar_mask = None
if self.config.trainer.joint_ar_nar_prob is not None and self.training:
batch_size = x.shape[0]
current_prob = linear_warmup(
current_step=self.global_step,
warmup_steps=self.config.trainer.joint_ar_nar_prob_warmup_steps,
final_value=self.config.trainer.joint_ar_nar_prob,
initial_value=1.0
)
joint_ar_nar_mask = torch.rand(batch_size, device=x.device) < current_prob
move_indices = torch.where(joint_ar_nar_mask[:, None], False, move_indices)
if self.config.trainer.add_label:
move_indices[:, 0] = False
if self.config.trainer.first_token_dropout is not None and self.training:
_initial_mask = torch.rand(x.shape[0], device=x.device) < self.config.trainer.first_token_dropout
move_indices[:, 0] = torch.where(_initial_mask, True, move_indices[:, 0])
if ignore_batch_mask_for_metrics is None:
ignore_batch_mask_for_metrics = _initial_mask
else:
ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | _initial_mask
if allow_move_mask is not None:
move_indices = move_indices & allow_move_mask
if getattr(self.config.trainer, "discrete_diffusion_mode", "absorbing") == "uniform":
if getattr(self.config.model, "force_argmax_valid_indices", False):
assert self.mask_index == self.text_vocab_size - 1
text_random_tokens = torch.randint(0, self.text_vocab_size - 1, size=x.shape, device=x.device)
img_random_tokens = torch.randint(self.text_vocab_size, self.vocab_size, size=x.shape, device=x.device)
random_tokens = torch.where(batch["modality_mask"][..., 0], text_random_tokens, img_random_tokens)
assert not torch.any(random_tokens == self.mask_index)
else:
random_tokens = torch.randint(0, vocab_size, size=x.shape, device=x.device)
random_tokens = torch.where(random_tokens == self.mask_index, random_tokens + 1, random_tokens) # avoid mask index
xt = torch.where(move_indices, random_tokens, x)
else:
xt = torch.where(move_indices, self.mask_index, x)
if self.parameterization == "ar":
xt = x.clone()
if return_ignore_batch_mask_for_metrics:
return xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices
else:
return xt
def _sample_t(self, n, device):
if self.config.backbone == "maskdit" and getattr(self.config.trainer, "force_single_timestep_per_batch", False):
_eps_t = torch.rand(1, device=device).repeat(n)
else:
_eps_t = torch.rand(n, device=device)
if self.config.trainer.joint_ar_nar_timestep_warmup_steps is not None:
max_t = linear_warmup(
current_step=self.global_step,
warmup_steps=self.config.trainer.joint_ar_nar_timestep_warmup_steps,
final_value=1,
initial_value=0,
start_step=0
)
_eps_t = _eps_t * max_t
if max_t == 1:
offset = torch.arange(n, device=device) / n
_eps_t = (_eps_t / n + offset) % 1
elif self.antithetic_sampling:
offset = torch.arange(n, device=device) / n
_eps_t = (_eps_t / n + offset) % 1
if getattr(self.config.trainer, "force_timestep", None) is not None:
_eps_t[:] = self.config.trainer.force_timestep
elif getattr(self.config.eval, "ar_inpainting_force_val", None) is not None:
_eps_t[:] = self.config.eval.ar_inpainting_force_val
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
if self.importance_sampling:
return self.noise.importance_sampling_transformation(t)
return t.to(torch.float32)
def _subs_parameterization(self, logits, xt, batch=None, modality=None, **kwargs):
# log prob at the mask index = - infinity
if not self.allow_slicing:
logits = logits.clone()
logits[..., self.mask_index] += self.neg_infinity
if getattr(self.config.model, "force_argmax_valid_indices", False):
if self.config.trainer.multimodal_batches:
_txt_sl = batch["txt_sl"] if modality is None else modality == 0
_img_sl = batch["img_sl"] if modality is None else modality == 1
logits[..., self.text_vocab_size:] = torch.where(_txt_sl[..., None], self.neg_infinity, logits[..., self.text_vocab_size:])
logits[..., :self.text_vocab_size] = torch.where(_img_sl[..., None], self.neg_infinity, logits[..., :self.text_vocab_size])
else:
logits[..., self.static_txt_sl, self.text_vocab_size:] = self.neg_infinity
logits[..., self.static_img_sl, :self.text_vocab_size] = self.neg_infinity
# Normalize the logits such that x.exp() is
# a probability distribution over vocab_size.
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
if self.parameterization != "ar" and xt is not None:
# Apply updates directly in the logits matrix.
# For the logits of the unmasked tokens, set all values
# to -infinity except for the indices corresponding to
# the unmasked tokens.
unmasked_indices = xt != self.mask_index
if not self.allow_slicing:
logits = torch.where(unmasked_indices.unsqueeze(-1), torch.full_like(logits, self.neg_infinity), logits)
logits = torch.where(
unmasked_indices.unsqueeze(-1) & (torch.arange(logits.size(-1)).to(logits.device) == xt.unsqueeze(-1)),
torch.zeros_like(logits),
logits
)
else:
logits[unmasked_indices] = self.neg_infinity
logits[unmasked_indices, xt[unmasked_indices]] = 0
return logits
def _process_sigma(self, sigma):
if sigma is None:
assert (self.parameterization == "ar" or self.config.trainer.ar_llm_loss) or self.config.trainer.allow_null_sigma
return sigma
if sigma.ndim > 1 and not self.config.trainer.image_mode == "continuous":
sigma = sigma.squeeze(-1)
assert sigma.ndim == 1, sigma.shape
if not self.time_conditioning and getattr(self.config.model, "force_time_conditioning", False):
sigma = torch.zeros_like(sigma)
return sigma
def forward(
self,
x,
sigma,
batch=None,
forward_attention_mask=None,
return_additional_loss=False,
x_img_emb=None,
disable_ar_shift=False,
continuous_mode=False,
joint_ar_nar_mask=None,
return_logits=False,
block_mask=None,
update_cache_slice=None,
**kwargs,
):
"""Returns log score."""
sigma = self._process_sigma(sigma)
if self.config.trainer.image_mode == "continuous": assert "modality" in kwargs
should_autocast = (((self.config.trainer.disable_forward_autocast_during_eval and self.backbone.training) is False) and (self.dtype != torch.float32))
with ExitStack() as stack:
if should_autocast:
stack.enter_context(torch.autocast(device_type=self.device.type, dtype=self.dtype))
orig_modality = None
if self.config.backbone == "elm":
if getattr(self.config.trainer, "print_llm_ppl", False):
_labels = x.clone()
_labels[~forward_attention_mask] = -100
kwargs['labels'] = _labels
if "modality" in kwargs:
if self.config.mode == "eval": orig_modality = kwargs.pop("modality")
else: kwargs.pop("modality")
if "modality_mask" in kwargs: kwargs.pop("modality_mask")
if "x0" in kwargs: kwargs.pop("x0")
if "start_pos" in kwargs: kwargs.pop("start_pos")
if "sample_ids" in kwargs: kwargs.pop("sample_ids")
output = self.backbone(input_ids=x, **kwargs)
if self.config.mode == "eval": kwargs["modality"] = orig_modality
if isinstance(output, Tensor):
logits = output
else:
logits = output.logits
if getattr(self.config.trainer, "print_llm_ppl", False):
rprint(f"AR PPL: {torch.exp(output.loss)}")
else:
if self.config.trainer.compile == 'max-autotune' and not is_xla_available:
torch.compiler.cudagraph_mark_step_begin()
logits = self.backbone(x, sigma, continuous_mode=continuous_mode, x_img_emb=x_img_emb, block_mask=block_mask, update_cache_slice=update_cache_slice, **kwargs)
if self.config.trainer.force_bf16_eval:
logits = logits.to(torch.bfloat16)
if continuous_mode:
assert self.parameterization == "ar"
logits, logits_img = logits
if self.config.trainer.ar_shift and not disable_ar_shift:
# config trainer ar shift is for training
# disable ar shift is for sampling at inference
logits = logits[:, :-1]
xt = x[:, 1:]
if orig_modality is not None and self.config.mode == 'eval':
orig_modality = orig_modality[:, 1:]
else:
xt = x
if self.config.trainer.low_precision_loss:
logits = logits.to(self.dtype)
if continuous_mode:
logits_img = logits_img.to(self.dtype)
if self.parameterization == "planner":
return logits
elif self.config.trainer.ar_llm_loss:
assert not self.parameterization == "ar"
model_output = self._subs_parameterization(logits, xt=xt, modality=orig_modality), logits
if is_xla_available: shard_output(model_output[0], self.xla_mesh)
if is_xla_available: shard_output(model_output[1], self.xla_mesh)
return model_output if return_additional_loss else model_output[0]
elif self.parameterization == "ar":
if not getattr(self.config.trainer, "use_orig_unidisc_dit", False):
logits = torch.where(
torch.arange(logits.shape[-1], device=logits.device)[None, None, :] == self.mask_index, self.neg_infinity, logits
)
_modality = kwargs.get("modality") if batch is None else batch.get("modality")
# During eval, we let the sampler handle this part.
if getattr(self.config.model, "force_argmax_valid_indices", False) and _modality.shape[1] == (logits.shape[1] + 1):
if not self.allow_slicing:
logits = logits.clone()
logits[..., self.text_vocab_size:] = torch.where(
(kwargs.get("modality") == 0)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., self.text_vocab_size:]
)
logits[..., :self.text_vocab_size] = torch.where(
(kwargs.get("modality") == 1)[..., 1:, None], torch.finfo(logits.dtype).min, logits[..., :self.text_vocab_size]
)
logits = logits.log_softmax(-1)
if continuous_mode:
return (logits, logits_img)
elif self.parameterization == "subs":
if return_logits:
return logits
model_output = self._subs_parameterization(logits, xt=xt, batch=batch, **kwargs)
if is_xla_available: shard_output(model_output, self.xla_mesh)
return model_output
elif self.parameterization == "sedd":
return self._sedd_parameterization(logits=logits, xt=x, sigma=sigma)
elif self.parameterization == "d3pm":
return self._d3pm_parameterization(logits=logits)
return logits
def compute_loss(self, batch, prefix, batch_idx=-1):
if not is_xla_available and ((self.current_run_fwd_bwd_pass == 0 and self.config.mode == 'train') or batch_idx == 0):
self.visualize_samples(batch, batch_idx, split=prefix)
if getattr(self.config.trainer, 'overfit_on_first_batch', False):
if batch_idx <= 0:
# store it
self.overfit_batch = batch.copy()
else:
batch = self.overfit_batch
kwargs = self.get_cond_dict(batch)
modality_mask = batch.get("modality_mask", None)
(input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample(batch["input_ids"], batch.get("attention_mask", None))
continuous_mode = self.config.trainer.image_mode == "continuous"
joint_ar_nar_mask, modality = None, None
if continuous_mode:
assert 'modality' in batch
x0, img_emb, attention_mask, modality = (
batch["input_ids"],
batch["img_emb"],
batch["attention_mask"],
batch["modality"],
) # img_emb has [0.] * txt_len + img_emb
xt = x0
B, N_tot, C = img_emb.shape
noise_scheduler = self.get_vae().scheduler
noise = torch.randn_like(img_emb)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (B,), device=img_emb.device).long()
img_timesteps = timesteps.unsqueeze(-1).expand(-1, N_tot).to(self.dtype)
zero_timesteps = torch.zeros_like(img_timesteps)
unet_conditioning = torch.where(modality == 1, img_timesteps, zero_timesteps)
# unet_conditioning = timesteps.to(self.dtype)
# unet_conditioning = torch.where(modality_mask==1, timesteps.to(self.dtype), torch.zeros_like(timesteps.to(self.dtype)))
x_img_emb = noise_scheduler.add_noise(img_emb, noise, timesteps).to(self.dtype)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(img_emb, noise, timesteps) # todo, might break
elif noise_scheduler.config.prediction_type:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
target = target.to(self.dtype)
else:
unet_conditioning, xt, x0, x_img_emb, modality_mask = None, None, input_tokens, None, batch.get("modality_mask", None)
if self.parameterization != "ar":
t = self._sample_t(x0.shape[0], x0.device)
if self.T > 0:
t = (t * self.T).to(torch.int)
t = t / self.T
t += 1 / self.T # t \in {1/T, 2/T, ..., 1}
if self.change_of_variables:
unet_conditioning = t[:, None]
f_T = torch.log1p(-torch.exp(-self.noise.sigma_max))
f_0 = torch.log1p(-torch.exp(-self.noise.sigma_min))
move_chance = torch.exp(f_0 + t * (f_T - f_0))
move_chance = move_chance[:, None]
else:
# total, rate
sigma, dsigma = self.noise(t)
unet_conditioning = sigma[:, None]
move_chance = 1 - torch.exp(-sigma[:, None])
xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, should_mask_txt, should_mask_img, move_indices = self.q_xt(x0, move_chance, return_ignore_batch_mask_for_metrics=True, batch=batch)
if (self.config.model.flex_attention_img_masking_prob is not None or self.config.model.flex_attention_txt_masking_prob is not None) and self.backbone.training:
assert xt.shape[1] == (self.config.model.img_length + self.config.model.txt_length)
txt_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_txt_masking_prob
img_batch_attn_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.model.flex_attention_img_masking_prob
# If we mask out a modality, we cannot let it only see itself
txt_batch_attn_dropout = txt_batch_attn_dropout & ~should_mask_txt.squeeze(-1)
img_batch_attn_dropout = img_batch_attn_dropout & ~should_mask_img.squeeze(-1)
kwargs['block_mask'] = get_block_mask(txt_batch_attn_dropout, img_batch_attn_dropout, self.config.model.txt_length, xt.shape[0], xt.shape[1], xt.device)
# TODO: Somehow report these metrics so we know what's going on
ignore_batch_mask_for_metrics = ignore_batch_mask_for_metrics | (txt_batch_attn_dropout | img_batch_attn_dropout).unsqueeze(-1)
if getattr(self.config.trainer, "interleaved_training_flex_attention", False):
kwargs['block_mask'] = get_interleaved_block_mask(batch["sample_ids"], batch_size=xt.shape[0], seq_len=xt.shape[1], device=xt.device)
kwargs['sample_ids'] = batch["sample_ids"]
elif self.config.trainer.ar_inpainting:
x0 = torch.cat([x0, x0], dim=1)
kwargs['modality'] = torch.cat([kwargs['modality'], kwargs['modality']], dim=1)
attention_mask = torch.cat([torch.zeros_like(attention_mask, dtype=attention_mask.dtype), torch.ones_like(attention_mask, dtype=attention_mask.dtype)], dim=1)
modality_mask = torch.cat([modality_mask, modality_mask], dim=1)
min_val, max_val = 0.0, 1.0
n = x0.shape[0]
_eps_t = torch.rand(n, device=self.device)
offset = torch.arange(n, device=self.device) / n
_eps_t = (_eps_t / n + offset) % 1
t = (max_val - min_val) * _eps_t + min_val
if getattr(self.config.eval, "ar_inpainting_force_val", None) is not None:
t = torch.full_like(t, getattr(self.config.eval, "ar_inpainting_force_val"), dtype=t.dtype, device=t.device)
move_indices = torch.rand(*x0.shape, device=x0.device) < t[:, None]
move_indices[:, x0.shape[1] // 2:] = False
x0 = torch.where(move_indices, self.mask_index, x0)
xt = x0
else:
xt = x0
if (self.training or getattr(self.config.trainer, "force_flip_ar_val", False)) and self.config.trainer.rand_ar_modality_dropout is not None:
assert not is_xla_available
xt = xt.clone()
batch_modality_dropout = torch.rand(xt.shape[0], device=xt.device) < self.config.trainer.rand_ar_modality_dropout
first_modality = batch["modality"][:, 0]
first_modality_mask = batch["modality"] == first_modality[:, None]
xt = torch.where(first_modality_mask & batch_modality_dropout[:, None], self.mask_index, xt)
attention_mask = torch.where(first_modality_mask & batch_modality_dropout[:, None], False, attention_mask)
true_logits = None
model_output = self.forward(
xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=x_img_emb, joint_ar_nar_mask=joint_ar_nar_mask, **kwargs
)
if isinstance(model_output, tuple):
if continuous_mode:
model_output, img_output = model_output # model_output is for text, img_output is for image although both will have N_total length (zeroed out according to modality mask)
B, _, C = img_output.shape
# use modality mask to get the correct logits
x0 = x0[modality==0].reshape(B, -1)
xt = xt[modality==0].reshape(B, -1)
attention_mask = torch.ones_like(x0, dtype=torch.bool) # since we separate text, we don't need to mask it out
img_output = img_output[modality==1].reshape(B, -1, C)
target = target[modality==1].reshape(B, -1, C)
else:
model_output, true_logits = model_output
to_dtype = self.dtype if self.config.trainer.low_precision_loss else torch.float32
model_output = model_output.to(to_dtype)
if true_logits is not None:
true_logits = true_logits.to(self.dtype)
if continuous_mode:
img_output = img_output.to(to_dtype)
target = target.to(to_dtype)
# if prefix != 'train':
# breakpoint()
if self.config.trainer.ar_shift:
x0 = x0[:, 1:]
xt = xt[:, 1:]
attention_mask = attention_mask[:, 1:]
if modality_mask is not None: modality_mask = modality_mask[:, 1:]
if modality is not None: modality = modality[:, 1:]
if not self.is_compiled:
utils.print_nans(model_output, "model_output")
if self.parameterization == "sedd":
return dsigma[:, None] * self._score_entropy(model_output, sigma[:, None], xt, x0)
elif self.parameterization == "planner":
return F.binary_cross_entropy_with_logits(model_output.squeeze(-1), move_indices.float()).mean()
diffusion_loss = None
if self.T > 0:
diffusion_loss = self._d3pm_loss(model_output=model_output, xt=xt, x0=x0, t=t)
if self.parameterization == "d3pm":
reconstruction_loss = self._reconstruction_loss(x0)
elif self.parameterization == "subs" or self.parameterization == "ar":
reconstruction_loss = 0
# return reconstruction_loss + diffusion_loss
if self.parameterization == "ar":
if getattr(self.config.trainer, "use_orig_unidisc_dit", False):
return self.shortcut_return(model_output, x0, attention_mask, prefix)
else:
log_p_theta = model_output.gather(-1, x0[:, :, None])[:, :, 0]
else:
# SUBS parameterization, continuous time
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1)
if self.change_of_variables or self.importance_sampling:
return log_p_theta * torch.log1p(-torch.exp(-self.noise.sigma_min))
if self.parameterization == "ar" or getattr(self.config.trainer, "no_ce_weighting", False):
std_weighting = 1
else:
std_weighting = (dsigma / torch.expm1(sigma))[:, None]
# ddprint(f"self.current_run_fwd_bwd_pass: {self.current_run_fwd_bwd_pass}, log_p_theta: {torch.isnan(log_p_theta).any()}")
# if torch.isnan(log_p_theta).any() or self.current_run_fwd_bwd_pass > 15473:
# import pickle
# import time
# rank = get_rank()
# timestamp = int(time.time() * 1e9) # nanosecond timestep
# filename = f'batch_datastep_{self.current_run_fwd_bwd_pass}_rank{rank}_{timestamp}.pkl'
# with open(filename, 'wb') as f:
# pickle.dump(log_p_theta, f)
# ddprint(f"Saved batch to {filename}")
loss = -log_p_theta * std_weighting
if not (self.parameterization == "ar" or (self.config.trainer.ar_llm_loss and joint_ar_nar_mask is None) or getattr(self.config.trainer, "no_ce_weighting", False)):
gamma = getattr(self.config.trainer, "softmin_snr", None)
if gamma is not None:
softmin_weighting = (dsigma / (torch.expm1(sigma) + (1 / gamma)))[:, None]
loss = -log_p_theta * softmin_weighting
if diffusion_loss is not None:
assert self.T > 0
loss = diffusion_loss
std_loss = -log_p_theta * std_weighting
loss_dict = dict(std_loss=std_loss.detach(), extra_losses=dict())
if self.config.trainer.log_seperate_modal_losses:
assert not continuous_mode
loss_dict.update(
dict(
std_txt_loss=(std_loss.detach() * modality_mask[..., 0] * attention_mask),
std_img_loss=(std_loss.detach() * modality_mask[..., 1] * attention_mask)
)
)
if getattr(self.config.trainer, "mask_entire_modality", None) is not None and self.backbone.training and not self.config.parameterization == "ar":
loss_dict['batch_ignore_loss'] = ignore_batch_mask_for_metrics.squeeze(-1)
if joint_ar_nar_mask is not None:
if "batch_ignore_loss" in loss_dict:
loss_dict["batch_ignore_loss"] = loss_dict["batch_ignore_loss"] | joint_ar_nar_mask
else:
loss_dict["batch_ignore_loss"] = joint_ar_nar_mask
if (self.config.trainer.multimodal_batches or (self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None)) and not continuous_mode:
txt_mask = modality_mask[..., 0] & attention_mask
img_mask = modality_mask[..., 1] & attention_mask
txt_count = txt_mask.sum()
img_count = img_mask.sum()
total_count = txt_count + img_count
txt_frac = txt_count / total_count
img_frac = img_count / total_count
loss_dict["extra_losses"]["trainer/img_frac"] = img_frac
loss_dict["extra_losses"]["trainer/txt_frac"] = txt_frac
loss_dict["extra_losses"]["trainer/attention_mask_valid_frac"] = attention_mask.sum() / attention_mask.numel()
if "batch_ignore_loss" in loss_dict:
loss_dict["extra_losses"]["trainer/ignore_batch_metrics_frac"] = loss_dict["batch_ignore_loss"].sum() / loss_dict["batch_ignore_loss"].numel()
if joint_ar_nar_mask is not None:
pass # Defer loss mean until after ar_loss is calculated
elif self.config.trainer.text_loss_weight is not None and self.config.trainer.img_loss_weight is not None:
assert not continuous_mode
loss = loss * attention_mask
txt_loss = (
loss[txt_mask].sum() / txt_count
) * txt_frac * self.config.trainer.text_loss_weight
img_loss = (
loss[img_mask].sum() / img_count
) * img_frac * self.config.trainer.img_loss_weight
if getattr(self.config.trainer, "set_max_txt_loss_ratio", None) is not None and not (torch.isnan(img_loss).any() or torch.isnan(txt_loss).any()):
max_txt_loss = getattr(self.config.trainer, "set_max_txt_loss_ratio", 1.5) * img_loss.detach()
scale = torch.minimum(torch.tensor(1.0, device=txt_loss.device), max_txt_loss / (txt_loss.detach() + 1e-8))
txt_loss = txt_loss * scale
txt_loss = torch.nan_to_num(txt_loss, nan=0.0)
img_loss = torch.nan_to_num(img_loss, nan=0.0)
if getattr(self.config.trainer, "force_remove_img_tokens", False):
img_loss = torch.tensor(0, device=loss.device, dtype=loss.dtype)
loss = txt_loss + img_loss
loss_dict.update(dict(txt_loss=txt_loss.clone().detach(), img_loss=img_loss.clone().detach()))
elif continuous_mode:
img_loss = F.mse_loss(img_output, target)
if attention_mask[:, self.static_txt_sl].numel() == 0:
# Let grads pass even though this is zeros...
txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum()
else:
txt_loss = (loss[:, self.static_txt_sl] * attention_mask[:, self.static_txt_sl]).sum() / attention_mask[:, self.static_txt_sl].sum()
loss = txt_loss + img_loss * self.config.trainer.image_loss_weight
loss_dict.update(dict(img_loss=img_loss.clone().detach(), txt_loss=txt_loss.clone().detach()))
else:
_attention_mask = torch.ones_like(attention_mask) if getattr(self.config.trainer, "force_full_attention_mask_loss_only", False) else attention_mask
loss = (loss * _attention_mask).sum() / _attention_mask.sum()
loss = torch.nan_to_num(loss, nan=0.0)
ar_loss = None
if self.config.trainer.ar_llm_loss:
assert not continuous_mode
valid_loss = xt == self.mask_index
_labels = x0.clone()
_labels = torch.where(valid_loss, _labels, -1)
_labels = torch.where(~attention_mask.to(torch.bool), -1, _labels)
_logits = true_logits
_logits[:, :, self.mask_index] += self.neg_infinity
if getattr(self.config.model, "force_argmax_valid_indices", False):
assert not self.config.trainer.multimodal_batches
_logits[:, self.static_txt_sl, self.text_vocab_size:] = torch.finfo(_logits.dtype).min
_logits[:, self.static_img_sl, : self.text_vocab_size] = torch.finfo(_logits.dtype).min
_logits = _logits.contiguous().view(-1, _logits.shape[-1])
_labels = _labels.contiguous().view(-1)
if self.config.trainer.ar_print_loss:
_labels = _labels.to(_logits.device)
ce_loss = loss_fct(_logits, _labels)
loss_fct = nn.CrossEntropyLoss(reduction='none')
ce_loss = ce_loss.mean(dim=-1)
if hasattr(self, 'histogram') is False:
self.histogram = {}
update_histogram(self.histogram, t, ce_loss)
rprint(f"ELM loss: move: {move_chance}, t:{t}, {ce_loss}")
loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduction='none' if joint_ar_nar_mask is not None else 'mean')
ce_loss = loss_fct(_logits, _labels)
loss_dict["extra_losses"]["trainer/ce_loss"] = ce_loss
ar_loss = ce_loss
if joint_ar_nar_mask is not None:
__true_logits = true_logits.clone()
__true_logits = torch.where(torch.arange(true_logits.shape[-1], device=true_logits.device)[None, None, :] == self.mask_index, self.neg_infinity, __true_logits)
log_softmax = __true_logits.log_softmax(-1)
ar_loss = -log_softmax.gather(-1, x0[:, :, None])[:, :, 0]
assert ar_loss is not None
assert ar_loss.ndim == 2
assert loss.ndim == 2
ar_loss_weight = joint_ar_nar_mask.sum(dim=0) / joint_ar_nar_mask.shape[0]
nar_loss_weight = 1 - ar_loss_weight
loss_dict["extra_losses"]["trainer/ar_loss_weight"] = ar_loss_weight.detach().float()
loss_dict["extra_losses"]["trainer/nar_loss_weight"] = nar_loss_weight.detach().float()
loss_dict["extra_losses"]["trainer/ce_loss"] = ar_loss.mean().detach().float()
ar_loss = (ar_loss * ar_loss_weight) * attention_mask
nar_loss = (loss * nar_loss_weight) * attention_mask
valid_count = attention_mask.sum()
if not is_xla_available:
ar_valid_count = attention_mask[joint_ar_nar_mask].sum()
nar_valid_count = attention_mask[~joint_ar_nar_mask].sum()
loss_dict["extra_losses"]["trainer/ar_loss"] = (ar_loss[joint_ar_nar_mask].sum() / ar_valid_count).detach().float()
loss_dict["extra_losses"]["trainer/nar_loss"] = (loss[~joint_ar_nar_mask].sum() / nar_valid_count).detach().float()
loss_dict["extra_losses"]["trainer/ar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/ar_loss"]).detach().float()
loss_dict["extra_losses"]["trainer/nar_ppl"] = torch.exp(loss_dict["extra_losses"]["trainer/nar_loss"]).detach().float()
loss = (torch.where(joint_ar_nar_mask[:, None], ar_loss, nar_loss).sum() / valid_count) + weighted_z_loss
elif ar_loss is not None:
loss = ar_loss
loss_dict = dict(loss=loss, **loss_dict)
std_loss = loss_dict.get("std_loss", 0)
std_nlls = std_loss * attention_mask
if "batch_ignore_loss" in loss_dict:
attention_mask = torch.where(loss_dict['batch_ignore_loss'][:, None].repeat(1, attention_mask.shape[-1]), torch.full_like(attention_mask, False), attention_mask)
losses = Loss(
loss=loss_dict["loss"],
img_loss=loss_dict.get("img_loss", 0),
txt_loss=loss_dict.get("txt_loss", 0),
nlls=std_nlls,
txt_nlls=loss_dict.get("std_txt_loss", 0),
img_nlls=loss_dict.get("std_img_loss", 0),
token_mask=attention_mask,
modality_mask=modality_mask,
extra_losses=loss_dict.get("extra_losses", None),
)
if getattr(self.config.trainer, "disable_torchmetrics", False):
raise NotImplementedError("Torchmetrics disabled")
elif prefix == "train":
return losses
elif prefix == "val":
self.valid_metrics.update(losses.nlls, losses.token_mask)
if hasattr(self, "valid_txt_metrics"):
self.valid_txt_metrics.update(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask)
self.valid_img_metrics.update(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask)
elif prefix == "test":
self.test_metrics.update(losses.nlls, losses.token_mask)
metrics = self.test_metrics
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
else:
raise ValueError(f"Invalid prefix: {prefix}")
@torch.no_grad()
def zero_shot_eval(self):
dataloader = self.validation_dataloader
total_batches = len(dataloader)
rprint(f"Zero shot eval with {total_batches} batches with limit_val_batches: {self.config.trainer.limit_val_batches}")
for idx, batch in tqdm(enumerate(dataloader), total=total_batches, desc="Zero shot eval validation steps", disable=not is_main_process()):
if self.config.trainer.limit_val_batches is not None and idx >= self.config.trainer.limit_val_batches:
break
self.zero_shot_eval_step(batch, idx)
self.zero_shot_eval_epoch_end()
def validate(self, state: TrainingState):
self.on_validation_epoch_start()
if getattr(self.config.eval, "compute_val_metrics_standalone", False) and getattr(self.config.eval, "bypass_normal_validation", False):
batch = next(iter(self.validation_dataloader))
self.on_validation_epoch_end(example_batch=batch)
self.on_validation_epoch_cleanup()
return
total_len = 10 if self.config.data.iterable or self.config.data.webdataset_indexed else len(self.validation_dataloader)
dprint(f"Validation batches: {total_len}")
total_batches = (
self.config.trainer.limit_val_batches
if (self.config.trainer.limit_val_batches is not None and self.fid_eval is False)
else total_len
)
if getattr(self.config.eval, 'pplx_full_dataset', False):
rprint("[INFO] PPLX full dataset eval, setting total_batches to total_len")
total_batches = total_len
elif self.config.eval.max_num_fid_batches_per_device is not None and self.fid_eval:
total_batches = min(total_len, self.config.eval.max_num_fid_batches_per_device)
_dataloader = self.train_dataloader if self.config.eval.val_with_train_data else self.validation_dataloader
rprint(f"Validating with {total_batches} batches on {self.world_size} GPUs with batch size {self.config.loader.eval_batch_size}")
for idx, batch in tqdm(enumerate(_dataloader), total=total_batches, desc="Validation steps", disable=not is_main_process()):
if self.config.trainer.limit_val_batches is not None and idx >= total_batches:
break
self.validation_step(batch, idx)
if getattr(self.config.eval, "eval_large_batch", None) is not None:
assert isinstance(batch, TensorDict)
dataloader_iter = iter(_dataloader)
large_batch = [next(dataloader_iter, None) for _ in range(getattr(self.config.eval, "eval_large_batch", None))]
large_batch = [b for b in large_batch if b is not None]
large_batch = torch.stack(large_batch, dim=0)
batch = large_batch
gprint(f"Large batch shape: {batch.shape}")
else:
batch = next(iter(_dataloader))
if self.config.eval.visualize_data_only:
return
if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
self.mauve_store_references(_dataloader)
if self.config.mode == "eval":
gprint(f"Batch shape: {batch['input_ids'].shape}")
self.on_validation_epoch_end(example_batch=batch)
self.on_validation_epoch_cleanup()
@cached_property
def global_batch_size(self):
"""Batch size for a single step over all GPUs"""
# SPMD treats all ranks [regardless of node] as a single device
return self.step_batch_size * (1 if (self.config.trainer.xla_spmd and is_xla_available) else self.world_size)
@cached_property
def step_batch_size(self):
"""Batch size for a single step for a single GPU"""
return self.config.loader.batch_size * self.config.trainer.accumulate_grad_batches
@cached_property
def world_size(self):
"""Number of GPUs over all nodes"""
return get_world_size()
@cached_property
def num_tokens_per_sample(self):
"""Number of tokens per sample"""
return self.config.model.length
@cached_property
def gradient_accumulation_steps(self):
"""Number of gradient accumulation steps"""
return self.config.trainer.accumulate_grad_batches
@cached_property
def static_txt_sl(self):
return slice(None, self.config.model.txt_length)
@cached_property
def static_img_sl(self):
return slice(-self.config.model.img_length, None)
def img_txt_pair_batch_mask(self, batch=None):
return batch["modality_mask"][..., 1].sum(dim=-1) > 0
def txt_sl(self, batch=None):
return batch["modality_mask"][..., 0]
def img_sl(self, batch=None):
return batch["modality_mask"][..., 1]
@cached_property
def is_compiled(self):
return is_xla_available or self.config.trainer.compile
@property
def allow_slicing(self):
return not is_xla_available and not self.backbone.training
@property
def training(self):
return self.backbone.training
def get_step_metrics(self):
return {
"trainer/global_step": self.global_step,
"global_samples": self.global_step * self.global_batch_size,
"train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length,
"effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0),
"effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)),
}
def train(self):
tr = self.config.trainer
total_batch_size = self.global_batch_size
initial_global_step = self.global_step
true_step = 0
first_epoch = 0
self.current_run_global_step = 0
self.current_run_fwd_bwd_pass = 0
rprint(f"Started at step {self.accelerator.step}")
if self.non_embedding_params < 1e9:
with try_except(write_error_to_file=True, clear_cuda_cache=True):
self.print_hashes()
# There is an unknown bug with accelerator where non-master ranks don't load the step count from a checkpoint.
# We workaround by broadcasting the step count if necessary
if is_torch_cuda_available():
dprint(f"Gathering step from {self.world_size} ranks")
starting_steps = gather_object([self.accelerator.step])
rprint(f"Starting steps: {starting_steps}")
if not all([x > 0 for x in starting_steps]):
rprint(f"Not all ranks have >0 step, setting to: {starting_steps[0]}")
self.accelerator.step = starting_steps[0]
if is_xla_available:
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
assert (self.config.trainer.accumulate_grad_batches == 1) or getattr(self.config.trainer, "allow_accum_grad_batches_xla", False), "Accumulate grad batches must be 1 for XLA"
rprint(f"***** Starting training at global step: {self.global_step} *****")
rprint(f" Instantaneous batch size per device = {self.config.loader.batch_size}")
rprint(f" Gradient Accumulation steps = {tr.accumulate_grad_batches}")
rprint(f" Num GPUs = {tr.devices}")
rprint(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
rprint(f" Total optimization steps = {tr.max_steps}")
rprint(f" Reported Global Batch Size: {self.global_batch_size}, Reported Step Batch Size: {self.step_batch_size}, Reported World Size: {self.world_size}")
if not self.config.data.iterable and not self.config.data.webdataset_indexed and is_torch_cuda_available():
num_epoch_steps = len(self.train_dataloader)
rprint(f" Num examples = {len(self.train_dataloader.dataset)}")
rprint(f" Num batches each epoch = {len(self.train_dataloader)}")
rprint(f"Train Dataloader Size on single GPU: {num_epoch_steps}")
if len(self.train_dataloader.dataset) < total_batch_size:
rprint("The training dataloader is smaller than the total batch size. This may lead to unexpected behaviour.")
else:
num_epoch_steps = 10000
if self.config.trainer.pytorch_profile:
profiler = Profiler(
output_dir=self.config.output_dir, warmup_steps=tr.profiler_warmup_steps, active_steps=tr.profiler_active_steps, record_memory=True
)
if self.config.trainer.viz_images_only:
return self.viz_images_from_dataloader()
progress_bar = tqdm(range(0, tr.max_steps), initial=initial_global_step, desc="Steps", disable=not is_local_main_process(), leave=False, smoothing=0.15)
global_step_metrics = defaultdict(float)
global_extra_wandb_metrics = dict()
accumulate_steps = 0
first_start_time = time.time()
self.on_train_start()
rprint(f"Training for {tr.num_epochs} epochs...")
last_end_step_time = start_timing(f"Dataloading accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
for epoch in range(first_epoch, tr.num_epochs):
rprint(f"Starting epoch {epoch}...")
for step, batch in enumerate(self.train_dataloader):
ddprint(f"Data Step: {step}")
if self.config.trainer.iterate_dataloader_only:
rprint(f"Iterating dataloader only: {step}")
# rprint((batch["modality"] == 0).sum(), (batch["modality"] == 1).sum())
if (batch["attention_mask"] == 0).all(dim=-1).any():
breakpoint()
batch = self.update_batch(batch)
if (batch["sample_ids"] == -1).all(dim=-1).any():
breakpoint()
continue
elif getattr(self.config.trainer, "iterate_dataloader_n_dataloader_batches", None) is not None and step <= self.config.trainer.iterate_dataloader_n_dataloader_batches:
self.current_run_fwd_bwd_pass += 1
if self.current_run_fwd_bwd_pass % self.config.trainer.accumulate_grad_batches == 0:
self.global_step += 1
self.current_run_global_step += 1
ddprint(f"Iterating dataloader only for {self.config.trainer.iterate_dataloader_n_dataloader_batches} dataloader batches. At step {self.global_step=}, {self.current_run_global_step=}, {self.current_run_fwd_bwd_pass=}")
continue
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
if self.config.trainer.sync_dataloader_timing: synchronize_device()
global_step_metrics[f"dataloading_time"] += end_timing(last_end_step_time)
if self.config.trainer.nvtx_profile and self.is_compiled and step == 4:
torch.cuda.cudart().cudaProfilerStart()
if self.current_run_global_step == 1 and is_xla_available:
gprint(f"First start time: {time.time() - first_start_time}")
if getattr(self.config.data, "force_dummy_tensordict", False):
gprint(self.global_step, self.current_run_global_step, true_step, batch["idx"].tolist(), batch["dataset_idx"].tolist())
if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.global_step == self.config.trainer.assert_at_n_steps:
gprint(batch["img_input_ids"].min(), batch["img_input_ids"].max(), batch["txt_input_ids"].min(), batch["txt_input_ids"].max())
if batch is None:
rprint(f"Batch is None at step {step}")
continue
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
ddprint(f"After Data Step 2: {step}")
with nullcontext() if is_xla_available else self.accelerator.accumulate(self.backbone):
ddprint(f"Before forward pass for global_step: {self.global_step}")
start_forward_time = start_timing(f"Forward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
global_step_metrics["examples_seen_per_gpu"] += len(next(iter(batch.values())))
state: TrainingState = TrainingState(
epoch_step=step,
num_epoch_steps=num_epoch_steps,
global_step=self.global_step,
epoch=epoch,
true_step=true_step,
current_run_global_step=self.current_run_global_step,
)
if self.accelerator.sync_gradients and is_xla_available is False:
self.cb_handler.on_train_step_start(state=state, unit=None)
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
ddprint(f"Before Fwd: {step}")
with xp.StepTrace('Forward', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
losses = self.training_step(batch, step)
ddprint(f"After Fwd: {step}")
global_step_metrics["forward_pass_time"] += end_timing(start_forward_time)
true_step += 1
evaluate_extra_log_data = lambda: dict()
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
if isinstance(losses, dict):
for k, v in losses.items():
if isinstance(v, torch.Tensor):
global_step_metrics[k.removeprefix("metric_")] += v.detach().cpu().item()
else:
global_extra_wandb_metrics[k.removeprefix("metric_")] = v
losses = dict(
filter(lambda item: not item[0].startswith("metric_"), losses.items())
) # Allow for custom metrics that are not losses
loss = sum(losses.values())
elif isinstance(losses, Loss):
loss = losses.loss
metrics = self.train_metrics(losses.nlls, losses.token_mask)
if hasattr(self, "txt_metrics") and losses.modality_mask is not None:
txt_metrics = self.txt_metrics(losses.txt_nlls, losses.modality_mask[..., 0] & losses.token_mask)
if hasattr(self, "img_metrics") and losses.modality_mask is not None:
img_metrics = self.img_metrics(losses.img_nlls, losses.modality_mask[..., 1] & losses.token_mask)
extra_losses_dict = losses.extra_losses
extra_losses_dict = extra_losses_dict if extra_losses_dict is not None else dict()
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
def evaluate_extra_log_data():
if hasattr(self, "txt_metrics"):
return {
**{f"train/txt_{k.split('/')[-1]}": v for k, v in replace_nan_dict(txt_metrics).items()},
**{f"train/img_{k.split('/')[-1]}": v for k, v in replace_nan_dict(img_metrics).items()},
}
else:
return {}
ddprint(f"Before loss: {step}")
incremental_dict_update(global_extra_wandb_metrics, {
"trainer/loss": loss,
"trainer/img_loss": losses.img_loss,
"trainer/txt_loss": losses.txt_loss,
**{
"global_samples": self.global_step * self.global_batch_size,
"train_metrics/global_tokens": self.global_step * self.global_batch_size * self.config.model.length,
"effective_global_tokens": self.global_step * self.global_batch_size * self.config.model.length * (0.5 if self.config.parameterization == "subs" else 1.0),
"effective_global_step": int(self.global_step * (0.5 if self.config.parameterization == "subs" else 1.0)),
},
**metrics,
**extra_losses_dict,
})
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
else:
loss = losses
if is_torch_cuda_available():
global_step_metrics["loss"] = loss.detach().cpu().item() # Only on the main process to avoid syncing
ddprint(f"Before backward pass for global_step: {self.global_step}")
# Short-circuit to avoid XLA eval
if tr.backward_pass and (is_xla_available or torch.isfinite(loss).all()):
start_backward_time = start_timing(f"Backward Pass accum:{accumulate_steps}, #{true_step}, global_step:{self.global_step}")
if self.accelerator.sync_gradients:
start_sync_time = start_timing(f"Gradient Sync global_step:{self.global_step}")
if getattr(self.config.trainer, "sync_timing", False):
sync_times(self.device)
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
# After each fwd, we perform a bwd. However, if we are accumulating there is an internal no_sync so the gradients remain on the GPU until
# the final bwd before a step. This can be controlled by sync_each_batch. Note that for the last bwd, the sync happens inside the bwd call below, so any timing for stragglers needs to happen before this call.
with xp.StepTrace('Backward', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
ddprint(f"Before accelerator.backward for global_step: {self.global_step}")
self.accelerator.backward(loss)
ddprint(f"After accelerator.backward for global_step: {self.global_step}")
with xp.StepTrace('After Backward + Clip', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
if self.accelerator.sync_gradients:
ddprint(f"Before after.backward for global_step: {self.global_step}")
self.after_backward(state)
if tr.gradient_clip_val is not None:
ddprint(f"Before self.accelerator.clip_grad_norm_ for global_step: {self.global_step}")
total_grad_norm = self.accelerator.clip_grad_norm_(self.backbone.parameters(), tr.gradient_clip_val)
ddprint(f"After self.accelerator.clip_grad_norm_ for global_step: {self.global_step}")
with xp.StepTrace('Optimizer + Scheduler Step', step_num=step) if self.config.trainer.tpu_profile else nullcontext():
ddprint(f"Before optimizer step for global_step: {self.global_step}, {step}")
if is_xla_available and False:
# TODO: xm.optimizer_step(self.optimizer) does not appear to be needed for XLA
xm.optimizer_step(self.optimizer)
else:
self.optimizer.step()
ddprint(f"After optimizer step for global_step: {self.global_step}, {step}")
self.lr_scheduler.step()
ddprint(f"After lr_scheduler step for global_step: {self.global_step}, {step}")
zero_grad_kwargs = dict()
if "apex" not in self.config.trainer.optimizer_cls:
zero_grad_kwargs["set_to_none"] = tr.set_grads_to_none
ddprint(f"Before zero_grad for global_step: {self.global_step}, {step}")
self.optimizer.zero_grad(**zero_grad_kwargs)
ddprint(f"Zeroed gradients for global_step: {self.global_step}, {step}")
if self.accelerator.sync_gradients:
if self.ema is not None:
if self.config.trainer.use_custom_ema:
ema_update(self.unwrap_model(self.ema), self.unwrap_model(self.backbone), self.config.trainer.ema)
else:
self.ema.step(self.get_params())
global_step_metrics["gradient_sync_time"] += end_timing(start_sync_time)
global_step_metrics["backward_pass_time"] += end_timing(start_backward_time)
else:
if not torch.isfinite(loss).all(): gprint(f"Loss is not finite: {loss}")
gprint("Skipping backward pass!")
accumulate_steps += 1
self.current_run_fwd_bwd_pass += 1
# Important: A single "global_step" is a single optimizer step. The accumulate decorator silently skips backward + optimizer to allow for gradient accumulation.
# A "true_step" counts the number of forward passes (on a per-GPU basis). The condition below should only happen immediately after a backward + optimizer step.
ddprint(f"Syncing gradients for global_step: {self.global_step}. Should sync: {self.accelerator.sync_gradients}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
if self.accelerator.sync_gradients:
start_gradient_sync_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}")
ddprint(f"Before on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
state.batch = batch
del loss, losses, batch
gradient_sync_time_after_train_step_end_time = start_timing(f"On Sync Gradients global_step:{self.global_step}, {step}")
self.on_train_step_end(state)
ddprint(f"After on_train_step_end for global_step: {self.global_step}, {step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
global_step_metrics["gradient_sync_time_after_train_step_end"] += end_timing(gradient_sync_time_after_train_step_end_time)
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
if self.config.trainer.profile_memory and self.global_step + 1 >= tr.max_steps:
rprint("Finished profiling memory...")
break
if self.config.trainer.pytorch_profile and profiler.step(self.global_step):
rprint(f"Profiling finished at step: {self.global_step}")
break
if getattr(self.config.trainer, "throw_failure_for_testing", False) and self.current_run_global_step == 5:
raise RuntimeError("Test failure")
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
progress_bar.update(1)
self.global_step += 1
self.current_run_global_step += 1
global_step_metrics["gradient_sync_time"] += end_timing(start_gradient_sync_time)
logs = {
"examples_seen": self.global_step * total_batch_size,
"trainer/global_step": self.global_step,
**{k:v for k, v in global_step_metrics.items()},
**{f"lr_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_last_lr())},
**global_extra_wandb_metrics,
}
if is_torch_cuda_available():
logs["gpu_max_mem_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
logs["gpu_cur_mem_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3)
logs["gpu_max_mem_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
logs["gpu_cur_mem_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3)
if is_xla_available:
if self.global_step % getattr(self.config.trainer, "log_every_n_steps", 1) == 0:
xm.add_step_closure(update_logs, args=(logs, evaluate_extra_log_data), run_async=False)
del logs
global_extra_wandb_metrics = dict()
if self.config.trainer.tpu_force_mark_step: xm.mark_step()
else:
logs.update(evaluate_extra_log_data())
progress_bar.set_postfix(**logs)
log(logs)
global_extra_wandb_metrics = dict()
if getattr(self.config.trainer, "sync_timing", False):
global_step_metrics = {f"rank_{get_rank()}/{k}": v for k, v in global_step_metrics.items()}
all_step_metrics = self.accelerator.gather_for_metrics([global_step_metrics], use_gather_object=True)
merged_metrics = {k: v for d in all_step_metrics for k, v in d.items()}
log(merged_metrics)
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
global_step_metrics = defaultdict(float)
accumulate_steps = 0
if self.global_step >= tr.max_steps:
break
ddprint(f"After logging for step v3: {self.global_step}, {step}")
if getattr(self.config.trainer, "assert_at_n_steps", None) is not None and self.current_run_global_step >= getattr(self.config.trainer, "assert_at_n_steps", None):
raise RuntimeError(f"Assertion failed at step {self.current_run_global_step}")
ddprint(f"After logging for step v4: {self.global_step}, {step}")
if is_xla_available and self.config.trainer.tpu_profile and (self.global_step == 0 or self.global_step % 50 == 0) and is_main_process():
import torch_xla.debug.metrics as met
rprint(met.metrics_report())
met.clear_all()
if is_xla_available and self.config.trainer.tpu_force_mark_step: xm.mark_step()
ddprint(f"Finished sync_gradients: {self.global_step}, {self.accelerator.step}, {self.accelerator.gradient_accumulation_steps}")
ddprint(f"Finished step: {self.global_step},{step},{self.accelerator.step},{self.accelerator.gradient_accumulation_steps},{self.accelerator.gradient_state.__repr__()}")
if self.config.trainer.sync_dataloader_timing: synchronize_device()
last_end_step_time = start_timing(f"Dataloading #{true_step + 1}")
if self.global_step >= tr.max_steps:
break
dprint(f"Finished epoch: {epoch}")
# Create the pipeline using using the trained modules and save it.
rprint("Training finished.")
barrier()
if tr.profile_memory:
print_memory(verbose=True)
save_memory_profile(self.config.output_dir / "profile")
if tr.pytorch_profile:
profiler.finish()
elif tr.nvtx_profile:
torch.cuda.cudart().cudaProfilerStop()
elif self.global_step > 100 or tr.skip_early_checkpointing is False:
self.checkpoint(state)
barrier()