Spaces:
Runtime error
Runtime error
File size: 1,601 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Tuple
from mmengine.model import BaseModule
from mmengine.structures import BaseDataElement
class BaseHead(BaseModule, metaclass=ABCMeta):
"""Base head.
Args:
init_cfg (dict, optional): The extra init config of layers.
Defaults to None.
"""
def __init__(self, init_cfg: Optional[dict] = None):
super(BaseHead, self).__init__(init_cfg=init_cfg)
@abstractmethod
def loss(self, feats: Tuple, data_samples: List[BaseDataElement]):
"""Calculate losses from the extracted features.
Args:
feats (tuple): The features extracted from the backbone.
data_samples (List[BaseDataElement]): The annotation data of
every samples.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
pass
@abstractmethod
def predict(self,
feats: Tuple,
data_samples: Optional[List[BaseDataElement]] = None):
"""Predict results from the extracted features.
Args:
feats (tuple): The features extracted from the backbone.
data_samples (List[BaseDataElement], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[BaseDataElement]: A list of data samples which contains the
predicted results.
"""
pass
|