Spaces:
Running
on
Zero
Running
on
Zero
""" | |
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). | |
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. | |
Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/test_time_augmentation.py | |
""" | |
import copy | |
import logging | |
from itertools import count | |
import numpy as np | |
import torch | |
from fvcore.transforms import HFlipTransform | |
from torch import nn | |
from torch.nn.parallel import DistributedDataParallel | |
from detectron2.data.detection_utils import read_image | |
from detectron2.modeling import DatasetMapperTTA | |
__all__ = [ | |
"SemanticSegmentorWithTTA", | |
] | |
class SemanticSegmentorWithTTA(nn.Module): | |
""" | |
A SemanticSegmentor with test-time augmentation enabled. | |
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. | |
""" | |
def __init__(self, cfg, model, tta_mapper=None, batch_size=1): | |
""" | |
Args: | |
cfg (CfgNode): | |
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. | |
tta_mapper (callable): takes a dataset dict and returns a list of | |
augmented versions of the dataset dict. Defaults to | |
`DatasetMapperTTA(cfg)`. | |
batch_size (int): batch the augmented images into this batch size for inference. | |
""" | |
super().__init__() | |
if isinstance(model, DistributedDataParallel): | |
model = model.module | |
self.cfg = cfg.clone() | |
self.model = model | |
if tta_mapper is None: | |
tta_mapper = DatasetMapperTTA(cfg) | |
self.tta_mapper = tta_mapper | |
self.batch_size = batch_size | |
def __call__(self, batched_inputs): | |
""" | |
Same input/output format as :meth:`SemanticSegmentor.forward` | |
""" | |
def _maybe_read_image(dataset_dict): | |
ret = copy.copy(dataset_dict) | |
if "image" not in ret: | |
image = read_image(ret.pop("file_name"), self.model.input_format) | |
image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW | |
ret["image"] = image | |
if "height" not in ret and "width" not in ret: | |
ret["height"] = image.shape[1] | |
ret["width"] = image.shape[2] | |
return ret | |
processed_results = [] | |
for x in batched_inputs: | |
result = self._inference_one_image(_maybe_read_image(x)) | |
processed_results.append(result) | |
return processed_results | |
def _inference_one_image(self, input): | |
""" | |
Args: | |
input (dict): one dataset dict with "image" field being a CHW tensor | |
Returns: | |
dict: one output dict | |
""" | |
orig_shape = (input["height"], input["width"]) | |
augmented_inputs, tfms = self._get_augmented_inputs(input) | |
final_predictions = None | |
count_predictions = 0 | |
for input, tfm in zip(augmented_inputs, tfms): | |
count_predictions += 1 | |
with torch.no_grad(): | |
if final_predictions is None: | |
if any(isinstance(t, HFlipTransform) for t in tfm.transforms): | |
final_predictions = self.model([input])[0].pop("sem_seg").flip(dims=[2]) | |
else: | |
final_predictions = self.model([input])[0].pop("sem_seg") | |
else: | |
if any(isinstance(t, HFlipTransform) for t in tfm.transforms): | |
final_predictions += self.model([input])[0].pop("sem_seg").flip(dims=[2]) | |
else: | |
final_predictions += self.model([input])[0].pop("sem_seg") | |
final_predictions = final_predictions / count_predictions | |
return {"sem_seg": final_predictions} | |
def _get_augmented_inputs(self, input): | |
augmented_inputs = self.tta_mapper(input) | |
tfms = [x.pop("transforms") for x in augmented_inputs] | |
return augmented_inputs, tfms | |