import math import time import warnings from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader, Dataset from transformers import Trainer with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=FutureWarning) from transformers.deepspeed import deepspeed_init from transformers.integrations import TensorBoardCallback from transformers.trainer_pt_utils import ( IterableDatasetShard, find_batch_size, nested_concat, nested_detach, nested_numpify, ) from transformers.trainer_utils import denumpify_detensorize, has_length, speed_metrics from transformers.utils import ( is_apex_available, is_datasets_available, is_sagemaker_mp_enabled, logging, ) from sdlm.inference.inference_utils import ( logits_projection, predict_conditional_generated, ) from sdlm.models.utils import is_cdcd_check from sdlm.pipelines.simplex_ddpm import SimplexDDPMClassifierGuidancePipeline from sdlm.utils import convert_to_simplex, pad_data, scale, self_condition_preds if is_apex_available(): from apex import amp if is_datasets_available(): import datasets GENERATION_RESULTS = "generated" logger = logging.get_logger(__name__) class EvalLoopOutput(NamedTuple): logits: Union[np.ndarray, Tuple[np.ndarray]] simplex: Union[np.ndarray, Tuple[np.ndarray]] input_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] metrics: Optional[Dict[str, float]] results: Optional[Dict[str, List[str]]] num_samples: Optional[int] class DiffusionTrainer(Trainer): def __init__( self, noise_scheduler, inference_noise_schedulers, diffusion_args, data_args, *args, **kwargs, ): super().__init__(*args, **kwargs) self.original_data_collator = self.data_collator self.noise_scheduler = noise_scheduler self.diffusion_args = diffusion_args self.data_args = data_args self.vocab_size = self.model.config.vocab_size self.inference_noise_schedulers = inference_noise_schedulers self.inference_timesteps = diffusion_args.num_inference_diffusion_steps self.tb_writer = self.get_tb_writer() self.eos_token_id = self.tokenizer.eos_token_id self.classifier_free_guidance = ( diffusion_args.guidance_scale > 1.0 and data_args.conditional_generation is not None ) self.counter = 0 # TODO: control seed. self.self_cond_generator = np.random.default_rng(42) def annotated_split(self, split): return f"{split}_top_p_{self.diffusion_args.top_p}_temperature_{self.diffusion_args.temperature}_seed_{self.args.seed}_guidance_scale_{self.diffusion_args.guidance_scale}" def save_metrics(self, split, metrics, combined=True): super().save_metrics(self.annotated_split(split), metrics, combined) def log_metrics(self, split, metrics): super().log_metrics(self.annotated_split(split), metrics) def get_tb_writer(self): for cb in self.callback_handler.callbacks: if isinstance(cb, TensorBoardCallback): return cb return None def training_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> torch.Tensor: """ Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to train. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. Return: `torch.Tensor`: The tensor with training loss on this batch. """ model.train() inputs = self._prepare_inputs(inputs) # Truncate the length if needed. if self.data_args.truncation_length > 0: inputs["input_ids"] = inputs["input_ids"][ :, : -self.data_args.truncation_length ] inputs["span_mask"] = inputs["span_mask"][ :, : -self.data_args.truncation_length ] # Creates the noisy simplex and timesteps. simplex = convert_to_simplex( inputs["input_ids"], self.diffusion_args.simplex_value, self.vocab_size ) noise = self.diffusion_args.simplex_value * torch.randn( simplex.shape, device=simplex.device, dtype=simplex.dtype ) bsz = simplex.shape[0] # Sample a random timestep for each simplex token representation. # testing just sampling the same place. This better matches reality. if True: # np.random.rand(1) > 0.5: timesteps = torch.randint( 0, len(self.noise_scheduler), (bsz, inputs["input_ids"].shape[1]) if False # is_tokenwise_cdcd_check(self.model) else (bsz,), device=simplex.device, dtype=torch.int64, ) timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1]) else: timesteps = torch.randint( 0, len(self.noise_scheduler), (bsz, inputs["input_ids"].shape[1]) if True # is_tokenwise_cdcd_check(self.model) else (bsz,), device=simplex.device, dtype=torch.int64, ) # expand out timesteps to match tokenwise setup # if True: # not is_tokenwise_cdcd_check(self.model): # timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1]) # save original timesteps for warping original_timesteps = timesteps # warp timesteps according to cdf # we re-scale the timesteps to the correct range. # the -1 is due to the timestep should be in range [0, 5000) if is_cdcd_check(self.model): input_ids = inputs["input_ids"] span_mask = inputs["span_mask"] token_input = torch.where( (input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids ) timesteps = self.model.warp_timesteps( timesteps, token_input=token_input, span_mask=span_mask, t_max=len(self.noise_scheduler) - 1, ) # Adds noise to each simplex representation (Forward diffusion process). noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps) # the warper model will scale the timesteps to the correct range. timesteps = scale(timesteps, len(self.noise_scheduler)) # original_timesteps_scaled = scale(original_timesteps, len(self.noise_scheduler)) # inputs.update( # {"original_timesteps": scale(original_timesteps, len(self.noise_scheduler))} # ) inputs.update( { "timesteps": timesteps, "simplex": noisy_simplex, } ) # inputs.update({"max_timestep": len(self.noise_scheduler)}) if self.diffusion_args.self_condition is not None: previous_pred = None # previous_hidden = None if self.self_cond_generator.random(1) > 0.5: next_timestep = inputs.pop("timesteps") next_simplex = inputs.pop("simplex") timesteps = torch.clamp( (next_timestep * len(self.noise_scheduler)) + 1, max=len(self.noise_scheduler) - 1, ) if is_cdcd_check(self.model): input_ids = inputs["input_ids"] span_mask = inputs["span_mask"] token_input = torch.where( (input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids, ) timesteps = self.model.warp_timesteps( timesteps, token_input=token_input, span_mask=span_mask, t_max=len(self.noise_scheduler) - 1, ) noisy_simplex = self.noise_scheduler.add_noise( simplex, noise, timesteps ) timesteps = scale(timesteps, len(self.noise_scheduler)) inputs.update( { "timesteps": timesteps, "simplex": noisy_simplex, } ) # we don't backprop through this. with torch.no_grad(): outputs = model(**inputs, previous_pred=previous_pred) logits_projection_fct = lambda x: logits_projection( # noqa: E731 x, self.diffusion_args.sampling_type, self.diffusion_args.top_p, self.diffusion_args.simplex_value, self.diffusion_args.temperature, ) previous_pred = self_condition_preds( self.diffusion_args.self_condition, outputs.logits, logits_projection_fct, ).detach() # following rest of self-conditioning, don't backprop through. # previous_hidden = outputs.hidden_states.detach() # pop timestep/simplex and put the old ones back. inputs.update( { "timesteps": next_timestep, "simplex": next_simplex, } ) inputs.update({"previous_pred": previous_pred}) # inputs.update({"previous_hidden": previous_hidden}) else: inputs.update({"previous_pred": None}) # inputs.update({"previous_hidden": None}) # previous_hidden = None # NOTE: we do this after computation of self-conditioning to not affect that one. # inputs.update( # {"classifier_free_guidance_in_train": self.classifier_free_guidance} # ) # re-warp based on previous hidden state if is_cdcd_check(self.model): # replace masked tokens with token. input_ids = inputs["input_ids"] span_mask = inputs["span_mask"] token_input = torch.where( (input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids ) timesteps = self.model.warp_timesteps( original_timesteps, t_max=len(self.noise_scheduler) - 1, token_input=token_input, span_mask=span_mask, ) noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps) timesteps = scale(timesteps, len(self.noise_scheduler)) inputs.update( { "timesteps": timesteps, "simplex": noisy_simplex, } ) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training # HACK: transformer update # if self.do_grad_scaling: # self.scaler.scale(loss).backward() elif self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: self.accelerator.backward(loss) return loss.detach() / self.args.gradient_accumulation_steps def light_prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: with torch.no_grad(): inputs = self._prepare_inputs(inputs) # Truncate the length if needed. if self.data_args.truncation_length > 0: inputs["input_ids"] = inputs["input_ids"][ :, : -self.data_args.truncation_length ] inputs["span_mask"] = inputs["span_mask"][ :, : -self.data_args.truncation_length ] # Creates the noisy simplex and timesteps. simplex = convert_to_simplex( inputs["input_ids"], self.diffusion_args.simplex_value, self.vocab_size ) noise = self.diffusion_args.simplex_value * torch.randn( simplex.shape, device=simplex.device, dtype=simplex.dtype ) bsz = simplex.shape[0] # Sample a random timestep for each simplex token representation. # we use the train timesteps to be consistent with the training process. # randomly flip between random batchwise and tokenwise timesteps. if True: timesteps = torch.randint( 0, len(self.noise_scheduler), (bsz, inputs["input_ids"].shape[1]) if False # is_tokenwise_cdcd_check(self.model) else (bsz,), device=simplex.device, dtype=torch.int64, ) timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1]) else: timesteps = torch.randint( 0, len(self.noise_scheduler), (bsz, inputs["input_ids"].shape[1]) if True # is_tokenwise_cdcd_check(self.model) else (bsz,), device=simplex.device, dtype=torch.int64, ) # original_timesteps = timesteps # if cdcd, we need to wrap the timesteps in a cdf. # make sure we scale the timesteps to the correct range! if is_cdcd_check(self.model): input_ids = inputs["input_ids"] span_mask = inputs["span_mask"] token_input = torch.where( (input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids ) timesteps = self.model.warp_timesteps( timesteps, t_max=len(self.noise_scheduler) - 1, token_input=token_input, span_mask=span_mask, ) # Adds noise to each simplex representation (Forward diffusion process). noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps) timesteps = scale(timesteps, len(self.noise_scheduler)) # original_timesteps_scaled = scale( # original_timesteps, len(self.noise_scheduler) # ) # inputs.update({"original_timesteps": original_timesteps_scaled}) inputs.update( { "timesteps": timesteps, "simplex": noisy_simplex, } ) # inputs.update({"max_timestep": len(self.noise_scheduler)}) if self.diffusion_args.self_condition is not None: previous_pred = None # last_hidden_state = None if np.random.rand(1) > 0.5: outputs = model(**inputs, previous_pred=previous_pred) logits_projection_fct = lambda x: logits_projection( # noqa: E731 x, self.diffusion_args.sampling_type, self.diffusion_args.top_p, self.diffusion_args.simplex_value, self.diffusion_args.temperature, ) previous_pred = self_condition_preds( self.diffusion_args.self_condition, outputs.logits, logits_projection_fct, ) # last_hidden_state = outputs.hidden_states inputs.update( { "previous_pred": previous_pred, # "previous_hidden": last_hidden_state, } ) # NOTE: we do this after computation of self-conditioning to not affect that one. # inputs.update( # {"classifier_free_guidance_in_train": self.classifier_free_guidance} # ) with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training return ( loss.detach() ) # no division by gradient accumulation steps for eval. we want per-sample avg loss. # TODO: argument for doing one step. def prediction_step( self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module, pipeline: List[SimplexDDPMClassifierGuidancePipeline], ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: inputs = self._prepare_inputs(inputs) # full inference. with torch.no_grad(): with self.compute_loss_context_manager(): for i, x in enumerate( pipeline( seq_length=self.data_args.max_seq_length - self.data_args.truncation_length, batch=inputs, guidance_scale=self.diffusion_args.guidance_scale, generator=torch.Generator(device=self.args.device).manual_seed( self.args.seed ) if self.diffusion_args.generate_with_seed else None, is_generator=False, use_gumbel_softmax=self.diffusion_args.use_gumbel_softmax, do_hard_sample=self.diffusion_args.do_hard_sample, softmax_temperature=self.diffusion_args.softmax_temperature, num_guidance_steps=self.diffusion_args.num_guidance_steps, ) ): outputs = x logits = nested_detach(outputs.logits) simplex = nested_detach(outputs.simplex) return (simplex, logits) def evaluation_loop( self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", noise_scheduler=None, light_eval_dataloader=None, do_light_eval=False, ) -> EvalLoopOutput: """ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. Works both with or without labels. """ args = self.args is_conditional_generation = self.data_args.conditional_generation is not None save_prefixes = is_conditional_generation prediction_loss_only = ( prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only ) # if eval is called w/o train handle model prep here if self.is_deepspeed_enabled and self.model_wrapped is self.model: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: model = ( self.accelerator.prepare(model) if self.is_deepspeed_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: if args.fp16_full_eval: model = model.to(dtype=torch.float16, device=args.device) elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = self.args.eval_batch_size logger.info(f"***** Running {description} *****") if has_length(dataloader): logger.info(f" Num examples = {self.num_examples(dataloader)}") else: logger.info(" Num examples: Unknown") logger.info(f" Batch size = {batch_size}") model.eval() pipeline = SimplexDDPMClassifierGuidancePipeline( model=model, scheduler=noise_scheduler, simplex_value=self.diffusion_args.simplex_value, top_p=self.diffusion_args.top_p, sampling_type=self.diffusion_args.sampling_type, is_conditional_generation=is_conditional_generation, tokenizer=self.tokenizer, classifier_free_uncond_input=self.diffusion_args.classifier_free_uncond_input, temperature=self.diffusion_args.temperature, guidance_softmax_combination=self.diffusion_args.guidance_softmax_combination, classifier_model_name_or_path=self.diffusion_args.classifier_model_name_or_path, ) self.callback_handler.eval_dataloader = dataloader # Do this before wrapping. eval_dataset = getattr(dataloader, "dataset", None) # Initialize containers # logits/simplex/labels on GPU/TPU (accumulated for eval_accumulation_steps) losses_host = None logits_host = None simplex_host = None inputs_host = None masks_host = None prefixes_host = None # logits/simplex/labels on CPU (final containers) all_losses = None all_logits = None all_simplex = None all_inputs = None all_masks = None all_prefixes = None observed_num_examples = 0 # light evaluation loop. if light_eval_dataloader is not None and do_light_eval: for step, inputs in enumerate(light_eval_dataloader): # Truncate the length if needed. if self.data_args.truncation_length > 0: inputs["input_ids"] = inputs["input_ids"][ :, : -self.data_args.truncation_length ] inputs["span_mask"] = inputs["span_mask"][ :, : -self.data_args.truncation_length ] max_seq_length = ( self.data_args.max_seq_length - self.data_args.truncation_length ) assert self.data_args.eval_context_size < max_seq_length # predict loss mimicking training. loss = self.light_prediction_step(model, inputs) if loss is not None: losses = self._nested_gather(loss.repeat(batch_size)) losses_host = ( losses if losses_host is None else torch.cat((losses_host, losses), dim=0) ) if ( args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 ): if losses_host is not None: losses = nested_numpify(losses_host) all_losses = ( losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) ) losses_host = None # Main evaluation loop for step, inputs in enumerate(dataloader): has_mask = True if "span_mask" in inputs else False # Truncate the length if needed. if self.data_args.truncation_length > 0: inputs["input_ids"] = inputs["input_ids"][ :, : -self.data_args.truncation_length ] inputs["span_mask"] = inputs["span_mask"][ :, : -self.data_args.truncation_length ] max_seq_length = ( self.data_args.max_seq_length - self.data_args.truncation_length ) assert self.data_args.eval_context_size < max_seq_length # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: observed_num_examples += observed_batch_size # For batch samplers, batch_size is not known by the dataloader in advance. if batch_size is None: batch_size = observed_batch_size # Prediction step simplex, logits = self.prediction_step(inputs, model, pipeline=pipeline) inputs_decode = self._prepare_input(inputs["input_ids"]) masks = self._prepare_input(inputs["span_mask"]) if has_mask else None if save_prefixes: prefixes = ( pad_data( [input[~mask] for input, mask in zip(inputs_decode, masks)], self.tokenizer, ) if has_mask else None ) prefixes = self._prepare_input(prefixes) else: prefixes = None # Update containers on host if prefixes is not None: prefixes = self.accelerator.pad_across_processes( prefixes, dim=1, pad_index=self.eos_token_id ) prefixes = self._nested_gather(prefixes) prefixes_host = ( prefixes if prefixes_host is None else nested_concat( prefixes_host, prefixes, padding_index=self.eos_token_id ) ) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes( inputs_decode, dim=1, pad_index=self.eos_token_id ) inputs_decode = self._nested_gather(inputs_decode) inputs_host = ( inputs_decode if inputs_host is None else nested_concat( inputs_host, inputs_decode, padding_index=self.eos_token_id ) ) # Note that this block should be before masks block, since we need masks here. if simplex is not None: # In case of having a mask softmax is applied over the simplex non-masked values. if has_mask: mask_value = torch.finfo(simplex.dtype).min mask_value = torch.tensor( mask_value, dtype=simplex.dtype, device=simplex.device ) simplex = torch.where(masks[:, :, None], simplex, mask_value) simplex = F.softmax(simplex, dim=-1) if self.preprocess_logits_for_metrics is not None: simplex = self.preprocess_logits_for_metrics(simplex) simplex = self.accelerator.pad_across_processes( simplex, dim=1, pad_index=self.eos_token_id ) simplex = self._nested_gather(simplex) # TODO: note that this is no more a simplex, but the processed one. simplex_host = ( simplex if simplex_host is None else nested_concat( simplex_host, simplex, padding_index=self.eos_token_id ) ) if masks is not None: masks = self.accelerator.pad_across_processes(masks, dim=1, pad_index=0) masks = self._nested_gather(masks) # We pad masks with False tokens. masks_host = ( masks if masks_host is None else nested_concat(masks_host, masks, padding_index=0) ) if logits is not None: if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits) logits = self.accelerator.pad_across_processes( logits, dim=1, pad_index=self.eos_token_id ) logits = self._nested_gather(logits) logits_host = ( logits if logits_host is None else nested_concat( logits_host, logits, padding_index=self.eos_token_id ) ) self.control = self.callback_handler.on_prediction_step( args, self.state, self.control ) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if ( args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 ): if logits_host is not None: logits = nested_numpify(logits_host) all_logits = ( logits if all_logits is None else nested_concat( all_logits, logits, padding_index=self.eos_token_id ) ) if simplex_host is not None: simplex = nested_numpify(simplex_host) all_simplex = ( simplex if all_simplex is None else nested_concat( all_simplex, simplex, padding_index=self.eos_token_id ) ) if inputs_host is not None: inputs_decode = nested_numpify(inputs_host) all_inputs = ( inputs_decode if all_inputs is None else nested_concat( all_inputs, inputs_decode, padding_index=self.eos_token_id ) ) if masks_host is not None: masks = nested_numpify(masks_host) all_masks = ( masks if all_masks is None else nested_concat(all_masks, masks, padding_index=0) ) if prefixes_host is not None: prefixes = nested_numpify(prefixes_host) all_prefixes = ( prefixes if all_prefixes is None else nested_concat( all_prefixes, prefixes, padding_index=self.eos_token_id ) ) # Set back to None to begin a new accumulation logits_host, simplex_host, inputs_host, masks_host, prefixes_host = ( None, None, None, None, None, ) # Gather all remaining tensors and put them back on the CPU if losses_host is not None: all_losses = nested_numpify(losses_host) if logits_host is not None: all_logits = nested_numpify(logits_host) if simplex_host is not None: all_simplex = nested_numpify(simplex_host) if inputs_host is not None: all_inputs = nested_numpify(inputs_host) if masks_host is not None: all_masks = nested_numpify(masks_host) if prefixes_host is not None: all_prefixes = nested_numpify(prefixes_host) if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") # Number of samples if has_length(eval_dataset): num_samples = len(eval_dataset) # The instance check is weird and does not actually check for the type, but whether the dataset has the right # methods. Therefore we need to make sure it also has the attribute. elif ( isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0 ): num_samples = eval_dataset.num_examples else: if has_length(dataloader): num_samples = self.num_examples(dataloader) else: # both len(dataloader.dataset) and len(dataloader) fail num_samples = observed_num_examples if num_samples == 0 and observed_num_examples > 0: num_samples = observed_num_examples # Generates the texts. results = {} if is_conditional_generation: # We predict the masked tokens only. Here, we compute the masked tokens. results.update( predict_conditional_generated( all_masks, all_inputs, self.tokenizer, all_simplex, "pred_texts_from_simplex", self.data_args.skip_special_tokens, ) ) results.update( predict_conditional_generated( all_masks, all_inputs, self.tokenizer, all_logits, "pred_texts_from_logits", self.data_args.skip_special_tokens, ) ) else: results.update( { "pred_texts_from_simplex": self.tokenizer.batch_decode( all_simplex, skip_special_tokens=self.data_args.skip_special_tokens, ) } ) results.update( { "pred_texts_from_logits": self.tokenizer.batch_decode( all_logits, skip_special_tokens=self.data_args.skip_special_tokens, ) } ) if is_conditional_generation: results.update( { "gold_texts_masked": [ self.tokenizer.decode( input[mask], skip_special_tokens=self.data_args.skip_special_tokens, ) for mask, input in zip(all_masks, all_inputs) ] } ) if save_prefixes: results.update( { "prefixes": [ self.tokenizer.decode( x, skip_special_tokens=True ) # self.data_args.skip_special_tokens) for x in all_prefixes ] } ) # Metrics. if self.compute_metrics is not None: metrics = self.compute_metrics(results) else: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) if all_losses is not None: metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) return EvalLoopOutput( logits=all_logits, simplex=all_simplex, input_ids=all_inputs, metrics=metrics, num_samples=num_samples, results=results, ) def evaluate( self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init `compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (`Dataset`, *optional*): Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` method. ignore_keys (`Lst[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is "eval" (default) Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ # memory metrics - must set up as early as possible self._memory_tracker.start() eval_dataloader = self.get_eval_dataloader(eval_dataset) light_eval_dataloader = self.get_light_eval_dataloader(eval_dataset) start_time = time.time() outputs = [] timesteps = self.inference_timesteps for timestep, noise_scheduler in zip( timesteps, self.inference_noise_schedulers ): output = self.evaluation_loop( eval_dataloader, description="Evaluation", # No point gathering the predictions if there are no metrics, otherwise we defer to # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, noise_scheduler=noise_scheduler, light_eval_dataloader=light_eval_dataloader, do_light_eval=timestep == timesteps[ 0 ], # we only need the loss once, since it is the same for all timesteps ) outputs.append(output) key_prefix = f"inference_{timestep}_" metrics = {key_prefix + k: v for k, v in output.metrics.items()} results = {key_prefix + k: v for k, v in output.results.items()} # reset output with new metrics / results output = EvalLoopOutput( logits=output.logits, simplex=output.simplex, input_ids=output.input_ids, metrics=metrics, num_samples=output.num_samples, results=results, ) total_batch_size = self.args.eval_batch_size * self.args.world_size output.metrics.update( speed_metrics( metric_key_prefix, start_time, num_samples=output.num_samples, num_steps=math.ceil(output.num_samples / total_batch_size), ) ) self.log(output.metrics) self.control = self.callback_handler.on_evaluate( self.args, self.state, self.control, output.metrics ) self._memory_tracker.stop_and_update_metrics(output.metrics) # Save the results self.save_metrics( GENERATION_RESULTS + "_" + key_prefix + metric_key_prefix, output.results, ) logger.info("Results are saved now") # log outside so we can group generations together if self.args.log_generated_texts: length = len(outputs[0].logits) results = { f"{k}_inference_{i}": v for o, i in zip(outputs, timesteps) for k, v in o.results.items() } self.log_results_to_tensorboard(self.state, length, results) return output.metrics def log_results_to_tensorboard(self, state, length, results): # TODO: we need to fix this which happens during the only eval option. if self.tb_writer.tb_writer is None: return for i in range(length): total_text = "" for k, v in results.items(): total_text += f"*** {k} ***: {v[i]}" + " \n" self.tb_writer.tb_writer.add_text( f"sample_{i}", total_text, state.global_step ) def get_train_dataloader(self) -> DataLoader: self.data_collator = self.original_data_collator("train") return super().get_train_dataloader() def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: self.data_collator = self.original_data_collator("eval") return super().get_eval_dataloader(eval_dataset) def get_light_eval_dataloader( self, eval_dataset: Optional[Dataset] = None ) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. Used for the light evaluation, which matches masking with training. Args: eval_dataset (`torch.utils.data.Dataset`, *optional*): If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset data_collator = self.original_data_collator("train") if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns( eval_dataset, description="evaluation" ) else: data_collator = self._get_collator_with_removed_columns( data_collator, description="evaluation" ) dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, } if not isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) def create_optimizer(self): from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is not None: return self.optimizer decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args ) # override to apply higher lr to timestep_embed and cdcd cdf optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if ( n in decay_parameters and p.requires_grad and not ("timestep_embed" in n or "cdf" in n) ) ], "weight_decay": self.args.weight_decay, "lr": optimizer_kwargs["lr"], }, { "params": [ p for n, p in opt_model.named_parameters() if ( n not in decay_parameters and p.requires_grad and not ("timestep_embed" in n or "cdf" in n) ) ], "weight_decay": 0.0, "lr": optimizer_kwargs["lr"], }, { "params": [ p for n, p in opt_model.named_parameters() if (("timestep_embed" in n) and p.requires_grad) ], "weight_decay": 0.0, "lr": self.args.timestep_embed_lr or self.args.learning_rate, }, ] # check cdcd cdf_params = [ p for n, p in opt_model.named_parameters() if (("cdf" in n) and p.requires_grad) ] if cdf_params: optimizer_grouped_parameters.append( { "params": cdf_params, "weight_decay": 0.0, "lr": 1e-3, } ) optimizer_kwargs.pop("lr") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return self.optimizer