|
|
|
from ..builder import CLASSIFIERS |
|
from ..heads import MultiLabelClsHead |
|
from .image import ImageClassifier |
|
|
|
|
|
@CLASSIFIERS.register_module() |
|
class MetadataClassifier(ImageClassifier): |
|
|
|
def forward_train(self, img, gt_label, img_metas, **kwargs): |
|
"""Forward computation during training. |
|
|
|
Args: |
|
img (Tensor): of shape (N, C, H, W) encoding input images. |
|
Typically these should be mean centered and std scaled. |
|
gt_label (Tensor): It should be of shape (N, 1) encoding the |
|
ground-truth label of input images for single label task. It |
|
shoulf be of shape (N, C) encoding the ground-truth label |
|
of input images for multi-labels task. |
|
Returns: |
|
dict[str, Tensor]: a dictionary of loss components |
|
""" |
|
if self.augments is not None: |
|
img, gt_label = self.augments(img, gt_label) |
|
|
|
x = self.extract_feat(img) |
|
|
|
losses = dict() |
|
loss = self.head.forward_train(x, gt_label, img_metas) |
|
|
|
losses.update(loss) |
|
|
|
return losses |
|
|
|
def simple_test(self, img, img_metas=None, **kwargs): |
|
"""Test without augmentation.""" |
|
x = self.extract_feat(img) |
|
|
|
if isinstance(self.head, MultiLabelClsHead): |
|
assert 'softmax' not in kwargs, ( |
|
'Please use `sigmoid` instead of `softmax` ' |
|
'in multi-label tasks.') |
|
res = self.head.simple_test(x, img_metas, **kwargs) |
|
|
|
return res |
|
|