# Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Optional from mmdet.datasets import BaseDetDataset, CocoDataset from ..registry import DATASETS, TASK_UTILS class BatchShapePolicyDataset(BaseDetDataset): """Dataset with the batch shape policy that makes paddings with least pixels during batch inference process, which does not require the image scales of all batches to be the same throughout validation.""" def __init__(self, *args, batch_shapes_cfg: Optional[dict] = None, **kwargs): self.batch_shapes_cfg = batch_shapes_cfg super().__init__(*args, **kwargs) def full_init(self): """rewrite full_init() to be compatible with serialize_data in BatchShapePolicy.""" if self._fully_initialized: return # load data information self.data_list = self.load_data_list() # batch_shapes_cfg if self.batch_shapes_cfg: batch_shapes_policy = TASK_UTILS.build(self.batch_shapes_cfg) self.data_list = batch_shapes_policy(self.data_list) del batch_shapes_policy # filter illegal data, such as data that has no annotations. self.data_list = self.filter_data() # Get subset data according to indices. if self._indices is not None: self.data_list = self._get_unserialized_subset(self._indices) # serialize data_list if self.serialize_data: self.data_bytes, self.data_address = self._serialize_data() self._fully_initialized = True def prepare_data(self, idx: int) -> Any: """Pass the dataset to the pipeline during training to support mixed data augmentation, such as Mosaic and MixUp.""" if self.test_mode is False: data_info = self.get_data_info(idx) data_info['dataset'] = self return self.pipeline(data_info) else: return super().prepare_data(idx) @DATASETS.register_module() class YOLOv5CocoDataset(BatchShapePolicyDataset, CocoDataset): """Dataset for YOLOv5 COCO Dataset. We only add `BatchShapePolicy` function compared with CocoDataset. See `mmyolo/datasets/utils.py#BatchShapePolicy` for details """ pass