File size: 3,131 Bytes
2de1f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# # Copyright (c) OpenMMLab. All rights reserved.
# import logging
# from abc import ABCMeta, abstractmethod
#
# import torch.nn as nn
#
# from .utils import load_checkpoint
#
#
# class BaseBackbone(nn.Module, metaclass=ABCMeta):
#     """Base backbone.
#
#     This class defines the basic functions of a backbone. Any backbone that
#     inherits this class should at least define its own `forward` function.
#     """
#
#     def init_weights(self, pretrained=None):
#         """Init backbone weights.
#
#         Args:
#             pretrained (str | None): If pretrained is a string, then it
#                 initializes backbone weights by loading the pretrained
#                 checkpoint. If pretrained is None, then it follows default
#                 initializer or customized initializer in subclasses.
#         """
#         if isinstance(pretrained, str):
#             logger = logging.getLogger()
#             load_checkpoint(self, pretrained, strict=False, logger=logger)
#         elif pretrained is None:
#             # use default initializer or customized initializer in subclasses
#             pass
#         else:
#             raise TypeError('pretrained must be a str or None.'
#                             f' But received {type(pretrained)}.')
#
#     @abstractmethod
#     def forward(self, x):
#         """Forward function.
#
#         Args:
#             x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of
#                 torch.Tensor, containing input data for forward computation.
#         """
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from abc import ABCMeta, abstractmethod

import torch.nn as nn

from .utils import load_checkpoint
# from mmcv_custom.checkpoint import load_checkpoint

class BaseBackbone(nn.Module, metaclass=ABCMeta):
    """Base backbone.
    This class defines the basic functions of a backbone. Any backbone that
    inherits this class should at least define its own `forward` function.
    """

    def init_weights(self, pretrained=None, patch_padding='pad'):
        """Init backbone weights.
        Args:
            pretrained (str | None): If pretrained is a string, then it
                initializes backbone weights by loading the pretrained
                checkpoint. If pretrained is None, then it follows default
                initializer or customized initializer in subclasses.
        """
        if isinstance(pretrained, str):
            logger = logging.getLogger()
            load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding)
        elif pretrained is None:
            # use default initializer or customized initializer in subclasses
            pass
        else:
            raise TypeError('pretrained must be a str or None.'
                            f' But received {type(pretrained)}.')

    @abstractmethod
    def forward(self, x):
        """Forward function.
        Args:
            x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of
                torch.Tensor, containing input data for forward computation.
        """