# Copyright (c) OpenMMLab. All rights reserved. # written by lzx from mmdet.registry import DATASETS from mmdet.datasets.api_wrappers import COCO from .HSI import HSIDataset @DATASETS.register_module() class SIRSTDataset(HSIDataset): """Dataset for COCO.""" METAINFO = { 'classes': ('object',), # palette is a list of color tuples, which is used for visualization. 'palette': [(220, 20, 60),] } COCOAPI = COCO # @DATASETS.register_module() # class SIRSTDataset(CocoDataset): # """Dataset for COCO.""" # # METAINFO = { # 'classes': # ('object',), # # palette is a list of color tuples, which is used for visualization. # 'palette': # [(220, 20, 60),] # } # COCOAPI = COCO # # ann_id is unique in coco dataset. # ANN_ID_UNIQUE = True # # def load_data_list(self) -> List[dict]: # """Load annotations from an annotation file named as ``self.ann_file`` # # Returns: # List[dict]: A list of annotation. # """ # noqa: E501 # with get_local_path( # self.ann_file, backend_args=self.backend_args) as local_path: # self.coco = self.COCOAPI(local_path) # # The order of returned `cat_ids` will not # # change with the order of the `classes` # self.cat_ids = self.coco.get_cat_ids( # cat_names=self.metainfo['classes']) # self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} # self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) # # img_ids = self.coco.get_img_ids() # data_list = [] # total_ann_ids = [] # for img_id in img_ids: # raw_img_info = self.coco.load_imgs([img_id])[0] # raw_img_info['img_id'] = img_id # # ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) # raw_ann_info = self.coco.load_anns(ann_ids) # total_ann_ids.extend(ann_ids) # # parsed_data_info = self.parse_data_info({ # 'raw_ann_info': # raw_ann_info, # 'raw_img_info': # raw_img_info # }) # data_list.append(parsed_data_info) # if self.ANN_ID_UNIQUE: # assert len(set(total_ann_ids)) == len( # total_ann_ids # ), f"Annotation ids in '{self.ann_file}' are not unique!" # # del self.coco # # return data_list # # def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: # """Parse raw annotation to target format. # # Args: # raw_data_info (dict): Raw data information load from ``ann_file`` # # Returns: # Union[dict, List[dict]]: Parsed annotation. # """ # img_info = raw_data_info['raw_img_info'] # ann_info = raw_data_info['raw_ann_info'] # # data_info = {} # # # TODO: need to change data_prefix['img'] to data_prefix['img_path'] # img_path = osp.join(self.data_prefix['img'], img_info['file_name']) # if self.data_prefix.get('seg', None): # seg_map_path = osp.join( # self.data_prefix['seg'], # img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix) # else: # seg_map_path = None # data_info['img_path'] = img_path # data_info['img_id'] = img_info['img_id'] # data_info['seg_map_path'] = seg_map_path # data_info['height'] = img_info['height'] # data_info['width'] = img_info['width'] # # instances = [] # for i, ann in enumerate(ann_info): # instance = {} # # if ann.get('ignore', False): # continue # x1, y1, w, h = ann['bbox'] # inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) # inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) # if inter_w * inter_h == 0: # continue # if ann['area'] <= 0 or w < 1 or h < 1: # continue # if ann['category_id'] not in self.cat_ids: # continue # bbox = [x1, y1, x1 + w, y1 + h] # # if ann.get('iscrowd', False): # instance['ignore_flag'] = 1 # else: # instance['ignore_flag'] = 0 # instance['bbox'] = bbox # instance['bbox_label'] = self.cat2label[ann['category_id']] # # if ann.get('segmentation', None): # instance['mask'] = ann['segmentation'] # # instances.append(instance) # data_info['instances'] = instances # return data_info # # def filter_data(self) -> List[dict]: # """Filter annotations according to filter_cfg. # # Returns: # List[dict]: Filtered results. # """ # if self.test_mode: # return self.data_list # # if self.filter_cfg is None: # return self.data_list # # filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) # min_size = self.filter_cfg.get('min_size', 0) # # # obtain images that contain annotation # ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) # # obtain images that contain annotations of the required categories # ids_in_cat = set() # for i, class_id in enumerate(self.cat_ids): # ids_in_cat |= set(self.cat_img_map[class_id]) # # merge the image id sets of the two conditions and use the merged set # # to filter out images if self.filter_empty_gt=True # ids_in_cat &= ids_with_ann # # valid_data_infos = [] # for i, data_info in enumerate(self.data_list): # img_id = data_info['img_id'] # width = data_info['width'] # height = data_info['height'] # if filter_empty_gt and img_id not in ids_in_cat: # continue # if min(width, height) >= min_size: # valid_data_infos.append(data_info) # # return valid_data_infos