Spaces:
Runtime error
Runtime error
import random | |
from numbers import Number | |
from typing import List, Optional, Sequence, Tuple, Union | |
import mmengine | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmengine.dist import barrier, broadcast, get_dist_info | |
from mmengine.logging import MessageHub | |
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor | |
from mmengine.structures import PixelData | |
from mmengine.utils import is_seq_of | |
from torch import Tensor | |
from mmdet.models.utils import unfold_wo_center | |
from mmdet.models.utils.misc import samplelist_boxtype2tensor | |
from mmpl.registry import MODELS | |
from mmdet.structures import DetDataSample | |
from mmdet.structures.mask import BitmapMasks | |
from mmdet.utils import ConfigType | |
try: | |
import skimage | |
except ImportError: | |
skimage = None | |
class BatchFixedSizePadTokenMaskGPT(BaseDataPreprocessor): | |
"""Fixed size padding for batch images. | |
Args: | |
size (Tuple[int, int]): Fixed padding size. Expected padding | |
shape (h, w). Defaults to None. | |
img_pad_value (int): The padded pixel value for images. | |
Defaults to 0. | |
pad_mask (bool): Whether to pad instance masks. Defaults to False. | |
mask_pad_value (int): The padded pixel value for instance masks. | |
Defaults to 0. | |
pad_seg (bool): Whether to pad semantic segmentation maps. | |
Defaults to False. | |
seg_pad_value (int): The padded pixel value for semantic | |
segmentation maps. Defaults to 255. | |
""" | |
def __init__(self, | |
pad_token: int, | |
p_token_keep: float = 1., | |
nb_code: int = 512, | |
) -> None: | |
super().__init__() | |
self.pad_token = pad_token | |
self.p_token_keep = p_token_keep | |
self.nb_code = nb_code | |
def forward( | |
self, | |
batch | |
): | |
# padding the input index to the same length | |
longest = max([len(item) for item in batch['motion_token']]) | |
bs = len(batch['motion_token']) | |
attention_mask = torch.zeros(bs, longest, dtype=torch.long, device=self.device) | |
input_ids = torch.ones(bs, longest, dtype=torch.long, device=self.device) * self.pad_token | |
for i, item in enumerate(batch['motion_token']): | |
input_ids[i, :len(item)] = item | |
attention_mask[i, :len(item)] = 1 | |
tgt_ids = input_ids | |
if self.p_token_keep == -1: | |
proba = np.random.rand(1)[0] | |
mask = torch.bernoulli(proba * torch.ones(input_ids.shape, | |
device=input_ids.device)) | |
else: | |
mask = torch.bernoulli(self.p_token_keep * torch.ones(input_ids.shape, device=input_ids.device)) | |
mask = mask.bool() | |
r_indices = torch.randint_like(input_ids, self.nb_code) | |
a_indices = mask * input_ids + mask.logical_not() * r_indices | |
tgt_ids[tgt_ids == self.pad_token] = -100 | |
data = dict() | |
data['inputs'] = dict( | |
input_ids=a_indices, | |
attention_mask=attention_mask, | |
labels=tgt_ids, | |
) | |
data['data_samples'] = batch | |
return data | |
class NormalizationMotion(BaseDataPreprocessor): | |
def __init__( | |
self, | |
mean_std_file: str, | |
) -> None: | |
super().__init__() | |
self.mean_std_info = mmengine.load(mean_std_file) | |
def forward( | |
self, | |
batch | |
): | |
for k, v in self.mean_std_info.items(): | |
for kk, vv in v.items(): | |
self.mean_std_info[k][kk] = vv.to(self.device, dtype=torch.float32) | |
gt_motion = batch['motion'] | |
gt_motion = (gt_motion - self.mean_std_info['motion']['mean']) / self.mean_std_info['motion']['std'] | |
data = dict( | |
inputs=gt_motion, | |
data_samples=batch | |
) | |
return data | |
def denormalize(self, x): | |
return x * self.mean_std_info['motion']['std'] + self.mean_std_info['motion']['mean'] |