File size: 1,227 Bytes
affcd23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from abc import abstractmethod
from typing import Union
import numpy as np
import torch.nn as nn
from torch import Tensor
class BaseModel(nn.Module):
"""
Base class for all models
"""
def __init__(self, n_feats, n_class, **batch):
super().__init__()
@abstractmethod
def forward(self, **batch) -> Union[Tensor, dict]:
"""
Forward pass logic.
Can return a torch.Tensor (it will be interpreted as logits) or a dict.
:return: Model output
"""
raise NotImplementedError()
def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super().__str__() + "\nTrainable parameters: {}".format(params)
def transform_input_lengths(self, input_lengths):
"""
Input length transformation function.
For example: if your NN transforms spectrogram of time-length `N` into an
output with time-length `N / 2`, then this function should return `input_lengths // 2`
"""
raise NotImplementedError()
|