import logging import torch.nn as nn import numpy as np class BaseModel(nn.Module): """ Base class for all models """ def __init__(self, config): super(BaseModel, self).__init__() self.config = config self.logger = logging.getLogger(self.__class__.__name__) def forward(self, *input): """ Forward pass logic :return: Model output """ raise NotImplementedError def summary(self): """ Model summary """ model_parameters = filter(lambda p: p.requires_grad, self.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) self.logger.info('Trainable parameters: {}'.format(params)) self.logger.info(self)