|
import logging |
|
import math |
|
import os |
|
import random |
|
import re |
|
import shutil |
|
import warnings |
|
from contextlib import contextmanager |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from packaging import version |
|
from torch import nn |
|
from torch.utils.data.dataloader import DataLoader |
|
from torch.utils.data.dataset import Dataset |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler |
|
from tqdm.auto import tqdm, trange |
|
|
|
from transformers.data.data_collator import DataCollator, default_data_collator |
|
from transformers.file_utils import is_apex_available, is_torch_tpu_available |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.optimization import AdamW, get_linear_schedule_with_warmup |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available |
|
|
|
from relogic.pretrainkit.training_args import TrainingArguments |
|
from relogic.pretrainkit.trainer_utils import EvalPredictionWithSize, PredictionOutputWithSize |
|
|
|
|
|
|
|
if is_apex_available(): |
|
from apex import amp |
|
|
|
|
|
if is_torch_tpu_available(): |
|
import torch_xla.core.xla_model as xm |
|
import torch_xla.debug.metrics as met |
|
import torch_xla.distributed.parallel_loader as pl |
|
|
|
try: |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
_has_tensorboard = True |
|
except ImportError: |
|
try: |
|
from tensorboardX import SummaryWriter |
|
|
|
_has_tensorboard = True |
|
except ImportError: |
|
_has_tensorboard = False |
|
|
|
|
|
def is_tensorboard_available(): |
|
return _has_tensorboard |
|
|
|
|
|
if is_wandb_available(): |
|
import wandb |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def set_seed(seed: int): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
@contextmanager |
|
def torch_distributed_zero_first(local_rank: int): |
|
""" |
|
Decorator to make all processes in distributed training wait for each local_master to do something. |
|
""" |
|
if local_rank not in [-1, 0]: |
|
torch.distributed.barrier() |
|
yield |
|
if local_rank == 0: |
|
torch.distributed.barrier() |
|
|
|
|
|
class SequentialDistributedSampler(Sampler): |
|
""" |
|
Distributed Sampler that subsamples indicies sequentially, |
|
making it easier to collate all results at the end. |
|
|
|
Even though we only use this sampler for eval and predict (no training), |
|
which means that the model params won't have to be synced (i.e. will not hang |
|
for synchronization even if varied number of forward passes), we still add extra |
|
samples to the sampler to make it evenly divisible (like in `DistributedSampler`) |
|
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. |
|
""" |
|
|
|
def __init__(self, dataset, num_replicas=None, rank=None): |
|
if num_replicas is None: |
|
if not torch.distributed.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
num_replicas = torch.distributed.get_world_size() |
|
if rank is None: |
|
if not torch.distributed.is_available(): |
|
raise RuntimeError("Requires distributed package to be available") |
|
rank = torch.distributed.get_rank() |
|
self.dataset = dataset |
|
self.num_replicas = num_replicas |
|
self.rank = rank |
|
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
|
self.total_size = self.num_samples * self.num_replicas |
|
|
|
def __iter__(self): |
|
indices = list(range(len(self.dataset))) |
|
|
|
|
|
indices += indices[: (self.total_size - len(indices))] |
|
assert len(indices) == self.total_size |
|
|
|
|
|
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] |
|
assert len(indices) == self.num_samples |
|
|
|
return iter(indices) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
|
|
def get_tpu_sampler(dataset: Dataset): |
|
if xm.xrt_world_size() <= 1: |
|
return RandomSampler(dataset) |
|
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) |
|
|
|
|
|
class Trainer: |
|
""" |
|
Trainer is a simple but feature-complete training and eval loop for PyTorch, |
|
optimized for Transformers. |
|
""" |
|
|
|
model: PreTrainedModel |
|
args: TrainingArguments |
|
data_collator: DataCollator |
|
train_dataset: Optional[Dataset] |
|
eval_dataset: Optional[Dataset] |
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None |
|
prediction_loss_only: bool |
|
tb_writer: Optional["SummaryWriter"] = None |
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None |
|
global_step: Optional[int] = None |
|
epoch: Optional[float] = None |
|
|
|
def __init__( |
|
self, |
|
model: PreTrainedModel, |
|
args: TrainingArguments, |
|
data_collator: Optional[DataCollator] = None, |
|
train_dataset: Optional[Dataset] = None, |
|
eval_dataset: Optional[Dataset] = None, |
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, |
|
prediction_loss_only=False, |
|
tb_writer: Optional["SummaryWriter"] = None, |
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None, |
|
): |
|
""" |
|
Trainer is a simple but feature-complete training and eval loop for PyTorch, |
|
optimized for Transformers. |
|
|
|
Args: |
|
prediction_loss_only: |
|
(Optional) in evaluation and prediction, only return the loss |
|
""" |
|
self.model = model.to(args.device) |
|
self.args = args |
|
self.data_collator = data_collator if data_collator is not None else default_data_collator |
|
self.train_dataset = train_dataset |
|
self.eval_dataset = eval_dataset |
|
self.compute_metrics = compute_metrics |
|
self.prediction_loss_only = prediction_loss_only |
|
self.optimizers = optimizers |
|
if tb_writer is not None: |
|
self.tb_writer = tb_writer |
|
elif is_tensorboard_available() and self.is_world_master(): |
|
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) |
|
if not is_tensorboard_available(): |
|
logger.warning( |
|
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." |
|
) |
|
if is_wandb_available(): |
|
self._setup_wandb() |
|
else: |
|
logger.info( |
|
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, " |
|
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface." |
|
) |
|
set_seed(self.args.seed) |
|
|
|
if self.is_world_master(): |
|
os.makedirs(self.args.output_dir, exist_ok=True) |
|
if is_torch_tpu_available(): |
|
|
|
|
|
self.model.config.xla_device = True |
|
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): |
|
self.data_collator = self.data_collator.collate_batch |
|
warnings.warn( |
|
( |
|
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes " |
|
+ "with a `collate_batch` are deprecated and won't be supported in a future version." |
|
), |
|
FutureWarning, |
|
) |
|
|
|
def get_train_dataloader(self) -> DataLoader: |
|
if self.train_dataset is None: |
|
raise ValueError("Trainer: training requires a train_dataset.") |
|
if is_torch_tpu_available(): |
|
train_sampler = get_tpu_sampler(self.train_dataset) |
|
else: |
|
train_sampler = ( |
|
RandomSampler(self.train_dataset) |
|
if self.args.local_rank == -1 |
|
else DistributedSampler(self.train_dataset) |
|
) |
|
|
|
data_loader = DataLoader( |
|
self.train_dataset, |
|
batch_size=self.args.train_batch_size, |
|
sampler=train_sampler, |
|
collate_fn=self.data_collator, |
|
drop_last=self.args.dataloader_drop_last, |
|
) |
|
|
|
return data_loader |
|
|
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
|
if eval_dataset is None and self.eval_dataset is None: |
|
raise ValueError("Trainer: evaluation requires an eval_dataset.") |
|
|
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset |
|
|
|
if is_torch_tpu_available(): |
|
sampler = SequentialDistributedSampler( |
|
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() |
|
) |
|
elif self.args.local_rank != -1: |
|
sampler = SequentialDistributedSampler(eval_dataset) |
|
else: |
|
sampler = SequentialSampler(eval_dataset) |
|
|
|
data_loader = DataLoader( |
|
eval_dataset, |
|
sampler=sampler, |
|
batch_size=self.args.eval_batch_size, |
|
collate_fn=self.data_collator, |
|
drop_last=self.args.dataloader_drop_last, |
|
) |
|
|
|
return data_loader |
|
|
|
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: |
|
|
|
if is_torch_tpu_available(): |
|
sampler = SequentialDistributedSampler( |
|
test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() |
|
) |
|
elif self.args.local_rank != -1: |
|
sampler = SequentialDistributedSampler(test_dataset) |
|
else: |
|
sampler = SequentialSampler(test_dataset) |
|
|
|
data_loader = DataLoader( |
|
test_dataset, |
|
sampler=sampler, |
|
batch_size=self.args.eval_batch_size, |
|
collate_fn=self.data_collator, |
|
drop_last=self.args.dataloader_drop_last, |
|
) |
|
|
|
return data_loader |
|
|
|
def get_optimizers( |
|
self, num_training_steps: int |
|
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]: |
|
""" |
|
Setup the optimizer and the learning rate scheduler. |
|
|
|
We provide a reasonable default that works well. |
|
If you want to use something else, you can pass a tuple in the Trainer's init, |
|
or override this method in a subclass. |
|
""" |
|
if self.optimizers is not None: |
|
return self.optimizers |
|
|
|
no_decay = ["bias", "LayerNorm.weight"] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in self.model.named_parameters() if "relational_transformer" not in n and not any(nd in n for nd in no_decay)], |
|
"weight_decay": self.args.weight_decay, |
|
}, |
|
{ |
|
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], |
|
"weight_decay": 0.0, |
|
}, |
|
{ |
|
"params": [p for n, p in self.model.named_parameters() if "relational_transformer" in n and not any(nd in n for nd in no_decay)], |
|
"weight_decay": self.args.weight_decay, |
|
"lr": 7e-5 |
|
} |
|
] |
|
optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps |
|
) |
|
return optimizer, scheduler |
|
|
|
def _setup_wandb(self): |
|
""" |
|
Setup the optional Weights & Biases (`wandb`) integration. |
|
|
|
One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface |
|
You can also override the following environment variables: |
|
|
|
Environment: |
|
WANDB_WATCH: |
|
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging |
|
or "all" to log gradients and parameters |
|
WANDB_PROJECT: |
|
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project |
|
WANDB_DISABLED: |
|
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely |
|
""" |
|
if self.is_world_master(): |
|
logger.info( |
|
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' |
|
) |
|
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) |
|
|
|
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": |
|
wandb.watch( |
|
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps) |
|
) |
|
|
|
def num_examples(self, dataloader: DataLoader) -> int: |
|
""" |
|
Helper to get num of examples from a DataLoader, by accessing its Dataset. |
|
""" |
|
return len(dataloader.dataset) |
|
|
|
def train(self, model_path: Optional[str] = None): |
|
""" |
|
Main training entry point. |
|
|
|
Args: |
|
model_path: |
|
(Optional) Local path to model if model to train has been instantiated from a local path |
|
If present, we will try reloading the optimizer/scheduler states from there. |
|
""" |
|
train_dataloader = self.get_train_dataloader() |
|
if self.args.max_steps > 0: |
|
t_total = self.args.max_steps |
|
num_train_epochs = ( |
|
self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 |
|
) |
|
else: |
|
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) |
|
num_train_epochs = self.args.num_train_epochs |
|
|
|
optimizer, scheduler = self.get_optimizers(num_training_steps=t_total) |
|
|
|
|
|
if ( |
|
model_path is not None |
|
and os.path.isfile(os.path.join(model_path, "optimizer.pt")) |
|
and os.path.isfile(os.path.join(model_path, "scheduler.pt")) |
|
): |
|
|
|
optimizer.load_state_dict( |
|
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) |
|
) |
|
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) |
|
|
|
model = self.model |
|
if self.args.fp16: |
|
if not is_apex_available(): |
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") |
|
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level) |
|
|
|
|
|
if self.args.n_gpu > 1: |
|
model = torch.nn.DataParallel(model) |
|
|
|
|
|
if self.args.local_rank != -1: |
|
model = torch.nn.parallel.DistributedDataParallel( |
|
model, |
|
device_ids=[self.args.local_rank], |
|
output_device=self.args.local_rank, |
|
find_unused_parameters=True, |
|
) |
|
|
|
if self.tb_writer is not None: |
|
self.tb_writer.add_text("args", self.args.to_json_string()) |
|
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) |
|
|
|
|
|
if is_torch_tpu_available(): |
|
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() |
|
else: |
|
total_train_batch_size = ( |
|
self.args.train_batch_size |
|
* self.args.gradient_accumulation_steps |
|
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) |
|
) |
|
logger.info("***** Running training *****") |
|
logger.info(" Num examples = %d", self.num_examples(train_dataloader)) |
|
logger.info(" Num Epochs = %d", num_train_epochs) |
|
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) |
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) |
|
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) |
|
logger.info(" Total optimization steps = %d", t_total) |
|
|
|
self.global_step = 0 |
|
self.epoch = 0 |
|
epochs_trained = 0 |
|
steps_trained_in_current_epoch = 0 |
|
|
|
if model_path is not None: |
|
|
|
try: |
|
self.global_step = int(model_path.split("-")[-1].split("/")[0]) |
|
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) |
|
steps_trained_in_current_epoch = self.global_step % ( |
|
len(train_dataloader) // self.args.gradient_accumulation_steps |
|
) |
|
|
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
|
logger.info(" Continuing training from epoch %d", epochs_trained) |
|
logger.info(" Continuing training from global step %d", self.global_step) |
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) |
|
except ValueError: |
|
self.global_step = 0 |
|
logger.info(" Starting fine-tuning.") |
|
|
|
tr_loss = 0.0 |
|
logging_loss = 0.0 |
|
model.zero_grad() |
|
train_iterator = trange( |
|
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() or not self.args.logging_tqdm |
|
) |
|
for epoch in train_iterator: |
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): |
|
train_dataloader.sampler.set_epoch(epoch) |
|
|
|
if is_torch_tpu_available(): |
|
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( |
|
self.args.device |
|
) |
|
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master() or not self.args.logging_tqdm) |
|
else: |
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master() or not self.args.logging_tqdm) |
|
|
|
for step, inputs in enumerate(epoch_iterator): |
|
|
|
|
|
if steps_trained_in_current_epoch > 0: |
|
steps_trained_in_current_epoch -= 1 |
|
continue |
|
|
|
tr_loss += self._training_step(model, inputs, optimizer) |
|
|
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( |
|
|
|
len(epoch_iterator) <= self.args.gradient_accumulation_steps |
|
and (step + 1) == len(epoch_iterator) |
|
): |
|
if self.args.fp16: |
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm) |
|
else: |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) |
|
|
|
if is_torch_tpu_available(): |
|
xm.optimizer_step(optimizer) |
|
else: |
|
optimizer.step() |
|
|
|
scheduler.step() |
|
model.zero_grad() |
|
self.global_step += 1 |
|
self.epoch = epoch + (step + 1) / len(epoch_iterator) |
|
|
|
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( |
|
self.global_step == 1 and self.args.logging_first_step |
|
): |
|
logs: Dict[str, float] = {} |
|
logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps |
|
|
|
logs["learning_rate"] = ( |
|
scheduler.get_last_lr()[0] |
|
if version.parse(torch.__version__) >= version.parse("1.4") |
|
else scheduler.get_lr()[0] |
|
) |
|
logging_loss = tr_loss |
|
|
|
self._log(logs) |
|
|
|
if (self.args.eval_steps > 0 and self.global_step % self.args.eval_steps == 0): |
|
if self.args.evaluate_during_training: |
|
self.evaluate() |
|
|
|
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: |
|
|
|
|
|
if hasattr(model, "module"): |
|
assert model.module is self.model |
|
else: |
|
assert model is self.model |
|
|
|
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}") |
|
|
|
self.save_model(output_dir) |
|
|
|
if self.is_world_master(): |
|
self._rotate_checkpoints() |
|
|
|
if is_torch_tpu_available(): |
|
xm.rendezvous("saving_optimizer_states") |
|
xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) |
|
xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) |
|
elif self.is_world_master(): |
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) |
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) |
|
|
|
if self.args.max_steps > 0 and self.global_step > self.args.max_steps: |
|
epoch_iterator.close() |
|
break |
|
if self.args.max_steps > 0 and self.global_step > self.args.max_steps: |
|
train_iterator.close() |
|
break |
|
if self.args.tpu_metrics_debug: |
|
|
|
xm.master_print(met.metrics_report()) |
|
|
|
if self.tb_writer: |
|
self.tb_writer.close() |
|
|
|
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
|
return TrainOutput(self.global_step, tr_loss / self.global_step) |
|
|
|
def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: |
|
if self.epoch is not None: |
|
logs["epoch"] = self.epoch |
|
if self.global_step is None: |
|
|
|
self.global_step = 0 |
|
if self.tb_writer: |
|
for k, v in logs.items(): |
|
if isinstance(v, (int, float)): |
|
self.tb_writer.add_scalar(k, v, self.global_step) |
|
else: |
|
logger.warning( |
|
"Trainer is attempting to log a value of " |
|
'"%s" of type %s for key "%s" as a scalar. ' |
|
"This invocation of Tensorboard's writer.add_scalar() " |
|
"is incorrect so we dropped this attribute.", |
|
v, |
|
type(v), |
|
k, |
|
) |
|
self.tb_writer.flush() |
|
if is_wandb_available(): |
|
if self.is_world_master(): |
|
wandb.log(logs, step=self.global_step) |
|
output = {**logs, **{"step": self.global_step}} |
|
if iterator is not None: |
|
iterator.write(output) |
|
else: |
|
logger.info(output) |
|
|
|
def _training_step( |
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer |
|
) -> float: |
|
model.train() |
|
for k, v in inputs.items(): |
|
if isinstance(v, torch.Tensor): |
|
inputs[k] = v.to(self.args.device) |
|
|
|
outputs = model(**inputs) |
|
loss = outputs[0] |
|
|
|
if self.args.n_gpu > 1: |
|
loss = loss.mean() |
|
if self.args.gradient_accumulation_steps > 1: |
|
loss = loss / self.args.gradient_accumulation_steps |
|
|
|
if self.args.fp16: |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
else: |
|
loss.backward() |
|
|
|
return loss.item() |
|
|
|
def is_local_master(self) -> bool: |
|
if is_torch_tpu_available(): |
|
return xm.is_master_ordinal(local=True) |
|
else: |
|
return self.args.local_rank in [-1, 0] |
|
|
|
def is_world_master(self) -> bool: |
|
""" |
|
This will be True only in one process, even in distributed mode, |
|
even when training on multiple machines. |
|
""" |
|
if is_torch_tpu_available(): |
|
return xm.is_master_ordinal(local=False) |
|
else: |
|
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0 |
|
|
|
def save_model(self, output_dir: Optional[str] = None): |
|
""" |
|
Saving best-practices: if you use default names for the model, |
|
you can reload it using from_pretrained(). |
|
|
|
Will only save from the world_master process (unless in TPUs). |
|
""" |
|
|
|
if is_torch_tpu_available(): |
|
self._save_tpu(output_dir) |
|
elif self.is_world_master(): |
|
self._save(output_dir) |
|
|
|
def _save_tpu(self, output_dir: Optional[str] = None): |
|
output_dir = output_dir if output_dir is not None else self.args.output_dir |
|
logger.info("Saving model checkpoint to %s", output_dir) |
|
|
|
if xm.is_master_ordinal(): |
|
os.makedirs(output_dir, exist_ok=True) |
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) |
|
|
|
|
|
|
|
if not isinstance(self.model, PreTrainedModel): |
|
raise ValueError("Trainer.model appears to not be a PreTrainedModel") |
|
|
|
xm.rendezvous("saving_checkpoint") |
|
self.model.save_pretrained(output_dir) |
|
|
|
def _save(self, output_dir: Optional[str] = None): |
|
output_dir = output_dir if output_dir is not None else self.args.output_dir |
|
os.makedirs(output_dir, exist_ok=True) |
|
logger.info("Saving model checkpoint to %s", output_dir) |
|
|
|
|
|
|
|
|
|
self.model.save_pretrained(output_dir) |
|
|
|
|
|
torch.save(self.args, os.path.join(output_dir, "training_args.bin")) |
|
|
|
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: |
|
ordering_and_checkpoint_path = [] |
|
|
|
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")] |
|
|
|
for path in glob_checkpoints: |
|
if use_mtime: |
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) |
|
else: |
|
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) |
|
if regex_match and regex_match.groups(): |
|
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) |
|
|
|
checkpoints_sorted = sorted(ordering_and_checkpoint_path) |
|
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] |
|
return checkpoints_sorted |
|
|
|
def _rotate_checkpoints(self, use_mtime=False) -> None: |
|
if self.args.save_total_limit is None or self.args.save_total_limit <= 0: |
|
return |
|
|
|
|
|
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime) |
|
if len(checkpoints_sorted) <= self.args.save_total_limit: |
|
return |
|
|
|
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit) |
|
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] |
|
for checkpoint in checkpoints_to_be_deleted: |
|
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) |
|
shutil.rmtree(checkpoint) |
|
|
|
def evaluate( |
|
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, |
|
) -> Dict[str, float]: |
|
""" |
|
Run evaluation and return metrics. |
|
|
|
The calling script will be responsible for providing a method to compute metrics, as they are |
|
task-dependent. |
|
|
|
Args: |
|
eval_dataset: (Optional) Pass a dataset if you wish to override |
|
the one on the instance. |
|
Returns: |
|
A dict containing: |
|
- the eval loss |
|
- the potential metrics computed from the predictions |
|
""" |
|
eval_dataloader = self.get_eval_dataloader(eval_dataset) |
|
|
|
output = self._prediction_loop(eval_dataloader, description="Evaluation") |
|
|
|
self._log(output.metrics) |
|
|
|
if self.args.tpu_metrics_debug: |
|
|
|
xm.master_print(met.metrics_report()) |
|
|
|
return output.metrics |
|
|
|
def predict(self, test_dataset: Dataset) -> PredictionOutput: |
|
""" |
|
Run prediction and return predictions and potential metrics. |
|
|
|
Depending on the dataset and your use case, your test dataset may contain labels. |
|
In that case, this method will also return metrics, like in evaluate(). |
|
""" |
|
test_dataloader = self.get_test_dataloader(test_dataset) |
|
|
|
return self._prediction_loop(test_dataloader, description="Prediction") |
|
|
|
def _prediction_loop( |
|
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None |
|
) -> PredictionOutput: |
|
""" |
|
Prediction/evaluation loop, shared by `evaluate()` and `predict()`. |
|
|
|
Works both with or without labels. |
|
|
|
NOTE: One issue is on the size of prediction and labels. |
|
For current code, it considers all the prediction and labels in different batch have same length of sequence. |
|
This is not true for our application. To make this more general, I will reformat the predictions and labels. |
|
|
|
""" |
|
|
|
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only |
|
|
|
model = self.model |
|
|
|
if self.args.n_gpu > 1: |
|
model = torch.nn.DataParallel(model) |
|
else: |
|
model = self.model |
|
|
|
|
|
|
|
batch_size = dataloader.batch_size |
|
logger.info("***** Running %s *****", description) |
|
logger.info(" Num examples = %d", self.num_examples(dataloader)) |
|
logger.info(" Batch size = %d", batch_size) |
|
eval_losses: List[float] = [] |
|
preds: torch.Tensor = None |
|
preds_size: torch.Tensor = None |
|
label_ids: torch.Tensor = None |
|
label_size: torch.Tensor = None |
|
model.eval() |
|
|
|
if is_torch_tpu_available(): |
|
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) |
|
|
|
for inputs in tqdm(dataloader, desc=description): |
|
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) |
|
|
|
for k, v in inputs.items(): |
|
if isinstance(v, torch.Tensor): |
|
inputs[k] = v.to(self.args.device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
if has_labels: |
|
step_eval_loss, logits = outputs[:2] |
|
eval_losses += [step_eval_loss.mean().item()] |
|
else: |
|
logits = outputs[0] |
|
|
|
if not prediction_loss_only: |
|
|
|
|
|
if preds is None: |
|
preds = logits.detach() |
|
preds_size = preds.new_full(size=preds.size()[:1], fill_value=preds.size(1)).detach() |
|
preds = preds.view(-1) |
|
else: |
|
preds_size = torch.cat((preds_size, logits.new_full(size=logits.size()[:1], fill_value=logits.size(1)).detach()), dim=0) |
|
preds = torch.cat((preds, logits.detach().view(-1)), dim=0) |
|
|
|
if inputs.get("labels") is not None: |
|
if label_ids is None: |
|
label_ids = inputs["labels"].detach() |
|
label_size = label_ids.new_full(size=label_ids.size()[:1], fill_value=label_ids.size(1)).detach() |
|
label_ids = label_ids.view(-1) |
|
else: |
|
label_size = torch.cat((label_size, inputs["labels"].new_full(size=inputs["labels"].size()[:1], fill_value=inputs["labels"].size(1)).detach()), dim=0) |
|
label_ids = torch.cat((label_ids, inputs["labels"].detach().view(-1)), dim=0) |
|
|
|
if self.args.local_rank != -1: |
|
|
|
if preds is not None: |
|
|
|
preds, preds_size = self.distributed_concat_with_size(preds, preds_size, num_total_examples=self.num_examples(dataloader)) |
|
if label_ids is not None: |
|
|
|
label_ids, label_size = self.distributed_concat_with_size(label_ids, label_size, num_total_examples=self.num_examples(dataloader)) |
|
elif is_torch_tpu_available(): |
|
|
|
|
|
if preds is not None: |
|
preds = xm.mesh_reduce("eval_preds", preds, torch.cat) |
|
if label_ids is not None: |
|
label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) |
|
|
|
|
|
if preds is not None: |
|
preds = preds.cpu().numpy() |
|
preds_size = preds_size.cpu().numpy() |
|
if label_ids is not None: |
|
label_ids = label_ids.cpu().numpy() |
|
label_size = label_size.cpu().numpy() |
|
if self.compute_metrics is not None and preds is not None and label_ids is not None: |
|
|
|
metrics = self.compute_metrics(EvalPredictionWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size)) |
|
else: |
|
metrics = {} |
|
if len(eval_losses) > 0: |
|
metrics["eval_loss"] = np.mean(eval_losses) |
|
|
|
|
|
for key in list(metrics.keys()): |
|
if not key.startswith("eval_"): |
|
metrics[f"eval_{key}"] = metrics.pop(key) |
|
|
|
|
|
return PredictionOutputWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size, metrics=metrics) |
|
|
|
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor: |
|
assert self.args.local_rank != -1 |
|
|
|
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] |
|
torch.distributed.all_gather(output_tensors, tensor) |
|
|
|
concat = torch.cat(output_tensors, dim=0) |
|
|
|
|
|
output = concat[:num_total_examples] |
|
return output |
|
|
|
def distributed_concat_tensor(self, tensor: torch.Tensor): |
|
assert self.args.local_rank != -1 |
|
|
|
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] |
|
torch.distributed.all_gather(output_tensors, tensor) |
|
|
|
concat = torch.cat(output_tensors, dim=0) |
|
return concat |
|
|
|
def distributed_concat_varsize_tensor(self, tensor: torch.Tensor): |
|
assert self.args.local_rank != -1 |
|
|
|
sizes = self.distributed_concat_tensor(tensor.new_full(size=(1,), fill_value=tensor.size(0))) |
|
max_size = sizes.max().item() |
|
|
|
padded = tensor.new_zeros(max_size) |
|
padded[:tensor.size(0)] = tensor |
|
|
|
padded_agg = self.distributed_concat_tensor(padded) |
|
slices = [] |
|
for i, size in enumerate(sizes): |
|
start_idx = i * max_size |
|
end_idx = start_idx + size.item() |
|
slices.append(padded_agg[start_idx: end_idx]) |
|
ret = torch.cat(slices, dim=0) |
|
return ret |
|
|
|
|
|
def distributed_concat_with_size(self, tensor: torch.Tensor, size: torch.Tensor, num_total_examples: int) -> torch.Tensor: |
|
assert self.args.local_rank != -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
concat_sizes = self.distributed_concat_varsize_tensor(size) |
|
concat = self.distributed_concat_varsize_tensor(tensor) |
|
|
|
|
|
|
|
assert concat_sizes.sum() == concat.size(0) |
|
return concat, concat_sizes |
|
|
|
|
|
|
|
|