Spaces:
Runtime error
Runtime error
File size: 5,280 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Callable, List, Union
from mmcv.transforms import BaseTransform, Compose
from mmpretrain.registry import TRANSFORMS
# Define type of transform or transform config
Transform = Union[dict, Callable[[dict], dict]]
@TRANSFORMS.register_module()
class MultiView(BaseTransform):
"""A transform wrapper for multiple views of an image.
Args:
transforms (list[dict | callable], optional): Sequence of transform
object or config dict to be wrapped.
mapping (dict): A dict that defines the input key mapping.
The keys corresponds to the inner key (i.e., kwargs of the
``transform`` method), and should be string type. The values
corresponds to the outer keys (i.e., the keys of the
data/results), and should have a type of string, list or dict.
None means not applying input mapping. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False.
Examples:
>>> # Example 1: MultiViews 1 pipeline with 2 views
>>> pipeline = [
>>> dict(type='MultiView',
>>> num_views=2,
>>> transforms=[
>>> [
>>> dict(type='Resize', scale=224))],
>>> ])
>>> ]
>>> # Example 2: MultiViews 2 pipelines, the first with 2 views,
>>> # the second with 6 views
>>> pipeline = [
>>> dict(type='MultiView',
>>> num_views=[2, 6],
>>> transforms=[
>>> [
>>> dict(type='Resize', scale=224)],
>>> [
>>> dict(type='Resize', scale=224),
>>> dict(type='RandomSolarize')],
>>> ])
>>> ]
"""
def __init__(self, transforms: List[List[Transform]],
num_views: Union[int, List[int]]) -> None:
if isinstance(num_views, int):
num_views = [num_views]
assert isinstance(num_views, List)
assert len(num_views) == len(transforms)
self.num_views = num_views
self.pipelines = []
for trans in transforms:
pipeline = Compose(trans)
self.pipelines.append(pipeline)
self.transforms = []
for i in range(len(num_views)):
self.transforms.extend([self.pipelines[i]] * num_views[i])
def transform(self, results: dict) -> dict:
"""Apply transformation to inputs.
Args:
results (dict): Result dict from previous pipelines.
Returns:
dict: Transformed results.
"""
multi_views_outputs = dict(img=[])
for trans in self.transforms:
inputs = copy.deepcopy(results)
outputs = trans(inputs)
multi_views_outputs['img'].append(outputs['img'])
results.update(multi_views_outputs)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__ + '('
for i, p in enumerate(self.pipelines):
repr_str += f'\nPipeline {i + 1} with {self.num_views[i]} views:\n'
repr_str += str(p)
repr_str += ')'
return repr_str
@TRANSFORMS.register_module()
class ApplyToList(BaseTransform):
"""A transform wrapper to apply the wrapped transforms to a list of items.
For example, to load and resize a list of images.
Args:
transforms (list[dict | callable]): Sequence of transform config dict
to be wrapped.
scatter_key (str): The key to scatter data dict. If the field is a
list, scatter the list to multiple data dicts to do transformation.
collate_keys (List[str]): The keys to collate from multiple data dicts.
The fields in ``collate_keys`` will be composed into a list after
transformation, and the other fields will be adopted from the
first data dict.
"""
def __init__(self, transforms, scatter_key, collate_keys):
super().__init__()
self.transforms = Compose([TRANSFORMS.build(t) for t in transforms])
self.scatter_key = scatter_key
self.collate_keys = set(collate_keys)
self.collate_keys.add(self.scatter_key)
def transform(self, results: dict):
scatter_field = results.get(self.scatter_key)
if isinstance(scatter_field, list):
scattered_results = []
for item in scatter_field:
single_results = copy.deepcopy(results)
single_results[self.scatter_key] = item
scattered_results.append(self.transforms(single_results))
final_output = scattered_results[0]
# merge output list to single output
for key in scattered_results[0].keys():
if key in self.collate_keys:
final_output[key] = [
single[key] for single in scattered_results
]
return final_output
else:
return self.transforms(results)
|