diff --git a/mmpl/__init__.py b/mmpl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmpl/__pycache__/__init__.cpython-310.pyc b/mmpl/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c532da563db8bfab89df0cfe847de868cbf0978a Binary files /dev/null and b/mmpl/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/__pycache__/registry.cpython-310.pyc b/mmpl/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..099058dd22662bb2bf5a23573ba26c17db1b3244 Binary files /dev/null and b/mmpl/__pycache__/registry.cpython-310.pyc differ diff --git a/mmpl/datasets/__init__.py b/mmpl/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38a145baeae7b71bdc0f726ebc0915a45e90bf20 --- /dev/null +++ b/mmpl/datasets/__init__.py @@ -0,0 +1,9 @@ +from .builder import build_dataset +from .pl_datamodule import PLDataModule +from .nwpu_ins_dataset import NWPUInsSegDataset +from .whu_ins_dataset import WHUInsSegDataset +from .ssdd_ins_dataset import SSDDInsSegDataset + +__all__ = [ + 'build_dataset', 'PLDataModule', +] diff --git a/mmpl/datasets/__pycache__/__init__.cpython-310.pyc b/mmpl/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..016e68460d866b71e71414be61860e68fa135d24 Binary files /dev/null and b/mmpl/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/datasets/__pycache__/builder.cpython-310.pyc b/mmpl/datasets/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef89ae934b634dd56493290cce21591f38537e7c Binary files /dev/null and b/mmpl/datasets/__pycache__/builder.cpython-310.pyc differ diff --git a/mmpl/datasets/__pycache__/nwpu_ins_dataset.cpython-310.pyc b/mmpl/datasets/__pycache__/nwpu_ins_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1556b35a35bf6efd313acfb1a9adb41d8017a31 Binary files /dev/null and b/mmpl/datasets/__pycache__/nwpu_ins_dataset.cpython-310.pyc differ diff --git a/mmpl/datasets/__pycache__/pl_datamodule.cpython-310.pyc b/mmpl/datasets/__pycache__/pl_datamodule.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d7cfdeec01c77a0ee4d5c27b52fa396aa207d31 Binary files /dev/null and b/mmpl/datasets/__pycache__/pl_datamodule.cpython-310.pyc differ diff --git a/mmpl/datasets/__pycache__/ssdd_ins_dataset.cpython-310.pyc b/mmpl/datasets/__pycache__/ssdd_ins_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b29b607f64cf7d82660e402ab09e715b6283be1c Binary files /dev/null and b/mmpl/datasets/__pycache__/ssdd_ins_dataset.cpython-310.pyc differ diff --git a/mmpl/datasets/__pycache__/whu_ins_dataset.cpython-310.pyc b/mmpl/datasets/__pycache__/whu_ins_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0165dd3ce19c64d92e71eb542d799a579cde404 Binary files /dev/null and b/mmpl/datasets/__pycache__/whu_ins_dataset.cpython-310.pyc differ diff --git a/mmpl/datasets/base_dataset.py b/mmpl/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..579f14041d2680e3dbd3dc3e1b2035d29809ceda --- /dev/null +++ b/mmpl/datasets/base_dataset.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from os import PathLike +from typing import List, Optional, Sequence, Union + +import mmengine +import numpy as np +from mmengine.dataset import BaseDataset as _BaseDataset + +from .builder import DATASETS + + +def expanduser(path): + """Expand ~ and ~user constructions. + + If user or $HOME is unknown, do nothing. + """ + if isinstance(path, (str, PathLike)): + return osp.expanduser(path) + else: + return path + + +@DATASETS.register_module() +class BaseDataset(_BaseDataset): + """Base dataset for image classification task. + + This dataset support annotation file in `OpenMMLab 2.0 style annotation + format`. + + .. _OpenMMLab 2.0 style annotation format: + https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md + + Comparing with the :class:`mmengine.BaseDataset`, this class implemented + several useful methods. + + Args: + ann_file (str): Annotation file path. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for training data. Defaults to ''. + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None, which means using all ``data_infos``. + serialize_data (bool): Whether to hold memory using serialized objects, + when enabled, data loader workers can use shared RAM from master + process instead of making a copy. Defaults to True. + pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. + test_mode (bool): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + max_refetch (int): If ``Basedataset.prepare_data`` get a None img. + The maximum extra number of cycles to get a valid image. + Defaults to 1000. + classes (str | Sequence[str], optional): Specify names of classes. + + - If is string, it should be a file path, and the every line of + the file is a name of a class. + - If is a sequence of string, every item is a name of class. + - If is None, use categories information in ``metainfo`` argument, + annotation file or the class attribute ``METAINFO``. + + Defaults to None. + """ # noqa: E501 + + def __init__(self, + ann_file: str = '', + metainfo: Optional[dict] = None, + data_root: str = '', + data_prefix: Union[str, dict] = '', + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: Sequence = (), + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + classes: Union[str, Sequence[str], None] = None): + if isinstance(data_prefix, str): + data_prefix = dict(img_path=expanduser(data_prefix)) + + ann_file = expanduser(ann_file) + metainfo = self._compat_classes(metainfo, classes) + + super().__init__( + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + filter_cfg=filter_cfg, + indices=indices, + serialize_data=serialize_data, + pipeline=pipeline, + test_mode=test_mode, + lazy_init=lazy_init, + max_refetch=max_refetch) + + @property + def img_prefix(self): + """The prefix of images.""" + return self.data_prefix['img_path'] + + @property + def CLASSES(self): + """Return all categories names.""" + return self._metainfo.get('classes', None) + + @property + def class_to_idx(self): + """Map mapping class name to class index. + + Returns: + dict: mapping from class name to class index. + """ + + return {cat: i for i, cat in enumerate(self.CLASSES)} + + def get_gt_labels(self): + """Get all ground-truth labels (categories). + + Returns: + np.ndarray: categories for all images. + """ + + gt_labels = np.array( + [self.get_data_info(i)['gt_label'] for i in range(len(self))]) + return gt_labels + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category id by index. + + Args: + idx (int): Index of data. + + Returns: + cat_ids (List[int]): Image category of specified index. + """ + + return [int(self.get_data_info(idx)['gt_label'])] + + def _compat_classes(self, metainfo, classes): + """Merge the old style ``classes`` arguments to ``metainfo``.""" + if isinstance(classes, str): + # take it as a file path + class_names = mmengine.list_from_file(expanduser(classes)) + elif isinstance(classes, (tuple, list)): + class_names = classes + elif classes is not None: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if metainfo is None: + metainfo = {} + + if classes is not None: + metainfo = {'classes': tuple(class_names), **metainfo} + + return metainfo + + def full_init(self): + """Load annotation file and set ``BaseDataset._fully_initialized`` to + True.""" + super().full_init() + + # To support the standard OpenMMLab 2.0 annotation format. Generate + # metainfo in internal format from standard metainfo format. + if 'categories' in self._metainfo and 'classes' not in self._metainfo: + categories = sorted( + self._metainfo['categories'], key=lambda x: x['id']) + self._metainfo['classes'] = tuple( + [cat['category_name'] for cat in categories]) + + def __repr__(self): + """Print the basic information of the dataset. + + Returns: + str: Formatted string. + """ + head = 'Dataset ' + self.__class__.__name__ + body = [] + if self._fully_initialized: + body.append(f'Number of samples: \t{self.__len__()}') + else: + body.append("Haven't been initialized") + + if self.CLASSES is not None: + body.append(f'Number of categories: \t{len(self.CLASSES)}') + else: + body.append('The `CLASSES` meta info is not set.') + + body.extend(self.extra_repr()) + + if len(self.pipeline.transforms) > 0: + body.append('With transforms:') + for t in self.pipeline.transforms: + body.append(f' {t}') + + lines = [head] + [' ' * 4 + line for line in body] + return '\n'.join(lines) + + def extra_repr(self) -> List[str]: + """The extra repr information of the dataset.""" + body = [] + body.append(f'Annotation file: \t{self.ann_file}') + body.append(f'Prefix of images: \t{self.img_prefix}') + return body diff --git a/mmpl/datasets/builder.py b/mmpl/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..b33d0afef3af275f81953aa7777341ac496c498f --- /dev/null +++ b/mmpl/datasets/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpl.registry import DATASETS + + +def build_dataset(cfg): + """Build dataset. + + Examples: + >>> from mmpl.datasets import build_dataset + >>> mnist_train = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False)) + >>> print(mnist_train) + Dataset MNIST + Number of samples: 60000 + Number of categories: 10 + Prefix of data: data/mnist/ + >>> mnist_test = build_dataset( + ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True)) + >>> print(mnist_test) + Dataset MNIST + Number of samples: 10000 + Number of categories: 10 + Prefix of data: data/mnist/ + """ + return DATASETS.build(cfg) diff --git a/mmpl/datasets/custom.py b/mmpl/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..af1c0c140da3cbe1915f2f45134108cd7a2c232b --- /dev/null +++ b/mmpl/datasets/custom.py @@ -0,0 +1,237 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +from mmengine.fileio import (BaseStorageBackend, get_file_backend, + list_from_file) +from mmengine.logging import MMLogger + +from mmcls.registry import DATASETS +from .base_dataset import BaseDataset + + +def find_folders( + root: str, + backend: Optional[BaseStorageBackend] = None +) -> Tuple[List[str], Dict[str, int]]: + """Find classes by folders under a root. + + Args: + root (string): root directory of folders + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[List[str], Dict[str, int]]: + + - folders: The name of sub folders under the root. + - folder_to_idx: The map from folder name to class idx. + """ + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + folders = list( + backend.list_dir_or_file( + root, + list_dir=True, + list_file=False, + recursive=False, + )) + folders.sort() + folder_to_idx = {folders[i]: i for i in range(len(folders))} + return folders, folder_to_idx + + +def get_samples( + root: str, + folder_to_idx: Dict[str, int], + is_valid_file: Callable, + backend: Optional[BaseStorageBackend] = None, +): + """Make dataset by walking all images under a root. + + Args: + root (string): root directory of folders + folder_to_idx (dict): the map from class name to class idx + is_valid_file (Callable): A function that takes path of a file + and check if the file is a valid sample file. + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. + + Returns: + Tuple[list, set]: + + - samples: a list of tuple where each element is (image, class_idx) + - empty_folders: The folders don't have any valid files. + """ + samples = [] + available_classes = set() + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) + + for folder_name in sorted(list(folder_to_idx.keys())): + _dir = backend.join_path(root, folder_name) + files = backend.list_dir_or_file( + _dir, + list_dir=False, + list_file=True, + recursive=True, + ) + for file in sorted(list(files)): + if is_valid_file(file): + path = backend.join_path(folder_name, file) + item = (path, folder_to_idx[folder_name]) + samples.append(item) + available_classes.add(folder_name) + + empty_folders = set(folder_to_idx.keys()) - available_classes + + return samples, empty_folders + + +@DATASETS.register_module() +class CustomDataset(BaseDataset): + """Custom dataset for classification. + + The dataset supports two kinds of annotation format. + + 1. An annotation file is provided, and each line indicates a sample: + + The sample files: :: + + data_prefix/ + ├── folder_1 + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + └── folder_2 + ├── 123.png + ├── nsdf3.png + └── ... + + The annotation file (the first column is the image path and the second + column is the index of category): :: + + folder_1/xxx.png 0 + folder_1/xxy.png 1 + folder_2/123.png 5 + folder_2/nsdf3.png 3 + ... + + Please specify the name of categories by the argument ``classes`` + or ``metainfo``. + + 2. The samples are arranged in the specific way: :: + + data_prefix/ + ├── class_x + │ ├── xxx.png + │ ├── xxy.png + │ └── ... + │ └── xxz.png + └── class_y + ├── 123.png + ├── nsdf3.png + ├── ... + └── asd932_.png + + If the ``ann_file`` is specified, the dataset will be generated by the + first way, otherwise, try the second way. + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (str | dict): Prefix for the data. Defaults to ''. + extensions (Sequence[str]): A sequence of allowed extensions. Defaults + to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'). + lazy_init (bool): Whether to load annotation during instantiation. + In some cases, such as visualization, only the meta information of + the dataset is needed, which is not necessary to load annotation + file. ``Basedataset`` can skip load annotations to save time by set + ``lazy_init=False``. Defaults to False. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + ann_file: str = '', + metainfo: Optional[dict] = None, + data_root: str = '', + data_prefix: Union[str, dict] = '', + extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', + '.bmp', '.pgm', '.tif'), + lazy_init: bool = False, + **kwargs): + assert (ann_file or data_prefix or data_root), \ + 'One of `ann_file`, `data_root` and `data_prefix` must '\ + 'be specified.' + + self.extensions = tuple(set([i.lower() for i in extensions])) + + super().__init__( + # The base class requires string ann_file but this class doesn't + ann_file=ann_file, + metainfo=metainfo, + data_root=data_root, + data_prefix=data_prefix, + # Force to lazy_init for some modification before loading data. + lazy_init=True, + **kwargs) + + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + def _find_samples(self): + """find samples from ``data_prefix``.""" + classes, folder_to_idx = find_folders(self.img_prefix) + samples, empty_classes = get_samples( + self.img_prefix, + folder_to_idx, + is_valid_file=self.is_valid_file, + ) + + if len(samples) == 0: + raise RuntimeError( + f'Found 0 files in subfolders of: {self.data_prefix}. ' + f'Supported extensions are: {",".join(self.extensions)}') + + if self.CLASSES is not None: + assert len(self.CLASSES) == len(classes), \ + f"The number of subfolders ({len(classes)}) doesn't match " \ + f'the number of specified classes ({len(self.CLASSES)}). ' \ + 'Please check the data folder.' + else: + self._metainfo['classes'] = tuple(classes) + + if empty_classes: + logger = MMLogger.get_current_instance() + logger.warning( + 'Found no valid file in the folder ' + f'{", ".join(empty_classes)}. ' + f"Supported extensions are: {', '.join(self.extensions)}") + + self.folder_to_idx = folder_to_idx + + return samples + + def load_data_list(self): + """Load image paths and gt_labels.""" + if not self.ann_file: + samples = self._find_samples() + else: + lines = list_from_file(self.ann_file) + samples = [x.strip().rsplit(' ', 1) for x in lines] + + # Pre-build file backend to prevent verbose file backend inference. + backend = get_file_backend(self.img_prefix, enable_singleton=True) + data_list = [] + for filename, gt_label in samples: + img_path = backend.join_path(self.img_prefix, filename) + info = {'img_path': img_path, 'gt_label': int(gt_label)} + data_list.append(info) + return data_list + + def is_valid_file(self, filename: str) -> bool: + """Check if a file is a valid sample.""" + return filename.lower().endswith(self.extensions) diff --git a/mmpl/datasets/nwpu_ins_dataset.py b/mmpl/datasets/nwpu_ins_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0e012808c4ab4571661067b7676483850a3f56 --- /dev/null +++ b/mmpl/datasets/nwpu_ins_dataset.py @@ -0,0 +1,59 @@ +from typing import List + +from mmpl.registry import DATASETS +from mmdet.datasets.coco import CocoDataset + + +@DATASETS.register_module() +class NWPUInsSegDataset(CocoDataset): + """Dataset for Cityscapes.""" + + METAINFO = { + 'classes': ['airplane', 'ship', 'storage_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'ground_track_field', + 'harbor', 'bridge', 'vehicle'], + 'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), + (0, 60, 100), (0, 80, 100), (0, 0, 230), + (119, 11, 32), (0, 255, 0), (0, 0, 255)] + } + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + if self.test_mode: + return self.data_list + + if self.filter_cfg is None: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) + min_size = self.filter_cfg.get('min_size', 0) + + # obtain images that contain annotation + ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + img_id = data_info['img_id'] + width = data_info['width'] + height = data_info['height'] + all_is_crowd = all([ + instance['ignore_flag'] == 1 + for instance in data_info['instances'] + ]) + if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd): + continue + if min(width, height) >= min_size: + valid_data_infos.append(data_info) + + return valid_data_infos diff --git a/mmpl/datasets/pl_datamodule.py b/mmpl/datasets/pl_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..36492b5eada36c3b936aa16b9cda2b9e2ae4741f --- /dev/null +++ b/mmpl/datasets/pl_datamodule.py @@ -0,0 +1,73 @@ +from mmpl.registry import DATASETS +import lightning.pytorch as pl +from torch.utils.data import DataLoader +from .builder import build_dataset +from mmengine.registry import FUNCTIONS +from functools import partial + + +def get_collate_fn(dataloader_cfg): + collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate')) + collate_fn_type = collate_fn_cfg.pop('type') + collate_fn = FUNCTIONS.get(collate_fn_type) + collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore + return collate_fn + + +@DATASETS.register_module() +class PLDataModule(pl.LightningDataModule): + def __init__(self, + train_loader=None, + val_loader=None, + test_loader=None, + predict_loader=None, + **kwargs + ): + super().__init__() + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + self.predict_loader = predict_loader + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.predict_dataset = None + + def prepare_data(self): + pass + + def setup(self, stage: str): + if stage == "fit": + dataset_cfg = self.train_loader.pop('dataset') + self.train_dataset = build_dataset(dataset_cfg) + if self.val_loader is not None: + dataset_cfg = self.val_loader.pop('dataset') + self.val_dataset = build_dataset(dataset_cfg) + if stage == "val": + if self.val_loader is not None: + dataset_cfg = self.val_loader.pop('dataset') + self.val_dataset = build_dataset(dataset_cfg) + if stage == "test": + if self.test_loader is not None: + dataset_cfg = self.test_loader.pop('dataset') + self.test_dataset = build_dataset(dataset_cfg) + if stage == "predict": + if self.predict_loader is not None: + dataset_cfg = self.predict_loader.pop('dataset') + self.predict_dataset = build_dataset(dataset_cfg) + + def train_dataloader(self): + collate_fn = get_collate_fn(self.train_loader) + return DataLoader(self.train_dataset, collate_fn=collate_fn, **self.train_loader) + + def val_dataloader(self): + collate_fn = get_collate_fn(self.val_loader) + return DataLoader(self.val_dataset, collate_fn=collate_fn, **self.val_loader) + + def test_dataloader(self): + collate_fn = get_collate_fn(self.test_loader) + return DataLoader(self.test_dataset, collate_fn=collate_fn, **self.test_loader) + + def predict_dataloader(self): + collate_fn = get_collate_fn(self.predict_loader) + return DataLoader(self.predict_dataset, collate_fn=collate_fn, **self.predict_loader) diff --git a/mmpl/datasets/ssdd_ins_dataset.py b/mmpl/datasets/ssdd_ins_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7bab673e796aca77b34600b730b5a1ca46eee011 --- /dev/null +++ b/mmpl/datasets/ssdd_ins_dataset.py @@ -0,0 +1,54 @@ +from typing import List +from mmpl.registry import DATASETS +from mmdet.datasets.coco import CocoDataset + + +@DATASETS.register_module() +class SSDDInsSegDataset(CocoDataset): + """Dataset for Cityscapes.""" + + METAINFO = { + 'classes': ['ship'], + 'palette': [(0, 0, 255)] + } + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + # if self.test_mode: + # return self.data_list + + if self.filter_cfg is None: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) + min_size = self.filter_cfg.get('min_size', 0) + + # obtain images that contain annotation + ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + img_id = data_info['img_id'] + width = data_info['width'] + height = data_info['height'] + all_is_crowd = all([ + instance['ignore_flag'] == 1 + for instance in data_info['instances'] + ]) + if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd): + continue + if min(width, height) >= min_size: + valid_data_infos.append(data_info) + + return valid_data_infos diff --git a/mmpl/datasets/transforms/__init__.py b/mmpl/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmpl/datasets/transforms/__pycache__/__init__.cpython-310.pyc b/mmpl/datasets/transforms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..592e0ce3c3f9c23c21e9244b14fc1e8f7e19adb4 Binary files /dev/null and b/mmpl/datasets/transforms/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/datasets/utils.py b/mmpl/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb60e432c374c1a904700a7348f706fa0e523eb --- /dev/null +++ b/mmpl/datasets/utils.py @@ -0,0 +1,243 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import gzip +import hashlib +import os +import os.path +import shutil +import tarfile +import tempfile +import urllib.error +import urllib.request +import zipfile + +from mmengine.fileio import LocalBackend, get_file_backend + +__all__ = [ + 'rm_suffix', 'check_integrity', 'download_and_extract_archive', + 'open_maybe_compressed_file' +] + + +def rm_suffix(s, suffix=None): + if suffix is None: + return s[:s.rfind('.')] + else: + return s[:s.rfind(suffix)] + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024): + md5 = hashlib.md5() + backend = get_file_backend(fpath, enable_singleton=True) + if isinstance(backend, LocalBackend): + # Enable chunk update for local file. + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + else: + md5.update(backend.get(fpath)) + return md5.hexdigest() + + +def check_md5(fpath, md5, **kwargs): + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath, md5=None): + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def download_url_to_file(url, dst, hash_prefix=None, progress=True): + """Download object at the given URL to a local path. + + Modified from + https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file + + Args: + url (str): URL of the object to download + dst (str): Full path where object will be saved, + e.g. ``/tmp/temporary_file`` + hash_prefix (string, optional): If not None, the SHA256 downloaded + file should start with ``hash_prefix``. Defaults to None. + progress (bool): whether or not to display a progress bar to stderr. + Defaults to True + """ + file_size = None + req = urllib.request.Request(url) + u = urllib.request.urlopen(req) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders('Content-Length') + else: + content_length = meta.get_all('Content-Length') + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after download is + # complete. This prevents a local file being overridden by a broken + # download. + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + import rich.progress + columns = [ + rich.progress.DownloadColumn(), + rich.progress.BarColumn(bar_width=None), + rich.progress.TimeRemainingColumn(), + ] + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with rich.progress.Progress(*columns) as pbar: + task = pbar.add_task('download', total=file_size, visible=progress) + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(task, advance=len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError( + 'invalid hash value (expected "{}", got "{}")'.format( + hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def download_url(url, root, filename=None, md5=None): + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from. + root (str): Directory to place downloaded file in. + filename (str | None): Name to save the file under. + If filename is None, use the basename of the URL. + md5 (str | None): MD5 checksum of the download. + If md5 is None, download without md5 check. + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if check_integrity(fpath, md5): + print(f'Using downloaded and verified file: {fpath}') + else: + try: + print(f'Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + except (urllib.error.URLError, IOError) as e: + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + f' Downloading {url} to {fpath}') + download_url_to_file(url, fpath) + else: + raise e + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError('File not found or corrupted.') + + +def _is_tarxz(filename): + return filename.endswith('.tar.xz') + + +def _is_tar(filename): + return filename.endswith('.tar') + + +def _is_targz(filename): + return filename.endswith('.tar.gz') + + +def _is_tgz(filename): + return filename.endswith('.tgz') + + +def _is_gzip(filename): + return filename.endswith('.gz') and not filename.endswith('.tar.gz') + + +def _is_zip(filename): + return filename.endswith('.zip') + + +def extract_archive(from_path, to_path=None, remove_finished=False): + if to_path is None: + to_path = os.path.dirname(from_path) + + if _is_tar(from_path): + with tarfile.open(from_path, 'r') as tar: + tar.extractall(path=to_path) + elif _is_targz(from_path) or _is_tgz(from_path): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif _is_tarxz(from_path): + with tarfile.open(from_path, 'r:xz') as tar: + tar.extractall(path=to_path) + elif _is_gzip(from_path): + to_path = os.path.join( + to_path, + os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif _is_zip(from_path): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError(f'Extraction of {from_path} not supported') + + if remove_finished: + os.remove(from_path) + + +def download_and_extract_archive(url, + download_root, + extract_root=None, + filename=None, + md5=None, + remove_finished=False): + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f'Extracting {archive} to {extract_root}') + extract_archive(archive, extract_root, remove_finished) + + +def open_maybe_compressed_file(path: str): + """Return a file object that possibly decompresses 'path' on the fly. + + Decompression occurs when argument `path` is a string and ends with '.gz' + or '.xz'. + """ + if not isinstance(path, str): + return path + if path.endswith('.gz'): + import gzip + return gzip.open(path, 'rb') + if path.endswith('.xz'): + import lzma + return lzma.open(path, 'rb') + return open(path, 'rb') diff --git a/mmpl/datasets/whu_ins_dataset.py b/mmpl/datasets/whu_ins_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd3ecc1c1285a8bd2655b5dfaa713453df735d9 --- /dev/null +++ b/mmpl/datasets/whu_ins_dataset.py @@ -0,0 +1,54 @@ +from typing import List +from mmpl.registry import DATASETS +from mmdet.datasets.coco import CocoDataset + + +@DATASETS.register_module() +class WHUInsSegDataset(CocoDataset): + """Dataset for Cityscapes.""" + + METAINFO = { + 'classes': ['building'], + 'palette': [(0, 255, 0)] + } + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + # if self.test_mode: + # return self.data_list + + if self.filter_cfg is None: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) + min_size = self.filter_cfg.get('min_size', 0) + + # obtain images that contain annotation + ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + img_id = data_info['img_id'] + width = data_info['width'] + height = data_info['height'] + all_is_crowd = all([ + instance['ignore_flag'] == 1 + for instance in data_info['instances'] + ]) + if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd): + continue + if min(width, height) >= min_size: + valid_data_infos.append(data_info) + + return valid_data_infos diff --git a/mmpl/engine/__init__.py b/mmpl/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5dfffcfe73ffcf53cb87e50ce6260c189ec2a8e --- /dev/null +++ b/mmpl/engine/__init__.py @@ -0,0 +1,5 @@ +from .runner import * +from .logger import * +from .hooks import * +from .visualization import * +from .strategies import * \ No newline at end of file diff --git a/mmpl/engine/__pycache__/__init__.cpython-310.pyc b/mmpl/engine/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0eda8dce0ef1055eef7a330a81c8d102a357db2 Binary files /dev/null and b/mmpl/engine/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/__init__.py b/mmpl/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77f2c33df26749d5597fb3875d9f65238a68a2b4 --- /dev/null +++ b/mmpl/engine/hooks/__init__.py @@ -0,0 +1,6 @@ +from .builder import PL_HOOKS +from .pipeline_switch_hook import PipelineSwitchHook +from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook +from .ema_hook import EMAHook +from .param_scheduler_hook import ParamSchedulerHook +from .visualization_hook import DetVisualizationHook diff --git a/mmpl/engine/hooks/__pycache__/__init__.cpython-310.pyc b/mmpl/engine/hooks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f34dfafe1594547ba25ce1a64cfe43626d9d200 Binary files /dev/null and b/mmpl/engine/hooks/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/__pycache__/builder.cpython-310.pyc b/mmpl/engine/hooks/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d9fd36c9f4b41cf81c6084bb980d73763f952d3 Binary files /dev/null and b/mmpl/engine/hooks/__pycache__/builder.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/__pycache__/ema_hook.cpython-310.pyc b/mmpl/engine/hooks/__pycache__/ema_hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c89aa778aeeb9264ef6efda3560770bca3d33e0 Binary files /dev/null and b/mmpl/engine/hooks/__pycache__/ema_hook.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/__pycache__/param_scheduler_hook.cpython-310.pyc b/mmpl/engine/hooks/__pycache__/param_scheduler_hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc7c54be0cf8c2abc7f67cd6d741cb42980bc46a Binary files /dev/null and b/mmpl/engine/hooks/__pycache__/param_scheduler_hook.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/__pycache__/pipeline_switch_hook.cpython-310.pyc b/mmpl/engine/hooks/__pycache__/pipeline_switch_hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94a63c593b89431e13cc7290381fbe42ccedc66d Binary files /dev/null and b/mmpl/engine/hooks/__pycache__/pipeline_switch_hook.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/__pycache__/visualization_hook.cpython-310.pyc b/mmpl/engine/hooks/__pycache__/visualization_hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bc9bd13443df281a75250405bde09f5b1f5a79a Binary files /dev/null and b/mmpl/engine/hooks/__pycache__/visualization_hook.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/__pycache__/yolov5_param_scheduler_hook.cpython-310.pyc b/mmpl/engine/hooks/__pycache__/yolov5_param_scheduler_hook.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bc6b530174b1f4e2a420f0785712e9e6c8ca4e5 Binary files /dev/null and b/mmpl/engine/hooks/__pycache__/yolov5_param_scheduler_hook.cpython-310.pyc differ diff --git a/mmpl/engine/hooks/builder.py b/mmpl/engine/hooks/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c27b4591a7a546dcce76ce5fa0233b452bf916c0 --- /dev/null +++ b/mmpl/engine/hooks/builder.py @@ -0,0 +1,31 @@ +import copy +import inspect +from typing import List, Union + +import torch +import torch.nn as nn +import lightning + +from mmengine.config import Config, ConfigDict +from mmengine.device import is_npu_available +from mmpl.registry import HOOKS + + +def register_pl_hooks() -> List[str]: + """Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry. + + Returns: + List[str]: A list of registered callbacks' name. + """ + pl_hooks = [] + for module_name in dir(lightning.pytorch.callbacks): + if module_name.startswith('__'): + continue + _hook = getattr(lightning.pytorch.callbacks, module_name) + if inspect.isclass(_hook) and issubclass(_hook, lightning.pytorch.callbacks.Callback): + HOOKS.register_module(module=_hook) + pl_hooks.append(module_name) + return pl_hooks + + +PL_HOOKS = register_pl_hooks() diff --git a/mmpl/engine/hooks/ema_hook.py b/mmpl/engine/hooks/ema_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..54d83d1e92b299f33f2234510d1f1e180ed631ac --- /dev/null +++ b/mmpl/engine/hooks/ema_hook.py @@ -0,0 +1,240 @@ +import copy +import itertools +import logging +from typing import Dict, Optional, Any + +from lightning import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT +from mmengine.logging import print_log +from mmengine.model import is_model_wrapper +from mmpl.registry import HOOKS, MODELS + + + +@HOOKS.register_module() +class EMAHook(Callback): + """A Hook to apply Exponential Moving Average (EMA) on the model during + training. + + Note: + - EMAHook takes priority over CheckpointHook. + - The original model parameters are actually saved in ema field after + train. + - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. + + Args: + ema_type (str): The type of EMA strategy to use. You can find the + supported strategies in :mod:`mmengine.model.averaged_model`. + Defaults to 'ExponentialMovingAverage'. + strict_load (bool): Whether to strictly enforce that the keys of + ``state_dict`` in checkpoint match the keys returned by + ``self.module.state_dict``. Defaults to False. + Changed in v0.3.0. + begin_iter (int): The number of iteration to enable ``EMAHook``. + Defaults to 0. + begin_epoch (int): The number of epoch to enable ``EMAHook``. + Defaults to 0. + **kwargs: Keyword arguments passed to subclasses of + :obj:`BaseAveragedModel` + """ + + priority = 'NORMAL' + + def __init__(self, + ema_type: str = 'ExponentialMovingAverage', + strict_load: bool = False, + begin_iter: int = 0, + begin_epoch: int = 0, + **kwargs): + self.strict_load = strict_load + self.ema_cfg = dict(type=ema_type, **kwargs) + assert not (begin_iter != 0 and begin_epoch != 0), ( + '`begin_iter` and `begin_epoch` should not be both set.') + assert begin_iter >= 0, ( + '`begin_iter` must larger than or equal to 0, ' + f'but got begin_iter: {begin_iter}') + assert begin_epoch >= 0, ( + '`begin_epoch` must larger than or equal to 0, ' + f'but got begin_epoch: {begin_epoch}') + self.begin_iter = begin_iter + self.begin_epoch = begin_epoch + # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be + # enabled at 0 iteration. + self.enabled_by_epoch = self.begin_epoch > 0 + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Create an ema copy of the model. + + Args: + runner (Runner): The runner of the training process. + """ + model = pl_module + if is_model_wrapper(model): + model = model.module + self.src_model = model + self.ema_model = MODELS.build( + self.ema_cfg, default_args=dict(model=self.src_model)) + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Check the begin_epoch/iter is smaller than max_epochs/iters. + + Args: + runner (Runner): The runner of the training process. + """ + if self.enabled_by_epoch: + assert self.begin_epoch <= trainer.max_epochs, ( + 'self.begin_epoch should be smaller than or equal to ' + f'runner.max_epochs: {trainer.max_epochs}, but got ' + f'begin_epoch: {self.begin_epoch}') + else: + assert self.begin_iter <= trainer.max_steps or self.begin_iter <= trainer.max_epochs * len(trainer.train_dataloader), ( + 'self.begin_iter should be smaller than or equal to ' + f'runner.max_iters: {trainer.max_steps}, but got ' + f'begin_iter: {self.begin_iter}') + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + """Update ema parameter. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. + outputs (dict, optional): Outputs from model. Defaults to None. + """ + if self._ema_started(trainer): + self.ema_model.update_parameters(self.src_model) + else: + ema_params = self.ema_model.module.state_dict() + src_params = self.src_model.state_dict() + for k, p in ema_params.items(): + p.data.copy_(src_params[k].data) + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """We load parameter values from ema model to source model before + validation. + + Args: + runner (Runner): The runner of the training process. + """ + self._swap_ema_parameters() + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """We recover source model's parameter from ema model after validation. + + Args: + runner (Runner): The runner of the validation process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + self._swap_ema_parameters() + + def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """We load parameter values from ema model to source model before test. + + Args: + runner (Runner): The runner of the training process. + """ + self._swap_ema_parameters() + + def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """We recover source model's parameter from ema model after test. + + Args: + runner (Runner): The runner of the testing process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on test dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + self._swap_ema_parameters() + + def on_save_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + """Save ema parameters to checkpoint. + + Args: + runner (Runner): The runner of the testing process. + """ + checkpoint['ema_state_dict'] = self.ema_model.state_dict() + # Save ema parameters to the source model's state dict so that we + # can directly load the averaged model weights for deployment. + # Swapping the state_dict key-values instead of swapping model + # parameters because the state_dict is a shallow copy of model + # parameters. + self._swap_ema_state_dict(checkpoint) + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> None: + """Resume ema parameters from checkpoint. + + Args: + runner (Runner): The runner of the testing process. + """ + from mmengine.runner.checkpoint import load_state_dict + if 'ema_state_dict' in checkpoint and not trainer._checkpoint_connector._loaded_checkpoint: + # The original model parameters are actually saved in ema + # field swap the weights back to resume ema state. + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict( + checkpoint['ema_state_dict'], strict=self.strict_load) + + # Support load checkpoint without ema state dict. + else: + if not trainer._checkpoint_connector._loaded_checkpoint: + print_log( + 'There is no `ema_state_dict` in checkpoint. ' + '`EMAHook` will make a copy of `state_dict` as the ' + 'initial `ema_state_dict`', 'current', logging.WARNING) + load_state_dict( + self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) + + def _swap_ema_parameters(self) -> None: + """Swap the parameter of model with ema_model.""" + avg_param = ( + itertools.chain(self.ema_model.module.parameters(), + self.ema_model.module.buffers()) + if self.ema_model.update_buffers else + self.ema_model.module.parameters()) + src_param = ( + itertools.chain(self.src_model.parameters(), + self.src_model.buffers()) + if self.ema_model.update_buffers else self.src_model.parameters()) + for p_avg, p_src in zip(avg_param, src_param): + tmp = p_avg.data.clone() + p_avg.data.copy_(p_src.data) + p_src.data.copy_(tmp) + + def _swap_ema_state_dict(self, checkpoint): + """Swap the state dict values of model with ema_model.""" + model_state = checkpoint['state_dict'] + ema_state = checkpoint['ema_state_dict'] + for k in ema_state: + if k[:7] == 'module.': + tmp = ema_state[k] + ema_state[k] = model_state[k[7:]] + model_state[k[7:]] = tmp + + def _ema_started(self, trainer) -> bool: + """Whether ``EMAHook`` has been initialized at current iteration or + epoch. + + :attr:`ema_model` will be initialized when ``runner.iter`` or + ``runner.epoch`` is greater than ``self.begin`` for the first time. + + Args: + runner (Runner): Runner of the training, validation process. + + Returns: + bool: Whether ``EMAHook`` has been initialized. + """ + if self.enabled_by_epoch: + return trainer.current_epoch + 1 >= self.begin_epoch + else: + return trainer.global_step + 1 >= self.begin_iter diff --git a/mmpl/engine/hooks/param_scheduler_hook.py b/mmpl/engine/hooks/param_scheduler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcc4887ee7243b4b1c627112fdef214c8f46c40 --- /dev/null +++ b/mmpl/engine/hooks/param_scheduler_hook.py @@ -0,0 +1,128 @@ +from typing import Dict, Optional, Union, Any + +from lightning.pytorch.utilities.types import STEP_OUTPUT +from mmengine.optim import _ParamScheduler +from mmpl.registry import HOOKS +from mmengine.utils import is_list_of +from lightning import Callback + +DATA_BATCH = Optional[Union[dict, tuple, list]] + + +@HOOKS.register_module() +class ParamSchedulerHook(Callback): + """A hook to update some hyper-parameters in optimizer, e.g., learning rate + and momentum.""" + + priority = 'LOW' + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + """Call step function for each scheduler after each training iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (dict or tuple or list, optional): Data from dataloader. + In order to keep this interface consistent with other hooks, + we keep ``data_batch`` here. + outputs (dict, optional): Outputs from model. + In order to keep this interface consistent with other hooks, we + keep ``data_batch`` here. + """ + param_schedulers = pl_module.lr_schedulers() + if param_schedulers is None: + return + + def step(param_schedulers): + assert isinstance(param_schedulers, list) + for scheduler in param_schedulers: + if not scheduler.by_epoch: + scheduler.step() + if isinstance(param_schedulers, _ParamScheduler): + param_schedulers = [param_schedulers] + if isinstance(param_schedulers, list): + step(param_schedulers) + elif isinstance(param_schedulers, dict): + for param_schedulers in param_schedulers.values(): + step(param_schedulers) + else: + raise TypeError( + 'runner.param_schedulers should be list of ParamScheduler or ' + 'a dict containing list of ParamScheduler, ' + f'but got {param_schedulers}') + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Call step function for each scheduler after each training epoch. + + Args: + runner (Runner): The runner of the training process. + """ + param_schedulers = pl_module.lr_schedulers() + if param_schedulers is None: + return + + def step(param_schedulers): + assert isinstance(param_schedulers, list) + for scheduler in param_schedulers: + if scheduler.by_epoch: + scheduler.step() + if isinstance(param_schedulers, _ParamScheduler): + param_schedulers = [param_schedulers] + if isinstance(param_schedulers, list): + step(param_schedulers) + elif isinstance(param_schedulers, dict): + for param_schedulers in param_schedulers.values(): + step(param_schedulers) + else: + raise TypeError( + 'runner.param_schedulers should be list of ParamScheduler or ' + 'a dict containing list of ParamScheduler, ' + f'but got {param_schedulers}') + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Call step function for each scheduler which has attribute + ``need_val_args`` after each validation epoch. + + Args: + runner (Runner): The runner of the validation process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + + Note: + if ``runner.param_schedulers`` is not built before, + the hook ``after_val_epoch`` will be skipped. + """ + param_schedulers = pl_module.lr_schedulers() + if param_schedulers is None: + return + + # avoid counting scheduler._global_step + # it has counted in after_train_* hook + metrics = trainer.callback_metrics + if metrics is None: + return + + def step(param_schedulers): + # check param_schedulers is list and built + if not is_list_of(param_schedulers, _ParamScheduler): + return + + for scheduler in param_schedulers: + if (scheduler.by_epoch + and getattr(scheduler, 'need_val_args', False)): + scheduler.step(metrics) + if isinstance(param_schedulers, _ParamScheduler): + param_schedulers = [param_schedulers] + if isinstance(param_schedulers, list): + step(param_schedulers) + elif isinstance(param_schedulers, dict): + for param_schedulers in param_schedulers.values(): + step(param_schedulers) + else: + raise TypeError( + 'runner.param_schedulers should be list of ParamScheduler or ' + 'a dict containing list of ParamScheduler, ' + f'but got {param_schedulers}') diff --git a/mmpl/engine/hooks/pipeline_switch_hook.py b/mmpl/engine/hooks/pipeline_switch_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad4a98dc47125bac0056aa3ab0f07e2c381f88d --- /dev/null +++ b/mmpl/engine/hooks/pipeline_switch_hook.py @@ -0,0 +1,41 @@ +from mmcv.transforms import Compose +from mmpl.registry import HOOKS +from lightning.pytorch.callbacks import Callback + + +@HOOKS.register_module() +class PipelineSwitchHook(Callback): + """Switch data pipeline at switch_epoch. + + Args: + switch_epoch (int): switch pipeline at this epoch. + switch_pipeline (list[dict]): the pipeline to switch to. + """ + + def __init__(self, switch_epoch, switch_pipeline): + self.switch_epoch = switch_epoch + self.switch_pipeline = switch_pipeline + self._restart_dataloader = False + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """switch pipeline.""" + epoch = trainer.current_epoch + train_loader = trainer.train_dataloader + if epoch == self.switch_epoch: + if trainer.local_rank == 0: + print('Switch pipeline now!') + # The dataset pipeline cannot be updated when persistent_workers + # is True, so we need to force the dataloader's multi-process + # restart. This is a very hacky approach. + train_loader.dataset.pipeline = Compose(self.switch_pipeline) + if hasattr(train_loader, 'persistent_workers' + ) and train_loader.persistent_workers is True: + train_loader._DataLoader__initialized = False + train_loader._iterator = None + self._restart_dataloader = True + + else: + # Once the restart is complete, we need to restore + # the initialization flag. + if self._restart_dataloader: + train_loader._DataLoader__initialized = True diff --git a/mmpl/engine/hooks/ppyoloe_param_scheduler_hook.py b/mmpl/engine/hooks/ppyoloe_param_scheduler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..26dfe6ef2d5cf590ea381efb3e42cdc1c5492361 --- /dev/null +++ b/mmpl/engine/hooks/ppyoloe_param_scheduler_hook.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +from mmengine.hooks import ParamSchedulerHook +from mmengine.runner import Runner + +from mmyolo.registry import HOOKS + + +@HOOKS.register_module() +class PPYOLOEParamSchedulerHook(ParamSchedulerHook): + """A hook to update learning rate and momentum in optimizer of PPYOLOE. We + use this hook to implement adaptive computation for `warmup_total_iters`, + which is not possible with the built-in ParamScheduler in mmyolo. + + Args: + warmup_min_iter (int): Minimum warmup iters. Defaults to 1000. + start_factor (float): The number we multiply learning rate in the + first epoch. The multiplication factor changes towards end_factor + in the following epochs. Defaults to 0. + warmup_epochs (int): Epochs for warmup. Defaults to 5. + min_lr_ratio (float): Minimum learning rate ratio. + total_epochs (int): In PPYOLOE, `total_epochs` is set to + training_epochs x 1.2. Defaults to 360. + """ + priority = 9 + + def __init__(self, + warmup_min_iter: int = 1000, + start_factor: float = 0., + warmup_epochs: int = 5, + min_lr_ratio: float = 0.0, + total_epochs: int = 360): + + self.warmup_min_iter = warmup_min_iter + self.start_factor = start_factor + self.warmup_epochs = warmup_epochs + self.min_lr_ratio = min_lr_ratio + self.total_epochs = total_epochs + + self._warmup_end = False + self._base_lr = None + + def before_train(self, runner: Runner): + """Operations before train. + + Args: + runner (Runner): The runner of the training process. + """ + optimizer = runner.optim_wrapper.optimizer + for group in optimizer.param_groups: + # If the param is never be scheduled, record the current value + # as the initial value. + group.setdefault('initial_lr', group['lr']) + + self._base_lr = [ + group['initial_lr'] for group in optimizer.param_groups + ] + self._min_lr = [i * self.min_lr_ratio for i in self._base_lr] + + def before_train_iter(self, + runner: Runner, + batch_idx: int, + data_batch: Optional[dict] = None): + """Operations before each training iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (dict or tuple or list, optional): Data from dataloader. + """ + cur_iters = runner.iter + optimizer = runner.optim_wrapper.optimizer + dataloader_len = len(runner.train_dataloader) + + # The minimum warmup is self.warmup_min_iter + warmup_total_iters = max( + round(self.warmup_epochs * dataloader_len), self.warmup_min_iter) + + if cur_iters <= warmup_total_iters: + # warm up + alpha = cur_iters / warmup_total_iters + factor = self.start_factor * (1 - alpha) + alpha + + for group_idx, param in enumerate(optimizer.param_groups): + param['lr'] = self._base_lr[group_idx] * factor + else: + for group_idx, param in enumerate(optimizer.param_groups): + total_iters = self.total_epochs * dataloader_len + lr = self._min_lr[group_idx] + ( + self._base_lr[group_idx] - + self._min_lr[group_idx]) * 0.5 * ( + math.cos((cur_iters - warmup_total_iters) * math.pi / + (total_iters - warmup_total_iters)) + 1.0) + param['lr'] = lr diff --git a/mmpl/engine/hooks/switch_to_deploy_hook.py b/mmpl/engine/hooks/switch_to_deploy_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..28ac345f40c44c974fb33b7bf9756a61fcabf820 --- /dev/null +++ b/mmpl/engine/hooks/switch_to_deploy_hook.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmengine.hooks import Hook +from mmengine.runner import Runner + +from mmyolo.registry import HOOKS +from mmyolo.utils import switch_to_deploy + + +@HOOKS.register_module() +class SwitchToDeployHook(Hook): + """Switch to deploy mode before testing. + + This hook converts the multi-channel structure of the training network + (high performance) to the one-way structure of the testing network (fast + speed and memory saving). + """ + + def before_test_epoch(self, runner: Runner): + """Switch to deploy mode before testing.""" + switch_to_deploy(runner.model) diff --git a/mmpl/engine/hooks/visualization_hook.py b/mmpl/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..16d72b958ede24e9ae888b16b46b5775e7511011 --- /dev/null +++ b/mmpl/engine/hooks/visualization_hook.py @@ -0,0 +1,199 @@ +import os.path as osp +import warnings +from typing import Optional, Sequence, Any + +import mmcv +from lightning import Callback +from mmengine.fileio import get +from mmengine.hooks import Hook +from mmengine.runner import Runner +from mmengine.utils import mkdir_or_exist +from mmengine.visualization import Visualizer + +from mmpl.registry import HOOKS +from mmdet.structures import DetDataSample + + +@HOOKS.register_module() +class DetVisualizationHook(Callback): + """Detection Visualization Hook. Used to visualize validation and testing + process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + 2. If ``test_out_dir`` is specified, it means that the prediction results + need to be saved to ``test_out_dir``. In order to avoid vis_backends + also storing data, so ``vis_backends`` needs to be excluded. + 3. ``vis_backends`` takes effect if the user does not specify ``show`` + and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or + TensorboardVisBackend to store the prediction result in Wandb or + Tensorboard. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + test_out_dir (str, optional): directory where painted images + will be saved in testing process. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + + def __init__(self, + draw: bool = False, + interval: int = 50, + score_thr: float = 0.3, + show: bool = False, + wait_time: float = 0., + test_out_dir: Optional[str] = None, + backend_args: dict = None): + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.interval = interval + self.score_thr = score_thr + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.backend_args = backend_args + self.draw = draw + self.test_out_dir = test_out_dir + self._test_index = 0 + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DetDataSample]) -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + img_path = outputs[0].img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + if total_curr_iter % self.interval == 0: + self._visualizer.add_datasample( + osp.basename(img_path) if self.show else 'val_img', + img, + data_sample=outputs[0], + show=self.show, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[DetDataSample]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + if self.test_out_dir is not None: + self.test_out_dir = osp.join(runner.work_dir, runner.timestamp, + self.test_out_dir) + mkdir_or_exist(self.test_out_dir) + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + out_file = None + if self.test_out_dir is not None: + out_file = osp.basename(img_path) + out_file = osp.join(self.test_out_dir, out_file) + + self._visualizer.add_datasample( + osp.basename(img_path) if self.show else 'test_img', + img, + data_sample=data_sample, + show=self.show, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + out_file=out_file, + step=self._test_index) + + def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + # if hasattr(trainer.datamodule, f'predict_dataset'): + # dataset = getattr(trainer.datamodule, f'predict_dataset') + # if hasattr(dataset, 'metainfo') and hasattr(self._visualizer, 'dataset_meta'): + # self._visualizer.dataset_meta = dataset.metainfo + if self.test_out_dir is not None: + self.test_out_dir = osp.join(trainer.default_root_dir, self.test_out_dir) + mkdir_or_exist(self.test_out_dir) + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + out_file = None + if self.test_out_dir is not None: + out_file = osp.basename(img_path) + out_file = osp.join(self.test_out_dir, out_file) + + self._visualizer.add_datasample( + osp.basename(img_path) if self.show else 'test_img', + img, + data_sample=data_sample, + show=self.show, + wait_time=self.wait_time, + pred_score_thr=self.score_thr, + out_file=out_file, + step=self._test_index) diff --git a/mmpl/engine/hooks/yolov5_param_scheduler_hook.py b/mmpl/engine/hooks/yolov5_param_scheduler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..59713e9303bdb8d7505f6a7efb03303200515b4c --- /dev/null +++ b/mmpl/engine/hooks/yolov5_param_scheduler_hook.py @@ -0,0 +1,111 @@ +import math +from typing import Optional +import numpy as np +from typing import Dict, Optional, Union +from mmengine.registry import HOOKS +from .param_scheduler_hook import ParamSchedulerHook + +DATA_BATCH = Optional[Union[dict, tuple, list]] + + +def linear_fn(lr_factor: float, max_epochs: int): + """Generate linear function.""" + return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor + + +def cosine_fn(lr_factor: float, max_epochs: int): + """Generate cosine function.""" + return lambda x: ( + (1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1 + + +@HOOKS.register_module() +class YOLOv5ParamSchedulerHook(ParamSchedulerHook): + """A hook to update learning rate and momentum in optimizer of YOLOv5.""" + priority = 9 + + scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn} + + def __init__(self, + scheduler_type: str = 'linear', + lr_factor: float = 0.01, + max_epochs: int = 300, + warmup_epochs: int = 3, + warmup_bias_lr: float = 0.1, + warmup_momentum: float = 0.8, + warmup_mim_iter: int = 500, + **kwargs): + + assert scheduler_type in self.scheduler_maps + + self.warmup_epochs = warmup_epochs + self.warmup_bias_lr = warmup_bias_lr + self.warmup_momentum = warmup_momentum + self.warmup_mim_iter = warmup_mim_iter + + kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs}) + self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs) + + self._warmup_end = False + self._base_lr = None + self._base_momentum = None + + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + optimizer = trainer.optimizers[0] + for group in optimizer.param_groups: + # If the param is never be scheduled, record the current value + # as the initial value. + group.setdefault('initial_lr', group['lr']) + group.setdefault('initial_momentum', group.get('momentum', -1)) + + self._base_lr = [ + group['initial_lr'] for group in optimizer.param_groups + ] + self._base_momentum = [ + group['initial_momentum'] for group in optimizer.param_groups + ] + + def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss) -> None: + cur_iters = trainer.global_step + cur_epoch = trainer.current_epoch + optimizer = trainer.optimizers[0] + + # The minimum warmup is self.warmup_mim_iter + warmup_total_iters = max( + round(self.warmup_epochs * len(trainer.train_dataloader)), + self.warmup_mim_iter) + + if cur_iters <= warmup_total_iters: + xp = [0, warmup_total_iters] + for group_idx, param in enumerate(optimizer.param_groups): + if group_idx == 2: + # bias learning rate will be handled specially + yp = [ + self.warmup_bias_lr, + self._base_lr[group_idx] * self.scheduler_fn(cur_epoch) + ] + else: + yp = [ + 0.0, + self._base_lr[group_idx] * self.scheduler_fn(cur_epoch) + ] + param['lr'] = np.interp(cur_iters, xp, yp) + + if 'momentum' in param: + param['momentum'] = np.interp( + cur_iters, xp, + [self.warmup_momentum, self._base_momentum[group_idx]]) + else: + self._warmup_end = True + + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self._warmup_end: + return + + cur_epoch = trainer.current_epoch + optimizer = trainer.optimizers[0] + for group_idx, param in enumerate(optimizer.param_groups): + param['lr'] = self._base_lr[group_idx] * self.scheduler_fn( + cur_epoch) + diff --git a/mmpl/engine/hooks/yolox_mode_switch_hook.py b/mmpl/engine/hooks/yolox_mode_switch_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..27711768c3f89b26410ae1373bc920d0bfded603 --- /dev/null +++ b/mmpl/engine/hooks/yolox_mode_switch_hook.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Sequence + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper +from mmengine.runner import Runner + +from mmyolo.registry import HOOKS + + +@HOOKS.register_module() +class YOLOXModeSwitchHook(Hook): + """Switch the mode of YOLOX during training. + + This hook turns off the mosaic and mixup data augmentation and switches + to use L1 loss in bbox_head. + + Args: + num_last_epochs (int): The number of latter epochs in the end of the + training to close the data augmentation and switch to L1 loss. + Defaults to 15. + """ + + def __init__(self, + num_last_epochs: int = 15, + new_train_pipeline: Sequence[dict] = None): + self.num_last_epochs = num_last_epochs + self.new_train_pipeline_cfg = new_train_pipeline + + def before_train_epoch(self, runner: Runner): + """Close mosaic and mixup augmentation and switches to use L1 loss.""" + epoch = runner.epoch + model = runner.model + if is_model_wrapper(model): + model = model.module + + if (epoch + 1) == runner.max_epochs - self.num_last_epochs: + runner.logger.info(f'New Pipeline: {self.new_train_pipeline_cfg}') + + train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader) + train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline_cfg + # Note: Why rebuild the dataset? + # When build_dataloader will make a deep copy of the dataset, + # it will lead to potential risks, such as the global instance + # object FileClient data is disordered. + # This problem needs to be solved in the future. + new_train_dataloader = Runner.build_dataloader( + train_dataloader_cfg) + runner.train_loop.dataloader = new_train_dataloader + + runner.logger.info('recreate the dataloader!') + runner.logger.info('Add additional bbox reg loss now!') + model.bbox_head.use_bbox_aux = True diff --git a/mmpl/engine/logger/__init__.py b/mmpl/engine/logger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7d509cbaa213acccf34153ac8df157bbe3bb86 --- /dev/null +++ b/mmpl/engine/logger/__init__.py @@ -0,0 +1 @@ +from .builder import PL_LOGGERS diff --git a/mmpl/engine/logger/__pycache__/__init__.cpython-310.pyc b/mmpl/engine/logger/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efb41ea38cea28be3f66cc025cbeb4e5729b42e8 Binary files /dev/null and b/mmpl/engine/logger/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/engine/logger/__pycache__/builder.cpython-310.pyc b/mmpl/engine/logger/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b842ba45e6f7cb0855b17ba44f30fd7a1137798 Binary files /dev/null and b/mmpl/engine/logger/__pycache__/builder.cpython-310.pyc differ diff --git a/mmpl/engine/logger/builder.py b/mmpl/engine/logger/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..14acaabe52bf1e245c17430abbad01b48e71fc25 --- /dev/null +++ b/mmpl/engine/logger/builder.py @@ -0,0 +1,112 @@ +import copy +import inspect +from typing import List, Union + +import torch +import torch.nn as nn +import lightning + +from mmengine.config import Config, ConfigDict +from mmengine.device import is_npu_available +from mmpl.registry import LOGGERS + + +def register_pl_loggers() -> List[str]: + """Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry. + + Returns: + List[str]: A list of registered optimizers' name. + """ + pl_loggers = [] + for module_name in dir(lightning.pytorch.loggers): + if module_name.startswith('__'): + continue + _logger = getattr(lightning.pytorch.loggers, module_name) + if inspect.isclass(_logger) and issubclass(_logger, lightning.pytorch.loggers.logger.Logger): + LOGGERS.register_module(module=_logger) + pl_loggers.append(module_name) + return pl_loggers + + +PL_LOGGERS = register_pl_loggers() + + +def register_dadaptation_optimizers() -> List[str]: + """Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry. + + Returns: + List[str]: A list of registered optimizers' name. + """ + dadaptation_optimizers = [] + try: + import dadaptation + except ImportError: + pass + else: + for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']: + _optim = getattr(dadaptation, module_name) + if inspect.isclass(_optim) and issubclass(_optim, + torch.optim.Optimizer): + OPTIMIZERS.register_module(module=_optim) + dadaptation_optimizers.append(module_name) + return dadaptation_optimizers + + +# DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() + + +def register_lion_optimizers() -> List[str]: + """Register Lion optimizer to the ``OPTIMIZERS`` registry. + + Returns: + List[str]: A list of registered optimizers' name. + """ + optimizers = [] + try: + from lion_pytorch import Lion + except ImportError: + pass + else: + OPTIMIZERS.register_module(module=Lion) + optimizers.append('Lion') + return optimizers + + +# LION_OPTIMIZERS = register_lion_optimizers() + + +def build_optim_wrapper(model: nn.Module, + cfg: Union[dict, Config, ConfigDict]): + """Build function of OptimWrapper. + + If ``constructor`` is set in the ``cfg``, this method will build an + optimizer wrapper constructor, and use optimizer wrapper constructor to + build the optimizer wrapper. If ``constructor`` is not set, the + ``DefaultOptimWrapperConstructor`` will be used by default. + + Args: + model (nn.Module): Model to be optimized. + cfg (dict): Config of optimizer wrapper, optimizer constructor and + optimizer. + + Returns: + OptimWrapper: The built optimizer wrapper. + """ + optim_wrapper_cfg = copy.deepcopy(cfg) + constructor_type = optim_wrapper_cfg.pop('constructor', + 'DefaultOptimWrapperConstructor') + paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) + + # Since the current generation of NPU(Ascend 910) only supports + # mixed precision training, here we turn on mixed precision by default + # on the NPU to make the training normal + if is_npu_available(): + optim_wrapper_cfg['type'] = 'AmpOptimWrapper' + + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( + dict( + type=constructor_type, + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg)) + optim_wrapper = optim_wrapper_constructor(model) + return optim_wrapper diff --git a/mmpl/engine/optimizers/__init__.py b/mmpl/engine/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmpl/engine/runner/__init__.py b/mmpl/engine/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc6e5770a871bc6f3de091d48d9846512adfb0d8 --- /dev/null +++ b/mmpl/engine/runner/__init__.py @@ -0,0 +1,3 @@ +from .pl_runner import PLRunner + +__all__ = ['PLRunner'] diff --git a/mmpl/engine/runner/__pycache__/__init__.cpython-310.pyc b/mmpl/engine/runner/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1cd300813c94871f468df86e0e17ece43e0dfcd Binary files /dev/null and b/mmpl/engine/runner/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/engine/runner/__pycache__/pl_runner.cpython-310.pyc b/mmpl/engine/runner/__pycache__/pl_runner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..748d9b493a57de1e7e5a1768f4c5f4d144492603 Binary files /dev/null and b/mmpl/engine/runner/__pycache__/pl_runner.cpython-310.pyc differ diff --git a/mmpl/engine/runner/pl_runner.py b/mmpl/engine/runner/pl_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..54b087fb1f2c15af9215c9df7bf1ee6259717ff2 --- /dev/null +++ b/mmpl/engine/runner/pl_runner.py @@ -0,0 +1,941 @@ +import copy +import logging +import os +import os.path as osp +import pickle +import platform +import time +import warnings +from collections import OrderedDict +from functools import partial +from typing import Callable, Dict, List, Optional, Sequence, Union + +import torch +import torch.nn as nn +from lightning.pytorch.loggers import Logger +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +import mmengine +from mmengine.config import Config, ConfigDict +from mmengine.dataset import worker_init_fn +from mmengine.device import get_device +from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, + is_distributed, master_only) +from mmengine.evaluator import Evaluator +from mmengine.fileio import FileClient, join_path +from mmengine.hooks import Hook +from mmengine.logging import MessageHub, MMLogger, print_log +from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm, + is_model_wrapper, revert_sync_batchnorm) +from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, + build_optim_wrapper) +from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, + HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, + OPTIM_WRAPPERS, PARAM_SCHEDULERS, + RUNNERS, VISUALIZERS, DefaultScope) +from mmengine.utils import digit_version, get_git_hash, is_seq_of +from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, + set_multi_processing) +from mmengine.visualization import Visualizer +from mmengine.runner.base_loop import BaseLoop +from mmengine.runner.checkpoint import (_load_checkpoint, _load_checkpoint_to_model, + find_latest_checkpoint, get_state_dict, + save_checkpoint, weights_to_cpu) +from mmengine.runner.log_processor import LogProcessor +from mmengine.runner.loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop +from mmengine.runner.priority import Priority, get_priority +from mmengine.runner.utils import set_random_seed + +ConfigType = Union[Dict, Config, ConfigDict] +ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, List[_ParamScheduler]]] +OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] + +from mmpl.registry import MODELS, LOGGERS +import lightning.pytorch as pl +from mmpl.models import build_pler + + +@RUNNERS.register_module() +class PLRunner: + def __init__( + self, + trainer_cfg: Dict, + model_cfg: Union[pl.LightningModule, Dict], + datamodule_cfg: Optional[Dict] = None, + cfg: Optional[ConfigType] = None + ): + self.trainer_cfg = copy.deepcopy(trainer_cfg) + self.model_cfg = copy.deepcopy(model_cfg) + self.datamodule_cfg = copy.deepcopy(datamodule_cfg) + mmengine.mkdir_or_exist(trainer_cfg['default_root_dir']) + + timestamp = torch.tensor(time.time(), dtype=torch.float64) + # broadcast timestamp from 0 process to other processes + broadcast(timestamp) + self.timestamp = time.strftime('%Y%m%d_%H%M%S', + time.localtime(timestamp.item())) + + if cfg is not None: + if isinstance(cfg, Config): + self.cfg = copy.deepcopy(cfg) + elif isinstance(cfg, dict): + self.cfg = Config(cfg) + else: + self.cfg = Config(dict()) + + compiled_model = trainer_cfg.pop('compiled_model', False) + + # build logger + loggers = self.build_logger( + trainer_cfg.get('logger', False), + trainer_cfg.get('default_root_dir', f'{self.timestamp}') + ) + trainer_cfg['logger'] = loggers + + # build visualizer used for writing log or visualizing all kinds of data + self.visualizer = self.build_visualizer( + self.cfg.get('visualizer', None), + trainer_cfg.get('default_root_dir', f'{self.timestamp}') + ) + if self.cfg: + self.visualizer.add_config(self.cfg) + + # build callbacks + callbacks = self.build_hooks( + trainer_cfg.get('callbacks', None), + ) + trainer_cfg['callbacks'] = callbacks + + # build strategy + strategy = self.build_strategy( + trainer_cfg.get('strategy', 'auto'), + ) + trainer_cfg['strategy'] = strategy + + self.trainer = pl.Trainer(**trainer_cfg) + model_cfg.update({'config_cfg': copy.deepcopy(cfg).to_dict()}) + model = self.build_model(model_cfg) + if cfg.get('load_from', None) is not None: + self.load_checkpoint(model, cfg['load_from']) + if compiled_model: + # default, reduce-overhead, and max-autotune. + self.model = torch.compile(model) + else: + self.model = model + + # dump `cfg` to `work_dir` + self.dump_config() + # # Collect and log environment information. + # self._log_env(env_cfg) + # log hooks information + # self.logger.info(f'Hooks will be executed in the following ' + # f'order:\n{self.get_hooks_info()}') + + def build_visualizer( + self, + visualizer: Optional[Union[Visualizer, + Dict]] = None, + default_root_dir = 'tmp' + ) -> Visualizer: + """Build a global asscessable Visualizer. + + Args: + visualizer (Visualizer or dict, optional): A Visualizer object + or a dict to build Visualizer object. If ``visualizer`` is a + Visualizer object, just returns itself. If not specified, + default config will be used to build Visualizer object. + Defaults to None. + + Returns: + Visualizer: A Visualizer object build from ``visualizer``. + """ + if visualizer is None: + visualizer = dict( + name=os.path.basename(default_root_dir), + vis_backends=[dict(type='LocalVisBackend')], + save_dir=default_root_dir+'/visualizer' + ) + return Visualizer.get_instance(**visualizer) + + if isinstance(visualizer, Visualizer): + return visualizer + + if isinstance(visualizer, dict): + # ensure visualizer containing name key + visualizer.setdefault('name', os.path.basename(default_root_dir)) + visualizer.setdefault('save_dir', default_root_dir+'/visualizer') + return VISUALIZERS.build(visualizer) + else: + raise TypeError( + 'visualizer should be Visualizer object, a dict or None, ' + f'but got {visualizer}') + + def build_hooks(self, hooks: Union[Dict, List[Dict]] = None) -> List[Hook]: + """Build hooks from config. + + Args: + hooks_cfg (dict): Config dict of hooks. + + Returns: + list[Hook]: A list of hooks. + """ + if hooks is not None: + if isinstance(hooks, dict): + hooks = [hooks] + tmp_hooks = [] + for hook in hooks: + hook = HOOKS.build(hook) + tmp_hooks.append(hook) + hooks = tmp_hooks + return hooks + + @classmethod + def from_cfg(cls, cfg: ConfigType) -> 'Runner': + cfg = copy.deepcopy(cfg) + runner = cls( + trainer_cfg=cfg.get('trainer_cfg'), + model_cfg=cfg['model_cfg'], + datamodule_cfg=cfg.get('datamodule_cfg'), + cfg=cfg + ) + + return runner + + def build_logger(self, loggers: Union[Dict, List[Dict]] = None, default_root_dir='logger'): + if loggers is not None and loggers: + if isinstance(loggers, Dict): + loggers = [loggers] + tmp_loggers = [] + for logger in loggers: + if logger.get('save_dir', None) is None: + logger['save_dir'] = default_root_dir + mmengine.mkdir_or_exist(logger['save_dir']) + tmp_loggers.append(LOGGERS.build(logger)) + loggers = tmp_loggers + return loggers + + def build_strategy(self, strategy='auto'): + if isinstance(strategy, str): + return strategy + elif isinstance(strategy, dict): + if strategy.get('type', '') == 'FSDPStrategy': + from torch.distributed.fsdp import CPUOffload + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + import functools + strategy.update( + dict( + # cpu_offload=CPUOffload(offload_params=True), + auto_wrap_policy=functools.partial( + size_based_auto_wrap_policy, min_num_params=int(5e7) + ) + ) + ) + strategy = MODEL_WRAPPERS.build(strategy) + return strategy + return strategy + + def build_model(self, model: Union[pl.LightningModule, Dict]) -> pl.LightningModule: + if isinstance(model, pl.LightningModule): + return model + elif isinstance(model, dict): + model = build_pler(model) + return model # type: ignore + else: + raise TypeError('model should be a nn.Module object or dict, ' + f'but got {model}') + + def _init_model_weights(self) -> None: + """Initialize the model weights if the model has + :meth:`init_weights`""" + if hasattr(self.model, 'module'): + model = self.model.module + else: + model = self.model + if hasattr(model, 'init_weights'): + model.init_weights() + # sync params and buffers + for name, params in model.state_dict().items(): + broadcast(params) + + def get_hooks_info(self) -> str: + # Get hooks info in each stage + stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages} + for hook in self.hooks: + try: + priority = Priority(hook.priority).name # type: ignore + except ValueError: + priority = hook.priority # type: ignore + classname = hook.__class__.__name__ + hook_info = f'({priority:<12}) {classname:<35}' + for trigger_stage in hook.get_triggered_stages(): + stage_hook_map[trigger_stage].append(hook_info) + + stage_hook_infos = [] + for stage in Hook.stages: + hook_infos = stage_hook_map[stage] + if len(hook_infos) > 0: + info = f'{stage}:\n' + info += '\n'.join(hook_infos) + info += '\n -------------------- ' + stage_hook_infos.append(info) + return '\n'.join(stage_hook_infos) + + def load_or_resume(self) -> None: + """load or resume checkpoint.""" + if self._has_loaded: + return None + + # decide to load from checkpoint or resume from checkpoint + resume_from = None + if self._resume and self._load_from is None: + # auto resume from the latest checkpoint + resume_from = find_latest_checkpoint(self.work_dir) + self.logger.info( + f'Auto resumed from the latest checkpoint {resume_from}.') + elif self._resume and self._load_from is not None: + # resume from the specified checkpoint + resume_from = self._load_from + + if resume_from is not None: + self.resume(resume_from) + self._has_loaded = True + elif self._load_from is not None: + self.load_checkpoint(self._load_from) + self._has_loaded = True + + @staticmethod + def build_datamodule(datamodule_cfg: Union[pl.LightningDataModule, Dict]): + if isinstance(datamodule_cfg, pl.LightningDataModule): + return datamodule_cfg + datamodule_cfg = copy.deepcopy(datamodule_cfg) + # build datamodule + datamodule = DATASETS.build(datamodule_cfg) + return datamodule + + def run(self, status, *args, **kwargs): + assert status in ['fit', 'test', 'predict', 'validate'] + trainer_func = self.trainer.__getattribute__(status) + self.datamodule = self.build_datamodule(self.datamodule_cfg) + return trainer_func(model=self.model, datamodule=self.datamodule, *args, **kwargs) + + # + # if is_model_wrapper(self.model): + # ori_model = self.model.module + # else: + # ori_model = self.model + # assert hasattr(ori_model, 'train_step'), ( + # 'If you want to train your model, please make sure your model ' + # 'has implemented `train_step`.') + # + # if self._val_loop is not None: + # assert hasattr(ori_model, 'val_step'), ( + # 'If you want to validate your model, please make sure your ' + # 'model has implemented `val_step`.') + # + # if self._train_loop is None: + # raise RuntimeError( + # '`self._train_loop` should not be None when calling train ' + # 'method. Please provide `train_dataloader`, `train_cfg`, ' + # '`optimizer` and `param_scheduler` arguments when ' + # 'initializing runner.') + # + # self._train_loop = self.build_train_loop( + # self._train_loop) # type: ignore + # + # # `build_optimizer` should be called before `build_param_scheduler` + # # because the latter depends on the former + # self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) + # # Automatically scaling lr by linear scaling rule + # self.scale_lr(self.optim_wrapper, self.auto_scale_lr) + # + # if self.param_schedulers is not None: + # self.param_schedulers = self.build_param_scheduler( # type: ignore + # self.param_schedulers) # type: ignore + # + # if self._val_loop is not None: + # self._val_loop = self.build_val_loop( + # self._val_loop) # type: ignore + # # TODO: add a contextmanager to avoid calling `before_run` many times + # self.call_hook('before_run') + # + # # initialize the model weights + # self._init_model_weights() + # # make sure checkpoint-related hooks are triggered after `before_run` + # self.load_or_resume() + # + # # Initiate inner count of `optim_wrapper`. + # self.optim_wrapper.initialize_count_status( + # self.model, + # self._train_loop.iter, # type: ignore + # self._train_loop.max_iters) # type: ignore + # + # # Maybe compile the model according to options in self.cfg.compile + # # This must be called **AFTER** model has been wrapped. + # self._maybe_compile('train_step') + # + # model = self.train_loop.run() # type: ignore + # self.call_hook('after_run') + # return model + + + + def register_hook( + self, + hook: Union[Hook, Dict], + priority: Optional[Union[str, int, Priority]] = None) -> None: + """Register a hook into the hook list. + + The hook will be inserted into a priority queue, with the specified + priority (See :class:`Priority` for details of priorities). + For hooks with the same priority, they will be triggered in the same + order as they are registered. + + Priority of hook will be decided with the following priority: + + - ``priority`` argument. If ``priority`` is given, it will be priority + of hook. + - If ``hook`` argument is a dict and ``priority`` in it, the priority + will be the value of ``hook['priority']``. + - If ``hook`` argument is a dict but ``priority`` not in it or ``hook`` + is an instance of ``hook``, the priority will be ``hook.priority``. + + Args: + hook (:obj:`Hook` or dict): The hook to be registered. + priority (int or str or :obj:`Priority`, optional): Hook priority. + Lower value means higher priority. + """ + if not isinstance(hook, (Hook, dict)): + raise TypeError( + f'hook should be an instance of Hook or dict, but got {hook}') + + _priority = None + if isinstance(hook, dict): + if 'priority' in hook: + _priority = hook.pop('priority') + + hook_obj = HOOKS.build(hook) + else: + hook_obj = hook + + if priority is not None: + hook_obj.priority = priority + elif _priority is not None: + hook_obj.priority = _priority + + inserted = False + for i in range(len(self._hooks) - 1, -1, -1): + if get_priority(hook_obj.priority) >= get_priority( + self._hooks[i].priority): + self._hooks.insert(i + 1, hook_obj) + inserted = True + break + if not inserted: + self._hooks.insert(0, hook_obj) + + def register_default_hooks( + self, + hooks: Optional[Dict[str, Union[Hook, Dict]]] = None) -> None: + """Register default hooks into hook list. + + ``hooks`` will be registered into runner to execute some default + actions like updating model parameters or saving checkpoints. + + Default hooks and their priorities: + + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | RuntimeInfoHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | IterTimerHook | NORMAL (50) | + +----------------------+-------------------------+ + | DistSamplerSeedHook | NORMAL (50) | + +----------------------+-------------------------+ + | LoggerHook | BELOW_NORMAL (60) | + +----------------------+-------------------------+ + | ParamSchedulerHook | LOW (70) | + +----------------------+-------------------------+ + | CheckpointHook | VERY_LOW (90) | + +----------------------+-------------------------+ + + If ``hooks`` is None, above hooks will be registered by + default:: + + default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + logger=dict(type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1), + ) + + If not None, ``hooks`` will be merged into ``default_hooks``. + If there are None value in default_hooks, the corresponding item will + be popped from ``default_hooks``:: + + hooks = dict(timer=None) + + The final registered default hooks will be :obj:`RuntimeInfoHook`, + :obj:`DistSamplerSeedHook`, :obj:`LoggerHook`, + :obj:`ParamSchedulerHook` and :obj:`CheckpointHook`. + + Args: + hooks (dict[str, Hook or dict], optional): Default hooks or configs + to be registered. + """ + default_hooks: dict = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + logger=dict(type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1), + ) + if hooks is not None: + for name, hook in hooks.items(): + if name in default_hooks and hook is None: + # remove hook from _default_hooks + default_hooks.pop(name) + else: + assert hook is not None + default_hooks[name] = hook + + for hook in default_hooks.values(): + self.register_hook(hook) + + def register_custom_hooks(self, hooks: List[Union[Hook, Dict]]) -> None: + """Register custom hooks into hook list. + + Args: + hooks (list[Hook | dict]): List of hooks or configs to be + registered. + """ + for hook in hooks: + self.register_hook(hook) + + def register_hooks( + self, + default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, + custom_hooks: Optional[List[Union[Hook, Dict]]] = None) -> None: + """Register default hooks and custom hooks into hook list. + + Args: + default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks + to execute default actions like updating model parameters and + saving checkpoints. Defaults to None. + custom_hooks (list[dict] or list[Hook], optional): Hooks to execute + custom actions like visualizing images processed by pipeline. + Defaults to None. + """ + self.register_default_hooks(default_hooks) + + if custom_hooks is not None: + self.register_custom_hooks(custom_hooks) + + def resume(self, + filename: str, + resume_optimizer: bool = True, + resume_param_scheduler: bool = True, + map_location: Union[str, Callable] = 'default') -> None: + """Resume model from checkpoint. + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. + resume_optimizer (bool): Whether to resume optimizer state. + Defaults to True. + resume_param_scheduler (bool): Whether to resume param scheduler + state. Defaults to True. + map_location (str or callable):A string or a callable function to + specifying how to remap storage locations. + Defaults to 'default'. + """ + if map_location == 'default': + device = get_device() + checkpoint = self.load_checkpoint(filename, map_location=device) + else: + checkpoint = self.load_checkpoint( + filename, map_location=map_location) + + self.train_loop._epoch = checkpoint['meta']['epoch'] + self.train_loop._iter = checkpoint['meta']['iter'] + + # check whether the number of GPU used for current experiment + # is consistent with resuming from checkpoint + if 'config' in checkpoint['meta']: + config = mmengine.Config.fromstring( + checkpoint['meta']['config'], file_format='.py') + previous_gpu_ids = config.get('gpu_ids', None) + if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0 + and len(previous_gpu_ids) != self._world_size): + # TODO, should we modify the iteration? + self.logger.info( + 'Number of GPU used for current experiment is not ' + 'consistent with resuming from checkpoint') + if (self.auto_scale_lr is None + or not self.auto_scale_lr.get('enable', False)): + raise RuntimeError( + 'Cannot automatically rescale lr in resuming. Please ' + 'make sure the number of GPU is consistent with the ' + 'previous training state resuming from the checkpoint ' + 'or set `enable` in `auto_scale_lr to False.') + + # resume random seed + resumed_seed = checkpoint['meta'].get('seed', None) + current_seed = self._randomness_cfg.get('seed') + if resumed_seed is not None and resumed_seed != current_seed: + if current_seed is not None: + print_log( + f'The value of random seed in the ' + f'checkpoint "{resumed_seed}" is ' + f'different from the value in ' + f'`randomness` config "{current_seed}"', + logger='current', + level=logging.WARNING) + self._randomness_cfg.update(seed=resumed_seed) + self.set_randomness(**self._randomness_cfg) + + resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None) + dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None) + + # `resumed_dataset_meta` and `dataset_meta` could be object like + # np.ndarray, which cannot be directly judged as equal or not, + # therefore we just compared their dumped results. + if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta): + print_log( + 'The dataset metainfo from the resumed checkpoint is ' + 'different from the current training dataset, please ' + 'check the correctness of the checkpoint or the training ' + 'dataset.', + logger='current', + level=logging.WARNING) + + self.message_hub.load_state_dict(checkpoint['message_hub']) + + # resume optimizer + if 'optimizer' in checkpoint and resume_optimizer: + self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper) + self.optim_wrapper.load_state_dict( # type: ignore + checkpoint['optimizer']) + + # resume param scheduler + if resume_param_scheduler and self.param_schedulers is None: + print_log( + '`resume_param_scheduler` is True but `self.param_schedulers` ' + 'is None, so skip resuming parameter schedulers', + logger='current', + level=logging.WARNING) + resume_param_scheduler = False + if 'param_schedulers' in checkpoint and resume_param_scheduler: + self.param_schedulers = self.build_param_scheduler( # type: ignore + self.param_schedulers) # type: ignore + if isinstance(self.param_schedulers, dict): + for name, schedulers in self.param_schedulers.items(): + for scheduler, ckpt_scheduler in zip( + schedulers, checkpoint['param_schedulers'][name]): + scheduler.load_state_dict(ckpt_scheduler) + else: + for scheduler, ckpt_scheduler in zip( + self.param_schedulers, # type: ignore + checkpoint['param_schedulers']): + scheduler.load_state_dict(ckpt_scheduler) + + self._has_loaded = True + + self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}') + + # def load_checkpoint(self, + # filename: str, + # model, + # map_location: Union[str, Callable] = 'cpu', + # strict: bool = False, + # revise_keys: list = [(r'^module.', '')]): + # """Load checkpoint from given ``filename``. + # + # Args: + # filename (str): Accept local filepath, URL, ``torchvision://xxx``, + # ``open-mmlab://xxx``. + # map_location (str or callable): A string or a callable function to + # specifying how to remap storage locations. + # Defaults to 'cpu'. + # strict (bool): strict (bool): Whether to allow different params for + # the model and checkpoint. + # revise_keys (list): A list of customized keywords to modify the + # state_dict in checkpoint. Each item is a (pattern, replacement) + # pair of the regular expression operations. Defaults to strip + # the prefix 'module.' by [(r'^module\\.', '')]. + # """ + # checkpoint = _load_checkpoint(filename, map_location=map_location) + # + # if is_model_wrapper(model): + # model = model.module + # else: + # model = model + # + # checkpoint = _load_checkpoint_to_model( + # model, checkpoint, strict, revise_keys=revise_keys) + # + # print(f'Load checkpoint from {filename}') + # + # return checkpoint + def load_checkpoint(self, model, file): + + if isinstance(file, str): + file_path = file + state_dict = torch.load(file_path, map_location='cpu')['state_dict'] + elif isinstance(file, dict): + file_path = file['file_path'] + state_dict = torch.load(file_path, map_location='cpu')['state_dict'] + for delete_key in file['delete_keys']: + del state_dict[delete_key] + else: + raise TypeError('file must be str or dict') + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print('load from:', file_path) + print('load model missing_keys:', missing_keys) + print('load model unexpected_keys:', unexpected_keys) + + @master_only + def save_checkpoint( + self, + out_dir: str, + filename: str, + file_client_args: Optional[dict] = None, + save_optimizer: bool = True, + save_param_scheduler: bool = True, + meta: dict = None, + by_epoch: bool = True, + backend_args: Optional[dict] = None, + ): + """Save checkpoints. + + ``CheckpointHook`` invokes this method to save checkpoints + periodically. + + Args: + out_dir (str): The directory that checkpoints are saved. + filename (str): The checkpoint filename. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for + details. Defaults to None. It will be deprecated in future. + Please use `backend_args` instead. + save_optimizer (bool): Whether to save the optimizer to + the checkpoint. Defaults to True. + save_param_scheduler (bool): Whether to save the param_scheduler + to the checkpoint. Defaults to True. + meta (dict, optional): The meta information to be saved in the + checkpoint. Defaults to None. + by_epoch (bool): Whether the scheduled momentum is updated by + epochs. Defaults to True. + backend_args (dict, optional): Arguments to instantiate the + prefix of uri corresponding backend. Defaults to None. + New in v0.2.0. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + + if by_epoch: + # self.epoch increments 1 after + # `self.call_hook('after_train_epoch)` but `save_checkpoint` is + # called by `after_train_epoch`` method of `CheckpointHook` so + # `epoch` should be `self.epoch + 1` + meta.update(epoch=self.epoch + 1, iter=self.iter) + else: + meta.update(epoch=self.epoch, iter=self.iter + 1) + + if file_client_args is not None: + warnings.warn( + '"file_client_args" will be deprecated in future. ' + 'Please use "backend_args" instead', DeprecationWarning) + if backend_args is not None: + raise ValueError( + '"file_client_args" and "backend_args" cannot be set at ' + 'the same time.') + + file_client = FileClient.infer_client(file_client_args, out_dir) + filepath = file_client.join_path(out_dir, filename) + else: + filepath = join_path( # type: ignore + out_dir, filename, backend_args=backend_args) + + meta.update( + cfg=self.cfg.pretty_text, + seed=self.seed, + experiment_name=self.experiment_name, + time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), + mmengine_version=mmengine.__version__ + get_git_hash()) + + if hasattr(self.train_dataloader.dataset, 'metainfo'): + meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) + + if is_model_wrapper(self.model): + model = self.model.module + else: + model = self.model + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)), + 'message_hub': self.message_hub.state_dict() + } + # save optimizer state dict to checkpoint + if save_optimizer: + if isinstance(self.optim_wrapper, OptimWrapper): + checkpoint['optimizer'] = self.optim_wrapper.state_dict() + else: + raise TypeError( + 'self.optim_wrapper should be an `OptimWrapper` ' + 'or `OptimWrapperDict` instance, but got ' + f'{self.optim_wrapper}') + + # save param scheduler state dict + if save_param_scheduler and self.param_schedulers is None: + print_log( + '`save_param_scheduler` is True but `self.param_schedulers` ' + 'is None, so skip saving parameter schedulers', + logger='current', + level=logging.WARNING) + save_param_scheduler = False + if save_param_scheduler: + if isinstance(self.param_schedulers, dict): + checkpoint['param_schedulers'] = dict() + for name, schedulers in self.param_schedulers.items(): + checkpoint['param_schedulers'][name] = [] + for scheduler in schedulers: + state_dict = scheduler.state_dict() + checkpoint['param_schedulers'][name].append(state_dict) + else: + checkpoint['param_schedulers'] = [] + for scheduler in self.param_schedulers: # type: ignore + state_dict = scheduler.state_dict() # type: ignore + checkpoint['param_schedulers'].append(state_dict) + + self.call_hook('before_save_checkpoint', checkpoint=checkpoint) + save_checkpoint(checkpoint, filepath) + + @master_only + def dump_config(self) -> None: + version = '' + if len(self.trainer.loggers) > 0: + version = self.trainer.loggers[0].version + version = version if isinstance(version, str) else f"version_{version}" + if version == '': + # if no loggers, use default_root_dir + version = 'version' + + """Dump config to `work_dir`.""" + if self.cfg.filename is not None: + filename = osp.basename(self.cfg.filename) + else: + filename = f'{self.timestamp}.py' + path = f'{self.trainer.default_root_dir}/{version}_{filename}' + + self.cfg.dump(path) + + def _check_scheduler_cfg( + self, param_scheduler: Optional[Union[dict, list, + _ParamScheduler]]) -> None: + """Parse `param_scheduler` to a list of parameter schedulers, or a + `dict` of which each value is a list of parameter schedulers. + + If only one optimizer is used, the parsed config should be a + list of parameter scheduler configs or instances. If multiple + optimizers are used, the parsed config should be `dict`. + Its key should be consistent with the optimizer `dict` and its value + should be a list of parameter scheduler configs or instances. See + :meth:`build_param_scheduler` for more details. + + Examples: + >>> # valid scheduler: + >>> # empty scheduler + >>> scheduler = None + >>> # Single scheduler + >>> scheduler = dict(type='MultiStepLR', milestones=[1, 2]) + >>> # Single list schedulers + >>> scheduler = [dict(type='MultiStepLR', milestones=[1, 2]), + >>> dict(type='MultiStepLR', milestones=[2, 3])] + >>> # `dict` of schedulers + >>> scheduler = dict(linear1=dict(type='MultiStepLR', milestones=[1, 2]), + >>> linear2=dict(type='MultiStepLR', milestones=[1, 2])) + >>> # `dict` of `list` of schedulers + >>> scheduler = dict(linear1=[dict(type='MultiStepLR', milestones=[1, 2])], + >>> linear2=[dict(type='MultiStepLR', milestones=[1, 2])]) + >>> # Single built scheduler + >>> from mmengine.optim import MultiStepLR + >>> scheduler = MultiStepLR(milestones=[1, 2], optimizer=optimizer) + >>> # Single built list schedulers + >>> scheduler = [MultiStepLR(milestones=[1, 2], optimizer=optimizer)] + >>> # dict of built scheduler + >>> scheduler = dict(linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer), + >>> linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer)) + >>> # dict of built list schedulers + >>> scheduler = dict(linear1=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)], + >>> linear2=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)]) + + Args: + param_scheduler (dict or list): The original parameter scheduler. + """ # noqa: E501 + param_schedulers: Union[dict, list, _ParamScheduler] + if param_scheduler is None: + return + if isinstance(param_scheduler, _ParamScheduler): + return + if is_seq_of(param_scheduler, _ParamScheduler): + return + + if is_seq_of(param_scheduler, dict): + for _param_scheduler in param_scheduler: + assert 'type' in _param_scheduler, ( + 'Each parameter scheduler should contain the key type, ' + f'but got {_param_scheduler}') + elif isinstance(param_scheduler, dict): + if 'type' not in param_scheduler: + for key, _param_scheduler in param_scheduler.items(): + assert isinstance( + _param_scheduler, + (dict, tuple, list, _ParamScheduler)), ( + 'Each value of `param_scheduler` should be a ' + f'dict or a list, but got {_param_scheduler} with ' + f'type {type(_ParamScheduler)}') + + else: + raise TypeError( + '`param_scheduler` should be a `_ParamScheduler`, `dict`, ' + f'list or a tuple, but got {type(param_scheduler)}. If ' + '`param_scheduler` is a list of dict, it means a list of ' + 'scheduler configs for single optimizer. If it is a dict and ' + 'contains key `type`, it means a scheduler config for a ' + 'single optimizer. If it does not contain key `type`, it ' + 'means multiple lists of schedulers for multiple optimizers.') + + def _log_env(self, env_cfg: dict) -> None: + """Logging environment information of the current task. + + Args: + env_cfg (dict): The environment config of the runner. + """ + # Collect and log environment information. + env = collect_env() + runtime_env = OrderedDict() + runtime_env.update(env_cfg) + runtime_env.update(self._randomness_cfg) + runtime_env['Distributed launcher'] = self._launcher + runtime_env['Distributed training'] = self._distributed + runtime_env['GPU number'] = self._world_size + + env_info = '\n ' + '\n '.join(f'{k}: {v}' + for k, v in env.items()) + runtime_env_info = '\n ' + '\n '.join( + f'{k}: {v}' for k, v in runtime_env.items()) + dash_line = '-' * 60 + self.logger.info('\n' + dash_line + '\nSystem environment:' + + env_info + '\n' + '\nRuntime environment:' + runtime_env_info + '\n' + + dash_line + '\n') + self.logger.info(f'Config:\n{self.cfg.pretty_text}') \ No newline at end of file diff --git a/mmpl/engine/strategies/__init__.py b/mmpl/engine/strategies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2accb5f598116547e35bcedfdd4a282e711b7d25 --- /dev/null +++ b/mmpl/engine/strategies/__init__.py @@ -0,0 +1 @@ +from .builder import PL_MODEL_WRAPPERS \ No newline at end of file diff --git a/mmpl/engine/strategies/__pycache__/__init__.cpython-310.pyc b/mmpl/engine/strategies/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3cd45ff0aab930647e8349eb1fae582fcf784d0 Binary files /dev/null and b/mmpl/engine/strategies/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/engine/strategies/__pycache__/builder.cpython-310.pyc b/mmpl/engine/strategies/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3b1963abafe33afe25e80947b0f9c795c5cbdca Binary files /dev/null and b/mmpl/engine/strategies/__pycache__/builder.cpython-310.pyc differ diff --git a/mmpl/engine/strategies/builder.py b/mmpl/engine/strategies/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..01edd2ba58fd00b58eb148998348ca8a620a23b5 --- /dev/null +++ b/mmpl/engine/strategies/builder.py @@ -0,0 +1,26 @@ +import inspect +from typing import List, Union + +import torch +import lightning +from mmpl.registry import MODEL_WRAPPERS + + +def register_pl_strategies() -> List[str]: + """Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry. + + Returns: + List[str]: A list of registered callbacks' name. + """ + pl_strategies = [] + for module_name in dir(lightning.pytorch.strategies): + if module_name.startswith('__'): + continue + _strategy = getattr(lightning.pytorch.strategies, module_name) + if inspect.isclass(_strategy) and issubclass(_strategy, lightning.pytorch.strategies.Strategy): + MODEL_WRAPPERS.register_module(module=_strategy) + pl_strategies.append(module_name) + return pl_strategies + + +PL_MODEL_WRAPPERS = register_pl_strategies() diff --git a/mmpl/engine/visualization/__init__.py b/mmpl/engine/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmpl/engine/visualization/__pycache__/__init__.cpython-310.pyc b/mmpl/engine/visualization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61bf6f6212d9dead96638cfd9cee181a0156284 Binary files /dev/null and b/mmpl/engine/visualization/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/evaluation/__init__.py b/mmpl/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1761d1aa7d58a7c39d4d8f051fecfd5b99802c7e --- /dev/null +++ b/mmpl/evaluation/__init__.py @@ -0,0 +1 @@ +from .metrics import * diff --git a/mmpl/evaluation/__pycache__/__init__.cpython-310.pyc b/mmpl/evaluation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..804b29d727b48e28403ef3b0d3c7b3f16fc1f8c3 Binary files /dev/null and b/mmpl/evaluation/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/evaluation/metrics/__init__.py b/mmpl/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17edbe4e4d3a1defddeb23deceba504f9058c43e --- /dev/null +++ b/mmpl/evaluation/metrics/__init__.py @@ -0,0 +1,3 @@ +from .builder import PL_METRICS +from .coco_pl_metric import CocoPLMetric +from .mean_ap import PLMeanAveragePrecision diff --git a/mmpl/evaluation/metrics/__pycache__/__init__.cpython-310.pyc b/mmpl/evaluation/metrics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bd12377b2c6120fefed06fc8eabdc144c08032 Binary files /dev/null and b/mmpl/evaluation/metrics/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/evaluation/metrics/__pycache__/builder.cpython-310.pyc b/mmpl/evaluation/metrics/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bccefd1018289193d55bd2bb928c805e18467b26 Binary files /dev/null and b/mmpl/evaluation/metrics/__pycache__/builder.cpython-310.pyc differ diff --git a/mmpl/evaluation/metrics/__pycache__/coco_pl_metric.cpython-310.pyc b/mmpl/evaluation/metrics/__pycache__/coco_pl_metric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04a3efe054eb2f3741ced5c3fdd83e9d40dcfb0b Binary files /dev/null and b/mmpl/evaluation/metrics/__pycache__/coco_pl_metric.cpython-310.pyc differ diff --git a/mmpl/evaluation/metrics/__pycache__/mean_ap.cpython-310.pyc b/mmpl/evaluation/metrics/__pycache__/mean_ap.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0817e4601b27ca70828e892987d2ffa1cd0a40fe Binary files /dev/null and b/mmpl/evaluation/metrics/__pycache__/mean_ap.cpython-310.pyc differ diff --git a/mmpl/evaluation/metrics/builder.py b/mmpl/evaluation/metrics/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..bd55df759561b73656a71941e67f9c033d900dd7 --- /dev/null +++ b/mmpl/evaluation/metrics/builder.py @@ -0,0 +1,34 @@ +import copy +import inspect +from typing import List, Union + +import torch +import torch.nn as nn +import lightning +import torchmetrics +import torchmetrics.detection + +from mmengine.config import Config, ConfigDict +from mmpl.registry import METRICS + + +def register_pl_metrics() -> List[str]: + """Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry. + + Returns: + List[str]: A list of registered optimizers' name. + """ + pl_metrics = [] + for modules in [torchmetrics, torchmetrics.detection]: + for module_name in dir(modules): + if module_name.startswith('__'): + continue + _metric = getattr(modules, module_name) + if inspect.isclass(_metric): + METRICS.register_module(module=_metric) + pl_metrics.append(module_name) + return pl_metrics + + +PL_METRICS = register_pl_metrics() + diff --git a/mmpl/evaluation/metrics/cityscapes_pl_metric.py b/mmpl/evaluation/metrics/cityscapes_pl_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d94f3eb945e90641d86705c45c0d38b206c2b3aa --- /dev/null +++ b/mmpl/evaluation/metrics/cityscapes_pl_metric.py @@ -0,0 +1,245 @@ +import os +import os.path as osp +import shutil +import tempfile +from collections import OrderedDict +from typing import Dict, Optional, Sequence + +import mmcv +import numpy as np +from mmengine.dist import is_main_process +# from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger +from torchmetrics import Metric + +from mmpl.registry import METRICS + +try: + import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa: E501 + import cityscapesscripts.helpers.labels as CSLabels + + from mmdet.evaluation.functional import evaluateImgLists + HAS_CITYSCAPESAPI = True +except ImportError: + HAS_CITYSCAPESAPI = False + + +@METRICS.register_module() +class CityScapesPLMetric(Metric): + """CityScapes metric for instance segmentation. + + Args: + outfile_prefix (str): The prefix of txt and png files. The txt and + png file will be save in a directory whose path is + "outfile_prefix.results/". + seg_prefix (str, optional): Path to the directory which contains the + cityscapes instance segmentation masks. It's necessary when + training and validation. It could be None when infer on test + dataset. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + dump_matches (bool): Whether dump matches.json file during evaluating. + Defaults to False. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + """ + default_prefix: Optional[str] = 'cityscapes' + + def __init__(self, + outfile_prefix: str, + seg_prefix: Optional[str] = None, + format_only: bool = False, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + dump_matches: bool = False, + file_client_args: dict = None, + backend_args: dict = None, + **kwargs + ) -> None: + + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + super().__init__(**kwargs) + + self.tmp_dir = None + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + else: + assert seg_prefix is not None, '`seg_prefix` is necessary when ' + 'computing the CityScapes metrics' + + if outfile_prefix is None: + self.tmp_dir = tempfile.TemporaryDirectory() + self.outfile_prefix = osp.join(self.tmp_dir.name, 'results') + else: + # the directory to save predicted panoptic segmentation mask + self.outfile_prefix = osp.join(outfile_prefix, 'results') # type: ignore # yapf: disable # noqa: E501 + + dir_name = osp.expanduser(self.outfile_prefix) + + if osp.exists(dir_name) and is_main_process(): + logger: MMLogger = MMLogger.get_current_instance() + logger.info('remove previous results.') + shutil.rmtree(dir_name) + os.makedirs(dir_name, exist_ok=True) + + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + self.seg_prefix = seg_prefix + self.dump_matches = dump_matches + + def __del__(self) -> None: + """Clean up the results if necessary.""" + if self.tmp_dir is not None: + self.tmp_dir.cleanup() + + def update(self, data_batch, data_samples) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + # parse pred + result = dict() + pred = data_sample['pred_instances'] + filename = data_sample['img_path'] + basename = osp.splitext(osp.basename(filename))[0] + pred_txt = osp.join(self.outfile_prefix, basename + '_pred.txt') + result['pred_txt'] = pred_txt + labels = pred['labels'].cpu().numpy() + masks = pred['masks'].cpu().numpy().astype(np.uint8) + if 'mask_scores' in pred: + # some detectors use different scores for bbox and mask + mask_scores = pred['mask_scores'].cpu().numpy() + else: + mask_scores = pred['scores'].cpu().numpy() + + with open(pred_txt, 'w') as f: + for i, (label, mask, mask_score) in enumerate( + zip(labels, masks, mask_scores)): + class_name = self.dataset_meta['classes'][label] + class_id = CSLabels.name2label[class_name].id + png_filename = osp.join( + self.outfile_prefix, + basename + f'_{i}_{class_name}.png') + mmcv.imwrite(mask, png_filename) + f.write(f'{osp.basename(png_filename)} ' + f'{class_id} {mask_score}\n') + + # parse gt + gt = dict() + img_path = filename.replace('leftImg8bit.png', + 'gtFine_instanceIds.png') + gt['file_name'] = img_path.replace('leftImg8bit', 'gtFine') + + self.results.append((gt, result)) + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + + for data_sample in data_samples: + # parse pred + result = dict() + pred = data_sample['pred_instances'] + filename = data_sample['img_path'] + basename = osp.splitext(osp.basename(filename))[0] + pred_txt = osp.join(self.outfile_prefix, basename + '_pred.txt') + result['pred_txt'] = pred_txt + labels = pred['labels'].cpu().numpy() + masks = pred['masks'].cpu().numpy().astype(np.uint8) + if 'mask_scores' in pred: + # some detectors use different scores for bbox and mask + mask_scores = pred['mask_scores'].cpu().numpy() + else: + mask_scores = pred['scores'].cpu().numpy() + + with open(pred_txt, 'w') as f: + for i, (label, mask, mask_score) in enumerate( + zip(labels, masks, mask_scores)): + class_name = self.dataset_meta['classes'][label] + class_id = CSLabels.name2label[class_name].id + png_filename = osp.join( + self.outfile_prefix, + basename + f'_{i}_{class_name}.png') + mmcv.imwrite(mask, png_filename) + f.write(f'{osp.basename(png_filename)} ' + f'{class_id} {mask_score}\n') + + # parse gt + gt = dict() + img_path = filename.replace('leftImg8bit.png', + 'gtFine_instanceIds.png') + gt['file_name'] = img_path.replace('leftImg8bit', 'gtFine') + + self.results.append((gt, result)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + if self.format_only: + logger.info( + f'results are saved to {osp.dirname(self.outfile_prefix)}') + return OrderedDict() + logger.info('starts to compute metric') + + gts, preds = zip(*results) + # set global states in cityscapes evaluation API + gt_instances_file = osp.join(self.outfile_prefix, 'gtInstances.json') # type: ignore # yapf: disable # noqa: E501 + # split gt and prediction list + gts, preds = zip(*results) + CSEval.args.JSONOutput = False + CSEval.args.colorized = False + CSEval.args.gtInstancesFile = gt_instances_file + + groundTruthImgList = [gt['file_name'] for gt in gts] + predictionImgList = [pred['pred_txt'] for pred in preds] + CSEval_results = evaluateImgLists( + predictionImgList, + groundTruthImgList, + CSEval.args, + self.backend_args, + dump_matches=self.dump_matches)['averages'] + + eval_results = OrderedDict() + eval_results['mAP'] = CSEval_results['allAp'] + eval_results['AP@50'] = CSEval_results['allAp50%'] + + return eval_results diff --git a/mmpl/evaluation/metrics/coco_pl_metric.py b/mmpl/evaluation/metrics/coco_pl_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b31b9ec5783996ccbb12e9620c8b27e0b10d64 --- /dev/null +++ b/mmpl/evaluation/metrics/coco_pl_metric.py @@ -0,0 +1,629 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import itertools +import os.path as osp +import tempfile +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence, Union + +import lightning +import mmengine +import numpy as np +import torch +from mmengine.fileio import dump, get_local_path, load +from mmengine.logging import MMLogger +from terminaltables import AsciiTable + +from mmdet.datasets.api_wrappers import COCO, COCOeval +from mmdet.structures.mask import encode_mask_results +from mmdet.evaluation.functional import eval_recalls +from torchmetrics import Metric +from mmpl.registry import METRICS +from torchmetrics.utilities import rank_zero_info + + +@METRICS.register_module() +class CocoPLMetric(Metric): + """COCO evaluation metric. + + Evaluate AR, AP, and mAP for detection tasks including proposal/box + detection and instance segmentation. Please refer to + https://cocodataset.org/#detection-eval for more details. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', 'proposal', and 'proposal_fast'. + Defaults to 'bbox'. + classwise (bool): Whether to evaluate the metric class-wise. + Defaults to False. + proposal_nums (Sequence[int]): Numbers of proposals to be evaluated. + Defaults to (100, 300, 1000). + iou_thrs (float | List[float], optional): IoU threshold to compute AP + and AR. If not specified, IoUs from 0.5 to 0.95 will be used. + Defaults to None. + metric_items (List[str], optional): Metric result names to be + recorded in the evaluation result. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + sort_categories (bool): Whether sort categories in annotations. Only + used for `Objects365V1Dataset`. Defaults to False. + """ + # default_prefix: Optional[str] = 'coco' + + def __init__(self, + ann_file: Optional[str] = None, + metric: Union[str, List[str]] = 'bbox', + classwise: bool = False, + proposal_nums: Sequence[int] = (100, 300, 1000), + iou_thrs: Optional[Union[float, Sequence[float]]] = None, + metric_items: Optional[Sequence[str]] = None, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + file_client_args: dict = None, + backend_args: dict = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + sort_categories: bool = False, + **kwargs + ) -> None: + super().__init__(**kwargs) + self._dataset_meta: Union[None, dict] = None + # coco evaluation metrics + self.metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast'] + for metric in self.metrics: + if metric not in allowed_metrics: + raise KeyError( + "metric should be one of 'bbox', 'segm', 'proposal', " + f"'proposal_fast', but got {metric}.") + + # do class wise evaluation, default False + self.classwise = classwise + + # proposal_nums used to compute recall or precision. + self.proposal_nums = list(proposal_nums) + + # iou_thrs used to compute recall or precision. + if iou_thrs is None: + iou_thrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + self.iou_thrs = iou_thrs + self.metric_items = metric_items + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + + self.outfile_prefix = outfile_prefix + + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + # if ann_file is not specified, + # initialize coco api with the converted dataset + if ann_file is not None: + with get_local_path( + ann_file, backend_args=self.backend_args) as local_path: + self._coco_api = COCO(local_path) + if sort_categories: + # 'categories' list in objects365_train.json and + # objects365_val.json is inconsistent, need sort + # list(or dict) before get cat_ids. + cats = self._coco_api.cats + sorted_cats = {i: cats[i] for i in sorted(cats)} + self._coco_api.cats = sorted_cats + categories = self._coco_api.dataset['categories'] + sorted_categories = sorted( + categories, key=lambda i: i['id']) + self._coco_api.dataset['categories'] = sorted_categories + else: + self._coco_api = None + + # handle dataset lazy init + self.cat_ids = None + self.img_ids = None + + self.add_state('results', default=[], dist_reduce_fx=None) + + @property + def dataset_meta(self) -> Optional[dict]: + """Optional[dict]: Meta info of the dataset.""" + return self._dataset_meta + + @dataset_meta.setter + def dataset_meta(self, dataset_meta: dict) -> None: + """Set the dataset meta info to the metric.""" + self._dataset_meta = dataset_meta + + def fast_eval_recall(self, + results: List[dict], + proposal_nums: Sequence[int], + iou_thrs: Sequence[float], + logger: Optional[MMLogger] = None) -> np.ndarray: + """Evaluate proposal recall with COCO's fast_eval_recall. + + Args: + results (List[dict]): Results of the dataset. + proposal_nums (Sequence[int]): Proposal numbers used for + evaluation. + iou_thrs (Sequence[float]): IoU thresholds used for evaluation. + logger (MMLogger, optional): Logger used for logging the recall + summary. + Returns: + np.ndarray: Averaged recall results. + """ + gt_bboxes = [] + pred_bboxes = [result['bboxes'] for result in results] + for i in range(len(self.img_ids)): + ann_ids = self._coco_api.get_ann_ids(img_ids=self.img_ids[i]) + ann_info = self._coco_api.load_anns(ann_ids) + if len(ann_info) == 0: + gt_bboxes.append(np.zeros((0, 4))) + continue + bboxes = [] + for ann in ann_info: + if ann.get('ignore', False) or ann['iscrowd']: + continue + x1, y1, w, h = ann['bbox'] + bboxes.append([x1, y1, x1 + w, y1 + h]) + bboxes = np.array(bboxes, dtype=np.float32) + if bboxes.shape[0] == 0: + bboxes = np.zeros((0, 4)) + gt_bboxes.append(bboxes) + + recalls = eval_recalls( + gt_bboxes, pred_bboxes, proposal_nums, iou_thrs, logger=logger) + ar = recalls.mean(axis=1) + return ar + + def xyxy2xywh(self, bbox: np.ndarray) -> list: + """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO + evaluation. + + Args: + bbox (numpy.ndarray): The bounding boxes, shape (4, ), in + ``xyxy`` order. + + Returns: + list[float]: The converted bounding boxes, in ``xywh`` order. + """ + + _bbox: List = bbox.tolist() + return [ + _bbox[0], + _bbox[1], + _bbox[2] - _bbox[0], + _bbox[3] - _bbox[1], + ] + + def results2json(self, results: Sequence[dict], + outfile_prefix: str) -> dict: + """Dump the detection results to a COCO style json file. + + There are 3 types of results: proposals, bbox predictions, mask + predictions, and they have different data types. This method will + automatically recognize the type, and dump them to json files. + + Args: + results (Sequence[dict]): Testing results of the + dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json files will be named + "somepath/xxx.bbox.json", "somepath/xxx.segm.json", + "somepath/xxx.proposal.json". + + Returns: + dict: Possible keys are "bbox", "segm", "proposal", and + values are corresponding filenames. + """ + bbox_json_results = [] + segm_json_results = [] if 'masks' in results[0] else None + for idx, result in enumerate(results): + image_id = result.get('img_id', idx) + labels = result['labels'] + bboxes = result['bboxes'] + scores = result['scores'] + # bbox results + for i, label in enumerate(labels): + data = dict() + data['image_id'] = image_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(scores[i]) + data['category_id'] = self.cat_ids[label] + bbox_json_results.append(data) + + if segm_json_results is None: + continue + + # segm results + masks = result['masks'] + mask_scores = result.get('mask_scores', scores) + for i, label in enumerate(labels): + data = dict() + data['image_id'] = image_id + data['bbox'] = self.xyxy2xywh(bboxes[i]) + data['score'] = float(mask_scores[i]) + data['category_id'] = self.cat_ids[label] + if isinstance(masks[i]['counts'], bytes): + masks[i]['counts'] = masks[i]['counts'].decode() + data['segmentation'] = masks[i] + segm_json_results.append(data) + + result_files = dict() + result_files['bbox'] = f'{outfile_prefix}.bbox.json' + result_files['proposal'] = f'{outfile_prefix}.bbox.json' + dump(bbox_json_results, result_files['bbox']) + + if segm_json_results is not None: + result_files['segm'] = f'{outfile_prefix}.segm.json' + dump(segm_json_results, result_files['segm']) + + return result_files + + def gt_to_coco_json(self, gt_dicts: Sequence[dict], + outfile_prefix: str) -> str: + """Convert ground truth to coco format json file. + + Args: + gt_dicts (Sequence[dict]): Ground truth of the dataset. + outfile_prefix (str): The filename prefix of the json files. If the + prefix is "somepath/xxx", the json file will be named + "somepath/xxx.gt.json". + Returns: + str: The filename of the json file. + """ + categories = [ + dict(id=id, name=name) + for id, name in enumerate(self.dataset_meta['classes']) + ] + image_infos = [] + annotations = [] + + for idx, gt_dict in enumerate(gt_dicts): + img_id = gt_dict.get('img_id', idx) + image_info = dict( + id=img_id, + width=gt_dict['width'], + height=gt_dict['height'], + file_name='') + image_infos.append(image_info) + for ann in gt_dict['anns']: + label = ann['bbox_label'] + bbox = ann['bbox'] + coco_bbox = [ + bbox[0], + bbox[1], + bbox[2] - bbox[0], + bbox[3] - bbox[1], + ] + + annotation = dict( + id=len(annotations) + + 1, # coco api requires id starts with 1 + image_id=img_id, + bbox=coco_bbox, + iscrowd=ann.get('ignore_flag', 0), + category_id=int(label), + area=coco_bbox[2] * coco_bbox[3]) + if ann.get('mask', None): + mask = ann['mask'] + # area = mask_util.area(mask) + if isinstance(mask, dict) and isinstance( + mask['counts'], bytes): + mask['counts'] = mask['counts'].decode() + annotation['segmentation'] = mask + # annotation['area'] = float(area) + annotations.append(annotation) + + info = dict( + date_created=str(datetime.datetime.now()), + description='Coco json file converted by mmdet CocoMetric.') + coco_json = dict( + info=info, + images=image_infos, + categories=categories, + licenses=None, + ) + if len(annotations) > 0: + coco_json['annotations'] = annotations + converted_json_path = f'{outfile_prefix}.gt.json' + dump(coco_json, converted_json_path) + return converted_json_path + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def update(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + result = dict() + pred = data_sample.pred_instances + result['img_id'] = data_sample.img_id + result['bboxes'] = pred['bboxes'].cpu().numpy() + result['scores'] = pred['scores'].cpu().numpy() + result['labels'] = pred['labels'].cpu().numpy() + # encode mask to RLE + if 'masks' in pred: + result['masks'] = encode_mask_results( + pred['masks'].detach().cpu().numpy()) if isinstance( + pred['masks'], torch.Tensor) else pred['masks'] + # some detectors use different scores for bbox and mask + if 'mask_scores' in pred: + result['mask_scores'] = pred['mask_scores'].cpu().numpy() + + # parse gt + gt = dict() + gt['width'] = data_sample.ori_shape[1] + gt['height'] = data_sample.ori_shape[0] + gt['img_id'] = data_sample.img_id + if self._coco_api is None: + # TODO: Need to refactor to support LoadAnnotations + assert 'gt_instances' in data_sample, \ + 'ground truth is required for evaluation when ' \ + '`ann_file` is not provided' + gt['anns'] = [] + for x_data in data_sample.gt_instances: + mask_ = encode_mask_results(x_data['masks'].masks) + assert len(mask_) == 1, \ + 'Only support one mask per instance for now' + gt['anns'].append( + dict( + bbox_label=x_data['labels'].item(), + bbox=x_data['bboxes'].cpu().numpy().reshape(4), + mask=mask_[0] + ) + ) + # add converted result to the results list + self.results.append((gt, result)) + + def compute(self) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + results = self.results + logger: MMLogger = MMLogger.get_current_instance() + + # split gt and prediction list + gts, preds = zip(*results) + + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + + if self._coco_api is None: + # use converted gt json file to initialize coco api + logger.info('Converting ground truth to coco format...') + coco_json_path = self.gt_to_coco_json( + gt_dicts=gts, outfile_prefix=outfile_prefix) + self._coco_api = COCO(coco_json_path) + + # handle lazy init + if self.cat_ids is None: + self.cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['classes']) + if self.img_ids is None: + self.img_ids = self._coco_api.get_img_ids() + + # convert predictions to coco format and dump to json file + result_files = self.results2json(preds, outfile_prefix) + + eval_results = OrderedDict() + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return eval_results + + for metric in self.metrics: + logger.info(f'Evaluating {metric}...') + + # TODO: May refactor fast_eval_recall to an independent metric? + # fast eval recall + if metric == 'proposal_fast': + ar = self.fast_eval_recall( + preds, self.proposal_nums, self.iou_thrs, logger=logger) + log_msg = [] + for i, num in enumerate(self.proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') + log_msg = ''.join(log_msg) + logger.info(log_msg) + continue + + # evaluate proposal, bbox and segm + iou_type = 'bbox' if metric == 'proposal' else metric + if metric not in result_files: + raise KeyError(f'{metric} is not in results') + try: + predictions = load(result_files[metric]) + if iou_type == 'segm': + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in predictions: + x.pop('bbox') + coco_dt = self._coco_api.loadRes(predictions) + + except IndexError: + # for k, v in eval_results.items(): + # eval_results[k] = torch.tensor(v).to(self.device) + # self._coco_api = None + logger.error( + 'The testing results of the whole dataset is empty.') + break + + coco_eval = COCOeval(self._coco_api, coco_dt, iou_type) + + coco_eval.params.catIds = self.cat_ids + coco_eval.params.imgIds = self.img_ids + coco_eval.params.maxDets = list(self.proposal_nums) + coco_eval.params.iouThrs = self.iou_thrs + + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + metric_items = self.metric_items + if metric_items is not None: + for metric_item in metric_items: + if metric_item not in coco_metric_names: + raise KeyError( + f'metric item "{metric_item}" is not supported') + + if metric == 'proposal': + coco_eval.params.useCats = 0 + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if metric_items is None: + metric_items = [ + 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', + 'AR_m@1000', 'AR_l@1000' + ] + + for item in metric_items: + val = float( + f'{coco_eval.stats[coco_metric_names[item]]:.3f}') + eval_results[item] = val + else: + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + if self.classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = coco_eval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, cat_id in enumerate(self.cat_ids): + t = [] + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self._coco_api.loadCats(cat_id)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{nm["name"]}') + t.append(f'{round(ap, 3)}') + eval_results[f'{nm["name"]}_precision'] = round(ap, 3) + + # indexes of IoU @50 and @75 + for iou in [0, 5]: + precision = precisions[iou, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{round(ap, 3)}') + + # indexes of area of small, median and large + for area in [1, 2, 3]: + precision = precisions[:, :, idx, area, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + t.append(f'{round(ap, 3)}') + results_per_category.append(tuple(t)) + + num_columns = len(results_per_category[0]) + results_flatten = list( + itertools.chain(*results_per_category)) + headers = [ + 'category', 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', + 'mAP_m', 'mAP_l' + ] + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] + for i in range(num_columns) + ]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + # if mmengine.dist.get_local_rank() == 0: + rank_zero_info('\n' + table.table) + + if metric_items is None: + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = coco_eval.stats[coco_metric_names[metric_item]] + eval_results[key] = float(f'{round(val, 3)}') + + ap = coco_eval.stats[:6] + # if mmengine.dist.get_local_rank() == 0: + + rank_zero_info(f'{metric}_mAP_copypaste: {ap[0]:.3f} ' + f'{ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + if tmp_dir is not None: + tmp_dir.cleanup() + for k, v in eval_results.items(): + eval_results[k] = torch.tensor(v).to(self.device) + self._coco_api = None + return eval_results diff --git a/mmpl/evaluation/metrics/mean_ap.py b/mmpl/evaluation/metrics/mean_ap.py new file mode 100644 index 0000000000000000000000000000000000000000..a8243b82354ba2448b5b9d36d2de29a4eb90ec3b --- /dev/null +++ b/mmpl/evaluation/metrics/mean_ap.py @@ -0,0 +1,41 @@ +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable +import torch.distributed as dist +from torchmetrics.detection import MeanAveragePrecision +from torchmetrics.utilities.distributed import gather_all_tensors +from mmpl.registry import METRICS + + +@METRICS.register_module(force=True) +class PLMeanAveragePrecision(MeanAveragePrecision): + def __init__( + self, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None: + super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) + + if self.iou_type == "segm": + self.detections = self._gather_tuple_list(self.detections, process_group) + self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) + + @staticmethod + def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: + world_size = dist.get_world_size(group=process_group) + list_gathered = [None] * world_size + dist.all_gather_object(list_gathered, list_to_gather, group=process_group) + + for rank in range(1, world_size): + assert ( + len(list_gathered[rank]) == list_gathered[0], + f"Rank{rank} doesn't have the same number of elements as Rank0: " + f"{list_gathered[rank]} vs. {list_gathered[0]}", + ) + list_merged = [] + for idx in range(len(list_gathered[0])): + for rank in range(world_size): + list_merged.append(list_gathered[rank][idx]) + + return list_merged diff --git a/mmpl/models/__init__.py b/mmpl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0806f5b20bdf732ff752fe0937550484870f4b --- /dev/null +++ b/mmpl/models/__init__.py @@ -0,0 +1,9 @@ +from .builder import build_pler +from .pler import * +from .backbones import * +from .losses import * +from .heads import * +from .necks import * +from .data_preprocessors import * + +__all__ = ['build_pler'] \ No newline at end of file diff --git a/mmpl/models/__pycache__/__init__.cpython-310.pyc b/mmpl/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a42d3f7872118940180807886e7e0862a0be9ca4 Binary files /dev/null and b/mmpl/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/models/__pycache__/builder.cpython-310.pyc b/mmpl/models/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..408c7e742cb919554c5bab51da210900faa0789e Binary files /dev/null and b/mmpl/models/__pycache__/builder.cpython-310.pyc differ diff --git a/mmpl/models/backbones/__init__.py b/mmpl/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmpl/models/backbones/__pycache__/__init__.cpython-310.pyc b/mmpl/models/backbones/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7a073e79f56bf1c00aa30d32b688f1588236884 Binary files /dev/null and b/mmpl/models/backbones/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/models/backbones/base_backbone.py b/mmpl/models/backbones/base_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed65b7bac50568b5fd9101c967be6f0f43e7ebb --- /dev/null +++ b/mmpl/models/backbones/base_backbone.py @@ -0,0 +1,31 @@ +from abc import ABCMeta, abstractmethod +from mmengine.model import BaseModule + + +class BaseBackbone(BaseModule, metaclass=ABCMeta): + """Base backbone. + + This class defines the basic functions of a backbone. Any backbone that + inherits this class should at least define its own `forward` function. + """ + + def __init__(self, init_cfg=None): + super(BaseBackbone, self).__init__(init_cfg) + + @abstractmethod + def forward(self, x): + """Forward computation. + + Args: + x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of + Torch.tensor, containing input data for forward computation. + """ + pass + + def train(self, mode=True): + """Set module status before forward computation. + + Args: + mode (bool): Whether it is train_mode or test_mode + """ + super(BaseBackbone, self).train(mode) diff --git a/mmpl/models/backbones/huggingface_hub.py b/mmpl/models/backbones/huggingface_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..7b582bebefb2e093f8c3b91926464fcacb9e0059 --- /dev/null +++ b/mmpl/models/backbones/huggingface_hub.py @@ -0,0 +1,12 @@ +import torch.nn as nn +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm +from mmpl.registry import MODELS +from .base_backbone import BaseBackbone + + +@MODELS.register_module() +class HuggingfaceModel(BaseBackbone): + def __init__(self, ): + pass diff --git a/mmpl/models/builder.py b/mmpl/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..269c46de8dab7aa0c3d502bb492a655d8181d9cd --- /dev/null +++ b/mmpl/models/builder.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpl.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +PLERS = MODELS +RETRIEVER = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_pler(cfg): + """Build classifier.""" + return PLERS.build(cfg) + + +def build_retriever(cfg): + """Build retriever.""" + return RETRIEVER.build(cfg) diff --git a/mmpl/models/data_preprocessors/__init__.py b/mmpl/models/data_preprocessors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e484f495fb81105ad9000a0b60238fbbb5e69600 --- /dev/null +++ b/mmpl/models/data_preprocessors/__init__.py @@ -0,0 +1 @@ +from .data_preprocessor import BatchFixedSizePadTokenMaskGPT \ No newline at end of file diff --git a/mmpl/models/data_preprocessors/__pycache__/__init__.cpython-310.pyc b/mmpl/models/data_preprocessors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc664927f74e2822ea101bf37cb1f4ffe57eee9d Binary files /dev/null and b/mmpl/models/data_preprocessors/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/models/data_preprocessors/__pycache__/data_preprocessor.cpython-310.pyc b/mmpl/models/data_preprocessors/__pycache__/data_preprocessor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4263b200336dd9235b91d12eda8ed2c5c922f829 Binary files /dev/null and b/mmpl/models/data_preprocessors/__pycache__/data_preprocessor.cpython-310.pyc differ diff --git a/mmpl/models/data_preprocessors/data_preprocessor.py b/mmpl/models/data_preprocessors/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..d58c527f3784a3586383adbacaa3ab821b56efa5 --- /dev/null +++ b/mmpl/models/data_preprocessors/data_preprocessor.py @@ -0,0 +1,125 @@ +import random +from numbers import Number +from typing import List, Optional, Sequence, Tuple, Union + +import mmengine +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.dist import barrier, broadcast, get_dist_info +from mmengine.logging import MessageHub +from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor +from mmengine.structures import PixelData +from mmengine.utils import is_seq_of +from torch import Tensor + +from mmdet.models.utils import unfold_wo_center +from mmdet.models.utils.misc import samplelist_boxtype2tensor +from mmpl.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.structures.mask import BitmapMasks +from mmdet.utils import ConfigType + +try: + import skimage +except ImportError: + skimage = None + + +@MODELS.register_module() +class BatchFixedSizePadTokenMaskGPT(BaseDataPreprocessor): + """Fixed size padding for batch images. + + Args: + size (Tuple[int, int]): Fixed padding size. Expected padding + shape (h, w). Defaults to None. + img_pad_value (int): The padded pixel value for images. + Defaults to 0. + pad_mask (bool): Whether to pad instance masks. Defaults to False. + mask_pad_value (int): The padded pixel value for instance masks. + Defaults to 0. + pad_seg (bool): Whether to pad semantic segmentation maps. + Defaults to False. + seg_pad_value (int): The padded pixel value for semantic + segmentation maps. Defaults to 255. + """ + + def __init__(self, + pad_token: int, + p_token_keep: float = 1., + nb_code: int = 512, + ) -> None: + super().__init__() + self.pad_token = pad_token + self.p_token_keep = p_token_keep + self.nb_code = nb_code + + def forward( + self, + batch + ): + # padding the input index to the same length + + longest = max([len(item) for item in batch['motion_token']]) + bs = len(batch['motion_token']) + + attention_mask = torch.zeros(bs, longest, dtype=torch.long, device=self.device) + input_ids = torch.ones(bs, longest, dtype=torch.long, device=self.device) * self.pad_token + for i, item in enumerate(batch['motion_token']): + input_ids[i, :len(item)] = item + attention_mask[i, :len(item)] = 1 + + tgt_ids = input_ids + + if self.p_token_keep == -1: + proba = np.random.rand(1)[0] + mask = torch.bernoulli(proba * torch.ones(input_ids.shape, + device=input_ids.device)) + else: + mask = torch.bernoulli(self.p_token_keep * torch.ones(input_ids.shape, device=input_ids.device)) + mask = mask.bool() + r_indices = torch.randint_like(input_ids, self.nb_code) + a_indices = mask * input_ids + mask.logical_not() * r_indices + + tgt_ids[tgt_ids == self.pad_token] = -100 + + data = dict() + data['inputs'] = dict( + input_ids=a_indices, + attention_mask=attention_mask, + labels=tgt_ids, + ) + data['data_samples'] = batch + return data + + +@MODELS.register_module() +class NormalizationMotion(BaseDataPreprocessor): + + def __init__( + self, + mean_std_file: str, + ) -> None: + super().__init__() + self.mean_std_info = mmengine.load(mean_std_file) + + def forward( + self, + batch + ): + for k, v in self.mean_std_info.items(): + for kk, vv in v.items(): + self.mean_std_info[k][kk] = vv.to(self.device, dtype=torch.float32) + + gt_motion = batch['motion'] + gt_motion = (gt_motion - self.mean_std_info['motion']['mean']) / self.mean_std_info['motion']['std'] + + data = dict( + inputs=gt_motion, + data_samples=batch + ) + return data + + def denormalize(self, x): + return x * self.mean_std_info['motion']['std'] + self.mean_std_info['motion']['mean'] \ No newline at end of file diff --git a/mmpl/models/heads/__init__.py b/mmpl/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..693f248f8f16e19f662732ab25fa0c1f589e407e --- /dev/null +++ b/mmpl/models/heads/__init__.py @@ -0,0 +1,8 @@ +from .sam_instance_head import SAMInstanceHead +from .semantic_seg_head import BinarySemanticSegHead +from .seg_upfcn_head import UpFCNHead +from .sam_semseg_head import SamSemSegHead + +from .sam_instance_head import SAMAnchorInstanceHead, SAMAnchorPromptRoIHead, SAMPromptMaskHead + +# __all__ = ['MotionGPTHead', 'YOLOv8SIRENSHead'] diff --git a/mmpl/models/heads/__pycache__/__init__.cpython-310.pyc b/mmpl/models/heads/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8494d63c31160ed5ed3777b30f7540349b4d91ba Binary files /dev/null and b/mmpl/models/heads/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/models/heads/__pycache__/sam_instance_head.cpython-310.pyc b/mmpl/models/heads/__pycache__/sam_instance_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91a61f653ad57517cb9c3ac415ff109600726ac5 Binary files /dev/null and b/mmpl/models/heads/__pycache__/sam_instance_head.cpython-310.pyc differ diff --git a/mmpl/models/heads/__pycache__/sam_semseg_head.cpython-310.pyc b/mmpl/models/heads/__pycache__/sam_semseg_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d6cf5e0244394eead15f847d4c6c01c846dd92 Binary files /dev/null and b/mmpl/models/heads/__pycache__/sam_semseg_head.cpython-310.pyc differ diff --git a/mmpl/models/heads/__pycache__/seg_upfcn_head.cpython-310.pyc b/mmpl/models/heads/__pycache__/seg_upfcn_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c177d4ab862d3718c31b7bc200d9c5c391ff6eb Binary files /dev/null and b/mmpl/models/heads/__pycache__/seg_upfcn_head.cpython-310.pyc differ diff --git a/mmpl/models/heads/__pycache__/semantic_seg_head.cpython-310.pyc b/mmpl/models/heads/__pycache__/semantic_seg_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c519a89648b14787acb1b96959c8b37232af938 Binary files /dev/null and b/mmpl/models/heads/__pycache__/semantic_seg_head.cpython-310.pyc differ diff --git a/mmpl/models/heads/base_head.py b/mmpl/models/heads/base_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bac60c3f14f1bc2bbc2e822cf1fb5a22c8a03cd7 --- /dev/null +++ b/mmpl/models/heads/base_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Tuple + +from mmengine.model import BaseModule +from mmengine.structures import BaseDataElement + + +class BaseHead(BaseModule, metaclass=ABCMeta): + """Base head. + + Args: + init_cfg (dict, optional): The extra init config of layers. + Defaults to None. + """ + + def __init__(self, init_cfg: Optional[dict] = None): + super(BaseHead, self).__init__(init_cfg=init_cfg) + + @abstractmethod + def loss(self, feats: Tuple, data_samples: List[BaseDataElement]): + """Calculate losses from the extracted features. + + Args: + feats (tuple): The features extracted from the backbone. + data_samples (List[BaseDataElement]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + pass + + @abstractmethod + def predict(self, + feats: Tuple, + data_samples: Optional[List[BaseDataElement]] = None): + """Predict results from the extracted features. + + Args: + feats (tuple): The features extracted from the backbone. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[BaseDataElement]: A list of data samples which contains the + predicted results. + """ + pass diff --git a/mmpl/models/heads/cls_head.py b/mmpl/models/heads/cls_head.py new file mode 100644 index 0000000000000000000000000000000000000000..26c01ac3c61170bdc8dab2377795b1a75e3fd881 --- /dev/null +++ b/mmpl/models/heads/cls_head.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmcls.evaluation.metrics import Accuracy +from mmcls.registry import MODELS +from mmcls.structures import ClsDataSample +from .base_head import BaseHead + + +@MODELS.register_module() +class ClsHead(BaseHead): + """Classification head. + + Args: + loss (dict): Config of classification loss. Defaults to + ``dict(type='CrossEntropyLoss', loss_weight=1.0)``. + topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. + cal_acc (bool): Whether to calculate accuracy during training. + If you use batch augmentations like Mixup and CutMix during + training, it is pointless to calculate accuracy. + Defaults to False. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0), + topk: Union[int, Tuple[int]] = (1, ), + cal_acc: bool = False, + init_cfg: Optional[dict] = None): + super(ClsHead, self).__init__(init_cfg=init_cfg) + + self.topk = topk + if not isinstance(loss, nn.Module): + loss = MODELS.build(loss) + self.loss_module = loss + self.cal_acc = cal_acc + + def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The process before the final classification head. + + The input ``feats`` is a tuple of tensor, and each tensor is the + feature of a backbone stage. In ``ClsHead``, we just obtain the feature + of the last stage. + """ + # The ClsHead doesn't have other module, just return after unpacking. + return feats[-1] + + def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + """The forward process.""" + pre_logits = self.pre_logits(feats) + # The ClsHead doesn't have the final classification head, + # just return the unpacked inputs. + return pre_logits + + def loss(self, feats: Tuple[torch.Tensor], + data_samples: List[ClsDataSample], **kwargs) -> dict: + """Calculate losses from the classification score. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + **kwargs: Other keyword arguments to forward the loss module. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + # The part can be traced by torch.fx + cls_score = self(feats) + # import pdb + # pdb.set_trace() + # The part can not be traced by torch.fx + losses = self._get_loss(cls_score, data_samples, **kwargs) + return losses + + def _get_loss(self, cls_score: torch.Tensor, + data_samples: List[ClsDataSample], **kwargs): + """Unpack data samples and compute loss.""" + # Unpack data samples and pack targets + if 'score' in data_samples[0].gt_label: + # Batch augmentation may convert labels to one-hot format scores. + target = torch.stack([i.gt_label.score for i in data_samples]) + else: + target = torch.cat([i.gt_label.label for i in data_samples]) + + # compute loss + losses = dict() + loss = self.loss_module( + cls_score, target, avg_factor=cls_score.size(0), **kwargs) + losses['loss'] = loss + + # compute accuracy + if self.cal_acc: + assert target.ndim == 1, 'If you enable batch augmentation ' \ + 'like mixup during training, `cal_acc` is pointless.' + acc = Accuracy.calculate(cls_score, target, topk=self.topk) + losses.update( + {f'accuracy_top-{k}': a + for k, a in zip(self.topk, acc)}) + + return losses + + def predict( + self, + feats: Tuple[torch.Tensor], + data_samples: List[Union[ClsDataSample, None]] = None + ) -> List[ClsDataSample]: + """Inference without augmentation. + + Args: + feats (tuple[Tensor]): The features extracted from the backbone. + Multiple stage inputs are acceptable but only the last stage + will be used to classify. The shape of every item should be + ``(num_samples, num_classes)``. + data_samples (List[ClsDataSample | None], optional): The annotation + data of every samples. If not None, set ``pred_label`` of + the input data samples. Defaults to None. + + Returns: + List[ClsDataSample]: A list of data samples which contains the + predicted results. + """ + # The part can be traced by torch.fx + cls_score = self(feats) + + # The part can not be traced by torch.fx + predictions = self._get_predictions(cls_score, data_samples) + return predictions + + def _get_predictions(self, cls_score, data_samples): + """Post-process the output of head. + + Including softmax and set ``pred_label`` of data samples. + """ + pred_scores = F.softmax(cls_score, dim=1) + pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() + + out_data_samples = [] + if data_samples is None: + data_samples = [None for _ in range(pred_scores.size(0))] + + for data_sample, score, label in zip(data_samples, pred_scores, + pred_labels): + if data_sample is None: + data_sample = ClsDataSample() + + data_sample.set_pred_score(score).set_pred_label(label) + out_data_samples.append(data_sample) + return out_data_samples diff --git a/mmpl/models/heads/sam_instance_head.py b/mmpl/models/heads/sam_instance_head.py new file mode 100644 index 0000000000000000000000000000000000000000..36bee5bb870f3b23fc9d90885f57ac694e940465 --- /dev/null +++ b/mmpl/models/heads/sam_instance_head.py @@ -0,0 +1,1015 @@ +import copy +import warnings +from typing import List, Optional, Tuple, Union, Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models import BaseDetector, TwoStageDetector, StandardRoIHead, SinePositionalEncoding, FCNMaskHead, \ + BaseRoIHead +from mmdet.models.task_modules import SamplingResult +from mmdet.models.utils import multi_apply, unpack_gt_instances, empty_instances +from mmdet.structures import SampleList, DetDataSample +from mmdet.structures.bbox import bbox2roi +from mmdet.structures.mask import mask_target +from mmdet.utils import InstanceList, reduce_mean, OptMultiConfig +from mmpl.registry import MODELS, TASK_UTILS +from mmengine.model import BaseModel, BaseModule +from einops import rearrange, repeat +from mmpl.utils import ConfigType, OptConfigType +from mmdet.models.dense_heads import Mask2FormerHead +from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead + +@MODELS.register_module() +class SAMInstanceHead(Mask2FormerHead): + def __init__( + self, + num_things_classes: int = 1, + num_stuff_classes: int = 0, + prompt_neck: ConfigType = ..., + with_iou: bool = False, + with_multiscale: bool = False, + with_sincos: bool = False, + with_res_imgfeat: bool = False, + loss_cls: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * 133 + [0.1]), + loss_mask: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice: ConfigType = dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU', inplace=True), + **kwargs + ): + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.with_iou = with_iou + self.with_multiscale = with_multiscale + self.with_sincos = with_sincos + self.with_res_imgfeat = with_res_imgfeat + + # self.num_transformer_feat_level = num_transformer_feat_level + # self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads + # self.num_transformer_decoder_layers = transformer_decoder.num_layers + # assert pixel_decoder.encoder.layer_cfg. \ + # self_attn_cfg.num_levels == num_transformer_feat_level + # pixel_decoder_ = copy.deepcopy(pixel_decoder) + # pixel_decoder_.update( + # in_channels=in_channels, + # feat_channels=feat_channels, + # out_channels=out_channels) + # self.pixel_decoder = MODELS.build(pixel_decoder_) + # self.transformer_decoder = Mask2FormerTransformerDecoder( + # **transformer_decoder) + # self.decoder_embed_dims = self.transformer_decoder.embed_dims + # + # self.decoder_input_projs = ModuleList() + # # from low resolution to high resolution + # for _ in range(num_transformer_feat_level): + # if (self.decoder_embed_dims != feat_channels + # or enforce_decoder_input_project): + # self.decoder_input_projs.append( + # Conv2d( + # feat_channels, self.decoder_embed_dims, kernel_size=1)) + # else: + # self.decoder_input_projs.append(nn.Identity()) + # self.decoder_positional_encoding = SinePositionalEncoding( + # **positional_encoding) + # self.query_embed = nn.Embedding(self.num_queries, feat_channels) + # self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # # from low resolution to high resolution + # self.level_embed = nn.Embedding(self.num_transformer_feat_level, + # feat_channels) + # + # self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + # self.mask_embed = nn.Sequential( + # nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + # nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + # nn.Linear(feat_channels, out_channels)) + + self.prompt_neck = MODELS.build(prompt_neck) + self.num_queries = self.prompt_neck.num_queries + self.per_query_point = self.prompt_neck.per_query_point + out_channels = self.prompt_neck.out_channels + + self.cls_embed = nn.Sequential( + nn.Linear(out_channels, out_channels // 2), + nn.ReLU(inplace=True), + nn.Linear(out_channels // 2, self.num_classes + 1) + ) + + if self.with_sincos: + self.point_emb = nn.Sequential( + nn.Linear(out_channels, out_channels), + nn.ReLU(inplace=True), + nn.Linear(out_channels, out_channels), + nn.ReLU(inplace=True), + nn.Linear(out_channels, self.per_query_point * out_channels*2) + ) + else: + self.point_emb = nn.Sequential( + nn.Linear(out_channels, out_channels), + nn.ReLU(inplace=True), + nn.Linear(out_channels, out_channels), + nn.ReLU(inplace=True), + nn.Linear(out_channels, self.per_query_point * out_channels) + ) + + if self.with_res_imgfeat: + self.res_imgfeat = nn.Sequential( + nn.UpsamplingBilinear2d(scale_factor=2), + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg + ) + ) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + + def forward(self, x: List[Tensor], + batch_data_samples: SampleList, + sam + ) -> Tuple[List[Tensor]]: + """Forward function. + + Args: + x (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[list[Tensor]]: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + batch_size = len(batch_img_metas) + decoder_out, query_feat_list, res_img_feat = self.prompt_neck(x) + + if self.with_multiscale: + cls_pred_list = [self.cls_embed(query_feat) for query_feat in query_feat_list] + else: + # shape (batch_size, num_queries, c) + cls_pred_list = [self.cls_embed(decoder_out)] + # shape (batch_size, num_queries, c) + point_emb = self.point_emb(decoder_out) + # shape (batch_size, num_queries, per_query_point, c) + point_emb = point_emb.view(batch_size, self.num_queries, self.per_query_point, -1) + + img_seg_feat = x[0] + point_emb = rearrange(point_emb, 'b n p c -> (b n) p c') + if self.with_sincos: + point_emb = torch.sin(point_emb[..., ::2]) + point_emb[..., 1::2] + + nomask_dense_embeddings = sam.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + point_emb.shape[0], -1, *img_seg_feat.shape[-2:] + ) + + img_embeddings = torch.repeat_interleave(img_seg_feat, self.num_queries, dim=0) + img_pe = sam.prompt_encoder.get_dense_pe() + img_pe = repeat(img_pe, 'b c h w -> (b n) c h w', n=img_embeddings.shape[0]) + + if self.with_res_imgfeat: + res_img_feat = self.res_imgfeat(res_img_feat) + res_img_feat = torch.repeat_interleave(res_img_feat, self.num_queries, dim=0) + else: + res_img_feat = None + + low_res_masks, iou_predictions = sam.mask_decoder.forward_batch( + image_embeddings=img_embeddings, + image_pe=img_pe, + sparse_prompt_embeddings=point_emb, + dense_prompt_embeddings=nomask_dense_embeddings, + multimask_output=False, + res_img_feat=res_img_feat, + ) + mask_pred = rearrange(low_res_masks.squeeze(1), '(b n) h w -> b n h w', b=batch_size) + + # optional + # if self.with_iou: + # iou_predictions = iou_predictions.view(batch_size, self.num_queries, -1) + # cls_pred = cls_pred * iou_predictions + + if self.with_multiscale: + mask_pred_list = [mask_pred] * len(cls_pred_list) + else: + mask_pred_list = [mask_pred] + + return cls_pred_list, mask_pred_list + + def predict(self, x: Tuple[Tensor], + batch_data_samples: SampleList, + sam + ) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two tensors. + + - mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + - mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). + """ + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + all_cls_scores, all_mask_preds = self(x, batch_data_samples, sam) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + + return mask_cls_results, mask_pred_results + + def loss( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + sam, + ) -> Dict[str, Tensor]: + """Perform forward propagation and loss calculation of the panoptic + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + batch_img_metas = [] + batch_gt_instances = [] + batch_gt_semantic_segs = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if 'gt_sem_seg' in data_sample: + batch_gt_semantic_segs.append(data_sample.gt_sem_seg) + else: + batch_gt_semantic_segs.append(None) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples, sam) + + # preprocess ground truth + batch_gt_instances = self.preprocess_gt(batch_gt_instances, + batch_gt_semantic_segs) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + +@MODELS.register_module() +class SAMAnchorInstanceHead(TwoStageDetector): + def __init__( + self, + sam_head=True, + neck: OptConfigType = None, + rpn_head: OptConfigType = None, + roi_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + **kwargs + ): + super(TwoStageDetector, self).__init__() + self.neck = MODELS.build(neck) + self.sam_head = sam_head + + if rpn_head is not None: + rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None + rpn_head_ = rpn_head.copy() + rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) + rpn_head_num_classes = rpn_head_.get('num_classes', None) + if rpn_head_num_classes is None: + rpn_head_.update(num_classes=1) + else: + if rpn_head_num_classes != 1: + warnings.warn( + 'The `num_classes` should be 1 in RPN, but get ' + f'{rpn_head_num_classes}, please set ' + 'rpn_head.num_classes = 1 in your config file.') + rpn_head_.update(num_classes=1) + self.rpn_head = MODELS.build(rpn_head_) + + if roi_head is not None: + # update train and test cfg here for now + # TODO: refactor assigner & sampler + rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None + roi_head.update(train_cfg=rcnn_train_cfg) + roi_head.update(test_cfg=test_cfg.rcnn) + self.roi_head = MODELS.build(roi_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def extract_feat(self, x): + x = self.neck(x) + return x + + def loss(self, + batch_inputs, + batch_data_samples: SampleList, + sam + ) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + x = self.extract_feat(batch_inputs) + img_seg_feat = batch_inputs[0] + losses = dict() + + # RPN forward and loss + if self.with_rpn: + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + rpn_data_samples = copy.deepcopy(batch_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( + x, rpn_data_samples, proposal_cfg=proposal_cfg) + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in list(keys): + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + losses.update(rpn_losses) + else: + assert batch_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second stage + # to extract ROI features. + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + if self.sam_head: + roi_losses = self.roi_head.loss(x, rpn_results_list, + batch_data_samples, + sam, img_seg_feat + ) + else: + roi_losses = self.roi_head.loss(x, rpn_results_list, + batch_data_samples + ) + losses.update(roi_losses) + + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + sam, + rescale: bool = True + ) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Return the detection results of the + input images. The returns value is DetDataSample, + which usually contain 'pred_instances'. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + + assert self.with_bbox, 'Bbox head must be implemented.' + x = self.extract_feat(batch_inputs) + img_seg_feat = batch_inputs[0] + + # If there are no pre-defined proposals, use RPN to get proposals + if batch_data_samples[0].get('proposals', None) is None: + rpn_results_list = self.rpn_head.predict( + x, batch_data_samples, rescale=False) + else: + rpn_results_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + if self.sam_head: + results_list = self.roi_head.predict( + x, rpn_results_list, batch_data_samples, sam, img_seg_feat, rescale=rescale) + else: + results_list = self.roi_head.predict( + x, rpn_results_list, batch_data_samples, rescale=rescale) + + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + +@MODELS.register_module() +class SAMAnchorPromptRoIHead(StandardRoIHead): + def __init__( + self, + positional_encoding=dict(num_feats=128, normalize=True), + *args, + **kwargs + ): + super(StandardRoIHead, self).__init__(*args, **kwargs) + self.generator_pe = SinePositionalEncoding(**positional_encoding) + + def _mask_forward(self, + x: Tuple[Tensor], + rois: Tensor = None, + pos_inds: Optional[Tensor] = None, + bbox_feats: Optional[Tensor] = None, + sam=None, img_seg_feat=None + ) -> dict: + """Mask head forward function used in both training and testing. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + pos_inds (Tensor, optional): Indices of positive samples. + Defaults to None. + bbox_feats (Tensor): Extract bbox RoI features. Defaults to None. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `mask_feats` (Tensor): Extract mask RoI features. + """ + assert ((rois is not None) ^ + (pos_inds is not None and bbox_feats is not None)) + if rois is not None: + mask_feats = self.mask_roi_extractor( + x[:self.mask_roi_extractor.num_inputs], rois) + if self.with_shared_head: + mask_feats = self.shared_head(mask_feats) + else: + assert bbox_feats is not None + mask_feats = bbox_feats[pos_inds] + + mask_preds = self.mask_head(mask_feats, sam, img_seg_feat, img_flag_ids=rois[:, 0]) + mask_results = dict(mask_preds=mask_preds[0], mask_iou=mask_preds[1], mask_feats=mask_feats) + return mask_results + + def mask_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult], bbox_feats: Tensor, + batch_gt_instances: InstanceList, + sam, img_seg_feat + ) -> dict: + """Perform forward propagation and loss calculation of the mask head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + bbox_feats (Tensor): Extract bbox RoI features. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `mask_feats` (Tensor): Extract mask RoI features. + - `mask_targets` (Tensor): Mask target of each positive\ + proposals in the image. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + if not self.share_roi_extractor: + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward( + x, pos_rois, sam=sam, img_seg_feat=img_seg_feat) + else: + pos_inds = [] + device = bbox_feats.device + for res in sampling_results: + pos_inds.append( + torch.ones( + res.pos_priors.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds.append( + torch.zeros( + res.neg_priors.shape[0], + device=device, + dtype=torch.uint8)) + pos_inds = torch.cat(pos_inds) + + mask_results = self._mask_forward( + x, pos_inds=pos_inds, bbox_feats=bbox_feats) + + mask_loss_and_target = self.mask_head.loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg) + + mask_results.update(loss_mask=mask_loss_and_target['loss_mask']) + return mask_results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: List[DetDataSample], + sam, img_seg_feat + ) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + x = list(x) + bs, _, h, w = x[-1].shape + mask_pe = torch.zeros((bs, h, w), device=x[0].device, dtype=torch.bool) + img_feats_pe = self.generator_pe(mask_pe) + for i in range(len(x)): + x[i] = x[i] + torch.nn.functional.interpolate(img_feats_pe, size=x[i].shape[-2:], mode='bilinear') + + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, _ = outputs + + # assign gts and sample proposals + num_imgs = len(batch_data_samples) + sampling_results = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + losses = dict() + # bbox head loss + if self.with_bbox: + bbox_results = self.bbox_loss(x, sampling_results) + losses.update(bbox_results['loss_bbox']) + + # mask head forward and loss + if self.with_mask: + mask_results = self.mask_loss(x, sampling_results, + bbox_results['bbox_feats'], + batch_gt_instances, + sam, img_seg_feat + ) + losses.update(mask_results['loss_mask']) + + return losses + + + def predict_mask(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + results_list: InstanceList, + rescale: bool = False, + sam=None, img_seg_feat=None + ) -> InstanceList: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + # don't need to consider aug_test. + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + mask_results = self._mask_forward(x, mask_rois, sam=sam, img_seg_feat=img_seg_feat) + mask_preds = mask_results['mask_preds'] + # split batch mask prediction back to each image + num_mask_rois_per_img = [len(res) for res in results_list] + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + + # TODO: Handle the case where rescale is false + results_list = self.mask_head.predict_by_feat( + mask_preds=mask_preds, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale) + return results_list + + def predict(self, + x: Tuple[Tensor], + rpn_results_list: InstanceList, + batch_data_samples: SampleList, + sam, img_seg_feat, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the roi head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Features from upstream network. Each + has shape (N, C, H, W). + rpn_results_list (list[:obj:`InstanceData`]): list of region + proposals. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results to + the original image. Defaults to True. + + Returns: + list[obj:`InstanceData`]: Detection results of each image. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + x = list(x) + bs, _, h, w = x[-1].shape + mask_pe = torch.zeros((bs, h, w), device=x[0].device, dtype=torch.bool) + img_feats_pe = self.generator_pe(mask_pe) + for i in range(len(x)): + x[i] = x[i] + torch.nn.functional.interpolate(img_feats_pe, size=x[i].shape[-2:], mode='bilinear') + + assert self.with_bbox, 'Bbox head must be implemented.' + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + + # TODO: nms_op in mmcv need be enhanced, the bbox result may get + # difference when not rescale in bbox_head + + # If it has the mask branch, the bbox branch does not need + # to be scaled to the original image scale, because the mask + # branch will scale both bbox and mask at the same time. + bbox_rescale = rescale if not self.with_mask else False + results_list = self.predict_bbox( + x, + batch_img_metas, + rpn_results_list, + rcnn_test_cfg=self.test_cfg, + rescale=bbox_rescale) + + if self.with_mask: + results_list = self.predict_mask( + x, batch_img_metas, results_list, rescale=rescale, sam=sam, img_seg_feat=img_seg_feat) + + return results_list + + +@MODELS.register_module() +class SAMPromptMaskHead(FCNMaskHead): + + def __init__(self, + per_query_point: int = 5, + with_sincos: bool = True, + class_agnostic: bool = False, + loss_mask: ConfigType = dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), + *args, + **kwargs + ) -> None: + super(BaseModule, self).__init__() + + self.per_query_point = per_query_point + self.with_sincos = with_sincos + self.class_agnostic = class_agnostic + + self.loss_mask = MODELS.build(loss_mask) + + if with_sincos: + sincos = 2 + else: + sincos = 1 + self.point_emb = nn.Sequential( + nn.Conv2d(256, 256, 3, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Flatten(), + nn.Linear(7*7*256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256*sincos*per_query_point) + ) + + def forward(self, x, sam, img_seg_feat, img_flag_ids) -> Tensor: + batch_size = x.shape[0] + point_emb = self.point_emb(x) + point_emb = point_emb.view(batch_size, self.per_query_point, -1) + if self.with_sincos: + point_emb = torch.sin(point_emb[..., ::2]) + point_emb[..., 1::2] + + nomask_dense_embeddings = sam.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + point_emb.shape[0], -1, *img_seg_feat.shape[-2:] + ) + img_flag_ids = torch.bincount(img_flag_ids.long()) + padding = torch.zeros((len(img_seg_feat)-len(img_flag_ids),), device=img_flag_ids.device, dtype=img_flag_ids.dtype) + img_flag_ids = torch.cat([img_flag_ids, padding]) + img_embeddings = torch.repeat_interleave(img_seg_feat, img_flag_ids, dim=0) + img_pe = sam.prompt_encoder.get_dense_pe() + img_pe = repeat(img_pe, 'b c h w -> (b n) c h w', n=img_embeddings.shape[0]) + + res_img_feat = None + low_res_masks, iou_predictions = sam.mask_decoder.forward_batch( + image_embeddings=img_embeddings, + image_pe=img_pe, + sparse_prompt_embeddings=point_emb, + dense_prompt_embeddings=nomask_dense_embeddings, + multimask_output=False, + res_img_feat=res_img_feat, + ) + mask_pred = low_res_masks.squeeze(1) + iou_predictions = iou_predictions.squeeze(1) + return mask_pred, iou_predictions + + def get_targets(self, sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> Tensor: + """Calculate the ground truth for all samples in a batch according to + the sampling_results. + + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + Tensor: Mask target of each positive proposals in the image. + """ + pos_proposals = [res.pos_priors for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + gt_masks = [res.masks for res in batch_gt_instances] + + mask_targets_list = [] + mask_size = (rcnn_train_cfg.mask_size,) * 2 + device = pos_proposals[0].device + for pos_gt_inds, gt_mask in zip(pos_assigned_gt_inds, gt_masks): + if len(pos_gt_inds) == 0: + mask_targets = torch.zeros((0,) + mask_size, device=device, dytpe=torch.float32) + else: + mask_targets = gt_mask[pos_gt_inds.cpu()].to_tensor(dtype=torch.float32, device=device) + mask_targets_list.append(mask_targets) + mask_targets = torch.cat(mask_targets_list) + return mask_targets + + def loss_and_target(self, mask_preds: Tensor, + sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> dict: + """Calculate the loss based on the features extracted by the mask head. + + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (num_pos, num_classes, h, w). + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + + Returns: + dict: A dictionary of loss and targets components. + """ + mask_targets = self.get_targets( + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=rcnn_train_cfg) + + pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) + + mask_preds = torch.nn.functional.interpolate( + mask_preds.unsqueeze(1), size=mask_targets.shape[-2:], mode='bilinear', align_corners=False) + loss = dict() + if mask_preds.size(0) == 0: + loss_mask = mask_preds.sum() + else: + if self.class_agnostic: + loss_mask = self.loss_mask(mask_preds, mask_targets, + torch.zeros_like(pos_labels)) + else: + loss_mask = self.loss_mask(mask_preds, mask_targets, + pos_labels) + loss['loss_mask'] = loss_mask + # TODO: which algorithm requires mask_targets? + return dict(loss_mask=loss, mask_targets=mask_targets) + + def _predict_by_feat_single(self, + mask_preds: Tensor, + bboxes: Tensor, + labels: Tensor, + img_meta: dict, + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False) -> Tensor: + """Get segmentation masks from mask_preds and bboxes. + + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (n, num_classes, h, w). + bboxes (Tensor): Predicted bboxes, has shape (n, 4) + labels (Tensor): Labels of bboxes, has shape (n, ) + img_meta (dict): image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + + Returns: + Tensor: Encoded masks, has shape (n, img_w, img_h) + + Example: + >>> from mmengine.config import Config + >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA + >>> N = 7 # N = number of extracted ROIs + >>> C, H, W = 11, 32, 32 + >>> # Create example instance of FCN Mask Head. + >>> self = FCNMaskHead(num_classes=C, num_convs=0) + >>> inputs = torch.rand(N, self.in_channels, H, W) + >>> mask_preds = self.forward(inputs) + >>> # Each input is associated with some bounding box + >>> bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) + >>> labels = torch.randint(0, C, size=(N,)) + >>> rcnn_test_cfg = Config({'mask_thr_binary': 0, }) + >>> ori_shape = (H * 4, W * 4) + >>> scale_factor = (1, 1) + >>> rescale = False + >>> img_meta = {'scale_factor': scale_factor, + ... 'ori_shape': ori_shape} + >>> # Encoded masks are a list for each category. + >>> encoded_masks = self._get_seg_masks_single( + ... mask_preds, bboxes, labels, + ... img_meta, rcnn_test_cfg, rescale) + >>> assert encoded_masks.size()[0] == N + >>> assert encoded_masks.size()[1:] == ori_shape + """ + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + img_h, img_w = img_meta['ori_shape'][:2] + device = bboxes.device + + if not activate_map: + mask_preds = mask_preds.sigmoid() + else: + # In AugTest, has been activated before + mask_preds = bboxes.new_tensor(mask_preds) + + if rescale: # in-placed rescale the bboxes + bboxes /= scale_factor + else: + w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] + img_h = np.round(img_h * h_scale.item()).astype(np.int32) + img_w = np.round(img_w * w_scale.item()).astype(np.int32) + + threshold = rcnn_test_cfg.mask_thr_binary + + im_mask = torch.nn.functional.interpolate( + mask_preds.unsqueeze(1), size=(img_h, img_w), mode='bilinear', align_corners=False).squeeze(1) + + if threshold >= 0: + im_mask = im_mask >= threshold + else: + # for visualization and debugging + im_mask = (im_mask * 255).to(dtype=torch.uint8) + return im_mask \ No newline at end of file diff --git a/mmpl/models/heads/sam_semseg_head.py b/mmpl/models/heads/sam_semseg_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f674c77e19a02a26bd6223480b54b055e4b59411 --- /dev/null +++ b/mmpl/models/heads/sam_semseg_head.py @@ -0,0 +1,226 @@ +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from typing import Tuple, List + +from torch import Tensor + +from mmpl.registry import MODELS +from mmseg.models import build_loss +from mmseg.models.utils import resize +from mmseg.structures import build_pixel_sampler +from mmseg.utils import SampleList, ConfigType + + +@MODELS.register_module() +class SamSemSegHead(BaseModule): + def __init__(self, + in_channels=2, + inner_channels=None, + num_classes=1, + ignore_index=255, + threshold=None, + out_channels=None, + loss_decode=None, + sampler=None, + align_corners=False, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU', inplace=True), + train_cfg=None, + test_cfg=None, + ): + super().__init__() + self.in_channels = in_channels + self.ignore_index = ignore_index + self.align_corners = align_corners + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert `seg_logits` into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = num_classes + 1 + self.threshold = threshold + + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(MODELS.build(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + if inner_channels is None: + self.down_conv = nn.ModuleList([nn.Identity(), nn.Identity()]) + else: + self.down_conv = nn.ModuleList([ + nn.Conv2d(in_channels, inner_channels, 1), + nn.Conv2d(in_channels, inner_channels, 1) + ]) + in_channels = inner_channels + + self.cls_seg = nn.Sequential( + nn.Linear(in_channels, in_channels // 2), + nn.ReLU(inplace=True), + nn.Linear(in_channels // 2, self.out_channels) + ) + + self.up_conv = nn.Sequential( + nn.UpsamplingBilinear2d(scale_factor=2), + ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + ) + + + def forward(self, inputs): + """Forward function.""" + x0, x1 = inputs + x0 = self.down_conv[0](x0) + x1 = self.down_conv[1](x1) + + gate_x0 = torch.sigmoid(x0) # B N H W + x1 = torch.einsum('bnhw,bchw->bnchw', gate_x0, x1) + x1 = torch.mean(x1, dim=(-2, -1)) + x1 = self.cls_seg(x1) # B N K + x0 = self.up_conv(x0) # B N H W + seg_logits = torch.einsum('bnhw,bnk->bkhw', x0, x1) + return seg_logits + + def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType=None) -> dict: + seg_logits = self.forward(inputs) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + return losses + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType=None) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + seg_logits = self.forward(inputs) + + return self.predict_by_feat(seg_logits, batch_img_metas) + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + + seg_logits = resize( + input=seg_logits, + size=batch_img_metas[0]['img_shape'], + mode='bilinear', + align_corners=self.align_corners) + return seg_logits + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + return torch.stack(gt_semantic_segs, dim=0) + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + seg_logits = resize( + input=seg_logits, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logits, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + return loss diff --git a/mmpl/models/heads/seg_upfcn_head.py b/mmpl/models/heads/seg_upfcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..74f32c8c1ff393ddc0963994e65a73003285f790 --- /dev/null +++ b/mmpl/models/heads/seg_upfcn_head.py @@ -0,0 +1,181 @@ +import einops +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.models import build_loss +from mmpl.registry import MODELS +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +import torch.nn.functional as F + + +@MODELS.register_module() +class UpFCNHead(BaseModule): + """Fully Convolution Networks for Semantic Segmentation. + + This head is implemented of `FCNNet `_. + + Args: + num_convs (int): Number of convs in the head. Default: 2. + kernel_size (int): The kernel size for convs in the head. Default: 3. + concat_input (bool): Whether concat the input and output of convs + before classification layer. + dilation (int): The dilation rate for convs in the head. Default: 1. + """ + + def __init__(self, + in_channels, + mid_channels=[256, 128, 64], + num_classes=2, + kernel_size=3, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + align_corners=False, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + **kwargs): + super().__init__(**kwargs) + self.in_channels = in_channels + self.mid_channels = mid_channels + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.kernel_size = kernel_size + self.num_classes = num_classes + self.align_corners = align_corners + + if isinstance(in_channels, list): + self.pre_layers = nn.ModuleList() + inner_channel = mid_channels[0] + for idx, channel in enumerate(in_channels): + self.pre_layers.append( + nn.Sequential( + ConvModule( + channel, + inner_channel, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channel, + inner_channel, + kernel_size=kernel_size, + padding=kernel_size // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ) + ) + self.pre_layers.append( + nn.Sequential( + ConvModule( + inner_channel*len(in_channels), + inner_channel, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channel, + inner_channel, + kernel_size=kernel_size, + padding=kernel_size // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ) + ) + input_channel = inner_channel + else: + input_channel = in_channels + + convs = [] + for idx, mid_channel in enumerate(mid_channels): + in_channel = input_channel if idx == 0 else mid_channels[idx-1] + convs += [ + ConvModule( + in_channel, + mid_channel, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + mid_channel, + mid_channel, + kernel_size=kernel_size, + padding=kernel_size // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ] + self.convs = nn.Sequential(*convs) + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + self.conv_seg = nn.Conv2d(mid_channels[-1], num_classes, kernel_size=1) + + def _forward_feature(self, img_feat, inner_states): + if hasattr(self, 'pre_layers'): + inner_states = inner_states[-len(self.in_channels):] + inner_states = [einops.rearrange(x, 'b h w c -> b c h w') for x in inner_states] + inner_states = [layer(x) for layer, x in zip(self.pre_layers[:-1], inner_states)] + img_feat = self.pre_layers[-1](torch.cat(inner_states, dim=1)) + feats = self.convs(img_feat) + return feats + + def forward(self, img_feat, inner_states): + """Forward function.""" + output = self._forward_feature(img_feat, inner_states) + output = self.conv_seg(output) + return output + + def loss(self, img_feat, inner_states, batch_data_samples) -> dict: + """Forward function for training. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(img_feat, inner_states) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + return losses + + def _stack_batch_gt(self, batch_data_samples): + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + return torch.stack(gt_semantic_segs, dim=0) + + def loss_by_feat(self, seg_logits, batch_data_samples) -> dict: + """Compute segmentation loss. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + seg_label = self._stack_batch_gt(batch_data_samples) + losses = dict() + seg_logits = F.interpolate(seg_logits, seg_label.shape[-2:], mode='bilinear', align_corners=self.align_corners) + seg_label = seg_label.squeeze(1) + losses['loss_ce'] = self.loss_decode(seg_logits, seg_label) + return losses + + def predict(self, img_feat, inner_states): + seg_logits = self.forward(img_feat, inner_states) + return seg_logits diff --git a/mmpl/models/heads/semantic_seg_head.py b/mmpl/models/heads/semantic_seg_head.py new file mode 100644 index 0000000000000000000000000000000000000000..99acee3074b0475539635b6aab6d2505375bad59 --- /dev/null +++ b/mmpl/models/heads/semantic_seg_head.py @@ -0,0 +1,216 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.utils import multi_apply +from mmdet.utils import InstanceList, reduce_mean +from mmpl.registry import MODELS, TASK_UTILS +from mmengine.model import BaseModel +from einops import rearrange + +from mmpl.utils import ConfigType, OptConfigType + + +@MODELS.register_module() +class BinarySemanticSegHead(BaseModel): + def __init__( + self, + num_classes=1, + align_corners=False, + loss_mask: ConfigType = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: Optional[dict] = None): + super(BinarySemanticSegHead, self).__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.align_corners = align_corners + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.loss_mask = MODELS.build(loss_mask) + if loss_dice is not None: + self.loss_dice = MODELS.build(loss_dice) + + def forward(self, *args, **kwargs): + pass + return + + def loss(self, + mask_preds: Tensor, + seg_labels: Tensor, + ): + bs = mask_preds.size(0) + + # dice loss + if hasattr(self, 'loss_dice'): + loss_dice = self.loss_dice(mask_preds, seg_labels, avg_factor=bs) + else: + loss_dice = torch.zeros([]).to(mask_preds.device) + + # mask loss + # FocalLoss support input of shape (n, num_class) + h, w = mask_preds.shape[-2:] + # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1) + mask_preds = mask_preds.reshape(-1, 1) + # shape (num_total_gts, h, w) -> (num_total_gts * h * w) + mask_targets = seg_labels.reshape(-1, 1) + # target is (1 - mask_targets) !!! + loss_mask = self.loss_mask(mask_preds, mask_targets, avg_factor=h * w) + + loss_dict = dict() + loss_dict['loss_mask'] = loss_mask + loss_dict['loss_dice'] = loss_dice + return loss_dict + + def get_targets( + self, + cls_scores_list: List[Tensor], + mask_preds_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + return_sampling_results: bool = False + ) -> Tuple[List[Union[Tensor, int]]]: + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + return_sampling_results (bool): Whether to return the sampling + results. Defaults to False. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - label_weights_list (list[Tensor]): Label weights\ + of all images. Each with shape (num_queries, ). + - mask_targets_list (list[Tensor]): Mask targets of\ + all images. Each with shape (num_queries, h, w). + - mask_weights_list (list[Tensor]): Mask weights of\ + all images. Each with shape (num_queries, ). + - avg_factor (int): Average factor that is used to average\ + the loss. When using sampling method, avg_factor is + usually the sum of positive and negative priors. When + using `MaskPseudoSampler`, `avg_factor` is usually equal + to the number of positive priors. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end. + """ + results = multi_apply(self._get_targets_single, cls_scores_list, + mask_preds_list, batch_gt_instances, + batch_img_metas) + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + pos_inds_list, neg_inds_list, sampling_results_list) = results[:7] + rest_results = list(results[7:]) + + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + + res = (labels_list, label_weights_list, mask_targets_list, + mask_weights_list, avg_factor) + if return_sampling_results: + res = res + (sampling_results_list) + + return res + tuple(rest_results) + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> Tuple[Tensor]: + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + img_meta (dict): Image informtation. + + Returns: + tuple: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + gt_masks = gt_instances.masks + gt_labels = gt_instances.labels + + target_shape = mask_pred.shape[-2:] + if gt_masks.shape[0] > 0: + gt_masks_downsampled = F.interpolate( + gt_masks.unsqueeze(1).float(), target_shape, + mode='nearest').squeeze(1).long() + else: + gt_masks_downsampled = gt_masks + + pred_instances = InstanceData(scores=cls_score, masks=mask_pred) + downsampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_masks_downsampled) + # assign and sample # assign_result is the 1-based + assign_result = self.assigner.assign( + pred_instances=pred_instances, + gt_instances=downsampled_gt_instances, + img_meta=img_meta) + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + # 第0类为背景 + num_queries = pred_instances.scores.shape[0] + labels = gt_labels.new_full((num_queries, ), + 0, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones(num_queries) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds, sampling_result) + diff --git a/mmpl/models/layers/__init__.py b/mmpl/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmpl/models/layers/transformer_layers.py b/mmpl/models/layers/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..95e3ab189657a83facdf71bf08e2b4af2e2d371d --- /dev/null +++ b/mmpl/models/layers/transformer_layers.py @@ -0,0 +1,122 @@ +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.logging import print_log +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmpl.registry import MODELS + + +@MODELS.register_module() +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + attn_cfg=dict(), + ffn_cfg=dict(), + with_cp=False): + super().__init__() + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + attn_cfg.update( + dict( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + batch_first=batch_first, + bias=qkv_bias)) + + self.build_attn(attn_cfg) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + ffn_cfg.update( + dict( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) + if drop_path_rate > 0 else None, + act_cfg=act_cfg)) + self.build_ffn(ffn_cfg) + self.with_cp = with_cp + + def build_attn(self, attn_cfg): + self.attn = MultiheadAttention(**attn_cfg) + + def build_ffn(self, ffn_cfg): + self.ffn = FFN(**ffn_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + diff --git a/mmpl/models/losses/__init__.py b/mmpl/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b4a4d2d671022a6f5fe26caef7192f6a5b6c95c --- /dev/null +++ b/mmpl/models/losses/__init__.py @@ -0,0 +1,9 @@ +from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, + weighted_loss) + +# __all__ = [ +# 'cross_entropy', +# 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', +# 'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss', +# 'sigmoid_focal_loss', 'convert_to_one_hot', 'SmoothL1Loss' +# ] diff --git a/mmpl/models/losses/__pycache__/__init__.cpython-310.pyc b/mmpl/models/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f4281ed29fdd657bcdef4d2a978ecc44c261437 Binary files /dev/null and b/mmpl/models/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/models/losses/__pycache__/utils.cpython-310.pyc b/mmpl/models/losses/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec82af8755bcf5bea63d67918c6a9456f58b657e Binary files /dev/null and b/mmpl/models/losses/__pycache__/utils.cpython-310.pyc differ diff --git a/mmpl/models/losses/utils.py b/mmpl/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a65b68a6590aa3fe10a023022c9c9c9bce51f935 --- /dev/null +++ b/mmpl/models/losses/utils.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + ``loss_func(pred, target, **kwargs)``. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like ``loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)``. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper + + +def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: + """This function converts target class indices to one-hot vectors, given + the number of classes. + + Args: + targets (Tensor): The ground truth label of the prediction + with shape (N, 1) + classes (int): the number of classes. + + Returns: + Tensor: Processed loss values. + """ + assert (torch.max(targets).item() < + classes), 'Class Index must be less than number of classes' + one_hot_targets = F.one_hot( + targets.long().squeeze(-1), num_classes=classes) + return one_hot_targets diff --git a/mmpl/models/necks/__init__.py b/mmpl/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c13e7960a6d2fc1f09a16336c7a95550a4d68e6d --- /dev/null +++ b/mmpl/models/necks/__init__.py @@ -0,0 +1,11 @@ +from .transformer_neck import TransformerEncoderNeck +from .transformer_edecoder_neck import TransformerEDecoderNeck +from .linear_proj import LinearProj +from .hf_gpt_transformer_decoder import HFGPTTransformerDecoderNeck +from .sirens import Sirens, ModulatedSirens +from .sam_prompt_generator import SAMTransformerPromptGenNeck, SAMPromptConvNeck, SAMTransformerEDPromptGenNeck, SAMAggregatorNeck + +__all__ = [ + 'TransformerEncoderNeck', 'TransformerEDecoderNeck', 'LinearProj', + 'HFGPTTransformerDecoderNeck', 'Sirens', 'ModulatedSirens' +] diff --git a/mmpl/models/necks/__pycache__/__init__.cpython-310.pyc b/mmpl/models/necks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b8fe5d3f3b355b00d84ab30444637b60c69a45e Binary files /dev/null and b/mmpl/models/necks/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/models/necks/__pycache__/hf_gpt_transformer_decoder.cpython-310.pyc b/mmpl/models/necks/__pycache__/hf_gpt_transformer_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..254cdb6b0f2197041768fb7043a4ef87de82da79 Binary files /dev/null and b/mmpl/models/necks/__pycache__/hf_gpt_transformer_decoder.cpython-310.pyc differ diff --git a/mmpl/models/necks/__pycache__/linear_proj.cpython-310.pyc b/mmpl/models/necks/__pycache__/linear_proj.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d415595f8afa80a7a40c733648d91993e09cb9f Binary files /dev/null and b/mmpl/models/necks/__pycache__/linear_proj.cpython-310.pyc differ diff --git a/mmpl/models/necks/__pycache__/sam_prompt_generator.cpython-310.pyc b/mmpl/models/necks/__pycache__/sam_prompt_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6081f5fa7be74c10a2d941e0034e7f5cb3dcd6ad Binary files /dev/null and b/mmpl/models/necks/__pycache__/sam_prompt_generator.cpython-310.pyc differ diff --git a/mmpl/models/necks/__pycache__/sirens.cpython-310.pyc b/mmpl/models/necks/__pycache__/sirens.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46c639edef92fbe1dd43f9dda9bf7869f3714e8d Binary files /dev/null and b/mmpl/models/necks/__pycache__/sirens.cpython-310.pyc differ diff --git a/mmpl/models/necks/__pycache__/transformer_edecoder_neck.cpython-310.pyc b/mmpl/models/necks/__pycache__/transformer_edecoder_neck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ac7a34b5711a9dededab4fd5823cfea3797ad4d Binary files /dev/null and b/mmpl/models/necks/__pycache__/transformer_edecoder_neck.cpython-310.pyc differ diff --git a/mmpl/models/necks/__pycache__/transformer_neck.cpython-310.pyc b/mmpl/models/necks/__pycache__/transformer_neck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3e8affb9045c9b5c85aca857f9408b2b1367c9d Binary files /dev/null and b/mmpl/models/necks/__pycache__/transformer_neck.cpython-310.pyc differ diff --git a/mmpl/models/necks/hf_gpt_transformer_decoder.py b/mmpl/models/necks/hf_gpt_transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..935785664fbbedb0f28a173d7e88958787261d95 --- /dev/null +++ b/mmpl/models/necks/hf_gpt_transformer_decoder.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from mmengine.model import BaseModule +from mmengine.model.weight_init import constant_init +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm +from mmpl.registry import MODELS +from mmengine.model import BaseModule +from transformers import GPT2Model, GPT2Config + + +@MODELS.register_module() +class HFGPTTransformerDecoderNeck(BaseModule): + def __init__( + self, + model_name='gpt2', + from_pretrained=True, + update_kwargs=dict( + max_position_embeddings=512, + hidden_size=512, + ) + ): + super(HFGPTTransformerDecoderNeck, self).__init__() + self.model_name = model_name + if from_pretrained: + self.gpt_model = GPT2Model.from_pretrained(model_name) + else: + config = GPT2Config.from_pretrained(model_name) + config.update(update_kwargs) + self.gpt_model = GPT2Model(config=config) + # self.wte = nn.Embedding(0, 512) + + def forward(self, *args, **kwargs): + out_puts = self.gpt_model(*args, **kwargs) + return out_puts diff --git a/mmpl/models/necks/linear_proj.py b/mmpl/models/necks/linear_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..76a7243870dfde85d1a579548da13e841ecccfb0 --- /dev/null +++ b/mmpl/models/necks/linear_proj.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +from mmpl.registry import MODELS + + +@MODELS.register_module() +class LinearProj(nn.Module): + def __init__(self, in_channels, out_channels, base_channels=None, num_inner_layers=1): + super(LinearProj, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_inner_layers = num_inner_layers + if base_channels is None: + base_channels = out_channels + self.base_channels = base_channels + + layers = [nn.Linear(self.in_channels, self.base_channels), nn.ReLU(inplace=True)] + + for i in range(self.num_inner_layers): + layers.append(nn.Linear(self.base_channels, self.base_channels)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(self.base_channels, self.out_channels)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + x = self.layers(x) + return x diff --git a/mmpl/models/necks/sam_prompt_generator.py b/mmpl/models/necks/sam_prompt_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1a78f0c9926b3a0c60cb9285b9c27bf4c6346962 --- /dev/null +++ b/mmpl/models/necks/sam_prompt_generator.py @@ -0,0 +1,971 @@ +import copy +import math +from typing import Type, Tuple + +import einops +import torch +import torch.nn as nn +from einops import rearrange +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks.transformer import build_transformer_layer +from torch import Tensor + +from mmdet.models import SinePositionalEncoding +from mmpl.registry import MODELS +import torch.nn.functional as F + + +@MODELS.register_module() +class SAMTransformerPromptGenNeck(nn.Module): + def __init__( + self, + prompt_shape=(100, 6), + in_channels=[1280]*16, + out_channels=256, + positional_encoding=dict(num_feats=128, normalize=True), + n_classes=2, + kernel_size=3, + stride=1, + norm_cfg=None, + act_cfg=dict(type='ReLU') + ): + super(SAMTransformerPromptGenNeck, self).__init__() + self.in_channels = in_channels + self.kernel_size = kernel_size + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.out_put_channels = out_channels + self.n_classes = n_classes + self.stride = stride + + self.prompt_shape = prompt_shape + self.num_queries = prompt_shape[0] + self.per_query_point = prompt_shape[1] + + if isinstance(in_channels, list): + self.pre_layers = nn.ModuleList() + inner_channel = 32 + for idx, channel in enumerate(in_channels): + self.pre_layers.append( + nn.Sequential( + ConvModule( + channel, + inner_channel, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channel, + inner_channel*2, + kernel_size=kernel_size, + padding=kernel_size // 2, + stride=self.stride, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channel*2, + inner_channel, + kernel_size=kernel_size, + padding=kernel_size // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ) + ) + self.pre_layers.append( + nn.Sequential( + ConvModule( + inner_channel * len(in_channels), + out_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + out_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ) + ) + + self.generator_pe = SinePositionalEncoding(**positional_encoding) + self.transformer = self.build_transformer() + self.query_feat = nn.Embedding(self.num_queries, out_channels) + self.query_emb = nn.Embedding(self.num_queries, out_channels) + + self.output_upscaling = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.GELU(), + nn.UpsamplingBilinear2d(scale_factor=2), + nn.Conv2d(out_channels, out_channels // 4, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels // 4), + nn.GELU(), + nn.UpsamplingBilinear2d(scale_factor=2), + nn.Conv2d(out_channels // 4, out_channels // 8, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels // 8), + nn.GELU(), + nn.UpsamplingBilinear2d(scale_factor=2), + nn.Conv2d(out_channels // 8, out_channels // 8, kernel_size=3, padding=1), + ) + + self.cls_head = nn.Sequential( + nn.Linear(out_channels, out_channels//2), + nn.ReLU(), + nn.Linear(out_channels//2, n_classes) + ) + + # self.point_emb = nn.Sequential( + # nn.Linear(out_channels, out_channels), + # nn.ReLU(), + # nn.Linear(out_channels, out_channels), + # nn.ReLU(), + # nn.Linear(out_channels, self.per_query_point * out_channels) + # ) + self.output_hypernetworks_mlps = MLP(out_channels, out_channels, out_channels // 8, 3) + + + def build_transformer( + self, num_encoder_layers=2, num_decoder_layers=3, embed_dims=256, num_heads=8, + mlp_ratio=2, dropout_rate=0.0, act_cfg=dict(type="gelu")): + """Build transformer decoder.""" + # transformer = nn.Transformer( + # d_model=embed_dims, nhead=num_heads, num_encoder_layers=num_encoder_layers, + # num_decoder_layers=num_decoder_layers, dim_feedforward=mlp_ratio * embed_dims, + # dropout=dropout_rate, activation=act_cfg['type'], batch_first=True, norm_first=True, + # ) + transformer = Transformer(depth=2) + return transformer + + def init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, inputs, prompt_encoder, mask_decoder): + + img_embs, inner_states = inputs + if hasattr(self, 'pre_layers'): + inner_states = inner_states[-len(self.in_channels):] + inner_states = [einops.rearrange(x, 'b h w c -> b c h w') for x in inner_states] + inner_states = [layer(x) for layer, x in zip(self.pre_layers[:-1], inner_states)] + img_feats = self.pre_layers[-1](torch.cat(inner_states, dim=1)) + bs, c, h, w = img_feats.shape + mask_pe = torch.zeros((bs, h, w), device=img_feats.device) + img_feats_pe = self.generator_pe(mask_pe) + query_feat = self.query_feat.weight.unsqueeze(0).expand(bs, -1, -1) # Bx256x256 + query_emb = self.query_emb.weight.unsqueeze(0).expand(bs, -1, -1) + img_feats, query_feats = self.transformer( + image_embedding=img_feats, + image_pe=img_feats_pe, + point_embedding=query_feat, + point_pe=query_emb) + cls_logits = self.cls_head(query_feats) + # point_embs = self.point_emb(query_feats) + # point_embs = rearrange(point_embs, 'b n (t c) -> b n t c', t=self.per_query_point) # Bx100x6x256 + + src = img_feats.transpose(1, 2).view(bs, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in = self.output_hypernetworks_mlps(query_feats) + b, c, h, w = upscaled_embedding.shape + l1_masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # dense_masks = einops.rearrange(l1_masks, 'b (n t) h w -> (b n) t h w', t=1) + # sparse, dense = prompt_encoder(points=None, boxes=None, masks=dense_masks) + # dense = einops.rearrange(dense, '(b n) t h w -> b n t h w', n=self.num_queries) + + # l2_masks = [] + # iou_preds = [] + # for curr_embedding, sparse_embeddings, dense_embeddings in zip(img_embs, point_embs, dense): + # low_res_masks, iou_predictions = mask_decoder( + # image_embeddings=curr_embedding.unsqueeze(0), + # image_pe=prompt_encoder.get_dense_pe(), + # sparse_prompt_embeddings=sparse_embeddings, + # dense_prompt_embeddings=dense_embeddings + # ) + # l2_masks.append(low_res_masks[:, 0]) + # iou_preds.append(iou_predictions[:, 0]) + # l2_masks = torch.stack(l2_masks, dim=0) + # iou_preds = torch.stack(iou_preds, dim=0) + + l2_masks = None + iou_preds = None + + return cls_logits, l1_masks, l2_masks, iou_preds + + +@MODELS.register_module() +class SAMPromptConvNeck(nn.Module): + def __init__( + self, + prompt_shape=(100, 5), + img_feat_channels=1280, + out_put_channels=256, + num_img_feat_level=16, + n_cls=2, + ): + super(SAMPromptConvNeck, self).__init__() + self.prompt_shape = prompt_shape + self.num_queries = prompt_shape[0] + self.per_query_point = prompt_shape[1] + self.point_size = int(math.sqrt(prompt_shape[0])) + + self.img_feat_channels = img_feat_channels + self.out_put_channels = out_put_channels + self.num_img_feat_level = num_img_feat_level + self.n_cls = n_cls + + # decoder_embed_dims = img_feat_channels // 32 + decoder_embed_dims = 32 + self.decoder_input_projs = nn.ModuleList() + # from low resolution to high resolution + for _ in range(num_img_feat_level): + self.decoder_input_projs.append( + nn.Sequential( + nn.Conv2d(img_feat_channels, decoder_embed_dims, kernel_size=1), + # nn.BatchNorm2d(decoder_embed_dims), + nn.ReLU(), + nn.Conv2d(decoder_embed_dims, decoder_embed_dims, kernel_size=3, padding=1), + # nn.BatchNorm2d(decoder_embed_dims), + nn.ReLU(), + )) + self.level_embed = nn.Embedding(self.num_img_feat_level, decoder_embed_dims) + self.gather_img_feats = nn.Sequential( + nn.Conv2d(num_img_feat_level * decoder_embed_dims, out_put_channels, kernel_size=1), + # nn.BatchNorm2d(out_put_channels), + nn.ReLU(), + nn.Conv2d(out_put_channels, out_put_channels, 3, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(out_put_channels, out_put_channels*2, 3, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(out_put_channels * 2, out_put_channels * 2, 3, padding=1), + ) + + self.img_feats_pe = nn.Parameter(torch.zeros(1, out_put_channels*2, self.point_size, self.point_size)) + + self.cls_head = nn.Sequential( + nn.Conv2d(out_put_channels * 2, out_put_channels, 3, padding=1), + nn.ReLU(), + nn.Conv2d(out_put_channels, n_cls, 1) + ) + + self.point_emb = nn.Sequential( + nn.Conv2d(out_put_channels * 2, out_put_channels, 3, padding=1), + nn.ReLU(), + nn.Conv2d(out_put_channels, out_put_channels, 3, padding=1), + nn.ReLU(), + nn.Conv2d(out_put_channels, self.per_query_point * out_put_channels, 1) + ) + + def forward(self, inputs): + inner_states = [x.permute(0, 3, 1, 2) for x in inputs] # from low2high, all 4 layers + bs = inner_states[0].shape[0] + # inputs: list([B, C, H, W]) + num_layers = len(inputs) + # import ipdb; ipdb.set_trace() + # select the feature maps from the selected layers + layer_start_id = num_layers - self.num_img_feat_level + decoder_inputs = [] + for i in range(self.num_img_feat_level): + decoder_input = self.decoder_input_projs[i](inner_states[i + layer_start_id]) # Bx256x64x64 + level_embed = self.level_embed.weight[i].unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(bs, -1, -1, -1) + decoder_input = decoder_input + level_embed + decoder_inputs.append(decoder_input) + decoder_inputs = torch.cat(decoder_inputs, dim=1) # Bx256x64x64 + decoder_inputs = self.gather_img_feats(decoder_inputs) + # import pdb; + # pdb.set_trace() + decoder_inputs = torch.nn.functional.interpolate(decoder_inputs, size=(self.point_size, self.point_size), mode='bilinear', align_corners=True) + img_pe = self.img_feats_pe.expand(bs, -1, -1, -1) # Bx256x64x64 + decoder_inputs = decoder_inputs + img_pe + + cls_logits = self.cls_head(decoder_inputs) # b c h w + cls_logits = rearrange(cls_logits, 'b c h w -> b (h w) c') + point_embs = self.point_emb(decoder_inputs) # b c h w + point_embs = rearrange(point_embs, 'b (t c) h w -> b (h w) t c', t=self.per_query_point) # Bx100x6x256 + + return point_embs, cls_logits + + + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +class Transformer(nn.Module): + def __init__( + self, + depth: int = 2, + embedding_dim: int = 256, + num_heads: int = 8, + mlp_dim: int = 1024, + activation: Type[nn.Module] = nn.GELU, + attention_downsample_rate: int = 2, + ) -> None: + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + AttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + point_pe: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=image_embedding, + query_pe=image_pe, + keys=point_embedding, + key_pe=point_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + image_pe + k = keys + point_embedding + + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class AttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +@MODELS.register_module() +class SAMTransformerEDPromptGenNeck(nn.Module): + def __init__( + self, + prompt_shape=(100, 5), + in_channels=[1280]*16, + inner_channels=128, + selected_channels: list=None, + num_encoders=2, + num_decoders=2, + out_channels=256, + positional_encoding=dict(num_feats=128, normalize=True), + kernel_size=3, + stride=1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU', inplace=True), + init_cfg=None, + **kwargs + ): + super().__init__() + self.in_channels = in_channels + self.kernel_size = kernel_size + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.out_channels = out_channels + self.stride = stride + self.selected_channels = selected_channels + + self.prompt_shape = prompt_shape + self.num_queries = prompt_shape[0] + self.per_query_point = prompt_shape[1] + + self.down_sample_layers = nn.ModuleList() + for idx in self.selected_channels: + self.down_sample_layers.append( + nn.Sequential( + ConvModule( + in_channels[idx], + inner_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + stride=2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ) + ) + self.fusion_layers = nn.ModuleList() + for idx in self.selected_channels: + self.fusion_layers.append( + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ) + ) + self.up_layers = nn.ModuleList() + self.up_layers.append( + nn.Sequential( + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ) + ) + ) + self.up_layers.append( + ConvModule( + inner_channels, + out_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=None + ) + ) + + self.generator_pe = SinePositionalEncoding(**positional_encoding) + + self.en_layers = nn.ModuleList() + self.de_layers = nn.ModuleList() + self.build_transformer(num_encoders=num_encoders, num_decoders=num_decoders) + + self.embed_dims = self.en_layers[0].embed_dims + self.pre_norm = self.en_layers[0].pre_norm + + self.query_feat = nn.Embedding(self.num_queries, out_channels) + self.query_embed = nn.Embedding(self.num_queries, out_channels) + + # self.output_upscaling = nn.Sequential( + # nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + # nn.BatchNorm2d(out_channels), + # nn.GELU(), + # nn.UpsamplingBilinear2d(scale_factor=2), + # nn.Conv2d(out_channels, out_channels // 4, kernel_size=3, padding=1), + # nn.BatchNorm2d(out_channels // 4), + # nn.GELU(), + # nn.UpsamplingBilinear2d(scale_factor=2), + # nn.Conv2d(out_channels // 4, out_channels // 8, kernel_size=3, padding=1), + # nn.BatchNorm2d(out_channels // 8), + # nn.GELU(), + # nn.UpsamplingBilinear2d(scale_factor=2), + # nn.Conv2d(out_channels // 8, out_channels // 8, kernel_size=3, padding=1), + # ) + # self.output_hypernetworks_mlps = MLP(out_channels, out_channels, out_channels // 8, 3) + + self.init_weights() + + def build_transformer(self, num_encoders=2, num_decoders=2, embed_dims=256, num_heads=8, mlp_ratio=4): + transformer_encoder_layer = dict( + type='BaseTransformerLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=dict(type='Dropout', drop_prob=0.1) + ), + ], + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=embed_dims * mlp_ratio, + num_fcs=2, + act_cfg=dict(type='GELU'), + ffn_drop=0.1, + add_identity=True), + operation_order=('norm', 'self_attn', 'norm', 'ffn'), + norm_cfg=dict(type='LN'), + batch_first=True + ) + transformer_decoder_layer = dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=dict(type='Dropout', drop_prob=0.1) + ), + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=embed_dims * mlp_ratio, + num_fcs=2, + act_cfg=dict(type='GELU'), + ffn_drop=0.1, + add_identity=True), + operation_order=('norm', 'self_attn', 'norm', 'cross_attn', 'norm', 'ffn'), + norm_cfg=dict(type='LN'), + batch_first=True + ) + + transformer_en_layers = [ + copy.deepcopy(transformer_encoder_layer) for _ in range(num_encoders) + ] + transformer_de_layers = [ + copy.deepcopy(transformer_decoder_layer) for _ in range(num_decoders) + ] + for i in range(num_encoders): + self.en_layers.append(build_transformer_layer(transformer_en_layers[i])) + for i in range(num_decoders): + self.de_layers.append(build_transformer_layer(transformer_de_layers[i])) + + def init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, inputs): + _, inner_states = inputs + inner_states = [einops.rearrange(inner_states[idx], 'b h w c -> b c h w') for idx in self.selected_channels] + inner_states = [layer(x) for layer, x in zip(self.down_sample_layers, inner_states)] + + x = None + for inner_state, layer in zip(inner_states, self.fusion_layers): + if x is not None: + inner_state = x + inner_state + x = inner_state + layer(inner_state) + x = self.up_layers[0](x) + x + img_feats = self.up_layers[1](x) + + bs, c, h, w = img_feats.shape + + mask_pe = torch.zeros((bs, h, w), device=img_feats.device, dtype=torch.bool) + img_feats_pe = self.generator_pe(mask_pe) + + query_feat = self.query_feat.weight.unsqueeze(0).repeat( + (bs, 1, 1)) + query_embed = self.query_embed.weight.unsqueeze(0).repeat( + (bs, 1, 1)) + + encoder_inputs = rearrange(img_feats, 'b c h w -> b (h w) c') + img_feats_pe = img_feats_pe.flatten(2).permute(0, 2, 1) + + # shape (batch_size, num_total_queries, c) + memory = encoder_inputs + for layer in self.en_layers: + memory = layer( + query=memory, + query_pos=img_feats_pe + ) + # (batch_size, num_total_queries, c) + + query_feat_list = [] + for layer in self.de_layers: + query_feat = layer( + query=query_feat, + key=memory, + value=memory, + query_pos=query_embed, + key_pos=img_feats_pe + ) + query_feat_list.append(query_feat) + + img_feat = rearrange(memory, 'b (h w) c -> b c h w', h=h, w=w) + return query_feat, query_feat_list, img_feat + + +@MODELS.register_module() +class SAMAggregatorNeck(nn.Module): + def __init__( + self, + in_channels=[1280]*16, + inner_channels=128, + selected_channels: list=None, + out_channels=256, + kernel_size=3, + stride=1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU', inplace=True), + up_sample_scale=4, + init_cfg=None, + **kwargs + ): + super().__init__() + self.in_channels = in_channels + self.kernel_size = kernel_size + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.out_channels = out_channels + self.stride = stride + self.selected_channels = selected_channels + self.up_sample_scale = up_sample_scale + + self.down_sample_layers = nn.ModuleList() + for idx in self.selected_channels: + self.down_sample_layers.append( + nn.Sequential( + ConvModule( + in_channels[idx], + inner_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + stride=2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ) + ) + self.fusion_layers = nn.ModuleList() + for idx in self.selected_channels: + self.fusion_layers.append( + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ) + ) + self.up_layers = nn.ModuleList() + self.up_layers.append( + nn.Sequential( + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + inner_channels, + inner_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ) + ) + ) + self.up_layers.append( + ConvModule( + inner_channels, + out_channels, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=None + ) + ) + + self.up_sample_layers = nn.ModuleList() + assert up_sample_scale == 4 + self.up_sample_layers.append( + nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ) + ) + ) + + self.up_sample_layers.append( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + ) + + self.up_sample_layers.append( + nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ), + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg + ) + ) + ) + + self.up_sample_layers.append( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + ) + + def forward(self, inputs): + _, inner_states = inputs + inner_states = [einops.rearrange(inner_states[idx], 'b h w c -> b c h w') for idx in self.selected_channels] + inner_states = [layer(x) for layer, x in zip(self.down_sample_layers, inner_states)] + + x = None + for inner_state, layer in zip(inner_states, self.fusion_layers): + if x is not None: + inner_state = x + inner_state + x = inner_state + layer(inner_state) + x = self.up_layers[0](x) + x + img_feats_0 = self.up_layers[1](x) + + img_feats_1 = self.up_sample_layers[0](img_feats_0) + self.up_sample_layers[1](img_feats_0) + + img_feats_2 = self.up_sample_layers[2](img_feats_1) + self.up_sample_layers[3](img_feats_1) + + return img_feats_2, img_feats_1, img_feats_0 \ No newline at end of file diff --git a/mmpl/models/necks/sirens.py b/mmpl/models/necks/sirens.py new file mode 100644 index 0000000000000000000000000000000000000000..766433faa6a2a07cbd93b506163b80da20d54008 --- /dev/null +++ b/mmpl/models/necks/sirens.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn + +from mmpl.registry import MODELS +from mmengine.model import BaseModule + + +@MODELS.register_module() +class Sirens(BaseModule): + def __init__(self, + in_channels, + out_channels=3, + base_channels=256, + num_inner_layers=2, + is_residual=True + ): + super(Sirens, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.base_channels = base_channels + self.num_inner_layers = num_inner_layers + self.is_residual = is_residual + + self.first_coord = nn.Linear(in_channels, base_channels) + self.inner_coords = nn.ModuleList(nn.Linear(base_channels, base_channels) for _ in range(self.num_inner_layers)) + self.last_coord = nn.Linear(base_channels, out_channels) + + def forward(self, x): + x = self.first_coord(x) + x = torch.sin(x) + for idx in range(self.num_inner_layers): + residual = x + x = self.inner_coords[idx](x) + if self.is_residual: + x = x + residual + x = torch.sin(x) + x = self.last_coord(x) + return x + + +@MODELS.register_module() +class ModulatedSirens(BaseModule): + def __init__(self, + num_inner_layers, + in_dim, + modulation_dim, + out_dim=3, + base_channels=256, + is_residual=True + ): + super(ModulatedSirens, self).__init__() + self.in_dim = in_dim + self.num_inner_layers = num_inner_layers + self.is_residual = is_residual + + self.first_mod = nn.Sequential( + nn.Conv2d(modulation_dim, base_channels, 1), + nn.ReLU() + ) + self.first_coord = nn.Conv2d(in_dim, base_channels, 1) + + self.inner_mods = nn.ModuleList() + self.inner_coords = nn.ModuleList() + for _ in range(self.num_inner_layers): + self.inner_mods.append( + nn.Sequential( + nn.Conv2d(modulation_dim+base_channels+base_channels, base_channels, 1), + nn.ReLU() + ) + ) + self.inner_coords.append( + nn.Conv2d(base_channels, base_channels, 1) + ) + self.last_coord = nn.Sequential( + # nn.Conv2d(base_channels, base_channels//2, 1), + # nn.ReLU(), + nn.Conv2d(base_channels, out_dim, 1) + ) + + def forward(self, x, ori_modulations=None): + modulations = self.first_mod(ori_modulations) + x = self.first_coord(x) # B 2 H W -> B C H W + x = x + modulations + x = torch.sin(x) + for i_layer in range(self.num_inner_layers): + modulations = self.inner_mods[i_layer]( + torch.cat((ori_modulations, modulations, x), dim=1)) + # modulations = self.inner_mods[i_layer]( + # torch.cat((ori_modulations, x), dim=1)) + residual = self.inner_coords[i_layer](x) + residual = residual + modulations + residual = torch.sin(residual) + if self.is_residual: + x = x + residual + else: + x = residual + x = self.last_coord(x) + return x diff --git a/mmpl/models/necks/transformer_edecoder_neck.py b/mmpl/models/necks/transformer_edecoder_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f4ce7eb9f97361733fc26adb16487eb67ffc8b --- /dev/null +++ b/mmpl/models/necks/transformer_edecoder_neck.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +from mmpl.registry import MODELS + + +@MODELS.register_module() +class TransformerEDecoderNeck(nn.Module): + """Global Average Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. + Default: 2 + """ + + def __init__(self, model_dim, num_encoder_layers=3): + super(TransformerEDecoderNeck, self).__init__() + self.embed_dims = model_dim + self.with_cls_token = True + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + self.transformer_encoder_decoder = nn.Transformer( + d_model=model_dim, num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_encoder_layers, dim_feedforward=model_dim * 2, + batch_first=True, + dropout=0.1 + ) + self.out_linear_layer = nn.Sequential( + nn.Linear(model_dim, model_dim // 2), + nn.LeakyReLU(), + nn.Linear(model_dim // 2, model_dim) + ) + + def init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, inputs): + B = inputs.shape[0] + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, inputs), dim=1) + x = self.transformer_encoder_decoder(inputs, x) + if self.with_cls_token: + x = x[:, 0] + + residual = self.out_linear_layer(x) + x = x + residual + + return x diff --git a/mmpl/models/necks/transformer_neck.py b/mmpl/models/necks/transformer_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..e743e5e8e4ee46c9b75c4c3e63641a433fbe99f5 --- /dev/null +++ b/mmpl/models/necks/transformer_neck.py @@ -0,0 +1,94 @@ +import copy + +import torch +import torch.nn as nn +from mmpl.registry import MODELS +from mmengine.model import BaseModule +from mmcv.cnn.bricks.transformer import build_transformer_layer + + +@MODELS.register_module() +class TransformerEncoderNeck(BaseModule): + """Global Average Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. + Default: 2 + """ + + def __init__(self, + model_dim, + with_pe=True, + max_position_embeddings=24, + with_cls_token=True, + num_encoder_layers=3 + ): + super(TransformerEncoderNeck, self).__init__() + self.embed_dims = model_dim + self.with_cls_token = with_cls_token + self.with_pe = with_pe + + if self.with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + if self.with_pe: + self.pe = nn.Parameter(torch.zeros(1, max_position_embeddings, self.embed_dims)) + + mlp_ratio = 4 + embed_dims = model_dim + transformer_layer = dict( + type='BaseTransformerLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=dict(type='Dropout', drop_prob=0.1) + ), + ], + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=embed_dims * mlp_ratio, + num_fcs=2, + act_cfg=dict(type='GELU'), + ffn_drop=0.1, + add_identity=True), + operation_order=('norm', 'self_attn', 'norm', 'ffn'), + norm_cfg=dict(type='LN'), + batch_first=True + ) + + self.layers = nn.ModuleList() + transformer_layers = [ + copy.deepcopy(transformer_layer) for _ in range(num_encoder_layers) + ] + for i in range(num_encoder_layers): + self.layers.append(build_transformer_layer(transformer_layers[i])) + self.embed_dims = self.layers[0].embed_dims + self.pre_norm = self.layers[0].pre_norm + + def init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, x): + B = x.shape[0] + if self.with_cls_token: + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.with_pe: + x = x + self.pe[:, :x.shape[1], :] + for layer in self.layers: + x = layer(x) + + if self.with_cls_token: + return x[:, 0], x + return None, x diff --git a/mmpl/models/pler/__init__.py b/mmpl/models/pler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3eb7530e9337e719512676668a6eb63d204d7c --- /dev/null +++ b/mmpl/models/pler/__init__.py @@ -0,0 +1,11 @@ +from .gpt_pler import GPTPLer +from .seg_pler import SegPLer +from .mmseg_pler import MMSegPLer +from .seg_sam_pler import SegSAMPLer +from .mmdet_pler import MMDetPLer +from .semseg_sam_pler import SemSegSAMPLer +from .seg_sam_anchor_pler import SegSAMAnchorPLer +from .seg_samdet import SegSAMDetPLer +from .mmcls_pler import MMClsPLer + +# __all__ = ['GPTPLer', 'YoloPLer', 'SegPLer', 'MMSegPLer'] diff --git a/mmpl/models/pler/__pycache__/__init__.cpython-310.pyc b/mmpl/models/pler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da8db9b8b0a3761ee134298382b25032e1a8d933 Binary files /dev/null and b/mmpl/models/pler/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/base.cpython-310.pyc b/mmpl/models/pler/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6956bfcef102e95a4246ebcc2f6c3e191bb22af6 Binary files /dev/null and b/mmpl/models/pler/__pycache__/base.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/base_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/base_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc7f1969e47b2dfe2c1a66ae3be742e0767e4e4e Binary files /dev/null and b/mmpl/models/pler/__pycache__/base_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/mmcls_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/mmcls_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8816c476d755aee0a5d8cfd5d73f4eb5650781d3 Binary files /dev/null and b/mmpl/models/pler/__pycache__/mmcls_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/mmdet_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/mmdet_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2e59b0b5dfc99f3cf107adeedd1a34757aa1fc3 Binary files /dev/null and b/mmpl/models/pler/__pycache__/mmdet_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/mmseg_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/mmseg_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f03f7d917ea51bd7c66a4b37ba1225edf67602e Binary files /dev/null and b/mmpl/models/pler/__pycache__/mmseg_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/seg_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/seg_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1403c0f6184b4746014d9b2dd67b459ae05f5ecd Binary files /dev/null and b/mmpl/models/pler/__pycache__/seg_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/seg_sam_anchor_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/seg_sam_anchor_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98a58ea4d172f88a7aa36c16bec2450314abcd3b Binary files /dev/null and b/mmpl/models/pler/__pycache__/seg_sam_anchor_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/seg_sam_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/seg_sam_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fb767ddf19ceeaa762e143a877076c6f3869142 Binary files /dev/null and b/mmpl/models/pler/__pycache__/seg_sam_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/seg_samdet.cpython-310.pyc b/mmpl/models/pler/__pycache__/seg_samdet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c250a1aa6cbe438cf4fdd2d5173d627b87770884 Binary files /dev/null and b/mmpl/models/pler/__pycache__/seg_samdet.cpython-310.pyc differ diff --git a/mmpl/models/pler/__pycache__/semseg_sam_pler.cpython-310.pyc b/mmpl/models/pler/__pycache__/semseg_sam_pler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42490c7c2290a0f8fbc1e6b7765b726f9ff89c06 Binary files /dev/null and b/mmpl/models/pler/__pycache__/semseg_sam_pler.cpython-310.pyc differ diff --git a/mmpl/models/pler/base.py b/mmpl/models/pler/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a65fc213f4bfe271a9298b823ba38fc4ca9f57e1 --- /dev/null +++ b/mmpl/models/pler/base.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Optional, Sequence + +import torch +from mmengine.model import BaseModel +from mmengine.structures import BaseDataElement + + +class BaseClassifier(BaseModel, metaclass=ABCMeta): + """Base class for classifiers. + + Args: + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None, it will use "BaseDataPreprocessor" as type, see + :class:`mmengine.model.BaseDataPreprocessor` for more details. + Defaults to None. + + Attributes: + init_cfg (dict): Initialization config dict. + data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An + extra data pre-processing module, which processes data from + dataloader to the format accepted by :meth:`forward`. + """ + + def __init__(self, + init_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None): + super(BaseClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + @property + def with_neck(self) -> bool: + """Whether the classifier has a neck.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_head(self) -> bool: + """Whether the classifier has a head.""" + return hasattr(self, 'head') and self.head is not None + + @abstractmethod + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`BaseDataElement`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C, ...) + in general. + data_samples (List[BaseDataElement], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmengine.BaseDataElement`. + - If ``mode="loss"``, return a dict of tensor. + """ + pass + + def extract_feat(self, inputs: torch.Tensor): + """Extract features from the input tensor with shape (N, C, ...). + + The sub-classes are recommended to implement this method to extract + features from backbone and neck. + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + """ + raise NotImplementedError + + def extract_feats(self, multi_inputs: Sequence[torch.Tensor], + **kwargs) -> list: + """Extract features from a sequence of input tensor. + + Args: + multi_inputs (Sequence[torch.Tensor]): A sequence of input + tensor. It can be used in augmented inference. + **kwargs: Other keyword arguments accepted by :meth:`extract_feat`. + + Returns: + list: Features of every input tensor. + """ + assert isinstance(multi_inputs, Sequence), \ + '`extract_feats` is used for a sequence of inputs tensor. If you '\ + 'want to extract on single inputs tensor, use `extract_feat`.' + return [self.extract_feat(inputs, **kwargs) for inputs in multi_inputs] diff --git a/mmpl/models/pler/base_pler.py b/mmpl/models/pler/base_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..bf099728a9e7bb452c61ffc5d4984ac58d8ef939 --- /dev/null +++ b/mmpl/models/pler/base_pler.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +from lightning.pytorch.utilities import grad_norm +from mmengine import OPTIM_WRAPPERS +from mmengine.optim import build_optim_wrapper, _ParamScheduler +import copy + +from torchmetrics import MetricCollection + +from mmpl.registry import MODELS, METRICS +import lightning.pytorch as pl +from mmengine.registry import OPTIMIZERS, PARAM_SCHEDULERS +from mmengine.model import BaseModel + + +@MODELS.register_module() +class BasePLer(pl.LightningModule, BaseModel): + def __init__( + self, + hyperparameters, + data_preprocessor=None, + train_cfg=None, + test_cfg=None, + *args, + **kwargs + ): + super().__init__() + self.hyperparameters = hyperparameters + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + if data_preprocessor is not None: + if isinstance(data_preprocessor, nn.Module): + self.data_preprocessor = data_preprocessor + elif isinstance(data_preprocessor, dict): + self.data_preprocessor = MODELS.build(data_preprocessor) + else: + raise TypeError('data_preprocessor should be a `dict` or ' + f'`nn.Module` instance, but got ' + f'{type(data_preprocessor)}') + + evaluator_cfg = copy.deepcopy(self.hyperparameters.get('evaluator', None)) + if evaluator_cfg is not None: + for k, v in evaluator_cfg.items(): + metrics = [] + if isinstance(v, dict): + v = [v] + if isinstance(v, list): + for metric_cfg in v: + metric = METRICS.build(metric_cfg) + metrics.append(metric) + else: + raise TypeError('evaluator should be a `dict` or ' + f'`list` instance, but got ' + f'{type(evaluator_cfg)}') + setattr(self, k, MetricCollection(metrics, prefix=k.split('_')[0])) + + def _set_grad(self, need_train_names: list=[], noneed_train_names: list=[]): + for name, param in self.named_parameters(): + flag = False + for need_train_name in need_train_names: + if need_train_name in name: + flag = True + for noneed_train_name in noneed_train_names: + if noneed_train_name in name: + flag = False + param.requires_grad_(flag) + + not_specific_names = [] + for name, param in self.named_parameters(): + flag_find = False + for specific_name in need_train_names + noneed_train_names: + if specific_name in name: + flag_find = True + if not flag_find: + not_specific_names.append(name) + + if self.local_rank == 0: + not_specific_names = [x.split('.')[0] for x in not_specific_names] + not_specific_names = set(not_specific_names) + print(f"Turning off gradients for names: {noneed_train_names}") + print(f"Turning on gradients for names: {need_train_names}") + print(f"Turning off gradients for not specific names: {not_specific_names}") + + def _set_train_module(self, mode=True, need_train_names: list=[]): + self.training = mode + for name, module in self.named_children(): + flag = False + for need_train_name in need_train_names: + if need_train_name in name: + flag = True + if flag: + module.train(mode) + else: + module.eval() + return self + + def configure_optimizers(self): + optimizer_cfg = copy.deepcopy(self.hyperparameters.get('optimizer')) + base_lr = optimizer_cfg.get('lr') + base_wd = optimizer_cfg.get('weight_decay', None) + + sub_models = optimizer_cfg.pop('sub_model', None) + if sub_models is None: + optimizer_cfg['params'] = filter(lambda p: p.requires_grad, self.parameters()) + # optimizer_cfg['params'] = self.parameters() + else: + if isinstance(sub_models, str): + sub_models = {sub_models: {}} + if isinstance(sub_models, list): + sub_models = {x: {} for x in sub_models} + assert isinstance(sub_models, dict), f'sub_models should be a dict, but got {type(sub_models)}' + # import ipdb; ipdb.set_trace() + # set training parameters and lr + for sub_model_name, value in sub_models.items(): + sub_attrs = sub_model_name.split('.') + sub_model_ = self + # import ipdb; ipdb.set_trace() + for sub_attr in sub_attrs: + sub_model_ = getattr(sub_model_, sub_attr) + # sub_model_ = self.trainer.strategy.model._forward_module.get_submodule(sub_model_name) + if isinstance(sub_model_, torch.nn.Parameter): + # filter(lambda p: p.requires_grad, model.parameters()) + # sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, [sub_model_]) + sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, [sub_model_]) + else: + # import ipdb;ipdb.set_trace() + sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, sub_model_.parameters()) + # sub_models[sub_model_name]['params'] = sub_model_.parameters() + lr_mult = value.pop('lr_mult', 1.) + sub_models[sub_model_name]['lr'] = base_lr * lr_mult + if base_wd is not None: + decay_mult = value.pop('decay_mult', 1.) + sub_models[sub_model_name]['weight_decay'] = base_wd * decay_mult + else: + raise ModuleNotFoundError(f'{sub_model_name} not in model') + + if self.local_rank == 0: + print('All sub models:') + for name, module in self.named_children(): + print(name, end=', ') + print() + print('Needed train models:') + for name, value in sub_models.items(): + print(f'{name}', end=', ') + print() + + optimizer_cfg['params'] = [value for key, value in sub_models.items()] + + optimizer = OPTIMIZERS.build(optimizer_cfg) + if self.local_rank == 0: + print('查看优化器参数') + for param_group in optimizer.param_groups: + print([value.shape for value in param_group['params']], '学习率: ', param_group['lr']) + + schedulers = copy.deepcopy(self.hyperparameters.get('param_scheduler', None)) + if schedulers is None: + return [optimizer] + param_schedulers = [] + total_step = self.trainer.estimated_stepping_batches + for scheduler in schedulers: + if isinstance(scheduler, _ParamScheduler): + param_schedulers.append(scheduler) + elif isinstance(scheduler, dict): + _scheduler = copy.deepcopy(scheduler) + param_schedulers.append( + PARAM_SCHEDULERS.build( + _scheduler, + default_args=dict( + optimizer=optimizer, + epoch_length=self.trainer.num_training_batches, + ) + ) + ) + else: + raise TypeError( + 'scheduler should be a _ParamScheduler object or dict, ' + f'but got {scheduler}') + + return [optimizer], param_schedulers + + def lr_scheduler_step(self, scheduler, metric): + pass + + def log_grad(self, module=None) -> None: + # Compute the 2-norm for each layer + # If using mixed precision, the gradients are already unscaled here + if module is None: + module = self + norms = grad_norm(module, norm_type=2) + max_grad = max(norms.values()) + min_gead = min(norms.values()) + self.log_dict( + {'max_grad': max_grad, 'min_grad': min_gead}, + prog_bar=True, + logger=True + ) + + def setup(self, stage: str) -> None: + evaluators = ['train', 'val', 'test'] + for evaluator in evaluators: + if hasattr(self, f'{evaluator}_evaluator'): + if hasattr(self.trainer.datamodule, f'{evaluator}_dataset'): + dataset = getattr(self.trainer.datamodule, f'{evaluator}_dataset') + if hasattr(dataset, 'metainfo'): + evaluator_ = getattr(self, f'{evaluator}_evaluator') + for v in evaluator_.values(): + if hasattr(v, 'dataset_meta'): + v.dataset_meta = dataset.metainfo + + def on_before_optimizer_step(self, optimizer) -> None: + self.log_grad() + + def on_validation_epoch_end(self) -> None: + self._log_eval_metrics('val') + + def on_test_epoch_end(self) -> None: + self._log_eval_metrics('test') + + def on_train_epoch_end(self) -> None: + self._log_eval_metrics('train') + + def _log_eval_metrics(self, stage): + assert stage in ['train', 'val', 'test'] + if hasattr(self, f'{stage}_evaluator'): + evaluator = getattr(self, f'{stage}_evaluator') + metrics = evaluator.compute() + metrics = {k.lower(): v for k, v in metrics.items()} + keys = [] + for k, v in metrics.items(): + v = v.view(-1) + for i, data in enumerate(v): + keys.append(f'{k}_{i}') + self.log(f'{k.lower()}_{i}', data, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + evaluator.reset() + + if hasattr(self.trainer, 'checkpoint_callback'): + monitor = self.trainer.checkpoint_callback.monitor + if (monitor is not None) and (monitor not in keys): + data = torch.tensor(0., device=self.device) + self.log(f'{monitor}', data, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) \ No newline at end of file diff --git a/mmpl/models/pler/gpt_pler.py b/mmpl/models/pler/gpt_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..66a07c36fab2f526803f895ca59079b3fc707e16 --- /dev/null +++ b/mmpl/models/pler/gpt_pler.py @@ -0,0 +1,34 @@ +from typing import Any + +import torch +import torch.nn as nn +from mmpl.registry import MODELS +from ..builder import build_backbone, build_loss +from .base_pler import BasePLer +from mmpl.structures import ClsDataSample +from .base import BaseClassifier +import lightning.pytorch as pl +import torch.nn.functional as F + + +@MODELS.register_module() +class GPTPLer(BasePLer): + def __init__(self, + backbone, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + *args, **kwargs): + super().__init__(*args, **kwargs) + self.backbone = build_backbone(backbone) + self.loss = build_loss(loss) + + def training_step(self, batch, batch_idx): + x, gt_label = batch['x'], batch['gt_label'] + outputs = self(input_ids=x, labels=gt_label) + loss, logits = outputs['loss'], outputs['logits'] + return loss + + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.backbone(*args, **kwargs) + + def validation_step(self, batch, batch_idx): + pass diff --git a/mmpl/models/pler/image.py b/mmpl/models/pler/image.py new file mode 100644 index 0000000000000000000000000000000000000000..7569ddfa4d0b7898615458bceacebaad75d549e0 --- /dev/null +++ b/mmpl/models/pler/image.py @@ -0,0 +1,241 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch +import torch.nn as nn + +from mmpl.registry import MODELS +from mmpl.structures import ClsDataSample +from .base import BaseClassifier + + +@MODELS.register_module() +class ImageClassifier(BaseClassifier): + """Image classifiers for supervised classification task. + + Args: + backbone (dict): The backbone module. See + :mod:`mmcls.models.backbones`. + neck (dict, optional): The neck module to process features from + backbone. See :mod:`mmcls.models.necks`. Defaults to None. + head (dict, optional): The head module to do prediction and calculate + loss from processed features. See :mod:`mmcls.models.heads`. + Notice that if the head is not set, almost all methods cannot be + used except :meth:`extract_feat`. Defaults to None. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + train_cfg (dict, optional): The training setting. The acceptable + fields are: + + - augments (List[dict]): The batch augmentation methods to use. + More details can be found in :mod:`mmcls.model.utils.augment`. + + Defaults to None. + data_preprocessor (dict, optional): The config for preprocessing input + data. If None or no specified type, it will use + "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for + more details. Defaults to None. + init_cfg (dict, optional): the config to control the initialization. + Defaults to None. + """ + + def __init__(self, + backbone: dict, + neck: Optional[dict] = None, + head: Optional[dict] = None, + pretrained: Optional[str] = None, + train_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + init_cfg: Optional[dict] = None): + if pretrained is not None: + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') + + if train_cfg is not None and 'augments' in train_cfg: + # Set batch augmentations by `train_cfg` + data_preprocessor['batch_augments'] = train_cfg + + super(ImageClassifier, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + if not isinstance(backbone, nn.Module): + backbone = MODELS.build(backbone) + if neck is not None and not isinstance(neck, nn.Module): + neck = MODELS.build(neck) + if head is not None and not isinstance(head, nn.Module): + head = MODELS.build(head) + + self.backbone = backbone + self.neck = neck + self.head = head + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[ClsDataSample]] = None, + mode: str = 'tensor'): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`ClsDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[ClsDataSample], optional): The annotation + data of every samples. It's required if ``mode="loss"``. + Defaults to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of + :obj:`mmcls.structures.ClsDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'tensor': + feats = self.extract_feat(inputs) + return self.head(feats) if self.with_head else feats + elif mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}".') + + def extract_feat(self, inputs, stage='neck'): + """Extract features from the input tensor with shape (N, C, ...). + + Args: + inputs (Tensor): A batch of inputs. The shape of it should be + ``(num_samples, num_channels, *img_shape)``. + stage (str): Which stage to output the feature. Choose from: + + - "backbone": The output of backbone network. Returns a tuple + including multiple stages features. + - "neck": The output of neck module. Returns a tuple including + multiple stages features. + - "pre_logits": The feature before the final classification + linear layer. Usually returns a tensor. + + Defaults to "neck". + + Returns: + tuple | Tensor: The output of specified stage. + The output depends on detailed implementation. In general, the + output of backbone and neck is a tuple and the output of + pre_logits is a tensor. + + Examples: + 1. Backbone output + + >>> import torch + >>> from mmengine import Config + >>> from mmcls.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64, 56, 56]) + torch.Size([1, 128, 28, 28]) + torch.Size([1, 256, 14, 14]) + torch.Size([1, 512, 7, 7]) + + 2. Neck output + + >>> import torch + >>> from mmengine import Config + >>> from mmcls.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model + >>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps + >>> model = build_classifier(cfg) + >>> + >>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck') + >>> for out in outs: + ... print(out.shape) + torch.Size([1, 64]) + torch.Size([1, 128]) + torch.Size([1, 256]) + torch.Size([1, 512]) + + 3. Pre-logits output (without the final linear classifier head) + + >>> import torch + >>> from mmengine import Config + >>> from mmcls.models import build_classifier + >>> + >>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model + >>> model = build_classifier(cfg) + >>> + >>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits') + >>> print(out.shape) # The hidden dims in head is 3072 + torch.Size([1, 3072]) + """ # noqa: E501 + assert stage in ['backbone', 'neck', 'pre_logits'], \ + (f'Invalid output stage "{stage}", please choose from "backbone", ' + '"neck" and "pre_logits"') + + x = self.backbone(inputs) + + if stage == 'backbone': + return x + + if self.with_neck: + x = self.neck(x) + if stage == 'neck': + return x + + assert self.with_head and hasattr(self.head, 'pre_logits'), \ + "No head or the head doesn't implement `pre_logits` method." + return self.head.pre_logits(x) + + def loss(self, inputs: torch.Tensor, + data_samples: List[ClsDataSample]) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[ClsDataSample]): The annotation data of + every samples. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + feats = self.extract_feat(inputs) + return self.head.loss(feats, data_samples) + + def predict(self, + inputs: torch.Tensor, + data_samples: Optional[List[ClsDataSample]] = None, + **kwargs) -> List[ClsDataSample]: + """Predict results from a batch of inputs. + + Args: + inputs (torch.Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[ClsDataSample], optional): The annotation + data of every samples. Defaults to None. + **kwargs: Other keyword arguments accepted by the ``predict`` + method of :attr:`head`. + """ + feats = self.extract_feat(inputs) + return self.head.predict(feats, data_samples, **kwargs) diff --git a/mmpl/models/pler/mmcls_pler.py b/mmpl/models/pler/mmcls_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..c42456e90149b9c61f9740a77f9f063f776b33e5 --- /dev/null +++ b/mmpl/models/pler/mmcls_pler.py @@ -0,0 +1,60 @@ +import os +from typing import Any + +import mmengine +import torch +import torch.nn as nn +from einops import rearrange + +from mmdet.models.utils import samplelist_boxtype2tensor +from mmdet.structures import SampleList +from mmdet.utils import InstanceList +from mmpl.registry import MODELS +from ..builder import build_backbone, build_loss, build_neck, build_head +from .base_pler import BasePLer +from mmpl.structures import ClsDataSample +from .base import BaseClassifier +import lightning.pytorch as pl +import torch.nn.functional as F + + +@MODELS.register_module() +class MMClsPLer(BasePLer): + def __init__(self, + whole_model=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.whole_model = MODELS.build(whole_model) + + def setup(self, stage: str) -> None: + super().setup(stage) + + def validation_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, False) + batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore + + pred_label = torch.cat([data_sample.pred_label for data_sample in batch_data_samples]) + gt_label = torch.cat([data_sample.gt_label for data_sample in batch_data_samples]) + self.val_evaluator.update(pred_label, gt_label) + # self.val_evaluator.update(batch, batch_data_samples) + + def training_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, True) + losses = self.whole_model._run_forward(data, mode='loss') # type: ignore + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + + def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any): + data = self.whole_model.data_preprocessor(batch, False) + batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore + self.test_evaluator.update(batch, batch_data_samples) + + + + + + diff --git a/mmpl/models/pler/mmdet_pler.py b/mmpl/models/pler/mmdet_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..5e85cf0b0349e20923c8d4ddaf066171215fc269 --- /dev/null +++ b/mmpl/models/pler/mmdet_pler.py @@ -0,0 +1,103 @@ +import os +from typing import Any + +import mmengine +import torch +import torch.nn as nn +from einops import rearrange + +from mmdet.models.utils import samplelist_boxtype2tensor +from mmdet.structures import SampleList +from mmdet.utils import InstanceList +from mmpl.registry import MODELS +from ..builder import build_backbone, build_loss, build_neck, build_head +from .base_pler import BasePLer +from mmpl.structures import ClsDataSample +from .base import BaseClassifier +import lightning.pytorch as pl +import torch.nn.functional as F + + +@MODELS.register_module() +class MMDetPLer(BasePLer): + def __init__(self, + whole_model=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.whole_model = MODELS.build(whole_model) + + def setup(self, stage: str) -> None: + super().setup(stage) + + def validation_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, False) + batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore + # preds = [] + # targets = [] + # for data_sample in batch_data_samples: + # result = dict() + # pred = data_sample.pred_instances + # result['boxes'] = pred['bboxes'] + # result['scores'] = pred['scores'] + # result['labels'] = pred['labels'] + # if 'masks' in pred: + # result['masks'] = pred['masks'] + # preds.append(result) + # # parse gt + # gt = dict() + # gt_data = data_sample.get('gt_instances', None) + # gt['boxes'] = gt_data['bboxes'] + # gt['labels'] = gt_data['labels'] + # if 'masks' in pred: + # gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device) + # targets.append(gt) + + # self.val_evaluator.update(preds, targets) + self.val_evaluator.update(batch, batch_data_samples) + + def training_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, True) + losses = self.whole_model._run_forward(data, mode='loss') # type: ignore + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + + def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any): + data = self.whole_model.data_preprocessor(batch, False) + batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore + preds = [] + targets = [] + for data_sample in batch_data_samples: + result = dict() + pred = data_sample.pred_instances + result['boxes'] = pred['bboxes'] + result['scores'] = pred['scores'] + result['labels'] = pred['labels'] + if 'masks' in pred: + result['masks'] = pred['masks'] + preds.append(result) + # parse gt + gt = dict() + gt_data = data_sample.get('gt_instances', None) + gt['boxes'] = gt_data['bboxes'] + gt['labels'] = gt_data['labels'] + if 'masks' in pred: + gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device) + targets.append(gt) + + # self.test_evaluator.update(preds, targets) + self.test_evaluator.update(batch, batch_data_samples) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + data = self.whole_model.data_preprocessor(batch, False) + batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore + return batch_data_samples + + + + + + diff --git a/mmpl/models/pler/mmseg_pler.py b/mmpl/models/pler/mmseg_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..fa73f067f78800e8b0de4e61988522a821c27bea --- /dev/null +++ b/mmpl/models/pler/mmseg_pler.py @@ -0,0 +1,68 @@ +import os +from typing import Any + +import mmengine +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange + +from mmpl.registry import MODELS +from ..builder import build_backbone, build_loss, build_neck, build_head +from .base_pler import BasePLer +from mmpl.structures import ClsDataSample +from .base import BaseClassifier +import lightning.pytorch as pl +import torch.nn.functional as F + + +@MODELS.register_module() +class MMSegPLer(BasePLer): + def __init__(self, + whole_model=None, + train_cfg=None, + test_cfg=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.whole_model = MODELS.build(whole_model) + + def setup(self, stage: str) -> None: + pass + + def init_weights(self): + import ipdb; ipdb.set_trace() + pass + + def training_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, True) + losses = self.whole_model._run_forward(data, mode='loss') # type: ignore + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + # return torch.tensor(0.0, requires_grad=True, device=self.device) + + def validation_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, False) + data_samples = self.whole_model._run_forward(data, mode='predict') + pred = [data_sample.pred_sem_seg.data for data_sample in data_samples] + label = [data_sample.gt_sem_seg.data for data_sample in data_samples] + pred = torch.cat(pred, dim=0) + label = torch.cat(label, dim=0) + self.val_evaluator.update(pred, label) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + data = self.whole_model.data_preprocessor(batch, False) + data_samples = self.whole_model._run_forward(data, mode='predict') + return data_samples + + + + + + + + + diff --git a/mmpl/models/pler/seg_pler.py b/mmpl/models/pler/seg_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..f9212863df91078c177ee36a1836509d5440c68b --- /dev/null +++ b/mmpl/models/pler/seg_pler.py @@ -0,0 +1,419 @@ +import os +from typing import Any + +import einops +import mmengine +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from lightning.pytorch.utilities import grad_norm +from mmengine.structures import InstanceData + +from mmpl.registry import MODELS +from mmseg.utils import SampleList +from ..builder import build_backbone, build_loss, build_neck, build_head +from .base_pler import BasePLer +from mmpl.structures import ClsDataSample +from .base import BaseClassifier +import lightning.pytorch as pl +import torch.nn.functional as F + + +@MODELS.register_module() +class SegPLer(BasePLer): + def __init__(self, + sam=None, + sam_checkpoint='', + points_per_side=None, + sam_prompt_generator=None, + only_img_encoder=False, + only_decoder=False, + global_prompt=None, + need_train_names=None, + head=None, + with_clip=False, + train_head=False, + threshold=0.5, + ignore_index=255, + train_cfg=None, + test_cfg=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.need_train_names = need_train_names + self.ignore_index = ignore_index + self.threshold = threshold + self.only_img_encoder = only_img_encoder + self.only_decoder = only_decoder + self.global_prompt = global_prompt + self.train_head = train_head + + if sam is not None: + if self.only_img_encoder: + self.sam = sam_model_registry[sam](sam_checkpoint).image_encoder + elif self.only_decoder: + self.prompt_encoder = sam_model_registry[sam](sam_checkpoint).prompt_encoder + self.mask_decoder = sam_model_registry[sam](sam_checkpoint).mask_decoder + else: + sam = sam_model_registry[sam](sam_checkpoint, train_head=train_head) + self.img_encoder = sam.image_encoder + self.prompt_encoder = sam.prompt_encoder + self.mask_decoder = sam.mask_decoder + self.prompt_encoder_no_mask_embed = sam.prompt_encoder.no_mask_embed + + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, 0, 1) + if sam_prompt_generator is not None: + self.sam_prompt_generator = MODELS.build(sam_prompt_generator) + if head is not None: + self.head = MODELS.build(head) + self.with_clip = with_clip + if global_prompt is not None: + if with_clip: + self.logits_prompt = nn.Sequential( + nn.Linear(1, 8), + nn.ReLU(), + nn.Linear(8, 16) + ) + self.global_prompt = nn.Sequential( + nn.Conv2d(768+16, 256, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(256, 1, kernel_size=3, padding=1), + ) + else: + self.global_prompt = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(128, 1, kernel_size=3, padding=1), + ) + + def setup(self, stage: str) -> None: + if self.need_train_names is not None: + self._set_grad(self.need_train_names, noneed_train_names=[]) + + def configure_sharded_model(self) -> None: + if self.trainer.strategy.__class__.__name__ == 'FSDPStrategy': + from torch.distributed.fsdp.wrap import wrap + self.sam_prompt_generator = wrap(self.sam_prompt_generator) + self.img_encoder = wrap(self.img_encoder) + self.prompt_encoder_no_mask_embed = wrap(self.prompt_encoder_no_mask_embed) + self.mask_decoder = wrap(self.mask_decoder) + self.prompt_encoder = wrap(self.prompt_encoder) + from torch.distributed.fsdp import CPUOffload + # from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + # import functools + # strategy = dict( + # type='FSDPStrategy', + # cpu_offload=CPUOffload(offload_params=True), + # auto_wrap_policy=functools.partial( + # size_based_auto_wrap_policy, min_num_params=int(1e8) + # ) + # + # ) + else: + super().configure_sharded_model() + + def configure_optimizers(self): + if self.trainer.strategy.__class__.__name__ == 'DeepSpeedStrategy': + import deepspeed + # optimizer = deepspeed.runtime. + optimizer = deepspeed.ops.adam.FusedAdam(self.sam_prompt_generator.parameters(), lr=1e-4) + # optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(self.sam_prompt_generator.parameters(), lr=1e-4) + # optimizer = torch.optim.Adam(self.sam_prompt_generator.parameters(), lr=1e-4) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) + return [optimizer], [lr_scheduler] + else: + return super().configure_optimizers() + + def init_weights(self): + import ipdb; ipdb.set_trace() + pass + + # def on_fit_start(self) -> None: + # if hasattr(self, 'train_evaluator'): + # self.train_evaluator = self.train_evaluator.to(self.device) + # if hasattr(self, 'val_evaluator'): + # self.val_evaluator = self.val_evaluator.to(self.device) + + def train(self, mode=True): + if self.need_train_names is not None: + return self._set_train_module(mode, self.need_train_names) + else: + super().train(mode) + return self + + def validation_step(self, batch, batch_idx): + seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) + if self.only_img_encoder: + masks_pred = self.forward_only_img_encoder(batch) + masks_pred = F.interpolate(masks_pred, size=seg_label.shape[-2:], mode='bilinear', + align_corners=True) + seg_logits = masks_pred > 0 + elif self.only_decoder: + cls_logits, masks, n_iou_preds = self.forward_sam_prompt_generator(batch) # 1x100x2, 1x100x1x256x256, 1x100x1 + masks = masks.squeeze(2) + masks = F.interpolate(masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) + # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds + seg_logits = self.post_process(cls_logits.detach(), masks.detach()) + seg_logits = seg_logits > self.threshold + else: + cls_logits, pred_masks, n_iou_preds = self.forward_sam_prompt_generator_all( + batch) # 1x100x2, 1x100x1x256x256, 1x100x1 + pred_masks = pred_masks.squeeze(2) + pred_masks = F.interpolate(pred_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) + # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds + seg_logits = self.post_process(cls_logits.detach(), pred_masks.detach()) + seg_logits = seg_logits > self.threshold + # import ipdb; ipdb.set_trace() + self.val_evaluator.update(seg_logits, seg_label) + + def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any): + cls_logits, n_img_masks = self.forward(batch) + + seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) + seg_label = seg_label.squeeze(1) + masks = F.interpolate(n_img_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) + masks = masks.squeeze(1) > 0 + self.evaluator.update(masks, seg_label) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + gt_masks = data_sample.instances_data.long() + gt_labels = data_sample.instances_label.long() + + instance_data = InstanceData(labels=gt_labels, masks=gt_masks) + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def training_step(self, batch, batch_idx): + if self.only_img_encoder: + masks_pred = self.forward_only_img_encoder(batch) + seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) + masks_pred = F.interpolate(masks_pred, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) + losses = self.head.loss(masks_pred, seg_label) + masks_pred_result = masks_pred > 0 + self.train_evaluator.update(masks_pred_result.detach(), seg_label.detach()) + + elif self.only_decoder: + cls_logits, masks, n_iou_preds = self.forward_sam_prompt_generator(batch) # 1x100x2, 1x100x1x256x256, 1x100x1 + masks = masks.squeeze(2) + seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) + masks = F.interpolate(masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) + # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds + seg_logits = self.post_process(cls_logits.clone().detach(), masks.clone().detach()) + seg_logits = seg_logits > self.threshold + self.train_evaluator.update(seg_logits, seg_label) + + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch['data_samples']) + + losses = self.head.loss(cls_logits, masks, batch_gt_instances, batch_img_metas) + else: + cls_logits, pred_masks, n_iou_preds = self.forward_sam_prompt_generator_all( + batch) # 1x100x2, 1x100x1x256x256, 1x100x1 + pred_masks = pred_masks.squeeze(2) + if torch.isinf(pred_masks).any() or torch.isnan(pred_masks).any(): + # import ipdb; + # ipdb.set_trace() + # raise ValueError('cost is nan in CrossEntropyLossCost') + print('!!!!!!!!!!!!!!!!!!!!loss is nan or inf!!!!!!!!!!!!!!!!!!') + return torch.tensor(0.0, requires_grad=True, device=self.device) + seg_label = torch.stack([x.gt_sem_seg.data for x in batch['data_samples']], dim=0) + pred_masks = F.interpolate(pred_masks, size=seg_label.shape[-2:], mode='bilinear', align_corners=True) + # cls_logits[..., 1:2] = cls_logits[..., 1:2] * n_iou_preds + seg_logits = self.post_process(cls_logits.clone().detach(), pred_masks.clone().detach()) + seg_logits = seg_logits > self.threshold + self.train_evaluator.update(seg_logits, seg_label) + + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch['data_samples']) + + losses = self.head.loss(cls_logits, pred_masks, batch_gt_instances, batch_img_metas) + + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + + def on_before_optimizer_step(self, optimizer) -> None: + self.log_grad(module=self.sam_prompt_generator) + + def post_process(self, mask_cls_results, mask_pred_results): + cls_score = F.softmax(mask_cls_results, dim=-1)[..., 1:2] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred) + return seg_logits + + def forward_only_img_encoder(self, batch, *args: Any, **kwargs: Any) -> Any: + if self.with_clip: + clip_dense_embs = torch.stack([x.clip_dense_embs for x in batch['data_samples']], dim=0) + logits_per_images = torch.stack([x.logits_per_image for x in batch['data_samples']], dim=0) + logits_per_images = self.logits_prompt(logits_per_images) # Bx576x16 + clip_dense_embs = torch.cat([clip_dense_embs, logits_per_images], dim=-1) + clip_dense_embs = rearrange(clip_dense_embs, 'b (h w) c -> b c h w', h=int(clip_dense_embs.shape[1]**0.5)) + masks_pred = self.global_prompt(clip_dense_embs) + else: + image_embeddings = torch.stack([x.image_embeddings for x in batch['data_samples']], dim=0) + masks_pred = self.global_prompt(image_embeddings) + return masks_pred + + def forward_sam_prompt_generator(self, batch, *args: Any, **kwargs: Any) -> Any: + inner_states = [x.inner_states for x in batch['data_samples']] + image_embeddings = torch.stack([x.image_embeddings for x in batch['data_samples']], dim=0) + + inner_states_tmp = [] + for idx in range(len(inner_states[0])): + inner_states_tmp.append(torch.stack([x[idx] for x in inner_states], dim=0).to(image_embeddings.device)) + + point_embs, cls_logits = self.sam_prompt_generator(inner_states_tmp) + + # if has points prompt, then get points embeddings + if hasattr(self, 'point_grids'): + points_scale = np.array(img.shape[-2:], dtype=np.float32).reshape(1, -1) # 2, + points_for_image = self.point_grids[0] * points_scale + in_points = torch.as_tensor(points_for_image, device=img.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + in_points = rearrange(in_points, 'n c -> n () c') + in_labels = rearrange(in_labels, 'n -> n ()') + points = (in_points, in_labels) + + sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( + points=points, + boxes=None, + masks=None, + ) # 1024x2x256; 1024x256x64x64 + else: + # ponits_embeddings B T N C + sparse_embeddings = point_embs + dense_embeddings = self.prompt_encoder.no_mask_embed.weight.view(1, 1, -1, 1, 1).expand( + sparse_embeddings.shape[0], sparse_embeddings.shape[1], -1, + self.prompt_encoder.image_embedding_size[0], self.prompt_encoder.image_embedding_size[1] + ) + + + n_img_masks = [] + n_iou_preds = [] + n_class_aware_probs = [] + for curr_img_embedding, cur_s_emb, cur_d_emb in zip(image_embeddings, sparse_embeddings, dense_embeddings): + lr_masks, iou_pred, class_aware_prob = self.mask_decoder( + image_embeddings=curr_img_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=cur_s_emb, + dense_prompt_embeddings=cur_d_emb + ) + mask_slice = slice(0, 1) + masks = lr_masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + class_aware_prob = class_aware_prob[:, mask_slice] + + n_img_masks.append(masks) + n_iou_preds.append(iou_pred) + n_img_masks = torch.stack(n_img_masks, dim=0) + n_iou_preds = torch.stack(n_iou_preds, dim=0) + + return cls_logits, n_img_masks, n_iou_preds + + def forward_sam_prompt_generator_all(self, batch, *args: Any, **kwargs: Any) -> Any: + x = torch.stack(batch['inputs'], dim=0) + # if self.local_rank == 0: + # import pdb; pdb.set_trace() + # self.trainer.strategy.barrier() + x = x[:, [2, 1, 0], :, :] # BGR -> RGB + x = (x - self.img_encoder.pixel_mean) / self.img_encoder.pixel_std + with torch.no_grad(): + image_embeddings, inner_states = self.img_encoder(x) + + point_embs, cls_logits = self.sam_prompt_generator(inner_states) + + # if has points prompt, then get points embeddings + if hasattr(self, 'point_grids'): + points_scale = np.array(img.shape[-2:], dtype=np.float32).reshape(1, -1) # 2, + points_for_image = self.point_grids[0] * points_scale + in_points = torch.as_tensor(points_for_image, device=img.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + in_points = rearrange(in_points, 'n c -> n () c') + in_labels = rearrange(in_labels, 'n -> n ()') + points = (in_points, in_labels) + + sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( + points=points, + boxes=None, + masks=None, + ) # 1024x2x256; 1024x256x64x64 + else: + # ponits_embeddings B T N C + sparse_embeddings = point_embs + dense_embeddings = self.prompt_encoder_no_mask_embed(torch.tensor([0], device=self.device)).view(1, 1, -1, 1, 1).expand( + sparse_embeddings.shape[0], sparse_embeddings.shape[1], -1, + image_embeddings.shape[-2], image_embeddings.shape[-1] + ) + + + n_img_masks = [] + n_iou_preds = [] + n_class_aware_probs = [] + for curr_img_embedding, cur_s_emb, cur_d_emb in zip(image_embeddings, sparse_embeddings, dense_embeddings): + lr_masks, iou_pred, class_aware_prob = self.mask_decoder( + image_embeddings=curr_img_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=cur_s_emb, + dense_prompt_embeddings=cur_d_emb + ) + if self.train_head: + masks = lr_masks + iou_pred = iou_pred + else: + mask_slice = slice(0, 1) + masks = lr_masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + n_img_masks.append(masks) + n_iou_preds.append(iou_pred) + n_img_masks = torch.stack(n_img_masks, dim=0) + n_iou_preds = torch.stack(n_iou_preds, dim=0) + + return cls_logits, n_img_masks, n_iou_preds + + def vis_inter_states(self, batch, masks, *args: Any, **kwargs: Any): + folder = 'results/tmp' + import cv2 + cv2.imwrite(os.path.join(folder, f'img.png'), batch['inputs'][0].permute((1, 2, 0)).detach().cpu().numpy()) + cv2.imwrite(os.path.join(folder, f'label_mask.png'), seg_label[0][0].detach().cpu().numpy() * 255) + masks = masks > 0 + for idx, mask_pred in enumerate(masks[0]): + cv2.imwrite(os.path.join(folder, f'pred_mask_{idx}.png'), mask_pred[0].detach().cpu().numpy() * 255) + import ipdb; ipdb.set_trace() + + + + + + diff --git a/mmpl/models/pler/seg_sam_anchor_pler.py b/mmpl/models/pler/seg_sam_anchor_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..9520c79629a6eb9a926ac7b31de3638d9aaa5e8b --- /dev/null +++ b/mmpl/models/pler/seg_sam_anchor_pler.py @@ -0,0 +1,104 @@ +import torch +from mmengine.structures import InstanceData +from typing import List, Any + +from mmpl.registry import MODELS +from mmseg.utils import SampleList +from .base_pler import BasePLer +import torch.nn.functional as F +from modules.sam import sam_model_registry + + +@MODELS.register_module() +class SegSAMAnchorPLer(BasePLer): + def __init__(self, + backbone, + neck=None, + panoptic_head=None, + need_train_names=None, + train_cfg=None, + test_cfg=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.need_train_names = need_train_names + + backbone_type = backbone.pop('type') + self.backbone = sam_model_registry[backbone_type](**backbone) + + if neck is not None: + self.neck = MODELS.build(neck) + + self.panoptic_head = MODELS.build(panoptic_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def setup(self, stage: str) -> None: + super().setup(stage) + if self.need_train_names is not None: + self._set_grad(self.need_train_names, noneed_train_names=[]) + + def init_weights(self): + import ipdb; ipdb.set_trace() + pass + + def train(self, mode=True): + if self.need_train_names is not None: + return self._set_train_module(mode, self.need_train_names) + else: + super().train(mode) + return self + + @torch.no_grad() + def extract_feat(self, batch_inputs): + feat, inter_features = self.backbone.image_encoder(batch_inputs) + return feat, inter_features + + def validation_step(self, batch, batch_idx): + data = self.data_preprocessor(batch, False) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + + x = self.extract_feat(batch_inputs) + # x = ( + # torch.rand(2, 256, 64, 64).to(self.device), [torch.rand(2, 64, 64, 768).to(self.device) for _ in range(12)]) + results = self.panoptic_head.predict( + x, batch_data_samples, self.backbone) + self.val_evaluator.update(batch, results) + + def training_step(self, batch, batch_idx): + data = self.data_preprocessor(batch, True) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + x = self.extract_feat(batch_inputs) + # x = (torch.rand(2, 256, 64, 64).to(self.device), [torch.rand(2, 64, 64, 768).to(self.device) for _ in range(12)]) + losses = self.panoptic_head.loss(x, batch_data_samples, self.backbone) + + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + + def on_before_optimizer_step(self, optimizer) -> None: + self.log_grad(module=self.panoptic_head) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + data = self.data_preprocessor(batch, False) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + + x = self.extract_feat(batch_inputs) + # x = ( + # torch.rand(2, 256, 64, 64).to(self.device), [torch.rand(2, 64, 64, 768).to(self.device) for _ in range(12)]) + results = self.panoptic_head.predict( + x, batch_data_samples, self.backbone) + return results + + + + + + + diff --git a/mmpl/models/pler/seg_sam_pler.py b/mmpl/models/pler/seg_sam_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..c2207a6eebff2a1b36af7c8206f9f9f58e494143 --- /dev/null +++ b/mmpl/models/pler/seg_sam_pler.py @@ -0,0 +1,201 @@ +import torch +from mmengine.structures import InstanceData +from typing import List, Any + +from mmpl.registry import MODELS +from mmseg.utils import SampleList +from .base_pler import BasePLer +import torch.nn.functional as F +from modules.sam import sam_model_registry + + +@MODELS.register_module() +class SegSAMPLer(BasePLer): + def __init__(self, + backbone, + sam_neck=None, + panoptic_head=None, + panoptic_fusion_head=None, + need_train_names=None, + train_cfg=None, + test_cfg=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.need_train_names = need_train_names + + backbone_type = backbone.pop('type') + self.backbone = sam_model_registry[backbone_type](**backbone) + + if sam_neck is not None: + self.sam_neck = MODELS.build(sam_neck) + + panoptic_head_ = panoptic_head.deepcopy() + panoptic_head_.update(train_cfg=train_cfg) + panoptic_head_.update(test_cfg=test_cfg) + self.panoptic_head = MODELS.build(panoptic_head_) + + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=test_cfg) + self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) + + self.num_things_classes = self.panoptic_head.num_things_classes + self.num_stuff_classes = self.panoptic_head.num_stuff_classes + self.num_classes = self.panoptic_head.num_classes + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def setup(self, stage: str) -> None: + super().setup(stage) + if self.need_train_names is not None: + self._set_grad(self.need_train_names, noneed_train_names=[]) + + def init_weights(self): + import ipdb; ipdb.set_trace() + pass + + def train(self, mode=True): + if self.need_train_names is not None: + return self._set_train_module(mode, self.need_train_names) + else: + super().train(mode) + return self + + @torch.no_grad() + def extract_feat(self, batch_inputs): + feat, inter_features = self.backbone.image_encoder(batch_inputs) + return feat, inter_features + + def validation_step(self, batch, batch_idx): + data = self.data_preprocessor(batch, False) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + + feats = self.extract_feat(batch_inputs) + if hasattr(self, 'sam_neck'): + feats = self.sam_neck(feats) + mask_cls_results, mask_pred_results = self.panoptic_head.predict( + feats, batch_data_samples) + else: + mask_cls_results, mask_pred_results = self.panoptic_head.predict( + feats, batch_data_samples, self.backbone) + + results_list = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results, + batch_data_samples, + rescale=True) + results = self.add_pred_to_datasample(batch_data_samples, results_list) + + # preds = [] + # targets = [] + # for data_sample in results: + # result = dict() + # pred = data_sample.pred_instances + # result['boxes'] = pred['bboxes'] + # result['scores'] = pred['scores'] + # result['labels'] = pred['labels'] + # if 'masks' in pred: + # result['masks'] = pred['masks'] + # preds.append(result) + # # parse gt + # gt = dict() + # gt_data = data_sample.get('gt_instances', None) + # gt['boxes'] = gt_data['bboxes'] + # gt['labels'] = gt_data['labels'] + # if 'masks' in pred: + # gt['masks'] = gt_data['masks'].to_tensor(dtype=torch.bool, device=result['masks'].device) + # targets.append(gt) + # + # self.val_evaluator.update(preds, targets) + self.val_evaluator.update(batch, results) + + def training_step(self, batch, batch_idx): + data = self.data_preprocessor(batch, True) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + x = self.extract_feat(batch_inputs) + if hasattr(self, 'sam_neck'): + x = self.sam_neck(x) + losses = self.panoptic_head.loss(x, batch_data_samples) + else: + losses = self.panoptic_head.loss(x, batch_data_samples, self.backbone) + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + + def on_before_optimizer_step(self, optimizer) -> None: + self.log_grad(module=self.panoptic_head) + + + def add_pred_to_datasample(self, data_samples: SampleList, + results_list: List[dict]) -> SampleList: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (List[dict]): Instance segmentation, segmantic + segmentation and panoptic segmentation results. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + data_sample.pred_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + data_sample.pred_instances = pred_results['ins_results'] + + assert 'sem_results' not in pred_results, 'segmantic ' \ + 'segmentation results are not supported yet.' + + return data_samples + + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + data = self.data_preprocessor(batch, False) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + # import ipdb; ipdb.set_trace() + feats = self.extract_feat(batch_inputs) + if hasattr(self, 'sam_neck'): + feats = self.sam_neck(feats) + mask_cls_results, mask_pred_results = self.panoptic_head.predict( + feats, batch_data_samples) + else: + mask_cls_results, mask_pred_results = self.panoptic_head.predict( + feats, batch_data_samples, self.backbone) + + results_list = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results, + batch_data_samples, + rescale=True) + results = self.add_pred_to_datasample(batch_data_samples, results_list) + return results + + + + + diff --git a/mmpl/models/pler/seg_samdet.py b/mmpl/models/pler/seg_samdet.py new file mode 100644 index 0000000000000000000000000000000000000000..070b2a4a039c4cdc20501dc83d0807d071512f00 --- /dev/null +++ b/mmpl/models/pler/seg_samdet.py @@ -0,0 +1,160 @@ +import torch +from mmengine.structures import InstanceData +from typing import List, Any + +from mmpl.registry import MODELS +from mmseg.utils import SampleList +from .base_pler import BasePLer +import torch.nn.functional as F +from modules.sam import sam_model_registry + + +@MODELS.register_module() +class SegSAMDetPLer(BasePLer): + def __init__(self, + whole_model, + backbone, + neck=None, + panoptic_head=None, + need_train_names=None, + train_cfg=None, + test_cfg=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.need_train_names = need_train_names + + self.whole_model = MODELS.build(whole_model) + backbone_type = backbone.pop('type') + self.backbone = sam_model_registry[backbone_type](**backbone) + + if neck is not None: + self.neck = MODELS.build(neck) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def setup(self, stage: str) -> None: + super().setup(stage) + if self.need_train_names is not None: + self._set_grad(self.need_train_names, noneed_train_names=[]) + + def init_weights(self): + import ipdb; ipdb.set_trace() + pass + + def train(self, mode=True): + if self.need_train_names is not None: + return self._set_train_module(mode, self.need_train_names) + else: + super().train(mode) + return self + + def validation_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, False) + batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore + + batch_inputs = data['inputs'] + feat, inter_features = self.backbone.image_encoder(batch_inputs) + # import ipdb; ipdb.set_trace() + for idx, data_sample in enumerate(batch_data_samples): + bboxes = data_sample.pred_instances['bboxes'] + ori_img_shape = data_sample.ori_shape + if len(bboxes) == 0: + im_mask = torch.zeros( + 0, + ori_img_shape[0], + ori_img_shape[1], + device=self.device, + dtype=torch.bool) + else: + scale_factor = data_sample.scale_factor + repeat_num = int(bboxes.size(-1) / 2) + scale_factor = bboxes.new_tensor(scale_factor).repeat((1, repeat_num)) + bboxes = bboxes * scale_factor + + # Embed prompts + sparse_embeddings, dense_embeddings = self.backbone.prompt_encoder( + points=None, + boxes=bboxes, + masks=None, + ) + + # Predict masks + low_res_masks, iou_predictions = self.backbone.mask_decoder( + image_embeddings=feat[idx:idx + 1], + image_pe=self.backbone.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + ) + # Upscale the masks to the original image resolution + im_mask = F.interpolate(low_res_masks, ori_img_shape, mode="bilinear", align_corners=False) + im_mask = im_mask > 0 + im_mask = im_mask.squeeze(1) + data_sample.pred_instances.masks = im_mask + + self.val_evaluator.update(batch, batch_data_samples) + + def training_step(self, batch, batch_idx): + data = self.whole_model.data_preprocessor(batch, True) + losses = self.whole_model._run_forward(data, mode='loss') # type: ignore + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + + def on_before_optimizer_step(self, optimizer) -> None: + self.log_grad(module=self.whole_model) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + data = self.whole_model.data_preprocessor(batch, False) + batch_data_samples = self.whole_model._run_forward(data, mode='predict') # type: ignore + + batch_inputs = data['inputs'] + feat, inter_features = self.backbone.image_encoder(batch_inputs) + # import ipdb; ipdb.set_trace() + for idx, data_sample in enumerate(batch_data_samples): + bboxes = data_sample.pred_instances['bboxes'] + ori_img_shape = data_sample.ori_shape + if len(bboxes) == 0: + im_mask = torch.zeros( + 0, + ori_img_shape[0], + ori_img_shape[1], + device=self.device, + dtype=torch.bool) + else: + scale_factor = data_sample.scale_factor + repeat_num = int(bboxes.size(-1) / 2) + scale_factor = bboxes.new_tensor(scale_factor).repeat((1, repeat_num)) + bboxes = bboxes * scale_factor + + # Embed prompts + sparse_embeddings, dense_embeddings = self.backbone.prompt_encoder( + points=None, + boxes=bboxes, + masks=None, + ) + + # Predict masks + low_res_masks, iou_predictions = self.backbone.mask_decoder( + image_embeddings=feat[idx:idx + 1], + image_pe=self.backbone.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + ) + # Upscale the masks to the original image resolution + im_mask = F.interpolate(low_res_masks, ori_img_shape, mode="bilinear", align_corners=False) + im_mask = im_mask > 0 + im_mask = im_mask.squeeze(1) + data_sample.pred_instances.masks = im_mask + + return batch_data_samples + + + + + diff --git a/mmpl/models/pler/semseg_sam_pler.py b/mmpl/models/pler/semseg_sam_pler.py new file mode 100644 index 0000000000000000000000000000000000000000..9343ec39a70bde288cf0831197b44766017b15fe --- /dev/null +++ b/mmpl/models/pler/semseg_sam_pler.py @@ -0,0 +1,198 @@ +import torch +from mmengine.structures import InstanceData, PixelData +from typing import List + +from torch import Tensor + +from mmpl.registry import MODELS +from mmseg.models.utils import resize +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList, OptSampleList +from .base_pler import BasePLer +import torch.nn.functional as F +from modules.sam import sam_model_registry + + +@MODELS.register_module() +class SemSegSAMPLer(BasePLer): + def __init__(self, + backbone, + adaphead=None, + decode_head=None, + need_train_names=None, + align_corners=False, + train_cfg=None, + test_cfg=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.need_train_names = need_train_names + self.align_corners = align_corners + + backbone_type = backbone.pop('type') + delete_submodel = backbone.pop('delete_submodel', []) + self.backbone = sam_model_registry[backbone_type](**backbone) + for submodel in delete_submodel: + delattr(self.backbone, submodel) + + if adaphead is not None: + self.adaphead = MODELS.build(adaphead) + + decode_head_ = decode_head.deepcopy() + decode_head_.update(train_cfg=train_cfg) + decode_head_.update(test_cfg=test_cfg) + self.decode_head = MODELS.build(decode_head_) + + self.num_classes = self.decode_head.num_classes + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def setup(self, stage: str) -> None: + if self.need_train_names is not None: + self._set_grad(self.need_train_names, noneed_train_names=[]) + + def init_weights(self): + import ipdb; ipdb.set_trace() + pass + + def train(self, mode=True): + if self.need_train_names is not None: + return self._set_train_module(mode, self.need_train_names) + else: + super().train(mode) + return self + + def extract_feat(self, batch_inputs): + x0, x1 = self.adaphead(batch_inputs, self.backbone.image_encoder) + return x0, x1 + + def validation_step(self, batch, batch_idx): + data = self.data_preprocessor(batch, False) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + + if batch_data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=batch_inputs.shape[2:], + img_shape=batch_inputs.shape[2:], + pad_shape=batch_inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * batch_inputs.shape[0] + + x = self.extract_feat(batch_inputs) + seg_logits = self.decode_head.predict(x, batch_img_metas, self.test_cfg) + + results = self.postprocess_result(seg_logits, batch_data_samples) + + preds = [] + targets = [] + for data_sample in results: + pred_label = data_sample.pred_sem_seg.data.squeeze() + label = data_sample.gt_sem_seg.data.squeeze().to(pred_label) + + preds.append(pred_label) + targets.append(label) + preds = torch.stack(preds, dim=0) + targets = torch.stack(targets, dim=0) + self.val_evaluator.update(preds, targets) + + def training_step(self, batch, batch_idx): + # import ipdb; ipdb.set_trace() + data = self.data_preprocessor(batch, True) + batch_inputs = data['inputs'] + batch_data_samples = data['data_samples'] + x = self.extract_feat(batch_inputs) + losses = self.decode_head.loss(x, batch_data_samples) + # import ipdb; ipdb.set_trace() + parsed_losses, log_vars = self.parse_losses(losses) + log_vars = {f'train_{k}': v for k, v in log_vars.items()} + log_vars['loss'] = parsed_losses + self.log_dict(log_vars, prog_bar=True) + return log_vars + + def on_before_optimizer_step(self, optimizer) -> None: + self.log_grad(module=self.adaphead) + + def postprocess_result(self, + seg_logits: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + seg_logits (Tensor): The segmentation results, seg_logits from + model of each input image. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + batch_size, C, H, W = seg_logits.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_seg_logits shape is 1, C, H, W after remove padding + i_seg_logits = seg_logits[i:i + 1, :, + padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_seg_logits = i_seg_logits.flip(dims=(3, )) + else: + i_seg_logits = i_seg_logits.flip(dims=(2, )) + + # resize as original shape + i_seg_logits = resize( + i_seg_logits, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_seg_logits = seg_logits[i] + + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) + data_samples[i].set_data({ + 'seg_logits': + PixelData(**{'data': i_seg_logits}), + 'pred_sem_seg': + PixelData(**{'data': i_seg_pred}) + }) + + return data_samples + + + + diff --git a/mmpl/models/utils/__init__.py b/mmpl/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cdfeaaf0f206fd62dda27cbf44f519777da56ea8 --- /dev/null +++ b/mmpl/models/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .misc import gt_instances_preprocess, make_divisible, make_round + +__all__ = ['make_divisible', 'make_round', 'gt_instances_preprocess'] diff --git a/mmpl/models/utils/misc.py b/mmpl/models/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1780bb20921f19c86ea01f93807e6690f59ad3 --- /dev/null +++ b/mmpl/models/utils/misc.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence, Union + +import torch +from mmdet.structures.bbox.transforms import get_box_tensor +from torch import Tensor + + +def make_divisible(x: float, + widen_factor: float = 1.0, + divisor: int = 8) -> int: + """Make sure that x*widen_factor is divisible by divisor.""" + return math.ceil(x * widen_factor / divisor) * divisor + + +def make_round(x: float, deepen_factor: float = 1.0) -> int: + """Make sure that x*deepen_factor becomes an integer not less than 1.""" + return max(round(x * deepen_factor), 1) if x > 1 else x + + +def gt_instances_preprocess(batch_gt_instances: Union[Tensor, Sequence], + batch_size: int) -> Tensor: + """Split batch_gt_instances with batch size. + + From [all_gt_bboxes, box_dim+2] to [batch_size, number_gt, box_dim+1]. + For horizontal box, box_dim=4, for rotated box, box_dim=5 + + If some shape of single batch smaller than + gt bbox len, then using zeros to fill. + + Args: + batch_gt_instances (Sequence[Tensor]): Ground truth + instances for whole batch, shape [all_gt_bboxes, box_dim+2] + batch_size (int): Batch size. + + Returns: + Tensor: batch gt instances data, shape + [batch_size, number_gt, box_dim+1] + """ + if isinstance(batch_gt_instances, Sequence): + max_gt_bbox_len = max( + [len(gt_instances) for gt_instances in batch_gt_instances]) + # fill zeros with length box_dim+1 if some shape of + # single batch not equal max_gt_bbox_len + batch_instance_list = [] + for index, gt_instance in enumerate(batch_gt_instances): + bboxes = gt_instance.bboxes + labels = gt_instance.labels + box_dim = get_box_tensor(bboxes).size(-1) + batch_instance_list.append( + torch.cat((labels[:, None], bboxes), dim=-1)) + + if bboxes.shape[0] >= max_gt_bbox_len: + continue + + fill_tensor = bboxes.new_full( + [max_gt_bbox_len - bboxes.shape[0], box_dim + 1], 0) + batch_instance_list[index] = torch.cat( + (batch_instance_list[index], fill_tensor), dim=0) + + return torch.stack(batch_instance_list) + else: + # faster version + # format of batch_gt_instances: [img_ind, cls_ind, (box)] + # For example horizontal box should be: + # [img_ind, cls_ind, x1, y1, x2, y2] + # Rotated box should be + # [img_ind, cls_ind, x, y, w, h, a] + + # sqlit batch gt instance [all_gt_bboxes, box_dim+2] -> + # [batch_size, max_gt_bbox_len, box_dim+1] + assert isinstance(batch_gt_instances, Tensor) + box_dim = batch_gt_instances.size(-1) - 2 + if len(batch_gt_instances) > 0: + gt_images_indexes = batch_gt_instances[:, 0] + # 注意 + max_gt_bbox_len = torch.unique(gt_images_indexes.to('cpu'), return_counts=True)[1].max().to(gt_images_indexes.device) + # fill zeros with length box_dim+1 if some shape of + # single batch not equal max_gt_bbox_len + batch_instance = torch.zeros( + (batch_size, max_gt_bbox_len, box_dim + 1), + dtype=batch_gt_instances.dtype, + device=batch_gt_instances.device) + + for i in range(batch_size): + match_indexes = gt_images_indexes == i + gt_num = match_indexes.sum() + if gt_num: + batch_instance[i, :gt_num] = batch_gt_instances[ + match_indexes, 1:] + else: + batch_instance = torch.zeros((batch_size, 0, box_dim + 1), + dtype=batch_gt_instances.dtype, + device=batch_gt_instances.device) + + return batch_instance diff --git a/mmpl/registry.py b/mmpl/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1692f6ade7e6c7740ef0f2677d74369317a6021f --- /dev/null +++ b/mmpl/registry.py @@ -0,0 +1,97 @@ +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +LOGGERS = Registry('logger', locations=['mmpl.engine.logger']) +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry( + 'runner', parent=MMENGINE_RUNNERS, locations=['mmpl.engine']) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', + parent=MMENGINE_RUNNER_CONSTRUCTORS, + locations=['mmpl.engine']) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS, locations=['mmpl.engine']) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry( + 'hook', parent=MMENGINE_HOOKS, locations=['mmpl.engine.hooks']) + +# manage data-related modules +DATASETS = Registry( + 'dataset', parent=MMENGINE_DATASETS, locations=['mmpl.datasets']) + +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['mmpl.datasets']) +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmpl.datasets.transforms']) + +# manage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmpl.models']) +# manage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmpl.models']) +# manage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmpl.models']) + +# manage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmpl.engine.optimizers']) +OPTIM_WRAPPERS = Registry( + 'optim_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmpl.engine.optimizers']) +# manage constructors that customize the optimization hyperparameters. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmpl.engine.optimizers']) +# manage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmpl.engine.optimizers']) +# manage all kinds of metrics +METRICS = Registry( + 'metric', parent=MMENGINE_METRICS, locations=['mmpl.engine']) + +# manage task-specific modules like anchor generators and box coders +TASK_UTILS = Registry( + 'task util', parent=MMENGINE_TASK_UTILS, locations=['mmpl.models']) + +# manage visualizer +VISUALIZERS = Registry( + 'visualizer', parent=MMENGINE_VISUALIZERS, locations=['mmpl.visualization']) +# manage visualizer backend +VISBACKENDS = Registry( + 'vis_backend', parent=MMENGINE_VISBACKENDS, locations=['mmpl.utils']) diff --git a/mmpl/structures/__init__.py b/mmpl/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3021d0a7d0b7fb1b342295ad0a4e99c675b4e52c --- /dev/null +++ b/mmpl/structures/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cls_data_sample import ClsDataSample +from .multi_task_data_sample import MultiTaskDataSample +from .utils import (batch_label_to_onehot, cat_batch_labels, + stack_batch_scores, tensor_split) + +__all__ = [ + 'ClsDataSample', 'batch_label_to_onehot', 'cat_batch_labels', + 'stack_batch_scores', 'tensor_split', 'MultiTaskDataSample' +] diff --git a/mmpl/structures/__pycache__/__init__.cpython-310.pyc b/mmpl/structures/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c30a93328eeaa5b6f843521ebde9eed4fcba851e Binary files /dev/null and b/mmpl/structures/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/structures/__pycache__/cls_data_sample.cpython-310.pyc b/mmpl/structures/__pycache__/cls_data_sample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4747e5bb0d22cc806ec37de0f755d71491b3b26 Binary files /dev/null and b/mmpl/structures/__pycache__/cls_data_sample.cpython-310.pyc differ diff --git a/mmpl/structures/__pycache__/multi_task_data_sample.cpython-310.pyc b/mmpl/structures/__pycache__/multi_task_data_sample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fd460b8e347f7cee082c721bc868e9b37406c15 Binary files /dev/null and b/mmpl/structures/__pycache__/multi_task_data_sample.cpython-310.pyc differ diff --git a/mmpl/structures/__pycache__/utils.cpython-310.pyc b/mmpl/structures/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecabea3ffc1d5972fcf6a3fb7439e461f5716060 Binary files /dev/null and b/mmpl/structures/__pycache__/utils.cpython-310.pyc differ diff --git a/mmpl/structures/cls_data_sample.py b/mmpl/structures/cls_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..9e319a7bb830ccc2d39e76c94047b14e229278a7 --- /dev/null +++ b/mmpl/structures/cls_data_sample.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from multiprocessing.reduction import ForkingPickler +from numbers import Number +from typing import Sequence, Union + +import numpy as np +import torch +from mmengine.structures import BaseDataElement, LabelData +from mmengine.utils import is_str + + +def format_label( + value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor: + """Convert various python types to label-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. + + Returns: + :obj:`torch.Tensor`: The foramtted label tensor. + """ + + # Handle single number + if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: + value = int(value.item()) + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(torch.long) + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).to(torch.long) + elif isinstance(value, int): + value = torch.LongTensor([value]) + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +def format_score( + value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor: + """Convert various python types to score-format tensor. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`. + + Args: + value (torch.Tensor | numpy.ndarray | Sequence): Score values. + + Returns: + :obj:`torch.Tensor`: The foramtted score tensor. + """ + + if isinstance(value, np.ndarray): + value = torch.from_numpy(value).float() + elif isinstance(value, Sequence) and not is_str(value): + value = torch.tensor(value).float() + elif not isinstance(value, torch.Tensor): + raise TypeError(f'Type {type(value)} is not an available label type.') + assert value.ndim == 1, \ + f'The dims of value should be 1, but got {value.ndim}.' + + return value + + +class ClsDataSample(BaseDataElement): + """A data structure interface of classification task. + + It's used as interfaces between different components. + + Meta fields: + img_shape (Tuple): The shape of the corresponding input image. + Used for visualization. + ori_shape (Tuple): The original shape of the corresponding image. + Used for visualization. + num_classes (int): The number of all categories. + Used for label format conversion. + + Data fields: + gt_label (:obj:`~mmengine.structures.LabelData`): The ground truth + label. + pred_label (:obj:`~mmengine.structures.LabelData`): The predicted + label. + scores (torch.Tensor): The outputs of model. + logits (torch.Tensor): The outputs of model without softmax nor + sigmoid. + + Examples: + >>> import torch + >>> from mmcls.structures import ClsDataSample + >>> + >>> img_meta = dict(img_shape=(960, 720), num_classes=5) + >>> data_sample = ClsDataSample(metainfo=img_meta) + >>> data_sample.set_gt_label(3) + >>> print(data_sample) + + ) at 0x7f21fb1b9880> + >>> # For multi-label data + >>> data_sample.set_gt_label([0, 1, 4]) + >>> print(data_sample.gt_label) + + >>> # Set one-hot format score + >>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]) + >>> data_sample.set_pred_score(score) + >>> print(data_sample.pred_label) + + """ + + def set_gt_label( + self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] + ) -> 'ClsDataSample': + """Set label of ``gt_label``.""" + label_data = getattr(self, '_gt_label', LabelData()) + label_data.label = format_label(value) + self.gt_label = label_data + return self + + def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample': + """Set score of ``gt_label``.""" + label_data = getattr(self, '_gt_label', LabelData()) + label_data.score = format_score(value) + if hasattr(self, 'num_classes'): + assert len(label_data.score) == self.num_classes, \ + f'The length of score {len(label_data.score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', + value=len(label_data.score), + field_type='metainfo') + self.gt_label = label_data + return self + + def set_pred_label( + self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] + ) -> 'ClsDataSample': + """Set label of ``pred_label``.""" + label_data = getattr(self, '_pred_label', LabelData()) + label_data.label = format_label(value) + self.pred_label = label_data + return self + + def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample': + """Set score of ``pred_label``.""" + label_data = getattr(self, '_pred_label', LabelData()) + label_data.score = format_score(value) + if hasattr(self, 'num_classes'): + assert len(label_data.score) == self.num_classes, \ + f'The length of score {len(label_data.score)} should be '\ + f'equal to the num_classes {self.num_classes}.' + else: + self.set_field( + name='num_classes', + value=len(label_data.score), + field_type='metainfo') + self.pred_label = label_data + return self + + @property + def gt_label(self): + return self._gt_label + + @gt_label.setter + def gt_label(self, value: LabelData): + self.set_field(value, '_gt_label', dtype=LabelData) + + @gt_label.deleter + def gt_label(self): + del self._gt_label + + @property + def pred_label(self): + return self._pred_label + + @pred_label.setter + def pred_label(self, value: LabelData): + self.set_field(value, '_pred_label', dtype=LabelData) + + @pred_label.deleter + def pred_label(self): + del self._pred_label + + +def _reduce_cls_datasample(data_sample): + """reduce ClsDataSample.""" + attr_dict = data_sample.__dict__ + convert_keys = [] + for k, v in attr_dict.items(): + if isinstance(v, LabelData): + attr_dict[k] = v.numpy() + convert_keys.append(k) + return _rebuild_cls_datasample, (attr_dict, convert_keys) + + +def _rebuild_cls_datasample(attr_dict, convert_keys): + """rebuild ClsDataSample.""" + data_sample = ClsDataSample() + for k in convert_keys: + attr_dict[k] = attr_dict[k].to_tensor() + data_sample.__dict__ = attr_dict + return data_sample + + +# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume +# many file descriptors because it contains multiple LabelData with tensors. +# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and +# convert these tensors to np.ndarray during pickling. It may influence the +# performance of dataloader, but slightly because these tensors in LabelData +# are very small. +ForkingPickler.register(ClsDataSample, _reduce_cls_datasample) diff --git a/mmpl/structures/multi_task_data_sample.py b/mmpl/structures/multi_task_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f00993861bfb4f35fb7d145198f81c5e9f0a5993 --- /dev/null +++ b/mmpl/structures/multi_task_data_sample.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmengine.structures import BaseDataElement + + +class MultiTaskDataSample(BaseDataElement): + + @property + def tasks(self): + return self._data_fields diff --git a/mmpl/structures/utils.py b/mmpl/structures/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8f0f3da643ba3355890c939a1483d19bdd3738 --- /dev/null +++ b/mmpl/structures/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmengine.structures import LabelData + +if hasattr(torch, 'tensor_split'): + tensor_split = torch.tensor_split +else: + # A simple implementation of `tensor_split`. + def tensor_split(input: torch.Tensor, indices: list): + outs = [] + for start, end in zip([0] + indices, indices + [input.size(0)]): + outs.append(input[start:end]) + return outs + + +def cat_batch_labels(elements: List[LabelData], device=None): + """Concat the ``label`` of a batch of :obj:`LabelData` to a tensor. + + Args: + elements (List[LabelData]): A batch of :obj`LabelData`. + device (torch.device, optional): The output device of the batch label. + Defaults to None. + + Returns: + Tuple[torch.Tensor, List[int]]: The first item is the concated label + tensor, and the second item is the split indices of every sample. + """ + item = elements[0] + if 'label' not in item._data_fields: + return None, None + + labels = [] + splits = [0] + for element in elements: + labels.append(element.label) + splits.append(splits[-1] + element.label.size(0)) + batch_label = torch.cat(labels) + if device is not None: + batch_label = batch_label.to(device=device) + return batch_label, splits[1:-1] + + +def batch_label_to_onehot(batch_label, split_indices, num_classes): + """Convert a concated label tensor to onehot format. + + Args: + batch_label (torch.Tensor): A concated label tensor from multiple + samples. + split_indices (List[int]): The split indices of every sample. + num_classes (int): The number of classes. + + Returns: + torch.Tensor: The onehot format label tensor. + + Examples: + >>> import torch + >>> from mmcls.structures import batch_label_to_onehot + >>> # Assume a concated label from 3 samples. + >>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1] + >>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1]) + >>> split_indices = [2, 5] + >>> batch_label_to_onehot(batch_label, split_indices, num_classes=5) + tensor([[1, 1, 0, 0, 0], + [1, 0, 1, 0, 1], + [0, 1, 0, 1, 0]]) + """ + sparse_onehot_list = F.one_hot(batch_label, num_classes) + onehot_list = [ + sparse_onehot.sum(0) + for sparse_onehot in tensor_split(sparse_onehot_list, split_indices) + ] + return torch.stack(onehot_list) + + +def stack_batch_scores(elements, device=None): + """Stack the ``score`` of a batch of :obj:`LabelData` to a tensor. + + Args: + elements (List[LabelData]): A batch of :obj`LabelData`. + device (torch.device, optional): The output device of the batch label. + Defaults to None. + + Returns: + torch.Tensor: The stacked score tensor. + """ + item = elements[0] + if 'score' not in item._data_fields: + return None + + batch_score = torch.stack([element.score for element in elements]) + if device is not None: + batch_score = batch_score.to(device) + return batch_score diff --git a/mmpl/utils/__init__.py b/mmpl/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53e90f15c7d80a807ede74e6192fc039da92d83a --- /dev/null +++ b/mmpl/utils/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .collect_env import collect_env +from .misc import is_metainfo_lower, switch_to_deploy +from .setup_env import register_all_modules +from .typing_utils import * + +__all__ = [ + 'register_all_modules', 'collect_env', 'switch_to_deploy', + 'is_metainfo_lower', 'ConfigType', 'OptMultiConfig', 'MultiConfig', +] diff --git a/mmpl/utils/__pycache__/__init__.cpython-310.pyc b/mmpl/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..959e5804045385d61203ba106dc53c2d8ee15cb7 Binary files /dev/null and b/mmpl/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmpl/utils/__pycache__/collect_env.cpython-310.pyc b/mmpl/utils/__pycache__/collect_env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e10b438fdd1fe0e4713fa3b5a74b8f6044c2bf8a Binary files /dev/null and b/mmpl/utils/__pycache__/collect_env.cpython-310.pyc differ diff --git a/mmpl/utils/__pycache__/misc.cpython-310.pyc b/mmpl/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..004bfa567b8072a1b367b4c42c299e027da272d6 Binary files /dev/null and b/mmpl/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/mmpl/utils/__pycache__/setup_env.cpython-310.pyc b/mmpl/utils/__pycache__/setup_env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8e84b7d0e2960225843c48859166f21bca4629e Binary files /dev/null and b/mmpl/utils/__pycache__/setup_env.cpython-310.pyc differ diff --git a/mmpl/utils/__pycache__/typing_utils.cpython-310.pyc b/mmpl/utils/__pycache__/typing_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35a30b7a04652cf3073d8404dd98a190711e03d5 Binary files /dev/null and b/mmpl/utils/__pycache__/typing_utils.cpython-310.pyc differ diff --git a/mmpl/utils/boxam_utils.py b/mmpl/utils/boxam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a46f21c1b5b40e7bc106ae7a15281816ae3efcc --- /dev/null +++ b/mmpl/utils/boxam_utils.py @@ -0,0 +1,512 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import bisect +import copy +import warnings +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torchvision +from mmcv.transforms import Compose +from mmdet.evaluation import get_classes +from mmdet.utils import ConfigType +from mmengine.config import Config +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint +from mmengine.structures import InstanceData +from torch import Tensor + +from mmyolo.registry import MODELS + +try: + from pytorch_grad_cam import (AblationCAM, AblationLayer, + ActivationsAndGradients) + from pytorch_grad_cam import GradCAM as Base_GradCAM + from pytorch_grad_cam import GradCAMPlusPlus as Base_GradCAMPlusPlus + from pytorch_grad_cam.base_cam import BaseCAM + from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image + from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection +except ImportError: + pass + + +def init_detector( + config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + palette: str = 'coco', + device: str = 'cuda:0', + cfg_options: Optional[dict] = None, +) -> nn.Module: + """Initialize a detector from config file. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + palette (str): Color palette used for visualization. If palette + is stored in checkpoint, use checkpoint's palette first, otherwise + use externally passed palette. Currently, supports 'coco', 'voc', + 'citys' and 'random'. Defaults to coco. + device (str): The device where the anchors will be put on. + Defaults to cuda:0. + cfg_options (dict, optional): Options to override some settings in + the used config. + + Returns: + nn.Module: The constructed detector. + """ + if isinstance(config, (str, Path)): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + if cfg_options is not None: + config.merge_from_dict(cfg_options) + elif 'init_cfg' in config.model.backbone: + config.model.backbone.init_cfg = None + + # only change this + # grad based method requires train_cfg + # config.model.train_cfg = None + init_default_scope(config.get('default_scope', 'mmyolo')) + + model = MODELS.build(config.model) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + # Weights converted from elsewhere may not have meta fields. + checkpoint_meta = checkpoint.get('meta', {}) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + # mmdet 3.x, all keys should be lowercase + model.dataset_meta = { + k.lower(): v + for k, v in checkpoint_meta['dataset_meta'].items() + } + elif 'CLASSES' in checkpoint_meta: + # < mmdet 3.x + classes = checkpoint_meta['CLASSES'] + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.simplefilter('once') + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, use COCO classes by default.') + model.dataset_meta = { + 'classes': get_classes('coco'), + 'palette': palette + } + + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def reshape_transform(feats: Union[Tensor, List[Tensor]], + max_shape: Tuple[int, int] = (20, 20), + is_need_grad: bool = False): + """Reshape and aggregate feature maps when the input is a multi-layer + feature map. + + Takes these tensors with different sizes, resizes them to a common shape, + and concatenates them. + """ + if len(max_shape) == 1: + max_shape = max_shape * 2 + + if isinstance(feats, torch.Tensor): + feats = [feats] + else: + if is_need_grad: + raise NotImplementedError('The `grad_base` method does not ' + 'support output multi-activation layers') + + max_h = max([im.shape[-2] for im in feats]) + max_w = max([im.shape[-1] for im in feats]) + if -1 in max_shape: + max_shape = (max_h, max_w) + else: + max_shape = (min(max_h, max_shape[0]), min(max_w, max_shape[1])) + + activations = [] + for feat in feats: + activations.append( + torch.nn.functional.interpolate( + torch.abs(feat), max_shape, mode='bilinear')) + + activations = torch.cat(activations, axis=1) + return activations + + +class BoxAMDetectorWrapper(nn.Module): + """Wrap the mmdet model class to facilitate handling of non-tensor + situations during inference.""" + + def __init__(self, + cfg: ConfigType, + checkpoint: str, + score_thr: float, + device: str = 'cuda:0'): + super().__init__() + self.cfg = cfg + self.device = device + self.score_thr = score_thr + self.checkpoint = checkpoint + self.detector = init_detector(self.cfg, self.checkpoint, device=device) + + pipeline_cfg = copy.deepcopy(self.cfg.test_dataloader.dataset.pipeline) + pipeline_cfg[0].type = 'mmdet.LoadImageFromNDArray' + + new_test_pipeline = [] + for pipeline in pipeline_cfg: + if not pipeline['type'].endswith('LoadAnnotations'): + new_test_pipeline.append(pipeline) + self.test_pipeline = Compose(new_test_pipeline) + + self.is_need_loss = False + self.input_data = None + self.image = None + + def need_loss(self, is_need_loss: bool): + """Grad-based methods require loss.""" + self.is_need_loss = is_need_loss + + def set_input_data(self, + image: np.ndarray, + pred_instances: Optional[InstanceData] = None): + """Set the input data to be used in the next step.""" + self.image = image + + if self.is_need_loss: + assert pred_instances is not None + pred_instances = pred_instances.numpy() + data = dict( + img=self.image, + img_id=0, + gt_bboxes=pred_instances.bboxes, + gt_bboxes_labels=pred_instances.labels) + data = self.test_pipeline(data) + else: + data = dict(img=self.image, img_id=0) + data = self.test_pipeline(data) + data['inputs'] = [data['inputs']] + data['data_samples'] = [data['data_samples']] + self.input_data = data + + def __call__(self, *args, **kwargs): + assert self.input_data is not None + if self.is_need_loss: + # Maybe this is a direction that can be optimized + # self.detector.init_weights() + + self.detector.bbox_head.head_module.training = True + if hasattr(self.detector.bbox_head, 'featmap_sizes'): + # Prevent the model algorithm error when calculating loss + self.detector.bbox_head.featmap_sizes = None + + data_ = {} + data_['inputs'] = [self.input_data['inputs']] + data_['data_samples'] = [self.input_data['data_samples']] + data = self.detector.data_preprocessor(data_, training=False) + loss = self.detector._run_forward(data, mode='loss') + + if hasattr(self.detector.bbox_head, 'featmap_sizes'): + self.detector.bbox_head.featmap_sizes = None + + return [loss] + else: + self.detector.bbox_head.head_module.training = False + with torch.no_grad(): + results = self.detector.test_step(self.input_data) + return results + + +class BoxAMDetectorVisualizer: + """Box AM visualization class.""" + + def __init__(self, + method_class, + model: nn.Module, + target_layers: List, + reshape_transform: Optional[Callable] = None, + is_need_grad: bool = False, + extra_params: Optional[dict] = None): + self.target_layers = target_layers + self.reshape_transform = reshape_transform + self.is_need_grad = is_need_grad + + if method_class.__name__ == 'AblationCAM': + batch_size = extra_params.get('batch_size', 1) + ratio_channels_to_ablate = extra_params.get( + 'ratio_channels_to_ablate', 1.) + self.cam = AblationCAM( + model, + target_layers, + use_cuda=True if 'cuda' in model.device else False, + reshape_transform=reshape_transform, + batch_size=batch_size, + ablation_layer=extra_params['ablation_layer'], + ratio_channels_to_ablate=ratio_channels_to_ablate) + else: + self.cam = method_class( + model, + target_layers, + use_cuda=True if 'cuda' in model.device else False, + reshape_transform=reshape_transform, + ) + if self.is_need_grad: + self.cam.activations_and_grads.release() + + self.classes = model.detector.dataset_meta['classes'] + self.COLORS = np.random.uniform(0, 255, size=(len(self.classes), 3)) + + def switch_activations_and_grads(self, model) -> None: + """In the grad-based method, we need to switch + ``ActivationsAndGradients`` layer, otherwise an error will occur.""" + self.cam.model = model + + if self.is_need_grad is True: + self.cam.activations_and_grads = ActivationsAndGradients( + model, self.target_layers, self.reshape_transform) + self.is_need_grad = False + else: + self.cam.activations_and_grads.release() + self.is_need_grad = True + + def __call__(self, img, targets, aug_smooth=False, eigen_smooth=False): + img = torch.from_numpy(img)[None].permute(0, 3, 1, 2) + return self.cam(img, targets, aug_smooth, eigen_smooth)[0, :] + + def show_am(self, + image: np.ndarray, + pred_instance: InstanceData, + grayscale_am: np.ndarray, + with_norm_in_bboxes: bool = False): + """Normalize the AM to be in the range [0, 1] inside every bounding + boxes, and zero outside of the bounding boxes.""" + + boxes = pred_instance.bboxes + labels = pred_instance.labels + + if with_norm_in_bboxes is True: + boxes = boxes.astype(np.int32) + renormalized_am = np.zeros(grayscale_am.shape, dtype=np.float32) + images = [] + for x1, y1, x2, y2 in boxes: + img = renormalized_am * 0 + img[y1:y2, x1:x2] = scale_cam_image( + [grayscale_am[y1:y2, x1:x2].copy()])[0] + images.append(img) + + renormalized_am = np.max(np.float32(images), axis=0) + renormalized_am = scale_cam_image([renormalized_am])[0] + else: + renormalized_am = grayscale_am + + am_image_renormalized = show_cam_on_image( + image / 255, renormalized_am, use_rgb=False) + + image_with_bounding_boxes = self._draw_boxes( + boxes, labels, am_image_renormalized, pred_instance.get('scores')) + return image_with_bounding_boxes + + def _draw_boxes(self, + boxes: List, + labels: List, + image: np.ndarray, + scores: Optional[List] = None): + """draw boxes on image.""" + for i, box in enumerate(boxes): + label = labels[i] + color = self.COLORS[label] + cv2.rectangle(image, (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), color, 2) + if scores is not None: + score = scores[i] + text = str(self.classes[label]) + ': ' + str( + round(score * 100, 1)) + else: + text = self.classes[label] + + cv2.putText( + image, + text, (int(box[0]), int(box[1] - 5)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 1, + lineType=cv2.LINE_AA) + return image + + +class DetAblationLayer(AblationLayer): + """Det AblationLayer.""" + + def __init__(self): + super().__init__() + self.activations = None + + def set_next_batch(self, input_batch_index, activations, + num_channels_to_ablate): + """Extract the next batch member from activations, and repeat it + num_channels_to_ablate times.""" + if isinstance(activations, torch.Tensor): + return super().set_next_batch(input_batch_index, activations, + num_channels_to_ablate) + + self.activations = [] + for activation in activations: + activation = activation[ + input_batch_index, :, :, :].clone().unsqueeze(0) + self.activations.append( + activation.repeat(num_channels_to_ablate, 1, 1, 1)) + + def __call__(self, x): + """Go over the activation indices to be ablated, stored in + self.indices.""" + result = self.activations + + if isinstance(result, torch.Tensor): + return super().__call__(x) + + channel_cumsum = np.cumsum([r.shape[1] for r in result]) + num_channels_to_ablate = result[0].size(0) # batch + for i in range(num_channels_to_ablate): + pyramid_layer = bisect.bisect_right(channel_cumsum, + self.indices[i]) + if pyramid_layer > 0: + index_in_pyramid_layer = self.indices[i] - channel_cumsum[ + pyramid_layer - 1] + else: + index_in_pyramid_layer = self.indices[i] + result[pyramid_layer][i, index_in_pyramid_layer, :, :] = -1000 + return result + + +class DetBoxScoreTarget: + """Det Score calculation class. + + In the case of the grad-free method, the calculation method is that + for every original detected bounding box specified in "bboxes", + assign a score on how the current bounding boxes match it, + + 1. In Bbox IoU + 2. In the classification score. + 3. In Mask IoU if ``segms`` exist. + + If there is not a large enough overlap, or the category changed, + assign a score of 0. The total score is the sum of all the box scores. + + In the case of the grad-based method, the calculation method is + the sum of losses after excluding a specific key. + """ + + def __init__(self, + pred_instance: InstanceData, + match_iou_thr: float = 0.5, + device: str = 'cuda:0', + ignore_loss_params: Optional[List] = None): + self.focal_bboxes = pred_instance.bboxes + self.focal_labels = pred_instance.labels + self.match_iou_thr = match_iou_thr + self.device = device + self.ignore_loss_params = ignore_loss_params + if ignore_loss_params is not None: + assert isinstance(self.ignore_loss_params, list) + + def __call__(self, results): + output = torch.tensor([0.], device=self.device) + + if 'loss_cls' in results: + # grad-based method + # results is dict + for loss_key, loss_value in results.items(): + if 'loss' not in loss_key or \ + loss_key in self.ignore_loss_params: + continue + if isinstance(loss_value, list): + output += sum(loss_value) + else: + output += loss_value + return output + else: + # grad-free method + # results is DetDataSample + pred_instances = results.pred_instances + if len(pred_instances) == 0: + return output + + pred_bboxes = pred_instances.bboxes + pred_scores = pred_instances.scores + pred_labels = pred_instances.labels + + for focal_box, focal_label in zip(self.focal_bboxes, + self.focal_labels): + ious = torchvision.ops.box_iou(focal_box[None], + pred_bboxes[..., :4]) + index = ious.argmax() + if ious[0, index] > self.match_iou_thr and pred_labels[ + index] == focal_label: + # TODO: Adaptive adjustment of weights based on algorithms + score = ious[0, index] + pred_scores[index] + output = output + score + return output + + +class SpatialBaseCAM(BaseCAM): + """CAM that maintains spatial information. + + Gradients are often averaged over the spatial dimension in CAM + visualization for classification, but this is unreasonable in detection + tasks. There is no need to average the gradients in the detection task. + """ + + def get_cam_image(self, + input_tensor: torch.Tensor, + target_layer: torch.nn.Module, + targets: List[torch.nn.Module], + activations: torch.Tensor, + grads: torch.Tensor, + eigen_smooth: bool = False) -> np.ndarray: + + weights = self.get_cam_weights(input_tensor, target_layer, targets, + activations, grads) + weighted_activations = weights * activations + if eigen_smooth: + cam = get_2d_projection(weighted_activations) + else: + cam = weighted_activations.sum(axis=1) + return cam + + +class GradCAM(SpatialBaseCAM, Base_GradCAM): + """Gradients are no longer averaged over the spatial dimension.""" + + def get_cam_weights(self, input_tensor, target_layer, target_category, + activations, grads): + return grads + + +class GradCAMPlusPlus(SpatialBaseCAM, Base_GradCAMPlusPlus): + """Gradients are no longer averaged over the spatial dimension.""" + + def get_cam_weights(self, input_tensor, target_layers, target_category, + activations, grads): + grads_power_2 = grads**2 + grads_power_3 = grads_power_2 * grads + # Equation 19 in https://arxiv.org/abs/1710.11063 + sum_activations = np.sum(activations, axis=(2, 3)) + eps = 0.000001 + aij = grads_power_2 / ( + 2 * grads_power_2 + + sum_activations[:, :, None, None] * grads_power_3 + eps) + # Now bring back the ReLU from eq.7 in the paper, + # And zero out aijs where the activations are 0 + aij = np.where(grads != 0, aij, 0) + + weights = np.maximum(grads, 0) * aij + return weights diff --git a/mmpl/utils/collect_env.py b/mmpl/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..94c675c841d74af49964c17ab360a6d3d754b4e2 --- /dev/null +++ b/mmpl/utils/collect_env.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + + +def collect_env() -> dict: + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMCV'] = mmcv.__version__ + env_info['MMDetection'] = mmdet.__version__ + env_info['MMYOLO'] = mmyolo.__version__ + '+' + get_git_hash()[:7] + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/mmpl/utils/labelme_utils.py b/mmpl/utils/labelme_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0981919771a617ca79b29c3ddf96ea14c82fccc6 --- /dev/null +++ b/mmpl/utils/labelme_utils.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path + +from mmengine.structures import InstanceData + + +class LabelmeFormat: + """Predict results save into labelme file. + + Base on https://github.com/wkentaro/labelme/blob/main/labelme/label_file.py + + Args: + classes (tuple): Model classes name. + """ + + def __init__(self, classes: tuple): + super().__init__() + self.classes = classes + + def __call__(self, pred_instances: InstanceData, metainfo: dict, + output_path: str, selected_classes: list): + """Get image data field for labelme. + + Args: + pred_instances (InstanceData): Candidate prediction info. + metainfo (dict): Meta info of prediction. + output_path (str): Image file path. + selected_classes (list): Selected class name. + + Labelme file eg. + { + "version": "5.1.1", + "flags": {}, + "imagePath": "/data/cat/1.jpg", + "imageData": null, + "imageHeight": 3000, + "imageWidth": 4000, + "shapes": [ + { + "label": "cat", + "points": [ + [ + 1148.076923076923, + 1188.4615384615383 + ], + [ + 2471.1538461538457, + 2176.923076923077 + ] + ], + "group_id": null, + "shape_type": "rectangle", + "flags": {} + }, + {...} + ] + } + """ + + image_path = os.path.abspath(metainfo['img_path']) + + json_info = { + 'version': '5.1.1', + 'flags': {}, + 'imagePath': image_path, + 'imageData': None, + 'imageHeight': metainfo['ori_shape'][0], + 'imageWidth': metainfo['ori_shape'][1], + 'shapes': [] + } + + for pred_instance in pred_instances: + pred_bbox = pred_instance.bboxes.cpu().numpy().tolist()[0] + pred_label = self.classes[pred_instance.labels] + + if selected_classes is not None and \ + pred_label not in selected_classes: + # filter class name + continue + + sub_dict = { + 'label': pred_label, + 'points': [pred_bbox[:2], pred_bbox[2:]], + 'group_id': None, + 'shape_type': 'rectangle', + 'flags': {} + } + json_info['shapes'].append(sub_dict) + + with open(output_path, 'w', encoding='utf-8') as f_json: + json.dump(json_info, f_json, ensure_ascii=False, indent=2) diff --git a/mmpl/utils/large_image.py b/mmpl/utils/large_image.py new file mode 100644 index 0000000000000000000000000000000000000000..8670804684f6dcdc6dc1846cf85260d900b3474e --- /dev/null +++ b/mmpl/utils/large_image.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +from mmcv.ops import batched_nms +from mmdet.structures import DetDataSample, SampleList +from mmengine.structures import InstanceData + + +def shift_rbboxes(bboxes: torch.Tensor, offset: Sequence[int]): + """Shift rotated bboxes with offset. + + Args: + bboxes (Tensor): The rotated bboxes need to be translated. + With shape (n, 5), which means (x, y, w, h, a). + offset (Sequence[int]): The translation offsets with shape of (2, ). + Returns: + Tensor: Shifted rotated bboxes. + """ + offset_tensor = bboxes.new_tensor(offset) + shifted_bboxes = bboxes.clone() + shifted_bboxes[:, 0:2] = shifted_bboxes[:, 0:2] + offset_tensor + return shifted_bboxes + + +def shift_predictions(det_data_samples: SampleList, + offsets: Sequence[Tuple[int, int]], + src_image_shape: Tuple[int, int]) -> SampleList: + """Shift predictions to the original image. + + Args: + det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. + offsets (Sequence[Tuple[int, int]]): Positions of the left top points + of patches. + src_image_shape (Tuple[int, int]): A (height, width) tuple of the large + image's width and height. + Returns: + (List[:obj:`DetDataSample`]): shifted results. + """ + try: + from sahi.slicing import shift_bboxes, shift_masks + except ImportError: + raise ImportError('Please run "pip install -U sahi" ' + 'to install sahi first for large image inference.') + + assert len(det_data_samples) == len( + offsets), 'The `results` should has the ' 'same length with `offsets`.' + shifted_predictions = [] + for det_data_sample, offset in zip(det_data_samples, offsets): + pred_inst = det_data_sample.pred_instances.clone() + + # Check bbox type + if pred_inst.bboxes.size(-1) == 4: + # Horizontal bboxes + shifted_bboxes = shift_bboxes(pred_inst.bboxes, offset) + elif pred_inst.bboxes.size(-1) == 5: + # Rotated bboxes + shifted_bboxes = shift_rbboxes(pred_inst.bboxes, offset) + else: + raise NotImplementedError + + # shift bboxes and masks + pred_inst.bboxes = shifted_bboxes + if 'masks' in det_data_sample: + pred_inst.masks = shift_masks(pred_inst.masks, offset, + src_image_shape) + + shifted_predictions.append(pred_inst.clone()) + + shifted_predictions = InstanceData.cat(shifted_predictions) + + return shifted_predictions + + +def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, + int]], + src_image_shape: Tuple[int, int], + nms_cfg: dict) -> DetDataSample: + """Merge patch results by nms. + + Args: + results (List[:obj:`DetDataSample`]): A list of patch results. + offsets (Sequence[Tuple[int, int]]): Positions of the left top points + of patches. + src_image_shape (Tuple[int, int]): A (height, width) tuple of the large + image's width and height. + nms_cfg (dict): it should specify nms type and other parameters + like `iou_threshold`. + Returns: + :obj:`DetDataSample`: merged results. + """ + shifted_instances = shift_predictions(results, offsets, src_image_shape) + + _, keeps = batched_nms( + boxes=shifted_instances.bboxes, + scores=shifted_instances.scores, + idxs=shifted_instances.labels, + nms_cfg=nms_cfg) + merged_instances = shifted_instances[keeps] + + merged_result = results[0].clone() + merged_result.pred_instances = merged_instances + return merged_result diff --git a/mmpl/utils/misc.py b/mmpl/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8633db7d95a1446586a469a873f7123a89b5f6f8 --- /dev/null +++ b/mmpl/utils/misc.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import urllib + +import numpy as np +import torch +from mmengine.utils import scandir +from prettytable import PrettyTable + +# from mmyolo.models import RepVGGBlock + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', + '.tiff', '.webp') + + +def switch_to_deploy(model): + """Model switch to deploy status.""" + for layer in model.modules(): + if isinstance(layer, RepVGGBlock): + layer.switch_to_deploy() + + print('Switch model to deploy modality.') + + +def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray: + """Auto arrange image to image_column x N row. + + Args: + image_list (list): cv2 image list. + image_column (int): Arrange to N column. Default: 2. + Return: + (np.ndarray): image_column x N row merge image + """ + img_count = len(image_list) + if img_count <= image_column: + # no need to arrange + image_show = np.concatenate(image_list, axis=1) + else: + # arrange image according to image_column + image_row = round(img_count / image_column) + fill_img_list = [np.ones(image_list[0].shape, dtype=np.uint8) * 255 + ] * ( + image_row * image_column - img_count) + image_list.extend(fill_img_list) + merge_imgs_col = [] + for i in range(image_row): + start_col = image_column * i + end_col = image_column * (i + 1) + merge_col = np.hstack(image_list[start_col:end_col]) + merge_imgs_col.append(merge_col) + + # merge to one image + image_show = np.vstack(merge_imgs_col) + + return image_show + + +def get_file_list(source_root: str) -> [list, dict]: + """Get file list. + + Args: + source_root (str): image or video source path + + Return: + source_file_path_list (list): A list for all source file. + source_type (dict): Source type: file or url or dir. + """ + is_dir = os.path.isdir(source_root) + is_url = source_root.startswith(('http:/', 'https:/')) + is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS + + source_file_path_list = [] + if is_dir: + # when input source is dir + for file in scandir(source_root, IMG_EXTENSIONS, recursive=True): + source_file_path_list.append(os.path.join(source_root, file)) + elif is_url: + # when input source is url + filename = os.path.basename( + urllib.parse.unquote(source_root).split('?')[0]) + file_save_path = os.path.join(os.getcwd(), filename) + print(f'Downloading source file to {file_save_path}') + torch.hub.download_url_to_file(source_root, file_save_path) + source_file_path_list = [file_save_path] + elif is_file: + # when input source is single image + source_file_path_list = [source_root] + else: + print('Cannot find image file.') + + source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file) + + return source_file_path_list, source_type + + +def show_data_classes(data_classes): + """When printing an error, all class names of the dataset.""" + print('\n\nThe name of the class contained in the dataset:') + data_classes_info = PrettyTable() + data_classes_info.title = 'Information of dataset class' + # List Print Settings + # If the quantity is too large, 25 rows will be displayed in each column + if len(data_classes) < 25: + data_classes_info.add_column('Class name', data_classes) + elif len(data_classes) % 25 != 0 and len(data_classes) > 25: + col_num = int(len(data_classes) / 25) + 1 + data_name_list = list(data_classes) + for i in range(0, (col_num * 25) - len(data_classes)): + data_name_list.append('') + for i in range(0, len(data_name_list), 25): + data_classes_info.add_column('Class name', + data_name_list[i:i + 25]) + + # Align display data to the left + data_classes_info.align['Class name'] = 'l' + print(data_classes_info) + + +def is_metainfo_lower(cfg): + """Determine whether the custom metainfo fields are all lowercase.""" + + def judge_keys(dataloader_cfg): + while 'dataset' in dataloader_cfg: + dataloader_cfg = dataloader_cfg['dataset'] + if 'metainfo' in dataloader_cfg: + all_keys = dataloader_cfg['metainfo'].keys() + all_is_lower = all([str(k).islower() for k in all_keys]) + assert all_is_lower, f'The keys in dataset metainfo must be all lowercase, but got {all_keys}. ' \ + f'Please refer to https://github.com/open-mmlab/mmyolo/blob/e62c8c4593/configs/yolov5/yolov5_s-v61_syncbn_fast_1xb4-300e_balloon.py#L8' # noqa + + judge_keys(cfg.get('train_dataloader', {})) + judge_keys(cfg.get('val_dataloader', {})) + judge_keys(cfg.get('test_dataloader', {})) diff --git a/mmpl/utils/setup_env.py b/mmpl/utils/setup_env.py new file mode 100644 index 0000000000000000000000000000000000000000..602caac2b5e4d644fe60eb8311d4d2166197d80d --- /dev/null +++ b/mmpl/utils/setup_env.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import warnings +from mmengine import DefaultScope + + +def register_all_modules(init_default_scope: bool = True): + """Register all modules in mmdet into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmdet default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmpl`, and all registries will build modules from mmdet's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmpl.datasets # noqa: F401,F403 + import mmpl.engine # noqa: F401,F403 + import mmpl.models # noqa: F401,F403 + import mmpl.evaluation # noqa: F401,F403 + + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmpl') + if never_created: + DefaultScope.get_instance('mmpl', scope_name='mmpl') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmpl': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmpl", ' + '`register_all_modules` will force the current' + 'default scope to be "mmpl". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmpl-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmpl') diff --git a/mmpl/utils/typing_utils.py b/mmpl/utils/typing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6caf6de53274594e139dbe7c1973c747229bf010 --- /dev/null +++ b/mmpl/utils/typing_utils.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Collecting some commonly used type hint in mmdetection.""" +from typing import List, Optional, Sequence, Tuple, Union + +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData, PixelData + +# TODO: Need to avoid circular import with assigner and sampler +# Type hint of config data +ConfigType = Union[ConfigDict, dict] +OptConfigType = Optional[ConfigType] +# Type hint of one or more config data +MultiConfig = Union[ConfigType, List[ConfigType]] +OptMultiConfig = Optional[MultiConfig] + +InstanceList = List[InstanceData] +OptInstanceList = Optional[InstanceList] + +PixelList = List[PixelData] +OptPixelList = Optional[PixelList] + +RangeType = Sequence[Tuple[int, int]] diff --git a/mmpl/visualization/__init__.py b/mmpl/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmpl/visualization/__pycache__/__init__.cpython-310.pyc b/mmpl/visualization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..269e8c5dc13d726477f762f930201543766ba87f Binary files /dev/null and b/mmpl/visualization/__pycache__/__init__.cpython-310.pyc differ