File size: 6,613 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import warnings
from typing import Any, Dict, List, Tuple, Union
import torch
class Instances:
"""
This class represents a list of instances in an image.
It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
All fields must have the same ``__len__`` which is the number of instances.
All other (non-field) attributes of this class are considered private:
they must start with '_' and are not modifiable by a user.
Some basic usage:
1. Set/get/check a field:
.. code-block:: python
instances.gt_boxes = Boxes(...)
print(instances.pred_masks) # a tensor of shape (N, H, W)
print('gt_masks' in instances)
2. ``len(instances)`` returns the number of instances
3. Indexing: ``instances[indices]`` will apply the indexing on all the fields
and returns a new :class:`Instances`.
Typically, ``indices`` is a integer vector of indices,
or a binary mask of length ``num_instances``
.. code-block:: python
category_3_detections = instances[instances.pred_classes == 3]
confident_detections = instances[instances.scores > 0.9]
"""
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
"""
Args:
image_size (height, width): the spatial size of the image.
kwargs: fields to add to this `Instances`.
"""
self._image_size = image_size
self._fields: Dict[str, Any] = {}
for k, v in kwargs.items():
self.set(k, v)
@property
def image_size(self) -> Tuple[int, int]:
"""
Returns:
tuple: height, width
"""
return self._image_size
def __setattr__(self, name: str, val: Any) -> None:
if name.startswith("_"):
super().__setattr__(name, val)
else:
self.set(name, val)
def __getattr__(self, name: str) -> Any:
if name == "_fields" or name not in self._fields:
raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
return self._fields[name]
def set(self, name: str, value: Any) -> None:
"""
Set the field named `name` to `value`.
The length of `value` must be the number of instances,
and must agree with other existing fields in this object.
"""
with warnings.catch_warnings(record=True):
data_len = len(value)
if len(self._fields):
assert (
len(self) == data_len
), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
self._fields[name] = value
def has(self, name: str) -> bool:
"""
Returns:
bool: whether the field called `name` exists.
"""
return name in self._fields
def remove(self, name: str) -> None:
"""
Remove the field called `name`.
"""
del self._fields[name]
def get(self, name: str) -> Any:
"""
Returns the field called `name`.
"""
return self._fields[name]
def get_fields(self) -> Dict[str, Any]:
"""
Returns:
dict: a dict which maps names (str) to data of the fields
Modifying the returned dict will modify this instance.
"""
return self._fields
# Tensor-like methods
def to(self, *args: Any, **kwargs: Any) -> "Instances":
"""
Returns:
Instances: all fields are called with a `to(device)`, if the field has this method.
"""
ret = Instances(self._image_size)
for k, v in self._fields.items():
if hasattr(v, "to"):
v = v.to(*args, **kwargs)
ret.set(k, v)
return ret
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances":
"""
Args:
item: an index-like object and will be used to index all the fields.
Returns:
If `item` is a string, return the data in the corresponding field.
Otherwise, returns an `Instances` where all fields are indexed by `item`.
"""
if type(item) == int:
if item >= len(self) or item < -len(self):
raise IndexError("Instances index out of range!")
else:
item = slice(item, None, len(self))
ret = Instances(self._image_size)
for k, v in self._fields.items():
ret.set(k, v[item])
return ret
def __len__(self) -> int:
for v in self._fields.values():
# use __len__ because len() has to be int and is not friendly to tracing
return v.__len__()
raise NotImplementedError("Empty Instances does not support __len__!")
def __iter__(self):
raise NotImplementedError("`Instances` object is not iterable!")
@staticmethod
def cat(instance_lists: List["Instances"]) -> "Instances":
"""
Args:
instance_lists (list[Instances])
Returns:
Instances
"""
assert all(isinstance(i, Instances) for i in instance_lists)
assert len(instance_lists) > 0
if len(instance_lists) == 1:
return instance_lists[0]
image_size = instance_lists[0].image_size
if not isinstance(image_size, torch.Tensor): # could be a tensor in tracing
for i in instance_lists[1:]:
assert i.image_size == image_size
ret = Instances(image_size)
for k in instance_lists[0]._fields.keys():
values = [i.get(k) for i in instance_lists]
v0 = values[0]
if isinstance(v0, torch.Tensor):
values = torch.cat(values, dim=0)
elif isinstance(v0, list):
values = list(itertools.chain(*values))
elif hasattr(type(v0), "cat"):
values = type(v0).cat(values)
else:
raise ValueError("Unsupported type {} for concatenation".format(type(v0)))
ret.set(k, values)
return ret
def __str__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self))
s += "image_height={}, ".format(self._image_size[0])
s += "image_width={}, ".format(self._image_size[1])
s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
return s
__repr__ = __str__
|