PAIR-Diffusion / annotator /OneFormer /oneformer /test_time_augmentation.py
vidit98's picture
demo files
2171e8f
raw
history blame
4.1 kB
# ------------------------------------------------------------------------------
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/test_time_augmentation.py
# ------------------------------------------------------------------------------
import copy
import logging
from itertools import count
import numpy as np
import torch
from fvcore.transforms import HFlipTransform
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from detectron2.data.detection_utils import read_image
from .datasetmapper_tta import DatasetMapperTTA
import torch.nn.functional as F
__all__ = [
"SemanticSegmentorWithTTA",
]
class SemanticSegmentorWithTTA(nn.Module):
"""
A SemanticSegmentor with test-time augmentation enabled.
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
"""
def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
"""
Args:
cfg (CfgNode):
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
tta_mapper (callable): takes a dataset dict and returns a list of
augmented versions of the dataset dict. Defaults to
`DatasetMapperTTA(cfg)`.
batch_size (int): batch the augmented images into this batch size for inference.
"""
super().__init__()
if isinstance(model, DistributedDataParallel):
model = model.module
self.cfg = cfg.clone()
self.num_classes = self.cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
self.model = model
if tta_mapper is None:
tta_mapper = DatasetMapperTTA(cfg)
self.tta_mapper = tta_mapper
self.batch_size = batch_size
def __call__(self, batched_inputs):
"""
Same input/output format as :meth:`SemanticSegmentor.forward`
"""
def _maybe_read_image(dataset_dict):
ret = copy.copy(dataset_dict)
if "image" not in ret:
image = read_image(ret.pop("file_name"), self.model.input_format)
image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW
ret["image"] = image
if "height" not in ret and "width" not in ret:
ret["height"] = image.shape[1]
ret["width"] = image.shape[2]
return ret
processed_results = []
for x in batched_inputs:
result = self._inference_one_image(_maybe_read_image(x))
processed_results.append(result)
return processed_results
def _inference_one_image(self, input):
"""
Args:
input (dict): one dataset dict with "image" field being a CHW tensor
Returns:
dict: one output dict
"""
orig_shape = (input["height"], input["width"])
augmented_inputs, tfms = self._get_augmented_inputs(input)
final_predictions = None
count_predictions = 0
for input, tfm in zip(augmented_inputs, tfms):
count_predictions += 1
with torch.no_grad():
if final_predictions is None:
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
final_predictions = self.model([input])[0].pop("sem_seg").flip(dims=[2])
else:
final_predictions = self.model([input])[0].pop("sem_seg")
else:
if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
final_predictions += self.model([input])[0].pop("sem_seg").flip(dims=[2])
else:
final_predictions += self.model([input])[0].pop("sem_seg")
final_predictions = final_predictions / count_predictions
return {"sem_seg": final_predictions}
def _get_augmented_inputs(self, input):
augmented_inputs = self.tta_mapper(input)
tfms = [x.pop("transforms") for x in augmented_inputs]
return augmented_inputs, tfms