tess-2-demo / sdlm /trainers /trainer_diffusion.py
hamishivi's picture
commit
17ff0d8 verified
import math
import time
import warnings
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import Trainer
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
from transformers.deepspeed import deepspeed_init
from transformers.integrations import TensorBoardCallback
from transformers.trainer_pt_utils import (
IterableDatasetShard,
find_batch_size,
nested_concat,
nested_detach,
nested_numpify,
)
from transformers.trainer_utils import denumpify_detensorize, has_length, speed_metrics
from transformers.utils import (
is_apex_available,
is_datasets_available,
is_sagemaker_mp_enabled,
logging,
)
from sdlm.inference.inference_utils import (
logits_projection,
predict_conditional_generated,
)
from sdlm.models.utils import is_cdcd_check
from sdlm.pipelines.simplex_ddpm import SimplexDDPMClassifierGuidancePipeline
from sdlm.utils import convert_to_simplex, pad_data, scale, self_condition_preds
if is_apex_available():
from apex import amp
if is_datasets_available():
import datasets
GENERATION_RESULTS = "generated"
logger = logging.get_logger(__name__)
class EvalLoopOutput(NamedTuple):
logits: Union[np.ndarray, Tuple[np.ndarray]]
simplex: Union[np.ndarray, Tuple[np.ndarray]]
input_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
metrics: Optional[Dict[str, float]]
results: Optional[Dict[str, List[str]]]
num_samples: Optional[int]
class DiffusionTrainer(Trainer):
def __init__(
self,
noise_scheduler,
inference_noise_schedulers,
diffusion_args,
data_args,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.original_data_collator = self.data_collator
self.noise_scheduler = noise_scheduler
self.diffusion_args = diffusion_args
self.data_args = data_args
self.vocab_size = self.model.config.vocab_size
self.inference_noise_schedulers = inference_noise_schedulers
self.inference_timesteps = diffusion_args.num_inference_diffusion_steps
self.tb_writer = self.get_tb_writer()
self.eos_token_id = self.tokenizer.eos_token_id
self.classifier_free_guidance = (
diffusion_args.guidance_scale > 1.0
and data_args.conditional_generation is not None
)
self.counter = 0
# TODO: control seed.
self.self_cond_generator = np.random.default_rng(42)
def annotated_split(self, split):
return f"{split}_top_p_{self.diffusion_args.top_p}_temperature_{self.diffusion_args.temperature}_seed_{self.args.seed}_guidance_scale_{self.diffusion_args.guidance_scale}"
def save_metrics(self, split, metrics, combined=True):
super().save_metrics(self.annotated_split(split), metrics, combined)
def log_metrics(self, split, metrics):
super().log_metrics(self.annotated_split(split), metrics)
def get_tb_writer(self):
for cb in self.callback_handler.callbacks:
if isinstance(cb, TensorBoardCallback):
return cb
return None
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
# Truncate the length if needed.
if self.data_args.truncation_length > 0:
inputs["input_ids"] = inputs["input_ids"][
:, : -self.data_args.truncation_length
]
inputs["span_mask"] = inputs["span_mask"][
:, : -self.data_args.truncation_length
]
# Creates the noisy simplex and timesteps.
simplex = convert_to_simplex(
inputs["input_ids"], self.diffusion_args.simplex_value, self.vocab_size
)
noise = self.diffusion_args.simplex_value * torch.randn(
simplex.shape, device=simplex.device, dtype=simplex.dtype
)
bsz = simplex.shape[0]
# Sample a random timestep for each simplex token representation.
# testing just sampling the same place. This better matches reality.
if True: # np.random.rand(1) > 0.5:
timesteps = torch.randint(
0,
len(self.noise_scheduler),
(bsz, inputs["input_ids"].shape[1])
if False # is_tokenwise_cdcd_check(self.model)
else (bsz,),
device=simplex.device,
dtype=torch.int64,
)
timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1])
else:
timesteps = torch.randint(
0,
len(self.noise_scheduler),
(bsz, inputs["input_ids"].shape[1])
if True # is_tokenwise_cdcd_check(self.model)
else (bsz,),
device=simplex.device,
dtype=torch.int64,
)
# expand out timesteps to match tokenwise setup
# if True: # not is_tokenwise_cdcd_check(self.model):
# timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1])
# save original timesteps for warping
original_timesteps = timesteps
# warp timesteps according to cdf
# we re-scale the timesteps to the correct range.
# the -1 is due to the timestep should be in range [0, 5000)
if is_cdcd_check(self.model):
input_ids = inputs["input_ids"]
span_mask = inputs["span_mask"]
token_input = torch.where(
(input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids
)
timesteps = self.model.warp_timesteps(
timesteps,
token_input=token_input,
span_mask=span_mask,
t_max=len(self.noise_scheduler) - 1,
)
# Adds noise to each simplex representation (Forward diffusion process).
noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps)
# the warper model will scale the timesteps to the correct range.
timesteps = scale(timesteps, len(self.noise_scheduler))
# original_timesteps_scaled = scale(original_timesteps, len(self.noise_scheduler))
# inputs.update(
# {"original_timesteps": scale(original_timesteps, len(self.noise_scheduler))}
# )
inputs.update(
{
"timesteps": timesteps,
"simplex": noisy_simplex,
}
)
# inputs.update({"max_timestep": len(self.noise_scheduler)})
if self.diffusion_args.self_condition is not None:
previous_pred = None
# previous_hidden = None
if self.self_cond_generator.random(1) > 0.5:
next_timestep = inputs.pop("timesteps")
next_simplex = inputs.pop("simplex")
timesteps = torch.clamp(
(next_timestep * len(self.noise_scheduler)) + 1,
max=len(self.noise_scheduler) - 1,
)
if is_cdcd_check(self.model):
input_ids = inputs["input_ids"]
span_mask = inputs["span_mask"]
token_input = torch.where(
(input_ids * span_mask) > 1,
self.tokenizer.pad_token_id,
input_ids,
)
timesteps = self.model.warp_timesteps(
timesteps,
token_input=token_input,
span_mask=span_mask,
t_max=len(self.noise_scheduler) - 1,
)
noisy_simplex = self.noise_scheduler.add_noise(
simplex, noise, timesteps
)
timesteps = scale(timesteps, len(self.noise_scheduler))
inputs.update(
{
"timesteps": timesteps,
"simplex": noisy_simplex,
}
)
# we don't backprop through this.
with torch.no_grad():
outputs = model(**inputs, previous_pred=previous_pred)
logits_projection_fct = lambda x: logits_projection( # noqa: E731
x,
self.diffusion_args.sampling_type,
self.diffusion_args.top_p,
self.diffusion_args.simplex_value,
self.diffusion_args.temperature,
)
previous_pred = self_condition_preds(
self.diffusion_args.self_condition,
outputs.logits,
logits_projection_fct,
).detach()
# following rest of self-conditioning, don't backprop through.
# previous_hidden = outputs.hidden_states.detach()
# pop timestep/simplex and put the old ones back.
inputs.update(
{
"timesteps": next_timestep,
"simplex": next_simplex,
}
)
inputs.update({"previous_pred": previous_pred})
# inputs.update({"previous_hidden": previous_hidden})
else:
inputs.update({"previous_pred": None})
# inputs.update({"previous_hidden": None})
# previous_hidden = None
# NOTE: we do this after computation of self-conditioning to not affect that one.
# inputs.update(
# {"classifier_free_guidance_in_train": self.classifier_free_guidance}
# )
# re-warp based on previous hidden state
if is_cdcd_check(self.model):
# replace masked tokens with <mask> token.
input_ids = inputs["input_ids"]
span_mask = inputs["span_mask"]
token_input = torch.where(
(input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids
)
timesteps = self.model.warp_timesteps(
original_timesteps,
t_max=len(self.noise_scheduler) - 1,
token_input=token_input,
span_mask=span_mask,
)
noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps)
timesteps = scale(timesteps, len(self.noise_scheduler))
inputs.update(
{
"timesteps": timesteps,
"simplex": noisy_simplex,
}
)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
# HACK: transformer update
# if self.do_grad_scaling:
# self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
def light_prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
with torch.no_grad():
inputs = self._prepare_inputs(inputs)
# Truncate the length if needed.
if self.data_args.truncation_length > 0:
inputs["input_ids"] = inputs["input_ids"][
:, : -self.data_args.truncation_length
]
inputs["span_mask"] = inputs["span_mask"][
:, : -self.data_args.truncation_length
]
# Creates the noisy simplex and timesteps.
simplex = convert_to_simplex(
inputs["input_ids"], self.diffusion_args.simplex_value, self.vocab_size
)
noise = self.diffusion_args.simplex_value * torch.randn(
simplex.shape, device=simplex.device, dtype=simplex.dtype
)
bsz = simplex.shape[0]
# Sample a random timestep for each simplex token representation.
# we use the train timesteps to be consistent with the training process.
# randomly flip between random batchwise and tokenwise timesteps.
if True:
timesteps = torch.randint(
0,
len(self.noise_scheduler),
(bsz, inputs["input_ids"].shape[1])
if False # is_tokenwise_cdcd_check(self.model)
else (bsz,),
device=simplex.device,
dtype=torch.int64,
)
timesteps = timesteps[:, None].expand(-1, inputs["input_ids"].shape[1])
else:
timesteps = torch.randint(
0,
len(self.noise_scheduler),
(bsz, inputs["input_ids"].shape[1])
if True # is_tokenwise_cdcd_check(self.model)
else (bsz,),
device=simplex.device,
dtype=torch.int64,
)
# original_timesteps = timesteps
# if cdcd, we need to wrap the timesteps in a cdf.
# make sure we scale the timesteps to the correct range!
if is_cdcd_check(self.model):
input_ids = inputs["input_ids"]
span_mask = inputs["span_mask"]
token_input = torch.where(
(input_ids * span_mask) > 1, self.tokenizer.pad_token_id, input_ids
)
timesteps = self.model.warp_timesteps(
timesteps,
t_max=len(self.noise_scheduler) - 1,
token_input=token_input,
span_mask=span_mask,
)
# Adds noise to each simplex representation (Forward diffusion process).
noisy_simplex = self.noise_scheduler.add_noise(simplex, noise, timesteps)
timesteps = scale(timesteps, len(self.noise_scheduler))
# original_timesteps_scaled = scale(
# original_timesteps, len(self.noise_scheduler)
# )
# inputs.update({"original_timesteps": original_timesteps_scaled})
inputs.update(
{
"timesteps": timesteps,
"simplex": noisy_simplex,
}
)
# inputs.update({"max_timestep": len(self.noise_scheduler)})
if self.diffusion_args.self_condition is not None:
previous_pred = None
# last_hidden_state = None
if np.random.rand(1) > 0.5:
outputs = model(**inputs, previous_pred=previous_pred)
logits_projection_fct = lambda x: logits_projection( # noqa: E731
x,
self.diffusion_args.sampling_type,
self.diffusion_args.top_p,
self.diffusion_args.simplex_value,
self.diffusion_args.temperature,
)
previous_pred = self_condition_preds(
self.diffusion_args.self_condition,
outputs.logits,
logits_projection_fct,
)
# last_hidden_state = outputs.hidden_states
inputs.update(
{
"previous_pred": previous_pred,
# "previous_hidden": last_hidden_state,
}
)
# NOTE: we do this after computation of self-conditioning to not affect that one.
# inputs.update(
# {"classifier_free_guidance_in_train": self.classifier_free_guidance}
# )
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
return (
loss.detach()
) # no division by gradient accumulation steps for eval. we want per-sample avg loss.
# TODO: argument for doing one step.
def prediction_step(
self,
inputs: Dict[str, Union[torch.Tensor, Any]],
model: nn.Module,
pipeline: List[SimplexDDPMClassifierGuidancePipeline],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
inputs = self._prepare_inputs(inputs)
# full inference.
with torch.no_grad():
with self.compute_loss_context_manager():
for i, x in enumerate(
pipeline(
seq_length=self.data_args.max_seq_length
- self.data_args.truncation_length,
batch=inputs,
guidance_scale=self.diffusion_args.guidance_scale,
generator=torch.Generator(device=self.args.device).manual_seed(
self.args.seed
)
if self.diffusion_args.generate_with_seed
else None,
is_generator=False,
use_gumbel_softmax=self.diffusion_args.use_gumbel_softmax,
do_hard_sample=self.diffusion_args.do_hard_sample,
softmax_temperature=self.diffusion_args.softmax_temperature,
num_guidance_steps=self.diffusion_args.num_guidance_steps,
)
):
outputs = x
logits = nested_detach(outputs.logits)
simplex = nested_detach(outputs.simplex)
return (simplex, logits)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
noise_scheduler=None,
light_eval_dataloader=None,
do_light_eval=False,
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
args = self.args
is_conditional_generation = self.data_args.conditional_generation is not None
save_prefixes = is_conditional_generation
prediction_loss_only = (
prediction_loss_only
if prediction_loss_only is not None
else args.prediction_loss_only
)
# if eval is called w/o train handle model prep here
if self.is_deepspeed_enabled and self.model_wrapped is self.model:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = self.args.eval_batch_size
logger.info(f"***** Running {description} *****")
if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")
model.eval()
pipeline = SimplexDDPMClassifierGuidancePipeline(
model=model,
scheduler=noise_scheduler,
simplex_value=self.diffusion_args.simplex_value,
top_p=self.diffusion_args.top_p,
sampling_type=self.diffusion_args.sampling_type,
is_conditional_generation=is_conditional_generation,
tokenizer=self.tokenizer,
classifier_free_uncond_input=self.diffusion_args.classifier_free_uncond_input,
temperature=self.diffusion_args.temperature,
guidance_softmax_combination=self.diffusion_args.guidance_softmax_combination,
classifier_model_name_or_path=self.diffusion_args.classifier_model_name_or_path,
)
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)
# Initialize containers
# logits/simplex/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
logits_host = None
simplex_host = None
inputs_host = None
masks_host = None
prefixes_host = None
# logits/simplex/labels on CPU (final containers)
all_losses = None
all_logits = None
all_simplex = None
all_inputs = None
all_masks = None
all_prefixes = None
observed_num_examples = 0
# light evaluation loop.
if light_eval_dataloader is not None and do_light_eval:
for step, inputs in enumerate(light_eval_dataloader):
# Truncate the length if needed.
if self.data_args.truncation_length > 0:
inputs["input_ids"] = inputs["input_ids"][
:, : -self.data_args.truncation_length
]
inputs["span_mask"] = inputs["span_mask"][
:, : -self.data_args.truncation_length
]
max_seq_length = (
self.data_args.max_seq_length - self.data_args.truncation_length
)
assert self.data_args.eval_context_size < max_seq_length
# predict loss mimicking training.
loss = self.light_prediction_step(model, inputs)
if loss is not None:
losses = self._nested_gather(loss.repeat(batch_size))
losses_host = (
losses
if losses_host is None
else torch.cat((losses_host, losses), dim=0)
)
if (
args.eval_accumulation_steps is not None
and (step + 1) % args.eval_accumulation_steps == 0
):
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = (
losses
if all_losses is None
else np.concatenate((all_losses, losses), axis=0)
)
losses_host = None
# Main evaluation loop
for step, inputs in enumerate(dataloader):
has_mask = True if "span_mask" in inputs else False
# Truncate the length if needed.
if self.data_args.truncation_length > 0:
inputs["input_ids"] = inputs["input_ids"][
:, : -self.data_args.truncation_length
]
inputs["span_mask"] = inputs["span_mask"][
:, : -self.data_args.truncation_length
]
max_seq_length = (
self.data_args.max_seq_length - self.data_args.truncation_length
)
assert self.data_args.eval_context_size < max_seq_length
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# Prediction step
simplex, logits = self.prediction_step(inputs, model, pipeline=pipeline)
inputs_decode = self._prepare_input(inputs["input_ids"])
masks = self._prepare_input(inputs["span_mask"]) if has_mask else None
if save_prefixes:
prefixes = (
pad_data(
[input[~mask] for input, mask in zip(inputs_decode, masks)],
self.tokenizer,
)
if has_mask
else None
)
prefixes = self._prepare_input(prefixes)
else:
prefixes = None
# Update containers on host
if prefixes is not None:
prefixes = self.accelerator.pad_across_processes(
prefixes, dim=1, pad_index=self.eos_token_id
)
prefixes = self._nested_gather(prefixes)
prefixes_host = (
prefixes
if prefixes_host is None
else nested_concat(
prefixes_host, prefixes, padding_index=self.eos_token_id
)
)
if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(
inputs_decode, dim=1, pad_index=self.eos_token_id
)
inputs_decode = self._nested_gather(inputs_decode)
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(
inputs_host, inputs_decode, padding_index=self.eos_token_id
)
)
# Note that this block should be before masks block, since we need masks here.
if simplex is not None:
# In case of having a mask softmax is applied over the simplex non-masked values.
if has_mask:
mask_value = torch.finfo(simplex.dtype).min
mask_value = torch.tensor(
mask_value, dtype=simplex.dtype, device=simplex.device
)
simplex = torch.where(masks[:, :, None], simplex, mask_value)
simplex = F.softmax(simplex, dim=-1)
if self.preprocess_logits_for_metrics is not None:
simplex = self.preprocess_logits_for_metrics(simplex)
simplex = self.accelerator.pad_across_processes(
simplex, dim=1, pad_index=self.eos_token_id
)
simplex = self._nested_gather(simplex)
# TODO: note that this is no more a simplex, but the processed one.
simplex_host = (
simplex
if simplex_host is None
else nested_concat(
simplex_host, simplex, padding_index=self.eos_token_id
)
)
if masks is not None:
masks = self.accelerator.pad_across_processes(masks, dim=1, pad_index=0)
masks = self._nested_gather(masks)
# We pad masks with False tokens.
masks_host = (
masks
if masks_host is None
else nested_concat(masks_host, masks, padding_index=0)
)
if logits is not None:
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits)
logits = self.accelerator.pad_across_processes(
logits, dim=1, pad_index=self.eos_token_id
)
logits = self._nested_gather(logits)
logits_host = (
logits
if logits_host is None
else nested_concat(
logits_host, logits, padding_index=self.eos_token_id
)
)
self.control = self.callback_handler.on_prediction_step(
args, self.state, self.control
)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if (
args.eval_accumulation_steps is not None
and (step + 1) % args.eval_accumulation_steps == 0
):
if logits_host is not None:
logits = nested_numpify(logits_host)
all_logits = (
logits
if all_logits is None
else nested_concat(
all_logits, logits, padding_index=self.eos_token_id
)
)
if simplex_host is not None:
simplex = nested_numpify(simplex_host)
all_simplex = (
simplex
if all_simplex is None
else nested_concat(
all_simplex, simplex, padding_index=self.eos_token_id
)
)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(
all_inputs, inputs_decode, padding_index=self.eos_token_id
)
)
if masks_host is not None:
masks = nested_numpify(masks_host)
all_masks = (
masks
if all_masks is None
else nested_concat(all_masks, masks, padding_index=0)
)
if prefixes_host is not None:
prefixes = nested_numpify(prefixes_host)
all_prefixes = (
prefixes
if all_prefixes is None
else nested_concat(
all_prefixes, prefixes, padding_index=self.eos_token_id
)
)
# Set back to None to begin a new accumulation
logits_host, simplex_host, inputs_host, masks_host, prefixes_host = (
None,
None,
None,
None,
None,
)
# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
all_losses = nested_numpify(losses_host)
if logits_host is not None:
all_logits = nested_numpify(logits_host)
if simplex_host is not None:
all_simplex = nested_numpify(simplex_host)
if inputs_host is not None:
all_inputs = nested_numpify(inputs_host)
if masks_host is not None:
all_masks = nested_numpify(masks_host)
if prefixes_host is not None:
all_prefixes = nested_numpify(prefixes_host)
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif (
isinstance(eval_dataset, IterableDatasetShard)
and getattr(eval_dataset, "num_examples", 0) > 0
):
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples
if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples
# Generates the texts.
results = {}
if is_conditional_generation:
# We predict the masked tokens only. Here, we compute the masked tokens.
results.update(
predict_conditional_generated(
all_masks,
all_inputs,
self.tokenizer,
all_simplex,
"pred_texts_from_simplex",
self.data_args.skip_special_tokens,
)
)
results.update(
predict_conditional_generated(
all_masks,
all_inputs,
self.tokenizer,
all_logits,
"pred_texts_from_logits",
self.data_args.skip_special_tokens,
)
)
else:
results.update(
{
"pred_texts_from_simplex": self.tokenizer.batch_decode(
all_simplex,
skip_special_tokens=self.data_args.skip_special_tokens,
)
}
)
results.update(
{
"pred_texts_from_logits": self.tokenizer.batch_decode(
all_logits,
skip_special_tokens=self.data_args.skip_special_tokens,
)
}
)
if is_conditional_generation:
results.update(
{
"gold_texts_masked": [
self.tokenizer.decode(
input[mask],
skip_special_tokens=self.data_args.skip_special_tokens,
)
for mask, input in zip(all_masks, all_inputs)
]
}
)
if save_prefixes:
results.update(
{
"prefixes": [
self.tokenizer.decode(
x, skip_special_tokens=True
) # self.data_args.skip_special_tokens)
for x in all_prefixes
]
}
)
# Metrics.
if self.compute_metrics is not None:
metrics = self.compute_metrics(results)
else:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return EvalLoopOutput(
logits=all_logits,
simplex=all_simplex,
input_ids=all_inputs,
metrics=metrics,
num_samples=num_samples,
results=results,
)
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init `compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
ignore_keys (`Lst[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
# memory metrics - must set up as early as possible
self._memory_tracker.start()
eval_dataloader = self.get_eval_dataloader(eval_dataset)
light_eval_dataloader = self.get_light_eval_dataloader(eval_dataset)
start_time = time.time()
outputs = []
timesteps = self.inference_timesteps
for timestep, noise_scheduler in zip(
timesteps, self.inference_noise_schedulers
):
output = self.evaluation_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
noise_scheduler=noise_scheduler,
light_eval_dataloader=light_eval_dataloader,
do_light_eval=timestep
== timesteps[
0
], # we only need the loss once, since it is the same for all timesteps
)
outputs.append(output)
key_prefix = f"inference_{timestep}_"
metrics = {key_prefix + k: v for k, v in output.metrics.items()}
results = {key_prefix + k: v for k, v in output.results.items()}
# reset output with new metrics / results
output = EvalLoopOutput(
logits=output.logits,
simplex=output.simplex,
input_ids=output.input_ids,
metrics=metrics,
num_samples=output.num_samples,
results=results,
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
self.log(output.metrics)
self.control = self.callback_handler.on_evaluate(
self.args, self.state, self.control, output.metrics
)
self._memory_tracker.stop_and_update_metrics(output.metrics)
# Save the results
self.save_metrics(
GENERATION_RESULTS + "_" + key_prefix + metric_key_prefix,
output.results,
)
logger.info("Results are saved now")
# log outside so we can group generations together
if self.args.log_generated_texts:
length = len(outputs[0].logits)
results = {
f"{k}_inference_{i}": v
for o, i in zip(outputs, timesteps)
for k, v in o.results.items()
}
self.log_results_to_tensorboard(self.state, length, results)
return output.metrics
def log_results_to_tensorboard(self, state, length, results):
# TODO: we need to fix this which happens during the only eval option.
if self.tb_writer.tb_writer is None:
return
for i in range(length):
total_text = ""
for k, v in results.items():
total_text += f"*** {k} ***: {v[i]}" + " \n"
self.tb_writer.tb_writer.add_text(
f"sample_{i}", total_text, state.global_step
)
def get_train_dataloader(self) -> DataLoader:
self.data_collator = self.original_data_collator("train")
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
self.data_collator = self.original_data_collator("eval")
return super().get_eval_dataloader(eval_dataset)
def get_light_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
Used for the light evaluation, which matches masking with training.
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
by the `model.forward()` method are automatically removed. It must implement `__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
data_collator = self.original_data_collator("train")
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="evaluation"
)
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
def create_optimizer(self):
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is not None:
return self.optimizer
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args
)
# override to apply higher lr to timestep_embed and cdcd cdf
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and p.requires_grad
and not ("timestep_embed" in n or "cdf" in n)
)
],
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and p.requires_grad
and not ("timestep_embed" in n or "cdf" in n)
)
],
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (("timestep_embed" in n) and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.timestep_embed_lr or self.args.learning_rate,
},
]
# check cdcd
cdf_params = [
p
for n, p in opt_model.named_parameters()
if (("cdf" in n) and p.requires_grad)
]
if cdf_params:
optimizer_grouped_parameters.append(
{
"params": cdf_params,
"weight_decay": 0.0,
"lr": 1e-3,
}
)
optimizer_kwargs.pop("lr")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer