|
from abc import abstractmethod |
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
from pytorch_lightning import LightningModule |
|
from torch import Tensor |
|
|
|
|
|
class LightningRegression(LightningModule): |
|
|
|
@abstractmethod |
|
def __init__(self, *args, **kwargs) -> None: |
|
super(LightningRegression, self).__init__(*args, **kwargs) |
|
self.train_step_output: List[Dict] = [] |
|
self.validation_step_output: List[Dict] = [] |
|
self.log_value_list: List[str] = ['loss', 'mse', 'mape'] |
|
|
|
@abstractmethod |
|
def forward(self, *args, **kwargs) -> Any: |
|
pass |
|
|
|
@abstractmethod |
|
def configure_optimizers(self): |
|
pass |
|
|
|
@abstractmethod |
|
def loss(self, input: Tensor, output: Tensor, **kwargs): |
|
return 0 |
|
|
|
@abstractmethod |
|
def training_step(self, batch, batch_idx): |
|
pass |
|
|
|
def __average(self, key: str, outputs: List[Dict]) -> Tensor: |
|
target_arr = torch.Tensor([val[key] for val in outputs]).float() |
|
return target_arr.mean() |
|
|
|
@torch.no_grad() |
|
def on_train_epoch_end(self) -> None: |
|
for key in self.log_value_list: |
|
val = self.__average(key=key, outputs=self.train_step_output) |
|
log_name = f"training/{key}" |
|
self.log(name=log_name, value=val) |
|
|
|
@torch.no_grad() |
|
@abstractmethod |
|
def validation_step(self, batch, batch_idx): |
|
pass |
|
|
|
@torch.no_grad() |
|
def validation_epoch_end(self, outputs): |
|
for key in self.log_value_list: |
|
val = self.__average(key=key, outputs=self.validation_step_output) |
|
log_name = f"training/{key}" |
|
self.log(name=log_name, value=val) |
|
|