Spaces:
Building
Building
File size: 5,202 Bytes
3b49518 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from .task_configs import task_parameters
MAKE_RESCALE_0_1_NEG1_POS1 = lambda n_chan: transforms.Normalize([0.5]*n_chan, [0.5]*n_chan)
RESCALE_0_1_NEG1_POS1 = transforms.Normalize([0.5], [0.5]) # This needs to be different depending on num out chans
MAKE_RESCALE_0_MAX_NEG1_POS1 = lambda maxx: transforms.Normalize([maxx / 2.], [maxx * 1.0])
RESCALE_0_255_NEG1_POS1 = transforms.Normalize([127.5,127.5,127.5], [255, 255, 255])
MAKE_RESCALE_0_MAX_0_POS1 = lambda maxx: transforms.Normalize([0.0], [maxx * 1.0])
STD_IMAGENET = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# For semantic segmentation
transform_dense_labels = lambda img: torch.Tensor(np.array(img)).long() # avoids normalizing
# Transforms to a 3-channel tensor and then changes [0,1] -> [0, 1]
transform_8bit = transforms.Compose([
transforms.ToTensor(),
])
# Transforms to a n-channel tensor and then changes [0,1] -> [0, 1]. Keeps only the first n-channels
def transform_8bit_n_channel(n_channel=1, crop_channels=True):
if crop_channels:
crop_channels_fn = lambda x: x[:n_channel] if x.shape[0] > n_channel else x
else:
crop_channels_fn = lambda x: x
return transforms.Compose([
transforms.ToTensor(),
crop_channels_fn,
])
# Transforms to a 1-channel tensor and then changes [0,1] -> [0, 1].
def transform_16bit_single_channel(im):
im = transforms.ToTensor()(np.array(im))
im = im.float() / (2 ** 16 - 1.0)
return im
def make_valid_mask(mask_float, max_pool_size=4):
'''
Creates a mask indicating the valid parts of the image(s).
Enlargens masked area using a max pooling operation.
Args:
mask_float: A (b x c x h x w) mask as loaded from the Taskonomy loader.
max_pool_size: Parameter to choose how much to enlarge masked area.
'''
squeeze = False
if len(mask_float.shape) == 3:
mask_float = mask_float.unsqueeze(0)
squeeze = True
_, _, h, w = mask_float.shape
mask_float = 1 - mask_float
mask_float = F.max_pool2d(mask_float, kernel_size=max_pool_size)
mask_float = F.interpolate(mask_float, (h, w), mode='nearest')
mask_valid = mask_float == 0
mask_valid = mask_valid[0] if squeeze else mask_valid
return mask_valid
def task_transform(file, task: str, image_size=Optional[int]):
transform = None
if task in ['rgb']:
transform = transforms.Compose([
transform_8bit,
STD_IMAGENET
])
elif task in ['normal']:
transform = transform_8bit
elif task in ['mask_valid']:
transform = transforms.Compose([
transforms.ToTensor(),
make_valid_mask
])
elif task in ['keypoints2d', 'keypoints3d', 'depth_euclidean', 'depth_zbuffer', 'edge_texture']:
transform = transform_16bit_single_channel
elif task in ['edge_occlusion']:
transform = transforms.Compose([
transform_16bit_single_channel,
transforms.GaussianBlur(3, sigma=1)
])
elif task in ['principal_curvature', 'curvature']:
transform = transform_8bit_n_channel(2)
elif task in ['reshading']:
transform = transform_8bit_n_channel(1)
elif task in ['segment_semantic', 'segment_instance', 'segment_panoptic', 'fragments', 'segment_unsup2d', 'segment_unsup25d']: # this is stored as 1 channel image (H,W) where each pixel value is a different class
transform = transform_dense_labels
elif task in ['class_object', 'class_scene']:
transform = torch.Tensor
image_size = None
else:
transform = None
if 'threshold_min' in task_parameters[task]:
threshold = task_parameters[task]['threshold_min']
transform = transforms.Compose([
transform,
lambda x: torch.threshold(x, threshold, 0.0)
])
if 'clamp_to' in task_parameters[task]:
minn, maxx = task_parameters[task]['clamp_to']
if minn > 0:
raise NotImplementedError("Rescaling (min1, max1) -> (min2, max2) not implemented for min1, min2 != 0 (task {})".format(task))
transform = transforms.Compose([
transform,
lambda x: torch.clamp(x, minn, maxx),
MAKE_RESCALE_0_MAX_0_POS1(maxx)
])
if image_size is not None:
if task == 'fragments':
resize_frag = lambda frag: F.interpolate(frag.permute(2,0,1).unsqueeze(0).float(), image_size, mode='nearest').long()[0].permute(1,2,0)
transform = transforms.Compose([
transform,
resize_frag
])
else:
resize_method = transforms.InterpolationMode.BILINEAR if task in ['rgb'] else transforms.InterpolationMode.NEAREST
transform = transforms.Compose([
transforms.Resize(image_size, resize_method),
transform
])
if transform is not None:
file = transform(file)
return file
|