File size: 5,237 Bytes
67a8158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os.path as osp
import numpy as np
import numpy.random as npr
import PIL

import torch
import torchvision
import xml.etree.ElementTree as ET
import json
import copy
import math

def singleton(class_):
    instances = {}
    def getinstance(*args, **kwargs):
        if class_ not in instances:
            instances[class_] = class_(*args, **kwargs)
        return instances[class_]
    return getinstance

@singleton
class get_transform(object):
    def __init__(self):
        self.transform = {}

    def register(self, transf):
        self.transform[transf.__name__] = transf

    def __call__(self, cfg):
        if cfg is None:
            return None
        if isinstance(cfg, list):
            loader = []
            for ci in cfg:
                t = ci.type
                loader.append(self.transform[t](**ci.args))
            return compose(loader)
        t = cfg.type
        return self.transform[t](**cfg.args)

def register():
    def wrapper(class_):
        get_transform().register(class_)
        return class_
    return wrapper

def have(must=[], may=[]):
    """
    The nextgen decorator that have two list of
        input tells what category the transform
        will operate on. 
    Args:
        must: [] of str,
            the names of the items that must be included
            inside the element. 
            If element[name] exist: do the transform 
            If element[name] is None: raise Exception.
            If element[name] not exist: raise Exception.
        may: [] of str,
            the names of the items that may be contained 
            inside the element for transform. 
            If element[name] exist: do the transform 
            If element[name] is None: ignore it.
            If element[name] not exist: ignore it.
    """
    def route(self, item, e, d):
        """
        Route the element to a proper function
            for calculation.
        Args:
            self: object,
                the transform functor.
            item: str,
                the item name of the data.
            e: {},
                the element
            d: nparray, tensor or PIL.Image,
                the data to transform.
        """
        if isinstance(d, np.ndarray):
            dtype = 'nparray'
        elif isinstance(d, torch.Tensor):
            dtype = 'tensor'
        elif isinstance(d, PIL.Image.Image):
            dtype = 'pilimage'
        else:
            raise ValueError

        # find function by order
        f = None
        for attrname in [
                'exec_{}_{}'.format(item, dtype),
                'exec_{}'.format(item),
                'exec_{}'.format(dtype),
                'exec']:
            f = getattr(self, attrname, None)
            if f is not None:
                break
        d, e = f(d, e)
        e[item] = d
        return e

    def wrapper(func):
        def inner(self, e): 
            e['imsize_previous'] = e['imsize_current']
            imsize_tag_cnt = 0
            imsize_tag = 'imsize_before_' + self.__class__.__name__
            while True:
                if imsize_tag_cnt != 0:
                    tag = imsize_tag + str(imsize_tag_cnt)
                else:
                    tag = imsize_tag
                if not tag in e:
                    e[tag] = e['imsize_current']
                    break
                imsize_tag_cnt += 1
            
            e = func(self, e)
            # must transform list
            for item in must:
                try:
                    d = e[item]
                except:
                    raise ValueError
                if d is None:
                    raise ValueError
                e = route(self, item, e, d)
            # may transform list
            for item in may:
                try:
                    d = e[item]
                except:
                    d = None
                if d is not None:
                    e = route(self, item, e, d)
            return e
        return inner
    return wrapper

class compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, element):
        for t in self.transforms:
            element = t(element)
        return element

class TBase(object):
    def __init__(self):
        pass

    def exec(self, data, element):
        raise ValueError

    def rand(self, 
             uid,
             tag, 
             rand_f, 
             *args,
             **kwargs):
        """
        Args:
            uid: string element['unique_id']
            tag: string tells the tag uses when tracking the random number.
                Or the tag to restore the tracked random number.
            rand_f: the random function use to generate random number. 
            **kwargs: the argument for the given random function.
        """
        # if rnduh().hdata is not None:
        #     return rnduh().get_history(uid, self.__class__.__name__, tag)
        # if rnduh().record_path is None:
        #     return rand_f(*args, **kwargs)
        # the special mode to create the random file.
        d = rand_f(*args, **kwargs)
        # rnduh().record(uid, self.__class__.__name__, tag, d)
        return d