File size: 2,184 Bytes
6250360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch

# transpose
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1

all_types = [[1,2,3,4],[1,2,4,3],[1,3,2,4],[1,3,4,2],[1,4,2,3],[1,4,3,2],\
              [2,1,3,4],[2,1,4,3],[2,3,1,4],[2,3,4,1],[2,4,1,3],[2,4,3,1],\
              [3,1,2,4],[3,1,4,2],[3,2,1,4],[3,2,4,1],[3,4,1,2],[3,4,2,1],\
              [4,1,2,3],[4,1,3,2],[4,2,1,3],[4,2,3,1],[4,3,1,2],[4,3,2,1]]
aty= [[all_types[iat][0]-1,all_types[iat][1]-1,all_types[iat][2]-1,all_types[iat][3]-1] for iat in range(24)]

class MTY(object):
    def __init__(self, mty, size, mode=None):
        # FIXME remove check once we have better integration with device
        # in my version this would consistently return a CPU tensor
        device = mty.device if isinstance(mty, torch.Tensor) else torch.device('cpu')
        mty = torch.as_tensor(mty, dtype=torch.int64, device=device)
            
        # TODO should I split them?
        assert(len(mty.size()) == 1), str(mty.size())
        self.mty = mty

        self.size = size
        self.mode = mode

    def crop(self, box):
        w, h = box[2] - box[0], box[3] - box[1]
        return type(self)(self.mty, (w, h), self.mode)

    def resize(self, size, *args, **kwargs):
        return type(self)(self.mty, size, self.mode)

    def transpose(self, method):
        if method not in (FLIP_LEFT_RIGHT,):
            raise NotImplementedError(
                    "Only FLIP_LEFT_RIGHT implemented")

        flipped_data = self.mty.clone()
        for i in range(self.mty.size()[0]):
            revs = [it for it in aty[self.mty[i]]]
            revs.reverse()
            flip_type = aty.index(revs)
            flipped_data[i] = flip_type

        return type(self)(flipped_data, self.size, self.mode)

    def to(self, *args, **kwargs):
        return type(self)(self.mty.to(*args, **kwargs), self.size, self.mode)

    def __getitem__(self, item):
        return type(self)(self.mty[item], self.size, self.mode)

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += 'num_instances={}, '.format(len(self.mty))
        s += 'image_width={}, '.format(self.size[0])
        s += 'image_height={})'.format(self.size[1])
        return s