|
from abc import abstractmethod |
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
|
|
class BaseModel(nn.Module): |
|
""" |
|
Base class for all models |
|
""" |
|
|
|
def __init__(self, n_feats, n_class, **batch): |
|
super().__init__() |
|
|
|
@abstractmethod |
|
def forward(self, **batch) -> Union[Tensor, dict]: |
|
""" |
|
Forward pass logic. |
|
Can return a torch.Tensor (it will be interpreted as logits) or a dict. |
|
|
|
:return: Model output |
|
""" |
|
raise NotImplementedError() |
|
|
|
def __str__(self): |
|
""" |
|
Model prints with number of trainable parameters |
|
""" |
|
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) |
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
return super().__str__() + "\nTrainable parameters: {}".format(params) |
|
|
|
def transform_input_lengths(self, input_lengths): |
|
""" |
|
Input length transformation function. |
|
For example: if your NN transforms spectrogram of time-length `N` into an |
|
output with time-length `N / 2`, then this function should return `input_lengths // 2` |
|
""" |
|
raise NotImplementedError() |
|
|