Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gzip | |
import json | |
import random | |
from pathlib import Path | |
from typing import Optional, Tuple, List, Dict | |
from abc import ABC, abstractmethod | |
from PIL import Image | |
import cv2 | |
import albumentations as A | |
import numpy as np | |
import torch | |
import torchvision.transforms.functional as TF | |
import torchvision.transforms as T | |
from einops import rearrange, repeat, reduce | |
from fourm.utils import to_2tuple | |
from fourm.utils.data_constants import (IMAGENET_DEFAULT_MEAN, | |
IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, | |
IMAGENET_SURFACE_NORMAL_STD, IMAGENET_SURFACE_NORMAL_MEAN, | |
IMAGENET_INCEPTION_STD, SEG_IGNORE_INDEX, PAD_MASK_VALUE) | |
# The @-symbol is used to specify the resolution of a modality. Syntax: modality@resolution | |
def get_transform_key(mod_name): | |
return mod_name.split('@')[0] | |
def get_transform_resolution(mod_name, default_resolution, to_tuple=True): | |
res = int(mod_name.split('@')[1]) if '@' in mod_name else default_resolution | |
return to_2tuple(res) if to_tuple else res | |
def get_transform(mod_name, transforms_dict): | |
return transforms_dict.get(get_transform_key(mod_name), IdentityTransform()) | |
def get_pil_resample_mode(resample_mode: str): | |
""" | |
Returns the PIL resampling mode for the given resample mode string. | |
Args: | |
resample_mode: Resampling mode string | |
""" | |
if resample_mode is None: | |
return None | |
elif resample_mode == "bilinear": | |
return Image.Resampling.BILINEAR if hasattr(Image, 'Resampling') else Image.BILINEAR | |
elif resample_mode == "bicubic": | |
return Image.Resampling.BICUBIC if hasattr(Image, 'Resampling') else Image.BICUBIC | |
elif resample_mode == "nearest": | |
return Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST | |
else: | |
raise ValueError(f"Resample mode {resample_mode} is not supported.") | |
class UnifiedDataTransform(object): | |
def __init__(self, transforms_dict, image_augmenter, resample_mode: str = None, add_sizes: bool = False, **kwargs): | |
"""Unified data augmentation for FourM | |
Args: | |
transforms_dict (dict): Dict of transforms for each modality | |
image_augmenter (AbstractImageAugmenter): Image augmenter | |
resample_mode (str, optional): Resampling mode for PIL images (default: None -> uses default resampling mode for data type) | |
One out of ["bilinear", "bicubic", "nearest", None]. | |
add_sizes (bool, optional): Whether to add crop coordinates and original size to the output dict | |
""" | |
self.transforms_dict = transforms_dict | |
self.image_augmenter = image_augmenter | |
self.resample_mode = resample_mode | |
self.add_sizes = add_sizes | |
def unified_image_augment(self, mod_dict, crop_settings): | |
"""Apply the image augmenter to all modalities where it is applicable | |
Args: | |
mod_dict (dict): Dict of modalities | |
crop_settings (dict): Crop settings | |
Returns: | |
dict: Transformed dict of modalities | |
""" | |
crop_coords, flip, orig_size, target_size, rand_aug_idx = self.image_augmenter(mod_dict, crop_settings) | |
mod_dict = { | |
k: self.transforms_dict[get_transform_key(k)].image_augment( | |
v, crop_coords=crop_coords, flip=flip, orig_size=orig_size, | |
target_size=get_transform_resolution(k, target_size), rand_aug_idx=rand_aug_idx, | |
resample_mode=self.resample_mode | |
) | |
for k, v in mod_dict.items() | |
} | |
if self.add_sizes: | |
mod_dict["crop_coords"] = torch.tensor(crop_coords) | |
mod_dict["orig_size"] = torch.tensor(orig_size) | |
return mod_dict | |
def __call__(self, mod_dict): | |
"""Apply the augmentation to a dict of modalities (both image based and sequence based modalities) | |
Args: | |
mod_dict (dict): Dict of modalities | |
Returns: | |
dict: Transformed dict of modalities | |
""" | |
crop_settings = mod_dict.pop("crop_settings", None) | |
mod_dict = {k: get_transform(k, self.transforms_dict).preprocess(v) for k, v in mod_dict.items()} | |
mod_dict = self.unified_image_augment(mod_dict, crop_settings) | |
mod_dict = {k: get_transform(k, self.transforms_dict).postprocess(v) for k, v in mod_dict.items()} | |
return mod_dict | |
def __repr__(self): | |
repr = "(UnifiedDataAugmentation,\n" | |
repr += ")" | |
return repr | |
class AbstractTransform(ABC): | |
def load(self, sample): | |
pass | |
def preprocess(self, sample): | |
pass | |
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
pass | |
def postprocess(self, v): | |
pass | |
class ImageTransform(AbstractTransform): | |
def pil_loader(path: str) -> Image.Image: | |
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
# with open(path, 'rb') as f: | |
# img = Image.open(f) | |
img = Image.open(path) | |
return img | |
def image_hflip(img: Image, flip: bool): | |
"""Crop and resize an image | |
:param img: Image to crop and resize | |
:param flip: Whether to flip the image | |
:return: Flipped image (if flip = True) | |
""" | |
if flip: | |
img = TF.hflip(img) | |
return img | |
def image_crop_and_resize(img: Image, crop_coords: Tuple, target_size: Tuple, resample_mode: str = None): | |
"""Crop and resize an image | |
:param img: Image to crop and resize | |
:param crop_coords: Coordinates of the crop (top, left, h, w) | |
:param target_size: Coordinates of the resize (height, width) | |
:return: Cropped and resized image | |
""" | |
top, left, h, w = crop_coords | |
resize_height, resize_width = target_size | |
img = TF.crop(img, top, left, h, w) | |
resample_mode = get_pil_resample_mode(resample_mode) | |
img = img.resize((resize_height, resize_width), resample=resample_mode) | |
return img | |
class RGBTransform(ImageTransform): | |
def __init__(self, imagenet_default_mean_and_std=True, color_jitter=False, color_jitter_strength=0.5): | |
self.rgb_mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN | |
self.rgb_std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD | |
self.color_jitter = color_jitter | |
self.color_jitter_transform = self.random_color_jitter(color_jitter_strength) | |
def random_color_jitter(self, strength=0.5): | |
# Color Jitter from Pix2Seq and SimCLR | |
# Source: https://github.com/google-research/pix2seq/blob/main/data/data_utils.py#L114 | |
t = T.Compose([ | |
T.RandomApply([T.ColorJitter(brightness=0.8 * strength, contrast=0.8 * strength, saturation=0.8 * strength, hue=0.2 * strength)], p=0.8), | |
T.RandomApply([T.Grayscale(num_output_channels=3)], p=0.2), | |
]) | |
return t | |
def rgb_to_tensor(self, img): | |
img = TF.to_tensor(img) | |
img = TF.normalize(img, mean=self.rgb_mean, std=self.rgb_std) | |
return img | |
def load(self, path): | |
# TODO: Instead of converting to RGB here, do it either in the preprocess or the postprocess step. Makes it compatible with wds dataloading. | |
sample = self.pil_loader(path) | |
return sample | |
def preprocess(self, sample): | |
sample = sample.convert('RGB') | |
if self.color_jitter: | |
sample = self.color_jitter_transform(sample) | |
return sample | |
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode) | |
img = self.image_hflip(img, flip) | |
return img | |
def postprocess(self, sample): | |
sample = self.rgb_to_tensor(sample) | |
return sample | |
class DepthTransform(ImageTransform): | |
def __init__(self, standardize_depth=True): | |
self.standardize_depth = standardize_depth | |
def depth_to_tensor(self, img): | |
img = torch.Tensor( img / (2 ** 16 - 1.0) ) | |
img = img.unsqueeze(0) # 1 x H x W | |
if self.standardize_depth: | |
img = self.truncated_depth_standardization(img) | |
return img | |
def truncated_depth_standardization(depth, thresh: float = 0.1): | |
"""Truncated depth standardization | |
:param depth: Depth map | |
:param thresh: Threshold | |
:return: Robustly standardized depth map | |
""" | |
# Flatten depth and remove bottom and top 10% of values | |
trunc_depth = torch.sort(depth.reshape(-1), dim=0)[0] | |
trunc_depth = trunc_depth[int(thresh * trunc_depth.shape[0]): int((1 - thresh) * trunc_depth.shape[0])] | |
return (depth - trunc_depth.mean()) / torch.sqrt(trunc_depth.var() + 1e-6) | |
def load(self, path): | |
sample = self.pil_loader(path) | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode) | |
img = self.image_hflip(img, flip) | |
return img | |
def postprocess(self, sample): | |
sample = np.array(sample) | |
sample = self.depth_to_tensor(sample) | |
return sample | |
class NormalTransform(ImageTransform): | |
def __init__(self, standardize_surface_normals=False): | |
self.normal_mean = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_MEAN | |
self.normal_std = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_STD | |
def normal_to_tensor(self, img): | |
img = TF.to_tensor(img) | |
img = TF.normalize(img, mean=self.normal_mean, std=self.normal_std) | |
return img | |
def load(self, path): | |
sample = self.pil_loader(path) | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_hflip(self, img: Image, flip: bool): | |
if flip: | |
img = TF.hflip(img) | |
flipped_np = np.array(img) | |
flipped_np[:, :, 0] = 255 - flipped_np[:, :, 0] | |
img = Image.fromarray(flipped_np) | |
return img | |
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode) | |
img = self.image_hflip(img, flip) | |
return img | |
def postprocess(self, sample): | |
sample = self.normal_to_tensor(sample) | |
return sample | |
class SemsegTransform(ImageTransform): | |
def __init__(self, scale_factor=1.0, shift_idx_by_one=False, id_mapping: Optional[Dict] = None, select_channel=None): | |
self.scale_factor = scale_factor | |
self.shift_idx_by_one = shift_idx_by_one | |
self.id_mapping = id_mapping | |
self.select_channel = select_channel | |
def map_semseg_values(self, sample): | |
sample = np.asarray(sample) | |
mapping_fn = lambda x: self.id_mapping.get(x, x) | |
sample = np.vectorize(mapping_fn)(sample) | |
sample = Image.fromarray(sample, mode='P') | |
return sample | |
def semseg_to_tensor(self, img): | |
# Rescale to scale factor | |
if self.scale_factor != 1.0: | |
target_height, target_width = int(img.height * self.scale_factor), int(img.width * self.scale_factor) | |
img = img.resize((target_width, target_height)) | |
# Using pil_to_tensor keeps it in uint8, to_tensor converts it to float (rescaled to [0, 1]) | |
img = TF.pil_to_tensor(img).to(torch.long).squeeze(0) | |
# 255->0, 254->0, all else shifted up by one | |
return img | |
def load(self, path): | |
sample = self.pil_loader(path) | |
if self.select_channel is not None: | |
sample = sample.split()[self.select_channel] | |
return sample | |
def preprocess(self, sample): | |
sample = sample.convert('P') | |
if self.id_mapping is not None: | |
sample = self.map_semseg_values(sample) | |
if self.shift_idx_by_one: | |
sample = np.asarray(sample) | |
sample = sample + 1 | |
sample = Image.fromarray(sample, mode='P') | |
return sample | |
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
# Value for padding with TF.crop is always 0. | |
# Override resampling mode to 'nearest' for semseg | |
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest') | |
img = self.image_hflip(img, flip) | |
return img | |
def postprocess(self, sample): | |
img = self.semseg_to_tensor(sample) | |
return img | |
class SAMInstanceTransform(AbstractTransform): | |
def __init__(self, mask_size=64, max_instance_n=20, bbox_area_threshold=0.0005): | |
self.mask_size = mask_size | |
self.max_instance_n = max_instance_n | |
self.bbox_area_threshold = bbox_area_threshold | |
def get_bbox(self, instance): | |
""" Gets bounding box of the given instance | |
""" | |
min_h, max_h = instance[:,:,1].min(), instance[:,:,1].max() | |
min_w, max_w = instance[:,:,0].min(), instance[:,:,0].max() | |
return [min_h, min_w, max_h, max_w] | |
def extend_instance_points(self, instance, border_fn): | |
""" Given an instance and a border function `border_fn`, extends the instance points with crossing points between the instance and | |
the crop borders. The crossing points are obtained using border_fn. | |
""" | |
p = instance[:,0] | |
p_next = np.roll(p, (-1), axis=(0)) | |
final_points = [] | |
for x, xn in zip(p, p_next): | |
final_points.append(x) | |
for r in border_fn(x, xn): | |
final_points.append(r.astype(np.int32)) | |
p = np.stack(final_points) | |
return p[:,None] | |
def remove_redundant_lines(self, orig_instance, instance): | |
""" Removes the redundant lines added during cropping. | |
""" | |
final_points = [] | |
for p in instance: | |
distance = cv2.pointPolygonTest(orig_instance, (p[0,0].item(), p[0,1].item()), measureDist=True) | |
if distance >= 0: | |
final_points.append(p[0]) | |
return np.stack(final_points)[:,None] | |
def get_border_functions(self, crop_points): | |
""" Creates and returns a function `fn` using crop region coordinates given in crop_points. | |
`fn` receives two input points x and xn and returns all the crossing points between the line connecting | |
x and xn, and the borders of the cropping rectangle. | |
""" | |
p = crop_points[:,0] | |
p_next = np.roll(p, (-1), axis=(0)) | |
def fn(x, xn): | |
output = [] | |
c_diff = p_next - p | |
x_diff = x - xn | |
for diff, c in zip(c_diff, p): | |
A = np.array([ | |
[diff[0], x_diff[0]], | |
[diff[1], x_diff[1]] | |
]) | |
b = x - c | |
try: | |
lmbda = np.linalg.solve(A, b) | |
if 0 <= lmbda[0] <= 1 and 0 <= lmbda[1] <= 1: | |
output.append(lmbda[1] * xn + (1-lmbda[1]) * x) | |
except: | |
continue | |
return output | |
return fn | |
def crop_sample(self, sample, crop_coords): | |
""" Crop the sample using crop coordinates. | |
""" | |
top, left, h, w = crop_coords | |
crop_region = (left, top, left + w, top + h) | |
crop_points = np.array([ | |
[crop_region[0], crop_region[1]], | |
[crop_region[2], crop_region[1]], | |
[crop_region[2], crop_region[3]], | |
[crop_region[0], crop_region[3]], | |
])[:,None] | |
border_functions = self.get_border_functions(crop_points) | |
cropped_sample = [] | |
for instance in sample: | |
instance = self.extend_instance_points(instance, border_functions) | |
filter_condition = ( | |
(instance[:, :, 0] > crop_region[0]) & | |
(instance[:, :, 0] < crop_region[2]) & | |
(instance[:, :, 1] > crop_region[1]) & | |
(instance[:, :, 1] < crop_region[3]) | |
) | |
if not np.any(filter_condition): | |
continue | |
instance_copy = instance.copy() | |
instance_copy[:, :, 0] = np.clip(instance[:, :, 0], a_min=crop_region[0], a_max=crop_region[2]) | |
instance_copy[:, :, 1] = np.clip(instance[:, :, 1], a_min=crop_region[1], a_max=crop_region[3]) | |
instance_copy = self.remove_redundant_lines(instance, instance_copy) | |
instance_copy[:, :, 0] -= crop_region[0] | |
instance_copy[:, :, 1] -= crop_region[1] | |
cropped_sample.append(instance_copy) | |
return cropped_sample | |
def resize_sample(self, sample, original_size, target_size): | |
""" Resize the sample | |
""" | |
width_scale = target_size[1] / original_size[1] | |
height_scale = target_size[0] / original_size[0] | |
resized_sample = [] | |
for instance in sample: | |
instance_copy = instance.copy() | |
instance_copy[:, :, 0] = np.round(width_scale * instance_copy[:, :, 0]) | |
instance_copy[:, :, 1] = np.round(height_scale * instance_copy[:, :, 1]) | |
resized_sample.append(instance_copy) | |
return resized_sample | |
def remove_tiny_instances(self, sample, image_size): | |
""" Remove instances that have an area ratio smaller than `bbox_area_threshold`. | |
""" | |
filtered_sample = [] | |
for instance in sample: | |
min_h, min_w, max_h, max_w = self.get_bbox(instance) | |
bbox_area_ratio = (max_h - min_h) * (max_w - min_w) / (image_size[0] * image_size[1]) | |
if bbox_area_ratio < self.bbox_area_threshold: | |
continue | |
filtered_sample.append(instance) | |
return filtered_sample | |
def hflip(self, sample, width): | |
""" Horizontal flipping the instances in a sample. | |
""" | |
flipped_sample = [] | |
for instance in sample: | |
instance_copy = instance.copy() | |
instance_copy[:, :, 0] = width - instance_copy[:, :, 0] | |
flipped_sample.append(instance_copy) | |
return flipped_sample | |
def get_binary_masks(self, sample): | |
""" Creates the binary mask of each instance in the sample. | |
""" | |
if self.max_instance_n is None: | |
max_instance_n = len(sample) | |
else: | |
max_instance_n = self.max_instance_n | |
masks = np.zeros((max_instance_n, self.mask_size, self.mask_size)) | |
bboxes = np.zeros((max_instance_n, 4)) | |
valid = np.full(max_instance_n, False) | |
for i, instance in enumerate(sample): | |
bbox = self.get_bbox(instance) | |
min_h, min_w, max_h, max_w = bbox | |
instance_copy = instance.copy() | |
mask = np.zeros((self.mask_size, self.mask_size), dtype=np.uint8) | |
instance_copy[:,:,0] = (instance_copy[:,:,0] - min_w) / (max_w - min_w) * self.mask_size | |
instance_copy[:,:,1] = (instance_copy[:,:,1] - min_h) / (max_h - min_h) * self.mask_size | |
cv2.drawContours(mask, [instance_copy], 0, (255), thickness=cv2.FILLED) | |
masks[i] = mask / 255.0 | |
bboxes[i] = np.array(bbox) | |
valid[i] = True | |
return masks, bboxes, valid | |
def load(self, path): | |
sample = np.load(path, allow_pickle=True) | |
return sample | |
def preprocess(self, sample): | |
if self.max_instance_n is None or len(sample) <= self.max_instance_n: | |
indecies = np.arange(len(sample)) | |
else: | |
indecies = np.random.choice(len(sample), size=self.max_instance_n, replace=False) | |
return [p['points'] for i, p in enumerate(sample) if i in indecies] | |
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
v = self.crop_sample(v, crop_coords) | |
_, _, h, w = crop_coords | |
v = self.resize_sample(v, (h, w), target_size) | |
v = self.remove_tiny_instances(v, target_size) | |
if flip: | |
v = self.hflip(v, target_size[0]) | |
return v | |
def postprocess(self, sample): | |
sample, bboxes, valid = self.get_binary_masks(sample) | |
return { | |
'instance': torch.from_numpy(sample).to(torch.float32), | |
'bbox': torch.from_numpy(bboxes).to(torch.float32), | |
'valid': torch.from_numpy(valid) | |
} | |
class MaskTransform(ImageTransform): | |
def __init__(self, mask_pool_size=1): | |
assert isinstance(mask_pool_size, int) | |
self.mask_pool_size = mask_pool_size # Use to expand masks | |
def mask_to_tensor(self, img): | |
mask = TF.to_tensor(img) | |
if self.mask_pool_size > 1: | |
mask = reduce(mask, 'c (h1 h2) (w1 w2) -> c h1 w1', 'min', h2=self.mask_pool_size, w2=self.mask_pool_size) | |
mask = repeat(mask, 'c h1 w1 -> c (h1 h2) (w1 w2)', h2=self.mask_pool_size, w2=self.mask_pool_size) | |
return (mask == 1.0) | |
def load(self, path): | |
sample = self.pil_loader(path) | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
# Override resampling mode to 'nearest' for masks | |
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest') | |
img = self.image_hflip(img, flip) | |
return img | |
def postprocess(self, sample): | |
sample = self.mask_to_tensor(sample) | |
return sample | |
class TokTransform(AbstractTransform): | |
def __init__(self): | |
pass | |
def load(self, path): | |
sample = np.load(path).astype(int) | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
if rand_aug_idx is None: | |
raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used") | |
v = torch.tensor(v[rand_aug_idx]) | |
return v | |
def postprocess(self, sample): | |
return sample | |
class DetectionTransform(AbstractTransform): | |
def __init__(self, det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0, return_raw=False): | |
self.det_threshold = det_threshold | |
self.det_max_instances = det_max_instances | |
self.coord_bins = coord_bins | |
self.min_visibility = min_visibility | |
self.return_raw = return_raw | |
if bbox_order == 'area': | |
self.bbox_order = self.order_bboxes_by_area | |
elif bbox_order == 'score': | |
self.bbox_order = self.order_bboxes_by_score | |
elif bbox_order == 'random': | |
self.bbox_order = self.shuffle_bboxes | |
else: | |
self.bbox_order = self.order_bboxes_by_dist_to_orig | |
def order_bboxes_by_area(bboxes): | |
return sorted(bboxes, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]), reverse=True) | |
def order_bboxes_by_dist_to_orig(bboxes): | |
return sorted(bboxes, key=lambda x: x[0] ** 2 + x[1] ** 2) | |
def order_bboxes_by_score(bboxes): | |
return sorted(bboxes, key=lambda x: x[5], reverse=True) | |
def shuffle_bboxes(bboxes): | |
return sorted(bboxes, key=lambda x: random.random()) | |
def convert_detection_instance(self, instances): | |
"""Convert instances dict to list of lists where each list takes the form: | |
[xmin, ymin, xmax, ymax, class_name, score] | |
""" | |
instances = [inst['boxes'] + [inst['class_name'], inst['score']] for inst in instances if inst['score'] >= self.det_threshold] | |
return instances | |
def bboxes_hflip(self, bboxes: List[Tuple], image_size: Tuple, flip: bool): | |
image_height, image_width = image_size | |
if flip: | |
bboxes = [tuple(A.bbox_hflip(bbox[:4], rows=image_height, cols=image_width)) + tuple(bbox[4:]) | |
for bbox in bboxes] | |
return bboxes | |
def bboxes_crop_and_resize(self, bboxes: List[Tuple], crop_coords: Tuple, orig_size: Tuple): | |
"""Crop and resize bounding boxes | |
Args: | |
bboxes: Bounding boxes to crop and resize | |
crop_coords: Coordinates of the crop (top, left, h, w) | |
orig_size: Size of the original image | |
Returns: | |
Cropped and resized bounding boxes | |
""" | |
orig_height, orig_width = orig_size | |
top, left, h, w = crop_coords | |
xmin, ymin, xmax, ymax = left, top, left + w, top + h | |
bboxes = [tuple(A.bbox_crop(bbox[:4], x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, | |
cols=orig_width)) + tuple(bbox[4:]) | |
for bbox in bboxes] | |
bboxes = A.core.bbox_utils.filter_bboxes(bboxes, rows=h, cols=w, min_visibility=self.min_visibility) | |
# No need to resize, bounding boxes in albumentations format are scale invariant | |
return bboxes | |
def order_and_filter_bboxes(self, bboxes): | |
if self.det_max_instances is not None and len(bboxes) > self.det_max_instances: | |
bboxes = self.order_bboxes_by_score(bboxes)[:self.det_max_instances] | |
return self.bbox_order(bboxes) | |
def convert_bboxes_to_string(self, bboxes: List[Tuple]): | |
"""Convert bounding boxes to a string. | |
xmin, ymin, xmax, ymax are mapped to v0, v1, v2, v3 special tokens. | |
Args: | |
bboxes: Bounding boxes | |
Returns: | |
String representation of the bounding boxes | |
""" | |
# Remove score, quantize coordinates | |
bins = self.coord_bins | |
bboxes = [ | |
[ | |
f"v0={round(xmin * (bins - 1))}", | |
f"v1={round(ymin * (bins - 1))}", | |
f"v2={round(xmax * (bins - 1))}", | |
f"v3={round(ymax * (bins - 1))}", | |
cls, | |
] | |
for (xmin, ymin, xmax, ymax, cls, score) in bboxes | |
] | |
# Convert each bounding box to a string | |
bboxes = [' '.join(b) for b in bboxes] | |
# Convert the list to a str | |
return ' '.join(bboxes) | |
def load(self, path): | |
with open(path, 'r') as f: | |
sample = json.load(f) | |
return sample | |
def preprocess(self, sample): | |
instances = sample['instances'] | |
return self.convert_detection_instance(instances) | |
def image_augment(self, bboxes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx=None, resample_mode: str = None): | |
bboxes = self.bboxes_crop_and_resize(bboxes, crop_coords, orig_size) | |
bboxes = self.bboxes_hflip(bboxes, target_size, flip) | |
bboxes = self.order_and_filter_bboxes(bboxes) | |
return bboxes | |
def postprocess(self, bboxes): | |
if self.return_raw: | |
return bboxes | |
bboxes = self.convert_bboxes_to_string(bboxes) | |
return bboxes | |
class CaptionTransform(AbstractTransform): | |
def __init__(self, aligned_captions=True, no_aug=False): | |
self.aligned_captions = aligned_captions | |
self.no_aug = no_aug | |
def load(self, path): | |
# Caption can either be stored as .txt or .json.gz (in which case it's a list of dicts) | |
if path.endswith('.txt'): | |
sample = Path(path).read_text() | |
elif path.endswith('.json'): | |
with open(path, 'r') as f: | |
sample = json.load(f) | |
elif path.endswith('.json.gz'): | |
with gzip.open(path, 'rb') as f: | |
sample = json.load(f) | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
if isinstance(val, list) or isinstance(val, tuple): | |
if self.aligned_captions: | |
val = val[0] if rand_aug_idx is None else val[rand_aug_idx] | |
else: | |
val = random.choice(val) if not self.no_aug else val[0] | |
if isinstance(val, dict): | |
# If each caption is saved as a dict, extract the string | |
val = val["caption"] | |
assert isinstance(val, str) | |
return val | |
def postprocess(self, sample): | |
return sample | |
class CaptionEmbTransform(AbstractTransform): | |
def __init__(self, aligned_captions=True, no_aug=False): | |
self.aligned_captions = aligned_captions | |
self.no_aug = no_aug | |
def load(self, path): | |
if path.endswith('.npz'): | |
sample = np.load(path) | |
sample = {'emb': sample['emb'], 'mask_valid': sample['mask_valid']} | |
else: | |
raise ValueError(f"Invalid file format for caption embedding: {path}") | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
emb = val['emb'] | |
mask_valid = val['mask_valid'].astype(bool) | |
num_sequences = emb.shape[0] | |
if num_sequences > 1: | |
if self.aligned_captions: | |
if rand_aug_idx is None: | |
emb, mask_valid = emb[0], mask_valid[0] | |
else: | |
emb, mask_valid = emb[rand_aug_idx], mask_valid[rand_aug_idx] | |
else: | |
if self.no_aug: | |
emb, mask_valid = emb[0], mask_valid[0] | |
else: | |
rand_idx = random.randint(0, num_sequences - 1) | |
emb, mask_valid = emb[rand_idx], mask_valid[rand_idx] | |
else: | |
emb, mask_valid = emb[0], mask_valid[0] | |
emb = emb[mask_valid] # Keep only valid embeddings | |
return emb | |
def postprocess(self, sample): | |
return torch.tensor(sample) | |
class MetadataTransform(AbstractTransform): | |
def __init__(self, | |
special_vmin: int = 0, | |
special_vmax: int = 999, | |
shuffle: bool = True, | |
random_trunc: bool = False, | |
return_chunks: bool = True, | |
return_raw: bool = False, | |
image_dim_bin_size: int = 32,): | |
"""Metadata transform that takes in a metadata dictionary and converts | |
it into a string, or list of strings (for chunked span masking). | |
Uses special tokens v1 to denote metadata types, and v0 for their values. | |
Args: | |
special_vmin: Minimum value for special tokens | |
special_vmax: Maximum value for special tokens | |
shuffle: Whether to shuffle the metadata order | |
random_trunc: Whether to randomly truncate the returned metadata | |
return_chunks: Whether to return a list of strings (for chunked span masking), | |
or a single string with all metadata concatenated | |
return_raw: Whether to return the raw metadata dictionary | |
""" | |
self.special_vmin = special_vmin | |
self.special_vmax = special_vmax | |
self.shuffle = shuffle | |
self.random_trunc = random_trunc | |
self.return_chunks = return_chunks | |
self.return_raw = return_raw | |
self.image_dim_bin_size = image_dim_bin_size | |
# Explicit map to make sure that additional entries do not change existing IDs | |
# TODO: Make this work with other text tokenizers | |
self.metadata_id_map = { | |
'original_width': 'v1=0', | |
'original_height': 'v1=1', | |
'caption_n_chars': 'v1=2', | |
'caption_n_words': 'v1=3', | |
'caption_n_sentences': 'v1=4', | |
'n_humans': 'v1=5', | |
'n_sam_instances': 'v1=6', | |
'n_coco_instances': 'v1=7', | |
'coco_instance_diversity': 'v1=8', | |
'colorfulness': 'v1=9', | |
'brightness': 'v1=10', | |
'contrast': 'v1=11', | |
'saturation': 'v1=12', | |
'entropy': 'v1=13', | |
'walkability': 'v1=14', | |
'objectness': 'v1=15', | |
'semantic_diversity': 'v1=16', | |
'geometric_complexity': 'v1=17', | |
'occlusion_score': 'v1=18', | |
'watermark_score': 'v1=19', | |
'aesthetic_score': 'v1=20', | |
} | |
self.id_metadata_map = {v: k for k, v in self.metadata_id_map.items()} | |
# Image-dimension modalities are binned into 32 bins | |
self.image_dim_modalities = ['original_height', 'original_width'] | |
# Integer modalities that don't undergo any scaling (except for truncation) | |
self.metadata_int_modalities = [ | |
'caption_n_chars', 'caption_n_words', 'caption_n_sentences', | |
'n_humans', 'n_sam_instances', 'n_coco_instances', | |
'coco_instance_diversity', 'semantic_diversity', | |
] | |
# Bin boundaries for manually defined metadata modalities. | |
# Lowest and highest bin boundaries are implicitly set to -inf and +inf | |
self.metadata_manual_bins = { | |
'watermark_score': [0.5], | |
'aesthetic_score': [4.5, 5.5], | |
} | |
# All other float or integer modalities that are binned into a defined number of bins | |
# Dictionary entries are (vmin, vmax, num_bins) | |
self.metadata_min_max_bins = { | |
'colorfulness': (0, 150, 50), | |
'brightness': (0, 255, 50), | |
'contrast': (0, 127, 50), | |
'saturation': (0, 255, 50), | |
'entropy': (0, 10, 50), | |
'walkability': (0, 1, 50), | |
'objectness': (0, 1, 50), | |
'geometric_complexity': (0, 0.75, 50), | |
'occlusion_score': (0, 0.25, 50), | |
} | |
def image_dim_to_string(self, metadata, key, bin_size=32): | |
value = metadata[key] // bin_size | |
value = max(self.special_vmin, min(value, self.special_vmax)) | |
return f"{self.metadata_id_map[key]} v0={value}" | |
def int_metadata_to_string(self, metadata, key): | |
value = max(self.special_vmin, min(metadata[key], self.special_vmax)) | |
return f"{self.metadata_id_map[key]} v0={value}" | |
def float_metadata_to_string(self, metadata, key, vmin, vmax, bins): | |
value = max(vmin, min(metadata[key], vmax)) | |
value = (value - vmin) / (vmax - vmin) | |
value = int(value * (bins-1)) | |
return f"{self.metadata_id_map[key]} v0={value}" | |
def manual_bin_metadata_to_string(self, metadata, key): | |
value = metadata[key] | |
bin_idx = 0 | |
for bin_value in self.metadata_manual_bins[key]: | |
if value < bin_value: | |
break | |
bin_idx += 1 | |
return f"{self.metadata_id_map[key]} v0={bin_idx}" | |
def metadata_to_string(self, metadata, keys: List[str] = None): | |
keys = list(metadata.keys()) if keys is None else keys | |
if self.shuffle: | |
# Randomly shuffle | |
random.shuffle(keys) | |
if self.random_trunc: | |
# Randomly truncate | |
keys = keys[:random.randint(1,len(keys))] | |
metadata_strings = [] | |
for key in keys: | |
if key in self.image_dim_modalities: | |
# Image dimension modalities | |
metadata_str = self.image_dim_to_string(metadata, key, bin_size=self.image_dim_bin_size) | |
elif key in self.metadata_int_modalities: | |
# Integer modalities that don't undergo any scaling | |
metadata_str = self.int_metadata_to_string(metadata, key) | |
elif key in self.metadata_manual_bins: | |
# Metadata modalities for which bin boundaries are manually defined | |
metadata_str = self.manual_bin_metadata_to_string(metadata, key) | |
else: | |
# All other modalities | |
vmin, vmax, bins = self.metadata_min_max_bins[key] | |
metadata_str = self.float_metadata_to_string(metadata, key, vmin, vmax, bins) | |
metadata_strings.append(metadata_str) | |
if self.return_chunks: | |
return metadata_strings | |
else: | |
return ' '.join(metadata_strings) | |
def load(self, path): | |
with open(path, 'r') as f: | |
sample = json.load(f) | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx=None, resample_mode: str = None): | |
return val | |
def postprocess(self, metadata): | |
if self.return_raw: | |
return metadata | |
metadata_str = self.metadata_to_string(metadata) | |
return metadata_str | |
class HumanPoseTransform(AbstractTransform): | |
def __init__(self, coord_bins=1000, only_pose=False, return_raw=False): | |
self.coord_bins = coord_bins | |
self.return_raw = return_raw | |
self.only_pose = only_pose | |
def convert_humanpose_instance(self, instances, only_pose=False): | |
"""Convert instances dict to list of lists where each list takes the form: | |
[human, xmin xmax ymin ymax global val1 val2 ... val10 pose val1 val2 ... val 207 shape val1 val2 ... val10 camera val1 val2 val3 val4] | |
Like for bounding boxes, xmin, ymin, xmax, and ymax map to v0, v1, v2, and v3 respectively. | |
""" | |
if only_pose: # used for tokenizer training for pose | |
if len(instances) == 0: | |
return torch.zeros(207) | |
else: | |
return torch.from_numpy(np.array(instances['pred_smpl_params']['body_pose'][0]).flatten()).float() | |
if len(instances) == 0: #empty, i.e. there are no humans | |
return 'none' | |
for k in instances: | |
if k!='pred_smpl_params': | |
instances[k] = torch.from_numpy(np.array(instances[k])) | |
smpl_params = (instances['pred_smpl_params']) | |
for k in smpl_params: | |
smpl_params[k] = torch.from_numpy(np.array(smpl_params[k])) | |
total_num_instances = len(instances['bbox_xyxy']) | |
instances_converted = [] | |
for ii in range(total_num_instances): | |
instances_converted.append(['human'] + (np.array(instances['bbox_xyxy'][ii]).flatten().tolist()) + ['global'] + (np.array(instances['pred_smpl_params']['global_orient'][ii]).flatten().tolist()) + ['pose'] + (instances['pose_tokenized'][ii].flatten().tolist()) + ['shape'] + (instances['pred_smpl_params']['betas'][ii].flatten().tolist()) + ['camera'] + (instances['pred_cam'][ii].flatten().tolist())) | |
return instances_converted | |
def humanposes_crop_and_resize(self, humanposes: List[Tuple], crop_coords: Tuple, orig_size: Tuple,): | |
"""Crop and resize human poses (and their bounding boxes) | |
""" | |
orig_height, orig_width = orig_size | |
top, left, h, w = crop_coords | |
humanposes_converted_resized = [] | |
for instance in humanposes: | |
bbox_curr = instance[1:5] | |
bbox_curr = np.array(bbox_curr) | |
bbox_curr[0::2] = bbox_curr[0::2] / orig_width | |
bbox_curr[1::2] = bbox_curr[1::2] / orig_height | |
xmin, ymin, xmax, ymax = left, top, left + w, top + h | |
bbox_curr = A.bbox_crop(bbox_curr, x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height, | |
cols=orig_width) | |
bbox_curr = np.array(bbox_curr) | |
if np.all(bbox_curr[1::2]<0) or np.all(bbox_curr[0::2]<0): #bbox is out of range, remove it | |
continue | |
if np.all(bbox_curr[1::2]>1.0) or np.all(bbox_curr[0::2]>1.0): #bbox is out of range, remove it | |
continue | |
bbox_curr = np.clip(bbox_curr, a_min=0, a_max=1.) | |
instance[1:5] = bbox_curr | |
humanposes_converted_resized.append(instance) | |
# now return all instances, or none if there is no instance | |
if len(humanposes_converted_resized)>0: | |
pass | |
else: #no valid masks remains | |
return 'none' | |
humanpose_returned = humanposes_converted_resized | |
return humanpose_returned | |
def convert_humanposes_to_string(self, all_humanposes: List[Tuple]): | |
"""Convert humanposes to a string | |
range of global orientation: [-1, 1] | |
range of object pose: [-1, 1] | |
range of shape (betas): [-3, 3] | |
range of camera: [-1, 19] | |
""" | |
bins = self.coord_bins | |
instance_final_all = '' | |
for humanposes in all_humanposes: | |
human = humanposes[0] | |
bboxes = humanposes[1:5] | |
glob = humanposes[5] | |
global_orient = np.array(humanposes[6:15]) | |
pose = humanposes[15] | |
pose_params = np.array(humanposes[16:24]) | |
shape = humanposes[24] | |
shape_params = np.array(humanposes[25:35]) | |
camera = humanposes[35] | |
camera_params = np.clip(np.array(humanposes[36:]), a_min=-1., a_max=19.) | |
bboxes_new = [ | |
f"v0={round(bboxes[0] * (bins - 1))}", | |
f"v1={round(bboxes[1] * (bins - 1))}", | |
f"v2={round(bboxes[2] * (bins - 1))}", | |
f"v3={round(bboxes[3] * (bins - 1))}"] | |
global_orient = 499.5*global_orient | |
global_orient_new = [] | |
for ii in range(len(global_orient)): | |
global_orient_curr = f"v0={round(global_orient[ii]+499.5)}" | |
global_orient_new.append(global_orient_curr) | |
pose_params_new = [] | |
for ii in range(len(pose_params)): | |
if pose_params[ii]<512: | |
pose_params_curr = f"v0={round(pose_params[ii])}" | |
else: | |
pose_params_curr = f"v1={round(pose_params[ii] - 512)}" | |
pose_params_new.append(pose_params_curr) | |
shape_params = 166.5*shape_params | |
shape_params_new = [] | |
for ii in range(len(shape_params)): | |
shape_params_curr = f"v0={round(shape_params[ii]+499.5)}" | |
shape_params_new.append(shape_params_curr) | |
camera_params = 49.95*camera_params | |
camera_params_new = [] | |
for ii in range(len(camera_params)): | |
camera_params_curr = f"v0={round(camera_params[ii]+49.95)}" | |
camera_params_new.append(camera_params_curr) | |
#randomly shuffle everything except bbox part of the sequence | |
all_strings = [[pose]+pose_params_new, [glob] + global_orient_new, [camera] + camera_params_new, [shape] + shape_params_new ] | |
rand_perm = torch.randperm(4) | |
instance_final = [human] + bboxes_new + all_strings[rand_perm[0]] + all_strings[rand_perm[1]] + all_strings[rand_perm[2]] + all_strings[rand_perm[3]] | |
instance_final = ', '.join(instance_final) | |
instance_final = instance_final.replace(",", "") | |
instance_final_all = instance_final_all + instance_final + ' ' | |
return instance_final_all | |
def load(self, path): | |
with open(path, 'r') as f: | |
sample = json.load(f) | |
return sample | |
def preprocess(self, sample): | |
instances = sample | |
instances = self.convert_humanpose_instance(instances, only_pose=self.only_pose) | |
return instances | |
def image_augment(self, humanposes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx=None, resample_mode: str = None): | |
if humanposes=='none' or self.only_pose: | |
return humanposes | |
humanposes = self.humanposes_crop_and_resize(humanposes, crop_coords, orig_size) | |
return humanposes | |
def postprocess(self, humanposes): | |
if humanposes=='none' or self.only_pose: | |
return humanposes if not self.return_raw else [] | |
if self.return_raw: | |
return humanposes | |
humanposes = self.convert_humanposes_to_string(humanposes) | |
return humanposes | |
class ColorPaletteTransform(AbstractTransform): | |
def __init__(self, coord_bins=1000, return_raw=False): | |
self.coord_bins = coord_bins | |
self.return_raw = return_raw | |
def convert_palette_instance(self, instances): | |
"""Convert colors to v0= v0= ... | |
""" | |
length = random.randint(1,7) | |
instances_converted = np.array(instances[0][str(length)]).flatten().tolist() | |
return instances_converted | |
def palette_hflip(self, palettes: List[Tuple], image_size: Tuple, flip: bool): | |
return palettes | |
def convert_palettes_to_string(self, all_palettes: List[Tuple]): | |
"""Convert palettes to a string | |
""" | |
colors = [] | |
len_palettes = len(all_palettes) | |
colors.append(f"v1={round(len_palettes/3)}") # start with the length of the color palette to avoid confusion | |
for ii in range(len(all_palettes)): | |
color_new = f"v0={round(all_palettes[ii])}" | |
colors.append(color_new) | |
instance_final_all = colors | |
instance_final_all = ', '.join(instance_final_all) | |
instance_final_all = instance_final_all.replace(",", "") | |
return instance_final_all | |
def load(self, path): | |
with open(path, 'r') as f: | |
sample = json.load(f) | |
return sample | |
def preprocess(self, sample): | |
if self.return_raw: | |
return sample | |
instances = sample | |
instances = self.convert_palette_instance(instances) | |
return instances | |
def image_augment(self, palettes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx=None, resample_mode: str = None): | |
return palettes | |
def postprocess(self, palettes): | |
if self.return_raw: | |
return palettes | |
palettes = self.convert_palettes_to_string(palettes) | |
return palettes | |
class SAMInstanceTokTransform(AbstractTransform): | |
def __init__(self, image_size=224, points_per_side=7, point_order='random'): | |
self.H, self.W = to_2tuple(image_size) | |
self.points_per_h, self.points_per_w = to_2tuple(points_per_side) | |
assert point_order in ['random', 'grid'] | |
self.point_order = point_order | |
def get_query_points(self): | |
if self.point_order == 'grid': | |
# Create and cache grid query points | |
if not hasattr(self, 'grid_query_points'): | |
y, x = np.meshgrid(np.linspace(0, self.H, self.points_per_h + 2)[1:-1], np.linspace(0, self.W, self.points_per_w + 2)[1:-1]) | |
grid = np.stack((x, y), axis=2).astype(np.int32) | |
self.grid_query_points = grid.reshape(-1, 2) | |
return self.grid_query_points | |
elif self.point_order == 'random': | |
# Randomly sample query points | |
y = np.random.randint(0, self.H, self.points_per_h) | |
x = np.random.randint(0, self.W, self.points_per_w) | |
return np.concatenate((x[:,None], y[:,None]), axis=1) | |
else: | |
raise ValueError(f"Query point order mode {self.point_order} is not supported.") | |
def get_target_tokens(self, sample, query_points): | |
instances_coords = [coords[0] for coords in sample['points']] | |
tokens = sample['token_ids'] | |
bboxes = sample['bbox'] | |
instance_tokens_per_qpoint = dict() | |
for point in query_points: | |
point = (int(point[0].item()), int(point[1].item())) | |
instance_tokens_per_qpoint[point] = [] | |
for i, (coords, tok, bbox) in enumerate(zip(instances_coords, tokens, bboxes)): | |
# Calculate the distance from the query point to the instance | |
distance = cv2.pointPolygonTest(coords, point, measureDist=True) | |
# If the query point is inside the instance, add its corresponding token | |
if distance >= 0: | |
instance_tokens_per_qpoint[point].append((tok, bbox)) | |
return instance_tokens_per_qpoint | |
def convert_target_tokens_to_string(self, target_tokens): | |
result_text = [] | |
query_points = list(target_tokens.keys()) | |
# Randomly shuffle query points order (mainly for grid order) | |
random.shuffle(query_points) | |
for point in query_points: | |
# Add query point coordinates to the string | |
result_text.append('point') | |
result_text.append(f'v0={point[1]}') | |
result_text.append(f'v1={point[0]}') | |
# Randomly shuffle the order of instance tokens per query point | |
random.shuffle(target_tokens[point]) | |
if len(target_tokens[point]) == 0: | |
# If no instances tokens are found, add 'none' to the string | |
result_text.append('none') | |
else: | |
for tok, bbox in target_tokens[point]: | |
result_text.append(f'polygon') | |
# Add bounding box coordinates to the string | |
ymin, xmin, ymax, xmax = bbox.astype(np.int32) | |
result_text.extend([ | |
f'v0={xmin}', | |
f'v1={ymin}', | |
f'v2={xmax}', | |
f'v3={ymax}', | |
]) | |
# Add instance tokens ids to the string | |
for idx in tok.tolist(): | |
if idx < 512: | |
result_text.append(f'v0={idx}') | |
else: | |
result_text.append(f'v1={idx - 512}') | |
return " ".join(result_text) | |
def load(self, path): | |
sample = np.load(path, allow_pickle=True) | |
return sample | |
def preprocess(self, sample): | |
for s in sample: | |
s['token_ids'] = s['token_ids'].astype(np.int32) | |
return sample | |
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
if rand_aug_idx is None: | |
raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used") | |
v = v[rand_aug_idx] | |
return v | |
def postprocess(self, sample): | |
query_points = self.get_query_points() | |
target_tokens = self.get_target_tokens(sample, query_points) | |
final_string = self.convert_target_tokens_to_string(target_tokens) | |
return final_string | |
class CropSettingsTransform(AbstractTransform): | |
def load(self, path): | |
sample = np.load(path) | |
return sample | |
def preprocess(self, sample): | |
raise NotImplementedError("CropSettingsTransform does not support preprocessing") | |
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
raise NotImplementedError("CropSettingsTransform is not meant to be used for image augmentation") | |
def postprocess(self, sample): | |
raise NotImplementedError("CropSettingsTransform does not support postprocessing") | |
class IdentityTransform(AbstractTransform): | |
def load(self, path): | |
raise NotImplementedError("IdentityTransform does not support loading") | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
return val | |
def postprocess(self, sample): | |
return sample | |
class JSONTransform(AbstractTransform): | |
def load(self, path): | |
if path.endswith('.json'): | |
with open(path, 'r') as f: | |
sample = json.load(f) | |
elif path.endswith('.json.gz'): | |
with gzip.open(path, 'rb') as f: | |
sample = json.load(f) | |
return sample | |
def preprocess(self, sample): | |
return sample | |
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple, | |
rand_aug_idx: Optional[int], resample_mode: str = None): | |
return val | |
def postprocess(self, sample): | |
return sample |