Spaces:
Running
Running
import torch | |
# transpose | |
FLIP_LEFT_RIGHT = 0 | |
FLIP_TOP_BOTTOM = 1 | |
class BoxList(object): | |
""" | |
This class represents a set of bounding boxes. | |
The bounding boxes are represented as a Nx4 Tensor. | |
In order to uniquely determine the bounding boxes with respect | |
to an image, we also store the corresponding image dimensions. | |
They can contain extra information that is specific to each bounding box, such as | |
labels. | |
""" | |
def __init__(self, bbox, image_size, mode="xyxy"): | |
device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu") | |
# only do as_tensor if isn't a "no-op", because it hurts JIT tracing | |
if (not isinstance(bbox, torch.Tensor) | |
or bbox.dtype != torch.float32 or bbox.device != device): | |
bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device) | |
if bbox.ndimension() == 1 and bbox.size(-1) ==4: | |
bbox = bbox.unsqueeze(0) | |
if bbox.ndimension() != 2: | |
raise ValueError( | |
"bbox should have 2 dimensions, got {}".format(bbox.ndimension()) | |
) | |
if bbox.size(-1) != 4: | |
raise ValueError( | |
"last dimenion of bbox should have a " | |
"size of 4, got {}".format(bbox.size(-1)) | |
) | |
if mode not in ("xyxy", "xywh"): | |
raise ValueError("mode should be 'xyxy' or 'xywh'") | |
self.bbox = bbox | |
self.size = image_size # (image_width, image_height) | |
self.mode = mode | |
self.extra_fields = {} | |
# note: _jit_wrap/_jit_unwrap only work if the keys and the sizes don't change in between | |
def _jit_unwrap(self): | |
return (self.bbox,) + tuple(f for f in (self.get_field(field) | |
for field in sorted(self.fields())) | |
if isinstance(f, torch.Tensor)) | |
def _jit_wrap(self, input_stream): | |
self.bbox = input_stream[0] | |
num_consumed = 1 | |
for f in sorted(self.fields()): | |
if isinstance(self.extra_fields[f], torch.Tensor): | |
self.extra_fields[f] = input_stream[num_consumed] | |
num_consumed += 1 | |
return self, input_stream[num_consumed:] | |
def add_field(self, field, field_data): | |
self.extra_fields[field] = field_data | |
def get_field(self, field): | |
return self.extra_fields[field] | |
def has_field(self, field): | |
return field in self.extra_fields | |
def fields(self): | |
return list(self.extra_fields.keys()) | |
def _copy_extra_fields(self, bbox): | |
for k, v in bbox.extra_fields.items(): | |
self.extra_fields[k] = v | |
def convert(self, mode): | |
if mode not in ("xyxy", "xywh"): | |
raise ValueError("mode should be 'xyxy' or 'xywh'") | |
if mode == self.mode: | |
return self | |
# we only have two modes, so don't need to check | |
# self.mode | |
xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
if mode == "xyxy": | |
bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1) | |
bbox = BoxList(bbox, self.size, mode=mode) | |
else: | |
TO_REMOVE = 1 | |
# NOTE: explicitly specify dim to avoid tracing error in GPU | |
bbox = torch.cat( | |
(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1 | |
) | |
bbox = BoxList(bbox, self.size, mode=mode) | |
bbox._copy_extra_fields(self) | |
return bbox | |
def _split_into_xyxy(self): | |
if self.mode == "xyxy": | |
xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1) | |
return xmin, ymin, xmax, ymax | |
elif self.mode == "xywh": | |
TO_REMOVE = 1 | |
xmin, ymin, w, h = self.bbox.split(1, dim=-1) | |
return ( | |
xmin, | |
ymin, | |
xmin + (w - TO_REMOVE).clamp(min=0), | |
ymin + (h - TO_REMOVE).clamp(min=0), | |
) | |
else: | |
raise RuntimeError("Should not be here") | |
def resize(self, size, *args, **kwargs): | |
""" | |
Returns a resized copy of this bounding box | |
:param size: The requested size in pixels, as a 2-tuple: | |
(width, height). | |
""" | |
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) | |
if ratios[0] == ratios[1]: | |
ratio = ratios[0] | |
scaled_box = self.bbox * ratio | |
bbox = BoxList(scaled_box, size, mode=self.mode) | |
# bbox._copy_extra_fields(self) | |
for k, v in self.extra_fields.items(): | |
if not isinstance(v, torch.Tensor): | |
v = v.resize(size, *args, **kwargs) | |
bbox.add_field(k, v) | |
return bbox | |
ratio_width, ratio_height = ratios | |
xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
scaled_xmin = xmin * ratio_width | |
scaled_xmax = xmax * ratio_width | |
scaled_ymin = ymin * ratio_height | |
scaled_ymax = ymax * ratio_height | |
scaled_box = torch.cat( | |
(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 | |
) | |
bbox = BoxList(scaled_box, size, mode="xyxy") | |
# bbox._copy_extra_fields(self) | |
for k, v in self.extra_fields.items(): | |
if not isinstance(v, torch.Tensor): | |
v = v.resize(size, *args, **kwargs) | |
bbox.add_field(k, v) | |
return bbox.convert(self.mode) | |
def transpose(self, method): | |
""" | |
Transpose bounding box (flip or rotate in 90 degree steps) | |
:param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`, | |
:py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`, | |
:py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`, | |
:py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`. | |
""" | |
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
raise NotImplementedError( | |
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
) | |
image_width, image_height = self.size | |
xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
if method == FLIP_LEFT_RIGHT: | |
TO_REMOVE = 1 | |
transposed_xmin = image_width - xmax - TO_REMOVE | |
transposed_xmax = image_width - xmin - TO_REMOVE | |
transposed_ymin = ymin | |
transposed_ymax = ymax | |
elif method == FLIP_TOP_BOTTOM: | |
transposed_xmin = xmin | |
transposed_xmax = xmax | |
transposed_ymin = image_height - ymax | |
transposed_ymax = image_height - ymin | |
transposed_boxes = torch.cat( | |
(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 | |
) | |
bbox = BoxList(transposed_boxes, self.size, mode="xyxy") | |
# bbox._copy_extra_fields(self) | |
for k, v in self.extra_fields.items(): | |
if not isinstance(v, torch.Tensor): | |
v = v.transpose(method) | |
bbox.add_field(k, v) | |
return bbox.convert(self.mode) | |
def crop(self, box): | |
""" | |
Cropss a rectangular region from this bounding box. The box is a | |
4-tuple defining the left, upper, right, and lower pixel | |
coordinate. | |
""" | |
xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
w, h = box[2] - box[0], box[3] - box[1] | |
cropped_xmin = (xmin - box[0]).clamp(min=0, max=w) | |
cropped_ymin = (ymin - box[1]).clamp(min=0, max=h) | |
cropped_xmax = (xmax - box[0]).clamp(min=0, max=w) | |
cropped_ymax = (ymax - box[1]).clamp(min=0, max=h) | |
# TODO should I filter empty boxes here? | |
cropped_box = torch.cat( | |
(cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1 | |
) | |
bbox = BoxList(cropped_box, (w, h), mode="xyxy") | |
# bbox._copy_extra_fields(self) | |
for k, v in self.extra_fields.items(): | |
if not isinstance(v, torch.Tensor): | |
v = v.crop(box) | |
bbox.add_field(k, v) | |
return bbox.convert(self.mode) | |
# Tensor-like methods | |
def to(self, device): | |
bbox = BoxList(self.bbox.to(device), self.size, self.mode) | |
for k, v in self.extra_fields.items(): | |
if hasattr(v, "to"): | |
v = v.to(device) | |
bbox.add_field(k, v) | |
return bbox | |
def __getitem__(self, item): | |
bbox = BoxList(self.bbox[item], self.size, self.mode) | |
for k, v in self.extra_fields.items(): | |
bbox.add_field(k, v[item]) | |
return bbox | |
def __len__(self): | |
return self.bbox.shape[0] | |
def clip_to_image(self, remove_empty=True): | |
TO_REMOVE = 1 | |
x1s = self.bbox[:, 0].clamp(min=0, max=self.size[0] - TO_REMOVE) | |
y1s = self.bbox[:, 1].clamp(min=0, max=self.size[1] - TO_REMOVE) | |
x2s = self.bbox[:, 2].clamp(min=0, max=self.size[0] - TO_REMOVE) | |
y2s = self.bbox[:, 3].clamp(min=0, max=self.size[1] - TO_REMOVE) | |
self.bbox = torch.stack((x1s, y1s, x2s, y2s), dim=-1) | |
if remove_empty: | |
box = self.bbox | |
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) | |
return self[keep] | |
return self | |
def area(self): | |
if self.mode == 'xyxy': | |
TO_REMOVE = 1 | |
box = self.bbox | |
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) | |
elif self.mode == 'xywh': | |
box = self.bbox | |
area = box[:, 2] * box[:, 3] | |
else: | |
raise RuntimeError("Should not be here") | |
return area | |
def copy_with_fields(self, fields): | |
bbox = BoxList(self.bbox, self.size, self.mode) | |
if not isinstance(fields, (list, tuple)): | |
fields = [fields] | |
for field in fields: | |
bbox.add_field(field, self.get_field(field)) | |
return bbox | |
def __repr__(self): | |
s = self.__class__.__name__ + "(" | |
s += "num_boxes={}, ".format(len(self)) | |
s += "image_width={}, ".format(self.size[0]) | |
s += "image_height={}, ".format(self.size[1]) | |
s += "mode={})".format(self.mode) | |
return s | |
def concate_box_list(list_of_boxes): | |
boxes = torch.cat([i.bbox for i in list_of_boxes], dim=0) | |
extra_fields_keys = list(list_of_boxes[0].extra_fields.keys()) | |
extra_fields = {} | |
for key in extra_fields_keys: | |
extra_fields[key] = torch.cat([i.extra_fields[key] for i in list_of_boxes], dim=0) | |
final = list_of_boxes[0].copy_with_fields(extra_fields_keys) | |
final.bbox = boxes | |
final.extra_fields = extra_fields | |
return final | |
def _onnx_clip_boxes_to_image(boxes, size): | |
# type: (Tensor, Tuple[int, int]) | |
""" | |
Clip boxes so that they lie inside an image of size `size`. | |
Clip's min max are traced as constants. Use torch.min/max to WAR this issue | |
Arguments: | |
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format | |
size (Tuple[height, width]): size of the image | |
Returns: | |
clipped_boxes (Tensor[N, 4]) | |
""" | |
TO_REMOVE = 1 | |
device = boxes.device | |
dim = boxes.dim() | |
boxes_x = boxes[..., 0::2] | |
boxes_y = boxes[..., 1::2] | |
boxes_x = torch.max(boxes_x, torch.tensor(0., dtype=torch.float).to(device)) | |
boxes_x = torch.min(boxes_x, torch.tensor(size[1] - TO_REMOVE, dtype=torch.float).to(device)) | |
boxes_y = torch.max(boxes_y, torch.tensor(0., dtype=torch.float).to(device)) | |
boxes_y = torch.min(boxes_y, torch.tensor(size[0] - TO_REMOVE, dtype=torch.float).to(device)) | |
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) | |
return clipped_boxes.reshape(boxes.shape) | |
if __name__ == "__main__": | |
bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10)) | |
s_bbox = bbox.resize((5, 5)) | |
print(s_bbox) | |
print(s_bbox.bbox) | |
t_bbox = bbox.transpose(0) | |
print(t_bbox) | |
print(t_bbox.bbox) |