Spaces:
Runtime error
Runtime error
File size: 3,165 Bytes
2171e8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import copy
import numpy as np
from typing import List
import torch
from fvcore.transforms import NoOpTransform
from torch import nn
from detectron2.config import configurable
from detectron2.data.transforms import (
RandomFlip,
ResizeShortestEdge,
ResizeTransform,
apply_augmentations,
)
__all__ = ["DatasetMapperTTA"]
class DatasetMapperTTA:
"""
Implement test-time augmentation for detection data.
It is a callable which takes a dataset dict from a detection dataset,
and returns a list of dataset dicts where the images
are augmented from the input image by the transformations defined in the config.
This is used for test-time augmentation.
"""
@configurable
def __init__(self, min_sizes: List[int], max_size: int, flip: bool):
"""
Args:
min_sizes: list of short-edge size to resize the image to
max_size: maximum height or width of resized images
flip: whether to apply flipping augmentation
"""
self.min_sizes = min_sizes
self.max_size = max_size
self.flip = flip
@classmethod
def from_config(cls, cfg):
return {
"min_sizes": cfg.TEST.AUG.MIN_SIZES,
"max_size": cfg.TEST.AUG.MAX_SIZE,
"flip": cfg.TEST.AUG.FLIP,
}
def __call__(self, dataset_dict):
"""
Args:
dict: a dict in standard model input format. See tutorials for details.
Returns:
list[dict]:
a list of dicts, which contain augmented version of the input image.
The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``.
Each dict has field "transforms" which is a TransformList,
containing the transforms that are used to generate this image.
"""
numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy()
shape = numpy_image.shape
orig_shape = (dataset_dict["height"], dataset_dict["width"])
if shape[:2] != orig_shape:
# It transforms the "original" image in the dataset to the input image
pre_tfm = ResizeTransform(orig_shape[0], orig_shape[1], shape[0], shape[1])
else:
pre_tfm = NoOpTransform()
# Create all combinations of augmentations to use
aug_candidates = [] # each element is a list[Augmentation]
for min_size in self.min_sizes:
resize = ResizeShortestEdge(min_size, self.max_size)
aug_candidates.append([resize]) # resize only
if self.flip:
flip = RandomFlip(prob=1.0)
aug_candidates.append([resize, flip]) # resize + flip
# Apply all the augmentations
ret = []
for aug in aug_candidates:
new_image, tfms = apply_augmentations(aug, np.copy(numpy_image))
torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1)))
dic = copy.deepcopy(dataset_dict)
dic["transforms"] = pre_tfm + tfms
dic["image"] = torch_image
ret.append(dic)
return ret |