mueller-franzes's picture
init
f85e212
raw
history blame
4.34 kB
from pathlib import Path
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.migration import pl_legacy_patch
class VeryBasicModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self._step_train = 0
self._step_val = 0
self._step_test = 0
def forward(self, x_in):
raise NotImplementedError
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int):
raise NotImplementedError
def training_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0 ):
self._step_train += 1 # =self.global_step
return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx)
def validation_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0):
self._step_val += 1
return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx )
def test_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0):
self._step_test += 1
return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx)
def _epoch_end(self, outputs: list, state: str):
return
def training_epoch_end(self, outputs):
self._epoch_end(outputs, "train")
def validation_epoch_end(self, outputs):
self._epoch_end(outputs, "val")
def test_epoch_end(self, outputs):
self._epoch_end(outputs, "test")
@classmethod
def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path):
with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f:
json.dump({'best_model_epoch': Path(best_model_path).name}, f)
@classmethod
def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs):
path_version = 'lightning_logs/version_'+str(version)
with open(Path(path_checkpoint_dir) / path_version/ 'best_checkpoint.json', 'r') as f:
path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch'])
return Path(path_checkpoint_dir)/path_rel_best_checkpoint
@classmethod
def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs):
path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version)
return cls.load_from_checkpoint(path_best_checkpoint, **kwargs)
def load_pretrained(self, checkpoint_path, map_location=None, **kwargs):
if checkpoint_path.is_dir():
checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs)
with pl_legacy_patch():
if map_location is not None:
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
return self.load_weights(checkpoint["state_dict"], **kwargs)
def load_weights(self, pretrained_weights, strict=True, **kwargs):
filter = kwargs.get('filter', lambda key:key in pretrained_weights)
init_weights = self.state_dict()
pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter(key)}
init_weights.update(pretrained_weights)
self.load_state_dict(init_weights, strict=strict)
return self
class BasicModel(VeryBasicModel):
def __init__(self,
optimizer=torch.optim.AdamW,
optimizer_kwargs={'lr':1e-3, 'weight_decay':1e-2},
lr_scheduler= None,
lr_scheduler_kwargs={},
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs)
if self.lr_scheduler is not None:
lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs)
return [optimizer], [lr_scheduler]
else:
return [optimizer]