Spaces:
Running
on
L4
Running
on
L4
File size: 5,550 Bytes
dcc8c59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from typing import Union, List, Dict
import torch
from matanyone.inference.object_info import ObjectInfo
class ObjectManager:
"""
Object IDs are immutable. The same ID always represent the same object.
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
Temporary IDs start from 1.
"""
def __init__(self):
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
self.all_historical_object_ids: List[int] = []
def _recompute_obj_id_to_obj_mapping(self) -> None:
self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
def add_new_objects(
self, objects: Union[List[ObjectInfo], ObjectInfo,
List[int]]) -> (List[int], List[int]):
if not isinstance(objects, list):
objects = [objects]
corresponding_tmp_ids = []
corresponding_obj_ids = []
for obj in objects:
if isinstance(obj, int):
obj = ObjectInfo(id=obj)
if obj in self.obj_to_tmp_id:
# old object
corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
corresponding_obj_ids.append(obj.id)
else:
# new object
new_obj = ObjectInfo(id=obj.id)
# new object
new_tmp_id = len(self.obj_to_tmp_id) + 1
self.obj_to_tmp_id[new_obj] = new_tmp_id
self.tmp_id_to_obj[new_tmp_id] = new_obj
self.all_historical_object_ids.append(new_obj.id)
corresponding_tmp_ids.append(new_tmp_id)
corresponding_obj_ids.append(new_obj.id)
self._recompute_obj_id_to_obj_mapping()
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
return corresponding_tmp_ids, corresponding_obj_ids
def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
# delete an object or a list of objects
# re-sort the tmp ids
if isinstance(obj_ids_to_remove, int):
obj_ids_to_remove = [obj_ids_to_remove]
new_tmp_id = 1
total_num_id = len(self.obj_to_tmp_id)
local_obj_to_tmp_id = {}
local_tmp_to_obj_id = {}
for tmp_iter in range(1, total_num_id + 1):
obj = self.tmp_id_to_obj[tmp_iter]
if obj.id not in obj_ids_to_remove:
local_obj_to_tmp_id[obj] = new_tmp_id
local_tmp_to_obj_id[new_tmp_id] = obj
new_tmp_id += 1
self.obj_to_tmp_id = local_obj_to_tmp_id
self.tmp_id_to_obj = local_tmp_to_obj_id
self._recompute_obj_id_to_obj_mapping()
def purge_inactive_objects(self,
max_missed_detection_count: int) -> (bool, List[int], List[int]):
# remove tmp ids of objects that are removed
obj_id_to_be_deleted = []
tmp_id_to_be_deleted = []
tmp_id_to_keep = []
obj_id_to_keep = []
for obj in self.obj_to_tmp_id:
if obj.poke_count > max_missed_detection_count:
obj_id_to_be_deleted.append(obj.id)
tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
else:
tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
obj_id_to_keep.append(obj.id)
purge_activated = len(obj_id_to_be_deleted) > 0
if purge_activated:
self.delete_objects(obj_id_to_be_deleted)
return purge_activated, tmp_id_to_keep, obj_id_to_keep
def tmp_to_obj_cls(self, mask) -> torch.Tensor:
# remap tmp id cls representation to the true object id representation
new_mask = torch.zeros_like(mask)
for tmp_id, obj in self.tmp_id_to_obj.items():
new_mask[mask == tmp_id] = obj.id
return new_mask
def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
# returns the mapping in a dict format for saving it with pickle
return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
# turns a dict indexed by obj id into a tensor, ordered by tmp IDs
output = []
for _, obj in self.tmp_id_to_obj.items():
if obj.id not in obj_dict:
raise NotImplementedError
output.append(obj_dict[obj.id])
output = torch.stack(output, dim=dim)
return output
def make_one_hot(self, cls_mask) -> torch.Tensor:
output = []
for _, obj in self.tmp_id_to_obj.items():
output.append(cls_mask == obj.id)
if len(output) == 0:
output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
else:
output = torch.stack(output, dim=0)
return output
@property
def all_obj_ids(self) -> List[int]:
return [k.id for k in self.obj_to_tmp_id]
@property
def num_obj(self) -> int:
return len(self.obj_to_tmp_id)
def has_all(self, objects: List[int]) -> bool:
for obj in objects:
if obj not in self.obj_to_tmp_id:
return False
return True
def find_object_by_id(self, obj_id) -> ObjectInfo:
return self.obj_id_to_obj[obj_id]
def find_tmp_by_id(self, obj_id) -> int:
return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
|