Demo_ProTrek_650M_UniRef50 / model /abstract_model.py
LTEnjoy's picture
Upload 21 files
52da96f verified
raw
history blame
15 kB
import torch
import abc
import os
import copy
import pytorch_lightning as pl
from utils.lr_scheduler import *
from torch import distributed as dist
class AbstractModel(pl.LightningModule):
def __init__(self,
lr_scheduler_kwargs: dict = None,
optimizer_kwargs: dict = None,
save_path: str = None,
from_checkpoint: str = None,
load_prev_scheduler: bool = False,
save_weights_only: bool = True,):
"""
Args:
lr_scheduler: Kwargs for lr_scheduler
optimizer_kwargs: Kwargs for optimizer_kwargs
save_path: Save trained model
from_checkpoint: Load model from checkpoint
load_prev_scheduler: Whether load previous scheduler from checkpoint
load_strict: Whether load model strictly
save_weights_only: Whether save only weights or also optimizer and lr_scheduler
"""
super().__init__()
self.initialize_model()
self.metrics = {}
for stage in ["train", "valid", "test"]:
stage_metrics = self.initialize_metrics(stage)
# Rigister metrics as attributes
for metric_name, metric in stage_metrics.items():
setattr(self, metric_name, metric)
self.metrics[stage] = stage_metrics
if lr_scheduler_kwargs is None:
# Default lr_scheduler
self.lr_scheduler_kwargs = {
"class": "ConstantLRScheduler",
"init_lr": 0,
}
print("No lr_scheduler_kwargs provided. The default learning rate is 0.")
else:
self.lr_scheduler_kwargs = lr_scheduler_kwargs
if optimizer_kwargs is None:
# Default optimizer
self.optimizer_kwargs = {
"class": "AdamW",
"betas": (0.9, 0.98),
"weight_decay": 0.01,
}
print("No optimizer_kwargs provided. The default optimizer is AdamW.")
else:
self.optimizer_kwargs = optimizer_kwargs
self.init_optimizers()
self.save_path = save_path
self.save_weights_only = save_weights_only
# temp_step is used for accumulating gradients
self.temp_step = 0
self.step = 0
self.epoch = 0
self.load_prev_scheduler = load_prev_scheduler
self.from_checkpoint = from_checkpoint
if from_checkpoint:
self.load_checkpoint(from_checkpoint)
@abc.abstractmethod
def initialize_model(self) -> None:
"""
All model initialization should be done here
Note that the whole model must be named as "self.model" for model saving and loading
"""
raise NotImplementedError
@abc.abstractmethod
def forward(self, *args, **kwargs):
"""
Forward propagation
"""
raise NotImplementedError
@abc.abstractmethod
def initialize_metrics(self, stage: str) -> dict:
"""
Initialize metrics for each stage
Args:
stage: "train", "valid" or "test"
Returns:
A dictionary of metrics for the stage. Keys are metric names and values are metric objects
"""
raise NotImplementedError
@abc.abstractmethod
def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
"""
Args:
stage: "train", "valid" or "test"
outputs: model outputs for calculating loss
labels: labels for calculating loss
Returns:
loss
"""
raise NotImplementedError
@staticmethod
def load_weights(model, weights):
model_dict = model.state_dict()
unused_params = []
missed_params = list(model_dict.keys())
for k, v in weights.items():
if k in model_dict.keys():
model_dict[k] = v
missed_params.remove(k)
else:
unused_params.append(k)
if len(missed_params) > 0:
print(f"\033[31mSome weights of {type(model).__name__} were not "
f"initialized from the model checkpoint: {missed_params}\033[0m")
if len(unused_params) > 0:
print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")
model.load_state_dict(model_dict)
def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer,
optimizer_closure=None,
) -> None:
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)
self.temp_step += 1
if self.temp_step == self.trainer.accumulate_grad_batches:
self.step += 1
self.temp_step = 0
# For pytorch-lightning 1.9.5
# def optimizer_step(
# self,
# epoch: int,
# batch_idx: int,
# optimizer,
# optimizer_idx: int = 0,
# optimizer_closure=None,
# on_tpu: bool = False,
# using_native_amp: bool = False,
# using_lbfgs: bool = False,
# ) -> None:
# super().optimizer_step(
# epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
# )
# self.temp_step += 1
# if self.temp_step == self.trainer.accumulate_grad_batches:
# self.step += 1
# self.temp_step = 0
def on_train_epoch_end(self):
self.epoch += 1
def training_step(self, batch, batch_idx):
inputs, labels = batch
# optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
# for _ in range(1000):
# outputs = self(**inputs)
# loss = self.loss_func('train', outputs, labels)
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
#
# raise
outputs = self(**inputs)
loss = self.loss_func('train', outputs, labels)
self.log("loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(**inputs)
loss = self.loss_func('valid', outputs, labels)
self.valid_outputs.append(loss)
return loss
def test_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(**inputs)
loss = self.loss_func('test', outputs, labels)
self.test_outputs.append(loss)
return loss
def on_train_start(self) -> None:
# Load previous scheduler
if getattr(self, "prev_schechuler", None) is not None:
try:
self.step = self.prev_schechuler["global_step"]
self.epoch = self.prev_schechuler["epoch"]
self.best_value = self.prev_schechuler["best_value"]
self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
print(f"Previous training global step: {self.step}")
print(f"Previous training epoch: {self.epoch}")
print(f"Previous best value: {self.best_value}")
print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
# Load optimizer state
if hasattr(self.trainer.strategy, "deepspeed_engine"):
# For DeepSpeed strategy
try:
self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
except Exception as e:
print(e)
else:
# For DDP strategy
self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])
except Exception as e:
print(e)
raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
def on_validation_epoch_start(self) -> None:
setattr(self, "valid_outputs", [])
def on_test_epoch_start(self) -> None:
setattr(self, "test_outputs", [])
def load_checkpoint(self, from_checkpoint: str) -> None:
"""
Args:
from_checkpoint: Path to checkpoint.
"""
# If ``from_checkpoint`` is a directory, load the checkpoint in it
if os.path.isdir(from_checkpoint):
basename = os.path.basename(from_checkpoint)
from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")
state_dict = torch.load(from_checkpoint, map_location=self.device)
self.load_weights(self.model, state_dict["model"])
if self.load_prev_scheduler:
state_dict.pop("model")
self.prev_schechuler = state_dict
def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
"""
Save model to save_path
Args:
save_path: Path to save model
save_info: Other info to save
save_weights_only: Whether only save model weights
"""
dir = os.path.dirname(save_path)
os.makedirs(dir, exist_ok=True)
state_dict = {} if save_info is None else save_info
state_dict["model"] = self.model.state_dict()
# Convert model weights to fp32
for k, v in state_dict["model"].items():
state_dict["model"][k] = v.float()
if not save_weights_only:
state_dict["global_step"] = self.step
state_dict["epoch"] = self.epoch
state_dict["best_value"] = getattr(self, f"best_value", None)
state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
# If not using DeepSpeed, save optimizer state
if not hasattr(self.trainer.strategy, "deepspeed_engine"):
state_dict["optimizer"] = self.optimizers().optimizer.state_dict()
torch.save(state_dict, save_path)
def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
"""
Check whether to save model. If save_path is not None and now_value is the best, save model.
Args:
now_value: Current metric value
mode: "min" or "max", meaning whether the lower the better or the higher the better
save_info: Other info to save
"""
assert mode in ["min", "max"], "mode should be 'min' or 'max'"
if self.save_path is not None:
# In case there are variables to be included in the save path
save_path = eval(f"f'{self.save_path}'")
dir = os.path.dirname(save_path)
os.makedirs(dir, exist_ok=True)
# Check whether to save model
best_value = getattr(self, f"best_value", None)
if best_value is not None:
if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
return
setattr(self, "best_value", now_value)
# For DeepSpeed strategy
if hasattr(self.trainer.strategy, "deepspeed_engine"):
if not self.save_weights_only:
self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
# Save a complete checkpoint
if dist.get_rank() == 0:
basename = os.path.basename(save_path)
ckpt_path = os.path.join(save_path, f"{basename}.pt")
self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
# For normal situation
else:
if dist.get_rank() == 0:
self.save_checkpoint(save_path, save_info, self.save_weights_only)
def reset_metrics(self, stage) -> None:
"""
Reset metrics for given stage
Args:
stage: "train", "valid" or "test"
"""
for metric in self.metrics[stage].values():
metric.reset()
def get_log_dict(self, stage: str) -> dict:
"""
Get log dict for the stage
Args:
stage: "train", "valid" or "test"
Returns:
A dictionary of metrics for the stage. Keys are metric names and values are metric values
"""
return {name: metric.compute() for name, metric in self.metrics[stage].items()}
def log_info(self, info: dict) -> None:
"""
Record metrics during training and testing
Args:
info: dict of metrics
"""
if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
info["epoch"] = self.epoch
self.logger.log_metrics(info, step=self.step)
def init_optimizers(self):
copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
# No decay for layer norm and bias
no_decay = ['LayerNorm.weight', 'bias']
weight_decay = copy_optimizer_kwargs.pop("weight_decay")
optimizer_grouped_parameters = [
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
self.optimizer = optimizer_cls(optimizer_grouped_parameters,
lr=self.lr_scheduler_kwargs['init_lr'],
**copy_optimizer_kwargs)
tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
lr_scheduler = tmp_kwargs.pop("class")
self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
def configure_optimizers(self):
return {"optimizer": self.optimizer,
"lr_scheduler": {"scheduler": self.lr_scheduler,
"interval": "step",
"frequency": 1}
}