# Copyright (c) OpenMMLab. All rights reserved. from ..builder import CLASSIFIERS from ..heads import MultiLabelClsHead from .image import ImageClassifier @CLASSIFIERS.register_module() class MetadataClassifier(ImageClassifier): def forward_train(self, img, gt_label, img_metas, **kwargs): """Forward computation during training. Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. gt_label (Tensor): It should be of shape (N, 1) encoding the ground-truth label of input images for single label task. It shoulf be of shape (N, C) encoding the ground-truth label of input images for multi-labels task. Returns: dict[str, Tensor]: a dictionary of loss components """ if self.augments is not None: img, gt_label = self.augments(img, gt_label) x = self.extract_feat(img) losses = dict() loss = self.head.forward_train(x, gt_label, img_metas) losses.update(loss) return losses def simple_test(self, img, img_metas=None, **kwargs): """Test without augmentation.""" x = self.extract_feat(img) if isinstance(self.head, MultiLabelClsHead): assert 'softmax' not in kwargs, ( 'Please use `sigmoid` instead of `softmax` ' 'in multi-label tasks.') res = self.head.simple_test(x, img_metas, **kwargs) return res