# # Copyright (c) OpenMMLab. All rights reserved. # import logging # from abc import ABCMeta, abstractmethod # # import torch.nn as nn # # from .utils import load_checkpoint # # # class BaseBackbone(nn.Module, metaclass=ABCMeta): # """Base backbone. # # This class defines the basic functions of a backbone. Any backbone that # inherits this class should at least define its own `forward` function. # """ # # def init_weights(self, pretrained=None): # """Init backbone weights. # # Args: # pretrained (str | None): If pretrained is a string, then it # initializes backbone weights by loading the pretrained # checkpoint. If pretrained is None, then it follows default # initializer or customized initializer in subclasses. # """ # if isinstance(pretrained, str): # logger = logging.getLogger() # load_checkpoint(self, pretrained, strict=False, logger=logger) # elif pretrained is None: # # use default initializer or customized initializer in subclasses # pass # else: # raise TypeError('pretrained must be a str or None.' # f' But received {type(pretrained)}.') # # @abstractmethod # def forward(self, x): # """Forward function. # # Args: # x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of # torch.Tensor, containing input data for forward computation. # """ # Copyright (c) OpenMMLab. All rights reserved. import logging from abc import ABCMeta, abstractmethod import torch.nn as nn from .utils import load_checkpoint # from mmcv_custom.checkpoint import load_checkpoint class BaseBackbone(nn.Module, metaclass=ABCMeta): """Base backbone. This class defines the basic functions of a backbone. Any backbone that inherits this class should at least define its own `forward` function. """ def init_weights(self, pretrained=None, patch_padding='pad'): """Init backbone weights. Args: pretrained (str | None): If pretrained is a string, then it initializes backbone weights by loading the pretrained checkpoint. If pretrained is None, then it follows default initializer or customized initializer in subclasses. """ if isinstance(pretrained, str): logger = logging.getLogger() load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding) elif pretrained is None: # use default initializer or customized initializer in subclasses pass else: raise TypeError('pretrained must be a str or None.' f' But received {type(pretrained)}.') @abstractmethod def forward(self, x): """Forward function. Args: x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of torch.Tensor, containing input data for forward computation. """