Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import logging | |
from abc import ABCMeta, abstractmethod | |
import torch.nn as nn | |
# from .utils import load_checkpoint | |
from mmcv_custom.checkpoint import load_checkpoint | |
class BaseBackbone(nn.Module, metaclass=ABCMeta): | |
"""Base backbone. | |
This class defines the basic functions of a backbone. Any backbone that | |
inherits this class should at least define its own `forward` function. | |
""" | |
def init_weights(self, pretrained=None, patch_padding='pad', part_features=None): | |
"""Init backbone weights. | |
Args: | |
pretrained (str | None): If pretrained is a string, then it | |
initializes backbone weights by loading the pretrained | |
checkpoint. If pretrained is None, then it follows default | |
initializer or customized initializer in subclasses. | |
""" | |
if isinstance(pretrained, str): | |
logger = logging.getLogger() | |
load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding, part_features=part_features) | |
elif pretrained is None: | |
# use default initializer or customized initializer in subclasses | |
pass | |
else: | |
raise TypeError('pretrained must be a str or None.' | |
f' But received {type(pretrained)}.') | |
def forward(self, x): | |
"""Forward function. | |
Args: | |
x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of | |
torch.Tensor, containing input data for forward computation. | |
""" | |