File size: 4,074 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import random
from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union

import mmengine
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.dist import barrier, broadcast, get_dist_info
from mmengine.logging import MessageHub
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
from mmengine.structures import PixelData
from mmengine.utils import is_seq_of
from torch import Tensor

from mmdet.models.utils import unfold_wo_center
from mmdet.models.utils.misc import samplelist_boxtype2tensor
from mmpl.registry import MODELS
from mmdet.structures import DetDataSample
from mmdet.structures.mask import BitmapMasks
from mmdet.utils import ConfigType

try:
    import skimage
except ImportError:
    skimage = None


@MODELS.register_module()
class BatchFixedSizePadTokenMaskGPT(BaseDataPreprocessor):
    """Fixed size padding for batch images.

    Args:
        size (Tuple[int, int]): Fixed padding size. Expected padding
            shape (h, w). Defaults to None.
        img_pad_value (int): The padded pixel value for images.
            Defaults to 0.
        pad_mask (bool): Whether to pad instance masks. Defaults to False.
        mask_pad_value (int): The padded pixel value for instance masks.
            Defaults to 0.
        pad_seg (bool): Whether to pad semantic segmentation maps.
            Defaults to False.
        seg_pad_value (int): The padded pixel value for semantic
            segmentation maps. Defaults to 255.
    """

    def __init__(self,
                 pad_token: int,
                 p_token_keep: float = 1.,
                 nb_code: int = 512,
                 ) -> None:
        super().__init__()
        self.pad_token = pad_token
        self.p_token_keep = p_token_keep
        self.nb_code = nb_code

    def forward(
        self,
        batch
    ):
        # padding the input index to the same length

        longest = max([len(item) for item in batch['motion_token']])
        bs = len(batch['motion_token'])

        attention_mask = torch.zeros(bs, longest, dtype=torch.long, device=self.device)
        input_ids = torch.ones(bs, longest, dtype=torch.long, device=self.device) * self.pad_token
        for i, item in enumerate(batch['motion_token']):
            input_ids[i, :len(item)] = item
            attention_mask[i, :len(item)] = 1

        tgt_ids = input_ids

        if self.p_token_keep == -1:
            proba = np.random.rand(1)[0]
            mask = torch.bernoulli(proba * torch.ones(input_ids.shape,
                                                      device=input_ids.device))
        else:
            mask = torch.bernoulli(self.p_token_keep * torch.ones(input_ids.shape, device=input_ids.device))
        mask = mask.bool()
        r_indices = torch.randint_like(input_ids, self.nb_code)
        a_indices = mask * input_ids + mask.logical_not() * r_indices

        tgt_ids[tgt_ids == self.pad_token] = -100

        data = dict()
        data['inputs'] = dict(
            input_ids=a_indices,
            attention_mask=attention_mask,
            labels=tgt_ids,
        )
        data['data_samples'] = batch
        return data


@MODELS.register_module()
class NormalizationMotion(BaseDataPreprocessor):

    def __init__(
            self,
            mean_std_file: str,
    ) -> None:
        super().__init__()
        self.mean_std_info = mmengine.load(mean_std_file)

    def forward(
            self,
            batch
    ):
        for k, v in self.mean_std_info.items():
            for kk, vv in v.items():
                self.mean_std_info[k][kk] = vv.to(self.device, dtype=torch.float32)

        gt_motion = batch['motion']
        gt_motion = (gt_motion - self.mean_std_info['motion']['mean']) / self.mean_std_info['motion']['std']

        data = dict(
            inputs=gt_motion,
            data_samples=batch
        )
        return data

    def denormalize(self, x):
        return x * self.mean_std_info['motion']['std'] + self.mean_std_info['motion']['mean']