Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from pytorch_lightning.utilities.cloud_io import load as pl_load | |
from pytorch_lightning.utilities.migration import pl_legacy_patch | |
class VeryBasicModel(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.save_hyperparameters() | |
self._step_train = 0 | |
self._step_val = 0 | |
self._step_test = 0 | |
def forward(self, x_in): | |
raise NotImplementedError | |
def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): | |
raise NotImplementedError | |
def training_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0 ): | |
self._step_train += 1 # =self.global_step | |
return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx) | |
def validation_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0): | |
self._step_val += 1 | |
return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx ) | |
def test_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0): | |
self._step_test += 1 | |
return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx) | |
def _epoch_end(self, outputs: list, state: str): | |
return | |
def training_epoch_end(self, outputs): | |
self._epoch_end(outputs, "train") | |
def validation_epoch_end(self, outputs): | |
self._epoch_end(outputs, "val") | |
def test_epoch_end(self, outputs): | |
self._epoch_end(outputs, "test") | |
def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path): | |
with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f: | |
json.dump({'best_model_epoch': Path(best_model_path).name}, f) | |
def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs): | |
path_version = 'lightning_logs/version_'+str(version) | |
with open(Path(path_checkpoint_dir) / path_version/ 'best_checkpoint.json', 'r') as f: | |
path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch']) | |
return Path(path_checkpoint_dir)/path_rel_best_checkpoint | |
def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs): | |
path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version) | |
return cls.load_from_checkpoint(path_best_checkpoint, **kwargs) | |
def load_pretrained(self, checkpoint_path, map_location=None, **kwargs): | |
if checkpoint_path.is_dir(): | |
checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs) | |
with pl_legacy_patch(): | |
if map_location is not None: | |
checkpoint = pl_load(checkpoint_path, map_location=map_location) | |
else: | |
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) | |
return self.load_weights(checkpoint["state_dict"], **kwargs) | |
def load_weights(self, pretrained_weights, strict=True, **kwargs): | |
filter = kwargs.get('filter', lambda key:key in pretrained_weights) | |
init_weights = self.state_dict() | |
pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter(key)} | |
init_weights.update(pretrained_weights) | |
self.load_state_dict(init_weights, strict=strict) | |
return self | |
class BasicModel(VeryBasicModel): | |
def __init__(self, | |
optimizer=torch.optim.AdamW, | |
optimizer_kwargs={'lr':1e-3, 'weight_decay':1e-2}, | |
lr_scheduler= None, | |
lr_scheduler_kwargs={}, | |
): | |
super().__init__() | |
self.save_hyperparameters() | |
self.optimizer = optimizer | |
self.optimizer_kwargs = optimizer_kwargs | |
self.lr_scheduler = lr_scheduler | |
self.lr_scheduler_kwargs = lr_scheduler_kwargs | |
def configure_optimizers(self): | |
optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs) | |
if self.lr_scheduler is not None: | |
lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) | |
return [optimizer], [lr_scheduler] | |
else: | |
return [optimizer] | |