from abc import abstractmethod from typing import Union import numpy as np import torch.nn as nn from torch import Tensor class BaseModel(nn.Module): """ Base class for all models """ def __init__(self, n_feats, n_class, **batch): super().__init__() @abstractmethod def forward(self, **batch) -> Union[Tensor, dict]: """ Forward pass logic. Can return a torch.Tensor (it will be interpreted as logits) or a dict. :return: Model output """ raise NotImplementedError() def __str__(self): """ Model prints with number of trainable parameters """ model_parameters = filter(lambda p: p.requires_grad, self.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) return super().__str__() + "\nTrainable parameters: {}".format(params) def transform_input_lengths(self, input_lengths): """ Input length transformation function. For example: if your NN transforms spectrogram of time-length `N` into an output with time-length `N / 2`, then this function should return `input_lengths // 2` """ raise NotImplementedError()