|
import torch
|
|
import abc
|
|
import os
|
|
import copy
|
|
|
|
import pytorch_lightning as pl
|
|
from utils.lr_scheduler import *
|
|
from torch import distributed as dist
|
|
|
|
|
|
class AbstractModel(pl.LightningModule):
|
|
def __init__(self,
|
|
lr_scheduler_kwargs: dict = None,
|
|
optimizer_kwargs: dict = None,
|
|
save_path: str = None,
|
|
from_checkpoint: str = None,
|
|
load_prev_scheduler: bool = False,
|
|
save_weights_only: bool = True,):
|
|
"""
|
|
|
|
Args:
|
|
lr_scheduler: Kwargs for lr_scheduler
|
|
optimizer_kwargs: Kwargs for optimizer_kwargs
|
|
save_path: Save trained model
|
|
from_checkpoint: Load model from checkpoint
|
|
load_prev_scheduler: Whether load previous scheduler from checkpoint
|
|
load_strict: Whether load model strictly
|
|
save_weights_only: Whether save only weights or also optimizer and lr_scheduler
|
|
|
|
"""
|
|
super().__init__()
|
|
self.initialize_model()
|
|
|
|
self.metrics = {}
|
|
for stage in ["train", "valid", "test"]:
|
|
stage_metrics = self.initialize_metrics(stage)
|
|
|
|
for metric_name, metric in stage_metrics.items():
|
|
setattr(self, metric_name, metric)
|
|
|
|
self.metrics[stage] = stage_metrics
|
|
|
|
if lr_scheduler_kwargs is None:
|
|
|
|
self.lr_scheduler_kwargs = {
|
|
"class": "ConstantLRScheduler",
|
|
"init_lr": 0,
|
|
}
|
|
print("No lr_scheduler_kwargs provided. The default learning rate is 0.")
|
|
|
|
else:
|
|
self.lr_scheduler_kwargs = lr_scheduler_kwargs
|
|
|
|
if optimizer_kwargs is None:
|
|
|
|
self.optimizer_kwargs = {
|
|
"class": "AdamW",
|
|
"betas": (0.9, 0.98),
|
|
"weight_decay": 0.01,
|
|
}
|
|
print("No optimizer_kwargs provided. The default optimizer is AdamW.")
|
|
else:
|
|
self.optimizer_kwargs = optimizer_kwargs
|
|
self.init_optimizers()
|
|
|
|
self.save_path = save_path
|
|
self.save_weights_only = save_weights_only
|
|
|
|
|
|
self.temp_step = 0
|
|
self.step = 0
|
|
self.epoch = 0
|
|
|
|
self.load_prev_scheduler = load_prev_scheduler
|
|
self.from_checkpoint = from_checkpoint
|
|
if from_checkpoint:
|
|
self.load_checkpoint(from_checkpoint)
|
|
|
|
@abc.abstractmethod
|
|
def initialize_model(self) -> None:
|
|
"""
|
|
All model initialization should be done here
|
|
Note that the whole model must be named as "self.model" for model saving and loading
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def forward(self, *args, **kwargs):
|
|
"""
|
|
Forward propagation
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def initialize_metrics(self, stage: str) -> dict:
|
|
"""
|
|
Initialize metrics for each stage
|
|
Args:
|
|
stage: "train", "valid" or "test"
|
|
|
|
Returns:
|
|
A dictionary of metrics for the stage. Keys are metric names and values are metric objects
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
|
|
"""
|
|
|
|
Args:
|
|
stage: "train", "valid" or "test"
|
|
outputs: model outputs for calculating loss
|
|
labels: labels for calculating loss
|
|
|
|
Returns:
|
|
loss
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def load_weights(model, weights):
|
|
model_dict = model.state_dict()
|
|
|
|
unused_params = []
|
|
missed_params = list(model_dict.keys())
|
|
|
|
for k, v in weights.items():
|
|
if k in model_dict.keys():
|
|
model_dict[k] = v
|
|
missed_params.remove(k)
|
|
|
|
else:
|
|
unused_params.append(k)
|
|
|
|
if len(missed_params) > 0:
|
|
print(f"\033[31mSome weights of {type(model).__name__} were not "
|
|
f"initialized from the model checkpoint: {missed_params}\033[0m")
|
|
|
|
if len(unused_params) > 0:
|
|
print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")
|
|
|
|
model.load_state_dict(model_dict)
|
|
|
|
def optimizer_step(
|
|
self,
|
|
epoch: int,
|
|
batch_idx: int,
|
|
optimizer,
|
|
optimizer_closure=None,
|
|
) -> None:
|
|
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)
|
|
|
|
self.temp_step += 1
|
|
if self.temp_step == self.trainer.accumulate_grad_batches:
|
|
self.step += 1
|
|
self.temp_step = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_end(self):
|
|
self.epoch += 1
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
inputs, labels = batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = self(**inputs)
|
|
loss = self.loss_func('train', outputs, labels)
|
|
|
|
self.log("loss", loss, prog_bar=True)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
inputs, labels = batch
|
|
outputs = self(**inputs)
|
|
loss = self.loss_func('valid', outputs, labels)
|
|
self.valid_outputs.append(loss)
|
|
return loss
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
inputs, labels = batch
|
|
outputs = self(**inputs)
|
|
|
|
loss = self.loss_func('test', outputs, labels)
|
|
self.test_outputs.append(loss)
|
|
return loss
|
|
|
|
def on_train_start(self) -> None:
|
|
|
|
if getattr(self, "prev_schechuler", None) is not None:
|
|
try:
|
|
self.step = self.prev_schechuler["global_step"]
|
|
self.epoch = self.prev_schechuler["epoch"]
|
|
self.best_value = self.prev_schechuler["best_value"]
|
|
self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
|
|
print(f"Previous training global step: {self.step}")
|
|
print(f"Previous training epoch: {self.epoch}")
|
|
print(f"Previous best value: {self.best_value}")
|
|
print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
|
|
|
|
|
|
if hasattr(self.trainer.strategy, "deepspeed_engine"):
|
|
|
|
try:
|
|
self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
else:
|
|
|
|
self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
|
|
|
|
def on_validation_epoch_start(self) -> None:
|
|
setattr(self, "valid_outputs", [])
|
|
|
|
def on_test_epoch_start(self) -> None:
|
|
setattr(self, "test_outputs", [])
|
|
|
|
def load_checkpoint(self, from_checkpoint: str) -> None:
|
|
"""
|
|
Args:
|
|
from_checkpoint: Path to checkpoint.
|
|
"""
|
|
|
|
|
|
if os.path.isdir(from_checkpoint):
|
|
basename = os.path.basename(from_checkpoint)
|
|
from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")
|
|
|
|
state_dict = torch.load(from_checkpoint, map_location=self.device)
|
|
self.load_weights(self.model, state_dict["model"])
|
|
|
|
if self.load_prev_scheduler:
|
|
state_dict.pop("model")
|
|
self.prev_schechuler = state_dict
|
|
|
|
def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
|
|
"""
|
|
Save model to save_path
|
|
Args:
|
|
save_path: Path to save model
|
|
save_info: Other info to save
|
|
save_weights_only: Whether only save model weights
|
|
"""
|
|
dir = os.path.dirname(save_path)
|
|
os.makedirs(dir, exist_ok=True)
|
|
|
|
state_dict = {} if save_info is None else save_info
|
|
state_dict["model"] = self.model.state_dict()
|
|
|
|
|
|
for k, v in state_dict["model"].items():
|
|
state_dict["model"][k] = v.float()
|
|
|
|
if not save_weights_only:
|
|
state_dict["global_step"] = self.step
|
|
state_dict["epoch"] = self.epoch
|
|
state_dict["best_value"] = getattr(self, f"best_value", None)
|
|
state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
|
|
|
|
|
|
if not hasattr(self.trainer.strategy, "deepspeed_engine"):
|
|
state_dict["optimizer"] = self.optimizers().optimizer.state_dict()
|
|
|
|
torch.save(state_dict, save_path)
|
|
|
|
def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
|
|
"""
|
|
Check whether to save model. If save_path is not None and now_value is the best, save model.
|
|
Args:
|
|
now_value: Current metric value
|
|
mode: "min" or "max", meaning whether the lower the better or the higher the better
|
|
save_info: Other info to save
|
|
"""
|
|
|
|
assert mode in ["min", "max"], "mode should be 'min' or 'max'"
|
|
|
|
if self.save_path is not None:
|
|
|
|
save_path = eval(f"f'{self.save_path}'")
|
|
|
|
dir = os.path.dirname(save_path)
|
|
os.makedirs(dir, exist_ok=True)
|
|
|
|
|
|
best_value = getattr(self, f"best_value", None)
|
|
if best_value is not None:
|
|
if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
|
|
return
|
|
|
|
setattr(self, "best_value", now_value)
|
|
|
|
|
|
if hasattr(self.trainer.strategy, "deepspeed_engine"):
|
|
if not self.save_weights_only:
|
|
self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
|
|
|
|
|
|
if dist.get_rank() == 0:
|
|
basename = os.path.basename(save_path)
|
|
ckpt_path = os.path.join(save_path, f"{basename}.pt")
|
|
self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
|
|
|
|
|
|
else:
|
|
if dist.get_rank() == 0:
|
|
self.save_checkpoint(save_path, save_info, self.save_weights_only)
|
|
|
|
def reset_metrics(self, stage) -> None:
|
|
"""
|
|
Reset metrics for given stage
|
|
Args:
|
|
stage: "train", "valid" or "test"
|
|
"""
|
|
for metric in self.metrics[stage].values():
|
|
metric.reset()
|
|
|
|
def get_log_dict(self, stage: str) -> dict:
|
|
"""
|
|
Get log dict for the stage
|
|
Args:
|
|
stage: "train", "valid" or "test"
|
|
|
|
Returns:
|
|
A dictionary of metrics for the stage. Keys are metric names and values are metric values
|
|
|
|
"""
|
|
return {name: metric.compute() for name, metric in self.metrics[stage].items()}
|
|
|
|
def log_info(self, info: dict) -> None:
|
|
"""
|
|
Record metrics during training and testing
|
|
Args:
|
|
info: dict of metrics
|
|
"""
|
|
if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
|
|
info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
|
|
info["epoch"] = self.epoch
|
|
self.logger.log_metrics(info, step=self.step)
|
|
|
|
def init_optimizers(self):
|
|
copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
|
|
|
|
|
|
no_decay = ['LayerNorm.weight', 'bias']
|
|
weight_decay = copy_optimizer_kwargs.pop("weight_decay")
|
|
|
|
optimizer_grouped_parameters = [
|
|
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
'weight_decay': weight_decay},
|
|
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
'weight_decay': 0.0}
|
|
]
|
|
|
|
optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
|
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters,
|
|
lr=self.lr_scheduler_kwargs['init_lr'],
|
|
**copy_optimizer_kwargs)
|
|
|
|
tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
|
|
lr_scheduler = tmp_kwargs.pop("class")
|
|
self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
|
|
|
|
def configure_optimizers(self):
|
|
return {"optimizer": self.optimizer,
|
|
"lr_scheduler": {"scheduler": self.lr_scheduler,
|
|
"interval": "step",
|
|
"frequency": 1}
|
|
}
|
|
|