""" This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/test_time_augmentation.py """ import copy import logging from itertools import count import numpy as np import torch from fvcore.transforms import HFlipTransform from torch import nn from torch.nn.parallel import DistributedDataParallel from detectron2.data.detection_utils import read_image from detectron2.modeling import DatasetMapperTTA __all__ = [ "SemanticSegmentorWithTTA", ] class SemanticSegmentorWithTTA(nn.Module): """ A SemanticSegmentor with test-time augmentation enabled. Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. """ def __init__(self, cfg, model, tta_mapper=None, batch_size=1): """ Args: cfg (CfgNode): model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. tta_mapper (callable): takes a dataset dict and returns a list of augmented versions of the dataset dict. Defaults to `DatasetMapperTTA(cfg)`. batch_size (int): batch the augmented images into this batch size for inference. """ super().__init__() if isinstance(model, DistributedDataParallel): model = model.module self.cfg = cfg.clone() self.model = model if tta_mapper is None: tta_mapper = DatasetMapperTTA(cfg) self.tta_mapper = tta_mapper self.batch_size = batch_size def __call__(self, batched_inputs): """ Same input/output format as :meth:`SemanticSegmentor.forward` """ def _maybe_read_image(dataset_dict): ret = copy.copy(dataset_dict) if "image" not in ret: image = read_image(ret.pop("file_name"), self.model.input_format) image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW ret["image"] = image if "height" not in ret and "width" not in ret: ret["height"] = image.shape[1] ret["width"] = image.shape[2] return ret processed_results = [] for x in batched_inputs: result = self._inference_one_image(_maybe_read_image(x)) processed_results.append(result) return processed_results def _inference_one_image(self, input): """ Args: input (dict): one dataset dict with "image" field being a CHW tensor Returns: dict: one output dict """ orig_shape = (input["height"], input["width"]) augmented_inputs, tfms = self._get_augmented_inputs(input) final_predictions = None count_predictions = 0 for input, tfm in zip(augmented_inputs, tfms): count_predictions += 1 with torch.no_grad(): if final_predictions is None: if any(isinstance(t, HFlipTransform) for t in tfm.transforms): final_predictions = self.model([input])[0].pop("sem_seg").flip(dims=[2]) else: final_predictions = self.model([input])[0].pop("sem_seg") else: if any(isinstance(t, HFlipTransform) for t in tfm.transforms): final_predictions += self.model([input])[0].pop("sem_seg").flip(dims=[2]) else: final_predictions += self.model([input])[0].pop("sem_seg") final_predictions = final_predictions / count_predictions return {"sem_seg": final_predictions} def _get_augmented_inputs(self, input): augmented_inputs = self.tta_mapper(input) tfms = [x.pop("transforms") for x in augmented_inputs] return augmented_inputs, tfms