Spaces:
Running
Running
import cv2 | |
import torch | |
import numpy as np | |
import pycocotools.mask as mask_utils | |
# transpose | |
FLIP_LEFT_RIGHT = 0 | |
FLIP_TOP_BOTTOM = 1 | |
class MaskList(object): | |
""" | |
This class is unfinished and not meant for use yet | |
It is supposed to contain the binary masks for all instances in a list of 2D tensors (H, W) | |
""" | |
def __init__(self, masks, size, mode): | |
assert(isinstance(masks, list)) | |
assert(mode in ['mask', 'rle']) | |
self.masks = masks | |
self.size = size # (image_width, image_height) | |
self.mode = mode | |
def transpose(self, method): | |
assert (self.mode == "mask"), "RLE masks cannot be transposed. Please convert them to binary first." | |
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
raise NotImplementedError( | |
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
) | |
# width, height = self.size | |
masks = np.array(self.masks) | |
if masks.ndim == 2: | |
masks = np.expand_dims(masks, axis=0) | |
if method == FLIP_LEFT_RIGHT: | |
masks = np.flip(masks, axis=2) | |
elif method == FLIP_TOP_BOTTOM: | |
masks = np.flip(masks, axis=1) | |
flipped_masks = np.split(masks, masks.shape[0]) | |
flipped_masks = [mask.squeeze(0) for mask in flipped_masks] | |
return MaskList(flipped_masks, self.size, self.mode) | |
def resize(self, size, *args, **kwargs): | |
""" | |
Resize the binary mask. | |
:param size: tuple, (image_width, image_height) | |
:param args: | |
:param kwargs: | |
:return: | |
""" | |
assert(self.mode == "mask"), "RLE masks cannot be resized. Please convert them to binary first." | |
cat_mask = np.array(self.masks) | |
cat_mask = cat_mask.transpose(1, 2, 0) | |
cat_mask *= 255 | |
cat_mask = cat_mask.astype(np.uint8) | |
resized_mask = cv2.resize(cat_mask, size) | |
if resized_mask.ndim == 2: | |
resized_mask = np.expand_dims(resized_mask, axis=2) | |
try: | |
resized_mask = resized_mask.transpose(2, 0, 1) | |
except ValueError: | |
print("?") | |
resized_mask = resized_mask.astype(int) | |
resized_mask = resized_mask // 255 | |
# # visualize to check mask correctness | |
# from matplotlib import pyplot as plt | |
# plt.figure() | |
# plt.imshow(resized_mask[0]*255, cmap='gray') | |
# plt.show() | |
mask_list = np.split(resized_mask, resized_mask.shape[0]) | |
mask_list = [mask.squeeze(0) for mask in mask_list] | |
return MaskList(mask_list, size, "mask") | |
def pad(self, size): | |
""" | |
pad the binary masks according to the new size. New size must be larger than original size in all dimensions | |
:param size: New image size, (image_width, image_height) | |
:return: | |
""" | |
assert(size[0] >= self.size[0] and size[1] >= self.size[1]), "New size must be larger than original size in all dimensions" | |
cat_mask = np.array(self.masks) | |
if cat_mask.ndim == 2: | |
cat_mask = np.expand_dims(cat_mask, axis=0) | |
padded_mask = np.zeros([len(self.masks), size[1], size[0]]) | |
padded_mask[:, :cat_mask.shape[1], :cat_mask.shape[2]] = cat_mask | |
# # visualize to check mask correctness | |
# from matplotlib import pyplot as plt | |
# plt.figure() | |
# plt.imshow(padded_mask[1]*255, cmap='gray') | |
# plt.show() | |
mask_list = np.split(padded_mask, padded_mask.shape[0]) | |
mask_list = [mask.squeeze(0) for mask in mask_list] | |
return MaskList(mask_list, size, "mask") | |
def convert(self, mode): | |
""" | |
Convert mask from between mode "mask" and mode "rle" | |
:param mode: | |
:return: | |
""" | |
if mode == self.mode: | |
return self | |
elif mode == "rle" and self.mode == "mask": | |
# use pycocotools to encode binary masks to rle | |
rle_mask_list = mask_utils.encode(np.asfortranarray(np.array(self.masks).transpose(1, 2, 0).astype(np.uint8))) | |
return MaskList(rle_mask_list, self.size, "rle") | |
elif mode == "mask" and self.mode == "rle": | |
# use pycocotools to decode rle to binary masks | |
bimasks = mask_utils.decode(self.masks) | |
mask_list = np.split(bimasks.transpose(2, 0, 1), bimasks.shape[2]) | |
mask_list = [mask.squeeze(0) for mask in mask_list] | |
return MaskList(mask_list, self.size, "mask") | |
def bbox(self, bbox_mode="xyxy"): | |
""" | |
Generate a bounding box according to the binary mask | |
:param bbox_mode: | |
:return: | |
""" | |
pass | |
def __len__(self): | |
return len(self.masks) | |
def __repr__(self): | |
s = self.__class__.__name__ + "(" | |
s += "num_masks={}, ".format(len(self)) | |
s += "image_width={}, ".format(self.size[0]) | |
s += "image_height={}, ".format(self.size[1]) | |
s += "mode={})".format(self.mode) | |
return s | |
class Polygons(object): | |
""" | |
This class holds a set of polygons that represents a single instance | |
of an object mask. The object can be represented as a set of | |
polygons | |
""" | |
def __init__(self, polygons, size, mode): | |
# assert isinstance(polygons, list), '{}'.format(polygons) | |
if isinstance(polygons, list): | |
polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons] | |
elif isinstance(polygons, Polygons): | |
polygons = polygons.polygons | |
self.polygons = polygons | |
self.size = size | |
self.mode = mode | |
def transpose(self, method): | |
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
raise NotImplementedError( | |
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
) | |
flipped_polygons = [] | |
width, height = self.size | |
if method == FLIP_LEFT_RIGHT: | |
dim = width | |
idx = 0 | |
elif method == FLIP_TOP_BOTTOM: | |
dim = height | |
idx = 1 | |
for poly in self.polygons: | |
p = poly.clone() | |
TO_REMOVE = 1 | |
p[idx::2] = dim - poly[idx::2] - TO_REMOVE | |
flipped_polygons.append(p) | |
return Polygons(flipped_polygons, size=self.size, mode=self.mode) | |
def crop(self, box): | |
w, h = box[2] - box[0], box[3] - box[1] | |
# TODO chck if necessary | |
w = max(w, 1) | |
h = max(h, 1) | |
cropped_polygons = [] | |
for poly in self.polygons: | |
p = poly.clone() | |
p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) | |
p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) | |
cropped_polygons.append(p) | |
return Polygons(cropped_polygons, size=(w, h), mode=self.mode) | |
def resize(self, size, *args, **kwargs): | |
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) | |
if ratios[0] == ratios[1]: | |
ratio = ratios[0] | |
scaled_polys = [p * ratio for p in self.polygons] | |
return Polygons(scaled_polys, size, mode=self.mode) | |
ratio_w, ratio_h = ratios | |
scaled_polygons = [] | |
for poly in self.polygons: | |
p = poly.clone() | |
p[0::2] *= ratio_w | |
p[1::2] *= ratio_h | |
scaled_polygons.append(p) | |
return Polygons(scaled_polygons, size=size, mode=self.mode) | |
def convert(self, mode): | |
width, height = self.size | |
if mode == "mask": | |
rles = mask_utils.frPyObjects( | |
[p.detach().numpy() for p in self.polygons], height, width | |
) | |
rle = mask_utils.merge(rles) | |
mask = mask_utils.decode(rle) | |
mask = torch.from_numpy(mask) | |
# TODO add squeeze? | |
return mask | |
def __repr__(self): | |
s = self.__class__.__name__ + "(" | |
s += "num_polygons={}, ".format(len(self.polygons)) | |
s += "image_width={}, ".format(self.size[0]) | |
s += "image_height={}, ".format(self.size[1]) | |
s += "mode={})".format(self.mode) | |
return s | |
class SegmentationMask(object): | |
""" | |
This class stores the segmentations for all objects in the image | |
""" | |
def __init__(self, polygons, size, mode=None): | |
""" | |
Arguments: | |
polygons: a list of list of lists of numbers. The first | |
level of the list correspond to individual instances, | |
the second level to all the polygons that compose the | |
object, and the third level to the polygon coordinates. | |
""" | |
assert isinstance(polygons, list) | |
self.polygons = [Polygons(p, size, mode) for p in polygons] | |
self.size = size | |
self.mode = mode | |
def transpose(self, method): | |
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
raise NotImplementedError( | |
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
) | |
flipped = [] | |
for polygon in self.polygons: | |
flipped.append(polygon.transpose(method)) | |
return SegmentationMask(flipped, size=self.size, mode=self.mode) | |
def crop(self, box): | |
w, h = box[2] - box[0], box[3] - box[1] | |
cropped = [] | |
for polygon in self.polygons: | |
cropped.append(polygon.crop(box)) | |
return SegmentationMask(cropped, size=(w, h), mode=self.mode) | |
def resize(self, size, *args, **kwargs): | |
scaled = [] | |
for polygon in self.polygons: | |
scaled.append(polygon.resize(size, *args, **kwargs)) | |
return SegmentationMask(scaled, size=size, mode=self.mode) | |
def to(self, *args, **kwargs): | |
return self | |
def __getitem__(self, item): | |
if isinstance(item, (int, slice)): | |
selected_polygons = [self.polygons[item]] | |
else: | |
# advanced indexing on a single dimension | |
selected_polygons = [] | |
if isinstance(item, torch.Tensor) and item.dtype == torch.bool: | |
item = item.nonzero() | |
item = item.squeeze(1) if item.numel() > 0 else item | |
item = item.tolist() | |
for i in item: | |
selected_polygons.append(self.polygons[i]) | |
return SegmentationMask(selected_polygons, size=self.size, mode=self.mode) | |
def __iter__(self): | |
return iter(self.polygons) | |
def __repr__(self): | |
s = self.__class__.__name__ + "(" | |
s += "num_instances={}, ".format(len(self.polygons)) | |
s += "image_width={}, ".format(self.size[0]) | |
s += "image_height={})".format(self.size[1]) | |
return s |