diff --git a/mmseg/__init__.py b/mmseg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f171ccb0a6a7e6abde613bf37cb07eed22fb09f --- /dev/null +++ b/mmseg/__init__.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import mmcv +import mmengine +from packaging.version import parse + +from .version import __version__, version_info + +MMCV_MIN = '2.0.0rc4' +MMCV_MAX = '2.1.0' +MMENGINE_MIN = '0.5.0' +MMENGINE_MAX = '1.0.0' + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +mmcv_min_version = digit_version(MMCV_MIN) +mmcv_max_version = digit_version(MMCV_MAX) +mmcv_version = digit_version(mmcv.__version__) + + +assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>=2.0.0rc4.' + +mmengine_min_version = digit_version(MMENGINE_MIN) +mmengine_max_version = digit_version(MMENGINE_MAX) +mmengine_version = digit_version(mmengine.__version__) + +assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_min_version}, '\ + f'<{mmengine_max_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmseg/__pycache__/__init__.cpython-310.pyc b/mmseg/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c7f91144e1c5f5ee48df9b7ca253f99c5473ffe Binary files /dev/null and b/mmseg/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/__pycache__/version.cpython-310.pyc b/mmseg/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d93eb1269aa838f18dcb24bd969ee7976c81b5e Binary files /dev/null and b/mmseg/__pycache__/version.cpython-310.pyc differ diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d22dc3f0ada938b3164497746b5a999191b0ff65 --- /dev/null +++ b/mmseg/apis/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_model, init_model, show_result_pyplot +from .mmseg_inferencer import MMSegInferencer + +__all__ = [ + 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer' +] diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4aadffc7982851f913f788915a11149d7e5459d7 --- /dev/null +++ b/mmseg/apis/inference.py @@ -0,0 +1,214 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Optional, Sequence, Union + +import mmcv +import numpy as np +import torch +from mmengine import Config +from mmengine.dataset import Compose +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint +from mmengine.utils import mkdir_or_exist + +from mmseg.models import BaseSegmentor +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette +from mmseg.visualization import SegLocalVisualizer + + +def init_model(config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + device: str = 'cuda:0', + cfg_options: Optional[dict] = None): + """Initialize a segmentor from config file. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + cfg_options (dict, optional): Options to override some settings in + the used config. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, (str, Path)): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + if cfg_options is not None: + config.merge_from_dict(cfg_options) + elif 'init_cfg' in config.model.backbone: + config.model.backbone.init_cfg = None + config.model.pretrained = None + config.model.train_cfg = None + init_default_scope(config.get('default_scope', 'mmseg')) + + model = MODELS.build(config.model) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + dataset_meta = checkpoint['meta'].get('dataset_meta', None) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmseg 1.x + model.dataset_meta = dataset_meta + elif 'CLASSES' in checkpoint.get('meta', {}): + # < mmseg 1.x + classes = checkpoint['meta']['CLASSES'] + palette = checkpoint['meta']['PALETTE'] + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.simplefilter('once') + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, classes and palette will be' + 'set according to num_classes ') + num_classes = model.decode_head.num_classes + dataset_name = None + for name in dataset_aliases.keys(): + if len(get_classes(name)) == num_classes: + dataset_name = name + break + if dataset_name is None: + warnings.warn( + 'No suitable dataset found, use Cityscapes by default') + dataset_name = 'cityscapes' + model.dataset_meta = { + 'classes': get_classes(dataset_name), + 'palette': get_palette(dataset_name) + } + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] + + +def _preprare_data(imgs: ImageType, model: BaseSegmentor): + + cfg = model.cfg + for t in cfg.test_pipeline: + if t.get('type') == 'LoadAnnotations': + cfg.test_pipeline.remove(t) + + is_batch = True + if not isinstance(imgs, (list, tuple)): + imgs = [imgs] + is_batch = False + + if isinstance(imgs[0], np.ndarray): + cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray' + + # TODO: Consider using the singleton pattern to avoid building + # a pipeline for each inference + pipeline = Compose(cfg.test_pipeline) + + data = defaultdict(list) + for img in imgs: + if isinstance(img, np.ndarray): + data_ = dict(img=img) + else: + data_ = dict(img_path=img) + data_ = pipeline(data_) + data['inputs'].append(data_['inputs']) + data['data_samples'].append(data_['data_samples']) + + return data, is_batch + + +def inference_model(model: BaseSegmentor, + img: ImageType) -> Union[SegDataSample, SampleList]: + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + :obj:`SegDataSample` or list[:obj:`SegDataSample`]: + If imgs is a list or tuple, the same length list type results + will be returned, otherwise return the segmentation results directly. + """ + # prepare data + data, is_batch = _preprare_data(img, model) + + # forward the model + with torch.no_grad(): + results = model.test_step(data) + + return results if is_batch else results[0] + + +def show_result_pyplot(model: BaseSegmentor, + img: Union[str, np.ndarray], + result: SegDataSample, + opacity: float = 0.5, + title: str = '', + draw_gt: bool = True, + draw_pred: bool = True, + wait_time: float = 0, + show: bool = True, + save_dir=None, + out_file=None): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (SegDataSample): The prediction SegDataSample result. + opacity(float): Opacity of painted segmentation map. + Default 0.5. Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + wait_time (float): The interval of show (s). 0 is the special value + that means "forever". Defaults to 0. + show (bool): Whether to display the drawn image. + Default to True. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + out_file (str, optional): Path to output file. Default to None. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + if hasattr(model, 'module'): + model = model.module + if isinstance(img, str): + image = mmcv.imread(img) + else: + image = img + if save_dir is not None: + mkdir_or_exist(save_dir) + # init visualizer + visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], + save_dir=save_dir, + alpha=opacity) + visualizer.dataset_meta = dict( + classes=model.dataset_meta['classes'], + palette=model.dataset_meta['palette']) + visualizer.add_datasample( + name=title, + image=image, + data_sample=result, + draw_gt=draw_gt, + draw_pred=draw_pred, + wait_time=wait_time, + out_file=out_file, + show=show) + vis_img = visualizer.get_image() + + return vis_img diff --git a/mmseg/apis/mmseg_inferencer.py b/mmseg/apis/mmseg_inferencer.py new file mode 100644 index 0000000000000000000000000000000000000000..1c72285c56207545273b9eae5d5c1668404431da --- /dev/null +++ b/mmseg/apis/mmseg_inferencer.py @@ -0,0 +1,361 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import List, Optional, Sequence, Union + +import mmcv +import mmengine +import numpy as np +import torch +import torch.nn as nn +from mmcv.transforms import Compose +from mmengine.infer.infer import BaseInferencer, ModelType +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner.checkpoint import _load_checkpoint_to_model +from PIL import Image + +from mmseg.structures import SegDataSample +from mmseg.utils import ConfigType, SampleList, get_classes, get_palette +from mmseg.visualization import SegLocalVisualizer + +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[SegDataSample, SampleList] + + +class MMSegInferencer(BaseInferencer): + """Semantic segmentation inferencer, provides inference and visualization + interfaces. Note: MMEngine >= 0.5.0 is required. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. Take the `mmseg metafile `_ + as an example the `model` could be + "fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model + will be download automatically. If use config file, like + "configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the + `weights` should be defined. + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + classes (list, optional): Input classes for result rendering, as the + prediction of segmentation model is a segment map with label + indices, `classes` is a list which includes items responding to the + label indices. If classes is not defined, visualizer will take + `cityscapes` classes by default. Defaults to None. + palette (list, optional): Input palette for result rendering, which is + a list of color palette responding to the classes. If palette is + not defined, visualizer will take `cityscapes` palette by default. + Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. classes + and palette, but the `classes` and `palette` have higher priority. + Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + scope (str, optional): The scope of the model. Defaults to 'mmseg'. + """ # noqa + + preprocess_kwargs: set = set() + forward_kwargs: set = {'mode', 'out_dir'} + visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'} + postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} + + def __init__(self, + model: Union[ModelType, str], + weights: Optional[str] = None, + classes: Optional[Union[str, List]] = None, + palette: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmseg') -> None: + # A global counter tracking the number of images processes, for + # naming of the output images + self.num_visualized_imgs = 0 + self.num_pred_imgs = 0 + init_default_scope(scope if scope else 'mmseg') + super().__init__( + model=model, weights=weights, device=device, scope=scope) + + if device == 'cpu' or not torch.cuda.is_available(): + self.model = revert_sync_batchnorm(self.model) + + assert isinstance(self.visualizer, SegLocalVisualizer) + self.visualizer.set_dataset_meta(palette, classes, dataset_name) + + def _load_weights_to_model(self, model: nn.Module, + checkpoint: Optional[dict], + cfg: Optional[ConfigType]) -> None: + """Loading model weights and meta information from cfg and checkpoint. + + Subclasses could override this method to load extra meta information + from ``checkpoint`` and ``cfg`` to model. + + Args: + model (nn.Module): Model to load weights and meta information. + checkpoint (dict, optional): The loaded checkpoint. + cfg (Config or ConfigDict, optional): The loaded config. + """ + + if checkpoint is not None: + _load_checkpoint_to_model(model, checkpoint) + checkpoint_meta = checkpoint.get('meta', {}) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + # mmsegmentation 1.x + model.dataset_meta = { + 'classes': checkpoint_meta['dataset_meta'].get('classes'), + 'palette': checkpoint_meta['dataset_meta'].get('palette') + } + elif 'CLASSES' in checkpoint_meta: + # mmsegmentation 0.x + classes = checkpoint_meta['CLASSES'] + palette = checkpoint_meta.get('PALETTE', None) + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, use classes of Cityscapes by ' + 'default.') + model.dataset_meta = { + 'classes': get_classes('cityscapes'), + 'palette': get_palette('cityscapes') + } + else: + warnings.warn('Checkpoint is not loaded, and the inference ' + 'result is calculated by the randomly initialized ' + 'model!') + warnings.warn( + 'weights is None, use cityscapes classes by default.') + model.dataset_meta = { + 'classes': get_classes('cityscapes'), + 'palette': get_palette('cityscapes') + } + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + show: bool = False, + wait_time: int = 0, + out_dir: str = '', + img_out_dir: str = 'vis', + pred_out_dir: str = 'pred', + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (Union[list, str, np.ndarray]): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`SegDataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + show (bool): Whether to display the rendering color segmentation + mask in a popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_dir (str): Output directory of inference results. Defaults + to ''. + img_out_dir (str): Subdirectory of `out_dir`, used to save + rendering color segmentation mask, so `out_dir` must be defined + if you would like to save predicted mask. Defaults to 'vis'. + pred_out_dir (str): Subdirectory of `out_dir`, used to save + predicted mask file, so `out_dir` must be defined if you would + like to save predicted mask. Defaults to 'pred'. + + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + + Returns: + dict: Inference and visualization results. + """ + + if out_dir != '': + pred_out_dir = osp.join(out_dir, pred_out_dir) + img_out_dir = osp.join(out_dir, img_out_dir) + else: + pred_out_dir = '' + img_out_dir = '' + + return super().__call__( + inputs=inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + show=show, + wait_time=wait_time, + img_out_dir=img_out_dir, + pred_out_dir=pred_out_dir, + **kwargs) + + def visualize(self, + inputs: list, + preds: List[dict], + show: bool = False, + wait_time: int = 0, + img_out_dir: str = '', + opacity: float = 0.8) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + img_out_dir (str): Output directory of rendering prediction i.e. + color segmentation mask. Defaults: '' + opacity (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Returns: + List[np.ndarray]: Visualization results. + """ + if self.visualizer is None or (not show and img_out_dir == ''): + return None + + if getattr(self, 'visualizer') is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None') + self.visualizer.set_dataset_meta(**self.model.dataset_meta) + self.visualizer.alpha = opacity + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + img_bytes = mmengine.fileio.get(single_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + img = single_input.copy() + img_num = str(self.num_visualized_imgs).zfill(8) + '_vis' + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type:' + f'{type(single_input)}') + + out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\ + else None + + self.visualizer.add_datasample( + img_name, + img, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=True, + out_file=out_file) + results.append(self.visualizer.get_image()) + self.num_visualized_imgs += 1 + + return results + + def postprocess(self, + preds: PredType, + visualization: List[np.ndarray], + return_datasample: bool = False, + pred_out_dir: str = '') -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Pack the predictions and visualization results and return them. + 2. Save the predictions, if it needed. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (List[np.ndarray]): The list of rendering color + segmentation mask. + return_datasample (bool): Whether to return results as datasamples. + Defaults to False. + pred_out_dir: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (List[np.ndarray], np.ndarray): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it will be the segmentation mask + with label indice. + """ + if return_datasample: + if len(preds) == 1: + return preds[0] + else: + return preds + + results_dict = {} + + results_dict['predictions'] = [] + results_dict['visualization'] = [] + + for i, pred in enumerate(preds): + pred_data = pred.pred_sem_seg.numpy().data[0] + results_dict['predictions'].append(pred_data) + if visualization is not None: + vis = visualization[i] + results_dict['visualization'].append(vis) + if pred_out_dir != '': + mmengine.mkdir_or_exist(pred_out_dir) + img_name = str(self.num_pred_imgs).zfill(8) + '_pred.png' + img_path = osp.join(pred_out_dir, img_name) + output = Image.fromarray(pred_data.astype(np.uint8)) + output.save(img_path) + self.num_pred_imgs += 1 + + if len(results_dict['predictions']) == 1: + results_dict['predictions'] = results_dict['predictions'][0] + if visualization is not None: + results_dict['visualization'] = \ + results_dict['visualization'][0] + return results_dict + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + # Loading annotations is also not applicable + idx = self._get_transform_idx(pipeline_cfg, 'LoadAnnotations') + if idx != -1: + del pipeline_cfg[idx] + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadImageFromFile') + + if load_img_idx == -1: + raise ValueError( + 'LoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader' + return Compose(pipeline_cfg) + + def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: + """Returns the index of the transform in a pipeline. + + If the transform is not found, returns -1. + """ + for i, transform in enumerate(pipeline_cfg): + if transform['type'] == name: + return i + return -1 diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a90d53c88e03850e56e765332cc87533466892f0 --- /dev/null +++ b/mmseg/datasets/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# yapf: disable +from .ade import ADE20KDataset +from .basesegdataset import BaseSegDataset +from .chase_db1 import ChaseDB1Dataset +from .cityscapes import CityscapesDataset +from .coco_stuff import COCOStuffDataset +from .dark_zurich import DarkZurichDataset +from .dataset_wrappers import MultiImageMixDataset +from .decathlon import DecathlonDataset +from .drive import DRIVEDataset +from .hrf import HRFDataset +from .isaid import iSAIDDataset +from .isprs import ISPRSDataset +from .lip import LIPDataset +from .loveda import LoveDADataset +from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2 +from .night_driving import NightDrivingDataset +from .pascal_context import PascalContextDataset, PascalContextDataset59 +from .potsdam import PotsdamDataset +from .refuge import REFUGEDataset +from .stare import STAREDataset +from .synapse import SynapseDataset +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, GenerateEdge, LoadAnnotations, + LoadBiomedicalAnnotation, LoadBiomedicalData, + LoadBiomedicalImageFromFile, LoadImageFromNDArray, + PackSegInputs, PhotoMetricDistortion, RandomCrop, + RandomCutOut, RandomMosaic, RandomRotate, + RandomRotFlip, Rerange, ResizeShortestEdge, + ResizeToMultiple, RGB2Gray, SegRescale) +from .voc import PascalVOCDataset + +# yapf: enable +__all__ = [ + 'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip', + 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', + 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', + 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', + 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset', + 'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', + 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', + 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', + 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', + 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', + 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge', + 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', + 'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1', + 'MapillaryDataset_v2' +] diff --git a/mmseg/datasets/__pycache__/__init__.cpython-310.pyc b/mmseg/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e41c4d26557c839e7ebad9cc33325ebee345047 Binary files /dev/null and b/mmseg/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/ade.cpython-310.pyc b/mmseg/datasets/__pycache__/ade.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..379875639aca97b4f970f9915f227c42502173c6 Binary files /dev/null and b/mmseg/datasets/__pycache__/ade.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/basesegdataset.cpython-310.pyc b/mmseg/datasets/__pycache__/basesegdataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..842ed8bc01ee14fff850a05ea8410d6a03ff1df6 Binary files /dev/null and b/mmseg/datasets/__pycache__/basesegdataset.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/chase_db1.cpython-310.pyc b/mmseg/datasets/__pycache__/chase_db1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e31a90c2317a5214aa8aaa53be517f56846f21ce Binary files /dev/null and b/mmseg/datasets/__pycache__/chase_db1.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/cityscapes.cpython-310.pyc b/mmseg/datasets/__pycache__/cityscapes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c49f2d87166971a7d4d390343ff1bd0c2f0dca0e Binary files /dev/null and b/mmseg/datasets/__pycache__/cityscapes.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/coco_stuff.cpython-310.pyc b/mmseg/datasets/__pycache__/coco_stuff.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9f6961f3e5f28b404e11dbbe3fb8bcb36efda17 Binary files /dev/null and b/mmseg/datasets/__pycache__/coco_stuff.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/dark_zurich.cpython-310.pyc b/mmseg/datasets/__pycache__/dark_zurich.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26d492ae38d0d0ca3062b49cc1612ee839e02197 Binary files /dev/null and b/mmseg/datasets/__pycache__/dark_zurich.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/dataset_wrappers.cpython-310.pyc b/mmseg/datasets/__pycache__/dataset_wrappers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a71b6b6c99736472a97c8ab97b7f29512a626c28 Binary files /dev/null and b/mmseg/datasets/__pycache__/dataset_wrappers.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/decathlon.cpython-310.pyc b/mmseg/datasets/__pycache__/decathlon.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2fd17c275a80d3c63d9bc1a4bd152fed35e5963 Binary files /dev/null and b/mmseg/datasets/__pycache__/decathlon.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/drive.cpython-310.pyc b/mmseg/datasets/__pycache__/drive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2222b17dd4d281612bf2871e7cdefdad8909da8 Binary files /dev/null and b/mmseg/datasets/__pycache__/drive.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/hrf.cpython-310.pyc b/mmseg/datasets/__pycache__/hrf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a629737c543340f40747cce74ad03edf02718818 Binary files /dev/null and b/mmseg/datasets/__pycache__/hrf.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/isaid.cpython-310.pyc b/mmseg/datasets/__pycache__/isaid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f8aa4906c04809eced5b73fa0638b2e91b9dab Binary files /dev/null and b/mmseg/datasets/__pycache__/isaid.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/isprs.cpython-310.pyc b/mmseg/datasets/__pycache__/isprs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8a8498e6c48d85ecaf91af67f10d9b13e2ff857 Binary files /dev/null and b/mmseg/datasets/__pycache__/isprs.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/lip.cpython-310.pyc b/mmseg/datasets/__pycache__/lip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9363899e7e7c1d1f5ed4ee8b7862fdc3affb6f92 Binary files /dev/null and b/mmseg/datasets/__pycache__/lip.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/loveda.cpython-310.pyc b/mmseg/datasets/__pycache__/loveda.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08c7bffaae2209862e5457d0d3d30ee2b95d3f8c Binary files /dev/null and b/mmseg/datasets/__pycache__/loveda.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/mapillary.cpython-310.pyc b/mmseg/datasets/__pycache__/mapillary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..926b3debe40129bcc335dc05a9a1132ca1d2fef1 Binary files /dev/null and b/mmseg/datasets/__pycache__/mapillary.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/night_driving.cpython-310.pyc b/mmseg/datasets/__pycache__/night_driving.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e77a46077afe877286a86eec0a1a3dbbcb3dc06 Binary files /dev/null and b/mmseg/datasets/__pycache__/night_driving.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/pascal_context.cpython-310.pyc b/mmseg/datasets/__pycache__/pascal_context.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1c41a370dba6415f516a8afea6a8cdcadb3a30 Binary files /dev/null and b/mmseg/datasets/__pycache__/pascal_context.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/potsdam.cpython-310.pyc b/mmseg/datasets/__pycache__/potsdam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efcc2b8630ca1d99d3dbcf5bdec504cdda912d06 Binary files /dev/null and b/mmseg/datasets/__pycache__/potsdam.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/refuge.cpython-310.pyc b/mmseg/datasets/__pycache__/refuge.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c35ed8314508b8208271a010289389573c51e65 Binary files /dev/null and b/mmseg/datasets/__pycache__/refuge.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/stare.cpython-310.pyc b/mmseg/datasets/__pycache__/stare.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15873de15019cd973eaf7c55769d3ddc3b87f14b Binary files /dev/null and b/mmseg/datasets/__pycache__/stare.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/synapse.cpython-310.pyc b/mmseg/datasets/__pycache__/synapse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cabd12194a838b5f173be568042e6b15b704af4c Binary files /dev/null and b/mmseg/datasets/__pycache__/synapse.cpython-310.pyc differ diff --git a/mmseg/datasets/__pycache__/voc.cpython-310.pyc b/mmseg/datasets/__pycache__/voc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..718cf0f38e3c86cefca3e9a6a93857d50331cc00 Binary files /dev/null and b/mmseg/datasets/__pycache__/voc.cpython-310.pyc differ diff --git a/mmseg/datasets/ade.py b/mmseg/datasets/ade.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d138af6c00e87330ab58a6a960255df7a21746 --- /dev/null +++ b/mmseg/datasets/ade.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class ADE20KDataset(BaseSegDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + # METAINFO = dict( + # classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', + # 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', + # 'person', 'earth', 'door', 'table', 'mountain', 'plant', + # 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', + # 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', + # 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', + # 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + # 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + # 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', + # 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', + # 'screen door', 'stairway', 'river', 'bridge', 'bookcase', + # 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', + # 'bench', 'countertop', 'stove', 'palm', 'kitchen island', + # 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', + # 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + # 'chandelier', 'awning', 'streetlight', 'booth', + # 'television receiver', 'airplane', 'dirt track', 'apparel', + # 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', + # 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', + # 'conveyer belt', 'canopy', 'washer', 'plaything', + # 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', + # 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', + # 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', + # 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', + # 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', + # 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', + # 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + # 'clock', 'flag'), + # palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + # [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + # [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + # [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + # [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + # [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + # [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + # [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + # [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + # [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + # [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + # [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + # [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + # [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + # [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + # [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + # [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + # [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + # [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + # [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + # [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + # [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + # [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + # [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + # [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + # [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + # [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + # [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + # [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + # [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + # [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + # [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + # [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + # [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + # [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + # [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + # [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + # [102, 255, 0], [92, 0, 255]]) + + METAINFO = dict(classes=('building',), palette=[(0, 0, 255)]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/basesegdataset.py b/mmseg/datasets/basesegdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eadf8482492c63407f7b6218fc2b1d554a678ac4 --- /dev/null +++ b/mmseg/datasets/basesegdataset.py @@ -0,0 +1,269 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +import mmengine +import mmengine.fileio as fileio +import numpy as np +from mmengine.dataset import BaseDataset, Compose + +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class BaseSegDataset(BaseDataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of BaseSegDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + seg_map_suffix='.png', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img_path='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + ignore_index: int = 255, + reduce_zero_label: bool = False, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.seg_map_suffix = seg_map_suffix + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map(new_classes) + self._metainfo.update( + dict( + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + # import ipdb; ipdb.set_trace() + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != 255: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + if osp.isfile(self.ann_file): + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + data_info = dict( + img_path=osp.join(img_dir, img_name + self.img_suffix)) + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + seg_map = img.replace(self.img_suffix, self.seg_map_suffix) + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/mmseg/datasets/chase_db1.py b/mmseg/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc1fc56773715139b23ddf1e7ac475540914a42 --- /dev/null +++ b/mmseg/datasets/chase_db1.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class ChaseDB1Dataset(BaseSegDataset): + """Chase_db1 dataset. + + In segmentation map annotation for Chase_db1, 0 stands for background, + which is included in 2 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_1stHO.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='_1stHO.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert self.file_client.exists(self.data_prefix['img_path']) diff --git a/mmseg/datasets/cityscapes.py b/mmseg/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f494d62424a39581961ab705b3308e7e07bee110 --- /dev/null +++ b/mmseg/datasets/cityscapes.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class CityscapesDataset(BaseSegDataset): + """Cityscapes dataset. + + The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is + fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. + """ + METAINFO = dict( + classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', + 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle'), + palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, + 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], + [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], + [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]) + + def __init__(self, + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtFine_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/coco_stuff.py b/mmseg/datasets/coco_stuff.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1574d9702330cc5b10bab084841df61e7121ff --- /dev/null +++ b/mmseg/datasets/coco_stuff.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class COCOStuffDataset(BaseSegDataset): + """COCO-Stuff dataset. + + In segmentation map annotation for COCO-Stuff, Train-IDs of the 10k version + are from 1 to 171, where 0 is the ignore index, and Train-ID of COCO Stuff + 164k is from 0 to 170, where 255 is the ignore index. So, they are all 171 + semantic categories. ``reduce_zero_label`` is set to True and False for the + 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg', + and ``seg_map_suffix`` is fixed to '.png'. + """ + METAINFO = dict( + classes=( + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', + 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', + 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', + 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road', + 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf', + 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs', + 'stone', 'straw', 'structural-other', 'table', 'tent', + 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick', + 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone', + 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood'), + palette=[[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], + [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], + [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], + [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], + [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], + [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], + [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], + [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0], + [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128], + [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160], + [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128], + [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192], + [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160], + [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0], + [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192], + [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160], + [64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128], + [128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128], + [64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224], + [64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0], + [0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128], + [64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224], + [64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128], + [128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192], + [0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224], + [0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0], + [64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192], + [0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224], + [0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128], + [192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128], + [64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160], + [0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64], + [64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128], + [64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160], + [0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192], + [192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192], + [0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160], + [64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64], + [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192], + [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160], + [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192], + [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128], + [64, 192, 96], [64, 160, 64], [64, 64, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/dark_zurich.py b/mmseg/datasets/dark_zurich.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5393fa9e5047e81790f91829cfe4b7f33cc707 --- /dev/null +++ b/mmseg/datasets/dark_zurich.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .cityscapes import CityscapesDataset + + +@DATASETS.register_module() +class DarkZurichDataset(CityscapesDataset): + """DarkZurichDataset dataset.""" + + def __init__(self, + img_suffix='_rgb_anon.png', + seg_map_suffix='_gt_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/dataset_wrappers.py b/mmseg/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..082c116ff4582ecc7064dba1aba3c164dd556af5 --- /dev/null +++ b/mmseg/datasets/dataset_wrappers.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import copy +from typing import List, Optional, Sequence, Union + +from mmengine.dataset import ConcatDataset, force_full_init + +from mmseg.registry import DATASETS, TRANSFORMS + + +@DATASETS.register_module() +class MultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. + + Args: + dataset (ConcatDataset or dict): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + """ + + def __init__(self, + dataset: Union[ConcatDataset, dict], + pipeline: Sequence[dict], + skip_type_keys: Optional[List[str]] = None, + lazy_init: bool = False) -> None: + assert isinstance(pipeline, collections.abc.Sequence) + + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, ConcatDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`ConcatDataset` instance, but got {type(dataset)}') + + if skip_type_keys is not None: + assert all([ + isinstance(skip_type_key, str) + for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform['type']) + transform = TRANSFORMS.build(transform) + self.pipeline.append(transform) + else: + raise TypeError('pipeline must be a dict') + + self._metainfo = self.dataset.metainfo + self.num_samples = len(self.dataset) + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of the multi-image-mixed dataset. + + Returns: + dict: The meta information of multi-image-mixed dataset. + """ + return copy.deepcopy(self._metainfo) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._ori_len = len(self.dataset) + self._fully_initialized = True + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + results = copy.deepcopy(self.dataset[idx]) + for (transform, transform_type) in zip(self.pipeline, + self.pipeline_types): + if self._skip_type_keys is not None and \ + transform_type in self._skip_type_keys: + continue + + if hasattr(transform, 'get_indices'): + indices = transform.get_indices(self.dataset) + if not isinstance(indices, collections.abc.Sequence): + indices = [indices] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indices + ] + results['mix_results'] = mix_results + + results = transform(results) + + if 'mix_results' in results: + results.pop('mix_results') + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. + + It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([ + isinstance(skip_type_key, str) for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys diff --git a/mmseg/datasets/decathlon.py b/mmseg/datasets/decathlon.py new file mode 100644 index 0000000000000000000000000000000000000000..26aa4ef0d7f44e55d4400ed6151ea1f6cb3930ec --- /dev/null +++ b/mmseg/datasets/decathlon.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import List + +from mmengine.fileio import load + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class DecathlonDataset(BaseSegDataset): + """Dataset for Dacathlon dataset. + + The dataset.json format is shown as follows + + .. code-block:: none + + { + "name": "BRATS", + "tensorImageSize": "4D", + "modality": + { + "0": "FLAIR", + "1": "T1w", + "2": "t1gd", + "3": "T2w" + }, + "labels": { + "0": "background", + "1": "edema", + "2": "non-enhancing tumor", + "3": "enhancing tumour" + }, + "numTraining": 484, + "numTest": 266, + "training": + [ + { + "image": "./imagesTr/BRATS_306.nii.gz" + "label": "./labelsTr/BRATS_306.nii.gz" + ... + } + ] + "test": + [ + "./imagesTs/BRATS_557.nii.gz" + ... + ] + } + """ + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + # `self.ann_file` denotes the absolute annotation file path if + # `self.root=None` or relative path if `self.root=/path/to/data/`. + annotations = load(self.ann_file) + if not isinstance(annotations, dict): + raise TypeError(f'The annotations loaded from annotation file ' + f'should be a dict, but got {type(annotations)}!') + raw_data_list = annotations[ + 'training'] if not self.test_mode else annotations['test'] + data_list = [] + for raw_data_info in raw_data_list: + # `2:` works for removing './' in file path, which will break + # loading from cloud storage. + if isinstance(raw_data_info, dict): + data_info = dict( + img_path=osp.join(self.data_root, raw_data_info['image'] + [2:])) + data_info['seg_map_path'] = osp.join( + self.data_root, raw_data_info['label'][2:]) + else: + data_info = dict( + img_path=osp.join(self.data_root, raw_data_info)[2:]) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + annotations.pop('training') + annotations.pop('test') + + metainfo = copy.deepcopy(annotations) + metainfo['classes'] = [*metainfo['labels'].values()] + # Meta information load from annotation file will not influence the + # existed meta information load from `BaseDataset.METAINFO` and + # `metainfo` arguments defined in constructor. + for k, v in metainfo.items(): + self._metainfo.setdefault(k, v) + + return data_list diff --git a/mmseg/datasets/drive.py b/mmseg/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..c42e18e711a8bc66b80e2b3106c39d20f7239d4a --- /dev/null +++ b/mmseg/datasets/drive.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class DRIVEDataset(BaseSegDataset): + """DRIVE dataset. + + In segmentation map annotation for DRIVE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='_manual1.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert self.file_client.exists(self.data_prefix['img_path']) diff --git a/mmseg/datasets/hrf.py b/mmseg/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..0df6ccc49c2f55a102373da9f68d02c41648d964 --- /dev/null +++ b/mmseg/datasets/hrf.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class HRFDataset(BaseSegDataset): + """HRF dataset. + + In segmentation map annotation for HRF, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert self.file_client.exists(self.data_prefix['img_path']) diff --git a/mmseg/datasets/isaid.py b/mmseg/datasets/isaid.py new file mode 100644 index 0000000000000000000000000000000000000000..61942ec1ea33e76c65c22d8e7fc71fb8194841dd --- /dev/null +++ b/mmseg/datasets/isaid.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class iSAIDDataset(BaseSegDataset): + """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images + In segmentation map annotation for iSAID dataset, which is included + in 16 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + METAINFO = dict( + classes=('background', 'ship', 'store_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'Ground_Track_Field', + 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', + 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', + 'Harbor'), + palette=[[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], + [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], + [0, 127, 191], [0, 127, 255], [0, 100, 155]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='_instance_color_RGB.png', + ignore_index=255, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ignore_index=ignore_index, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/isprs.py b/mmseg/datasets/isprs.py new file mode 100644 index 0000000000000000000000000000000000000000..30af53c569b05c9be1218e9a58655c36c8aa9931 --- /dev/null +++ b/mmseg/datasets/isprs.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class ISPRSDataset(BaseSegDataset): + """ISPRS dataset. + + In segmentation map annotation for ISPRS, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('impervious_surface', 'building', 'low_vegetation', 'tree', + 'car', 'clutter'), + palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/lip.py b/mmseg/datasets/lip.py new file mode 100644 index 0000000000000000000000000000000000000000..3a32a193aff990ae9f819d4a0a1be82df1d049cb --- /dev/null +++ b/mmseg/datasets/lip.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class LIPDataset(BaseSegDataset): + """LIP dataset. + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + METAINFO = dict( + classes=('Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', + 'UpperClothes', 'Dress', 'Coat', 'Socks', 'Pants', + 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', + 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', + 'Right-shoe'), + palette=( + [0, 0, 0], + [128, 0, 0], + [255, 0, 0], + [0, 85, 0], + [170, 0, 51], + [255, 85, 0], + [0, 0, 85], + [0, 119, 221], + [85, 85, 0], + [0, 85, 85], + [85, 51, 0], + [52, 86, 128], + [0, 128, 0], + [0, 0, 255], + [51, 170, 221], + [0, 255, 255], + [85, 255, 170], + [170, 255, 85], + [255, 255, 0], + [255, 170, 0], + )) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/loveda.py b/mmseg/datasets/loveda.py new file mode 100644 index 0000000000000000000000000000000000000000..5c16db503adee6f1a1cac67e1dc72ff873ccd5ea --- /dev/null +++ b/mmseg/datasets/loveda.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class LoveDADataset(BaseSegDataset): + """LoveDA dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural'), + palette=[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/mapillary.py b/mmseg/datasets/mapillary.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2947338ec79b3d8558cee0387a2a84e41f0421 --- /dev/null +++ b/mmseg/datasets/mapillary.py @@ -0,0 +1,176 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class MapillaryDataset_v1(BaseSegDataset): + """Mapillary Vistas Dataset. + + Dataset paper link: + http://ieeexplore.ieee.org/document/8237796/ + + v1.2 contain 66 object classes. + (37 instance-specific) + + v2.0 contain 124 object classes. + (70 instance-specific, 46 stuff, 8 void or crowd). + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png' for Mapillary Vistas Dataset. + """ + METAINFO = dict( + classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', + 'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain', + 'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track', + 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', + 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Crosswalk', + 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', + 'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', + 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', + 'Phone Booth', 'Pothole', 'Street Light', 'Pole', + 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', + 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', + 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle', + 'On Rails', 'Other Vehicle', 'Trailer', 'Truck', + 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'), + palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [102, 102, 156], + [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], + [244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90], + [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], + [200, 128, 128], [255, 255, 255], [64, 170, + 64], [230, 160, 50], + [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 220, 220], [220, 128, 128], + [222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33], + [100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100], + [153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30], + [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], + [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, + 10], [0, 0, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) + + +@DATASETS.register_module() +class MapillaryDataset_v2(BaseSegDataset): + """Mapillary Vistas Dataset. + + Dataset paper link: + http://ieeexplore.ieee.org/document/8237796/ + + v1.2 contain 66 object classes. + (37 instance-specific) + + v2.0 contain 124 object classes. + (70 instance-specific, 46 stuff, 8 void or crowd). + + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png' for Mapillary Vistas Dataset. + """ + METAINFO = dict( + classes=( + 'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', + 'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median', + 'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall', + 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway', + 'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track', + 'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk', + 'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel', + 'Person', 'Person Group', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Dashed Line', + 'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line', + 'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)', + 'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)', + 'Lane Marking - Arrow (Split Left or Straight)', + 'Lane Marking - Arrow (Split Right or Straight)', + 'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk', + 'Lane Marking - Give Way (Row)', + 'Lane Marking - Give Way (Single)', + 'Lane Marking - Hatched (Chevron)', + 'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other', + 'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)', + 'Lane Marking - Symbol (Other)', 'Lane Marking - Text', + 'Lane Marking (only) - Dashed Line', + 'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other', + 'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow', + 'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack', + 'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box', + 'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole', + 'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back', + 'Signage - Information', 'Signage - Other', 'Signage - Store', + 'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame', + 'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)', + 'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)', + 'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists', + 'Traffic Light - Other', 'Traffic Sign - Ambiguous', + 'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)', + 'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)', + 'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)', + 'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', + 'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve', + 'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', + 'Unlabeled'), + palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32], + [196, 196, 196], [190, 153, 153], [180, 165, 180], + [90, 120, 150], [250, 170, 33], [250, 170, 34], + [128, 128, 128], [250, 170, 35], [102, 102, 156], + [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], + [110, 110, 110], [244, 35, 232], [128, 196, + 128], [150, 100, 100], + [70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60], + [220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200], + [255, 255, 255], [255, 255, 255], [250, 170, 29], + [250, 170, 28], [250, 170, 26], [250, 170, + 25], [250, 170, 24], + [250, 170, 22], [250, 170, 21], [250, 170, + 20], [255, 255, 255], + [250, 170, 19], [250, 170, 18], [250, 170, + 12], [250, 170, 11], + [255, 255, 255], [255, 255, 255], [250, 170, 16], + [250, 170, 15], [250, 170, 15], [255, 255, 255], + [255, 255, 255], [255, 255, 255], [255, 255, 255], + [64, 170, 64], [230, 160, 50], + [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 128, 128], [222, 40, + 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255], + [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30], + [250, 173, 30], [250, 174, 30], [250, 175, + 30], [250, 176, 30], + [210, 170, 100], [153, 153, 153], [153, 153, 153], + [128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30], + [250, 170, 30], [250, 170, 30], [250, 170, + 30], [250, 170, 30], + [250, 170, 30], [192, 192, 192], [192, 192, 192], + [192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196], + [192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32], + [150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170], + [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81], + [111, 111, 0], [0, 0, 0]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/night_driving.py b/mmseg/datasets/night_driving.py new file mode 100644 index 0000000000000000000000000000000000000000..3ead91ec77cbd8e3f0a870dee3462549183e9c9b --- /dev/null +++ b/mmseg/datasets/night_driving.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .cityscapes import CityscapesDataset + + +@DATASETS.register_module() +class NightDrivingDataset(CityscapesDataset): + """NightDrivingDataset dataset.""" + + def __init__(self, + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtCoarse_labelTrainIds.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/pascal_context.py b/mmseg/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b2fba7b420f9544cf95057e8b585cd8514d725 --- /dev/null +++ b/mmseg/datasets/pascal_context.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class PascalContextDataset(BaseSegDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + ann_file (str): Annotation file path. + """ + + METAINFO = dict( + classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes', + 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', + 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', + 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', + 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', + 'horse', 'keyboard', 'light', 'motorbike', 'mountain', + 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road', + 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', + 'sofa', 'table', 'track', 'train', 'tree', 'truck', + 'tvmonitor', 'wall', 'water', 'window', 'wood'), + palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) + + def __init__(self, + ann_file: str, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ann_file=ann_file, + reduce_zero_label=False, + **kwargs) + assert self.file_client.exists( + self.data_prefix['img_path']) and osp.isfile(self.ann_file) + + +@DATASETS.register_module() +class PascalContextDataset59(BaseSegDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + ann_file (str): Annotation file path. + """ + METAINFO = dict( + classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', + 'bird', 'boat', 'book', 'bottle', 'building', 'bus', + 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', + 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', + 'floor', 'flower', 'food', 'grass', 'ground', 'horse', + 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', + 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', + 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', + 'table', 'track', 'train', 'tree', 'truck', 'tvmonitor', + 'wall', 'water', 'window', 'wood'), + palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]) + + def __init__(self, + ann_file: str, + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs): + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ann_file=ann_file, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert self.file_client.exists( + self.data_prefix['img_path']) and osp.isfile(self.ann_file) diff --git a/mmseg/datasets/potsdam.py b/mmseg/datasets/potsdam.py new file mode 100644 index 0000000000000000000000000000000000000000..6892de3dd29fda569527342377c6e83ce0d972bf --- /dev/null +++ b/mmseg/datasets/potsdam.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class PotsdamDataset(BaseSegDataset): + """ISPRS Potsdam dataset. + + In segmentation map annotation for Potsdam dataset, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('impervious_surface', 'building', 'low_vegetation', 'tree', + 'car', 'clutter'), + palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) diff --git a/mmseg/datasets/refuge.py b/mmseg/datasets/refuge.py new file mode 100644 index 0000000000000000000000000000000000000000..4016a825a37cdd0162f9c3e72df2fcabc6984991 --- /dev/null +++ b/mmseg/datasets/refuge.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class REFUGEDataset(BaseSegDataset): + """REFUGE dataset. + + In segmentation map annotation for REFUGE, 0 stands for background, which + is not included in 2 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + METAINFO = dict( + classes=('background', ' Optic Cup', 'Optic Disc'), + palette=[[120, 120, 120], [6, 230, 230], [56, 59, 120]]) + + def __init__(self, **kwargs) -> None: + super().__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert fileio.exists( + self.data_prefix['img_path'], backend_args=self.backend_args) diff --git a/mmseg/datasets/stare.py b/mmseg/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfce234494b37aeabac6956e69c8a45403f5103 --- /dev/null +++ b/mmseg/datasets/stare.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class STAREDataset(BaseSegDataset): + """STARE dataset. + + In segmentation map annotation for STARE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.ah.png'. + """ + METAINFO = dict( + classes=('background', 'vessel'), + palette=[[120, 120, 120], [6, 230, 230]]) + + def __init__(self, + img_suffix='.png', + seg_map_suffix='.ah.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + assert self.file_client.exists(self.data_prefix['img_path']) diff --git a/mmseg/datasets/synapse.py b/mmseg/datasets/synapse.py new file mode 100644 index 0000000000000000000000000000000000000000..6f83b6415046667fb24086083c43083040f4487c --- /dev/null +++ b/mmseg/datasets/synapse.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class SynapseDataset(BaseSegDataset): + """Synapse dataset. + + Before dataset preprocess of Synapse, there are total 13 categories of + foreground which does not include background. After preprocessing, 8 + foreground categories are kept while the other 5 foreground categories are + handled as background. The ``img_suffix`` is fixed to '.jpg' and + ``seg_map_suffix`` is fixed to '.png'. + """ + METAINFO = dict( + classes=('background', 'aorta', 'gallbladder', 'left_kidney', + 'right_kidney', 'liver', 'pancreas', 'spleen', 'stomach'), + palette=[[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], + [0, 255, 255], [255, 0, 255], [255, 255, 0], [60, 255, 255], + [240, 240, 240]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25f4ee4a98733a41a8b38315746e6df3a192e25f --- /dev/null +++ b/mmseg/datasets/transforms/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .formatting import PackSegInputs +from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, + LoadBiomedicalData, LoadBiomedicalImageFromFile, + LoadImageFromNDArray) +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, GenerateEdge, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomMosaic, RandomRotate, RandomRotFlip, Rerange, + ResizeShortestEdge, ResizeToMultiple, RGB2Gray, + SegRescale) + +# yapf: enable +__all__ = [ + 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', + 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', + 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', + 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', + 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', + 'RandomRotFlip' +] diff --git a/mmseg/datasets/transforms/__pycache__/__init__.cpython-310.pyc b/mmseg/datasets/transforms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71f5685b7cc734426e0e252a639d714cf3c2b6f8 Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/datasets/transforms/__pycache__/formatting.cpython-310.pyc b/mmseg/datasets/transforms/__pycache__/formatting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb4bc3cffa5d21b60e4812c38e91102706ceb364 Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/formatting.cpython-310.pyc differ diff --git a/mmseg/datasets/transforms/__pycache__/loading.cpython-310.pyc b/mmseg/datasets/transforms/__pycache__/loading.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d6823ee6032e6456b26e994ef2b895e41e8df25 Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/loading.cpython-310.pyc differ diff --git a/mmseg/datasets/transforms/__pycache__/transforms.cpython-310.pyc b/mmseg/datasets/transforms/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f4481f19369d25c8dfbe9e70a5b897e9207881 Binary files /dev/null and b/mmseg/datasets/transforms/__pycache__/transforms.cpython-310.pyc differ diff --git a/mmseg/datasets/transforms/formatting.py b/mmseg/datasets/transforms/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..89fd8837913f9d8cc0e33dcbb2e500e3c996c5f5 --- /dev/null +++ b/mmseg/datasets/transforms/formatting.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmengine.structures import PixelData + +from mmseg.registry import TRANSFORMS +from mmseg.structures import SegDataSample + + +@TRANSFORMS.register_module() +class PackSegInputs(BaseTransform): + """Pack the inputs data for the semantic segmentation. + + The ``img_meta`` item is always populated. The contents of the + ``img_meta`` dictionary depends on ``meta_keys``. By default this includes: + + - ``img_path``: filename of the image + + - ``ori_shape``: original shape of the image as a tuple (h, w, c) + + - ``img_shape``: shape of the image input to the network as a tuple \ + (h, w, c). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - ``pad_shape``: shape of padded images + + - ``scale_factor``: a float indicating the preprocessing scale + + - ``flip``: a boolean indicating if image flip transform was used + + - ``flip_direction``: the flipping direction + + Args: + meta_keys (Sequence[str], optional): Meta keys to be packed from + ``SegDataSample`` and collected in ``data[img_metas]``. + Default: ``('img_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction')`` + """ + + def __init__(self, + meta_keys=('img_path', 'seg_map_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'reduce_zero_label')): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`SegDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results['inputs'] = img + + data_sample = SegDataSample() + if 'gt_seg_map' in results: + if len(results['gt_seg_map'].shape) == 2: + data = to_tensor(results['gt_seg_map'][None, + ...].astype(np.int64)) + else: + warnings.warn('Please pay attention your ground truth ' + 'segmentation map, usually the segmentation ' + 'map is 2D, but got ' + f'{results["gt_seg_map"].shape}') + data = to_tensor(results['gt_seg_map'].astype(np.int64)) + gt_sem_seg_data = dict(data=data) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + + if 'gt_edge_map' in results: + gt_edge_data = dict( + data=to_tensor(results['gt_edge_map'][None, + ...].astype(np.int64))) + data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data))) + + img_meta = {} + for key in self.meta_keys: + if key in results: + img_meta[key] = results[key] + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/mmseg/datasets/transforms/loading.py b/mmseg/datasets/transforms/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..c5cc8eae9774a42cb77c7520db94c90c304d5e47 --- /dev/null +++ b/mmseg/datasets/transforms/loading.py @@ -0,0 +1,502 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, Optional, Union + +import mmcv +import mmengine.fileio as fileio +import numpy as np +from mmcv.transforms import BaseTransform +from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations +from mmcv.transforms import LoadImageFromFile + +from mmseg.registry import TRANSFORMS +from mmseg.utils import datafrombytes + + +@TRANSFORMS.register_module() +class LoadAnnotations(MMCV_LoadAnnotations): + """Load annotations for semantic segmentation provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of semantic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # in str + 'seg_fields': List + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of semantic segmentation ground truth file. + + Added Keys: + + - seg_fields (List) + - gt_seg_map (np.uint8) + + Args: + reduce_zero_label (bool, optional): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to None. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'pillow'. + backend_args (dict): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__( + self, + label_id_map={}, + reduce_zero_label=None, + backend_args=None, + imdecode_backend='pillow', + ) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + self.label_id_map = label_id_map + self.reduce_zero_label = reduce_zero_label + if self.reduce_zero_label is not None: + warnings.warn('`reduce_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_zero_label=True` when dataset ' + 'initialized') + self.imdecode_backend = imdecode_backend + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + img_bytes = fileio.get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + + if np.any(gt_semantic_seg > 1): + raise ValueError('gt_semantic_seg should not contain value 255.') + + for ori_id, new_id in self.label_id_map.items(): + gt_semantic_seg[gt_semantic_seg == int(ori_id)] = new_id + + # reduce zero_label + if self.reduce_zero_label is None: + self.reduce_zero_label = results['reduce_zero_label'] + assert self.reduce_zero_label == results['reduce_zero_label'], \ + 'Initialize dataset with `reduce_zero_label` as ' \ + f'{results["reduce_zero_label"]} but when load annotation ' \ + f'the `reduce_zero_label` is {self.reduce_zero_label}' + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['seg_fields'].append('gt_seg_map') + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + + +@TRANSFORMS.register_module() +class LoadImageFromNDArray(LoadImageFromFile): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def transform(self, results: dict) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results['img'] + if self.to_float32: + img = img.astype(np.float32) + + results['img_path'] = None + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + +@TRANSFORMS.register_module() +class LoadBiomedicalImageFromFile(BaseTransform): + """Load an biomedical mage from file. + + Required Keys: + + - img_path + + Added Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities, and data type is float32 + if set to_float32 = True, or float64 if decode_backend is 'nifti' and + to_float32 is False. + - img_shape + - ori_shape + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an float64 array. + Defaults to True. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'nifti', + to_xyz: bool = False, + to_float32: bool = True, + backend_args: Optional[dict] = None) -> None: + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.to_float32 = to_float32 + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + + data_bytes = fileio.get(filename, self.backend_args) + img = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + img = img.astype(np.float32) + + if len(img.shape) == 3: + img = img[None, ...] + + if self.decode_backend == 'nifti': + img = img.transpose(0, 3, 2, 1) + + if self.to_xyz: + img = img.transpose(0, 3, 2, 1) + + results['img'] = img + results['img_shape'] = img.shape[1:] + results['ori_shape'] = img.shape[1:] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadBiomedicalAnnotation(BaseTransform): + """Load ``seg_map`` annotation provided by biomedical dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + 'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X) + } + + Required Keys: + + - seg_map_path + + Added Keys: + + - gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by + default, and data type is float32 if set to_float32 = True, or + float64 if decode_backend is 'nifti' and to_float32 is False. + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + to_float32 (bool): Whether to convert the loaded seg map to a float32 + numpy array. If set to False, the loaded image is an float64 array. + Defaults to True. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See :class:`mmengine.fileio` for details. + Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'nifti', + to_xyz: bool = False, + to_float32: bool = True, + backend_args: Optional[dict] = None) -> None: + super().__init__() + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.to_float32 = to_float32 + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + data_bytes = fileio.get(results['seg_map_path'], self.backend_args) + gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + gt_seg_map = gt_seg_map.astype(np.float32) + + if self.decode_backend == 'nifti': + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + if self.to_xyz: + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + results['gt_seg_map'] = gt_seg_map + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadBiomedicalData(BaseTransform): + """Load an biomedical image and annotation from file. + + The loading data format is as the following: + + .. code-block:: python + + { + 'img': np.ndarray data[:-1, X, Y, Z] + 'seg_map': np.ndarray data[-1, X, Y, Z] + } + + + Required Keys: + + - img_path + + Added Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + - img_shape + - ori_shape + + Args: + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Defaults to False. + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + with_seg=False, + decode_backend: str = 'numpy', + to_xyz: bool = False, + backend_args: Optional[dict] = None) -> None: # noqa + self.with_seg = with_seg + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + data_bytes = fileio.get(results['img_path'], self.backend_args) + data = datafrombytes(data_bytes, backend=self.decode_backend) + # img is 4D data (N, X, Y, Z), N is the number of protocol + img = data[:-1, :] + + if self.decode_backend == 'nifti': + img = img.transpose(0, 3, 2, 1) + + if self.to_xyz: + img = img.transpose(0, 3, 2, 1) + + results['img'] = img + results['img_shape'] = img.shape[1:] + results['ori_shape'] = img.shape[1:] + + if self.with_seg: + gt_seg_map = data[-1, :] + if self.decode_backend == 'nifti': + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + if self.to_xyz: + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + results['gt_seg_map'] = gt_seg_map + return results + + def __repr__(self) -> str: + repr_str = (f'{self.__class__.__name__}(' + f'with_seg={self.with_seg}, ' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class InferencerLoader(BaseTransform): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, **kwargs) -> None: + super().__init__() + self.from_file = TRANSFORMS.build( + dict(type='LoadImageFromFile', **kwargs)) + self.from_ndarray = TRANSFORMS.build( + dict(type='LoadImageFromNDArray', **kwargs)) + + def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + if isinstance(single_input, str): + inputs = dict(img_path=single_input) + elif isinstance(single_input, np.ndarray): + inputs = dict(img=single_input) + elif isinstance(single_input, dict): + inputs = single_input + else: + raise NotImplementedError + + if 'img' in inputs: + return self.from_ndarray(inputs) + return self.from_file(inputs) diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7e2a0e66577a6b1c90009c0f8a6e2fba8a228a --- /dev/null +++ b/mmseg/datasets/transforms/transforms.py @@ -0,0 +1,2137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import warnings +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import mmcv +import numpy as np +from mmcv.transforms.base import BaseTransform +from mmcv.transforms.utils import cache_randomness +from mmengine.utils import is_tuple_of +from numpy import random +from scipy.ndimage import gaussian_filter + +from mmseg.datasets.dataset_wrappers import MultiImageMixDataset +from mmseg.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class ResizeToMultiple(BaseTransform): + """Resize images & seg to multiple of divisor. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - pad_shape + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def transform(self, results: dict) -> dict: + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + img = results['img'] + img = mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['pad_shape'] = img.shape[:2] + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str + + +@TRANSFORMS.register_module() +class Rerange(BaseTransform): + """Rerange the image pixel value. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def transform(self, results: dict) -> dict: + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@TRANSFORMS.register_module() +class CLAHE(BaseTransform): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def transform(self, results: dict) -> dict: + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, '\ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomCrop(BaseTransform): + """Random crop the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - gt_seg_map + + + Args: + crop_size (Union[int, Tuple[int, int]]): Expected size after cropping + with the format of (h, w). If set to an integer, then cropping + width and height are equal to this integer. + cat_max_ratio (float): The maximum ratio that single category could + occupy. + ignore_index (int): The label index to be ignored. Default: 255 + """ + + def __init__(self, + crop_size: Union[int, Tuple[int, int]], + cat_max_ratio: float = 1., + ignore_index: int = 255): + super().__init__() + assert isinstance(crop_size, int) or ( + isinstance(crop_size, tuple) and len(crop_size) == 2 + ), 'The expected crop_size is an integer, or a tuple containing two ' + 'intergers' + + if isinstance(crop_size, int): + crop_size = (crop_size, crop_size) + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + @cache_randomness + def crop_bbox(self, results: dict) -> tuple: + """get a crop bounding box. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: Coordinates of the cropped image. + """ + + def generate_crop_bbox(img: np.ndarray) -> tuple: + """Randomly get a crop bounding box. + + Args: + img (np.ndarray): Original input image. + + Returns: + tuple: Coordinates of the cropped image. + """ + + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + img = results['img'] + crop_bbox = generate_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_seg_map'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = generate_crop_bbox(img) + + return crop_bbox + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.crop_bbox(results) + + # crop the image + img = self.crop(img, crop_bbox) + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + results['img'] = img + results['img_shape'] = img.shape[:2] + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@TRANSFORMS.register_module() +class RandomRotate(BaseTransform): + """Rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + @cache_randomness + def generate_degree(self): + return np.random.rand() < self.prob, np.random.uniform( + min(*self.degree), max(*self.degree)) + + def transform(self, results: dict) -> dict: + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate, degree = self.generate_degree() + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@TRANSFORMS.register_module() +class RGB2Gray(BaseTransform): + """Convert RGB image to grayscale image. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def transform(self, results: dict) -> dict: + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + img = results['img'] + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@TRANSFORMS.register_module() +class AdjustGamma(BaseTransform): + """Using gamma correction to process the image. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def transform(self, results: dict) -> dict: + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + +@TRANSFORMS.register_module() +class SegRescale(BaseTransform): + """Rescale semantic segmentation maps. + + Required Keys: + + - gt_seg_map + + Modified Keys: + + - gt_seg_map + + Args: + scale_factor (float): The scale factor of the final output. + """ + + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def transform(self, results: dict) -> dict: + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = mmcv.imrescale( + results[key], self.scale_factor, interpolation='nearest') + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@TRANSFORMS.register_module() +class PhotoMetricDistortion(BaseTransform): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta: int = 32, + contrast_range: Sequence[float] = (0.5, 1.5), + saturation_range: Sequence[float] = (0.5, 1.5), + hue_delta: int = 18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, + img: np.ndarray, + alpha: int = 1, + beta: int = 0) -> np.ndarray: + """Multiple with alpha and add beat with clip. + + Args: + img (np.ndarray): The input image. + alpha (int): Image weights, change the contrast/saturation + of the image. Default: 1 + beta (int): Image bias, change the brightness of the + image. Default: 0 + + Returns: + np.ndarray: The transformed image. + """ + + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img: np.ndarray) -> np.ndarray: + """Brightness distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after brightness change. + """ + + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img: np.ndarray) -> np.ndarray: + """Contrast distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after contrast change. + """ + + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img: np.ndarray) -> np.ndarray: + """Saturation distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after saturation change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img: np.ndarray) -> np.ndarray: + """Hue distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after hue change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def transform(self, results: dict) -> dict: + """Transform function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str + + +@TRANSFORMS.register_module() +class RandomCutOut(BaseTransform): + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): cutout probability. + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + seg_fill_in (int): The labels of pixel to fill in the dropped regions. + If seg_fill_in is None, skip. Default: None. + """ + + def __init__(self, + prob, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0), + seg_fill_in=None): + + assert 0 <= prob and prob <= 1 + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + if seg_fill_in is not None: + assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in + and seg_fill_in <= 255) + self.prob = prob + self.n_holes = n_holes + self.fill_in = fill_in + self.seg_fill_in = seg_fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + @cache_randomness + def do_cutout(self): + return np.random.rand() < self.prob + + @cache_randomness + def generate_patches(self, results): + cutout = self.do_cutout() + + h, w, _ = results['img'].shape + if cutout: + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + else: + n_holes = 0 + x1_lst = [] + y1_lst = [] + index_lst = [] + for _ in range(n_holes): + x1_lst.append(np.random.randint(0, w)) + y1_lst.append(np.random.randint(0, h)) + index_lst.append(np.random.randint(0, len(self.candidates))) + return cutout, n_holes, x1_lst, y1_lst, index_lst + + def transform(self, results: dict) -> dict: + """Call function to drop some regions of image.""" + cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches( + results) + if cutout: + h, w, c = results['img'].shape + for i in range(n_holes): + x1 = x1_lst[i] + y1 = y1_lst[i] + index = index_lst[i] + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in}, ' + repr_str += f'seg_fill_in={self.seg_fill_in})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomRotFlip(BaseTransform): + """Rotate and flip the image & seg or just rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + rotate_prob (float): The probability of rotate image. + flip_prob (float): The probability of rotate&flip image. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + """ + + def __init__(self, rotate_prob=0.5, flip_prob=0.5, degree=(-20, 20)): + self.rotate_prob = rotate_prob + self.flip_prob = flip_prob + assert 0 <= rotate_prob <= 1 and 0 <= flip_prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + + def random_rot_flip(self, results: dict) -> dict: + k = np.random.randint(0, 4) + results['img'] = np.rot90(results['img'], k) + for key in results.get('seg_fields', []): + results[key] = np.rot90(results[key], k) + axis = np.random.randint(0, 2) + results['img'] = np.flip(results['img'], axis=axis).copy() + for key in results.get('seg_fields', []): + results[key] = np.flip(results[key], axis=axis).copy() + return results + + def random_rotate(self, results: dict) -> dict: + angle = np.random.uniform(min(*self.degree), max(*self.degree)) + results['img'] = mmcv.imrotate(results['img'], angle=angle) + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate(results[key], angle=angle) + return results + + def transform(self, results: dict) -> dict: + """Call function to rotate or rotate & flip image, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated or rotated & flipped results. + """ + rotate_flag = 0 + if random.random() < self.rotate_prob: + results = self.random_rotate(results) + rotate_flag = 1 + if random.random() < self.flip_prob and rotate_flag == 0: + results = self.random_rot_flip(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(rotate_prob={self.rotate_prob}, ' \ + f'flip_prob={self.flip_prob}, ' \ + f'degree={self.degree})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomMosaic(BaseTransform): + """Mosaic augmentation. Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Required Keys: + + - img + - gt_seg_map + - mix_results + + Modified Keys: + + - img + - img_shape + - ori_shape + - gt_seg_map + + Args: + prob (float): mosaic probability. + img_scale (Sequence[int]): Image size after mosaic pipeline of + a single image. The size of the output image is four times + that of a single image. The output image comprises 4 single images. + Default: (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Default: (0.5, 1.5). + pad_val (int): Pad value. Default: 0. + seg_pad_val (int): Pad value of segmentation map. Default: 255. + """ + + def __init__(self, + prob, + img_scale=(640, 640), + center_ratio_range=(0.5, 1.5), + pad_val=0, + seg_pad_val=255): + assert 0 <= prob and prob <= 1 + assert isinstance(img_scale, tuple) + self.prob = prob + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + @cache_randomness + def do_mosaic(self): + return np.random.rand() < self.prob + + def transform(self, results: dict) -> dict: + """Call function to make a mosaic of image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with mosaic transformed. + """ + mosaic = self.do_mosaic() + if mosaic: + results = self._mosaic_transform_img(results) + results = self._mosaic_transform_seg(results) + return results + + def get_indices(self, dataset: MultiImageMixDataset) -> list: + """Call function to collect indices. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indices. + """ + + indices = [random.randint(0, len(dataset)) for _ in range(3)] + return indices + + @cache_randomness + def generate_mosaic_center(self): + # mosaic center x, y + center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + return center_x, center_y + + def _mosaic_transform_img(self, results: dict) -> dict: + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + if len(results['img'].shape) == 3: + c = results['img'].shape[2] + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), c), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + self.center_x, self.center_y = self.generate_mosaic_center() + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + img_i = result_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape + results['ori_shape'] = mosaic_img.shape + + return results + + def _mosaic_transform_seg(self, results: dict) -> dict: + """Mosaic transform function for label annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + for key in results.get('seg_fields', []): + mosaic_seg = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.seg_pad_val, + dtype=results[key].dtype) + + # mosaic center x, y + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + gt_seg_i = result_patch[key] + h_i, w_i = gt_seg_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + gt_seg_i = mmcv.imresize( + gt_seg_i, + (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)), + interpolation='nearest') + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, gt_seg_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c, + x1_c:x2_c] + + results[key] = mosaic_seg + + return results + + def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float], + img_shape_wh: Sequence[int]) -> tuple: + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + + assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') + if loc == 'top_left': + # index0 to top left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + center_position_xy[0], \ + center_position_xy[1] + crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( + y2 - y1), img_shape_wh[0], img_shape_wh[1] + + elif loc == 'top_right': + # index1 to top right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + center_position_xy[1] + crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( + img_shape_wh[0], x2 - x1), img_shape_wh[1] + + elif loc == 'bottom_left': + # index2 to bottom left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + center_position_xy[1], \ + center_position_xy[0], \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( + y2 - y1, img_shape_wh[1]) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + center_position_xy[1], \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = 0, 0, min(img_shape_wh[0], + x2 - x1), min(y2 - y1, img_shape_wh[1]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'seg_pad_val={self.pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class GenerateEdge(BaseTransform): + """Generate Edge for CE2P approach. + + Edge will be used to calculate loss of + `CE2P `_. + + Modified from https://github.com/liutinglt/CE2P/blob/master/dataset/target_generation.py # noqa:E501 + + Required Keys: + + - img_shape + - gt_seg_map + + Added Keys: + - gt_edge_map (np.ndarray, uint8): The edge annotation generated from the + seg map by extracting border between different semantics. + + Args: + edge_width (int): The width of edge. Default to 3. + ignore_index (int): Index that will be ignored. Default to 255. + """ + + def __init__(self, edge_width: int = 3, ignore_index: int = 255) -> None: + super().__init__() + self.edge_width = edge_width + self.ignore_index = ignore_index + + def transform(self, results: Dict) -> Dict: + """Call function to generate edge from segmentation map. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with edge mask. + """ + h, w = results['img_shape'] + edge = np.zeros((h, w), dtype=np.uint8) + seg_map = results['gt_seg_map'] + + # down + edge_down = edge[1:h, :] + edge_down[(seg_map[1:h, :] != seg_map[:h - 1, :]) + & (seg_map[1:h, :] != self.ignore_index) & + (seg_map[:h - 1, :] != self.ignore_index)] = 1 + # left + edge_left = edge[:, :w - 1] + edge_left[(seg_map[:, :w - 1] != seg_map[:, 1:w]) + & (seg_map[:, :w - 1] != self.ignore_index) & + (seg_map[:, 1:w] != self.ignore_index)] = 1 + # up_left + edge_upleft = edge[:h - 1, :w - 1] + edge_upleft[(seg_map[:h - 1, :w - 1] != seg_map[1:h, 1:w]) + & (seg_map[:h - 1, :w - 1] != self.ignore_index) & + (seg_map[1:h, 1:w] != self.ignore_index)] = 1 + # up_right + edge_upright = edge[:h - 1, 1:w] + edge_upright[(seg_map[:h - 1, 1:w] != seg_map[1:h, :w - 1]) + & (seg_map[:h - 1, 1:w] != self.ignore_index) & + (seg_map[1:h, :w - 1] != self.ignore_index)] = 1 + + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, + (self.edge_width, self.edge_width)) + edge = cv2.dilate(edge, kernel) + + results['gt_edge_map'] = edge + results['edge_width'] = self.edge_width + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'edge_width={self.edge_width}, ' + repr_str += f'ignore_index={self.ignore_index})' + return repr_str + + +@TRANSFORMS.register_module() +class ResizeShortestEdge(BaseTransform): + """Resize the image and mask while keeping the aspect ratio unchanged. + + Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501 + Copyright (c) Facebook, Inc. and its affiliates. + Licensed under the Apache-2.0 License + + This transform attempts to scale the shorter edge to the given + `scale`, as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer + edge does not exceed `max_size`. + + Required Keys: + + - img + - gt_seg_map (optional) + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional)) + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + + Args: + scale (Union[int, Tuple[int, int]]): The target short edge length. + If it's tuple, will select the min value as the short edge length. + max_size (int): The maximum allowed longest edge length. + """ + + def __init__(self, scale: Union[int, Tuple[int, int]], + max_size: int) -> None: + super().__init__() + self.scale = scale + self.max_size = max_size + + # Create a empty Resize object + self.resize = TRANSFORMS.build({ + 'type': 'Resize', + 'scale': 0, + 'keep_ratio': True + }) + + def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]: + """Compute the target image shape with the given `short_edge_length`. + + Args: + img (np.ndarray): The input image. + short_edge_length (Union[int, Tuple[int, int]]): The target short + edge length. If it's tuple, will select the min value as the + short edge length. + """ + h, w = img.shape[:2] + if isinstance(short_edge_length, int): + size = short_edge_length * 1.0 + elif isinstance(short_edge_length, tuple): + size = min(short_edge_length) * 1.0 + scale = size / min(h, w) + if h < w: + new_h, new_w = size, scale * w + else: + new_h, new_w = scale * h, size + + if max(new_h, new_w) > self.max_size: + scale = self.max_size * 1.0 / max(new_h, new_w) + new_h *= scale + new_w *= scale + + new_h = int(new_h + 0.5) + new_w = int(new_w + 0.5) + return (new_w, new_h) + + def transform(self, results: Dict) -> Dict: + self.resize.scale = self._get_output_shape(results['img'], self.scale) + return self.resize(results) + + +@TRANSFORMS.register_module() +class BioMedical3DRandomCrop(BaseTransform): + """Crop the input patch for medical image & segmentation mask. + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + - gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask + with shape (Z, Y, X). + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional) + + Args: + crop_shape (Union[int, Tuple[int, int, int]]): Expected size after + cropping with the format of (z, y, x). If set to an integer, + then cropping width and height are equal to this integer. + keep_foreground (bool): If keep_foreground is True, it will sample a + voxel of foreground classes randomly, and will take it as the + center of the crop bounding-box. Default to True. + """ + + def __init__(self, + crop_shape: Union[int, Tuple[int, int, int]], + keep_foreground: bool = True): + super().__init__() + assert isinstance(crop_shape, int) or ( + isinstance(crop_shape, tuple) and len(crop_shape) == 3 + ), 'The expected crop_shape is an integer, or a tuple containing ' + 'three integers' + + if isinstance(crop_shape, int): + crop_shape = (crop_shape, crop_shape, crop_shape) + assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0 + self.crop_shape = crop_shape + self.keep_foreground = keep_foreground + + def random_sample_location(self, seg_map: np.ndarray) -> dict: + """sample foreground voxel when keep_foreground is True. + + Args: + seg_map (np.ndarray): gt seg map. + + Returns: + dict: Coordinates of selected foreground voxel. + """ + num_samples = 10000 + # at least 1% of the class voxels need to be selected, + # otherwise it may be too sparse + min_percent_coverage = 0.01 + class_locs = {} + foreground_classes = [] + all_classes = np.unique(seg_map) + for c in all_classes: + if c == 0: + # to avoid the segmentation mask full of background 0 + # and the class_locs is just void dictionary {} when it return + # there add a void list for background 0. + class_locs[c] = [] + else: + all_locs = np.argwhere(seg_map == c) + target_num_samples = min(num_samples, len(all_locs)) + target_num_samples = max( + target_num_samples, + int(np.ceil(len(all_locs) * min_percent_coverage))) + + selected = all_locs[np.random.choice( + len(all_locs), target_num_samples, replace=False)] + class_locs[c] = selected + foreground_classes.append(c) + + selected_voxel = None + if len(foreground_classes) > 0: + selected_class = np.random.choice(foreground_classes) + voxels_of_that_class = class_locs[selected_class] + selected_voxel = voxels_of_that_class[np.random.choice( + len(voxels_of_that_class))] + + return selected_voxel + + def random_generate_crop_bbox(self, margin_z: int, margin_y: int, + margin_x: int) -> tuple: + """Randomly get a crop bounding box. + + Args: + seg_map (np.ndarray): Ground truth segmentation map. + + Returns: + tuple: Coordinates of the cropped image. + """ + offset_z = np.random.randint(0, margin_z + 1) + offset_y = np.random.randint(0, margin_y + 1) + offset_x = np.random.randint(0, margin_x + 1) + crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0] + crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1] + crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2] + + return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 + + def generate_margin(self, results: dict) -> tuple: + """Generate margin of crop bounding-box. + + If keep_foreground is True, it will sample a voxel of foreground + classes randomly, and will take it as the center of the bounding-box, + and return the margin between of the bounding-box and image. + If keep_foreground is False, it will return the difference from crop + shape and image shape. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: The margin for 3 dimensions of crop bounding-box and image. + """ + + seg_map = results['gt_seg_map'] + if self.keep_foreground: + selected_voxel = self.random_sample_location(seg_map) + if selected_voxel is None: + # this only happens if some image does not contain + # foreground voxels at all + warnings.warn(f'case does not contain any foreground classes' + f': {results["img_path"]}') + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + else: + margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2) + margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2) + margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2) + margin_z = max( + 0, min(seg_map.shape[0] - self.crop_shape[0], margin_z)) + margin_y = max( + 0, min(seg_map.shape[1] - self.crop_shape[1], margin_y)) + margin_x = max( + 0, min(seg_map.shape[2] - self.crop_shape[2], margin_x)) + else: + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + + return margin_z, margin_y, margin_x + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + if len(img.shape) == 3: + # crop seg map + img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + else: + # crop image + assert len(img.shape) == 4 + img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + margin = self.generate_margin(results) + crop_bbox = self.random_generate_crop_bbox(*margin) + + # crop the image + img = results['img'] + results['img'] = self.crop(img, crop_bbox) + results['img_shape'] = results['img'].shape[1:] + + # crop semantic seg + seg_map = results['gt_seg_map'] + results['gt_seg_map'] = self.crop(seg_map, crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' + + +@TRANSFORMS.register_module() +class BioMedicalGaussianNoise(BaseTransform): + """Add random Gaussian noise to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + prob (float): Probability to add Gaussian noise for + each sample. Default to 0.1. + mean (float): Mean or “centre” of the distribution. Default to 0.0. + std (float): Standard deviation of distribution. Default to 0.1. + """ + + def __init__(self, + prob: float = 0.1, + mean: float = 0.0, + std: float = 0.1) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 and std >= 0.0 + self.prob = prob + self.mean = mean + self.std = std + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian noise to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + rand_std = np.random.uniform(0, self.std) + noise = np.random.normal( + self.mean, rand_std, size=results['img'].shape) + # noise is float64 array, convert to the results['img'].dtype + noise = noise.astype(results['img'].dtype) + results['img'] = results['img'] + noise + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'mean={self.mean}, ' + repr_str += f'std={self.std})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalGaussianBlur(BaseTransform): + """Add Gaussian blur with random sigma to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + sigma_range (Tuple[float, float]|float): range to randomly + select sigma value. Default to (0.5, 1.0). + prob (float): Probability to apply Gaussian blur + for each sample. Default to 0.2. + prob_per_channel (float): Probability to apply Gaussian blur + for each channel (axis N of the image). Default to 0.5. + different_sigma_per_channel (bool): whether to use different + sigma for each channel (axis N of the image). Default to True. + different_sigma_per_axis (bool): whether to use different + sigma for axis Z, X and Y of the image. Default to True. + """ + + def __init__(self, + sigma_range: Tuple[float, float] = (0.5, 1.0), + prob: float = 0.2, + prob_per_channel: float = 0.5, + different_sigma_per_channel: bool = True, + different_sigma_per_axis: bool = True) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 + assert 0.0 <= prob_per_channel <= 1.0 + assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2 + self.sigma_range = sigma_range + self.prob = prob + self.prob_per_channel = prob_per_channel + self.different_sigma_per_channel = different_sigma_per_channel + self.different_sigma_per_axis = different_sigma_per_axis + + def _get_valid_sigma(self, value_range) -> Tuple[float, ...]: + """Ensure the `value_range` to be either a single value or a sequence + of two values. If the `value_range` is a sequence, generate a random + value with `[value_range[0], value_range[1]]` based on uniform + sampling. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501 + + Args: + value_range (tuple|list|float|int): the input value range + """ + if (isinstance(value_range, (list, tuple))): + if (value_range[0] == value_range[1]): + value = value_range[0] + else: + orig_type = type(value_range[0]) + value = np.random.uniform(value_range[0], value_range[1]) + value = orig_type(value) + return value + + def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray: + """Random generate sigma and apply Gaussian Blur to the data + Args: + data_sample (np.ndarray): data sample with multiple modalities, + the data shape is (N, Z, Y, X) + """ + sigma = None + for c in range(data_sample.shape[0]): + if np.random.rand() < self.prob_per_channel: + # if no `sigma` is generated, generate one + # if `self.different_sigma_per_channel` is True, + # re-generate random sigma for each channel + if (sigma is None or self.different_sigma_per_channel): + if (not self.different_sigma_per_axis): + sigma = self._get_valid_sigma(self.sigma_range) + else: + sigma = [ + self._get_valid_sigma(self.sigma_range) + for _ in data_sample.shape[1:] + ] + # apply gaussian filter with `sigma` + data_sample[c] = gaussian_filter( + data_sample[c], sigma, order=0) + return data_sample + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian blur to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + results['img'] = self._gaussian_blur(results['img']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'prob_per_channel={self.prob_per_channel}, ' + repr_str += f'sigma_range={self.sigma_range}, ' + repr_str += 'different_sigma_per_channel='\ + f'{self.different_sigma_per_channel}, ' + repr_str += 'different_sigma_per_axis='\ + f'{self.different_sigma_per_axis})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalRandomGamma(BaseTransform): + """Using random gamma correction to process the biomedical image. + + Modified from + https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501 + With licence: Apache 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + - img + + Args: + prob (float): The probability to perform this transform. Default: 0.5. + gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2). + invert_image (bool): Whether invert the image before applying gamma + augmentation. Default: False. + per_channel (bool): Whether perform the transform each channel + individually. Default: False + retain_stats (bool): Gamma transformation will alter the mean and std + of the data in the patch. If retain_stats=True, the data will be + transformed to match the mean and standard deviation before gamma + augmentation. Default: False. + """ + + def __init__(self, + prob: float = 0.5, + gamma_range: Tuple[float] = (0.5, 2), + invert_image: bool = False, + per_channel: bool = False, + retain_stats: bool = False): + assert 0 <= prob and prob <= 1 + assert isinstance(gamma_range, tuple) and len(gamma_range) == 2 + assert isinstance(invert_image, bool) + assert isinstance(per_channel, bool) + assert isinstance(retain_stats, bool) + self.prob = prob + self.gamma_range = gamma_range + self.invert_image = invert_image + self.per_channel = per_channel + self.retain_stats = retain_stats + + @cache_randomness + def _do_gamma(self): + """Whether do adjust gamma for image.""" + return np.random.rand() < self.prob + + def _adjust_gamma(self, img: np.array): + """Gamma adjustment for image. + + Args: + img (np.array): Input image before gamma adjust. + + Returns: + np.arrays: Image after gamma adjust. + """ + + if self.invert_image: + img = -img + + def _do_adjust(img): + if retain_stats_here: + img_mean = img.mean() + img_std = img.std() + if np.random.random() < 0.5 and self.gamma_range[0] < 1: + gamma = np.random.uniform(self.gamma_range[0], 1) + else: + gamma = np.random.uniform( + max(self.gamma_range[0], 1), self.gamma_range[1]) + img_min = img.min() + img_range = img.max() - img_min # range + img = np.power(((img - img_min) / float(img_range + 1e-7)), + gamma) * img_range + img_min + if retain_stats_here: + img = img - img.mean() + img = img / (img.std() + 1e-8) * img_std + img = img + img_mean + return img + + if not self.per_channel: + retain_stats_here = self.retain_stats + img = _do_adjust(img) + else: + for c in range(img.shape[0]): + img[c] = _do_adjust(img[c]) + if self.invert_image: + img = -img + return img + + def transform(self, results: dict) -> dict: + """Call function to perform random gamma correction + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with random gamma correction performed. + """ + do_gamma = self._do_gamma() + + if do_gamma: + results['img'] = self._adjust_gamma(results['img']) + else: + pass + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'gamma_range={self.gamma_range},' + repr_str += f'invert_image={self.invert_image},' + repr_str += f'per_channel={self.per_channel},' + repr_str += f'retain_stats={self.retain_stats}' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DPad(BaseTransform): + """Pad the biomedical 3d image & biomedical 3d semantic segmentation maps. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - pad_shape (Tuple[int, int, int]): The padded shape. + + Args: + pad_shape (Tuple[int, int, int]): Fixed padding size. + Expected padding shape (Z, Y, X). + pad_val (float): Padding value for biomedical image. + The padding mode is set to "constant". The value + to be filled in padding area. Default: 0. + seg_pad_val (int): Padding value for biomedical 3d semantic + segmentation maps. The padding mode is set to "constant". + The value to be filled in padding area. Default: 0. + """ + + def __init__(self, + pad_shape: Tuple[int, int, int], + pad_val: float = 0., + seg_pad_val: int = 0) -> None: + + # check pad_shape + assert pad_shape is not None + if not isinstance(pad_shape, tuple): + assert len(pad_shape) == 3 + + self.pad_shape = pad_shape + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def _pad_img(self, results: dict) -> None: + """Pad images according to ``self.pad_shape`` + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: The dict contains the padded image and shape + information. + """ + padded_img = self._to_pad( + results['img'], pad_shape=self.pad_shape, pad_val=self.pad_val) + + results['img'] = padded_img + results['pad_shape'] = padded_img.shape[1:] + + def _pad_seg(self, results: dict) -> None: + """Pad semantic segmentation map according to ``self.pad_shape`` if + ``gt_seg_map`` is not None in results dict. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Update the padded gt seg map in dict. + """ + if results.get('gt_seg_map', None) is not None: + pad_gt_seg = self._to_pad( + results['gt_seg_map'][None, ...], + pad_shape=results['pad_shape'], + pad_val=self.seg_pad_val) + results['gt_seg_map'] = pad_gt_seg[1:] + + @staticmethod + def _to_pad(img: np.ndarray, + pad_shape: Tuple[int, int, int], + pad_val: Union[int, float] = 0) -> np.ndarray: + """Pad the given 3d image to a certain shape with specified padding + value. + + Args: + img (ndarray): Biomedical image with shape (N, Z, Y, X) + to be padded. N is the number of modalities. + pad_shape (Tuple[int,int,int]): Expected padding shape (Z, Y, X). + pad_val (float, int): Values to be filled in padding areas + and the padding_mode is set to 'constant'. Default: 0. + + Returns: + ndarray: The padded image. + """ + # compute pad width + d = max(pad_shape[0] - img.shape[1], 0) + pad_d = (d // 2, d - d // 2) + h = max(pad_shape[1] - img.shape[2], 0) + pad_h = (h // 2, h - h // 2) + w = max(pad_shape[2] - img.shape[2], 0) + pad_w = (w // 2, w - w // 2) + + pad_list = [(0, 0), pad_d, pad_h, pad_w] + + img = np.pad(img, pad_list, mode='constant', constant_values=pad_val) + return img + + def transform(self, results: dict) -> dict: + """Call function to pad images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + self._pad_img(results) + self._pad_seg(results) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'pad_shape={self.pad_shape}, ' + repr_str += f'pad_val={self.pad_val}), ' + repr_str += f'seg_pad_val={self.seg_pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DRandomFlip(BaseTransform): + """Flip biomedical 3D images and segmentations. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/spatial_transforms.py # noqa:E501 + + Copyright 2021 Division of + Medical Image Computing, German Cancer Research Center (DKFZ) and Applied + Computer Vision Lab, Helmholtz Imaging Platform. + Licensed under the Apache-2.0 License. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - do_flip + - flip_axes + + Args: + prob (float): Flipping probability. + axes (Tuple[int, ...]): Flipping axes with order 'ZXY'. + swap_label_pairs (Optional[List[Tuple[int, int]]]): + The segmentation label pairs that are swapped when flipping. + """ + + def __init__(self, + prob: float, + axes: Tuple[int, ...], + swap_label_pairs: Optional[List[Tuple[int, int]]] = None): + self.prob = prob + self.axes = axes + self.swap_label_pairs = swap_label_pairs + assert prob >= 0 and prob <= 1 + if axes is not None: + assert max(axes) <= 2 + + @staticmethod + def _flip(img, direction: Tuple[bool, bool, bool]) -> np.ndarray: + if direction[0]: + img[:, :] = img[:, ::-1] + if direction[1]: + img[:, :, :] = img[:, :, ::-1] + if direction[2]: + img[:, :, :, :] = img[:, :, :, ::-1] + return img + + def _do_flip(self, img: np.ndarray) -> Tuple[bool, bool, bool]: + """Call function to determine which axis to flip. + + Args: + img (np.ndarry): Image or segmentation map array. + Returns: + tuple: Flip action, whether to flip on the z, x, and y axes. + """ + flip_c, flip_x, flip_y = False, False, False + if self.axes is not None: + flip_c = 0 in self.axes and np.random.rand() < self.prob + flip_x = 1 in self.axes and np.random.rand() < self.prob + if len(img.shape) == 4: + flip_y = 2 in self.axes and np.random.rand() < self.prob + return flip_c, flip_x, flip_y + + def _swap_label(self, seg: np.ndarray) -> np.ndarray: + out = seg.copy() + for first, second in self.swap_label_pairs: + first_area = (seg == first) + second_area = (seg == second) + out[first_area] = second + out[second_area] = first + return out + + def transform(self, results: Dict) -> Dict: + """Call function to flip and swap pair labels. + + Args: + results (dict): Result dict. + Returns: + dict: Flipped results, 'do_flip', 'flip_axes' keys are added into + result dict. + """ + # get actual flipped axis + if 'do_flip' not in results: + results['do_flip'] = self._do_flip(results['img']) + if 'flip_axes' not in results: + results['flip_axes'] = self.axes + # flip image + results['img'] = self._flip( + results['img'], direction=results['do_flip']) + # flip seg + if results['gt_seg_map'] is not None: + if results['gt_seg_map'].shape != results['img'].shape: + results['gt_seg_map'] = results['gt_seg_map'][None, :] + results['gt_seg_map'] = self._flip( + results['gt_seg_map'], direction=results['do_flip']) + results['gt_seg_map'] = results['gt_seg_map'].squeeze() + # swap label pairs + if self.swap_label_pairs is not None: + results['gt_seg_map'] = self._swap_label(results['gt_seg_map']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, axes={self.axes}, ' \ + f'swap_label_pairs={self.swap_label_pairs})' + return repr_str diff --git a/mmseg/datasets/voc.py b/mmseg/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5d6025c03760953a82f80e337185afc51f1386 --- /dev/null +++ b/mmseg/datasets/voc.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmengine.fileio as fileio + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + + +@DATASETS.register_module() +class PascalVOCDataset(BaseSegDataset): + """Pascal VOC dataset. + + Args: + split (str): Split txt file for Pascal VOC. + """ + METAINFO = dict( + classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat', + 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', + 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', + 'sofa', 'train', 'tvmonitor'), + palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], + [0, 64, 128]]) + + def __init__(self, + ann_file, + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + ann_file=ann_file, + **kwargs) + assert fileio.exists(self.data_prefix['img_path'], + self.backend_args) and osp.isfile(self.ann_file) diff --git a/mmseg/engine/__init__.py b/mmseg/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ada40570121dfe29f44808d4f3afd685e3054b5a --- /dev/null +++ b/mmseg/engine/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import SegVisualizationHook +from .optimizers import (LayerDecayOptimizerConstructor, + LearningRateDecayOptimizerConstructor) + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', + 'SegVisualizationHook' +] diff --git a/mmseg/engine/hooks/__init__.py b/mmseg/engine/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6048088a7fd322890ced17569e855acee826eca --- /dev/null +++ b/mmseg/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .visualization_hook import SegVisualizationHook + +__all__ = ['SegVisualizationHook'] diff --git a/mmseg/engine/hooks/visualization_hook.py b/mmseg/engine/hooks/visualization_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7c97afe8a1a9c8f7ce6f4fffdcbd3d60f018a5 --- /dev/null +++ b/mmseg/engine/hooks/visualization_hook.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Optional, Sequence + +import mmcv +import mmengine.fileio as fileio +from mmengine.hooks import Hook +from mmengine.runner import Runner + +from mmseg.registry import HOOKS +from mmseg.structures import SegDataSample +from mmseg.visualization import SegLocalVisualizer + + +@HOOKS.register_module() +class SegVisualizationHook(Hook): + """Segmentation Visualization Hook. Used to visualize validation and + testing process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + draw: bool = False, + interval: int = 50, + show: bool = False, + wait_time: float = 0., + backend_args: Optional[dict] = None): + self._visualizer: SegLocalVisualizer = \ + SegLocalVisualizer.get_current_instance() + self.interval = interval + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.backend_args = backend_args.copy() if backend_args else None + self.draw = draw + if not self.draw: + warnings.warn('The draw is False, it means that the ' + 'hook for visualization will not take ' + 'effect. The results will NOT be ' + 'visualized or stored.') + + def _after_iter(self, + runner: Runner, + batch_idx: int, + data_batch: dict, + outputs: Sequence[SegDataSample], + mode: str = 'val') -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): Outputs from model. + mode (str): mode (str): Current mode of runner. Defaults to 'val'. + """ + if self.draw is False or mode == 'train': + return + + if self.every_n_inner_iters(batch_idx, self.interval): + for output in outputs: + img_path = output.img_path + img_bytes = fileio.get( + img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + window_name = f'{mode}_{osp.basename(img_path)}' + + self._visualizer.add_datasample( + window_name, + img, + data_sample=output, + show=self.show, + wait_time=self.wait_time, + step=runner.iter) diff --git a/mmseg/engine/optimizers/__init__.py b/mmseg/engine/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fbf4ecfcd4d1f0834322e2964b55d9637c844ba --- /dev/null +++ b/mmseg/engine/optimizers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .layer_decay_optimizer_constructor import ( + LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' +] diff --git a/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py b/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..fdae3ca698c65879056b969f04185f80452ff8d0 --- /dev/null +++ b/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py @@ -0,0 +1,207 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import warnings + +from mmengine.dist import get_dist_info +from mmengine.logging import print_log +from mmengine.optim import DefaultOptimWrapperConstructor + +from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +def get_layer_id_for_convnext(var_name, max_layer_id): + """Get the layer id to set the different learning rates in ``layer_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_layer_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + stage_id = int(var_name.split('.')[2]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + block_id = int(var_name.split('.')[3]) + if stage_id == 0: + layer_id = 1 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + else: + return max_layer_id + 1 + + +def get_stage_id_for_convnext(var_name, max_stage_id): + """Get the stage id to set the different learning rates in ``stage_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_stage_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + return 0 + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + return stage_id + 1 + else: + return max_stage_id - 1 + + +def get_layer_id_for_vit(var_name, max_layer_id): + """Get the layer id to set the different learning rates. + + Args: + var_name (str): The key of the model. + num_max_layer (int): Maximum number of backbone layers. + + Returns: + int: Returns the layer id of the key. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.patch_embed'): + return 0 + elif var_name.startswith('backbone.layers'): + layer_id = int(var_name.split('.')[2]) + return layer_id + 1 + else: + return max_layer_id - 1 + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for ConvNeXt, + BEiT and MAE. + """ + + def add_params(self, params, module, **kwargs): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + """ + + parameter_groups = {} + print_log(f'self.paramwise_cfg is {self.paramwise_cfg}') + num_layers = self.paramwise_cfg.get('num_layers') + 2 + decay_rate = self.paramwise_cfg.get('decay_rate') + decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') + print_log('Build LearningRateDecayOptimizerConstructor ' + f'{decay_type} {decay_rate} - {num_layers}') + weight_decay = self.base_wd + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias') or name in ( + 'pos_embed', 'cls_token'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + if 'layer_wise' in decay_type: + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_convnext( + name, self.paramwise_cfg.get('num_layers')) + print_log(f'set param {name} as id {layer_id}') + elif 'BEiT' in module.backbone.__class__.__name__ or \ + 'MAE' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_vit(name, num_layers) + print_log(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + elif decay_type == 'stage_wise': + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_stage_id_for_convnext(name, num_layers) + print_log(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + group_name = f'layer_{layer_id}_{group_name}' + + if group_name not in parameter_groups: + scale = decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + print_log(f'Param groups = {json.dumps(to_display, indent=2)}') + params.extend(parameter_groups.values()) + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for BEiT, + and it will be deprecated. + Please use ``LearningRateDecayOptimizerConstructor`` instead. + """ + + def __init__(self, optim_wrapper_cfg, paramwise_cfg): + warnings.warn('DeprecationWarning: Original ' + 'LayerDecayOptimizerConstructor of BEiT ' + 'will be deprecated. Please use ' + 'LearningRateDecayOptimizerConstructor instead, ' + 'and set decay_type = layer_wise_vit in paramwise_cfg.') + paramwise_cfg.update({'decay_type': 'layer_wise_vit'}) + warnings.warn('DeprecationWarning: Layer_decay_rate will ' + 'be deleted, please use decay_rate instead.') + paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate') + super().__init__(optim_wrapper_cfg, paramwise_cfg) diff --git a/mmseg/evaluation/__init__.py b/mmseg/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a82008f3ad3148a23e297f8ad8c22d968f285968 --- /dev/null +++ b/mmseg/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .metrics import CityscapesMetric, IoUMetric + +__all__ = ['IoUMetric', 'CityscapesMetric'] diff --git a/mmseg/evaluation/metrics/__init__.py b/mmseg/evaluation/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa39e480cdb930a9f8b9550d84ee60c474aca4b --- /dev/null +++ b/mmseg/evaluation/metrics/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .citys_metric import CityscapesMetric +from .iou_metric import IoUMetric + +__all__ = ['IoUMetric', 'CityscapesMetric'] diff --git a/mmseg/evaluation/metrics/citys_metric.py b/mmseg/evaluation/metrics/citys_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..32984653c3fa9c13d8c6a7402033001012b5031f --- /dev/null +++ b/mmseg/evaluation/metrics/citys_metric.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import shutil +from collections import OrderedDict +from typing import Dict, Optional, Sequence + +try: + + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + import cityscapesscripts.helpers.labels as CSLabels +except ImportError: + CSLabels = None + CSEval = None + +import numpy as np +from mmengine.dist import is_main_process, master_only +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class CityscapesMetric(BaseMetric): + """Cityscapes evaluation metric. + + Args: + output_dir (str): The directory for output prediction + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + keep_results (bool): Whether to keep the results. When ``format_only`` + is True, ``keep_results`` must be True. Defaults to False. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + output_dir: str, + ignore_index: int = 255, + format_only: bool = False, + keep_results: bool = False, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + if CSEval is None: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + self.output_dir = output_dir + self.ignore_index = ignore_index + + self.format_only = format_only + if format_only: + assert keep_results, ( + 'When format_only is True, the results must be keep, please ' + f'set keep_results as True, but got {keep_results}') + self.keep_results = keep_results + self.prefix = prefix + if is_main_process(): + mkdir_or_exist(self.output_dir) + + @master_only + def __del__(self) -> None: + """Clean up.""" + if not self.keep_results: + shutil.rmtree(self.output_dir) + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + mkdir_or_exist(self.output_dir) + + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy() + # when evaluating with official cityscapesscripts, + # labelIds should be used + pred_label = self._convert_to_label_id(pred_label) + basename = osp.splitext(osp.basename(data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output = Image.fromarray(pred_label.astype(np.uint8)).convert('P') + output.save(png_filename) + if self.format_only: + # format_only always for test dataset without ground truth + gt_filename = '' + else: + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + gt_filename = data_sample['seg_map_path'].replace( + 'labelTrainIds.png', 'labelIds.png') + self.results.append((png_filename, gt_filename)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): Testing results of the dataset. + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + + msg = 'Evaluating in Cityscapes style' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + eval_results = dict() + print_log( + f'Evaluating results under {self.output_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(self.output_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + pred_list, gt_list = zip(*results) + metric = dict() + eval_results.update( + CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args)) + metric['averageScoreCategories'] = eval_results[ + 'averageScoreCategories'] + metric['averageScoreInstCategories'] = eval_results[ + 'averageScoreInstCategories'] + return metric + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy diff --git a/mmseg/evaluation/metrics/iou_metric.py b/mmseg/evaluation/metrics/iou_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..16014c74001d7295f9fff8f03ef185077e3f613b --- /dev/null +++ b/mmseg/evaluation/metrics/iou_metric.py @@ -0,0 +1,286 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence + +import numpy as np +import torch +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image +from prettytable import PrettyTable + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class IoUMetric(BaseMetric): + """IoU evaluation metric. + + Args: + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + iou_metrics (list[str] | str): Metrics to be calculated, the options + includes 'mIoU', 'mDice' and 'mFscore'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + output_dir (str): The directory for output prediction. Defaults to + None. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to save the result + to a specific format and submit it to the test server. + Defaults to False. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + ignore_index: int = 255, + iou_metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1, + collect_device: str = 'cpu', + output_dir: Optional[str] = None, + format_only: bool = False, + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + self.ignore_index = ignore_index + self.metrics = iou_metrics + self.nan_to_num = nan_to_num + self.beta = beta + self.output_dir = output_dir + if self.output_dir and is_main_process(): + mkdir_or_exist(self.output_dir) + self.format_only = format_only + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + num_classes = len(self.dataset_meta['classes']) + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'].squeeze() + # format_only always for test dataset without ground truth + if not self.format_only: + label = data_sample['gt_sem_seg']['data'].squeeze().to( + pred_label) + self.results.append( + self.intersect_and_union(pred_label, label, num_classes, + self.ignore_index)) + # format_result + if self.output_dir is not None: + basename = osp.splitext(osp.basename( + data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output_mask = pred_label.cpu().numpy() + # The index range of official ADE20k dataset is from 0 to 150. + # But the index range of output is from 0 to 149. + # That is because we set reduce_zero_label=True. + if data_sample.get('reduce_zero_label', False): + output_mask = output_mask + 1 + output = Image.fromarray(output_mask.astype(np.uint8)) + output.save(png_filename) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The key + mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, + mRecall. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + # convert list of tuples to tuple of lists, e.g. + # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to + # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) + results = tuple(zip(*results)) + assert len(results) == 4 + + total_area_intersect = sum(results[0]) + total_area_union = sum(results[1]) + total_area_pred_label = sum(results[2]) + total_area_label = sum(results[3]) + ret_metrics = self.total_area_to_metrics( + total_area_intersect, total_area_union, total_area_pred_label, + total_area_label, self.metrics, self.nan_to_num, self.beta) + + class_names = self.dataset_meta['classes'] + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + metrics = dict() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + metrics[key] = val + else: + metrics['m' + key] = val + + # each class table + ret_metrics.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + + return metrics + + @staticmethod + def intersect_and_union(pred_label: torch.tensor, label: torch.tensor, + num_classes: int, ignore_index: int): + """Calculate Intersection and Union. + + Args: + pred_label (torch.tensor): Prediction segmentation map + or predict result filename. The shape is (H, W). + label (torch.tensor): Ground truth segmentation map + or label filename. The shape is (H, W). + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + @staticmethod + def total_area_to_metrics(total_area_intersect: np.ndarray, + total_area_union: np.ndarray, + total_area_pred_label: np.ndarray, + total_area_label: np.ndarray, + metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1): + """Calculate evaluation metrics + Args: + total_area_intersect (np.ndarray): The intersection of prediction + and ground truth histogram on all classes. + total_area_union (np.ndarray): The union of prediction and ground + truth histogram on all classes. + total_area_pred_label (np.ndarray): The prediction histogram on + all classes. + total_area_label (np.ndarray): The ground truth histogram on + all classes. + metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and + 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be + replaced by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + Returns: + Dict[str, np.ndarray]: per category evaluation metrics, + shape (num_classes, ). + """ + + def f_score(precision, recall, beta=1): + """calculate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined + score. Default: 1. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError(f'metrics {metrics} is not supported') + + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor([ + f_score(x[0], x[1], beta) for x in zip(precision, recall) + ]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.numpy() + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + }) + return ret_metrics diff --git a/mmseg/models/__init__.py b/mmseg/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a520fb2fa4da477ecb57cb6bb4b4af936e2a4da --- /dev/null +++ b/mmseg/models/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, + build_head, build_loss, build_segmentor) +from .data_preprocessor import SegDataPreProcessor +from .decode_heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .segmentors import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', + 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor' +] diff --git a/mmseg/models/__pycache__/__init__.cpython-310.pyc b/mmseg/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0dead8ad03fb469158c655140769b423b23d281 Binary files /dev/null and b/mmseg/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/models/__pycache__/builder.cpython-310.pyc b/mmseg/models/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db7c8c9a6b8a8ec30ff83af2af53c2b881d74990 Binary files /dev/null and b/mmseg/models/__pycache__/builder.cpython-310.pyc differ diff --git a/mmseg/models/__pycache__/data_preprocessor.cpython-310.pyc b/mmseg/models/__pycache__/data_preprocessor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a27698097cb774f1fe0fa61301328af8c1e114c Binary files /dev/null and b/mmseg/models/__pycache__/data_preprocessor.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3107306eae5c4f23a00c3ed544960e280f2dfd0 --- /dev/null +++ b/mmseg/models/backbones/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beit import BEiT +from .bisenetv1 import BiSeNetV1 +from .bisenetv2 import BiSeNetV2 +from .cgnet import CGNet +from .erfnet import ERFNet +from .fast_scnn import FastSCNN +from .hrnet import HRNet +from .icnet import ICNet +from .mae import MAE +from .mit import MixVisionTransformer +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mscan import MSCAN +from .pidnet import PIDNet +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnext import ResNeXt +from .stdc import STDCContextPathNet, STDCNet +from .swin import SwinTransformer +from .timm_backbone import TIMMBackbone +from .twins import PCPVT, SVT +from .unet import UNet +from .vit import VisionTransformer + +__all__ = [ + 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', + 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', + 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', + 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', + 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN' +] diff --git a/mmseg/models/backbones/__pycache__/__init__.cpython-310.pyc b/mmseg/models/backbones/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfd8e78790d53ce252e9ecbb0e2e462811649e15 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/beit.cpython-310.pyc b/mmseg/models/backbones/__pycache__/beit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76d420ccb34638404dc60610cbc5ada58744ed10 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/beit.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/bisenetv1.cpython-310.pyc b/mmseg/models/backbones/__pycache__/bisenetv1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c6ce9e22c75e2da36b5aeecd00f14f97fcf8ae7 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/bisenetv1.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/bisenetv2.cpython-310.pyc b/mmseg/models/backbones/__pycache__/bisenetv2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9024f596e3a3444b176b434a870d2edb00a724ae Binary files /dev/null and b/mmseg/models/backbones/__pycache__/bisenetv2.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/cgnet.cpython-310.pyc b/mmseg/models/backbones/__pycache__/cgnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b5eff652c72f95e1ad2c1050e9d0d43e68d3f8d Binary files /dev/null and b/mmseg/models/backbones/__pycache__/cgnet.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/erfnet.cpython-310.pyc b/mmseg/models/backbones/__pycache__/erfnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c778cb196b99cc2524f17000ab9aa2c0692552bd Binary files /dev/null and b/mmseg/models/backbones/__pycache__/erfnet.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/fast_scnn.cpython-310.pyc b/mmseg/models/backbones/__pycache__/fast_scnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec4ebb3562fa5411fa57a278985bef372a767c73 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/fast_scnn.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/hrnet.cpython-310.pyc b/mmseg/models/backbones/__pycache__/hrnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..166341037e8e035165d01a8dbb750275e307b6b2 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/hrnet.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/icnet.cpython-310.pyc b/mmseg/models/backbones/__pycache__/icnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4931c850a355c72c28d88678dd39dd8580b2fd19 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/icnet.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mae.cpython-310.pyc b/mmseg/models/backbones/__pycache__/mae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f53ce70dd2e8628d5e1ae5662c96aedd797972a7 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mae.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mit.cpython-310.pyc b/mmseg/models/backbones/__pycache__/mit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32a1561e973ea2e5b30ec88f534482c58eec7f77 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mit.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-310.pyc b/mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2ef2f6abe3fe09b535d45e008bc4b8fb960bae1 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mobilenet_v2.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-310.pyc b/mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a392b0de71b259621e8779829b24aa6e7cf07b7 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mobilenet_v3.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/mscan.cpython-310.pyc b/mmseg/models/backbones/__pycache__/mscan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1bf565ad7420a75d186894133d05c25870060be Binary files /dev/null and b/mmseg/models/backbones/__pycache__/mscan.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/pidnet.cpython-310.pyc b/mmseg/models/backbones/__pycache__/pidnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..225a263a0689ec153f90df155558529c62b37edf Binary files /dev/null and b/mmseg/models/backbones/__pycache__/pidnet.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/resnest.cpython-310.pyc b/mmseg/models/backbones/__pycache__/resnest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1983959721a694799f1a0575138b2d02cd034aa4 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/resnest.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/resnet.cpython-310.pyc b/mmseg/models/backbones/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9791383bb671a6faacb354553382cd40238bf570 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/resnet.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/resnext.cpython-310.pyc b/mmseg/models/backbones/__pycache__/resnext.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cf52b2b65aa9e48082c2a7dcbfbb10ed76d397b Binary files /dev/null and b/mmseg/models/backbones/__pycache__/resnext.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/stdc.cpython-310.pyc b/mmseg/models/backbones/__pycache__/stdc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f93a801740d6108cac704a0415ddab71abb0bf1b Binary files /dev/null and b/mmseg/models/backbones/__pycache__/stdc.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/swin.cpython-310.pyc b/mmseg/models/backbones/__pycache__/swin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ea092eadde6650f7a83da97e778dd4f316d4bb7 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/swin.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/timm_backbone.cpython-310.pyc b/mmseg/models/backbones/__pycache__/timm_backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48b9c07deb938720101ed3d1da0c58ddd44fc683 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/timm_backbone.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/twins.cpython-310.pyc b/mmseg/models/backbones/__pycache__/twins.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8db79d7bdb4fde561fe782e80353c63202bf060 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/twins.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/unet.cpython-310.pyc b/mmseg/models/backbones/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94b32405bfd31477b59b9624771abdf422a0eb9e Binary files /dev/null and b/mmseg/models/backbones/__pycache__/unet.cpython-310.pyc differ diff --git a/mmseg/models/backbones/__pycache__/vit.cpython-310.pyc b/mmseg/models/backbones/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61495391373c69987a03e1954b74e5050744250 Binary files /dev/null and b/mmseg/models/backbones/__pycache__/vit.cpython-310.pyc differ diff --git a/mmseg/models/backbones/beit.py b/mmseg/models/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..e5da71e729256a9dd12b70d32886c9db27d9fa3c --- /dev/null +++ b/mmseg/models/backbones/beit.py @@ -0,0 +1,554 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import _load_checkpoint +from scipy import interpolate +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.registry import MODELS +from ..utils import PatchEmbed +from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + self.window_size = window_size + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (tuple[int], optional): The height and width of the window. + Default: None. + init_values (float, optional): Initialize the values of BEiTAttention + and FFN with learnable scaling. Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + bias='qv_bias', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=None, + attn_cfg=dict(), + ffn_cfg=dict(add_identity=False), + init_values=None): + attn_cfg.update(dict(window_size=window_size, qk_scale=None)) + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + qkv_bias=bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + attn_cfg=attn_cfg, + ffn_cfg=ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + self.gamma_1 = nn.Parameter( + init_values * torch.ones(embed_dims), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones(embed_dims), requires_grad=True) + + def build_attn(self, attn_cfg): + self.attn = BEiTAttention(**attn_cfg) + + def forward(self, x): + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x))) + return x + + +@MODELS.register_module() +class BEiT(BaseModule): + """BERT Pre-Training of Image Transformers. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_layers (int): Depth of transformer. Default: 12. + num_heads (int): Number of attention heads. Default: 12. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qv_bias (bool): Enable bias for qv if True. Default: True. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): Model pretrained path. Default: None. + init_values (float): Initialize the values of BEiTAttention and FFN + with learnable scaling. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qv_bias=True, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.img_size = img_size + self.patch_size = patch_size + self.norm_eval = norm_eval + self.pretrained = pretrained + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.num_fcs = num_fcs + self.qv_bias = qv_bias + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.patch_norm = patch_norm + self.init_values = init_values + self.window_size = (img_size[0] // patch_size, + img_size[1] // patch_size) + self.patch_shape = self.window_size + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self._build_patch_embedding() + self._build_layers() + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + def _build_patch_embedding(self): + """Build patch embedding layer.""" + self.patch_embed = PatchEmbed( + in_channels=self.in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + padding=0, + norm_cfg=self.norm_cfg if self.patch_norm else None, + init_cfg=None) + + def _build_layers(self): + """Build transformer encoding layers.""" + + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + BEiTTransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias='qv_bias' if self.qv_bias else False, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.window_size, + init_values=self.init_values)) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _geometric_sequence_interpolation(self, src_size, dst_size, sequence, + num): + """Get new sequence via geometric sequence interpolation. + + Args: + src_size (int): Pos_embedding size in pre-trained model. + dst_size (int): Pos_embedding size in the current model. + sequence (tensor): The relative position bias of the pretrain + model after removing the extra tokens. + num (int): Number of attention heads. + Returns: + new_sequence (tensor): Geometric sequence interpolate the + pre-trained relative position bias to the size of + the current model. + """ + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + # Here is a binary function. + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + # The position of each interpolated point is determined + # by the ratio obtained by dichotomy. + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q**(i + 1) + r_ids = [-_ for _ in reversed(dis)] + x = r_ids + [0] + dis + y = r_ids + [0] + dis + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + # Interpolation functions are being executed and called. + new_sequence = [] + for i in range(num): + z = sequence[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + new_sequence.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence)) + new_sequence = torch.cat(new_sequence, dim=-1) + return new_sequence + + def resize_rel_pos_embed(self, checkpoint): + """Resize relative pos_embed weights. + + This function is modified from + https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + Args: + checkpoint (dict): Key and value of the pretrain model. + Returns: + state_dict (dict): Interpolate the relative pos_embed weights + in the pre-train model to the current model size. + """ + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + all_keys = list(state_dict.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict.pop(key) + # In order to keep the center of pos_bias as consistent as + # possible after interpolation, and vice versa in the edge + # area, the geometric sequence interpolation method is adopted. + if 'relative_position_bias_table' in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = self.state_dict()[key].size() + dst_patch_shape = self.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + # Count the number of extra tokens. + num_extra_tokens = dst_num_pos - ( + dst_patch_shape[0] * 2 - 1) * ( + dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens)**0.5) + dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + if src_size != dst_size: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + new_rel_pos_bias = self._geometric_sequence_interpolation( + src_size, dst_size, rel_pos_bias, num_attn_heads) + new_rel_pos_bias = torch.cat( + (new_rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + return state_dict + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/mmseg/models/backbones/bisenetv1.py b/mmseg/models/backbones/bisenetv1.py new file mode 100644 index 0000000000000000000000000000000000000000..ca58bf9c597836937bc384739ff77001b5402942 --- /dev/null +++ b/mmseg/models/backbones/bisenetv1.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class SpatialPath(BaseModule): + """Spatial Path to preserve the spatial size of the original input image + and encode affluent spatial information. + + Args: + in_channels(int): The number of channels of input + image. Default: 3. + num_channels (Tuple[int]): The number of channels of + each layers in Spatial Path. + Default: (64, 64, 64, 128). + Returns: + x (torch.Tensor): Feature map for Feature Fusion Module. + """ + + def __init__(self, + in_channels=3, + num_channels=(64, 64, 64, 128), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(num_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + self.layers = [] + for i in range(len(num_channels)): + layer_name = f'layer{i + 1}' + self.layers.append(layer_name) + if i == 0: + self.add_module( + layer_name, + ConvModule( + in_channels=in_channels, + out_channels=num_channels[i], + kernel_size=7, + stride=2, + padding=3, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + elif i == len(num_channels) - 1: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + for i, layer_name in enumerate(self.layers): + layer_stage = getattr(self, layer_name) + x = layer_stage(x) + return x + + +class AttentionRefinementModule(BaseModule): + """Attention Refinement Module (ARM) to refine the features of each stage. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Attention Refinement Module. + """ + + def __init__(self, + in_channels, + out_channel, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_layer = ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.atten_conv_layer = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), nn.Sigmoid()) + + def forward(self, x): + x = self.conv_layer(x) + x_atten = self.atten_conv_layer(x) + x_out = x * x_atten + return x_out + + +class ContextPath(BaseModule): + """Context Path to provide sufficient receptive field. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + context_channels (Tuple[int]): The number of channel numbers + of various modules in Context Path. + Default: (128, 256, 512). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + Returns: + x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps + undergoing upsampling from 1/16 and 1/32 downsampling + feature maps. These two feature maps are used for Feature + Fusion Module and Auxiliary Head. + """ + + def __init__(self, + backbone_cfg, + context_channels=(128, 256, 512), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.backbone = MODELS.build(backbone_cfg) + + self.align_corners = align_corners + self.arm16 = AttentionRefinementModule(context_channels[1], + context_channels[0]) + self.arm32 = AttentionRefinementModule(context_channels[2], + context_channels[0]) + self.conv_head32 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_head16 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap_conv = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=context_channels[2], + out_channels=context_channels[0], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + x_4, x_8, x_16, x_32 = self.backbone(x) + x_gap = self.gap_conv(x_32) + + x_32_arm = self.arm32(x_32) + x_32_sum = x_32_arm + x_gap + x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest') + x_32_up = self.conv_head32(x_32_up) + + x_16_arm = self.arm16(x_16) + x_16_sum = x_16_arm + x_32_up + x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest') + x_16_up = self.conv_head16(x_16_up) + + return x_16_up, x_32_up + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module to fuse low level output feature of Spatial Path + and high level output feature of Context Path. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Feature Fusion Module. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_atten = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), nn.Sigmoid()) + + def forward(self, x_sp, x_cp): + x_concat = torch.cat([x_sp, x_cp], dim=1) + x_fuse = self.conv1(x_concat) + x_atten = self.gap(x_fuse) + # Note: No BN and more 1x1 conv in paper. + x_atten = self.conv_atten(x_atten) + x_atten = x_fuse * x_atten + x_out = x_atten + x_fuse + return x_out + + +@MODELS.register_module() +class BiSeNetV1(BaseModule): + """BiSeNetV1 backbone. + + This backbone is the implementation of `BiSeNet: Bilateral + Segmentation Network for Real-time Semantic + Segmentation `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input + image. Default: 3. + spatial_channels (Tuple[int]): Size of channel numbers of + various layers in Spatial Path. + Default: (64, 64, 64, 128). + context_channels (Tuple[int]): Size of channel numbers of + various modules in Context Path. + Default: (128, 256, 512). + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + out_channels(int): The number of channels of output. + It must be the same with `in_channels` of decode_head. + Default: 256. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + spatial_channels=(64, 64, 64, 128), + context_channels=(128, 256, 512), + out_indices=(0, 1, 2), + align_corners=False, + out_channels=256, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + assert len(spatial_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.out_indices = out_indices + self.align_corners = align_corners + self.context_path = ContextPath(backbone_cfg, context_channels, + self.align_corners) + self.spatial_path = SpatialPath(in_channels, spatial_channels) + self.ffm = FeatureFusionModule(context_channels[1], out_channels) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_context8, x_context16 = self.context_path(x) + x_spatial = self.spatial_path(x) + x_fuse = self.ffm(x_spatial, x_context8) + + outs = [x_fuse, x_context8, x_context16] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/mmseg/models/backbones/bisenetv2.py b/mmseg/models/backbones/bisenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..32aa49822f7d0c3bd4839b3796a15689e1f4cbc0 --- /dev/null +++ b/mmseg/models/backbones/bisenetv2.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, + build_activation_layer, build_norm_layer) +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class DetailBranch(BaseModule): + """Detail Branch with wide channels and shallow layers to capture low-level + details and generate high-resolution feature representation. + + Args: + detail_channels (Tuple[int]): Size of channel numbers of each stage + in Detail Branch, in paper it has 3 stages. + Default: (64, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Feature map of Detail Branch. + """ + + def __init__(self, + detail_channels=(64, 64, 128), + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + detail_branch = [] + for i in range(len(detail_channels)): + if i == 0: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + else: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=detail_channels[i - 1], + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + self.detail_branch = nn.ModuleList(detail_branch) + + def forward(self, x): + for stage in self.detail_branch: + x = stage(x) + return x + + +class StemBlock(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): First feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.conv_first = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.convs = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=out_channels // 2, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.pool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=False) + self.fuse_last = ConvModule( + in_channels=out_channels * 2, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.conv_first(x) + x_left = self.convs(x) + x_right = self.pool(x) + x = self.fuse_last(torch.cat([x_left, x_right], dim=1)) + return x + + +class GELayer(BaseModule): + """Gather-and-Expansion Layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + stride (int): Stride of GELayer. Default: 1 + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Intermediate feature map in + Semantic Branch. + """ + + def __init__(self, + in_channels, + out_channels, + exp_ratio=6, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + mid_channel = in_channels * exp_ratio + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if stride == 1: + self.dwconv = nn.Sequential( + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.shortcut = None + else: + self.dwconv = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=mid_channel, + out_channels=mid_channel, + kernel_size=3, + stride=1, + padding=1, + groups=mid_channel, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + self.shortcut = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=None, + )) + + self.conv2 = nn.Sequential( + ConvModule( + in_channels=mid_channel, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + )) + + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.dwconv(x) + x = self.conv2(x) + if self.shortcut is not None: + shortcut = self.shortcut(identity) + x = x + shortcut + else: + x = x + identity + x = self.act(x) + return x + + +class CEBlock(BaseModule): + """Context Embedding Block for large receptive filed in Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Last feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + build_norm_layer(norm_cfg, self.in_channels)[1]) + self.conv_gap = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # Note: in paper here is naive conv2d, no bn-relu + self.conv_last = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + identity = x + x = self.gap(x) + x = self.conv_gap(x) + x = identity + x + x = self.conv_last(x) + return x + + +class SemanticBranch(BaseModule): + """Semantic Branch which is lightweight with narrow channels and deep + layers to obtain high-level semantic context. + + Args: + semantic_channels(Tuple[int]): Size of channel numbers of + various stages in Semantic Branch. + Default: (16, 32, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + semantic_outs (List[torch.Tensor]): List of several feature maps + for auxiliary heads (Booster) and Bilateral + Guided Aggregation Layer. + """ + + def __init__(self, + semantic_channels=(16, 32, 64, 128), + in_channels=3, + exp_ratio=6, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.semantic_channels = semantic_channels + self.semantic_stages = [] + for i in range(len(semantic_channels)): + stage_name = f'stage{i + 1}' + self.semantic_stages.append(stage_name) + if i == 0: + self.add_module( + stage_name, + StemBlock(self.in_channels, semantic_channels[i])) + elif i == (len(semantic_channels) - 1): + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + else: + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + + self.add_module(f'stage{len(semantic_channels)}_CEBlock', + CEBlock(semantic_channels[-1], semantic_channels[-1])) + self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') + + def forward(self, x): + semantic_outs = [] + for stage_name in self.semantic_stages: + semantic_stage = getattr(self, stage_name) + x = semantic_stage(x) + semantic_outs.append(x) + return semantic_outs + + +class BGALayer(BaseModule): + """Bilateral Guided Aggregation Layer to fuse the complementary information + from both Detail Branch and Semantic Branch. + + Args: + out_channels (int): Number of output channels. + Default: 128. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + output (torch.Tensor): Output feature map for Segment heads. + """ + + def __init__(self, + out_channels=128, + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.out_channels = out_channels + self.align_corners = align_corners + self.detail_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.detail_down = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) + self.semantic_conv = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None)) + self.semantic_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.conv = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + inplace=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + def forward(self, x_d, x_s): + detail_dwconv = self.detail_dwconv(x_d) + detail_down = self.detail_down(x_d) + semantic_conv = self.semantic_conv(x_s) + semantic_dwconv = self.semantic_dwconv(x_s) + semantic_conv = resize( + input=semantic_conv, + size=detail_dwconv.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) + fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) + fuse_2 = resize( + input=fuse_2, + size=fuse_1.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = self.conv(fuse_1 + fuse_2) + return output + + +@MODELS.register_module() +class BiSeNetV2(BaseModule): + """BiSeNetV2: Bilateral Network with Guided Aggregation for + Real-time Semantic Segmentation. + + This backbone is the implementation of + `BiSeNetV2 `_. + + Args: + in_channels (int): Number of channel of input image. Default: 3. + detail_channels (Tuple[int], optional): Channels of each stage + in Detail Branch. Default: (64, 64, 128). + semantic_channels (Tuple[int], optional): Channels of each stage + in Semantic Branch. Default: (16, 32, 64, 128). + See Table 1 and Figure 3 of paper for more details. + semantic_expansion_ratio (int, optional): The expansion factor + expanding channel number of middle channels in Semantic Branch. + Default: 6. + bga_channels (int, optional): Number of middle channels in + Bilateral Guided Aggregation Layer. Default: 128. + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2, 3, 4). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_indices = out_indices + self.detail_channels = detail_channels + self.semantic_channels = semantic_channels + self.semantic_expansion_ratio = semantic_expansion_ratio + self.bga_channels = bga_channels + self.align_corners = align_corners + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.detail = DetailBranch(self.detail_channels, self.in_channels) + self.semantic = SemanticBranch(self.semantic_channels, + self.in_channels, + self.semantic_expansion_ratio) + self.bga = BGALayer(self.bga_channels, self.align_corners) + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_detail = self.detail(x) + x_semantic_lst = self.semantic(x) + x_head = self.bga(x_detail, x_semantic_lst[-1]) + outs = [x_head] + x_semantic_lst[:-1] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/mmseg/models/backbones/cgnet.py b/mmseg/models/backbones/cgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b74b494f53466d1c608e50d088632aa952a5e534 --- /dev/null +++ b/mmseg/models/backbones/cgnet.py @@ -0,0 +1,372 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS + + +class GlobalContextExtractor(nn.Module): + """Global Context Extractor for CGNet. + + This class is employed to refine the joint feature of both local feature + and surrounding context. + + Args: + channel (int): Number of input feature channels. + reduction (int): Reductions for global context extractor. Default: 16. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, channel, reduction=16, with_cp=False): + super().__init__() + self.channel = channel + self.reduction = reduction + assert reduction >= 1 and channel >= reduction + self.with_cp = with_cp + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + + def _inner_forward(x): + num_batch, num_channel = x.size()[:2] + y = self.avg_pool(x).view(num_batch, num_channel) + y = self.fc(y).view(num_batch, num_channel, 1, 1) + return x * y + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class ContextGuidedBlock(nn.Module): + """Context Guided Block for CGNet. + + This class consists of four components: local feature extractor, + surrounding feature extractor, joint feature extractor and global + context extractor. + + Args: + in_channels (int): Number of input feature channels. + out_channels (int): Number of output feature channels. + dilation (int): Dilation rate for surrounding context extractor. + Default: 2. + reduction (int): Reduction for global context extractor. Default: 16. + skip_connect (bool): Add input to output or not. Default: True. + downsample (bool): Downsample the input to 1/2 or not. Default: False. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + dilation=2, + reduction=16, + skip_connect=True, + downsample=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.downsample = downsample + + channels = out_channels if downsample else out_channels // 2 + if 'type' in act_cfg and act_cfg['type'] == 'PReLU': + act_cfg['num_parameters'] = channels + kernel_size = 3 if downsample else 1 + stride = 2 if downsample else 1 + padding = (kernel_size - 1) // 2 + + self.conv1x1 = ConvModule( + in_channels, + channels, + kernel_size, + stride, + padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.f_loc = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=False) + self.f_sur = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=dilation, + groups=channels, + dilation=dilation, + bias=False) + + self.bn = build_norm_layer(norm_cfg, 2 * channels)[1] + self.activate = nn.PReLU(2 * channels) + + if downsample: + self.bottleneck = build_conv_layer( + conv_cfg, + 2 * channels, + out_channels, + kernel_size=1, + bias=False) + + self.skip_connect = skip_connect and not downsample + self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp) + + def forward(self, x): + + def _inner_forward(x): + out = self.conv1x1(x) + loc = self.f_loc(out) + sur = self.f_sur(out) + + joi_feat = torch.cat([loc, sur], 1) # the joint feature + joi_feat = self.bn(joi_feat) + joi_feat = self.activate(joi_feat) + if self.downsample: + joi_feat = self.bottleneck(joi_feat) # channel = out_channels + # f_glo is employed to refine the joint feature + out = self.f_glo(joi_feat) + + if self.skip_connect: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InputInjection(nn.Module): + """Downsampling module for CGNet.""" + + def __init__(self, num_downsampling): + super().__init__() + self.pool = nn.ModuleList() + for i in range(num_downsampling): + self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) + + def forward(self, x): + for pool in self.pool: + x = pool(x) + return x + + +@MODELS.register_module() +class CGNet(BaseModule): + """CGNet backbone. + + This backbone is the implementation of `A Light-weight Context Guided + Network for Semantic Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Normally 3. + num_channels (tuple[int]): Numbers of feature channels at each stages. + Default: (32, 64, 128). + num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2. + Default: (3, 21). + dilations (tuple[int]): Dilation rate for surrounding context + extractors at stage 1 and stage 2. Default: (2, 4). + reductions (tuple[int]): Reductions for global context extractors at + stage 1 and stage 2. Default: (8, 16). + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + + super().__init__(init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer=['Conv2d', 'Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Constant', val=0, layer='PReLU') + ] + else: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.num_channels = num_channels + assert isinstance(self.num_channels, tuple) and len( + self.num_channels) == 3 + self.num_blocks = num_blocks + assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2 + self.dilations = dilations + assert isinstance(self.dilations, tuple) and len(self.dilations) == 2 + self.reductions = reductions + assert isinstance(self.reductions, tuple) and len(self.reductions) == 2 + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU': + self.act_cfg['num_parameters'] = num_channels[0] + self.norm_eval = norm_eval + self.with_cp = with_cp + + cur_channels = in_channels + self.stem = nn.ModuleList() + for i in range(3): + self.stem.append( + ConvModule( + cur_channels, + num_channels[0], + 3, + 2 if i == 0 else 1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + cur_channels = num_channels[0] + + self.inject_2x = InputInjection(1) # down-sample for Input, factor=2 + self.inject_4x = InputInjection(2) # down-sample for Input, factor=4 + + cur_channels += in_channels + self.norm_prelu_0 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 1 + self.level1 = nn.ModuleList() + for i in range(num_blocks[0]): + self.level1.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[1], + num_channels[1], + dilations[0], + reductions[0], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[1] + in_channels + self.norm_prelu_1 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 2 + self.level2 = nn.ModuleList() + for i in range(num_blocks[1]): + self.level2.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[2], + num_channels[2], + dilations[1], + reductions[1], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[2] + self.norm_prelu_2 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + def forward(self, x): + output = [] + + # stage 0 + inp_2x = self.inject_2x(x) + inp_4x = self.inject_4x(x) + for layer in self.stem: + x = layer(x) + x = self.norm_prelu_0(torch.cat([x, inp_2x], 1)) + output.append(x) + + # stage 1 + for i, layer in enumerate(self.level1): + x = layer(x) + if i == 0: + down1 = x + x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1)) + output.append(x) + + # stage 2 + for i, layer in enumerate(self.level2): + x = layer(x) + if i == 0: + down2 = x + x = self.norm_prelu_2(torch.cat([down2, x], 1)) + output.append(x) + + return output + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5ec672a086b5d67568514140023ce402eef92f --- /dev/null +++ b/mmseg/models/backbones/erfnet.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class DownsamplerBlock(BaseModule): + """Downsampler block of ERFNet. + + This module is a little different from basical ConvModule. + The features from Conv and MaxPool layers are + concatenated before BatchNorm. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels - in_channels, + kernel_size=3, + stride=2, + padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + conv_out = self.conv(input) + pool_out = self.pool(input) + pool_out = resize( + input=pool_out, + size=conv_out.size()[2:], + mode='bilinear', + align_corners=False) + output = torch.cat([conv_out, pool_out], 1) + output = self.bn(output) + output = self.act(output) + return output + + +class NonBottleneck1d(BaseModule): + """Non-bottleneck block of ERFNet. + + Args: + channels (int): Number of channels in Non-bottleneck block. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + dilation (int): Dilation rate for last two conv layers. + Default 1. + num_conv_layer (int): Number of 3x1 and 1x3 convolution layers. + Default 2. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + channels, + drop_rate=0, + dilation=1, + num_conv_layer=2, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.act = build_activation_layer(self.act_cfg) + + self.convs_layers = nn.ModuleList() + for conv_layer in range(num_conv_layer): + first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0) + first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1) + second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation) + second_conv_dilation = 1 if conv_layer == 0 else (1, dilation) + + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(3, 1), + stride=1, + padding=first_conv_padding, + bias=True, + dilation=first_conv_dilation)) + self.convs_layers.append(self.act) + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(1, 3), + stride=1, + padding=second_conv_padding, + bias=True, + dilation=second_conv_dilation)) + self.convs_layers.append( + build_norm_layer(self.norm_cfg, channels)[1]) + if conv_layer == 0: + self.convs_layers.append(self.act) + else: + self.convs_layers.append(nn.Dropout(p=drop_rate)) + + def forward(self, input): + output = input + for conv in self.convs_layers: + output = conv(output) + output = self.act(output + input) + return output + + +class UpsamplerBlock(BaseModule): + """Upsampler block of ERFNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=True) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + output = self.act(output) + return output + + +@MODELS.register_module() +class ERFNet(BaseModule): + """ERFNet backbone. + + This backbone is the implementation of `ERFNet: Efficient Residual + Factorized ConvNet for Real-time SemanticSegmentation + `_. + + Args: + in_channels (int): The number of channels of input + image. Default: 3. + enc_downsample_channels (Tuple[int]): Size of channel + numbers of various Downsampler block in encoder. + Default: (16, 64, 128). + enc_stage_non_bottlenecks (Tuple[int]): Number of stages of + Non-bottleneck block in encoder. + Default: (5, 8). + enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each + stage of Non-bottleneck block of encoder. + Default: (2, 4, 8, 16). + enc_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in encoder. + Default: (64, 128). + dec_upsample_channels (Tuple[int]): Size of channel numbers of + various Deconvolution block in decoder. + Default: (64, 16). + dec_stages_non_bottleneck (Tuple[int]): Number of stages of + Non-bottleneck block in decoder. + Default: (2, 2). + dec_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in decoder. + Default: (64, 16). + drop_rate (float): Probability of an element to be zeroed. + Default 0.1. + """ + + def __init__(self, + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_stage_non_bottlenecks=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + assert len(enc_downsample_channels) \ + == len(dec_upsample_channels)+1, 'Number of downsample\ + block of encoder does not \ + match number of upsample block of decoder!' + assert len(enc_downsample_channels) \ + == len(enc_stage_non_bottlenecks)+1, 'Number of \ + downsample block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(enc_downsample_channels) \ + == len(enc_non_bottleneck_channels)+1, 'Number of \ + downsample block of encoder does not match \ + number of channels of Non-bottleneck block of encoder!' + assert enc_stage_non_bottlenecks[-1] \ + % len(enc_non_bottleneck_dilations) == 0, 'Number of \ + Non-bottleneck block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(dec_upsample_channels) \ + == len(dec_stages_non_bottleneck), 'Number of \ + upsample block of decoder does not match \ + number of Non-bottleneck block of decoder!' + assert len(dec_stages_non_bottleneck) \ + == len(dec_non_bottleneck_channels), 'Number of \ + Non-bottleneck block of decoder does not match \ + number of channels of Non-bottleneck block of decoder!' + + self.in_channels = in_channels + self.enc_downsample_channels = enc_downsample_channels + self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks + self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations + self.enc_non_bottleneck_channels = enc_non_bottleneck_channels + self.dec_upsample_channels = dec_upsample_channels + self.dec_stages_non_bottleneck = dec_stages_non_bottleneck + self.dec_non_bottleneck_channels = dec_non_bottleneck_channels + self.dropout_ratio = dropout_ratio + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.encoder.append( + DownsamplerBlock(self.in_channels, enc_downsample_channels[0])) + + for i in range(len(enc_downsample_channels) - 1): + self.encoder.append( + DownsamplerBlock(enc_downsample_channels[i], + enc_downsample_channels[i + 1])) + # Last part of encoder is some dilated NonBottleneck1d blocks. + if i == len(enc_downsample_channels) - 2: + iteration_times = int(enc_stage_non_bottlenecks[-1] / + len(enc_non_bottleneck_dilations)) + for j in range(iteration_times): + for k in range(len(enc_non_bottleneck_dilations)): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[-1], + self.dropout_ratio, + enc_non_bottleneck_dilations[k])) + else: + for j in range(enc_stage_non_bottlenecks[i]): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[i + 1], + self.dropout_ratio)) + + for i in range(len(dec_upsample_channels)): + if i == 0: + self.decoder.append( + UpsamplerBlock(enc_downsample_channels[-1], + dec_non_bottleneck_channels[i])) + else: + self.decoder.append( + UpsamplerBlock(dec_non_bottleneck_channels[i - 1], + dec_non_bottleneck_channels[i])) + for j in range(dec_stages_non_bottleneck[i]): + self.decoder.append( + NonBottleneck1d(dec_non_bottleneck_channels[i])) + + def forward(self, x): + for enc in self.encoder: + x = enc(x) + for dec in self.decoder: + x = dec(x) + return [x] diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff7a3191d2fee904c5200e0a526214a65f58b32 --- /dev/null +++ b/mmseg/models/backbones/fast_scnn.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmseg.models.decode_heads.psp_head import PPM +from mmseg.registry import MODELS +from ..utils import InvertedResidual, resize + + +class LearningToDownsample(nn.Module): + """Learning to downsample module. + + Args: + in_channels (int): Number of input channels. + dw_channels (tuple[int]): Number of output channels of the first and + the second depthwise conv (dwconv) layers. + out_channels (int): Number of output channels of the whole + 'learning to downsample' module. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + """ + + def __init__(self, + in_channels, + dw_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dw_act_cfg=None): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.dw_act_cfg = dw_act_cfg + dw_channels1 = dw_channels[0] + dw_channels2 = dw_channels[1] + + self.conv = ConvModule( + in_channels, + dw_channels1, + 3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.dsconv1 = DepthwiseSeparableConvModule( + dw_channels1, + dw_channels2, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + self.dsconv2 = DepthwiseSeparableConvModule( + dw_channels2, + out_channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + def forward(self, x): + x = self.conv(x) + x = self.dsconv1(x) + x = self.dsconv2(x) + return x + + +class GlobalFeatureExtractor(nn.Module): + """Global feature extractor module. + + Args: + in_channels (int): Number of input channels of the GFE module. + Default: 64 + block_channels (tuple[int]): Tuple of ints. Each int specifies the + number of output channels of each Inverted Residual module. + Default: (64, 96, 128) + out_channels(int): Number of output channels of the GFE module. + Default: 128 + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + Default: 6 + num_blocks (tuple[int]): Tuple of ints. Each int specifies the + number of times each Inverted Residual module is repeated. + The repeated Inverted Residual modules are called a 'group'. + Default: (3, 3, 3) + strides (tuple[int]): Tuple of ints. Each int specifies + the downsampling factor of each 'group'. + Default: (2, 2, 1) + pool_scales (tuple[int]): Tuple of ints. Each int specifies + the parameter required in 'global average pooling' within PPM. + Default: (1, 2, 3, 6) + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + """ + + def __init__(self, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + expand_ratio=6, + num_blocks=(3, 3, 3), + strides=(2, 2, 1), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + assert len(block_channels) == len(num_blocks) == 3 + self.bottleneck1 = self._make_layer(in_channels, block_channels[0], + num_blocks[0], strides[0], + expand_ratio) + self.bottleneck2 = self._make_layer(block_channels[0], + block_channels[1], num_blocks[1], + strides[1], expand_ratio) + self.bottleneck3 = self._make_layer(block_channels[1], + block_channels[2], num_blocks[2], + strides[2], expand_ratio) + self.ppm = PPM( + pool_scales, + block_channels[2], + block_channels[2] // 4, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=align_corners) + + self.out = ConvModule( + block_channels[2] * 2, + out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _make_layer(self, + in_channels, + out_channels, + blocks, + stride=1, + expand_ratio=6): + layers = [ + InvertedResidual( + in_channels, + out_channels, + stride, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + for i in range(1, blocks): + layers.append( + InvertedResidual( + out_channels, + out_channels, + 1, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.bottleneck1(x) + x = self.bottleneck2(x) + x = self.bottleneck3(x) + x = torch.cat([x, *self.ppm(x)], dim=1) + x = self.out(x) + return x + + +class FeatureFusionModule(nn.Module): + """Feature fusion module. + + Args: + higher_in_channels (int): Number of input channels of the + higher-resolution branch. + lower_in_channels (int): Number of input channels of the + lower-resolution branch. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + dwconv_act_cfg (dict): Config of activation layers in 3x3 conv. + Default: dict(type='ReLU'). + conv_act_cfg (dict): Config of activation layers in the two 1x1 conv. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + """ + + def __init__(self, + higher_in_channels, + lower_in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dwconv_act_cfg=dict(type='ReLU'), + conv_act_cfg=None, + align_corners=False): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dwconv_act_cfg = dwconv_act_cfg + self.conv_act_cfg = conv_act_cfg + self.align_corners = align_corners + self.dwconv = ConvModule( + lower_in_channels, + out_channels, + 3, + padding=1, + groups=out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.dwconv_act_cfg) + self.conv_lower_res = ConvModule( + out_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.conv_higher_res = ConvModule( + higher_in_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.relu = nn.ReLU(True) + + def forward(self, higher_res_feature, lower_res_feature): + lower_res_feature = resize( + lower_res_feature, + size=higher_res_feature.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + lower_res_feature = self.dwconv(lower_res_feature) + lower_res_feature = self.conv_lower_res(lower_res_feature) + + higher_res_feature = self.conv_higher_res(higher_res_feature) + out = higher_res_feature + lower_res_feature + return self.relu(out) + + +@MODELS.register_module() +class FastSCNN(BaseModule): + """Fast-SCNN Backbone. + + This backbone is the implementation of `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels (int): Number of input image channels. Default: 3. + downsample_dw_channels (tuple[int]): Number of output channels after + the first conv layer & the second conv layer in + Learning-To-Downsample (LTD) module. + Default: (32, 48). + global_in_channels (int): Number of input channels of + Global Feature Extractor(GFE). + Equal to number of output channels of LTD. + Default: 64. + global_block_channels (tuple[int]): Tuple of integers that describe + the output channels for each of the MobileNet-v2 bottleneck + residual blocks in GFE. + Default: (64, 96, 128). + global_block_strides (tuple[int]): Tuple of integers + that describe the strides (downsampling factors) for each of the + MobileNet-v2 bottleneck residual blocks in GFE. + Default: (2, 2, 1). + global_out_channels (int): Number of output channels of GFE. + Default: 128. + higher_in_channels (int): Number of input channels of the higher + resolution branch in FFM. + Equal to global_in_channels. + Default: 64. + lower_in_channels (int): Number of input channels of the lower + resolution branch in FFM. + Equal to global_out_channels. + Default: 128. + fusion_out_channels (int): Number of output channels of FFM. + Default: 128. + out_indices (tuple): Tuple of indices of list + [higher_res_features, lower_res_features, fusion_output]. + Often set to (0,1,2) to enable aux. heads. + Default: (0, 1, 2). + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + dw_act_cfg=None, + init_cfg=None): + + super().__init__(init_cfg) + + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + + if global_in_channels != higher_in_channels: + raise AssertionError('Global Input Channels must be the same \ + with Higher Input Channels!') + elif global_out_channels != lower_in_channels: + raise AssertionError('Global Output Channels must be the same \ + with Lower Input Channels!') + + self.in_channels = in_channels + self.downsample_dw_channels1 = downsample_dw_channels[0] + self.downsample_dw_channels2 = downsample_dw_channels[1] + self.global_in_channels = global_in_channels + self.global_block_channels = global_block_channels + self.global_block_strides = global_block_strides + self.global_out_channels = global_out_channels + self.higher_in_channels = higher_in_channels + self.lower_in_channels = lower_in_channels + self.fusion_out_channels = fusion_out_channels + self.out_indices = out_indices + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.learning_to_downsample = LearningToDownsample( + in_channels, + downsample_dw_channels, + global_in_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + dw_act_cfg=dw_act_cfg) + self.global_feature_extractor = GlobalFeatureExtractor( + global_in_channels, + global_block_channels, + global_out_channels, + strides=self.global_block_strides, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.feature_fusion = FeatureFusionModule( + higher_in_channels, + lower_in_channels, + fusion_out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dwconv_act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, x): + higher_res_features = self.learning_to_downsample(x) + lower_res_features = self.global_feature_extractor(higher_res_features) + fusion_output = self.feature_fusion(higher_res_features, + lower_res_features) + + outs = [higher_res_features, lower_res_features, fusion_output] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2da755e731cfea911d47729f455c54c3d38a68e4 --- /dev/null +++ b/mmseg/models/backbones/hrnet.py @@ -0,0 +1,642 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import Upsample, resize +from .resnet import BasicBlock, Bottleneck + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + """ + + def __init__(self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + block_init_cfg=None, + init_cfg=None): + super().__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + """Check branches configuration.""" + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \ + f'{len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \ + f'{len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \ + f'{len(in_channels)})' + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Build one branch.""" + downsample = None + if stride != 1 or \ + self.in_channels[branch_index] != \ + num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.in_channels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, num_channels[branch_index] * + block.expansion)[1]) + + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + self.in_channels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + + return Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + """Build multiple branch.""" + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return ModuleList(branches) + + def _make_fuse_layers(self): + """Build fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + # we set align_corners=False for HRNet + Upsample( + scale_factor=2**(j - i), + mode='bilinear', + align_corners=False))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + elif j > i: + y = y + resize( + self.fuse_layers[i][j](x[j]), + size=x[i].shape[2:], + mode='bilinear', + align_corners=False) + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + This backbone is the implementation of `High-Resolution Representations + for Labeling Pixels and Regions `_. + + Args: + extra (dict): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of convolution block. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + in_channels (int): Number of input image channels. Normally 3. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Use `BN` by default. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import HRNet + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + frozen_stages=-1, + zero_init_residual=False, + multiscale_output=True, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + # Assert configurations of 4 stages are in extra + assert 'stage1' in extra and 'stage2' in extra \ + and 'stage3' in extra and 'stage4' in extra + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + for i in range(4): + cfg = extra[f'stage{i + 1}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.frozen_stages = frozen_stages + + # stem net + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + 64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = self.extra['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'][0] + + block = self.blocks_dict[block_type] + stage1_out_channels = num_channels * block.expansion + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + + # stage 2 + self.stage2_cfg = self.extra['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = self.stage2_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition1 = self._make_transition_layer([stage1_out_channels], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = self.extra['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = self.stage3_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = self.extra['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = self.stage4_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multiscale_output=multiscale_output) + + self._freeze_stages() + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + """Make each layer.""" + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + layers.append( + block( + inplanes, + planes, + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + inplanes, + planes, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + + return Sequential(*layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + """Make each stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules), in_channels + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + + self.norm1.eval() + self.norm2.eval() + for m in [self.conv1, self.norm1, self.conv2, self.norm2]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + if i == 1: + m = getattr(self, f'layer{i}') + t = getattr(self, f'transition{i}') + elif i == 4: + m = getattr(self, f'stage{i}') + else: + m = getattr(self, f'stage{i}') + t = getattr(self, f'transition{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + t.eval() + for param in t.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return y_list + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/icnet.py b/mmseg/models/backbones/icnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff3448569c5a3ec82a12726767fcbb48b3870d2 --- /dev/null +++ b/mmseg/models/backbones/icnet.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..decode_heads.psp_head import PPM +from ..utils import resize + + +@MODELS.register_module() +class ICNet(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This backbone is the implementation of + `ICNet `_. + + Args: + backbone_cfg (dict): Config dict to build backbone. Usually it is + ResNet but it can also be other backbones. + in_channels (int): The number of input image channels. Default: 3. + layer_channels (Sequence[int]): The numbers of feature channels at + layer 2 and layer 4 in ResNet. It can also be other backbones. + Default: (512, 2048). + light_branch_middle_channels (int): The number of channels of the + middle layer in light branch. Default: 32. + psp_out_channels (int): The number of channels of the output of PSP + module. Default: 512. + out_channels (Sequence[int]): The numbers of output feature channels + at each branches. Default: (64, 256, 256). + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + layer_channels=(512, 2048), + light_branch_middle_channels=32, + psp_out_channels=512, + out_channels=(64, 256, 256), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + if backbone_cfg is None: + raise TypeError('backbone_cfg must be passed from config file!') + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', mode='fan_out', layer='Conv2d'), + dict(type='Constant', val=1, layer='_BatchNorm'), + dict(type='Normal', mean=0.01, layer='Linear') + ] + super().__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.backbone = MODELS.build(backbone_cfg) + + # Note: Default `ceil_mode` is false in nn.MaxPool2d, set + # `ceil_mode=True` to keep information in the corner of feature map. + self.backbone.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=True) + + self.psp_modules = PPM( + pool_scales=pool_scales, + in_channels=layer_channels[1], + channels=psp_out_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + align_corners=align_corners) + + self.psp_bottleneck = ConvModule( + layer_channels[1] + len(pool_scales) * psp_out_channels, + psp_out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.conv_sub1 = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=out_channels[0], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + self.conv_sub2 = ConvModule( + layer_channels[0], + out_channels[1], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + self.conv_sub4 = ConvModule( + psp_out_channels, + out_channels[2], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + def forward(self, x): + output = [] + + # sub 1 + output.append(self.conv_sub1(x)) + + # sub 2 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.stem(x) + x = self.backbone.maxpool(x) + x = self.backbone.layer1(x) + x = self.backbone.layer2(x) + output.append(self.conv_sub2(x)) + + # sub 4 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.layer3(x) + x = self.backbone.layer4(x) + psp_outs = self.psp_modules(x) + [x] + psp_outs = torch.cat(psp_outs, dim=1) + x = self.psp_bottleneck(psp_outs) + + output.append(self.conv_sub4(x)) + + return output diff --git a/mmseg/models/backbones/mae.py b/mmseg/models/backbones/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f243f0857b9aca5454e8c1410075bff9281285 --- /dev/null +++ b/mmseg/models/backbones/mae.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved.import math +import math + +import torch +import torch.nn as nn +from mmengine.model import ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import _load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer + + +class MAEAttention(BEiTAttention): + """Multi-head self-attention with relative position bias used in MAE. + + This module is different from ``BEiTAttention`` by initializing the + relative bias table with zeros. + """ + + def init_weights(self): + """Initialize relative position bias with zeros.""" + + # As MAE initializes relative position bias as zeros and this class + # inherited from BEiT which initializes relative position bias + # with `trunc_normal`, `init_weights` here does + # nothing and just passes directly + + pass + + +class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + This module is different from ``BEiTTransformerEncoderLayer`` by replacing + ``BEiTAttention`` with ``MAEAttention``. + """ + + def build_attn(self, attn_cfg): + self.attn = MAEAttention(**attn_cfg) + + +@MODELS.register_module() +class MAE(BEiT): + """VisionTransformer with support for patch. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_values (float): Initialize the values of Attention and FFN + with learnable scaling. Defaults to 0.1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super().__init__( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + num_layers=num_layers, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + out_indices=out_indices, + qv_bias=False, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + patch_norm=patch_norm, + final_norm=final_norm, + num_fcs=num_fcs, + norm_eval=norm_eval, + pretrained=pretrained, + init_values=init_values, + init_cfg=init_cfg) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self.num_patches = self.patch_shape[0] * self.patch_shape[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, embed_dims)) + + def _build_layers(self): + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + MAETransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias=True, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.patch_shape, + init_values=self.init_values)) + + def fix_init_weight(self): + """Rescale the initialization according to layer id. + + This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + """ + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + self.fix_init_weight() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + state_dict = self.resize_abs_pos_embed(state_dict) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def resize_abs_pos_embed(self, state_dict): + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(self.num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode='bicubic', + align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + return state_dict + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py new file mode 100644 index 0000000000000000000000000000000000000000..66556bdfca2b0bcb180afd23c2923c68b9ff3a69 --- /dev/null +++ b/mmseg/models/backbones/mit.py @@ -0,0 +1,450 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) + +from mmseg.registry import MODELS +from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw + + +class MixFFN(BaseModule): + """An implementation of MixFFN of Segformer. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Conv to encode positional information. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + init_cfg=None): + super().__init__(init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + # 3x3 depth wise conv to provide positional encode information + pe_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, pe_conv, self.activate, drop, fc2, drop] + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class EfficientMultiheadAttention(MultiheadAttention): + """An implementation of Efficient Multi-head Attention of Segformer. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + init_cfg=None, + batch_first=True, + qkv_bias=False, + norm_cfg=dict(type='LN'), + sr_ratio=1): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + dropout_layer=dropout_layer, + init_cfg=init_cfg, + batch_first=batch_first, + bias=qkv_bias) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa + from mmseg import digit_version, mmcv_version + if mmcv_version < digit_version('1.3.17'): + warnings.warn('The legacy version of forward function in' + 'EfficientMultiheadAttention is deprecated in' + 'mmcv>=1.3.17 and will no longer support in the' + 'future. Please upgrade your mmcv.') + self.forward = self.legacy_forward + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + x_q = x_q.transpose(0, 1) + x_kv = x_kv.transpose(0, 1) + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + def legacy_forward(self, x, hw_shape, identity=None): + """multi head attention forward in mmcv version < 1.3.17.""" + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # `need_weights=True` will let nn.MultiHeadAttention + # `return attn_output, attn_output_weights.sum(dim=1) / num_heads` + # The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set + # `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`. + # This issue - `https://github.com/pytorch/pytorch/issues/37583` report + # the error that large scale tensor sum operation may cause cuda error. + out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Segformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + init_cfg (dict, optional): Initialization config dict. + Default:None. + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + sr_ratio=1, + with_cp=False): + super().__init__() + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = EfficientMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + self.with_cp = with_cp + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class MixVisionTransformer(BaseModule): + """The backbone of Segformer. + + This backbone is the implementation of `SegFormer: Simple and + Efficient Design for Semantic Segmentation with + Transformers `_. + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 4, 8]. + patch_sizes (Sequence[int]): The patch_size of each overlapped patch + embedding. Default: [7, 3, 3, 3]. + strides (Sequence[int]): The stride of each overlapped patch embedding. + Default: [4, 2, 2, 2]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 4, 8], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + pretrained=None, + init_cfg=None, + with_cp=False): + super().__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.embed_dims = embed_dims + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + self.with_cp = with_cp + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=patch_sizes[i] // 2, + norm_cfg=norm_cfg) + layer = ModuleList([ + TransformerEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + sr_ratio=sr_ratios[i]) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + self.layers.append(ModuleList([patch_embed, layer, norm])) + cur += num_layer + + def init_weights(self): + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super().init_weights() + + def forward(self, x): + outs = [] + + for i, layer in enumerate(self.layers): + x, hw_shape = layer[0](x) + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs diff --git a/mmseg/models/backbones/mobilenet_v2.py b/mmseg/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..1c21b5df97dade148136e8b0e6b039512f9e03f9 --- /dev/null +++ b/mmseg/models/backbones/mobilenet_v2.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import InvertedResidual, make_divisible + + +@MODELS.register_module() +class MobileNetV2(BaseModule): + """MobileNetV2 backbone. + + This backbone is the implementation of + `MobileNetV2: Inverted Residuals and Linear Bottlenecks + `_. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + strides (Sequence[int], optional): Strides of the first block of each + layer. If not specified, default config in ``arch_setting`` will + be used. + dilations (Sequence[int]): Dilation of each layer. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks. + arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], + [6, 96, 3], [6, 160, 3], [6, 320, 1]] + + def __init__(self, + widen_factor=1., + strides=(1, 2, 2, 2, 1, 2, 1), + dilations=(1, 1, 1, 1, 1, 1, 1), + out_indices=(1, 2, 4, 6), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + self.widen_factor = widen_factor + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == len(self.arch_settings) + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 7): + raise ValueError('the item in out_indices must in ' + f'range(0, 7). But received {index}') + + if frozen_stages not in range(-1, 7): + raise ValueError('frozen_stages must be in range(-1, 7). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks = layer_cfg + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + def make_layer(self, out_channels, num_blocks, stride, dilation, + expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. + dilation (int): Dilation of the first block. + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. + """ + layers = [] + for i in range(num_blocks): + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride if i == 0 else 1, + expand_ratio=expand_ratio, + dilation=dilation if i == 0 else 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/mobilenet_v3.py b/mmseg/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..1efb6e097472d53a5269e52a39ff2cae48e834db --- /dev/null +++ b/mmseg/models/backbones/mobilenet_v3.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import Conv2dAdaptivePadding +from mmengine.model import BaseModule +from mmengine.utils import is_tuple_of +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import InvertedResidualV3 as InvertedResidual + + +@MODELS.register_module() +class MobileNetV3(BaseModule): + """MobileNetV3 backbone. + + This backbone is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + arch (str): Architecture of mobilnetv3, from {'small', 'large'}. + Default: 'small'. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (tuple[int]): Output from which layer. + Default: (0, 1, 12). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 + [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 + [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(0, 1, 12), + frozen_stages=-1, + reduction_factor=1, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert arch in self.arch_settings + assert isinstance(reduction_factor, int) and reduction_factor > 0 + assert is_tuple_of(out_indices, int) + for index in out_indices: + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])+2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])+2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.reduction_factor = reduction_factor + self.norm_eval = norm_eval + self.with_cp = with_cp + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + + # build the first layer (layer0) + in_channels = 16 + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + + if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ + i >= 8: + mid_channels = mid_channels // self.reduction_factor + out_channels = out_channels // self.reduction_factor + + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=(in_channels != mid_channels), + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = f'layer{i + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + + # build the last layer + # block5 layer12 os=32 for small model + # block6 layer16 os=32 for large model + layer = ConvModule( + in_channels=in_channels, + out_channels=576 if self.arch == 'small' else 960, + kernel_size=1, + stride=1, + dilation=4, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = f'layer{len(layer_setting) + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + + # next, convert backbone MobileNetV3 to a semantic segmentation version + if self.arch == 'small': + self.layer4.depthwise_conv.conv.stride = (1, 1) + self.layer9.depthwise_conv.conv.stride = (1, 1) + for i in range(4, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 9: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + else: + self.layer7.depthwise_conv.conv.stride = (1, 1) + self.layer13.depthwise_conv.conv.stride = (1, 1) + for i in range(7, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 13: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return outs + + def _freeze_stages(self): + for i in range(self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmseg/models/backbones/mscan.py b/mmseg/models/backbones/mscan.py new file mode 100644 index 0000000000000000000000000000000000000000..7150cb7a1c13d11dcdcc6fbbc72931154853929e --- /dev/null +++ b/mmseg/models/backbones/mscan.py @@ -0,0 +1,467 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) + +from mmseg.registry import MODELS + + +class Mlp(BaseModule): + """Multi Layer Perceptron (MLP) Module. + + Args: + in_features (int): The dimension of input features. + hidden_features (int): The dimension of hidden features. + Defaults: None. + out_features (int): The dimension of output features. + Defaults: None. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Forward function.""" + + x = self.fc1(x) + + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +class StemConv(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): The dimension of input channels. + out_channels (int): The dimension of output channels. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + in_channels, + out_channels, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels // 2)[1], + build_activation_layer(act_cfg), + nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.size() + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class MSCAAttention(BaseModule): + """Attention Module in Multi-Scale Convolutional Attention Module (MSCA). + + Args: + channels (int): The dimension of channels. + kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + """ + + def __init__(self, + channels, + kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + paddings=[2, [0, 3], [0, 5], [0, 10]]): + super().__init__() + self.conv0 = nn.Conv2d( + channels, + channels, + kernel_size=kernel_sizes[0], + padding=paddings[0], + groups=channels) + for i, (kernel_size, + padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])): + kernel_size_ = [kernel_size, kernel_size[::-1]] + padding_ = [padding, padding[::-1]] + conv_name = [f'conv{i}_1', f'conv{i}_2'] + for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_, + conv_name): + self.add_module( + i_conv, + nn.Conv2d( + channels, + channels, + tuple(i_kernel), + padding=i_pad, + groups=channels)) + self.conv3 = nn.Conv2d(channels, channels, 1) + + def forward(self, x): + """Forward function.""" + + u = x.clone() + + attn = self.conv0(x) + + # Multi-Scale Feature extraction + attn_0 = self.conv0_1(attn) + attn_0 = self.conv0_2(attn_0) + + attn_1 = self.conv1_1(attn) + attn_1 = self.conv1_2(attn_1) + + attn_2 = self.conv2_1(attn) + attn_2 = self.conv2_2(attn_2) + + attn = attn + attn_0 + attn_1 + attn_2 + # Channel Mixing + attn = self.conv3(attn) + + # Convolutional Attention + x = attn * u + + return x + + +class MSCASpatialAttention(BaseModule): + """Spatial Attention Module in Multi-Scale Convolutional Attention Module + (MSCA). + + Args: + in_channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU')): + super().__init__() + self.proj_1 = nn.Conv2d(in_channels, in_channels, 1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = MSCAAttention(in_channels, + attention_kernel_sizes, + attention_kernel_paddings) + self.proj_2 = nn.Conv2d(in_channels, in_channels, 1) + + def forward(self, x): + """Forward function.""" + + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class MSCABlock(BaseModule): + """Basic Multi-Scale Convolutional Attention Block. It leverage the large- + kernel attention (LKA) mechanism to build both channel and spatial + attention. In each branch, it uses two depth-wise strip convolutions to + approximate standard depth-wise convolutions with large kernels. The kernel + size for each branch is set to 7, 11, and 21, respectively. + + Args: + channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + mlp_ratio (float): The ratio of multiple input dimension to + calculate hidden feature in MLP layer. Defaults: 4.0. + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + drop_path (float): The ratio of drop paths. + Defaults: 0.0. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.norm1 = build_norm_layer(norm_cfg, channels)[1] + self.attn = MSCASpatialAttention(channels, attention_kernel_sizes, + attention_kernel_paddings, act_cfg) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, channels)[1] + mlp_hidden_channels = int(channels * mlp_ratio) + self.mlp = Mlp( + in_features=channels, + hidden_features=mlp_hidden_channels, + act_cfg=act_cfg, + drop=drop) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + + def forward(self, x, H, W): + """Forward function.""" + + B, N, C = x.shape + x = x.permute(0, 2, 1).view(B, C, H, W) + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.attn(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + x = x.view(B, C, N).permute(0, 2, 1) + return x + + +class OverlapPatchEmbed(BaseModule): + """Image to Patch Embedding. + + Args: + patch_size (int): The patch size. + Defaults: 7. + stride (int): Stride of the convolutional layer. + Default: 4. + in_channels (int): The number of input channels. + Defaults: 3. + embed_dims (int): The dimensions of embedding. + Defaults: 768. + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + patch_size=7, + stride=4, + in_channels=3, + embed_dim=768, + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2) + self.norm = build_norm_layer(norm_cfg, embed_dim)[1] + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + + x = x.flatten(2).transpose(1, 2) + + return x, H, W + + +@MODELS.register_module() +class MSCAN(BaseModule): + """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone. + + This backbone is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Args: + in_channels (int): The number of input channels. Defaults: 3. + embed_dims (list[int]): Embedding dimension. + Defaults: [64, 128, 256, 512]. + mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim. + Defaults: [4, 4, 4, 4]. + drop_rate (float): Dropout rate. Defaults: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0. + depths (list[int]): Depths of each Swin Transformer stage. + Default: [3, 4, 6, 3]. + num_stages (int): MSCAN stages. Default: 4. + attention_kernel_sizes (list): Size of attention kernel in + Attention Module (Figure 2(b) of original paper). + Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): Size of attention paddings + in Attention Module (Figure 2(b) of original paper). + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + norm_cfg (dict): Config of norm layers. + Defaults: dict(type='SyncBN', requires_grad=True). + pretrained (str, optional): model pretrained path. + Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[4, 4, 4, 4], + drop_rate=0., + drop_path_rate=0., + depths=[3, 4, 6, 3], + num_stages=4, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True), + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.depths = depths + self.num_stages = num_stages + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + if i == 0: + patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg) + else: + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + norm_cfg=norm_cfg) + + block = nn.ModuleList([ + MSCABlock( + channels=embed_dims[i], + attention_kernel_sizes=attention_kernel_sizes, + attention_kernel_paddings=attention_kernel_paddings, + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j], + act_cfg=act_cfg, + norm_cfg=norm_cfg) for j in range(depths[i]) + ]) + norm = nn.LayerNorm(embed_dims[i]) + cur += depths[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + def init_weights(self): + """Initialize modules of MSCAN.""" + + print('init cfg', self.init_cfg) + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super().init_weights() + + def forward(self, x): + """Forward function.""" + + B = x.shape[0] + outs = [] + + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, H, W = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs diff --git a/mmseg/models/backbones/pidnet.py b/mmseg/models/backbones/pidnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b711a373701c0771c5c5997bbb8e5b345d70924 --- /dev/null +++ b/mmseg/models/backbones/pidnet.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.runner import CheckpointLoader +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType +from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck + + +class PagFM(BaseModule): + """Pixel-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + channels (int): The number of channels. + after_relu (bool): Whether to use ReLU before attention. + Default: False. + with_channel (bool): Whether to use channel attention. + Default: False. + upsample_mode (str): The mode of upsample. Default: 'bilinear'. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(typ='ReLU', inplace=True). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + channels: int, + after_relu: bool = False, + with_channel: bool = False, + upsample_mode: str = 'bilinear', + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(typ='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.after_relu = after_relu + self.with_channel = with_channel + self.upsample_mode = upsample_mode + self.f_i = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None) + self.f_p = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None) + if with_channel: + self.up = ConvModule( + channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + if after_relu: + self.relu = MODELS.build(act_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor: + """Forward function. + + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + + Returns: + Tensor: The feature map with pixel-attention-guided fusion. + """ + if self.after_relu: + x_p = self.relu(x_p) + x_i = self.relu(x_i) + + f_i = self.f_i(x_i) + f_i = F.interpolate( + f_i, + size=x_p.shape[2:], + mode=self.upsample_mode, + align_corners=False) + + f_p = self.f_p(x_p) + + if self.with_channel: + sigma = torch.sigmoid(self.up(f_p * f_i)) + else: + sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1)) + + x_i = F.interpolate( + x_i, + size=x_p.shape[2:], + mode=self.upsample_mode, + align_corners=False) + + out = sigma * x_i + (1 - sigma) * x_p + return out + + +class Bag(BaseModule): + """Boundary-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int): The kernel size of the convolution. Default: 3. + padding (int): The padding of the convolution. Default: 1. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer. + Default: dict(order=('norm', 'act', 'conv')). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + + self.conv = ConvModule( + in_channels, + out_channels, + kernel_size, + padding=padding, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor: + """Forward function. + + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + x_d (Tensor): The featrue map from D branch. + + Returns: + Tensor: The feature map with boundary-attention-guided fusion. + """ + sigma = torch.sigmoid(x_d) + return self.conv(sigma * x_p + (1 - sigma) * x_i) + + +class LightBag(BaseModule): + """Light Boundary-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. Default: None. + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.f_p = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.f_i = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor: + """Forward function. + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + x_d (Tensor): The featrue map from D branch. + + Returns: + Tensor: The feature map with light boundary-attention-guided + fusion. + """ + sigma = torch.sigmoid(x_d) + + f_p = self.f_p((1 - sigma) * x_i + x_p) + f_i = self.f_i(x_i + sigma * x_p) + + return f_p + f_i + + +@MODELS.register_module() +class PIDNet(BaseModule): + """PIDNet backbone. + + This backbone is the implementation of `PIDNet: A Real-time Semantic + Segmentation Network Inspired from PID Controller + `_. + Modified from https://github.com/XuJiacong/PIDNet. + + Licensed under the MIT License. + + Args: + in_channels (int): The number of input channels. Default: 3. + channels (int): The number of channels in the stem layer. Default: 64. + ppm_channels (int): The number of channels in the PPM layer. + Default: 96. + num_stem_blocks (int): The number of blocks in the stem layer. + Default: 2. + num_branch_blocks (int): The number of blocks in the branch layer. + Default: 3. + align_corners (bool): The align_corners argument of F.interpolate. + Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int = 3, + channels: int = 64, + ppm_channels: int = 96, + num_stem_blocks: int = 2, + num_branch_blocks: int = 3, + align_corners: bool = False, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None, + **kwargs): + super().__init__(init_cfg) + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + + # stem layer + self.stem = self._make_stem_layer(in_channels, channels, + num_stem_blocks) + self.relu = nn.ReLU() + + # I Branch + self.i_branch_layers = nn.ModuleList() + for i in range(3): + self.i_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + in_channels=channels * 2**(i + 1), + channels=channels * 8 if i > 0 else channels * 4, + num_blocks=num_branch_blocks if i < 2 else 2, + stride=2)) + + # P Branch + self.p_branch_layers = nn.ModuleList() + for i in range(3): + self.p_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + in_channels=channels * 2, + channels=channels * 2, + num_blocks=num_stem_blocks if i < 2 else 1)) + self.compression_1 = ConvModule( + channels * 4, + channels * 2, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.compression_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.pag_1 = PagFM(channels * 2, channels) + self.pag_2 = PagFM(channels * 2, channels) + + # D Branch + if num_stem_blocks == 2: + self.d_branch_layers = nn.ModuleList([ + self._make_single_layer(BasicBlock, channels * 2, channels), + self._make_layer(Bottleneck, channels, channels, 1) + ]) + channel_expand = 1 + spp_module = PAPPM + dfm_module = LightBag + act_cfg_dfm = None + else: + self.d_branch_layers = nn.ModuleList([ + self._make_single_layer(BasicBlock, channels * 2, + channels * 2), + self._make_single_layer(BasicBlock, channels * 2, channels * 2) + ]) + channel_expand = 2 + spp_module = DAPPM + dfm_module = Bag + act_cfg_dfm = act_cfg + + self.diff_1 = ConvModule( + channels * 4, + channels * channel_expand, + kernel_size=3, + padding=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.diff_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=3, + padding=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + + self.spp = spp_module( + channels * 16, ppm_channels, channels * 4, num_scales=5) + self.dfm = dfm_module( + channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm) + + self.d_branch_layers.append( + self._make_layer(Bottleneck, channels * 2, channels * 2, 1)) + + def _make_stem_layer(self, in_channels: int, channels: int, + num_blocks: int) -> nn.Sequential: + """Make stem layer. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_blocks (int): Number of blocks. + + Returns: + nn.Sequential: The stem layer. + """ + + layers = [ + ConvModule( + in_channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + + layers.append( + self._make_layer(BasicBlock, channels, channels, num_blocks)) + layers.append(nn.ReLU()) + layers.append( + self._make_layer( + BasicBlock, channels, channels * 2, num_blocks, stride=2)) + layers.append(nn.ReLU()) + + return nn.Sequential(*layers) + + def _make_layer(self, + block: BasicBlock, + in_channels: int, + channels: int, + num_blocks: int, + stride: int = 1) -> nn.Sequential: + """Make layer for PIDNet backbone. + Args: + block (BasicBlock): Basic block. + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. Default: 1. + + Returns: + nn.Sequential: The Branch Layer. + """ + downsample = None + if stride != 1 or in_channels != channels * block.expansion: + downsample = ConvModule( + in_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + norm_cfg=self.norm_cfg, + act_cfg=None) + + layers = [block(in_channels, channels, stride, downsample)] + in_channels = channels * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + in_channels, + channels, + stride=1, + act_cfg_out=None if i == num_blocks - 1 else self.act_cfg)) + return nn.Sequential(*layers) + + def _make_single_layer(self, + block: Union[BasicBlock, Bottleneck], + in_channels: int, + channels: int, + stride: int = 1) -> nn.Module: + """Make single layer for PIDNet backbone. + Args: + block (BasicBlock or Bottleneck): Basic block or Bottleneck. + in_channels (int): Number of input channels. + channels (int): Number of output channels. + stride (int): Stride of the first block. Default: 1. + + Returns: + nn.Module + """ + + downsample = None + if stride != 1 or in_channels != channels * block.expansion: + downsample = ConvModule( + in_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + norm_cfg=self.norm_cfg, + act_cfg=None) + return block( + in_channels, channels, stride, downsample, act_cfg_out=None) + + def init_weights(self): + """Initialize the weights in backbone. + + Since the D branch is not initialized by the pre-trained model, we + initialize it with the same method as the ResNet. + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if self.init_cfg is not None: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], map_location='cpu') + self.load_state_dict(ckpt, strict=False) + + def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]: + """Forward function. + + Args: + x (Tensor): Input tensor with shape (B, C, H, W). + + Returns: + Tensor or tuple[Tensor]: If self.training is True, return + tuple[Tensor], else return Tensor. + """ + w_out = x.shape[-1] // 8 + h_out = x.shape[-2] // 8 + + # stage 0-2 + x = self.stem(x) + + # stage 3 + x_i = self.relu(self.i_branch_layers[0](x)) + x_p = self.p_branch_layers[0](x) + x_d = self.d_branch_layers[0](x) + + comp_i = self.compression_1(x_i) + x_p = self.pag_1(x_p, comp_i) + diff_i = self.diff_1(x_i) + x_d += F.interpolate( + diff_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_p = x_p.clone() + + # stage 4 + x_i = self.relu(self.i_branch_layers[1](x_i)) + x_p = self.p_branch_layers[1](self.relu(x_p)) + x_d = self.d_branch_layers[1](self.relu(x_d)) + + comp_i = self.compression_2(x_i) + x_p = self.pag_2(x_p, comp_i) + diff_i = self.diff_2(x_i) + x_d += F.interpolate( + diff_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_d = x_d.clone() + + # stage 5 + x_i = self.i_branch_layers[2](x_i) + x_p = self.p_branch_layers[2](self.relu(x_p)) + x_d = self.d_branch_layers[2](self.relu(x_d)) + + x_i = self.spp(x_i) + x_i = F.interpolate( + x_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + out = self.dfm(x_p, x_i, x_d) + return (temp_p, out, temp_d) if self.training else out diff --git a/mmseg/models/backbones/resnest.py b/mmseg/models/backbones/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc380b4460915f476ffc1febcfc145a94fc7c7a --- /dev/null +++ b/mmseg/models/backbones/resnest.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d in ResNeSt. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels. Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + dcn (dict): Config dict for DCN. Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None): + super().__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.with_dcn = dcn is not None + self.dcn = dcn + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_dcn and not fallback_on_stride: + assert conv_cfg is None, 'conv_cfg must be None for DCN' + conv_cfg = dcn + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + """nn.Module: the normalization layer named "norm0" """ + return getattr(self, self.norm0_name) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + batch = x.size(0) + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + inplane (int): Input planes of this block. + planes (int): Middle planes of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Key word arguments for base class. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + """Bottleneck block for ResNeSt.""" + super().__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.with_modulated_dcn = False + self.conv2 = SplitAttentionConv2d( + width, + width, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=self.dcn) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + This backbone is the implementation of `ResNeSt: + Split-Attention Networks `_. + + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super().__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/mmseg/models/backbones/resnet.py b/mmseg/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9226c90d85c938e76f322e58643ee9d7b17ba27b --- /dev/null +++ b/mmseg/models/backbones/resnet.py @@ -0,0 +1,712 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import ResLayer + + +class BasicBlock(BaseModule): + """Basic block for ResNet.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super().__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super().__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + """Forward function for plugins.""" + out = x + for name in plugin_names: + out = getattr(self, name)(x) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + This backbone is the improved implementation of `Deep Residual Learning + for Image Recognition `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Number of stem channels. Default: 64. + base_channels (int): Number of base channels of res layer. Default: 64. + num_stages (int): Resnet stages, normally 4. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: (1, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default: (1, 1, 1, 1). + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: 'pytorch'. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): Dictionary to construct and config conv layer. + When conv_cfg is None, cfg will be set to dict(type='Conv2d'). + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (dict | None): Dictionary to construct and config DCN conv layer. + When dcn is not None, conv_cfg must be None. Default: None. + stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each + stage. The length of stage_with_dcn is equal to num_stages. + Default: (False, False, False, False). + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + + - position (str, required): Position inside block to insert plugin, + options: 'after_conv1', 'after_conv2', 'after_conv3'. + + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + Default: None. + multi_grid (Sequence[int]|None): Multi grid dilation rates of last + stage. Default: None. + contract_dilation (bool): Whether contract first dilation of each layer + Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + multi_grid=None, + contract_dilation=False, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.multi_grid = multi_grid + self.contract_dilation = contract_dilation + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + # multi grid is applied to last layer only + stage_multi_grid = multi_grid if i == len( + self.stage_blocks) - 1 else None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + multi_grid=stage_multi_grid, + contract_dilation=contract_dilation, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i+1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """make plugins for ResNet 'stage_idx'th stage . + + Currently we support to insert 'context_block', + 'empirical_attention_block', 'nonlocal_block' into the backbone like + ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be : + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose 'stage_idx=0', the structure of blocks in the stage would be: + conv1-> conv2->conv3->yyy->zzz1->zzz2 + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer for ResNet.""" + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class ResNetV1c(ResNet): + """ResNetV1c variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in + the input stem with three 3x3 convs. For more details please refer to `Bag + of Tricks for Image Classification with Convolutional Neural Networks + `_. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=True, avg_down=False, **kwargs) + + +@MODELS.register_module() +class ResNetV1d(ResNet): + """ResNetV1d variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=True, avg_down=True, **kwargs) diff --git a/mmseg/models/backbones/resnext.py b/mmseg/models/backbones/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..67a244a12f61b78ee12e89e8b45868781208614c --- /dev/null +++ b/mmseg/models/backbones/resnext.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + super().__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + This backbone is the implementation of `Aggregated + Residual Transformations for Deep Neural + Networks `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Normally 3. + num_stages (int): Resnet stages, normally 4. + groups (int): Group of resnext. + base_width (int): Base width of resnext. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmseg.models import ResNeXt + >>> import torch + >>> self = ResNeXt(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super().__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/mmseg/models/backbones/stdc.py b/mmseg/models/backbones/stdc.py new file mode 100644 index 0000000000000000000000000000000000000000..758a3c92e07dc8d2051f670adf00d163019d758c --- /dev/null +++ b/mmseg/models/backbones/stdc.py @@ -0,0 +1,422 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/MichaelFan01/STDC-Seg.""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmseg.registry import MODELS +from ..utils import resize +from .bisenetv1 import AttentionRefinementModule + + +class STDCModule(BaseModule): + """STDCModule. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels before scaling. + stride (int): The number of stride for the first conv layer. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layers. + fusion_type (str): Type of fusion operation. Default: 'add'. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + norm_cfg=None, + act_cfg=None, + num_convs=4, + fusion_type='add', + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert num_convs > 1 + assert fusion_type in ['add', 'cat'] + self.stride = stride + self.with_downsample = True if self.stride == 2 else False + self.fusion_type = fusion_type + + self.layers = ModuleList() + conv_0 = ConvModule( + in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) + + if self.with_downsample: + self.downsample = ConvModule( + out_channels // 2, + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels // 2, + norm_cfg=norm_cfg, + act_cfg=None) + + if self.fusion_type == 'add': + self.layers.append(nn.Sequential(conv_0, self.downsample)) + self.skip = Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=1, + groups=in_channels, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None)) + else: + self.layers.append(conv_0) + self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + self.layers.append(conv_0) + + for i in range(1, num_convs): + out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i + self.layers.append( + ConvModule( + out_channels // 2**i, + out_channels // out_factor, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + if self.fusion_type == 'add': + out = self.forward_add(inputs) + else: + out = self.forward_cat(inputs) + return out + + def forward_add(self, inputs): + layer_outputs = [] + x = inputs.clone() + for layer in self.layers: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + inputs = self.skip(inputs) + + return torch.cat(layer_outputs, dim=1) + inputs + + def forward_cat(self, inputs): + x0 = self.layers[0](inputs) + layer_outputs = [x0] + for i, layer in enumerate(self.layers[1:]): + if i == 0: + if self.with_downsample: + x = layer(self.downsample(x0)) + else: + x = layer(x0) + else: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + layer_outputs[0] = self.skip(x0) + return torch.cat(layer_outputs, dim=1) + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module. This module is different from FeatureFusionModule + in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter + channel number is calculated by given `scale_factor`, while + FeatureFusionModule in BiSeNetV1 only uses one ConvModule in + `self.conv_atten`. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + scale_factor (int): The number of channel scale factor. + Default: 4. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): The activation config for conv layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scale_factor=4, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + channels = out_channels // scale_factor + self.conv0 = ConvModule( + in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + out_channels, + channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=act_cfg), + ConvModule( + channels, + out_channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=None), nn.Sigmoid()) + + def forward(self, spatial_inputs, context_inputs): + inputs = torch.cat([spatial_inputs, context_inputs], dim=1) + x = self.conv0(inputs) + attn = self.attention(x) + x_attn = x * attn + return x_attn + x + + +@MODELS.register_module() +class STDCNet(BaseModule): + """This backbone is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + stdc_type (int): The type of backbone structure, + `STDCNet1` and`STDCNet2` denotes two main backbones in paper, + whose FLOPs is 813M and 1446M, respectively. + in_channels (int): The num of input_channels. + channels (tuple[int]): The output channels for each stage. + bottleneck_type (str): The type of STDC Module type, the value must + be 'add' or 'cat'. + norm_cfg (dict): Config dict for normalization layer. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layer at each STDC Module. + Default: 4. + with_final_conv (bool): Whether add a conv layer at the Module output. + Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> import torch + >>> stdc_type = 'STDCNet1' + >>> in_channels = 3 + >>> channels = (32, 64, 256, 512, 1024) + >>> bottleneck_type = 'cat' + >>> inputs = torch.rand(1, 3, 1024, 2048) + >>> self = STDCNet(stdc_type, in_channels, + ... channels, bottleneck_type).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 256, 128, 256]) + outputs[1].shape = torch.Size([1, 512, 64, 128]) + outputs[2].shape = torch.Size([1, 1024, 32, 64]) + """ + + arch_settings = { + 'STDCNet1': [(2, 1), (2, 1), (2, 1)], + 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] + } + + def __init__(self, + stdc_type, + in_channels, + channels, + bottleneck_type, + norm_cfg, + act_cfg, + num_convs=4, + with_final_conv=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert stdc_type in self.arch_settings, \ + f'invalid structure {stdc_type} for STDCNet.' + assert bottleneck_type in ['add', 'cat'],\ + f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' + + assert len(channels) == 5,\ + f'invalid channels length {len(channels)} for STDCNet.' + + self.in_channels = in_channels + self.channels = channels + self.stage_strides = self.arch_settings[stdc_type] + self.prtrained = pretrained + self.num_convs = num_convs + self.with_final_conv = with_final_conv + + self.stages = ModuleList([ + ConvModule( + self.in_channels, + self.channels[0], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ]) + # `self.num_shallow_features` is the number of shallow modules in + # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. + # They are both not used for following modules like Attention + # Refinement Module and Feature Fusion Module. + # Thus they would be cut from `outs`. Please refer to Figure 4 + # of original paper for more details. + self.num_shallow_features = len(self.stages) + + for strides in self.stage_strides: + idx = len(self.stages) - 1 + self.stages.append( + self._make_stage(self.channels[idx], self.channels[idx + 1], + strides, norm_cfg, act_cfg, bottleneck_type)) + # After appending, `self.stages` is a ModuleList including several + # shallow modules and STDCModules. + # (len(self.stages) == + # self.num_shallow_features + len(self.stage_strides)) + if self.with_final_conv: + self.final_conv = ConvModule( + self.channels[-1], + max(1024, self.channels[-1]), + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def _make_stage(self, in_channels, out_channels, strides, norm_cfg, + act_cfg, bottleneck_type): + layers = [] + for i, stride in enumerate(strides): + layers.append( + STDCModule( + in_channels if i == 0 else out_channels, + out_channels, + stride, + norm_cfg, + act_cfg, + num_convs=self.num_convs, + fusion_type=bottleneck_type)) + return Sequential(*layers) + + def forward(self, x): + outs = [] + for stage in self.stages: + x = stage(x) + outs.append(x) + if self.with_final_conv: + outs[-1] = self.final_conv(outs[-1]) + outs = outs[self.num_shallow_features:] + return tuple(outs) + + +@MODELS.register_module() +class STDCContextPathNet(BaseModule): + """STDCNet with Context Path. The `outs` below is a list of three feature + maps from deep to shallow, whose height and width is from small to big, + respectively. The biggest feature map of `outs` is outputted for + `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. + The other two feature maps are used for Attention Refinement Module, + respectively. Besides, the biggest feature map of `outs` and the last + output of Attention Refinement Module are concatenated for Feature Fusion + Module. Then, this fusion feature map `feat_fuse` would be outputted for + `decode_head`. More details please refer to Figure 4 of original paper. + + Args: + backbone_cfg (dict): Config dict for stdc backbone. + last_in_channels (tuple(int)), The number of channels of last + two feature maps from stdc backbone. Default: (1024, 512). + out_channels (int): The channels of output feature maps. + Default: 128. + ffm_cfg (dict): Config dict for Feature Fusion Module. Default: + `dict(in_channels=512, out_channels=256, scale_factor=4)`. + upsample_mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'``. + align_corners (str): align_corners argument of F.interpolate. It + must be `None` if upsample_mode is ``'nearest'``. Default: None. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Return: + outputs (tuple): The tuple of list of output feature map for + auxiliary heads and decoder head. + """ + + def __init__(self, + backbone_cfg, + last_in_channels=(1024, 512), + out_channels=128, + ffm_cfg=dict( + in_channels=512, out_channels=256, scale_factor=4), + upsample_mode='nearest', + align_corners=None, + norm_cfg=dict(type='BN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.backbone = MODELS.build(backbone_cfg) + self.arms = ModuleList() + self.convs = ModuleList() + for channels in last_in_channels: + self.arms.append(AttentionRefinementModule(channels, out_channels)) + self.convs.append( + ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=norm_cfg)) + self.conv_avg = ConvModule( + last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) + + self.ffm = FeatureFusionModule(**ffm_cfg) + + self.upsample_mode = upsample_mode + self.align_corners = align_corners + + def forward(self, x): + outs = list(self.backbone(x)) + avg = F.adaptive_avg_pool2d(outs[-1], 1) + avg_feat = self.conv_avg(avg) + + feature_up = resize( + avg_feat, + size=outs[-1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + arms_out = [] + for i in range(len(self.arms)): + x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up + feature_up = resize( + x_arm, + size=outs[len(outs) - 1 - i - 1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + feature_up = self.convs[i](feature_up) + arms_out.append(feature_up) + + feat_fuse = self.ffm(outs[0], arms_out[1]) + + # The `outputs` has four feature maps. + # `outs[0]` is outputted for `STDCHead` auxiliary head. + # Two feature maps of `arms_out` are outputted for auxiliary head. + # `feat_fuse` is outputted for decoder head. + outputs = [outs[0]] + list(arms_out) + [feat_fuse] + return tuple(outputs) diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ace3c1391c1e8bf85961f69581f6d69be0b9af --- /dev/null +++ b/mmseg/models/backbones/swin.py @@ -0,0 +1,755 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmengine.logging import print_log +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) +from mmengine.runner import CheckpointLoader +from mmengine.utils import to_2tuple + +from mmseg.registry import MODELS +from ..utils.embed import PatchEmbed, PatchMerging + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ShiftWindowMSA(BaseModule): + """Shifted Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + init_cfg=None) + + self.drop = build_dropout(dropout_layer) + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + """ + Args: + x: (B, H, W, C) + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +class SwinBlock(BaseModule): + """" + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + self.with_cp = with_cp + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + init_cfg=None) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=True, + init_cfg=None) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. + downsample (BaseModule | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + downsample=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + assert len(drop_path_rates) == depth + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=False if i % 2 == 0 else True, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rates[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x, hw_shape): + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + else: + return x, hw_shape, x, hw_shape + + +@MODELS.register_module() +class SwinTransformer(BaseModule): + """Swin Transformer backbone. + + This backbone is the implementation of `Swin Transformer: + Hierarchical Vision Transformer using Shifted + Windows `_. + Inspiration from https://github.com/microsoft/Swin-Transformer. + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LN'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained=None, + frozen_stages=-1, + init_cfg=None): + self.frozen_stages = frozen_stages + + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + init_cfg = init_cfg + else: + raise TypeError('pretrained must be a str or None') + + super().__init__(init_cfg=init_cfg) + + num_layers = len(depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + + assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=strides[0], + padding='corner', + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + + if self.use_abs_pos_embed: + patch_row = pretrain_img_size[0] // patch_size + patch_col = pretrain_img_size[1] // patch_size + num_patches = patch_row * patch_col + self.absolute_pos_embed = nn.Parameter( + torch.zeros((1, num_patches, embed_dims))) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # set stochastic depth decay rule + total_depth = sum(depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + self.stages = ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=int(mlp_ratio * in_channels), + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.stages.append(stage) + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.use_abs_pos_embed: + self.absolute_pos_embed.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f'norm{i-1}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + if self.init_cfg is None: + print_log(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + if 'state_dict' in ckpt: + _state_dict = ckpt['state_dict'] + elif 'model' in ckpt: + _state_dict = ckpt['model'] + else: + _state_dict = ckpt + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = self.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + print_log('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() + if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = self.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + print_log(f'Error in loading {table_key}, pass') + elif L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() + + # load state_dict + self.load_state_dict(state_dict, strict=False) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return outs diff --git a/mmseg/models/backbones/timm_backbone.py b/mmseg/models/backbones/timm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..1eef302bddeac3cee71412bcb481b68b796e515f --- /dev/null +++ b/mmseg/models/backbones/timm_backbone.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + import timm +except ImportError: + timm = None + +from mmengine.model import BaseModule +from mmengine.registry import MODELS as MMENGINE_MODELS + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class TIMMBackbone(BaseModule): + """Wrapper to use backbones from timm library. More details can be found in + `timm `_ . + + Args: + model_name (str): Name of timm model to instantiate. + pretrained (bool): Load pretrained weights if True. + checkpoint_path (str): Path of checkpoint to load after + model is initialized. + in_channels (int): Number of input image channels. Default: 3. + init_cfg (dict, optional): Initialization config dict + **kwargs: Other timm & model specific arguments. + """ + + def __init__( + self, + model_name, + features_only=True, + pretrained=True, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs, + ): + if timm is None: + raise RuntimeError('timm is not installed') + super().__init__(init_cfg) + if 'norm_layer' in kwargs: + kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer']) + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs, + ) + + # Make unused parameters None + self.timm_model.global_pool = None + self.timm_model.fc = None + self.timm_model.classifier = None + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + def forward(self, x): + features = self.timm_model(x) + return features diff --git a/mmseg/models/backbones/twins.py b/mmseg/models/backbones/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a6eea795cf53bee6b52ece80d5d90ecc969970 --- /dev/null +++ b/mmseg/models/backbones/twins.py @@ -0,0 +1,588 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.models.backbones.mit import EfficientMultiheadAttention +from mmseg.registry import MODELS +from ..utils.embed import PatchEmbed + + +class GlobalSubsampledAttention(EfficientMultiheadAttention): + """Global Sub-sampled Attention (Spatial Reduction Attention) + + This module is modified from EfficientMultiheadAttention, + which is a module from mmseg.models.backbones.mit.py. + Specifically, there is no difference between + `GlobalSubsampledAttention` and `EfficientMultiheadAttention`, + `GlobalSubsampledAttention` is built as a brand new class + because it is renamed as `Global sub-sampled attention (GSA)` + in paper. + + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dims) + or (n, batch, embed_dims). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default: True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT. + Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + batch_first=True, + qkv_bias=True, + norm_cfg=dict(type='LN'), + sr_ratio=1, + init_cfg=None): + super().__init__( + embed_dims, + num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + dropout_layer=dropout_layer, + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio, + init_cfg=init_cfg) + + +class GSAEncoderLayer(BaseModule): + """Implements one encoder layer with GSA. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = GlobalSubsampledAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class LocallyGroupedSelfAttention(BaseModule): + """Locally-grouped Self Attention (LSA) module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + window_size(int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + window_size=1, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \ + f'divided by num_heads ' \ + f'{num_heads}.' + self.embed_dims = embed_dims + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + self.window_size = window_size + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + x = x.view(b, h, w, c) + + # pad feature maps to multiples of Local-groups + pad_l = pad_t = 0 + pad_r = (self.window_size - w % self.window_size) % self.window_size + pad_b = (self.window_size - h % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + # calculate attention mask for LSA + Hp, Wp = x.shape[1:-1] + _h, _w = Hp // self.window_size, Wp // self.window_size + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + # [B, _h, _w, window_size, window_size, C] + x = x.reshape(b, _h, self.window_size, _w, self.window_size, + c).transpose(2, 3) + mask = mask.reshape(1, _h, self.window_size, _w, + self.window_size).transpose(2, 3).reshape( + 1, _h * _w, + self.window_size * self.window_size) + # [1, _h*_w, window_size*window_size, window_size*window_size] + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # [3, B, _w*_h, nhead, window_size*window_size, dim] + qkv = self.qkv(x).reshape(b, _h * _w, + self.window_size * self.window_size, 3, + self.num_heads, c // self.num_heads).permute( + 3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size, + self.window_size, c) + x = attn.transpose(2, 3).reshape(b, _h * self.window_size, + _w * self.window_size, c) + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + + x = x.reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LSAEncoderLayer(BaseModule): + """Implements one encoder layer in Twins-SVT. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=1, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, + qkv_bias, qk_scale, + attn_drop_rate, drop_rate, + window_size) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class ConditionalPositionEncoding(BaseModule): + """The Conditional Position Encoding (CPE) module. + + The CPE is the implementation of 'Conditional Positional Encodings + for Vision Transformers '_. + + Args: + in_channels (int): Number of input channels. + embed_dims (int): The feature dimension. Default: 768. + stride (int): Stride of conv layer. Default: 1. + """ + + def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + groups=embed_dims) + self.stride = stride + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + feat_token = x + cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w) + if self.stride == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + +@MODELS.register_module() +class PCPVT(BaseModule): + """The backbone of Twins-PCPVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4, 8]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [3, 4, 6, 3] + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [8, 4, 2, 1]. + norm_after_stage(bool): Add extra norm. Default False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + norm_after_stage=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + self.depths = depths + + # patch_embed + self.patch_embeds = ModuleList() + self.position_encoding_drops = ModuleList() + self.layers = ModuleList() + + for i in range(len(depths)): + self.patch_embeds.append( + PatchEmbed( + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dims=embed_dims[i], + conv_type='Conv2d', + kernel_size=patch_sizes[i], + stride=strides[i], + padding='corner', + norm_cfg=norm_cfg)) + + self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) + + self.position_encodings = ModuleList([ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in embed_dims + ]) + + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for k in range(len(depths)): + _block = ModuleList([ + GSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=sr_ratios[k]) for i in range(depths[k]) + ]) + self.layers.append(_block) + cur += depths[k] + + self.norm_name, norm = build_norm_layer( + norm_cfg, embed_dims[-1], postfix=1) + + self.out_indices = out_indices + self.norm_after_stage = norm_after_stage + if self.norm_after_stage: + self.norm_list = ModuleList() + for dim in embed_dims: + self.norm_list.append(build_norm_layer(norm_cfg, dim)[1]) + + def init_weights(self): + if self.init_cfg is not None: + super().init_weights() + else: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + + def forward(self, x): + outputs = list() + + b = x.shape[0] + + for i in range(len(self.depths)): + x, hw_shape = self.patch_embeds[i](x) + h, w = hw_shape + x = self.position_encoding_drops[i](x) + for j, blk in enumerate(self.layers[i]): + x = blk(x, hw_shape) + if j == 0: + x = self.position_encodings[i](x, hw_shape) + if self.norm_after_stage: + x = self.norm_list[i](x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + + if i in self.out_indices: + outputs.append(x) + + return tuple(outputs) + + +@MODELS.register_module() +class SVT(PCPVT): + """The backbone of Twins-SVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Dropout rate. Default 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.2. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [4, 4, 4]. + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [4, 2, 1]. + windiow_sizes (list): Window size of LSA. Default: [7, 7, 7], + input_features_slice(bool): Input features need slice. Default: False. + norm_after_stage(bool): Add extra norm. Default False. + strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2) + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_cfg=dict(type='LN'), + depths=[4, 4, 4], + sr_ratios=[4, 2, 1], + windiow_sizes=[7, 7, 7], + norm_after_stage=True, + pretrained=None, + init_cfg=None): + super().__init__(in_channels, embed_dims, patch_sizes, strides, + num_heads, mlp_ratios, out_indices, qkv_bias, + drop_rate, attn_drop_rate, drop_path_rate, norm_cfg, + depths, sr_ratios, norm_after_stage, pretrained, + init_cfg) + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + for k in range(len(depths)): + for i in range(depths[k]): + if i % 2 == 0: + self.layers[k][i] = \ + LSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:k])+i], + qkv_bias=qkv_bias, + window_size=windiow_sizes[k]) diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..545921db8e14668e454f5834f9a1618fe0c04ffe --- /dev/null +++ b/mmseg/models/backbones/unet.py @@ -0,0 +1,436 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import UpConvBlock, Upsample + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@MODELS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super().__init__() + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + norm_name, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@MODELS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super().__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@MODELS.register_module() +class UNet(BaseModule): + """UNet backbone. + + This backbone is the implementation of `U-Net: Convolutional Networks + for Biomedical Image Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, \ + 'The length of strides should be equal to num_stages, '\ + f'while the strides is {strides}, the length of '\ + f'strides is {len(strides)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_num_convs) == num_stages, \ + 'The length of enc_num_convs should be equal to num_stages, '\ + f'while the enc_num_convs is {enc_num_convs}, the length of '\ + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_num_convs) == (num_stages-1), \ + 'The length of dec_num_convs should be equal to (num_stages-1), '\ + f'while the dec_num_convs is {dec_num_convs}, the length of '\ + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(downsamples) == (num_stages-1), \ + 'The length of downsamples should be equal to (num_stages-1), '\ + f'while the downsamples is {downsamples}, the length of '\ + f'downsamples is {len(downsamples)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_dilations) == num_stages, \ + 'The length of enc_dilations should be equal to num_stages, '\ + f'while the enc_dilations is {enc_dilations}, the length of '\ + f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_dilations) == (num_stages-1), \ + 'The length of dec_dilations should be equal to (num_stages-1), '\ + f'while the dec_dilations is {dec_dilations}, the length of '\ + f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ + f'{num_stages}.' + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append(nn.Sequential(*enc_conv_block)) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert (h % whole_downsample_rate == 0) \ + and (w % whole_downsample_rate == 0),\ + f'The input image size {(h, w)} should be divisible by the whole '\ + f'downsample rate {whole_downsample_rate}, when num_stages is '\ + f'{self.num_stages}, strides is {self.strides}, and downsamples '\ + f'is {self.downsamples}.' diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..3c96f6549372717acfb2ef442bc222d277ee37a8 --- /dev/null +++ b/mmseg/models/backbones/vit.py @@ -0,0 +1,438 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.logging import print_log +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.registry import MODELS +from ..utils import PatchEmbed, resize + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + attn_cfg=dict(), + ffn_cfg=dict(), + with_cp=False): + super().__init__() + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + attn_cfg.update( + dict( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + batch_first=batch_first, + bias=qkv_bias)) + + self.build_attn(attn_cfg) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + ffn_cfg.update( + dict( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) + if drop_path_rate > 0 else None, + act_cfg=act_cfg)) + self.build_ffn(ffn_cfg) + self.with_cp = with_cp + + def build_attn(self, attn_cfg): + self.attn = MultiheadAttention(**attn_cfg) + + def build_ffn(self, ffn_cfg): + self.ffn = FFN(**ffn_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class VisionTransformer(BaseModule): + """Vision Transformer. + + This backbone is the implementation of `An Image is Worth 16x16 Words: + Transformers for Image Recognition at + Scale `_. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qkv_bias (bool): enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Default: True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Default: bicubic. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + interpolate_mode='bicubic', + num_fcs=2, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.img_size = img_size + self.patch_size = patch_size + self.interpolate_mode = interpolate_mode + self.norm_eval = norm_eval + self.with_cp = with_cp + self.pretrained = pretrained + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding='corner', + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) + + num_patches = (img_size[0] // patch_size) * \ + (img_size[1] // patch_size) + + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + batch_first=True)) + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def init_weights(self): + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + if 'pos_embed' in state_dict.keys(): + if self.pos_embed.shape != state_dict['pos_embed'].shape: + print_log(msg=f'Resize the pos_embed shape from ' + f'{state_dict["pos_embed"].shape} to ' + f'{self.pos_embed.shape}') + h, w = self.img_size + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], + (h // self.patch_size, w // self.patch_size), + (pos_size, pos_size), self.interpolate_mode) + + load_state_dict(self, state_dict, strict=False, logger=None) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def _pos_embeding(self, patched_img, hw_shape, pos_embed): + """Positioning embeding method. + + Resize the pos_embed, if the input image size doesn't match + the training size. + Args: + patched_img (torch.Tensor): The patched image, it should be + shape of [B, L1, C]. + hw_shape (tuple): The downsampled image resolution. + pos_embed (torch.Tensor): The pos_embed weighs, it should be + shape of [B, L2, c]. + Return: + torch.Tensor: The pos encoded image feature. + """ + assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ + 'the shapes of patched_img and pos_embed must be [B, L, C]' + x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] + if x_len != pos_len: + if pos_len == (self.img_size[0] // self.patch_size) * ( + self.img_size[1] // self.patch_size) + 1: + pos_h = self.img_size[0] // self.patch_size + pos_w = self.img_size[1] // self.patch_size + else: + raise ValueError( + 'Unexpected shape of pos_embed, got {}.'.format( + pos_embed.shape)) + pos_embed = self.resize_pos_embed(pos_embed, hw_shape, + (pos_h, pos_w), + self.interpolate_mode) + return self.drop_after_pos(patched_img + pos_embed) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = pos_shape + cls_token_weight = pos_embed[:, 0] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize( + pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + cls_token_weight = cls_token_weight.unsqueeze(1) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self._pos_embeding(x, hw_shape, self.pos_embed) + + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/mmseg/models/builder.py b/mmseg/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..081c646b49b8ff1ea6c42d1ea4e24e63cdf6b43a --- /dev/null +++ b/mmseg/models/builder.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + warnings.warn('``build_backbone`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + warnings.warn('``build_neck`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + warnings.warn('``build_head`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + warnings.warn('``build_loss`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return LOSSES.build(cfg) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/mmseg/models/data_preprocessor.py b/mmseg/models/data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..173d80c9aafdae8f7ae7b13fe7529c9da2711dee --- /dev/null +++ b/mmseg/models/data_preprocessor.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from numbers import Number +from typing import Any, Dict, List, Optional, Sequence + +import torch +from mmengine.model import BaseDataPreprocessor + +from mmseg.registry import MODELS +from mmseg.utils import stack_batch + + +@MODELS.register_module() +class SegDataPreProcessor(BaseDataPreprocessor): + """Image pre-processor for segmentation tasks. + + Comparing with the :class:`mmengine.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the input size with defined ``pad_val``, and pad seg map + with defined ``seg_pad_val``. + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + padding_mode (str): Type of padding. Default: constant. + - constant: pads with a constant value, this value is specified + with pad_val. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + batch_augments (list[dict], optional): Batch-level augmentations + test_cfg (dict, optional): The padding size config in testing, if not + specify, will use `size` and `size_divisor` params as default. + Defaults to None, only supports keys `size` or `size_divisor`. + """ + + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Number = 0, + seg_pad_val: Number = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + batch_augments: Optional[List[dict]] = None, + test_cfg: dict = None, + ): + super().__init__() + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + assert not (bgr_to_rgb and rgb_to_bgr), ( + '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') + self.channel_conversion = rgb_to_bgr or bgr_to_rgb + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + # TODO: support batch augmentations. + self.batch_augments = batch_augments + + # Support different padding methods in testing + self.test_cfg = test_cfg + + def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: + """Perform normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + Dict: Data in the same format as the model input. + """ + data = self.cast_data(data) # type: ignore + inputs = data['inputs'] + data_samples = data.get('data_samples', None) + # TODO: whether normalize should be after stack_batch + if self.channel_conversion and inputs[0].size(0) == 3: + inputs = [_input[[2, 1, 0], ...] for _input in inputs] + + inputs = [_input.float() for _input in inputs] + if self._enable_normalize: + inputs = [(_input - self.mean) / self.std for _input in inputs] + + if training: + assert data_samples is not None, ('During training, ', + '`data_samples` must be define.') + inputs, data_samples = stack_batch( + inputs=inputs, + data_samples=data_samples, + size=self.size, + size_divisor=self.size_divisor, + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + + if self.batch_augments is not None: + inputs, data_samples = self.batch_augments( + inputs, data_samples) + else: + # assert len(inputs) == 1, ( + # 'Batch inference is not support currently, ' + # 'as the image size might be different in a batch') + # pad images when testing + if self.test_cfg: + inputs, padded_samples = stack_batch( + inputs=inputs, + size=self.test_cfg.get('size', None), + size_divisor=self.test_cfg.get('size_divisor', None), + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + for data_sample, pad_info in zip(data_samples, padded_samples): + data_sample.set_metainfo({**pad_info}) + else: + inputs = torch.stack(inputs, dim=0) + + return dict(inputs=inputs, data_samples=data_samples) diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18235456bc99775c29efba15be851e4c55558559 --- /dev/null +++ b/mmseg/models/decode_heads/__init__.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ann_head import ANNHead +from .apc_head import APCHead +from .aspp_head import ASPPHead +from .cc_head import CCHead +from .da_head import DAHead +from .dm_head import DMHead +from .dnl_head import DNLHead +from .dpt_head import DPTHead +from .ema_head import EMAHead +from .enc_head import EncHead +from .fcn_head import FCNHead +from .fpn_head import FPNHead +from .gc_head import GCHead +from .ham_head import LightHamHead +from .isa_head import ISAHead +from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator +from .lraspp_head import LRASPPHead +from .mask2former_head import Mask2FormerHead +from .maskformer_head import MaskFormerHead +from .nl_head import NLHead +from .ocr_head import OCRHead +from .pid_head import PIDHead +from .point_head import PointHead +from .psa_head import PSAHead +from .psp_head import PSPHead +from .segformer_head import SegformerHead +from .segmenter_mask_head import SegmenterMaskTransformerHead +from .sep_aspp_head import DepthwiseSeparableASPPHead +from .sep_fcn_head import DepthwiseSeparableFCNHead +from .setr_mla_head import SETRMLAHead +from .setr_up_head import SETRUPHead +from .stdc_head import STDCHead +from .uper_head import UPerHead + +__all__ = [ + 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', + 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', + 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', + 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', + 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', + 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', + 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', + 'LightHamHead', 'PIDHead' +] diff --git a/mmseg/models/decode_heads/__pycache__/__init__.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f207641b7765e1be847c7a1c4f0155164798c5e1 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ann_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/ann_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51ed471a31f8bc0e69414cc2aeec22ffd1d550dc Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ann_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/apc_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/apc_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98febba8f7856936ba1b774ab4faee0610f04488 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/apc_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/aspp_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/aspp_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d16496cff9aec8103b7e1b7ac150715d9f05d10 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/aspp_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..494d80405b224a83ee62fe2f404cdab7626c968c Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/cascade_decode_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/cc_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/cc_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d89f9bf747181c78e6437d9d45555fc21bf87ea3 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/cc_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/da_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/da_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e0b7716fe1be053293f290f95dafd7ad6553984 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/da_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/decode_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/decode_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..121d321b0b5859535695a9a69f37b082a7866adf Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/decode_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/dm_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/dm_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..add73b07b6d4e9d5f47724a7f4a389a90d131b8d Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/dm_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/dnl_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/dnl_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6acbd63171970aecaf8ff54fab387d22d4b11754 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/dnl_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/dpt_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/dpt_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5f744889d3a5f46c5a14e4bb0c7383bc19975fd Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/dpt_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ema_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/ema_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78c30d0b02eda2353b86ac1005ab8740831ae288 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ema_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/enc_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/enc_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afaca8c33b6fa0e29d94fec0f63505853fea82d9 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/enc_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/fcn_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/fcn_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69bae2920a2eb9cfd655fb8584b348ac60c1d7e1 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/fcn_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/fpn_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/fpn_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5540d8fca1801a67c2478ef2b4df6c119e685e75 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/fpn_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/gc_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/gc_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53ffdeb761d9e57fb07c490365eb0b67dc99e518 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/gc_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ham_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/ham_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0ad29cba1b451d69b547f4d758839766004192d Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ham_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/isa_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/isa_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48ce84541ce4687dada41a81142020988d2b7529 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/isa_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/knet_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/knet_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800bc645ca8852da03f15872ed67bb901d1aa079 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/knet_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8ee8e4184d542e6f763a82cbeccdc0b5a072aeb Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/lraspp_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/mask2former_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/mask2former_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b8ed454b88e5c4b8465b9933ff7089fc8ea643 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/mask2former_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/maskformer_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/maskformer_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44706156ae6888ebfef72d0444e0fc3364551ead Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/maskformer_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/nl_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/nl_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70de9f7712d3e983f95232a70f9f4bc4e57c5e56 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/nl_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/ocr_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/ocr_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ef4cbc74bfd854d42b6c4ca3fe881549fe8a918 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/ocr_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/pid_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/pid_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..705ba0e5d5e2d7d201cd14d5adeed9149ae3a83f Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/pid_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/point_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/point_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5439fff49889e91c338e420bacc2c3fc42b79ff2 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/point_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/psa_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/psa_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aebb210a416aeab6de5332f85d433e8b1c71dbd Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/psa_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/psp_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/psp_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6a8697c1205161e91dbc4373aefefe1e4ff7ec5 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/psp_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/segformer_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/segformer_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5af9c9ca7670a97c8537edfd1807abea7f3d305 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/segformer_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48cc65122a574326392f438db9e555faa665a045 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/segmenter_mask_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b79405ba531260d44614a1019c633454b594f48 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/sep_aspp_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae2ece32c2f45860fe38cb79b18fca936e15c930 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/sep_fcn_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f85f67260cd27725c9b69b37a82806c9999a1c62 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/setr_mla_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3862796e91f56d8af40b54c7de6df091bb05983 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/setr_up_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/stdc_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/stdc_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..feafd543eb8f486f532d33351ceed686f0f766ff Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/stdc_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/__pycache__/uper_head.cpython-310.pyc b/mmseg/models/decode_heads/__pycache__/uper_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f4c22f0a8c0f97a7fa4bc0f6b3c691c1d004709 Binary files /dev/null and b/mmseg/models/decode_heads/__pycache__/uper_head.cpython-310.pyc differ diff --git a/mmseg/models/decode_heads/ann_head.py b/mmseg/models/decode_heads/ann_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2b40ef5aa1da0bc2473597fedca5b3f33973beb0 --- /dev/null +++ b/mmseg/models/decode_heads/ann_head.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PPMConcat(nn.ModuleList): + """Pyramid Pooling Module that only concat the features of each layer. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + """ + + def __init__(self, pool_scales=(1, 3, 6, 8)): + super().__init__( + [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) + + def forward(self, feats): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(feats) + ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) + concat_outs = torch.cat(ppm_outs, dim=2) + return concat_outs + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Make a ANN used SelfAttentionBlock. + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_scale (int): The scale of query feature map. + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, share_key_query, query_scale, key_pool_scales, + conv_cfg, norm_cfg, act_cfg): + key_psp = PPMConcat(key_pool_scales) + if query_scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=query_scale) + else: + query_downsample = None + super().__init__( + key_in_channels=low_in_channels, + query_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=share_key_query, + query_downsample=query_downsample, + key_downsample=key_psp, + key_query_num_convs=1, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + +class AFNB(nn.Module): + """Asymmetric Fusion Non-local Block(AFNB) + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + and query projection. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, query_scales, key_pool_scales, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=False, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + out_channels + high_in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, low_feats, high_feats): + """Forward function.""" + priors = [stage(high_feats, low_feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, high_feats], 1)) + return output + + +class APNB(nn.Module): + """Asymmetric Pyramid Non-local Block (APNB) + + Args: + in_channels (int): Input channels of key/query feature, + which is the key feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, out_channels, query_scales, + key_pool_scales, conv_cfg, norm_cfg, act_cfg): + super().__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=in_channels, + high_in_channels=in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=True, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + 2 * in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, feats): + """Forward function.""" + priors = [stage(feats, feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, feats], 1)) + return output + + +@MODELS.register_module() +class ANNHead(BaseDecodeHead): + """Asymmetric Non-local Neural Networks for Semantic Segmentation. + + This head is the implementation of `ANNNet + `_. + + Args: + project_channels (int): Projection channels for Nonlocal. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): The pooling scales of key feature map. + Default: (1, 3, 6, 8). + """ + + def __init__(self, + project_channels, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + assert len(self.in_channels) == 2 + low_in_channels, high_in_channels = self.in_channels + self.project_channels = project_channels + self.fusion = AFNB( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + out_channels=high_in_channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + high_in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.context = APNB( + in_channels=self.channels, + out_channels=self.channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + low_feats, high_feats = self._transform_inputs(inputs) + output = self.fusion(low_feats, high_feats) + output = self.dropout(output) + output = self.bottleneck(output) + output = self.context(output) + output = self.cls_seg(output) + + return output diff --git a/mmseg/models/decode_heads/apc_head.py b/mmseg/models/decode_heads/apc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..728f39659c63680944306fddc9e33b7c9172c1ba --- /dev/null +++ b/mmseg/models/decode_heads/apc_head.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ACM(nn.Module): + """Adaptive Context Module used in APCNet. + + Args: + pool_scale (int): Pooling scale used in Adaptive Context + Module to extract region features. + fusion (bool): Add one conv to fuse residual feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.pool_scale = pool_scale + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.pooled_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.global_info = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) + + self.residual_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) + # [batch_size, channels, h, w] + x = self.input_redu_conv(x) + # [batch_size, channels, pool_scale, pool_scale] + pooled_x = self.pooled_redu_conv(pooled_x) + batch_size = x.size(0) + # [batch_size, pool_scale * pool_scale, channels] + pooled_x = pooled_x.view(batch_size, self.channels, + -1).permute(0, 2, 1).contiguous() + # [batch_size, h * w, pool_scale * pool_scale] + affinity_matrix = self.gla(x + resize( + self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) + ).permute(0, 2, 3, 1).reshape( + batch_size, -1, self.pool_scale**2) + affinity_matrix = F.sigmoid(affinity_matrix) + # [batch_size, h * w, channels] + z_out = torch.matmul(affinity_matrix, pooled_x) + # [batch_size, channels, h * w] + z_out = z_out.permute(0, 2, 1).contiguous() + # [batch_size, channels, h, w] + z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) + z_out = self.residual_conv(z_out) + z_out = F.relu(z_out + x) + if self.fusion: + z_out = self.fusion_conv(z_out) + + return z_out + + +@MODELS.register_module() +class APCHead(BaseDecodeHead): + """Adaptive Pyramid Context Network for Semantic Segmentation. + + This head is the implementation of + `APCNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Adaptive Context + Module. Default: (1, 2, 3, 6). + fusion (bool): Add one conv to fuse residual feature. + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): + super().__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.fusion = fusion + acm_modules = [] + for pool_scale in self.pool_scales: + acm_modules.append( + ACM(pool_scale, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.acm_modules = nn.ModuleList(acm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + acm_outs = [x] + for acm_module in self.acm_modules: + acm_outs.append(acm_module(x)) + acm_outs = torch.cat(acm_outs, dim=1) + output = self.bottleneck(acm_outs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/aspp_head.py b/mmseg/models/decode_heads/aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7185d7de58d35ef17e5d54e0e75b045e8724c4 --- /dev/null +++ b/mmseg/models/decode_heads/aspp_head.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ASPPModule(nn.ModuleList): + """Atrous Spatial Pyramid Pooling (ASPP) Module. + + Args: + dilations (tuple[int]): Dilation rate of each layer. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, + act_cfg): + super().__init__() + self.dilations = dilations + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for dilation in dilations: + self.append( + ConvModule( + self.in_channels, + self.channels, + 1 if dilation == 1 else 3, + dilation=dilation, + padding=0 if dilation == 1 else dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x): + """Forward function.""" + aspp_outs = [] + for aspp_module in self: + aspp_outs.append(aspp_module(x)) + + return aspp_outs + + +@MODELS.register_module() +class ASPPHead(BaseDecodeHead): + """Rethinking Atrous Convolution for Semantic Image Segmentation. + + This head is the implementation of `DeepLabV3 + `_. + + Args: + dilations (tuple[int]): Dilation rates for ASPP module. + Default: (1, 6, 12, 18). + """ + + def __init__(self, dilations=(1, 6, 12, 18), **kwargs): + super().__init__(**kwargs) + assert isinstance(dilations, (list, tuple)) + self.dilations = dilations + self.image_pool = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.aspp_modules = ASPPModule( + dilations, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + (len(dilations) + 1) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + feats = self.bottleneck(aspp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/cascade_decode_head.py b/mmseg/models/decode_heads/cascade_decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2bcb9302235e3881696dff6657e3e7fb12609b --- /dev/null +++ b/mmseg/models/decode_heads/cascade_decode_head.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List + +from torch import Tensor + +from mmseg.utils import ConfigType +from .decode_head import BaseDecodeHead + + +class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): + """Base class for cascade decode head used in + :class:`CascadeEncoderDecoder.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @abstractmethod + def forward(self, inputs, prev_output): + """Placeholder of forward function.""" + pass + + def loss(self, inputs: List[Tensor], prev_output: Tensor, + batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor: + """Forward function for training. + + Args: + inputs (List[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs, prev_output) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + + return losses + + def predict(self, inputs: List[Tensor], prev_output: Tensor, + batch_img_metas: List[dict], tese_cfg: ConfigType): + """Forward function for testing. + + Args: + inputs (List[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + seg_logits = self.forward(inputs, prev_output) + + return self.predict_by_feat(seg_logits, batch_img_metas) diff --git a/mmseg/models/decode_heads/cc_head.py b/mmseg/models/decode_heads/cc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e9075a2648d77f6bca6bb29f3e7db52a329f7afb --- /dev/null +++ b/mmseg/models/decode_heads/cc_head.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + +try: + from mmcv.ops import CrissCrossAttention +except ModuleNotFoundError: + CrissCrossAttention = None + + +@MODELS.register_module() +class CCHead(FCNHead): + """CCNet: Criss-Cross Attention for Semantic Segmentation. + + This head is the implementation of `CCNet + `_. + + Args: + recurrence (int): Number of recurrence of Criss Cross Attention + module. Default: 2. + """ + + def __init__(self, recurrence=2, **kwargs): + if CrissCrossAttention is None: + raise RuntimeError('Please install mmcv-full for ' + 'CrissCrossAttention ops') + super().__init__(num_convs=2, **kwargs) + self.recurrence = recurrence + self.cca = CrissCrossAttention(self.channels) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + for _ in range(self.recurrence): + output = self.cca(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/da_head.py b/mmseg/models/decode_heads/da_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d87214365d2f8695b60ccab0c1850669ff8dd295 --- /dev/null +++ b/mmseg/models/decode_heads/da_head.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale +from torch import Tensor, nn + +from mmseg.registry import MODELS +from mmseg.utils import SampleList, add_prefix +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PAM(_SelfAttentionBlock): + """Position Attention Module (PAM) + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + """ + + def __init__(self, in_channels, channels): + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=1, + key_query_norm=False, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=False, + with_out=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + out = super().forward(x, x) + + out = self.gamma(out) + x + return out + + +class CAM(nn.Module): + """Channel Attention Module (CAM)""" + + def __init__(self): + super().__init__() + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + batch_size, channels, height, width = x.size() + proj_query = x.view(batch_size, channels, -1) + proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max( + energy, -1, keepdim=True)[0].expand_as(energy) - energy + attention = F.softmax(energy_new, dim=-1) + proj_value = x.view(batch_size, channels, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(batch_size, channels, height, width) + + out = self.gamma(out) + x + return out + + +@MODELS.register_module() +class DAHead(BaseDecodeHead): + """Dual Attention Network for Scene Segmentation. + + This head is the implementation of `DANet + `_. + + Args: + pam_channels (int): The channels of Position Attention Module(PAM). + """ + + def __init__(self, pam_channels, **kwargs): + super().__init__(**kwargs) + self.pam_channels = pam_channels + self.pam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam = PAM(self.channels, pam_channels) + self.pam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + self.cam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam = CAM() + self.cam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + def pam_cls_seg(self, feat): + """PAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.pam_conv_seg(feat) + return output + + def cam_cls_seg(self, feat): + """CAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.cam_conv_seg(feat) + return output + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + pam_feat = self.pam_in_conv(x) + pam_feat = self.pam(pam_feat) + pam_feat = self.pam_out_conv(pam_feat) + pam_out = self.pam_cls_seg(pam_feat) + + cam_feat = self.cam_in_conv(x) + cam_feat = self.cam(cam_feat) + cam_feat = self.cam_out_conv(cam_feat) + cam_out = self.cam_cls_seg(cam_feat) + + feat_sum = pam_feat + cam_feat + pam_cam_out = self.cls_seg(feat_sum) + + return pam_cam_out, pam_out, cam_out + + def predict(self, inputs, batch_img_metas: List[dict], test_cfg, + **kwargs) -> List[Tensor]: + """Forward function for testing, only ``pam_cam`` is used.""" + seg_logits = self.forward(inputs)[0] + return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) + + def loss_by_feat(self, seg_logit: Tuple[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" + pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit + loss = dict() + loss.update( + add_prefix( + super().loss_by_feat(pam_cam_seg_logit, batch_data_samples), + 'pam_cam')) + loss.update( + add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples), + 'pam')) + loss.update( + add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples), + 'cam')) + return loss diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdbb24a1cc12e34fc11cf3af40f382a4c991ca9 --- /dev/null +++ b/mmseg/models/decode_heads/decode_head.py @@ -0,0 +1,358 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from abc import ABCMeta, abstractmethod +from typing import List, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.structures import build_pixel_sampler +from mmseg.utils import ConfigType, SampleList +from ..builder import build_loss +from ..losses import accuracy +from ..utils import resize + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + 1. The ``init_weights`` method is used to initialize decode_head's + model parameters. After segmentor initialization, ``init_weights`` + is triggered when ``segmentor.init_weights()`` is called externally. + + 2. The ``loss`` method is used to calculate the loss of decode_head, + which includes two steps: (1) the decode_head model performs forward + propagation to obtain the feature maps (2) The ``loss_by_feat`` method + is called based on the feature maps to calculate the loss. + + .. code:: text + + loss(): forward() -> loss_by_feat() + + 3. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) the decode_head model performs forward + propagation to obtain the feature maps (2) The ``predict_by_feat`` method + is called based on the feature maps to predict segmentation results + including post-processing. + + .. code:: text + + predict(): forward() -> predict_by_feat() + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + out_channels (int): Output channels of conv_seg. + threshold (float): Threshold for binary segmentation in the case of + `num_classes==1`. Default: None. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255. + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + out_channels=None, + threshold=None, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super().__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + + self.ignore_index = ignore_index + self.align_corners = align_corners + + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert `seg_logits` into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Forward function for training. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + return losses + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + seg_logits = self.forward(inputs) + + return self.predict_by_feat(seg_logits, batch_img_metas) + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + return torch.stack(gt_semantic_segs, dim=0) + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute segmentation loss. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + seg_logits = resize( + input=seg_logits, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logits, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logits, seg_label, ignore_index=self.ignore_index) + return loss + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + + seg_logits = resize( + input=seg_logits, + size=batch_img_metas[0]['img_shape'], + mode='bilinear', + align_corners=self.align_corners) + return seg_logits diff --git a/mmseg/models/decode_heads/dm_head.py b/mmseg/models/decode_heads/dm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7694abd8ac3a470d543c580bd97adceb5b647f7c --- /dev/null +++ b/mmseg/models/decode_heads/dm_head.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +class DCM(nn.Module): + """Dynamic Convolutional Module used in DMNet. + + Args: + filter_size (int): The filter size of generated convolution kernel + used in Dynamic Convolutional Module. + fusion (bool): Add one conv to fuse DCM output feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.filter_size = filter_size + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, + 0) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.norm_cfg is not None: + self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] + else: + self.norm = None + self.activate = build_activation_layer(self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + generated_filter = self.filter_gen_conv( + F.adaptive_avg_pool2d(x, self.filter_size)) + x = self.input_redu_conv(x) + b, c, h, w = x.shape + # [1, b * c, h, w], c = self.channels + x = x.view(1, b * c, h, w) + # [b * c, 1, filter_size, filter_size] + generated_filter = generated_filter.view(b * c, 1, self.filter_size, + self.filter_size) + pad = (self.filter_size - 1) // 2 + if (self.filter_size - 1) % 2 == 0: + p2d = (pad, pad, pad, pad) + else: + p2d = (pad + 1, pad, pad + 1, pad) + x = F.pad(input=x, pad=p2d, mode='constant', value=0) + # [1, b * c, h, w] + output = F.conv2d(input=x, weight=generated_filter, groups=b * c) + # [b, c, h, w] + output = output.view(b, c, h, w) + if self.norm is not None: + output = self.norm(output) + output = self.activate(output) + + if self.fusion: + output = self.fusion_conv(output) + + return output + + +@MODELS.register_module() +class DMHead(BaseDecodeHead): + """Dynamic Multi-scale Filters for Semantic Segmentation. + + This head is the implementation of + `DMNet `_. + + Args: + filter_sizes (tuple[int]): The size of generated convolutional filters + used in Dynamic Convolutional Module. Default: (1, 3, 5, 7). + fusion (bool): Add one conv to fuse DCM output feature. + """ + + def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): + super().__init__(**kwargs) + assert isinstance(filter_sizes, (list, tuple)) + self.filter_sizes = filter_sizes + self.fusion = fusion + dcm_modules = [] + for filter_size in self.filter_sizes: + dcm_modules.append( + DCM(filter_size, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.dcm_modules = nn.ModuleList(dcm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(filter_sizes) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + dcm_outs = [x] + for dcm_module in self.dcm_modules: + dcm_outs.append(dcm_module(x)) + dcm_outs = torch.cat(dcm_outs, dim=1) + output = self.bottleneck(dcm_outs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/dnl_head.py b/mmseg/models/decode_heads/dnl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..248c11814108d02e88fa7e0cada061b3366e33ff --- /dev/null +++ b/mmseg/models/decode_heads/dnl_head.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d +from torch import nn + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +class DisentangledNonLocal2d(NonLocal2d): + """Disentangled Non-Local Blocks. + + Args: + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, *arg, temperature, **kwargs): + super().__init__(*arg, **kwargs) + self.temperature = temperature + self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) + + def embedded_gaussian(self, theta_x, phi_x): + """Embedded gaussian with temperature.""" + + # NonLocal2d pairwise_weight: [N, HxW, HxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= torch.tensor( + theta_x.shape[-1], + dtype=torch.float, + device=pairwise_weight.device)**torch.tensor( + 0.5, device=pairwise_weight.device) + pairwise_weight /= torch.tensor( + self.temperature, device=pairwise_weight.device) + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def forward(self, x): + # x: [N, C, H, W] + n = x.size(0) + + # g_x: [N, HxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta_x: [N, HxW, C], phi_x: [N, C, HxW] + if self.mode == 'gaussian': + theta_x = x.view(n, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + if self.sub_sample: + phi_x = self.phi(x).view(n, self.in_channels, -1) + else: + phi_x = x.view(n, self.in_channels, -1) + elif self.mode == 'concatenation': + theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) + phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) + else: + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + # subtract mean + theta_x -= theta_x.mean(dim=-2, keepdim=True) + phi_x -= phi_x.mean(dim=-1, keepdim=True) + + pairwise_func = getattr(self, self.mode) + # pairwise_weight: [N, HxW, HxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # y: [N, HxW, C] + y = torch.matmul(pairwise_weight, g_x) + # y: [N, C, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + # unary_mask: [N, 1, HxW] + unary_mask = self.conv_mask(x) + unary_mask = unary_mask.view(n, 1, -1) + unary_mask = unary_mask.softmax(dim=-1) + # unary_x: [N, 1, C] + unary_x = torch.matmul(unary_mask, g_x) + # unary_x: [N, C, 1, 1] + unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( + n, self.inter_channels, 1, 1) + + output = x + self.conv_out(y + unary_x) + + return output + + +@MODELS.register_module() +class DNLHead(FCNHead): + """Disentangled Non-Local Neural Networks. + + This head is the implementation of `DNLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: False. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + temperature=0.05, + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.temperature = temperature + self.dnl_block = DisentangledNonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode, + temperature=self.temperature) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.dnl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/dpt_head.py b/mmseg/models/decode_heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d2cfd89daa4df48601e930cfd158dcf3c9a6a837 --- /dev/null +++ b/mmseg/models/decode_heads/dpt_head.py @@ -0,0 +1,294 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Linear, build_activation_layer +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ReassembleBlocks(BaseModule): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels=768, + out_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + init_cfg=None): + super().__init__(init_cfg) + + assert readout_type in ['ignore', 'add', 'project'] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList([ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + if self.readout_type == 'project': + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + Linear(2 * in_channels, in_channels), + build_activation_layer(dict(type='GELU')))) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == 'project': + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == 'add': + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(BaseModule): + """ResidualConvUnit, pre-activate residual unit. + + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + stride=1, + dilation=1, + init_cfg=None): + super().__init__(init_cfg) + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(BaseModule): + """FeatureFusionBlock, merge feature map from different stages. + + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + expand=False, + align_corners=True, + init_cfg=None): + super().__init__(init_cfg) + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule( + self.in_channels, + self.out_channels, + kernel_size=1, + act_cfg=None, + bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize( + inputs[1], + size=(x.shape[2], x.shape[3]), + mode='bilinear', + align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize( + x, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + x = self.project(x) + return x + + +@MODELS.register_module() +class DPTHead(BaseDecodeHead): + """Vision Transformers for Dense Prediction. + + This head is implemented of `DPT `_. + + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + act_cfg (dict): The activation config for residual conv unit. + Default dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + """ + + def __init__(self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + expand_channels=False, + act_cfg=dict(type='ReLU'), + norm_cfg=dict(type='BN'), + **kwargs): + super().__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, + post_process_channels, + readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel + for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append( + ConvModule( + channel, + self.channels, + kernel_size=3, + padding=1, + act_cfg=None, + bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append( + FeatureFusionBlock(self.channels, act_cfg, norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule( + self.channels, + self.channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + + def forward(self, inputs): + assert len(inputs) == self.num_reassemble_blocks + x = self._transform_inputs(inputs) + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.cls_seg(out) + return out diff --git a/mmseg/models/decode_heads/ema_head.py b/mmseg/models/decode_heads/ema_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8dbb0c29b9b533dad962e48d71ae055f20aa07 --- /dev/null +++ b/mmseg/models/decode_heads/ema_head.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +def reduce_mean(tensor): + """Reduce mean when distributed training.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +class EMAModule(nn.Module): + """Expectation Maximization Attention Module used in EMANet. + + Args: + channels (int): Channels of the whole module. + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + """ + + def __init__(self, channels, num_bases, num_stages, momentum): + super().__init__() + assert num_stages >= 1, 'num_stages must be at least 1!' + self.num_bases = num_bases + self.num_stages = num_stages + self.momentum = momentum + + bases = torch.zeros(1, channels, self.num_bases) + bases.normal_(0, math.sqrt(2. / self.num_bases)) + # [1, channels, num_bases] + bases = F.normalize(bases, dim=1, p=2) + self.register_buffer('bases', bases) + + def forward(self, feats): + """Forward function.""" + batch_size, channels, height, width = feats.size() + # [batch_size, channels, height*width] + feats = feats.view(batch_size, channels, height * width) + # [batch_size, channels, num_bases] + bases = self.bases.repeat(batch_size, 1, 1) + + with torch.no_grad(): + for i in range(self.num_stages): + # [batch_size, height*width, num_bases] + attention = torch.einsum('bcn,bck->bnk', feats, bases) + attention = F.softmax(attention, dim=2) + # l1 norm + attention_normed = F.normalize(attention, dim=1, p=1) + # [batch_size, channels, num_bases] + bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + + feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) + feats_recon = feats_recon.view(batch_size, channels, height, width) + + if self.training: + bases = bases.mean(dim=0, keepdim=True) + bases = reduce_mean(bases) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + self.bases = (1 - + self.momentum) * self.bases + self.momentum * bases + + return feats_recon + + +@MODELS.register_module() +class EMAHead(BaseDecodeHead): + """Expectation Maximization Attention Networks for Semantic Segmentation. + + This head is the implementation of `EMANet + `_. + + Args: + ema_channels (int): EMA module channels + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + concat_input (bool): Whether concat the input and output of convs + before classification layer. Default: True + momentum (float): Momentum to update the base. Default: 0.1. + """ + + def __init__(self, + ema_channels, + num_bases, + num_stages, + concat_input=True, + momentum=0.1, + **kwargs): + super().__init__(**kwargs) + self.ema_channels = ema_channels + self.num_bases = num_bases + self.num_stages = num_stages + self.concat_input = concat_input + self.momentum = momentum + self.ema_module = EMAModule(self.ema_channels, self.num_bases, + self.num_stages, self.momentum) + + self.ema_in_conv = ConvModule( + self.in_channels, + self.ema_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # project (0, inf) -> (-inf, inf) + self.ema_mid_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=None, + act_cfg=None) + for param in self.ema_mid_conv.parameters(): + param.requires_grad = False + + self.ema_out_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.bottleneck = ConvModule( + self.ema_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.ema_in_conv(x) + identity = feats + feats = self.ema_mid_conv(feats) + recon = self.ema_module(feats) + recon = F.relu(recon, inplace=True) + recon = self.ema_out_conv(recon) + output = F.relu(identity + recon, inplace=True) + output = self.bottleneck(output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ef48fb6995365ba374b29ea265608087500f27dc --- /dev/null +++ b/mmseg/models/decode_heads/enc_head.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_norm_layer +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import ConfigType, SampleList +from ..builder import build_loss +from ..utils import Encoding, resize +from .decode_head import BaseDecodeHead + + +class EncModule(nn.Module): + """Encoding Module used in EncNet. + + Args: + in_channels (int): Input channels. + num_codes (int): Number of code words. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): + super().__init__() + self.encoding_project = ConvModule( + in_channels, + in_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # TODO: resolve this hack + # change to 1d + if norm_cfg is not None: + encoding_norm_cfg = norm_cfg.copy() + if encoding_norm_cfg['type'] in ['BN', 'IN']: + encoding_norm_cfg['type'] += '1d' + else: + encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( + '2d', '1d') + else: + # fallback to BN1d + encoding_norm_cfg = dict(type='BN1d') + self.encoding = nn.Sequential( + Encoding(channels=in_channels, num_codes=num_codes), + build_norm_layer(encoding_norm_cfg, num_codes)[1], + nn.ReLU(inplace=True)) + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels), nn.Sigmoid()) + + def forward(self, x): + """Forward function.""" + encoding_projection = self.encoding_project(x) + encoding_feat = self.encoding(encoding_projection).mean(dim=1) + batch_size, channels, _, _ = x.size() + gamma = self.fc(encoding_feat) + y = gamma.view(batch_size, channels, 1, 1) + output = F.relu_(x + x * y) + return encoding_feat, output + + +@MODELS.register_module() +class EncHead(BaseDecodeHead): + """Context Encoding for Semantic Segmentation. + + This head is the implementation of `EncNet + `_. + + Args: + num_codes (int): Number of code words. Default: 32. + use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to + regularize the training. Default: True. + add_lateral (bool): Whether use lateral connection to fuse features. + Default: False. + loss_se_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss', use_sigmoid=True). + """ + + def __init__(self, + num_codes=32, + use_se_loss=True, + add_lateral=False, + loss_se_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=0.2), + **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.use_se_loss = use_se_loss + self.add_lateral = add_lateral + self.num_codes = num_codes + self.bottleneck = ConvModule( + self.in_channels[-1], + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if add_lateral: + self.lateral_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the last one + self.lateral_convs.append( + ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.fusion = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.enc_module = EncModule( + self.channels, + num_codes=num_codes, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.use_se_loss: + self.loss_se_decode = build_loss(loss_se_decode) + self.se_layer = nn.Linear(self.channels, self.num_classes) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + feat = self.bottleneck(inputs[-1]) + if self.add_lateral: + laterals = [ + resize( + lateral_conv(inputs[i]), + size=feat.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + feat = self.fusion(torch.cat([feat, *laterals], 1)) + encode_feat, output = self.enc_module(feat) + output = self.cls_seg(output) + if self.use_se_loss: + se_output = self.se_layer(encode_feat) + return output, se_output + else: + return output + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType): + """Forward function for testing, ignore se_loss.""" + if self.use_se_loss: + seg_logits = self.forward(inputs)[0] + else: + seg_logits = self.forward(inputs) + return self.predict_by_feat(seg_logits, batch_img_metas) + + @staticmethod + def _convert_to_onehot_labels(seg_label, num_classes): + """Convert segmentation label to onehot. + + Args: + seg_label (Tensor): Segmentation label of shape (N, H, W). + num_classes (int): Number of classes. + + Returns: + Tensor: Onehot labels of shape (N, num_classes). + """ + + batch_size = seg_label.size(0) + onehot_labels = seg_label.new_zeros((batch_size, num_classes)) + for i in range(batch_size): + hist = seg_label[i].float().histc( + bins=num_classes, min=0, max=num_classes - 1) + onehot_labels[i] = hist > 0 + return onehot_labels + + def loss_by_feat(self, seg_logit: Tuple[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + """Compute segmentation and semantic encoding loss.""" + seg_logit, se_seg_logit = seg_logit + loss = dict() + loss.update(super().loss_by_feat(seg_logit, batch_data_samples)) + + seg_label = self._stack_batch_gt(batch_data_samples) + se_loss = self.loss_se_decode( + se_seg_logit, + self._convert_to_onehot_labels(seg_label, self.num_classes)) + loss['loss_se'] = se_loss + return loss diff --git a/mmseg/models/decode_heads/fcn_head.py b/mmseg/models/decode_heads/fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..341801888368d307da6b926a2c89f72b6b06476d --- /dev/null +++ b/mmseg/models/decode_heads/fcn_head.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class FCNHead(BaseDecodeHead): + """Fully Convolution Networks for Semantic Segmentation. + + This head is implemented of `FCNNet `_. + + Args: + num_convs (int): Number of convs in the head. Default: 2. + kernel_size (int): The kernel size for convs in the head. Default: 3. + concat_input (bool): Whether concat the input and output of convs + before classification layer. + dilation (int): The dilation rate for convs in the head. Default: 1. + """ + + def __init__(self, + num_convs=2, + kernel_size=3, + concat_input=True, + dilation=1, + **kwargs): + assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) + self.num_convs = num_convs + self.concat_input = concat_input + self.kernel_size = kernel_size + super().__init__(**kwargs) + if num_convs == 0: + assert self.in_channels == self.channels + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + ConvModule( + self.in_channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + for i in range(num_convs - 1): + convs.append( + ConvModule( + self.channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + feats = self.convs(x) + if self.concat_input: + feats = self.conv_cat(torch.cat([x, feats], dim=1)) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..25f481fe81c5f4f0aa37903aaf135dc63c930bf8 --- /dev/null +++ b/mmseg/models/decode_heads/fpn_head.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import Upsample, resize +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class FPNHead(BaseDecodeHead): + """Panoptic Feature Pyramid Networks. + + This head is the implementation of `Semantic FPN + `_. + + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, feature_strides, **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.in_channels[i] if k == 0 else self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + def forward(self, inputs): + + x = self._transform_inputs(inputs) + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/gc_head.py b/mmseg/models/decode_heads/gc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..14f0ef021c1143d493e17f347f1f4da1145470b8 --- /dev/null +++ b/mmseg/models/decode_heads/gc_head.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ContextBlock + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class GCHead(FCNHead): + """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. + + This head is the implementation of `GCNet + `_. + + Args: + ratio (float): Multiplier of channels ratio. Default: 1/4. + pooling_type (str): The pooling type of context aggregation. + Options are 'att', 'avg'. Default: 'avg'. + fusion_types (tuple[str]): The fusion type for feature fusion. + Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) + """ + + def __init__(self, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.ratio = ratio + self.pooling_type = pooling_type + self.fusion_types = fusion_types + self.gc_block = ContextBlock( + in_channels=self.channels, + ratio=self.ratio, + pooling_type=self.pooling_type, + fusion_types=self.fusion_types) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.gc_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/ham_head.py b/mmseg/models/decode_heads/ham_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d80025f77d261be18369d54ea85f53717b2c15d1 --- /dev/null +++ b/mmseg/models/decode_heads/ham_head.py @@ -0,0 +1,257 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class Matrix_Decomposition_2D_Base(nn.Module): + """Base class of 2D Matrix Decomposition. + + Args: + MD_S (int): The number of spatial coefficient in + Matrix Decomposition, it may be used for calculation + of the number of latent dimension D in Matrix + Decomposition. Defaults: 1. + MD_R (int): The number of latent dimension R in + Matrix Decomposition. Defaults: 64. + train_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in training. Defaults: 6. + eval_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in evaluation. Defaults: 7. + inv_t (int): Inverted multiple number to make coefficient + smaller in softmax. Defaults: 100. + rand_init (bool): Whether to initialize randomly. + Defaults: True. + """ + + def __init__(self, + MD_S=1, + MD_R=64, + train_steps=6, + eval_steps=7, + inv_t=100, + rand_init=True): + super().__init__() + + self.S = MD_S + self.R = MD_R + + self.train_steps = train_steps + self.eval_steps = eval_steps + + self.inv_t = inv_t + + self.rand_init = rand_init + + def _build_bases(self, B, S, D, R, cuda=False): + raise NotImplementedError + + def local_step(self, x, bases, coef): + raise NotImplementedError + + def local_inference(self, x, bases): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def compute_coef(self, x, bases, coef): + raise NotImplementedError + + def forward(self, x, return_bases=False): + """Forward Function.""" + B, C, H, W = x.shape + + # (B, C, H, W) -> (B * S, D, N) + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + cuda = 'cuda' in str(x.device) + if not self.rand_init and not hasattr(self, 'bases'): + bases = self._build_bases(1, self.S, D, self.R, cuda=cuda) + self.register_buffer('bases', bases) + + # (S, D, R) -> (B * S, D, R) + if self.rand_init: + bases = self._build_bases(B, self.S, D, self.R, cuda=cuda) + else: + bases = self.bases.repeat(B, 1, 1) + + bases, coef = self.local_inference(x, bases) + + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) + + return x + + +class NMF2D(Matrix_Decomposition_2D_Base): + """Non-negative Matrix Factorization (NMF) module. + + It is inherited from ``Matrix_Decomposition_2D_Base`` module. + """ + + def __init__(self, args=dict()): + super().__init__(**args) + + self.inv_t = 1 + + def _build_bases(self, B, S, D, R, cuda=False): + """Build bases in initialization.""" + if cuda: + bases = torch.rand((B * S, D, R)).cuda() + else: + bases = torch.rand((B * S, D, R)) + + bases = F.normalize(bases, dim=1) + + return bases + + def local_step(self, x, bases, coef): + """Local step in iteration to renew bases and coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + + return bases, coef + + def compute_coef(self, x, bases, coef): + """Compute coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + coef = coef * numerator / (denominator + 1e-6) + + return coef + + +class Hamburger(nn.Module): + """Hamburger Module. It consists of one slice of "ham" (matrix + decomposition) and two slices of "bread" (linear transformation). + + Args: + ham_channels (int): Input and output channels of feature. + ham_kwargs (dict): Config of matrix decomposition module. + norm_cfg (dict | None): Config of norm layers. + """ + + def __init__(self, + ham_channels=512, + ham_kwargs=dict(), + norm_cfg=None, + **kwargs): + super().__init__() + + self.ham_in = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None) + + self.ham = NMF2D(ham_kwargs) + + self.ham_out = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + + def forward(self, x): + enjoy = self.ham_in(x) + enjoy = F.relu(enjoy, inplace=True) + enjoy = self.ham(enjoy) + enjoy = self.ham_out(enjoy) + ham = F.relu(x + enjoy, inplace=True) + + return ham + + +@MODELS.register_module() +class LightHamHead(BaseDecodeHead): + """SegNeXt decode head. + + This decode head is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Specifically, LightHamHead is inspired by HamNet from + `Is Attention Better Than Matrix Decomposition? + `. + + Args: + ham_channels (int): input channels for Hamburger. + Defaults: 512. + ham_kwargs (int): kwagrs for Ham. Defaults: dict(). + """ + + def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.ham_channels = ham_channels + + self.squeeze = ConvModule( + sum(self.in_channels), + self.ham_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs) + + self.align = ConvModule( + self.ham_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + inputs = [ + resize( + level, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for level in inputs + ] + + inputs = torch.cat(inputs, dim=1) + # apply a conv block to squeeze feature map + x = self.squeeze(inputs) + # apply hamburger module + x = self.hamburger(x) + + # apply a conv block to align feature map + output = self.align(x) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/isa_head.py b/mmseg/models/decode_heads/isa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..355f215f39007d0153c2fdb3b22a40e7f11a01e3 --- /dev/null +++ b/mmseg/models/decode_heads/isa_head.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Self-Attention Module. + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict | None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.output_project = self.build_project( + in_channels, + in_channels, + num_convs=1, + use_conv_module=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + """Forward function.""" + context = super().forward(x, x) + return self.output_project(context) + + +@MODELS.register_module() +class ISAHead(BaseDecodeHead): + """Interlaced Sparse Self-Attention for Semantic Segmentation. + + This head is the implementation of `ISA + `_. + + Args: + isa_channels (int): The channels of ISA Module. + down_factor (tuple[int]): The local group size of ISA. + """ + + def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): + super().__init__(**kwargs) + self.down_factor = down_factor + + self.in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.global_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.local_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.out_conv = ConvModule( + self.channels * 2, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x_ = self._transform_inputs(inputs) + x = self.in_conv(x_) + residual = x + + n, c, h, w = x.size() + loc_h, loc_w = self.down_factor # size of local group in H- and W-axes + glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w) + pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w + if pad_h > 0 or pad_w > 0: # pad if the size is not divisible + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2) + x = F.pad(x, padding) + + # global relation + x = x.view(n, c, glb_h, loc_h, glb_w, loc_w) + # do permutation to gather global group + x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w) + x = x.reshape(-1, c, glb_h, glb_w) + # apply attention within each global group + x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w) + + # local relation + x = x.view(n, loc_h, loc_w, c, glb_h, glb_w) + # do permutation to gather local group + x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w) + x = x.reshape(-1, c, loc_h, loc_w) + # apply attention within each local group + x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w) + + # permute each pixel back to its original position + x = x.view(n, glb_h, glb_w, c, loc_h, loc_w) + x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w) + x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w) + if pad_h > 0 or pad_w > 0: # remove padding + x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w] + + x = self.out_conv(torch.cat([x, residual], dim=1)) + out = self.cls_seg(x) + + return out diff --git a/mmseg/models/decode_heads/knet_head.py b/mmseg/models/decode_heads/knet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..82d3a2807685cdc896c881095f46fd50a450018e --- /dev/null +++ b/mmseg/models/decode_heads/knet_head.py @@ -0,0 +1,461 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention, + build_transformer_layer) +from mmengine.logging import print_log +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from mmseg.utils import SampleList + + +@MODELS.register_module() +class KernelUpdator(nn.Module): + """Dynamic Kernel Updator in Kernel Update Head. + + Args: + in_channels (int): The number of channels of input feature map. + Default: 256. + feat_channels (int): The number of middle-stage channels in + the kernel updator. Default: 64. + out_channels (int): The number of output channels. + gate_sigmoid (bool): Whether use sigmoid function in gate + mechanism. Default: True. + gate_norm_act (bool): Whether add normalization and activation + layer in gate mechanism. Default: False. + activate_out: Whether add activation after gate mechanism. + Default: False. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='LN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='ReLU', inplace=True), + ): + super().__init__() + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.gate_sigmoid = gate_sigmoid + self.gate_norm_act = gate_norm_act + self.activate_out = activate_out + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.feat_channels + self.num_params_out = self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + self.input_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out, + 1) + self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + if self.gate_norm_act: + self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, update_feature, input_feature): + """Forward function of KernelUpdator. + + Args: + update_feature (torch.Tensor): Feature map assembled from + each group. It would be reshaped with last dimension + shape: `self.in_channels`. + input_feature (torch.Tensor): Intermediate feature + with shape: (N, num_classes, conv_kernel_size**2, channels). + Returns: + Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is + the number of classes, C1 and C2 are the feature map channels of + KernelUpdateHead and KernelUpdator, respectively. + """ + + update_feature = update_feature.reshape(-1, self.in_channels) + num_proposals = update_feature.size(0) + # dynamic_layer works for + # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper + parameters = self.dynamic_layer(update_feature) + param_in = parameters[:, :self.num_params_in].view( + -1, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels) + + # input_layer works for + # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper + input_feats = self.input_layer( + input_feature.reshape(num_proposals, -1, self.feat_channels)) + input_in = input_feats[..., :self.num_params_in] + input_out = input_feats[..., -self.num_params_out:] + + # `gate_feats` is F^G in K-Net paper + gate_feats = input_in * param_in.unsqueeze(-2) + if self.gate_norm_act: + gate_feats = self.activation(self.gate_norm(gate_feats)) + + input_gate = self.input_norm_in(self.input_gate(gate_feats)) + update_gate = self.norm_in(self.update_gate(gate_feats)) + if self.gate_sigmoid: + input_gate = input_gate.sigmoid() + update_gate = update_gate.sigmoid() + param_out = self.norm_out(param_out) + input_out = self.input_norm_out(input_out) + + if self.activate_out: + param_out = self.activation(param_out) + input_out = self.activation(input_out) + + # Gate mechanism. Eq.(5) in original paper. + # param_out has shape (batch_size, feat_channels, out_channels) + features = update_gate * param_out.unsqueeze( + -2) + input_gate * input_out + + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +@MODELS.register_module() +class KernelUpdateHead(nn.Module): + """Kernel Update Head in K-Net. + + Args: + num_classes (int): Number of classes. Default: 150. + num_ffn_fcs (int): The number of fully-connected layers in + FFNs. Default: 2. + num_heads (int): The number of parallel attention heads. + Default: 8. + num_mask_fcs (int): The number of fully connected layers for + mask prediction. Default: 3. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 2048. + in_channels (int): The number of channels of input feature map. + Default: 256. + out_channels (int): The number of output channels. + Default: 256. + dropout (float): The Probability of an element to be + zeroed in MultiheadAttention and FFN. Default 0.0. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + ffn_act_cfg (dict): Config of activation layers in FFN. + Default: dict(type='ReLU'). + conv_kernel_size (int): The kernel size of convolution in + Kernel Update Head for dynamic kernel updation. + Default: 1. + feat_transform_cfg (dict | None): Config of feature transform. + Default: None. + kernel_init (bool): Whether initiate mask kernel in mask head. + Default: False. + with_ffn (bool): Whether add FFN in kernel update head. + Default: True. + feat_gather_stride (int): Stride of convolution in feature transform. + Default: 1. + mask_transform_stride (int): Stride of mask transform. + Default: 1. + kernel_updator_cfg (dict): Config of kernel updator. + Default: dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')). + """ + + def __init__(self, + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type='ReLU', inplace=True), + conv_kernel_size=1, + feat_transform_cfg=None, + kernel_init=False, + with_ffn=True, + feat_gather_stride=1, + mask_transform_stride=1, + kernel_updator_cfg=dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.out_channels = out_channels + self.fp16_enabled = False + self.dropout = dropout + self.num_heads = num_heads + self.kernel_init = kernel_init + self.with_ffn = with_ffn + self.conv_kernel_size = conv_kernel_size + self.feat_gather_stride = feat_gather_stride + self.mask_transform_stride = mask_transform_stride + + self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, + num_heads, dropout) + self.attention_norm = build_norm_layer( + dict(type='LN'), in_channels * conv_kernel_size**2)[1] + self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) + + if feat_transform_cfg is not None: + kernel_size = feat_transform_cfg.pop('kernel_size', 1) + transform_channels = in_channels + self.feat_transform = ConvModule( + transform_channels, + in_channels, + kernel_size, + stride=feat_gather_stride, + padding=int(feat_gather_stride // 2), + **feat_transform_cfg) + else: + self.feat_transform = None + + if self.with_ffn: + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + dropout=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.mask_fcs = nn.ModuleList() + for _ in range(num_mask_fcs): + self.mask_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(build_activation_layer(act_cfg)) + + self.fc_mask = nn.Linear(in_channels, out_channels) + + def init_weights(self): + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.kernel_init: + print_log( + 'mask kernel in mask head is normal initialized by std 0.01') + nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) + + def forward(self, x, proposal_feat, mask_preds, mask_shape=None): + """Forward function of Dynamic Instance Interactive Head. + + Args: + x (Tensor): Feature map from FPN with shape + (batch_size, feature_dimensions, H , W). + proposal_feat (Tensor): Intermediate feature get from + diihead in last stage, has shape + (batch_size, num_proposals, feature_dimensions) + mask_preds (Tensor): mask prediction from the former stage in shape + (batch_size, num_proposals, H, W). + + Returns: + Tuple: The first tensor is predicted mask with shape + (N, num_classes, H, W), the second tensor is dynamic kernel + with shape (N, num_classes, channels, K, K). + """ + N, num_proposals = proposal_feat.shape[:2] + if self.feat_transform is not None: + x = self.feat_transform(x) + + C, H, W = x.shape[-3:] + + mask_h, mask_w = mask_preds.shape[-2:] + if mask_h != H or mask_w != W: + gather_mask = F.interpolate( + mask_preds, (H, W), align_corners=False, mode='bilinear') + else: + gather_mask = mask_preds + + sigmoid_masks = gather_mask.softmax(dim=1) + + # Group Feature Assembling. Eq.(3) in original paper. + # einsum is faster than bmm by 30% + x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) + + # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] + proposal_feat = proposal_feat.reshape(N, num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv(x_feat, proposal_feat) + + # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] + obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) + obj_feat = self.attention_norm(self.attention(obj_feat)) + # [N, B, K*K*C] -> [B, N, K*K*C] + obj_feat = obj_feat.permute(1, 0, 2) + + # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] + obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) + + # FFN + if self.with_ffn: + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + mask_feat = obj_feat + + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + + # [B, N, K*K, C] -> [B, N, C, K*K] + mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) + + if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + mask_x = F.interpolate( + x, scale_factor=0.5, mode='bilinear', align_corners=False) + H, W = mask_x.shape[-2:] + else: + mask_x = x + # group conv is 5x faster than unfold and uses about 1/5 memory + # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms + # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 + # but in real training group conv is slower than concat batch + # so we keep using concat batch. + # fold_x = F.unfold( + # mask_x, + # self.conv_kernel_size, + # padding=int(self.conv_kernel_size // 2)) + # mask_feat = mask_feat.reshape(N, num_proposals, -1) + # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) + # [B, N, C, K*K] -> [B*N, C, K, K] + mask_feat = mask_feat.reshape(N, num_proposals, C, + self.conv_kernel_size, + self.conv_kernel_size) + # [B, C, H, W] -> [1, B*C, H, W] + new_mask_preds = [] + for i in range(N): + new_mask_preds.append( + F.conv2d( + mask_x[i:i + 1], + mask_feat[i], + padding=int(self.conv_kernel_size // 2))) + + new_mask_preds = torch.cat(new_mask_preds, dim=0) + new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) + if self.mask_transform_stride == 2: + new_mask_preds = F.interpolate( + new_mask_preds, + scale_factor=2, + mode='bilinear', + align_corners=False) + + if mask_shape is not None and mask_shape[0] != H: + new_mask_preds = F.interpolate( + new_mask_preds, + mask_shape, + align_corners=False, + mode='bilinear') + + return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( + N, num_proposals, self.in_channels, self.conv_kernel_size, + self.conv_kernel_size) + + +@MODELS.register_module() +class IterativeDecodeHead(BaseDecodeHead): + """K-Net: Towards Unified Image Segmentation. + + This head is the implementation of + `K-Net: `_. + + Args: + num_stages (int): The number of stages (kernel update heads) + in IterativeDecodeHead. Default: 3. + kernel_generate_head:(dict): Config of kernel generate head which + generate mask predictions, dynamic kernels and class predictions + for next kernel update heads. + kernel_update_head (dict): Config of kernel update head which refine + dynamic kernels and class predictions iteratively. + + """ + + def __init__(self, num_stages, kernel_generate_head, kernel_update_head, + **kwargs): + # ``IterativeDecodeHead`` would skip initialization of + # ``BaseDecodeHead`` which would be called when building + # ``self.kernel_generate_head``. + super(BaseDecodeHead, self).__init__(**kwargs) + assert num_stages == len(kernel_update_head) + self.num_stages = num_stages + self.kernel_generate_head = MODELS.build(kernel_generate_head) + self.kernel_update_head = nn.ModuleList() + self.align_corners = self.kernel_generate_head.align_corners + self.num_classes = self.kernel_generate_head.num_classes + self.input_transform = self.kernel_generate_head.input_transform + self.ignore_index = self.kernel_generate_head.ignore_index + self.out_channels = self.num_classes + + for head_cfg in kernel_update_head: + self.kernel_update_head.append(MODELS.build(head_cfg)) + + def forward(self, inputs): + """Forward function.""" + feats = self.kernel_generate_head._forward_feature(inputs) + sem_seg = self.kernel_generate_head.cls_seg(feats) + seg_kernels = self.kernel_generate_head.conv_seg.weight.clone() + seg_kernels = seg_kernels[None].expand( + feats.size(0), *seg_kernels.size()) + + stage_segs = [sem_seg] + for i in range(self.num_stages): + sem_seg, seg_kernels = self.kernel_update_head[i](feats, + seg_kernels, + sem_seg) + stage_segs.append(sem_seg) + if self.training: + return stage_segs + # only return the prediction of the last stage during testing + return stage_segs[-1] + + def loss_by_feat(self, seg_logits: List[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + losses = dict() + for i, logit in enumerate(seg_logits): + loss = self.kernel_generate_head.loss_by_feat( + logit, batch_data_samples) + for k, v in loss.items(): + losses[f'{k}.s{i}'] = v + + return losses diff --git a/mmseg/models/decode_heads/lraspp_head.py b/mmseg/models/decode_heads/lraspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2465f27522e6ff106fcdf94a46aab42881260a --- /dev/null +++ b/mmseg/models/decode_heads/lraspp_head.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils import is_tuple_of + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class LRASPPHead(BaseDecodeHead): + """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. + + This head is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + branch_channels (tuple[int]): The number of output channels in every + each branch. Default: (32, 64). + """ + + def __init__(self, branch_channels=(32, 64), **kwargs): + super().__init__(**kwargs) + if self.input_transform != 'multiple_select': + raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' + f'must be \'multiple_select\'. But received ' + f'\'{self.input_transform}\'') + assert is_tuple_of(branch_channels, int) + assert len(branch_channels) == len(self.in_channels) - 1 + self.branch_channels = branch_channels + + self.convs = nn.Sequential() + self.conv_ups = nn.Sequential() + for i in range(len(branch_channels)): + self.convs.add_module( + f'conv{i}', + nn.Conv2d( + self.in_channels[i], branch_channels[i], 1, bias=False)) + self.conv_ups.add_module( + f'conv_up{i}', + ConvModule( + self.channels + branch_channels[i], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False)) + + self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) + + self.aspp_conv = ConvModule( + self.in_channels[-1], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False) + self.image_pool = nn.Sequential( + nn.AvgPool2d(kernel_size=49, stride=(16, 20)), + ConvModule( + self.in_channels[2], + self.channels, + 1, + act_cfg=dict(type='Sigmoid'), + bias=False)) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + x = inputs[-1] + + x = self.aspp_conv(x) * resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = self.conv_up_input(x) + + for i in range(len(self.branch_channels) - 1, -1, -1): + x = resize( + x, + size=inputs[i].size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = torch.cat([x, self.convs[i](inputs[i])], 1) + x = self.conv_ups[i](x) + + return self.cls_seg(x) diff --git a/mmseg/models/decode_heads/mask2former_head.py b/mmseg/models/decode_heads/mask2former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0135af0645830f5cf98595318c4bb20220e64b0b --- /dev/null +++ b/mmseg/models/decode_heads/mask2former_head.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +try: + from mmdet.models.dense_heads import \ + Mask2FormerHead as MMDET_Mask2FormerHead +except ModuleNotFoundError: + MMDET_Mask2FormerHead = BaseModule + +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures.seg_data_sample import SegDataSample +from mmseg.utils import ConfigType, SampleList + + +@MODELS.register_module() +class Mask2FormerHead(MMDET_Mask2FormerHead): + """Implements the Mask2Former head. + + See `Mask2Former: Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + num_classes (int): Number of classes. Default: 150. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + ignore_index (int): The label index to be ignored. Default: 255. + """ + + def __init__(self, + num_classes, + align_corners=False, + ignore_index=255, + **kwargs): + super().__init__(**kwargs) + + self.num_classes = num_classes + self.align_corners = align_corners + self.out_channels = num_classes + self.ignore_index = ignore_index + + feat_channels = kwargs['feat_channels'] + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != self.ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros( + (0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg).long() + else: + gt_masks = torch.stack(masks).squeeze(1).long() + + instance_data = InstanceData(labels=gt_labels, masks=gt_masks) + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch_data_samples) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_img_metas (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + test_cfg (ConfigType): Test config. + + Returns: + Tensor: A tensor of segmentation mask. + """ + batch_data_samples = [ + SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas + ] + + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + if 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'] + else: + size = batch_img_metas[0]['img_shape'] + # upsample mask + mask_pred_results = F.interpolate( + mask_pred_results, size=size, mode='bilinear', align_corners=False) + cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred) + return seg_logits diff --git a/mmseg/models/decode_heads/maskformer_head.py b/mmseg/models/decode_heads/maskformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7881b5d29f85ba29ae818b25aed84162b912f2 --- /dev/null +++ b/mmseg/models/decode_heads/maskformer_head.py @@ -0,0 +1,178 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +try: + from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead +except ModuleNotFoundError: + MMDET_MaskFormerHead = BaseModule + +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures.seg_data_sample import SegDataSample +from mmseg.utils import ConfigType, SampleList + + +@MODELS.register_module() +class MaskFormerHead(MMDET_MaskFormerHead): + """Implements the MaskFormer head. + + See `Per-Pixel Classification is Not All You Need for Semantic Segmentation + `_ for details. + + Args: + num_classes (int): Number of classes. Default: 150. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + ignore_index (int): The label index to be ignored. Default: 255. + """ + + def __init__(self, + threshold: float = 0.5, + num_classes: int = 150, + align_corners: bool = False, + ignore_index: int = 255, + **kwargs) -> None: + super().__init__(**kwargs) + self.threshold = threshold + self.out_channels = kwargs['out_channels'] + self.align_corners = True + self.num_classes = num_classes + self.align_corners = align_corners + self.out_channels = num_classes + self.ignore_index = ignore_index + + feat_channels = kwargs['feat_channels'] + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + for data_sample in batch_data_samples: + # Add `batch_input_shape` in metainfo of data_sample, which would + # be used in MaskFormerHead of MMDetection. + metainfo = data_sample.metainfo + metainfo['batch_input_shape'] = metainfo['img_shape'] + data_sample.set_metainfo(metainfo) + batch_img_metas.append(data_sample.metainfo) + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != self.ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros((0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg) + else: + gt_masks = torch.stack(masks).squeeze(1) + + if hasattr(data_sample, 'instances_data'): + instance_data = InstanceData(labels=data_sample.instances_label, masks=data_sample.instances_data.long()) + else: + instance_data = InstanceData(labels=gt_labels, masks=gt_masks.long()) + + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch_data_samples) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_img_metas (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + test_cfg (ConfigType): Test config. + + Returns: + Tensor: A tensor of segmentation mask. + """ + + batch_data_samples = [] + for metainfo in batch_img_metas: + metainfo['batch_input_shape'] = metainfo['img_shape'] + batch_data_samples.append(SegDataSample(metainfo=metainfo)) + # Forward function of MaskFormerHead from MMDetection needs + # 'batch_data_samples' as inputs, which is image shape actually. + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=img_shape, + mode='bilinear', + align_corners=False) + + # semantic inference + cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_logits diff --git a/mmseg/models/decode_heads/nl_head.py b/mmseg/models/decode_heads/nl_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffcc2a2f081127f109deb0ad5bd1be0d6f50493 --- /dev/null +++ b/mmseg/models/decode_heads/nl_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class NLHead(FCNHead): + """Non-local Neural Networks. + + This head is the implementation of `NLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: True. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.nl_block = NonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.nl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/ocr_head.py b/mmseg/models/decode_heads/ocr_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9afe37bebd6c16ff184dc482ae358eb7ae9a093a --- /dev/null +++ b/mmseg/models/decode_heads/ocr_head.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from ..utils import resize +from .cascade_decode_head import BaseCascadeDecodeHead + + +class SpatialGatherModule(nn.Module): + """Aggregate the context features according to the initial predicted + probability distribution. + + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, feats, probs): + """Forward function.""" + batch_size, num_classes, height, width = probs.size() + channels = feats.size(1) + probs = probs.view(batch_size, num_classes, -1) + feats = feats.view(batch_size, channels, -1) + # [batch_size, height*width, num_classes] + feats = feats.permute(0, 2, 1) + # [batch_size, channels, height*width] + probs = F.softmax(self.scale * probs, dim=2) + # [batch_size, channels, num_classes] + ocr_context = torch.matmul(probs, feats) + ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) + return ocr_context + + +class ObjectAttentionBlock(_SelfAttentionBlock): + """Make a OCR used SelfAttentionBlock.""" + + def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, + act_cfg): + if scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=scale) + else: + query_downsample = None + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=query_downsample, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.bottleneck = ConvModule( + in_channels * 2, + in_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, query_feats, key_feats): + """Forward function.""" + context = super().forward(query_feats, key_feats) + output = self.bottleneck(torch.cat([context, query_feats], dim=1)) + if self.query_downsample is not None: + output = resize(query_feats) + + return output + + +@MODELS.register_module() +class OCRHead(BaseCascadeDecodeHead): + """Object-Contextual Representations for Semantic Segmentation. + + This head is the implementation of `OCRNet + `_. + + Args: + ocr_channels (int): The intermediate channels of OCR block. + scale (int): The scale of probability map in SpatialGatherModule in + Default: 1. + """ + + def __init__(self, ocr_channels, scale=1, **kwargs): + super().__init__(**kwargs) + self.ocr_channels = ocr_channels + self.scale = scale + self.object_context_block = ObjectAttentionBlock( + self.channels, + self.ocr_channels, + self.scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.spatial_gather_module = SpatialGatherModule(self.scale) + + self.bottleneck = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs, prev_output): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.bottleneck(x) + context = self.spatial_gather_module(feats, prev_output) + object_context = self.object_context_block(feats, context) + output = self.cls_seg(object_context) + + return output diff --git a/mmseg/models/decode_heads/pid_head.py b/mmseg/models/decode_heads/pid_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c092cb32d07c279c1d6a45d2e02baccb8e5ffa33 --- /dev/null +++ b/mmseg/models/decode_heads/pid_head.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import resize +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType, SampleList + + +class BasePIDHead(BaseModule): + """Base class for PID head. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict or list[dict], optional): Init config dict. + Default: None. + """ + + def __init__(self, + in_channels: int, + channels: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv = ConvModule( + in_channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('norm', 'act', 'conv')) + _, self.norm = build_norm_layer(norm_cfg, num_features=channels) + self.act = build_activation_layer(act_cfg) + + def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor: + """Forward function. + Args: + x (Tensor): Input tensor. + cls_seg (nn.Module, optional): The classification head. + + Returns: + Tensor: Output tensor. + """ + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + if cls_seg is not None: + x = cls_seg(x) + return x + + +@MODELS.register_module() +class PIDHead(BaseDecodeHead): + """Decode head for PIDNet. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_classes (int): Number of classes. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + """ + + def __init__(self, + in_channels: int, + channels: int, + num_classes: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + **kwargs): + super().__init__( + in_channels, + channels, + num_classes=num_classes, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs) + self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg) + self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg, + act_cfg) + self.d_head = BasePIDHead( + in_channels // 2, + in_channels // 4, + norm_cfg, + ) + self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward( + self, + inputs: Union[Tensor, + Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: + """Forward function. + Args: + inputs (Tensor | tuple[Tensor]): Input tensor or tuple of + Tensor. When training, the input is a tuple of three tensors, + (p_feat, i_feat, d_feat), and the output is a tuple of three + tensors, (p_seg_logit, i_seg_logit, d_seg_logit). + When inference, only the head of integral branch is used, and + input is a tensor of integral feature map, and the output is + the segmentation logit. + + Returns: + Tensor | tuple[Tensor]: Output tensor or tuple of tensors. + """ + if self.training: + x_p, x_i, x_d = inputs + x_p = self.p_head(x_p, self.p_cls_seg) + x_i = self.i_head(x_i, self.cls_seg) + x_d = self.d_head(x_d, self.d_cls_seg) + return x_p, x_i, x_d + else: + return self.i_head(inputs, self.cls_seg) + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]: + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + gt_edge_segs = [ + data_sample.gt_edge_map.data for data_sample in batch_data_samples + ] + gt_sem_segs = torch.stack(gt_semantic_segs, dim=0) + gt_edge_segs = torch.stack(gt_edge_segs, dim=0) + return gt_sem_segs, gt_edge_segs + + def loss_by_feat(self, seg_logits: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + loss = dict() + p_logit, i_logit, d_logit = seg_logits + sem_label, bd_label = self._stack_batch_gt(batch_data_samples) + p_logit = resize( + input=p_logit, + size=sem_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + i_logit = resize( + input=i_logit, + size=sem_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + d_logit = resize( + input=d_logit, + size=bd_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + sem_label = sem_label.squeeze(1) + bd_label = bd_label.squeeze(1) + loss['loss_sem_p'] = self.loss_decode[0]( + p_logit, sem_label, ignore_index=self.ignore_index) + loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label) + loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label) + filler = torch.ones_like(sem_label) * self.ignore_index + sem_bd_label = torch.where( + torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler) + loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label) + loss['acc_seg'] = accuracy( + i_logit, sem_label, ignore_index=self.ignore_index) + return loss diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e433d66249a4690cea3e33e95ec54d58ee3a07 --- /dev/null +++ b/mmseg/models/decode_heads/point_head.py @@ -0,0 +1,367 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +try: + from mmcv.ops import point_sample +except ModuleNotFoundError: + point_sample = None + +from typing import List + +from mmseg.registry import MODELS +from mmseg.utils import SampleList +from ..losses import accuracy +from ..utils import resize +from .cascade_decode_head import BaseCascadeDecodeHead + + +def calculate_uncertainty(seg_logits): + """Estimate uncertainty based on seg logits. + + For each location of the prediction ``seg_logits`` we estimate + uncertainty as the difference between top first and top second + predicted logits. + + Args: + seg_logits (Tensor): Semantic segmentation logits, + shape (batch_size, num_classes, height, width). + + Returns: + scores (Tensor): T uncertainty scores with the most uncertain + locations having the highest uncertainty score, shape ( + batch_size, 1, height, width) + """ + top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +@MODELS.register_module() +class PointHead(BaseCascadeDecodeHead): + """A mask point head use in PointRend. + + This head is implemented of `PointRend: Image Segmentation as + Rendering `_. + ``PointHead`` use shared multi-layer perceptron (equivalent to + nn.Conv1d) to predict the logit of input points. The fine-grained feature + and coarse feature will be concatenate together for predication. + + Args: + num_fcs (int): Number of fc layers in the head. Default: 3. + in_channels (int): Number of input channels. Default: 256. + fc_channels (int): Number of fc channels. Default: 256. + num_classes (int): Number of classes for logits. Default: 80. + class_agnostic (bool): Whether use class agnostic classification. + If so, the output channels of logits will be 1. Default: False. + coarse_pred_each_layer (bool): Whether concatenate coarse feature with + the output of each fc layer. Default: True. + conv_cfg (dict|None): Dictionary to construct and config conv layer. + Default: dict(type='Conv1d')) + norm_cfg (dict|None): Dictionary to construct and config norm layer. + Default: None. + loss_point (dict): Dictionary to construct and config loss layer of + point head. Default: dict(type='CrossEntropyLoss', use_mask=True, + loss_weight=1.0). + """ + + def __init__(self, + num_fcs=3, + coarse_pred_each_layer=True, + conv_cfg=dict(type='Conv1d'), + norm_cfg=None, + act_cfg=dict(type='ReLU', inplace=False), + **kwargs): + super().__init__( + input_transform='multiple_select', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='fc_seg')), + **kwargs) + if point_sample is None: + raise RuntimeError('Please install mmcv-full for ' + 'point_sample ops') + + self.num_fcs = num_fcs + self.coarse_pred_each_layer = coarse_pred_each_layer + + fc_in_channels = sum(self.in_channels) + self.num_classes + fc_channels = self.channels + self.fcs = nn.ModuleList() + for k in range(num_fcs): + fc = ConvModule( + fc_in_channels, + fc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.fcs.append(fc) + fc_in_channels = fc_channels + fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ + else 0 + self.fc_seg = nn.Conv1d( + fc_in_channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0) + if self.dropout_ratio > 0: + self.dropout = nn.Dropout(self.dropout_ratio) + delattr(self, 'conv_seg') + + def cls_seg(self, feat): + """Classify each pixel with fc.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.fc_seg(feat) + return output + + def forward(self, fine_grained_point_feats, coarse_point_feats): + x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) + for fc in self.fcs: + x = fc(x) + if self.coarse_pred_each_layer: + x = torch.cat((x, coarse_point_feats), dim=1) + return self.cls_seg(x) + + def _get_fine_grained_point_feats(self, x, points): + """Sample from fine grained features. + + Args: + x (list[Tensor]): Feature pyramid from by neck or backbone. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + fine_grained_feats (Tensor): Sampled fine grained feature, + shape (batch_size, sum(channels of x), num_points). + """ + + fine_grained_feats_list = [ + point_sample(_, points, align_corners=self.align_corners) + for _ in x + ] + if len(fine_grained_feats_list) > 1: + fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) + else: + fine_grained_feats = fine_grained_feats_list[0] + + return fine_grained_feats + + def _get_coarse_point_feats(self, prev_output, points): + """Sample from fine grained features. + + Args: + prev_output (list[Tensor]): Prediction of previous decode head. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, + num_classes, num_points). + """ + + coarse_feats = point_sample( + prev_output, points, align_corners=self.align_corners) + + return coarse_feats + + def loss(self, inputs, prev_output, batch_data_samples: SampleList, + train_cfg, **kwargs): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self._transform_inputs(inputs) + with torch.no_grad(): + points = self.get_points_train( + prev_output, calculate_uncertainty, cfg=train_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats(prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + losses = self.loss_by_feat(point_logits, points, batch_data_samples) + + return losses + + def predict(self, inputs, prev_output, batch_img_metas: List[dict], + test_cfg, **kwargs): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + + x = self._transform_inputs(inputs) + refined_seg_logits = prev_output.clone() + for _ in range(test_cfg.subdivision_steps): + refined_seg_logits = resize( + refined_seg_logits, + scale_factor=test_cfg.scale_factor, + mode='bilinear', + align_corners=self.align_corners) + batch_size, channels, height, width = refined_seg_logits.shape + point_indices, points = self.get_points_test( + refined_seg_logits, calculate_uncertainty, cfg=test_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats( + prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) + refined_seg_logits = refined_seg_logits.reshape( + batch_size, channels, height * width) + refined_seg_logits = refined_seg_logits.scatter_( + 2, point_indices, point_logits) + refined_seg_logits = refined_seg_logits.view( + batch_size, channels, height, width) + + return self.predict_by_feat(refined_seg_logits, batch_img_metas, + **kwargs) + + def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs): + """Compute segmentation loss.""" + gt_semantic_seg = self._stack_batch_gt(batch_data_samples) + point_label = point_sample( + gt_semantic_seg.float(), + points, + mode='nearest', + align_corners=self.align_corners) + point_label = point_label.squeeze(1).long() + + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_module in losses_decode: + loss['point' + loss_module.loss_name] = loss_module( + point_logits, point_label, ignore_index=self.ignore_index) + + loss['acc_point'] = accuracy( + point_logits, point_label, ignore_index=self.ignore_index) + return loss + + def get_points_train(self, seg_logits, uncertainty_func, cfg): + """Sample points for training. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'uncertainty_func' function that takes point's logit prediction as + input. + + Args: + seg_logits (Tensor): Semantic segmentation logits, shape ( + batch_size, num_classes, height, width). + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Training config of point head. + + Returns: + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains the coordinates of ``num_points`` sampled + points. + """ + num_points = cfg.num_points + oversample_ratio = cfg.oversample_ratio + importance_sample_ratio = cfg.importance_sample_ratio + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = seg_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=seg_logits.device) + point_logits = point_sample(seg_logits, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=seg_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_point_coords = torch.rand( + batch_size, num_random_points, 2, device=seg_logits.device) + point_coords = torch.cat((point_coords, rand_point_coords), dim=1) + return point_coords + + def get_points_test(self, seg_logits, uncertainty_func, cfg): + """Sample points for testing. + + Find ``num_points`` most uncertain points from ``uncertainty_map``. + + Args: + seg_logits (Tensor): A tensor of shape (batch_size, num_classes, + height, width) for class-specific or class-agnostic prediction. + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Testing config of point head. + + Returns: + point_indices (Tensor): A tensor of shape (batch_size, num_points) + that contains indices from [0, height x width) of the most + uncertain points. + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains [0, 1] x [0, 1] normalized coordinates of the + most uncertain points from the ``height x width`` grid . + """ + + num_points = cfg.subdivision_num_points + uncertainty_map = uncertainty_func(seg_logits) + batch_size, _, height, width = uncertainty_map.shape + h_step = 1.0 / height + w_step = 1.0 / width + + uncertainty_map = uncertainty_map.view(batch_size, height * width) + num_points = min(height * width, num_points) + point_indices = uncertainty_map.topk(num_points, dim=1)[1] + point_coords = torch.zeros( + batch_size, + num_points, + 2, + dtype=torch.float, + device=seg_logits.device) + point_coords[:, :, 0] = w_step / 2.0 + (point_indices % + width).float() * w_step + point_coords[:, :, 1] = h_step / 2.0 + (point_indices // + width).float() * h_step + return point_indices, point_coords diff --git a/mmseg/models/decode_heads/psa_head.py b/mmseg/models/decode_heads/psa_head.py new file mode 100644 index 0000000000000000000000000000000000000000..13ee5c58a569bb46612625b85685cd61b7e9df3e --- /dev/null +++ b/mmseg/models/decode_heads/psa_head.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + +try: + from mmcv.ops import PSAMask +except ModuleNotFoundError: + PSAMask = None + + +@MODELS.register_module() +class PSAHead(BaseDecodeHead): + """Point-wise Spatial Attention Network for Scene Parsing. + + This head is the implementation of `PSANet + `_. + + Args: + mask_size (tuple[int]): The PSA mask size. It usually equals input + size. + psa_type (str): The type of psa module. Options are 'collect', + 'distribute', 'bi-direction'. Default: 'bi-direction' + compact (bool): Whether use compact map for 'collect' mode. + Default: True. + shrink_factor (int): The downsample factors of psa mask. Default: 2. + normalization_factor (float): The normalize factor of attention. + psa_softmax (bool): Whether use softmax for attention. + """ + + def __init__(self, + mask_size, + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + **kwargs): + if PSAMask is None: + raise RuntimeError('Please install mmcv-full for PSAMask ops') + super().__init__(**kwargs) + assert psa_type in ['collect', 'distribute', 'bi-direction'] + self.psa_type = psa_type + self.compact = compact + self.shrink_factor = shrink_factor + self.mask_size = mask_size + mask_h, mask_w = mask_size + self.psa_softmax = psa_softmax + if normalization_factor is None: + normalization_factor = mask_h * mask_w + self.normalization_factor = normalization_factor + + self.reduce = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + if psa_type == 'bi-direction': + self.reduce_p = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention_p = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + self.psamask_collect = PSAMask('collect', mask_size) + self.psamask_distribute = PSAMask('distribute', mask_size) + else: + self.psamask = PSAMask(psa_type, mask_size) + self.proj = ConvModule( + self.channels * (2 if psa_type == 'bi-direction' else 1), + self.in_channels, + kernel_size=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + self.in_channels * 2, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + identity = x + align_corners = self.align_corners + if self.psa_type in ['collect', 'distribute']: + out = self.reduce(x) + n, c, h, w = out.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + out = resize( + out, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y = self.attention(out) + if self.compact: + if self.psa_type == 'collect': + y = y.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y = self.psamask(y) + if self.psa_softmax: + y = F.softmax(y, dim=1) + out = torch.bmm( + out.view(n, c, h * w), y.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + else: + x_col = self.reduce(x) + x_dis = self.reduce_p(x) + n, c, h, w = x_col.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + x_col = resize( + x_col, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + x_dis = resize( + x_dis, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y_col = self.attention(x_col) + y_dis = self.attention_p(x_dis) + if self.compact: + y_dis = y_dis.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y_col = self.psamask_collect(y_col) + y_dis = self.psamask_distribute(y_dis) + if self.psa_softmax: + y_col = F.softmax(y_col, dim=1) + y_dis = F.softmax(y_dis, dim=1) + x_col = torch.bmm( + x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + x_dis = torch.bmm( + x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + out = torch.cat([x_col, x_dis], 1) + out = self.proj(out) + out = resize( + out, + size=identity.shape[2:], + mode='bilinear', + align_corners=align_corners) + out = self.bottleneck(torch.cat((identity, out), dim=1)) + out = self.cls_seg(out) + return out diff --git a/mmseg/models/decode_heads/psp_head.py b/mmseg/models/decode_heads/psp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a40ec41dec281e53815e9753ee2ba1a5da76bd05 --- /dev/null +++ b/mmseg/models/decode_heads/psp_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class PPM(nn.ModuleList): + """Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, + act_cfg, align_corners, **kwargs): + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **kwargs))) + + def forward(self, x): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = resize( + ppm_out, + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +@MODELS.register_module() +class PSPHead(BaseDecodeHead): + """Pyramid Scene Parsing Network. + + This head is the implementation of + `PSPNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super().__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.psp_modules = PPM( + self.pool_scales, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + feats = self.bottleneck(psp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/segformer_head.py b/mmseg/models/decode_heads/segformer_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f9eb0b320b4e7b892e0540cea5ba5ea7054f8008 --- /dev/null +++ b/mmseg/models/decode_heads/segformer_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class SegformerHead(BaseDecodeHead): + """The all mlp Head of segformer. + + This head is the implementation of + `Segformer ` _. + + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + + out = self.cls_seg(out) + + return out diff --git a/mmseg/models/decode_heads/segmenter_mask_head.py b/mmseg/models/decode_heads/segmenter_mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..85d27735ba8015772324177716b5e8d5f357295c --- /dev/null +++ b/mmseg/models/decode_heads/segmenter_mask_head.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmengine.model import ModuleList +from mmengine.model.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) + +from mmseg.models.backbones.vit import TransformerEncoderLayer +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SegmenterMaskTransformerHead(BaseDecodeHead): + """Segmenter: Transformer for Semantic Segmentation. + + This head is the implementation of + `Segmenter: `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input image. + num_layers (int): The depth of transformer. + num_heads (int): The number of attention heads. + embed_dims (int): The number of embedding dimension. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + drop_path_rate (float): stochastic depth rate. Default 0.1. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + init_std (float): The value of std in weight initialization. + Default: 0.02. + """ + + def __init__( + self, + in_channels, + num_layers, + num_heads, + embed_dims, + mlp_ratio=4, + drop_path_rate=0.1, + drop_rate=0.0, + attn_drop_rate=0.0, + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_std=0.02, + **kwargs, + ): + super().__init__(in_channels=in_channels, **kwargs) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + batch_first=True, + )) + + self.dec_proj = nn.Linear(in_channels, embed_dims) + + self.cls_emb = nn.Parameter( + torch.randn(1, self.num_classes, embed_dims)) + self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) + self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) + + self.decoder_norm = build_norm_layer( + norm_cfg, embed_dims, postfix=1)[1] + self.mask_norm = build_norm_layer( + norm_cfg, self.num_classes, postfix=2)[1] + + self.init_std = init_std + + delattr(self, 'conv_seg') + + def init_weights(self): + trunc_normal_(self.cls_emb, std=self.init_std) + trunc_normal_init(self.patch_proj, std=self.init_std) + trunc_normal_init(self.classes_proj, std=self.init_std) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=self.init_std, bias=0) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.0) + + def forward(self, inputs): + x = self._transform_inputs(inputs) + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) + + x = self.dec_proj(x) + cls_emb = self.cls_emb.expand(x.size(0), -1, -1) + x = torch.cat((x, cls_emb), 1) + for layer in self.layers: + x = layer(x) + x = self.decoder_norm(x) + + patches = self.patch_proj(x[:, :-self.num_classes]) + cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) + + patches = F.normalize(patches, dim=2, p=2) + cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) + + masks = patches @ cls_seg_feat.transpose(1, 2) + masks = self.mask_norm(masks) + masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w) + + return masks diff --git a/mmseg/models/decode_heads/sep_aspp_head.py b/mmseg/models/decode_heads/sep_aspp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9dba68c9ecc6909e47da4f2da6169d529910355d --- /dev/null +++ b/mmseg/models/decode_heads/sep_aspp_head.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .aspp_head import ASPPHead, ASPPModule + + +class DepthwiseSeparableASPPModule(ASPPModule): + """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable + conv.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for i, dilation in enumerate(self.dilations): + if dilation > 1: + self[i] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + 3, + dilation=dilation, + padding=dilation, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + +@MODELS.register_module() +class DepthwiseSeparableASPPHead(ASPPHead): + """Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation. + + This head is the implementation of `DeepLabV3+ + `_. + + Args: + c1_in_channels (int): The input channels of c1 decoder. If is 0, + the no decoder will be used. + c1_channels (int): The intermediate channels of c1 decoder. + """ + + def __init__(self, c1_in_channels, c1_channels, **kwargs): + super().__init__(**kwargs) + assert c1_in_channels >= 0 + self.aspp_modules = DepthwiseSeparableASPPModule( + dilations=self.dilations, + in_channels=self.in_channels, + channels=self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if c1_in_channels > 0: + self.c1_bottleneck = ConvModule( + c1_in_channels, + c1_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + else: + self.c1_bottleneck = None + self.sep_bottleneck = nn.Sequential( + DepthwiseSeparableConvModule( + self.channels + c1_channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + DepthwiseSeparableConvModule( + self.channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + if self.c1_bottleneck is not None: + c1_output = self.c1_bottleneck(inputs[0]) + output = resize( + input=output, + size=c1_output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = torch.cat([output, c1_output], dim=1) + output = self.sep_bottleneck(output) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3b15983bceaeff48534bbceedfdf1c434a8d1d1f --- /dev/null +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import DepthwiseSeparableConvModule + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class DepthwiseSeparableFCNHead(FCNHead): + """Depthwise-Separable Fully Convolutional Network for Semantic + Segmentation. + + This head is implemented according to `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels(int): Number of output channels of FFM. + channels(int): Number of middle-stage channels in the decode head. + concat_input(bool): Whether to concatenate original decode input into + the result of several consecutive convolution layers. + Default: True. + num_classes(int): Used to determine the dimension of + final prediction tensor. + in_index(int): Correspond with 'out_indices' in FastSCNN backbone. + norm_cfg (dict | None): Config of norm layers. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_decode(dict): Config of loss type and some + relevant additional options. + dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: None. + """ + + def __init__(self, dw_act_cfg=None, **kwargs): + super().__init__(**kwargs) + self.convs[0] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + for i in range(1, self.num_convs): + self.convs[i] = DepthwiseSeparableConvModule( + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + if self.concat_input: + self.conv_cat = DepthwiseSeparableConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) diff --git a/mmseg/models/decode_heads/setr_mla_head.py b/mmseg/models/decode_heads/setr_mla_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1975991a60cc720650b880060efe10753f213131 --- /dev/null +++ b/mmseg/models/decode_heads/setr_mla_head.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import Upsample +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SETRMLAHead(BaseDecodeHead): + """Multi level feature aggretation head of SETR. + + MLA head of `SETR `_. + + Args: + mlahead_channels (int): Channels of conv-conv-4x of multi-level feature + aggregation. Default: 128. + up_scale (int): The scale factor of interpolate. Default:4. + """ + + def __init__(self, mla_channels=128, up_scale=4, **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.mla_channels = mla_channels + + num_inputs = len(self.in_channels) + + # Refer to self.cls_seg settings of BaseDecodeHead + assert self.channels == num_inputs * mla_channels + + self.up_convs = nn.ModuleList() + for i in range(num_inputs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=self.in_channels[i], + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + in_channels=mla_channels, + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + + def forward(self, inputs): + inputs = self._transform_inputs(inputs) + outs = [] + for x, up_conv in zip(inputs, self.up_convs): + outs.append(up_conv(x)) + out = torch.cat(outs, dim=1) + out = self.cls_seg(out) + return out diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py new file mode 100644 index 0000000000000000000000000000000000000000..9c796d8161088c2d7effe17f5ba71e43ff62e50c --- /dev/null +++ b/mmseg/models/decode_heads/setr_up_head.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import Upsample +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SETRUPHead(BaseDecodeHead): + """Naive upsampling head and Progressive upsampling head of SETR. + + Naive or PUP head of `SETR `_. + + Args: + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + num_convs (int): Number of decoder convolutions. Default: 1. + up_scale (int): The scale factor of interpolate. Default:4. + kernel_size (int): The kernel size of convolution when decoding + feature information from backbone. Default: 3. + init_cfg (dict | list[dict] | None): Initialization config dict. + Default: dict( + type='Constant', val=1.0, bias=0, layer='LayerNorm'). + """ + + def __init__(self, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + num_convs=1, + up_scale=4, + kernel_size=3, + init_cfg=[ + dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), + dict( + type='Normal', + std=0.01, + override=dict(name='conv_seg')) + ], + **kwargs): + + assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' + + super().__init__(init_cfg=init_cfg, **kwargs) + + assert isinstance(self.in_channels, int) + + _, self.norm = build_norm_layer(norm_layer, self.in_channels) + + self.up_convs = nn.ModuleList() + in_channels = self.in_channels + out_channels = self.channels + for _ in range(num_convs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=int(kernel_size - 1) // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + in_channels = out_channels + + def forward(self, x): + x = self._transform_inputs(x) + + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + + for up_conv in self.up_convs: + x = up_conv(x) + out = self.cls_seg(x) + return out diff --git a/mmseg/models/decode_heads/stdc_head.py b/mmseg/models/decode_heads/stdc_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1c1c21e3083fcb5098d2458e44538c0cf5b8f0e4 --- /dev/null +++ b/mmseg/models/decode_heads/stdc_head.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList +from .fcn_head import FCNHead + + +@MODELS.register_module() +class STDCHead(FCNHead): + """This head is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + boundary_threshold (float): The threshold of calculating boundary. + Default: 0.1. + """ + + def __init__(self, boundary_threshold=0.1, **kwargs): + super().__init__(**kwargs) + self.boundary_threshold = boundary_threshold + # Using register buffer to make laplacian kernel on the same + # device of `seg_label`. + self.register_buffer( + 'laplacian_kernel', + torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], + dtype=torch.float32, + requires_grad=False).reshape((1, 1, 3, 3))) + self.fusion_kernel = torch.nn.Parameter( + torch.tensor([[6. / 10], [3. / 10], [1. / 10]], + dtype=torch.float32).reshape(1, 3, 1, 1), + requires_grad=False) + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute Detail Aggregation Loss.""" + # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv + # parameters. However, it is a constant in original repo and other + # codebase because it would not be added into computation graph + # after threshold operation. + seg_label = self._stack_batch_gt(batch_data_samples).to( + self.laplacian_kernel) + boundary_targets = F.conv2d( + seg_label, self.laplacian_kernel, padding=1) + boundary_targets = boundary_targets.clamp(min=0) + boundary_targets[boundary_targets > self.boundary_threshold] = 1 + boundary_targets[boundary_targets <= self.boundary_threshold] = 0 + + boundary_targets_x2 = F.conv2d( + seg_label, self.laplacian_kernel, stride=2, padding=1) + boundary_targets_x2 = boundary_targets_x2.clamp(min=0) + + boundary_targets_x4 = F.conv2d( + seg_label, self.laplacian_kernel, stride=4, padding=1) + boundary_targets_x4 = boundary_targets_x4.clamp(min=0) + + boundary_targets_x4_up = F.interpolate( + boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x2_up = F.interpolate( + boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') + + boundary_targets_x2_up[ + boundary_targets_x2_up > self.boundary_threshold] = 1 + boundary_targets_x2_up[ + boundary_targets_x2_up <= self.boundary_threshold] = 0 + + boundary_targets_x4_up[ + boundary_targets_x4_up > self.boundary_threshold] = 1 + boundary_targets_x4_up[ + boundary_targets_x4_up <= self.boundary_threshold] = 0 + + boundary_targets_pyramids = torch.stack( + (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), + dim=1) + + boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) + boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids, + self.fusion_kernel) + + boudary_targets_pyramid[ + boudary_targets_pyramid > self.boundary_threshold] = 1 + boudary_targets_pyramid[ + boudary_targets_pyramid <= self.boundary_threshold] = 0 + + seg_labels = boudary_targets_pyramid.long() + batch_sample_list = [] + for label in seg_labels: + seg_data_sample = SegDataSample() + seg_data_sample.gt_sem_seg = PixelData(data=label) + batch_sample_list.append(seg_data_sample) + + loss = super().loss_by_feat(seg_logits, batch_sample_list) + return loss diff --git a/mmseg/models/decode_heads/uper_head.py b/mmseg/models/decode_heads/uper_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ccc3173c0f1193e89ad48861aa7b5ee3b329cc --- /dev/null +++ b/mmseg/models/decode_heads/uper_head.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead +from .psp_head import PPM + + +@MODELS.register_module() +class UPerHead(BaseDecodeHead): + """Unified Perceptual Parsing for Scene Understanding. + + This head is the implementation of `UPerNet + `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module applied on the last feature. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + # PSP Module + self.psp_modules = PPM( + pool_scales, + self.in_channels[-1], + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels[-1] + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + fpn_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def psp_forward(self, inputs): + """Forward function of PSP module.""" + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + inputs = self._transform_inputs(inputs) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + laterals.append(self.psp_forward(inputs)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], + size=prev_shape, + mode='bilinear', + align_corners=self.align_corners) + + # build outputs + fpn_outs = [ + self.fpn_convs[i](laterals[i]) + for i in range(used_backbone_levels - 1) + ] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = resize( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = torch.cat(fpn_outs, dim=1) + feats = self.fpn_bottleneck(fpn_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/mmseg/models/losses/__init__.py b/mmseg/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f7e39cb28b4ff8f9db19d3c0663b7a88a65d908 --- /dev/null +++ b/mmseg/models/losses/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .accuracy import Accuracy, accuracy +from .boundary_loss import BoundaryLoss +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss +from .focal_loss import FocalLoss +from .lovasz_loss import LovaszLoss +from .ohem_cross_entropy_loss import OhemCrossEntropy +from .tversky_loss import TverskyLoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss + +__all__ = [ + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', + 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss' +] diff --git a/mmseg/models/losses/__pycache__/__init__.cpython-310.pyc b/mmseg/models/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..948ea645b7a060c2d8a2edf07d6342ecc3550b11 Binary files /dev/null and b/mmseg/models/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/accuracy.cpython-310.pyc b/mmseg/models/losses/__pycache__/accuracy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e59e504fd57e97b345b9eba333f496afea80aacd Binary files /dev/null and b/mmseg/models/losses/__pycache__/accuracy.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/boundary_loss.cpython-310.pyc b/mmseg/models/losses/__pycache__/boundary_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07b73b1905d31e09a7f0473852340f96f05a347c Binary files /dev/null and b/mmseg/models/losses/__pycache__/boundary_loss.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-310.pyc b/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d05a3cd0d17b6f334aa0a19839c0dbe1b3dfa15f Binary files /dev/null and b/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/dice_loss.cpython-310.pyc b/mmseg/models/losses/__pycache__/dice_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6774e62c146f87014b4d5309d8f6fda6d6173e42 Binary files /dev/null and b/mmseg/models/losses/__pycache__/dice_loss.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/focal_loss.cpython-310.pyc b/mmseg/models/losses/__pycache__/focal_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ceec32ffc4783279bd4d5056821c73febc0b927 Binary files /dev/null and b/mmseg/models/losses/__pycache__/focal_loss.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/lovasz_loss.cpython-310.pyc b/mmseg/models/losses/__pycache__/lovasz_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22bff29ad2679243c17e1a1190841be5a5b68a77 Binary files /dev/null and b/mmseg/models/losses/__pycache__/lovasz_loss.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/ohem_cross_entropy_loss.cpython-310.pyc b/mmseg/models/losses/__pycache__/ohem_cross_entropy_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d004da4b4fb67a071af61690293850a22d3083a Binary files /dev/null and b/mmseg/models/losses/__pycache__/ohem_cross_entropy_loss.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/tversky_loss.cpython-310.pyc b/mmseg/models/losses/__pycache__/tversky_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f28e4684f840f078d885fca07449507726503dcd Binary files /dev/null and b/mmseg/models/losses/__pycache__/tversky_loss.cpython-310.pyc differ diff --git a/mmseg/models/losses/__pycache__/utils.cpython-310.pyc b/mmseg/models/losses/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13399d4752695c7fe133e385fb0aabf15c50c37b Binary files /dev/null and b/mmseg/models/losses/__pycache__/utils.cpython-310.pyc differ diff --git a/mmseg/models/losses/accuracy.py b/mmseg/models/losses/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9e2d7701088adadd5b6bb71c718c986b87a066 --- /dev/null +++ b/mmseg/models/losses/accuracy.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class, ...) + target (torch.Tensor): The target of each prediction, shape (N, , ...) + ignore_index (int | None): The label index to be ignored. Default: None + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + if ignore_index is not None: + correct = correct[:, target != ignore_index] + res = [] + eps = torch.finfo(torch.float32).eps + for k in topk: + # Avoid causing ZeroDivisionError when all pixels + # of an image are ignored + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps + if ignore_index is not None: + total_num = target[target != ignore_index].numel() + eps + else: + total_num = target.numel() + eps + res.append(correct_k.mul_(100.0 / total_num)) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + """Accuracy calculation module.""" + + def __init__(self, topk=(1, ), thresh=None, ignore_index=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + self.ignore_index = ignore_index + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh, + self.ignore_index) diff --git a/mmseg/models/losses/boundary_loss.py b/mmseg/models/losses/boundary_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e86b850d87e1d26be8cbb700758dae8dead82c58 --- /dev/null +++ b/mmseg/models/losses/boundary_loss.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class BoundaryLoss(nn.Module): + """Boundary loss. + + This function is modified from + `PIDNet `_. # noqa + Licensed under the MIT License. + + + Args: + loss_weight (float): Weight of the loss. Defaults to 1.0. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + loss_weight: float = 1.0, + loss_name: str = 'loss_boundary'): + super().__init__() + self.loss_weight = loss_weight + self.loss_name_ = loss_name + + def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor: + """Forward function. + Args: + bd_pre (Tensor): Predictions of the boundary head. + bd_gt (Tensor): Ground truth of the boundary. + + Returns: + Tensor: Loss tensor. + """ + log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1) + target_t = bd_gt.view(1, -1).float() + + pos_index = (target_t == 1) + neg_index = (target_t == 0) + + weight = torch.zeros_like(log_p) + pos_num = pos_index.sum() + neg_num = neg_index.sum() + sum_num = pos_num + neg_num + weight[pos_index] = neg_num * 1.0 / sum_num + weight[neg_index] = pos_num * 1.0 / sum_num + + loss = F.binary_cross_entropy_with_logits( + log_p, target_t, weight, reduction='mean') + + return self.loss_weight * loss + + @property + def loss_name(self): + return self.loss_name_ diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..780fdda3b95bceecf4e1013c74f20b08c1d032a8 --- /dev/null +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -0,0 +1,297 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.registry import MODELS +from .utils import get_class_weight, weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - (label == ignore_index).sum().item() + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + import ipdb; ipdb.set_trace() + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + # As the ignore_index often set as 255, so the + # binary class label check should mask out + # ignore_index + assert label[label != ignore_index].max() <= 1, \ + 'For pred with shape [N, 1, H, W], its label must have at ' \ + 'most 2 classes' + pred = pred.squeeze(dim=1) + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == 'mean' and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None, + **kwargs): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_ce', + avg_non_ignore=False): + super().__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == 'mean': + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=-100, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee89a81f4e8758913491d9740fae5a9e58f081f --- /dev/null +++ b/mmseg/models/losses/dice_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ +segmentron/solver/loss.py (Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.registry import MODELS +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def dice_loss(pred, + target, + valid_mask, + smooth=1, + exponent=2, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + dice_loss = binary_dice_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + smooth=smooth, + exponent=exponent) + if class_weight is not None: + dice_loss *= class_weight[i] + total_loss += dice_loss + return total_loss / num_classes + + +@weighted_loss +def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth + den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth + + return 1 - num / den + + +@MODELS.register_module() +class DiceLoss(nn.Module): + """DiceLoss. + + This loss is proposed in `V-Net: Fully Convolutional Neural Networks for + Volumetric Medical Image Segmentation `_. + + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1 + exponent (float): An float number to calculate denominator + value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_dice'. + """ + + def __init__(self, + smooth=1, + exponent=2, + reduction='mean', + class_weight=None, + loss_weight=1.0, + ignore_index=255, + loss_name='loss_dice', + **kwards): + super().__init__() + self.smooth = smooth + self.exponent = exponent + self.reduction = reduction + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + self._loss_name = loss_name + + def forward(self, + pred, + target, + avg_factor=None, + reduction_override=None, + **kwards): + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * dice_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor, + smooth=self.smooth, + exponent=self.exponent, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/focal_loss.py b/mmseg/models/losses/focal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..104d6602c80b91af58f09963494288098a3b0572 --- /dev/null +++ b/mmseg/models/losses/focal_loss.py @@ -0,0 +1,327 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/open-mmlab/mmdetection +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +# This method is used when cuda is not available +def py_sigmoid_focal_loss(pred, + target, + one_hot_target=None, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction with + shape (N, C) + one_hot_target (None): Placeholder. It should be None. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + if isinstance(alpha, list): + alpha = pred.new_tensor(alpha) + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * one_minus_pt.pow(gamma) + + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + final_weight = torch.ones(1, pred.size(1)).type_as(loss) + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + one_hot_target, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + r"""A wrapper of cuda version `Focal Loss + `_. + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. It's shape + should be (N, ) + one_hot_target (torch.Tensor): The learning label with shape (N, C) + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + final_weight = torch.ones(1, pred.size(1)).type_as(pred) + if isinstance(alpha, list): + # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if + # a list is given, we set the input alpha as 0.5. This means setting + # equal weight for foreground class and background class. By + # multiplying the loss by 2, the effect of setting alpha as 0.5 is + # undone. The alpha of type list is used to regulate the loss in the + # post-processing process. + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, 0.5, None, 'none') * 2 + alpha = pred.new_tensor(alpha) + final_weight = final_weight * ( + alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) + else: + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, alpha, None, 'none') + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class FocalLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + gamma=2.0, + alpha=0.5, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_focal'): + """`Focal Loss `_ + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal + Loss. Defaults to 0.5. When a list is provided, the length + of the list should be equal to the number of classes. + Please be careful that this parameter is not the + class-wise weight but the weight of a binary classification + problem. This binary classification problem regards the + pixels which belong to one class as the foreground + and the other pixels as the background, each element in + the list is the weight of the corresponding foreground class. + The value of alpha or each element of alpha should be a float + in the interval [0, 1]. If you want to specify the class-wise + weight, please use `class_weight` parameter. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_focal'. + """ + super().__init__() + assert use_sigmoid is True, \ + 'AssertionError: Only sigmoid focal loss supported now.' + assert reduction in ('none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert isinstance(alpha, (float, list)), \ + 'AssertionError: alpha should be of type float' + assert isinstance(gamma, float), \ + 'AssertionError: gamma should be of type float' + assert isinstance(loss_weight, float), \ + 'AssertionError: loss_weight should be of type float' + assert isinstance(loss_name, str), \ + 'AssertionError: loss_name should be of type str' + assert isinstance(class_weight, list) or class_weight is None, \ + 'AssertionError: class_weight must be None or of type list' + self.use_sigmoid = use_sigmoid + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.class_weight = class_weight + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction with shape + (N, C) where C = number of classes, or + (N, C, d_1, d_2, ..., d_K) with K≥1 in the + case of K-dimensional loss. + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤C−1, + or (N, d_1, d_2, ..., d_K) with K≥1 in the case of + K-dimensional loss. If containing class probabilities, + same shape as the input. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used + to override the original reduction method of the loss. + Options are "none", "mean" and "sum". + ignore_index (int, optional): The label index to be ignored. + Default: 255 + Returns: + torch.Tensor: The calculated loss + """ + assert isinstance(ignore_index, int), \ + 'ignore_index must be of type int' + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert pred.shape == target.shape or \ + (pred.size(0) == target.size(0) and + pred.shape[2:] == target.shape[1:]), \ + "The shape of pred doesn't match the shape of target" + + original_shape = pred.shape + + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] + pred = pred.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + pred = pred.reshape(pred.size(0), -1) + # [C, N] -> [N, C] + pred = pred.transpose(0, 1).contiguous() + + if original_shape == target.shape: + # target with shape [B, C, d_1, d_2, ...] + # transform it's shape into [N, C] + # [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k] + target = target.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + target = target.reshape(target.size(0), -1) + # [C, N] -> [N, C] + target = target.transpose(0, 1).contiguous() + else: + # target with shape [B, d_1, d_2, ...] + # transform it's shape into [N, ] + target = target.view(-1).contiguous() + valid_mask = (target != ignore_index).view(-1, 1) + # avoid raising error when using F.one_hot() + target = torch.where(target == ignore_index, target.new_tensor(0), + target) + + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + num_classes = pred.size(1) + if torch.cuda.is_available() and pred.is_cuda: + if target.dim() == 1: + one_hot_target = F.one_hot(target, num_classes=num_classes) + else: + one_hot_target = target + target = target.argmax(dim=1) + valid_mask = (target != ignore_index).view(-1, 1) + calculate_loss_func = sigmoid_focal_loss + else: + one_hot_target = None + if target.dim() == 1: + target = F.one_hot(target, num_classes=num_classes) + else: + valid_mask = (target.argmax(dim=1) != ignore_index).view( + -1, 1) + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + one_hot_target, + weight, + gamma=self.gamma, + alpha=self.alpha, + class_weight=self.class_weight, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor) + + if reduction == 'none': + # [N, C] -> [C, N] + loss_cls = loss_cls.transpose(0, 1) + # [C, N] -> [C, B, d1, d2, ...] + # original_shape: [B, C, d1, d2, ...] + loss_cls = loss_cls.reshape(original_shape[1], + original_shape[0], + *original_shape[2:]) + # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...] + loss_cls = loss_cls.transpose(0, 1).contiguous() + else: + raise NotImplementedError + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/lovasz_loss.py b/mmseg/models/losses/lovasz_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b47f9d8a15330a45d0d2d25f3c18d9386e2b335e --- /dev/null +++ b/mmseg/models/losses/lovasz_loss.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License)""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import is_list_of + +from mmseg.registry import MODELS +from .utils import get_class_weight, weight_reduce_loss + + +def lovasz_grad(gt_sorted): + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits(logits, labels, ignore_index=None): + """Flattens predictions in the batch (binary case) Remove labels equal to + 'ignore_index'.""" + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs(probs, labels, ignore_index=None): + """Flattens predictions in the batch.""" + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [P], logits at each prediction + (between -infty and +infty). + labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [B, H, W], logits at each pixel + (between -infty and +infty). + labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + classes (str | list[int], optional): Placeholder, to be consistent with + other loss. Default: None. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_image: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [P, C], class probabilities at each prediction + (between 0 and 1). + labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + class_weight (list[float], optional): The weight for each class. + Default: None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [B, C, H, W], class probabilities at each + prediction (between 0 and 1). + labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and + C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_image: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@MODELS.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_lovasz'. + """ + + def __init__(self, + loss_type='multi_class', + classes='present', + per_image=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_lovasz'): + super().__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or is_list_of(classes, int) + if not per_image: + assert reduction == 'none', "reduction should be 'none' when \ + per_image is False." + + self.classes = classes + self.per_image = per_image + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_image, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/ohem_cross_entropy_loss.py b/mmseg/models/losses/ohem_cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a519b4d84e1dbf86ebc7ad07372ddbdfb0ff3d13 --- /dev/null +++ b/mmseg/models/losses/ohem_cross_entropy_loss.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class OhemCrossEntropy(nn.Module): + """OhemCrossEntropy loss. + + This func is modified from + `PIDNet `_. # noqa + + Licensed under the MIT License. + + Args: + ignore_label (int): Labels to ignore when computing the loss. + Default: 255 + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: 0.7. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + loss_weight (float): Weight of the loss. Defaults to 1.0. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + ignore_label: int = 255, + thres: float = 0.7, + min_kept: int = 100000, + loss_weight: float = 1.0, + class_weight: Optional[Union[List[float], str]] = None, + loss_name: str = 'loss_ohem'): + super().__init__() + self.thresh = thres + self.min_kept = max(1, min_kept) + self.ignore_label = ignore_label + self.loss_weight = loss_weight + self.loss_name_ = loss_name + self.class_weight = class_weight + + def forward(self, score: Tensor, target: Tensor) -> Tensor: + """Forward function. + Args: + score (Tensor): Predictions of the segmentation head. + target (Tensor): Ground truth of the image. + + Returns: + Tensor: Loss tensor. + """ + # score: (N, C, H, W) + pred = F.softmax(score, dim=1) + if self.class_weight is not None: + class_weight = score.new_tensor(self.class_weight) + else: + class_weight = None + + pixel_losses = F.cross_entropy( + score, + target, + weight=class_weight, + ignore_index=self.ignore_label, + reduction='none').contiguous().view(-1) # (N*H*W) + mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W) + + tmp_target = target.clone() # (N, H, W) + tmp_target[tmp_target == self.ignore_label] = 0 + # pred: (N, C, H, W) -> (N*H*W, C) + pred = pred.gather(1, tmp_target.unsqueeze(1)) + # pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W) + pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort() + if pred.numel() > 0: + min_value = pred[min(self.min_kept, pred.numel() - 1)] + else: + return score.new_tensor(0.0) + threshold = max(min_value, self.thresh) + + pixel_losses = pixel_losses[mask][ind] + pixel_losses = pixel_losses[pred < threshold] + return self.loss_weight * pixel_losses.mean() + + @property + def loss_name(self): + return self.loss_name_ diff --git a/mmseg/models/losses/tversky_loss.py b/mmseg/models/losses/tversky_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..bfca1af6669e3ac328492da11758a084999ef906 --- /dev/null +++ b/mmseg/models/losses/tversky_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from +https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333 +(Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + tversky_loss = binary_tversky_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + alpha=alpha, + beta=beta, + smooth=smooth) + if class_weight is not None: + tversky_loss *= class_weight[i] + total_loss += tversky_loss + return total_loss / num_classes + + +@weighted_loss +def binary_tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) + FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1) + FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1) + tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth) + + return 1 - tversky + + +@LOSSES.register_module() +class TverskyLoss(nn.Module): + """TverskyLoss. This loss is proposed in `Tversky loss function for image + segmentation using 3D fully convolutional deep networks. + + `_. + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + alpha(float, in [0, 1]): + The coefficient of false positives. Default: 0.3. + beta (float, in [0, 1]): + The coefficient of false negatives. Default: 0.7. + Note: alpha + beta = 1. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_tversky'. + """ + + def __init__(self, + smooth=1, + class_weight=None, + loss_weight=1.0, + ignore_index=255, + alpha=0.3, + beta=0.7, + loss_name='loss_tversky'): + super().__init__() + self.smooth = smooth + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' + self.alpha = alpha + self.beta = beta + self._loss_name = loss_name + + def forward(self, pred, target, **kwargs): + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * tversky_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + alpha=self.alpha, + beta=self.beta, + smooth=self.smooth, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/utils.py b/mmseg/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f74efcf35cec5aba06dcf564e23bde5a4e811423 --- /dev/null +++ b/mmseg/models/losses/utils.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.fileio import load + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = load(class_weight) + + return class_weight + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff03186a92b78f942e79cff9eec9f5e2784c359a --- /dev/null +++ b/mmseg/models/necks/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .featurepyramid import Feature2Pyramid +from .fpn import FPN +from .ic_neck import ICNeck +from .jpu import JPU +from .mla_neck import MLANeck +from .multilevel_neck import MultiLevelNeck + +__all__ = [ + 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid' +] diff --git a/mmseg/models/necks/__pycache__/__init__.cpython-310.pyc b/mmseg/models/necks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994ef29d79eb4449153949cac8dc4c2c49f49114 Binary files /dev/null and b/mmseg/models/necks/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/models/necks/__pycache__/featurepyramid.cpython-310.pyc b/mmseg/models/necks/__pycache__/featurepyramid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02c2ad60a5975facee56a69f1ddf7006aff3b1f6 Binary files /dev/null and b/mmseg/models/necks/__pycache__/featurepyramid.cpython-310.pyc differ diff --git a/mmseg/models/necks/__pycache__/fpn.cpython-310.pyc b/mmseg/models/necks/__pycache__/fpn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a0b0f7cabac9d2f386491a04951031f1612093b Binary files /dev/null and b/mmseg/models/necks/__pycache__/fpn.cpython-310.pyc differ diff --git a/mmseg/models/necks/__pycache__/ic_neck.cpython-310.pyc b/mmseg/models/necks/__pycache__/ic_neck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b0796b9ac80b097a235411e36678c4862635534 Binary files /dev/null and b/mmseg/models/necks/__pycache__/ic_neck.cpython-310.pyc differ diff --git a/mmseg/models/necks/__pycache__/jpu.cpython-310.pyc b/mmseg/models/necks/__pycache__/jpu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63ec2580786a538ce9254404437699b4ffffb44c Binary files /dev/null and b/mmseg/models/necks/__pycache__/jpu.cpython-310.pyc differ diff --git a/mmseg/models/necks/__pycache__/mla_neck.cpython-310.pyc b/mmseg/models/necks/__pycache__/mla_neck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01ee15d36886aa7d03532aee03af79c6362d15bd Binary files /dev/null and b/mmseg/models/necks/__pycache__/mla_neck.cpython-310.pyc differ diff --git a/mmseg/models/necks/__pycache__/multilevel_neck.cpython-310.pyc b/mmseg/models/necks/__pycache__/multilevel_neck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f39bd021af0a0536f5494d6ee146f6d982a6a3b Binary files /dev/null and b/mmseg/models/necks/__pycache__/multilevel_neck.cpython-310.pyc differ diff --git a/mmseg/models/necks/featurepyramid.py b/mmseg/models/necks/featurepyramid.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1250d39dafcf78880aa282bcba4215520ad94e --- /dev/null +++ b/mmseg/models/necks/featurepyramid.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_norm_layer + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class Feature2Pyramid(nn.Module): + """Feature2Pyramid. + + A neck structure connect ViT backbone and decoder_heads. + + Args: + embed_dims (int): Embedding dimension. + rescales (list[float]): Different sampling multiples were + used to obtain pyramid features. Default: [4, 2, 1, 0.5]. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + embed_dim, + rescales=[4, 2, 1, 0.5], + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.rescales = rescales + self.upsample_4x = None + for k in self.rescales: + if k == 4: + self.upsample_4x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + build_norm_layer(norm_cfg, embed_dim)[1], + nn.GELU(), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + elif k == 2: + self.upsample_2x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2)) + elif k == 1: + self.identity = nn.Identity() + elif k == 0.5: + self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) + elif k == 0.25: + self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) + else: + raise KeyError(f'invalid {k} for feature2pyramid') + + def forward(self, inputs): + assert len(inputs) == len(self.rescales) + outputs = [] + if self.upsample_4x is not None: + ops = [ + self.upsample_4x, self.upsample_2x, self.identity, + self.downsample_2x + ] + else: + ops = [ + self.upsample_2x, self.identity, self.downsample_2x, + self.downsample_4x + ] + for i in range(len(inputs)): + outputs.append(ops[i](inputs[i])) + return tuple(outputs) diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..ddab74c00a262a89031fda44824c5de0e2e9a362 --- /dev/null +++ b/mmseg/models/necks/fpn.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class FPN(BaseModule): + """Feature Pyramid Network. + + This neck is the implementation of `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super().__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/mmseg/models/necks/ic_neck.py b/mmseg/models/necks/ic_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..9763541e0980cb0ec53a342b656e64c99d87ed7e --- /dev/null +++ b/mmseg/models/necks/ic_neck.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class CascadeFeatureFusion(BaseModule): + """Cascade Feature Fusion Unit in ICNet. + + Args: + low_channels (int): The number of input channels for + low resolution feature map. + high_channels (int): The number of input channels for + high resolution feature map. + out_channels (int): The number of output channels. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Returns: + x (Tensor): The output tensor of shape (N, out_channels, H, W). + x_low (Tensor): The output tensor of shape (N, out_channels, H, W) + for Cascade Label Guidance in auxiliary heads. + """ + + def __init__(self, + low_channels, + high_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.conv_low = ConvModule( + low_channels, + out_channels, + 3, + padding=2, + dilation=2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_high = ConvModule( + high_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x_low, x_high): + x_low = resize( + x_low, + size=x_high.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + # Note: Different from original paper, `x_low` is underwent + # `self.conv_low` rather than another 1x1 conv classifier + # before being used for auxiliary head. + x_low = self.conv_low(x_low) + x_high = self.conv_high(x_high) + x = x_low + x_high + x = F.relu(x, inplace=True) + return x, x_low + + +@MODELS.register_module() +class ICNeck(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This head is the implementation of `ICHead + `_. + + Args: + in_channels (int): The number of input image channels. Default: 3. + out_channels (int): The numbers of output feature channels. + Default: 128. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(64, 256, 256), + out_channels=128, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(in_channels) == 3, 'Length of input channels \ + must be 3!' + + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.cff_24 = CascadeFeatureFusion( + self.in_channels[2], + self.in_channels[1], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + self.cff_12 = CascadeFeatureFusion( + self.out_channels, + self.in_channels[0], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, inputs): + assert len(inputs) == 3, 'Length of input feature \ + maps must be 3!' + + x_sub1, x_sub2, x_sub4 = inputs + x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) + x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1) + # Note: `x_cff_12` is used for decode_head, + # `x_24` and `x_12` are used for auxiliary head. + return x_24, x_12, x_cff_12 diff --git a/mmseg/models/necks/jpu.py b/mmseg/models/necks/jpu.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea0fe2183377d3e3c1a87ca8a0df909b123cdfa --- /dev/null +++ b/mmseg/models/necks/jpu.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class JPU(BaseModule): + """FastFCN: Rethinking Dilated Convolution in the Backbone + for Semantic Segmentation. + + This Joint Pyramid Upsampling (JPU) neck is the implementation of + `FastFCN `_. + + Args: + in_channels (Tuple[int], optional): The number of input channels + for each convolution operations before upsampling. + Default: (512, 1024, 2048). + mid_channels (int): The number of output channels of JPU. + Default: 512. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + dilations (tuple[int]): Dilation rate of each Depthwise + Separable ConvModule. Default: (1, 2, 4, 8). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(512, 1024, 2048), + mid_channels=512, + start_level=0, + end_level=-1, + dilations=(1, 2, 4, 8), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, tuple) + assert isinstance(dilations, tuple) + self.in_channels = in_channels + self.mid_channels = mid_channels + self.start_level = start_level + self.num_ins = len(in_channels) + if end_level == -1: + self.backbone_end_level = self.num_ins + else: + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + + self.dilations = dilations + self.align_corners = align_corners + + self.conv_layers = nn.ModuleList() + self.dilation_layers = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + conv_layer = nn.Sequential( + ConvModule( + self.in_channels[i], + self.mid_channels, + kernel_size=3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.conv_layers.append(conv_layer) + for i in range(len(dilations)): + dilation_layer = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=(self.backbone_end_level - self.start_level) * + self.mid_channels, + out_channels=self.mid_channels, + kernel_size=3, + stride=1, + padding=dilations[i], + dilation=dilations[i], + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=act_cfg)) + self.dilation_layers.append(dilation_layer) + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels), 'Length of inputs must \ + be the same with self.in_channels!' + + feats = [ + self.conv_layers[i - self.start_level](inputs[i]) + for i in range(self.start_level, self.backbone_end_level) + ] + + h, w = feats[0].shape[2:] + for i in range(1, len(feats)): + feats[i] = resize( + feats[i], + size=(h, w), + mode='bilinear', + align_corners=self.align_corners) + + feat = torch.cat(feats, dim=1) + concat_feat = torch.cat([ + self.dilation_layers[i](feat) for i in range(len(self.dilations)) + ], + dim=1) + + outs = [] + + # Default: outs[2] is the output of JPU for decoder head, outs[1] is + # the feature map from backbone for auxiliary head. Additionally, + # outs[0] can also be used for auxiliary head. + for i in range(self.start_level, self.backbone_end_level - 1): + outs.append(inputs[i]) + outs.append(concat_feat) + return tuple(outs) diff --git a/mmseg/models/necks/mla_neck.py b/mmseg/models/necks/mla_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..db250aefbfa45beaa98855be79ddc7f5e7276cca --- /dev/null +++ b/mmseg/models/necks/mla_neck.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.registry import MODELS + + +class MLAModule(nn.Module): + + def __init__(self, + in_channels=[1024, 1024, 1024, 1024], + out_channels=256, + norm_cfg=None, + act_cfg=None): + super().__init__() + self.channel_proj = nn.ModuleList() + for i in range(len(in_channels)): + self.channel_proj.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.feat_extract = nn.ModuleList() + for i in range(len(in_channels)): + self.feat_extract.append( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + + # feat_list -> [p2, p3, p4, p5] + feat_list = [] + for x, conv in zip(inputs, self.channel_proj): + feat_list.append(conv(x)) + + # feat_list -> [p5, p4, p3, p2] + # mid_list -> [m5, m4, m3, m2] + feat_list = feat_list[::-1] + mid_list = [] + for feat in feat_list: + if len(mid_list) == 0: + mid_list.append(feat) + else: + mid_list.append(mid_list[-1] + feat) + + # mid_list -> [m5, m4, m3, m2] + # out_list -> [o2, o3, o4, o5] + out_list = [] + for mid, conv in zip(mid_list, self.feat_extract): + out_list.append(conv(mid)) + + return tuple(out_list) + + +@MODELS.register_module() +class MLANeck(nn.Module): + """Multi-level Feature Aggregation. + + This neck is `The Multi-level Feature Aggregation construction of + SETR `_. + + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + norm_cfg=None, + act_cfg=None): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + + # In order to build general vision transformer backbone, we have to + # move MLA to neck. + self.norm = nn.ModuleList([ + build_norm_layer(norm_layer, in_channels[i])[1] + for i in range(len(in_channels)) + ]) + + self.mla = MLAModule( + in_channels=in_channels, + out_channels=out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # Convert from nchw to nlc + outs = [] + for i in range(len(inputs)): + x = inputs[i] + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm[i](x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + outs.append(x) + + outs = self.mla(outs) + return tuple(outs) diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py new file mode 100644 index 0000000000000000000000000000000000000000..c997125f24791b1c01248c60a27fa37a986c6c82 --- /dev/null +++ b/mmseg/models/necks/multilevel_neck.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model.weight_init import xavier_init + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class MultiLevelNeck(nn.Module): + """MultiLevelNeck. + + A neck structure connect vit backbone and decoder_heads. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + scales (List[float]): Scale factors for each input feature map. + Default: [0.5, 1, 2, 4] + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.scales = scales + self.num_outs = len(scales) + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(self.num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # for len(inputs) not equal to self.num_outs + if len(inputs) == 1: + inputs = [inputs[0] for _ in range(self.num_outs)] + outs = [] + for i in range(self.num_outs): + x_resize = resize( + inputs[i], scale_factor=self.scales[i], mode='bilinear') + outs.append(self.convs[i](x_resize)) + return tuple(outs) diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fec0d52c3a43cec4dab46080e5e01f83f06c3d27 --- /dev/null +++ b/mmseg/models/segmentors/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseSegmentor +from .cascade_encoder_decoder import CascadeEncoderDecoder +from .encoder_decoder import EncoderDecoder +from .seg_tta import SegTTAModel + +__all__ = [ + 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel' +] diff --git a/mmseg/models/segmentors/__pycache__/__init__.cpython-310.pyc b/mmseg/models/segmentors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cddd0c33d991ec1a66953dda11aaaf22ba9ceaed Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/base.cpython-310.pyc b/mmseg/models/segmentors/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb934587e28913796fe654e7fd02ab027dd33a22 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/base.cpython-310.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-310.pyc b/mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deb64559c91783b67000e348e89300275177a333 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/cascade_encoder_decoder.cpython-310.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-310.pyc b/mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f395fa234564af0bffc8e76a4e641610984aca1 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/encoder_decoder.cpython-310.pyc differ diff --git a/mmseg/models/segmentors/__pycache__/seg_tta.cpython-310.pyc b/mmseg/models/segmentors/__pycache__/seg_tta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03aa704f0368631c50ecab2730692cf347606b55 Binary files /dev/null and b/mmseg/models/segmentors/__pycache__/seg_tta.cpython-310.pyc differ diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..25487de5ab8aa975a18d376b312972090ffa8f24 --- /dev/null +++ b/mmseg/models/segmentors/base.py @@ -0,0 +1,199 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Tuple + +from mmengine.model import BaseModel +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.structures import SegDataSample +from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig, + OptSampleList, SampleList) +from ..utils import resize + + +class BaseSegmentor(BaseModel, metaclass=ABCMeta): + """Base class for segmentors. + + Args: + data_preprocessor (dict, optional): Model preprocessing config + for processing the input data. it usually includes + ``to_rgb``, ``pad_size_divisor``, ``pad_val``, + ``mean`` and ``std``. Default to None. + init_cfg (dict, optional): the config to control the + initialization. Default to None. + """ + + def __init__(self, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + @property + def with_neck(self) -> bool: + """bool: whether the segmentor has neck""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_auxiliary_head(self) -> bool: + """bool: whether the segmentor has auxiliary head""" + return hasattr(self, + 'auxiliary_head') and self.auxiliary_head is not None + + @property + def with_decode_head(self) -> bool: + """bool: whether the segmentor has decode head""" + return hasattr(self, 'decode_head') and self.decode_head is not None + + @abstractmethod + def extract_feat(self, inputs: Tensor) -> bool: + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList): + """Placeholder for encode images with backbone and decode into a + semantic segmentation map of the same size as input.""" + pass + + def forward(self, + inputs: Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`SegDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C, ...) in + general. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + elif mode == 'tensor': + return self._forward(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + @abstractmethod + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + pass + + @abstractmethod + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing.""" + pass + + @abstractmethod + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. + + Usually includes backbone, neck and head forward without any post- + processing. + """ + pass + + def postprocess_result(self, + seg_logits: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + seg_logits (Tensor): The segmentation results, seg_logits from + model of each input image. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + batch_size, C, H, W = seg_logits.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_seg_logits shape is 1, C, H, W after remove padding + i_seg_logits = seg_logits[i:i + 1, :, + padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_seg_logits = i_seg_logits.flip(dims=(3, )) + else: + i_seg_logits = i_seg_logits.flip(dims=(2, )) + + # resize as original shape + i_seg_logits = resize( + i_seg_logits, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_seg_logits = seg_logits[i] + + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) + data_samples[i].set_data({ + 'seg_logits': + PixelData(**{'data': i_seg_logits}), + 'pred_sem_seg': + PixelData(**{'data': i_seg_pred}) + }) + + return data_samples diff --git a/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmseg/models/segmentors/cascade_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0184a3533a18cbe96a28bbb645c3e73bbffcdeee --- /dev/null +++ b/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from torch import Tensor, nn + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .encoder_decoder import EncoderDecoder + + +@MODELS.register_module() +class CascadeEncoderDecoder(EncoderDecoder): + """Cascade Encoder Decoder segmentors. + + CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of + CascadeEncoderDecoder are cascaded. The output of previous decoder_head + will be the input of next decoder_head. + + Args: + + num_stages (int): How many stages will be cascaded. + backbone (ConfigType): The config for the backnone of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + neck (OptConfigType): The config for the neck of segmentor. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + segmentor. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ + + def __init__(self, + num_stages: int, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + self.num_stages = num_stages + super().__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + pretrained=pretrained, + init_cfg=init_cfg) + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + assert isinstance(decode_head, list) + assert len(decode_head) == self.num_stages + self.decode_head = nn.ModuleList() + for i in range(self.num_stages): + self.decode_head.append(MODELS.build(decode_head[i])) + self.align_corners = self.decode_head[-1].align_corners + self.num_classes = self.decode_head[-1].num_classes + self.out_channels = self.decode_head[-1].out_channels + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(inputs) + out = self.decode_head[0].forward(x) + for i in range(1, self.num_stages - 1): + out = self.decode_head[i].forward(x, out) + seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas, + self.test_cfg) + + return seg_logits_list + + def _decode_head_forward_train(self, inputs: Tensor, + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + + loss_decode = self.decode_head[0].loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode_0')) + # get batch_img_metas + batch_size = len(data_samples) + batch_img_metas = [] + for batch_index in range(batch_size): + metainfo = data_samples[batch_index].metainfo + batch_img_metas.append(metainfo) + + for i in range(1, self.num_stages): + # forward test again, maybe unnecessary for most methods. + if i == 1: + prev_outputs = self.decode_head[0].forward(inputs) + else: + prev_outputs = self.decode_head[i - 1].forward( + inputs, prev_outputs) + loss_decode = self.decode_head[i].loss(inputs, prev_outputs, + data_samples, + self.train_cfg) + losses.update(add_prefix(loss_decode, f'decode_{i}')) + + return losses + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_semantic_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + + out = self.decode_head[0].forward(x) + for i in range(1, self.num_stages): + # TODO support PointRend tensor mode + out = self.decode_head[i].forward(x, out) + + return out diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8db3ec7de41f0b29f84a58f6f59c58db5d8dd4 --- /dev/null +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -0,0 +1,357 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .base import BaseSegmentor + + +@MODELS.register_module() +class EncoderDecoder(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional) + _decode_head_forward_train(): decode_head.loss() + _auxiliary_head_forward_train(): auxiliary_head.loss (optional) + + 2. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) Run inference function to obtain the list of + seg_logits (2) Call post-processing function to obtain list of + ``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``. + + .. code:: text + + predict(): inference() -> postprocess_result() + infercen(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + backbone (ConfigType): The config for the backnone of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + neck (OptConfigType): The config for the neck of segmentor. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + segmentor. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + self.decode_head = MODELS.build(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None: + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(MODELS.build(head_cfg)) + else: + self.auxiliary_head = MODELS.build(auxiliary_head) + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract features from images.""" + x = self.backbone(inputs) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(inputs) + seg_logits = self.decode_head.predict(x, batch_img_metas, + self.test_cfg) + + return seg_logits + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _auxiliary_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.loss(inputs, data_samples, + self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(inputs) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, data_samples) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, data_samples) + losses.update(loss_aux) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_sem_seg`. + + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + + return seg_logits + + def whole_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference with full image. + + Args: + inputs (Tensor): The tensor should have a shape NxCxHxW, which + contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + seg_logits = self.encode_decode(inputs, batch_img_metas) + + return seg_logits + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = batch_img_metas[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(inputs, batch_img_metas) + else: + seg_logit = self.whole_inference(inputs, batch_img_metas) + + return seg_logit + + def aug_test(self, inputs, batch_img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale) + for i in range(1, len(inputs)): + cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], + rescale) + seg_logit += cur_seg_logit + seg_logit /= len(inputs) + seg_pred = seg_logit.argmax(dim=1) + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/mmseg/models/segmentors/seg_tta.py b/mmseg/models/segmentors/seg_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..eacb6c00a9a398c226df12b64a64997f28a35bb7 --- /dev/null +++ b/mmseg/models/segmentors/seg_tta.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from mmengine.model import BaseTTAModel +from mmengine.structures import PixelData + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList + + +@MODELS.register_module() +class SegTTAModel(BaseTTAModel): + + def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[SampleList]): List of predictions + of all enhanced data. + + Returns: + SampleList: Merged prediction. + """ + predictions = [] + for data_samples in data_samples_list: + seg_logits = data_samples[0].seg_logits.data + logits = torch.zeros(seg_logits.shape).to(seg_logits) + for data_sample in data_samples: + seg_logit = data_sample.seg_logits.data + if self.module.out_channels > 1: + logits += seg_logit.softmax(dim=0) + else: + logits += seg_logit.sigmoid() + logits /= len(data_samples) + if self.module.out_channels == 1: + seg_pred = (logits > self.module.decode_head.threshold + ).to(logits).squeeze(1) + else: + seg_pred = logits.argmax(dim=0) + data_sample = SegDataSample( + **{ + 'pred_sem_seg': PixelData(data=seg_pred), + 'gt_sem_seg': data_samples[0].gt_sem_seg + }) + predictions.append(data_sample) + return predictions diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc142f16fc9c0d485973c26448468541c8b7e48a --- /dev/null +++ b/mmseg/models/utils/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .basic_block import BasicBlock, Bottleneck +from .embed import PatchEmbed +from .encoding import Encoding +from .inverted_residual import InvertedResidual, InvertedResidualV3 +from .make_divisible import make_divisible +from .ppm import DAPPM, PAPPM +from .res_layer import ResLayer +from .se_layer import SELayer +from .self_attention_block import SelfAttentionBlock +from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, + nlc_to_nchw) +from .up_conv_block import UpConvBlock +from .wrappers import Upsample, resize + +__all__ = [ + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', + 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding', + 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck' +] diff --git a/mmseg/models/utils/__pycache__/__init__.cpython-310.pyc b/mmseg/models/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..520b10adb2590a085f2f0ec86ce875211eee8c0a Binary files /dev/null and b/mmseg/models/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/basic_block.cpython-310.pyc b/mmseg/models/utils/__pycache__/basic_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3065dd543d851664126de4d370db5996baa94daf Binary files /dev/null and b/mmseg/models/utils/__pycache__/basic_block.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/embed.cpython-310.pyc b/mmseg/models/utils/__pycache__/embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87a1fe7e539af79e7c4f5b2fa4897a60fd4e1964 Binary files /dev/null and b/mmseg/models/utils/__pycache__/embed.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/encoding.cpython-310.pyc b/mmseg/models/utils/__pycache__/encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1642cb9f9ed2bab4acdccb8a61614a97057953f Binary files /dev/null and b/mmseg/models/utils/__pycache__/encoding.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/inverted_residual.cpython-310.pyc b/mmseg/models/utils/__pycache__/inverted_residual.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a78394c38f5d9e97e79a23c04aa51b0b9087c2ea Binary files /dev/null and b/mmseg/models/utils/__pycache__/inverted_residual.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/make_divisible.cpython-310.pyc b/mmseg/models/utils/__pycache__/make_divisible.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc5ed54d02b0df3d66bda179d4d2df544f1078f5 Binary files /dev/null and b/mmseg/models/utils/__pycache__/make_divisible.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/ppm.cpython-310.pyc b/mmseg/models/utils/__pycache__/ppm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41ab806338b73f30aeba9e5d24c0745f634b5a18 Binary files /dev/null and b/mmseg/models/utils/__pycache__/ppm.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/res_layer.cpython-310.pyc b/mmseg/models/utils/__pycache__/res_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49daf19d038e44e1a4f054fac72c55142e1b4bc7 Binary files /dev/null and b/mmseg/models/utils/__pycache__/res_layer.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/se_layer.cpython-310.pyc b/mmseg/models/utils/__pycache__/se_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc2222d7ae1e496e93806b10b1ff3108b4215ad6 Binary files /dev/null and b/mmseg/models/utils/__pycache__/se_layer.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/self_attention_block.cpython-310.pyc b/mmseg/models/utils/__pycache__/self_attention_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18a8d6d593238179354767a954764b6cfefd5514 Binary files /dev/null and b/mmseg/models/utils/__pycache__/self_attention_block.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/shape_convert.cpython-310.pyc b/mmseg/models/utils/__pycache__/shape_convert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28c67eef6b9ec3dbb552ed19a1e0f4b6c724eb46 Binary files /dev/null and b/mmseg/models/utils/__pycache__/shape_convert.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/up_conv_block.cpython-310.pyc b/mmseg/models/utils/__pycache__/up_conv_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a7a62a7dd930ae605d6e036f5d4788ba67a6881 Binary files /dev/null and b/mmseg/models/utils/__pycache__/up_conv_block.cpython-310.pyc differ diff --git a/mmseg/models/utils/__pycache__/wrappers.cpython-310.pyc b/mmseg/models/utils/__pycache__/wrappers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fd7cbfe69e2653dcf638eec207353e85413b23f Binary files /dev/null and b/mmseg/models/utils/__pycache__/wrappers.cpython-310.pyc differ diff --git a/mmseg/models/utils/basic_block.py b/mmseg/models/utils/basic_block.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1ad8146dd200c5f1e543adf22ada654ee196a4 --- /dev/null +++ b/mmseg/models/utils/basic_block.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType + + +class BasicBlock(BaseModule): + """Basic block from `ResNet `_. + + Args: + in_channels (int): Input channels. + channels (int): Output channels. + stride (int): Stride of the first block. Default: 1. + downsample (nn.Module, optional): Downsample operation on identity. + Default: None. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer in + ConvModule. Default: dict(type='ReLU', inplace=True). + act_cfg_out (dict, optional): Config dict for activation layer at the + last of the block. Default: None. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + expansion = 1 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + downsample: nn.Module = None, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv1 = ConvModule( + in_channels, + channels, + kernel_size=3, + stride=stride, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=None) + self.downsample = downsample + if act_cfg_out: + self.act = MODELS.build(act_cfg_out) + + def forward(self, x: Tensor) -> Tensor: + residual = x + out = self.conv1(x) + out = self.conv2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + + if hasattr(self, 'act'): + out = self.act(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block from `ResNet `_. + + Args: + in_channels (int): Input channels. + channels (int): Output channels. + stride (int): Stride of the first block. Default: 1. + downsample (nn.Module, optional): Downsample operation on identity. + Default: None. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer in + ConvModule. Default: dict(type='ReLU', inplace=True). + act_cfg_out (dict, optional): Config dict for activation layer at + the last of the block. Default: None. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + expansion = 2 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + act_cfg_out: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv1 = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.conv2 = ConvModule( + channels, + channels, + 3, + stride, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv3 = ConvModule( + channels, + channels * self.expansion, + 1, + norm_cfg=norm_cfg, + act_cfg=None) + if act_cfg_out: + self.act = MODELS.build(act_cfg_out) + self.downsample = downsample + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + + if hasattr(self, 'act'): + out = self.act(out) + + return out diff --git a/mmseg/models/utils/embed.py b/mmseg/models/utils/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..aef0a40b0a87bb6616db96fe2c72c19cc6f5b366 --- /dev/null +++ b/mmseg/models/utils/embed.py @@ -0,0 +1,330 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils import to_2tuple + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1. + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super().__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d". + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int, optional): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmengine.ConfigDict`, optional): The Config for + initialization. Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=768, + conv_type='Conv2d', + kernel_size=16, + stride=None, + padding='corner', + dilation=1, + bias=True, + norm_cfg=None, + input_size=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size diff --git a/mmseg/models/utils/encoding.py b/mmseg/models/utils/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4f0574fbc1957cf8da591a0e4befd6d8a125d3 --- /dev/null +++ b/mmseg/models/utils/encoding.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super().__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assignment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assignment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assignment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/mmseg/models/utils/inverted_residual.py b/mmseg/models/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..56190b3bfe7cc8fe98bf34c3812db18dd34a8f02 --- /dev/null +++ b/mmseg/models/utils/inverted_residual.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from torch import nn +from torch.utils import checkpoint as cp + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + **kwargs): + super().__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InvertedResidualV3(nn.Module): + """Inverted Residual Block for MobileNetV3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super().__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=dict( + type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/mmseg/models/utils/make_divisible.py b/mmseg/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..ed42c2eeea2a6aed03a0be5516b8d1ef1139e486 --- /dev/null +++ b/mmseg/models/utils/make_divisible.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/mmseg/models/utils/ppm.py b/mmseg/models/utils/ppm.py new file mode 100644 index 0000000000000000000000000000000000000000..5fe6ff26fae6869b989cecde96af3ceff1a37b38 --- /dev/null +++ b/mmseg/models/utils/ppm.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList, Sequential +from torch import Tensor + + +class DAPPM(BaseModule): + """DAPPM module in `DDRNet `_. + + Args: + in_channels (int): Input channels. + branch_channels (int): Branch channels. + out_channels (int): Output channels. + num_scales (int): Number of scales. + kernel_sizes (list[int]): Kernel sizes of each scale. + strides (list[int]): Strides of each scale. + paddings (list[int]): Paddings of each scale. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer in ConvModule. + Default: dict(order=('norm', 'act', 'conv'), bias=False). + upsample_mode (str): Upsample mode. Default: 'bilinear'. + """ + + def __init__(self, + in_channels: int, + branch_channels: int, + out_channels: int, + num_scales: int, + kernel_sizes: List[int] = [5, 9, 17], + strides: List[int] = [2, 4, 8], + paddings: List[int] = [2, 4, 8], + norm_cfg: Dict = dict(type='BN', momentum=0.1), + act_cfg: Dict = dict(type='ReLU', inplace=True), + conv_cfg: Dict = dict( + order=('norm', 'act', 'conv'), bias=False), + upsample_mode: str = 'bilinear'): + super().__init__() + + self.num_scales = num_scales + self.unsample_mode = upsample_mode + self.in_channels = in_channels + self.branch_channels = branch_channels + self.out_channels = out_channels + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.conv_cfg = conv_cfg + + self.scales = ModuleList([ + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ]) + for i in range(1, num_scales - 1): + self.scales.append( + Sequential(*[ + nn.AvgPool2d( + kernel_size=kernel_sizes[i - 1], + stride=strides[i - 1], + padding=paddings[i - 1]), + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ])) + self.scales.append( + Sequential(*[ + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ])) + self.processes = ModuleList() + for i in range(num_scales - 1): + self.processes.append( + ConvModule( + branch_channels, + branch_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg)) + + self.compression = ConvModule( + branch_channels * num_scales, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + self.shortcut = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + def forward(self, inputs: Tensor): + feats = [] + feats.append(self.scales[0](inputs)) + + for i in range(1, self.num_scales): + feat_up = F.interpolate( + self.scales[i](inputs), + size=inputs.shape[2:], + mode=self.unsample_mode) + feats.append(self.processes[i - 1](feat_up + feats[i - 1])) + + return self.compression(torch.cat(feats, + dim=1)) + self.shortcut(inputs) + + +class PAPPM(DAPPM): + """PAPPM module in `PIDNet `_. + + Args: + in_channels (int): Input channels. + branch_channels (int): Branch channels. + out_channels (int): Output channels. + num_scales (int): Number of scales. + kernel_sizes (list[int]): Kernel sizes of each scale. + strides (list[int]): Strides of each scale. + paddings (list[int]): Paddings of each scale. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', momentum=0.1). + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer in ConvModule. + Default: dict(order=('norm', 'act', 'conv'), bias=False). + upsample_mode (str): Upsample mode. Default: 'bilinear'. + """ + + def __init__(self, + in_channels: int, + branch_channels: int, + out_channels: int, + num_scales: int, + kernel_sizes: List[int] = [5, 9, 17], + strides: List[int] = [2, 4, 8], + paddings: List[int] = [2, 4, 8], + norm_cfg: Dict = dict(type='BN', momentum=0.1), + act_cfg: Dict = dict(type='ReLU', inplace=True), + conv_cfg: Dict = dict( + order=('norm', 'act', 'conv'), bias=False), + upsample_mode: str = 'bilinear'): + super().__init__(in_channels, branch_channels, out_channels, + num_scales, kernel_sizes, strides, paddings, norm_cfg, + act_cfg, conv_cfg, upsample_mode) + + self.processes = ConvModule( + self.branch_channels * (self.num_scales - 1), + self.branch_channels * (self.num_scales - 1), + kernel_size=3, + padding=1, + groups=self.num_scales - 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **self.conv_cfg) + + def forward(self, inputs: Tensor): + x_ = self.scales[0](inputs) + feats = [] + for i in range(1, self.num_scales): + feat_up = F.interpolate( + self.scales[i](inputs), + size=inputs.shape[2:], + mode=self.unsample_mode, + align_corners=False) + feats.append(feat_up + x_) + scale_out = self.processes(torch.cat(feats, dim=1)) + return self.compression(torch.cat([x_, scale_out], + dim=1)) + self.shortcut(inputs) diff --git a/mmseg/models/utils/res_layer.py b/mmseg/models/utils/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd7a6f75a168f2f7e3c61f82d309b1cf0d502bc --- /dev/null +++ b/mmseg/models/utils/res_layer.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import Sequential +from torch import nn as nn + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + multi_grid (int | None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + dilation=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + multi_grid=None, + contract_dilation=False, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if multi_grid is None: + if dilation > 1 and contract_dilation: + first_dilation = dilation // 2 + else: + first_dilation = dilation + else: + first_dilation = multi_grid[0] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=first_dilation, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation if multi_grid is None else multi_grid[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super().__init__(*layers) diff --git a/mmseg/models/utils/se_layer.py b/mmseg/models/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff632cfea728a7ffd99f1578c828c588d78f3db --- /dev/null +++ b/mmseg/models/utils/se_layer.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils import is_tuple_of + +from .make_divisible import make_divisible + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configured + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configured by the first dict and the + second activation layer will be configured by the second dict. + Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)). + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))): + super().__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=make_divisible(channels // ratio, 8), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=make_divisible(channels // ratio, 8), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/mmseg/models/utils/self_attention_block.py b/mmseg/models/utils/self_attention_block.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb6e8284e599637c12553e27199338a820709e3 --- /dev/null +++ b/mmseg/models/utils/self_attention_block.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ConvModule +from mmengine.model.weight_init import constant_init +from torch import nn as nn +from torch.nn import functional as F + + +class SelfAttentionBlock(nn.Module): + """General self-attention block/non-local block. + + Please refer to https://arxiv.org/abs/1706.03762 for details about key, + query and value. + + Args: + key_in_channels (int): Input channels of key feature. + query_in_channels (int): Input channels of query feature. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_downsample (nn.Module): Query downsample module. + key_downsample (nn.Module): Key downsample module. + key_query_num_convs (int): Number of convs for key/query projection. + value_num_convs (int): Number of convs for value projection. + matmul_norm (bool): Whether normalize attention map with sqrt of + channels + with_out (bool): Whether use out projection. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, key_in_channels, query_in_channels, channels, + out_channels, share_key_query, query_downsample, + key_downsample, key_query_num_convs, value_out_num_convs, + key_query_norm, value_out_norm, matmul_norm, with_out, + conv_cfg, norm_cfg, act_cfg): + super().__init__() + if share_key_query: + assert key_in_channels == query_in_channels + self.key_in_channels = key_in_channels + self.query_in_channels = query_in_channels + self.out_channels = out_channels + self.channels = channels + self.share_key_query = share_key_query + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.key_project = self.build_project( + key_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if share_key_query: + self.query_project = self.key_project + else: + self.query_project = self.build_project( + query_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.value_project = self.build_project( + key_in_channels, + channels if with_out else out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if with_out: + self.out_project = self.build_project( + channels, + out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.out_project = None + + self.query_downsample = query_downsample + self.key_downsample = key_downsample + self.matmul_norm = matmul_norm + + self.init_weights() + + def init_weights(self): + """Initialize weight of later layer.""" + if self.out_project is not None: + if not isinstance(self.out_project, ConvModule): + constant_init(self.out_project, 0) + + def build_project(self, in_channels, channels, num_convs, use_conv_module, + conv_cfg, norm_cfg, act_cfg): + """Build projection layer for key/query/value/out.""" + if use_conv_module: + convs = [ + ConvModule( + in_channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ] + for _ in range(num_convs - 1): + convs.append( + ConvModule( + channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + convs = [nn.Conv2d(in_channels, channels, 1)] + for _ in range(num_convs - 1): + convs.append(nn.Conv2d(channels, channels, 1)) + if len(convs) > 1: + convs = nn.Sequential(*convs) + else: + convs = convs[0] + return convs + + def forward(self, query_feats, key_feats): + """Forward function.""" + batch_size = query_feats.size(0) + query = self.query_project(query_feats) + if self.query_downsample is not None: + query = self.query_downsample(query) + query = query.reshape(*query.shape[:2], -1) + query = query.permute(0, 2, 1).contiguous() + + key = self.key_project(key_feats) + value = self.value_project(key_feats) + if self.key_downsample is not None: + key = self.key_downsample(key) + value = self.key_downsample(value) + key = key.reshape(*key.shape[:2], -1) + value = value.reshape(*value.shape[:2], -1) + value = value.permute(0, 2, 1).contiguous() + + sim_map = torch.matmul(query, key) + if self.matmul_norm: + sim_map = (self.channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.reshape(batch_size, -1, *query_feats.shape[2:]) + if self.out_project is not None: + context = self.out_project(context) + return context diff --git a/mmseg/models/utils/shape_convert.py b/mmseg/models/utils/shape_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..cce1e220b645d4b02df1ec2d9ed3137c8acba707 --- /dev/null +++ b/mmseg/models/utils/shape_convert.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): + """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the + reshaped tensor as the input of `module`, and the convert the output of + `module`, whose shape is. + + [N, L, C], to [N, C, H, W]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, L, C] as input. + x (Tensor): The input tensor of shape [N, C, H, W]. + contiguous: + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, C, H, W]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> norm = nn.LayerNorm(4) + >>> feature_map = torch.rand(4, 4, 5, 5) + >>> output = nchw2nlc2nchw(norm, feature_map) + """ + B, C, H, W = x.shape + if not contiguous: + x = x.flatten(2).transpose(1, 2) + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W) + else: + x = x.flatten(2).transpose(1, 2).contiguous() + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + return x + + +def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): + """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the + reshaped tensor as the input of `module`, and convert the output of + `module`, whose shape is. + + [N, C, H, W], to [N, L, C]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, C, H, W] as input. + x (Tensor): The input tensor of shape [N, L, C]. + hw_shape: (Sequence[int]): The height and width of the + feature map with shape [N, C, H, W]. + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, L, C]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> conv = nn.Conv2d(16, 16, 3, 1, 1) + >>> feature_map = torch.rand(4, 25, 16) + >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + if not contiguous: + x = x.transpose(1, 2).reshape(B, C, H, W) + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2) + else: + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2).contiguous() + return x diff --git a/mmseg/models/utils/up_conv_block.py b/mmseg/models/utils/up_conv_block.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa3b598de96d53c169232d9c89ac458f6921e8d --- /dev/null +++ b/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_upsample_layer + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out diff --git a/mmseg/models/utils/wrappers.py b/mmseg/models/utils/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..abbd0c029623b4f480a067e4b78adfec234ef8d0 --- /dev/null +++ b/mmseg/models/utils/wrappers.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super().__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/mmseg/registry/__init__.py b/mmseg/registry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee514d1a2a2bdd54a0a9b017ec227160ee502be5 --- /dev/null +++ b/mmseg/registry/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS, + LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, + OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, + PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, + TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, + WEIGHT_INITIALIZERS) + +__all__ = [ + 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', + 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', + 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', + 'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', + 'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS' +] diff --git a/mmseg/registry/__pycache__/__init__.cpython-310.pyc b/mmseg/registry/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef6b27c116181deaf72eefa029f7e1d586365b1 Binary files /dev/null and b/mmseg/registry/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/registry/__pycache__/registry.cpython-310.pyc b/mmseg/registry/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc60aebbcabcffb389f79df38a0efe667b066c5f Binary files /dev/null and b/mmseg/registry/__pycache__/registry.cpython-310.pyc differ diff --git a/mmseg/registry/registry.py b/mmseg/registry/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..32684e758f98f72ddad82a035e500caca28c1dcc --- /dev/null +++ b/mmseg/registry/registry.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMSegmentation provides 17 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry( + 'hook', parent=MMENGINE_HOOKS, locations=['mmseg.engine.hooks']) + +# manage data-related modules +DATASETS = Registry( + 'dataset', parent=MMENGINE_DATASETS, locations=['mmseg.datasets']) +DATA_SAMPLERS = Registry( + 'data sampler', + parent=MMENGINE_DATA_SAMPLERS, + locations=['mmseg.datasets.samplers']) +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmseg.datasets.transforms']) + +# mangage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models']) +# mangage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmseg.models']) +# mangage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmseg.models']) + +# mangage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmseg.engine.optimizers']) +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry( + 'optim_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmseg.engine.optimizers']) +# manage constructors that customize the optimization hyperparameters. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmseg.engine.optimizers']) +# mangage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmseg.engine.schedulers']) + +# manage all kinds of metrics +METRICS = Registry( + 'metric', parent=MMENGINE_METRICS, locations=['mmseg.evaluation']) +# manage evaluator +EVALUATOR = Registry( + 'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmseg.evaluation']) + +# manage task-specific modules like ohem pixel sampler +TASK_UTILS = Registry( + 'task util', parent=MMENGINE_TASK_UTILS, locations=['mmseg.models']) + +# manage visualizer +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmseg.visualization']) +# manage visualizer backend +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmseg.visualization']) + +# manage logprocessor +LOG_PROCESSORS = Registry( + 'log_processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['mmseg.visualization']) + +# manage inferencer +INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS) diff --git a/mmseg/structures/__init__.py b/mmseg/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63d118dca3ebcff30ca241f9378475bcce072627 --- /dev/null +++ b/mmseg/structures/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler +from .seg_data_sample import SegDataSample + +__all__ = [ + 'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler', + 'build_pixel_sampler' +] diff --git a/mmseg/structures/__pycache__/__init__.cpython-310.pyc b/mmseg/structures/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c5dc31c472e24c7b6427f1311e74727f931bbf2 Binary files /dev/null and b/mmseg/structures/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/structures/__pycache__/seg_data_sample.cpython-310.pyc b/mmseg/structures/__pycache__/seg_data_sample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9871ef3bce0081f59f11d5f23d88743a322bf6e Binary files /dev/null and b/mmseg/structures/__pycache__/seg_data_sample.cpython-310.pyc differ diff --git a/mmseg/structures/sampler/__init__.py b/mmseg/structures/sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91d762d1b4552b391ece046fa3d094409011bcec --- /dev/null +++ b/mmseg/structures/sampler/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_pixel_sampler import BasePixelSampler +from .builder import build_pixel_sampler +from .ohem_pixel_sampler import OHEMPixelSampler + +__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] diff --git a/mmseg/structures/sampler/__pycache__/__init__.cpython-310.pyc b/mmseg/structures/sampler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84fba2c21b86eaf22156368be01d8019737846b1 Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/structures/sampler/__pycache__/base_pixel_sampler.cpython-310.pyc b/mmseg/structures/sampler/__pycache__/base_pixel_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f623945e798ec76de8a4b4df56ad733bbf266c5c Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/base_pixel_sampler.cpython-310.pyc differ diff --git a/mmseg/structures/sampler/__pycache__/builder.cpython-310.pyc b/mmseg/structures/sampler/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba95c76b44df606805739c9c979222f423935de2 Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/builder.cpython-310.pyc differ diff --git a/mmseg/structures/sampler/__pycache__/ohem_pixel_sampler.cpython-310.pyc b/mmseg/structures/sampler/__pycache__/ohem_pixel_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7470471ea733248ad0176a3a9d675b2d970d3f57 Binary files /dev/null and b/mmseg/structures/sampler/__pycache__/ohem_pixel_sampler.cpython-310.pyc differ diff --git a/mmseg/structures/sampler/base_pixel_sampler.py b/mmseg/structures/sampler/base_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..03672cd478a2e464cc734ae92686c86f219da0a9 --- /dev/null +++ b/mmseg/structures/sampler/base_pixel_sampler.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" diff --git a/mmseg/structures/sampler/builder.py b/mmseg/structures/sampler/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..48e14790264a3d4c4ff54d84e5bab67b1623a1df --- /dev/null +++ b/mmseg/structures/sampler/builder.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import TASK_UTILS + +PIXEL_SAMPLERS = TASK_UTILS + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + warnings.warn( + '``build_pixel_sampler`` would be deprecated soon, please use ' + '``mmseg.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) diff --git a/mmseg/structures/sampler/ohem_pixel_sampler.py b/mmseg/structures/sampler/ohem_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a974273cab504be269e7f391e23a521b97bd8588 --- /dev/null +++ b/mmseg/structures/sampler/ohem_pixel_sampler.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_pixel_sampler import BasePixelSampler +from .builder import PIXEL_SAMPLERS + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + """Online Hard Example Mining Sampler for segmentation. + + Args: + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + """ + + def __init__(self, context, thresh=None, min_kept=100000): + super().__init__() + self.context = context + assert min_kept > 1 + self.thresh = thresh + self.min_kept = min_kept + + def sample(self, seg_logit, seg_label): + """Sample pixels that have high loss or with low prediction confidence. + + Args: + seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) + seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) + + Returns: + torch.Tensor: segmentation weight, shape (N, H, W) + """ + with torch.no_grad(): + assert seg_logit.shape[2:] == seg_label.shape[2:] + assert seg_label.shape[1] == 1 + seg_label = seg_label.squeeze(1).long() + batch_kept = self.min_kept * seg_label.size(0) + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) + + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() + + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + else: + if not isinstance(self.context.loss_decode, nn.ModuleList): + losses_decode = [self.context.loss_decode] + else: + losses_decode = self.context.loss_decode + losses = 0.0 + for loss_module in losses_decode: + losses += loss_module( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. + + seg_weight[valid_mask] = valid_seg_weight + + return seg_weight diff --git a/mmseg/structures/seg_data_sample.py b/mmseg/structures/seg_data_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..ce68b5474330e2149d7d1c4de2d2406ae5b0345e --- /dev/null +++ b/mmseg/structures/seg_data_sample.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.structures import BaseDataElement, PixelData + + +class SegDataSample(BaseDataElement): + """A data structure interface of MMSegmentation. They are used as + interfaces between different components. + + The attributes in ``SegDataSample`` are divided into several parts: + + - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation. + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic segmentation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.structures import PixelData + >>> from mmseg.structures import SegDataSample + + >>> data_sample = SegDataSample() + >>> img_meta = dict(img_shape=(4, 4, 3), + ... pad_shape=(4, 4, 3)) + >>> gt_segmentations = PixelData(metainfo=img_meta) + >>> gt_segmentations.data = torch.randint(0, 2, (1, 4, 4)) + >>> data_sample.gt_sem_seg = gt_segmentations + >>> assert 'img_shape' in data_sample.gt_sem_seg.metainfo_keys() + >>> data_sample.gt_sem_seg.shape + (4, 4) + >>> print(data_sample) + + ) at 0x1c2aae44d60> + + >>> data_sample = SegDataSample() + >>> gt_sem_seg_data = dict(sem_seg=torch.rand(1, 4, 4)) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> data_sample.gt_sem_seg = gt_sem_seg + >>> assert 'gt_sem_seg' in data_sample + >>> assert 'sem_seg' in data_sample.gt_sem_seg + """ + + @property + def gt_sem_seg(self) -> PixelData: + return self._gt_sem_seg + + @gt_sem_seg.setter + def gt_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_gt_sem_seg', dtype=PixelData) + + @gt_sem_seg.deleter + def gt_sem_seg(self) -> None: + del self._gt_sem_seg + + @property + def pred_sem_seg(self) -> PixelData: + return self._pred_sem_seg + + @pred_sem_seg.setter + def pred_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_pred_sem_seg', dtype=PixelData) + + @pred_sem_seg.deleter + def pred_sem_seg(self) -> None: + del self._pred_sem_seg + + @property + def seg_logits(self) -> PixelData: + return self._seg_logits + + @seg_logits.setter + def seg_logits(self, value: PixelData) -> None: + self.set_field(value, '_seg_logits', dtype=PixelData) + + @seg_logits.deleter + def seg_logits(self) -> None: + del self._seg_logits diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1436c1980f8258dae17aa1394170d2f91cb382 --- /dev/null +++ b/mmseg/utils/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# yapf: disable +from .class_names import (ade_classes, ade_palette, cityscapes_classes, + cityscapes_palette, cocostuff_classes, + cocostuff_palette, dataset_aliases, get_classes, + get_palette, isaid_classes, isaid_palette, + loveda_classes, loveda_palette, potsdam_classes, + potsdam_palette, stare_classes, stare_palette, + synapse_classes, synapse_palette, vaihingen_classes, + vaihingen_palette, voc_classes, voc_palette) +# yapf: enable +from .collect_env import collect_env +from .io import datafrombytes +from .misc import add_prefix, stack_batch +from .set_env import register_all_modules +from .typing_utils import (ConfigType, ForwardResults, MultiConfig, + OptConfigType, OptMultiConfig, OptSampleList, + SampleList, TensorDict, TensorList) + +__all__ = [ + 'collect_env', 'register_all_modules', 'stack_batch', 'add_prefix', + 'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig', + 'SampleList', 'OptSampleList', 'TensorDict', 'TensorList', + 'ForwardResults', 'cityscapes_classes', 'ade_classes', 'voc_classes', + 'cocostuff_classes', 'loveda_classes', 'potsdam_classes', + 'vaihingen_classes', 'isaid_classes', 'stare_classes', + 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette', + 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', + 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', + 'datafrombytes', 'synapse_palette', 'synapse_classes' +] diff --git a/mmseg/utils/__pycache__/__init__.cpython-310.pyc b/mmseg/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19c7a2d2dfd3b041c1e94db81a73d6c6dea7ff8d Binary files /dev/null and b/mmseg/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/mmseg/utils/__pycache__/class_names.cpython-310.pyc b/mmseg/utils/__pycache__/class_names.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d303cb78e176fc6d51f5a8fcafc15eeeebdcdb6 Binary files /dev/null and b/mmseg/utils/__pycache__/class_names.cpython-310.pyc differ diff --git a/mmseg/utils/__pycache__/collect_env.cpython-310.pyc b/mmseg/utils/__pycache__/collect_env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eee509643e14d1ed25160cd2e09de7bb3292844 Binary files /dev/null and b/mmseg/utils/__pycache__/collect_env.cpython-310.pyc differ diff --git a/mmseg/utils/__pycache__/io.cpython-310.pyc b/mmseg/utils/__pycache__/io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50fbd78892b7403280f8c819fc27f20b74d8920f Binary files /dev/null and b/mmseg/utils/__pycache__/io.cpython-310.pyc differ diff --git a/mmseg/utils/__pycache__/misc.cpython-310.pyc b/mmseg/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..164696fab81562d72fb081073cd7c6495b7f9afa Binary files /dev/null and b/mmseg/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/mmseg/utils/__pycache__/set_env.cpython-310.pyc b/mmseg/utils/__pycache__/set_env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6be5fd343260008a299452ab20f82aba52e76584 Binary files /dev/null and b/mmseg/utils/__pycache__/set_env.cpython-310.pyc differ diff --git a/mmseg/utils/__pycache__/typing_utils.cpython-310.pyc b/mmseg/utils/__pycache__/typing_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ee5a608c52aca39952967ae402be0d1feecebd1 Binary files /dev/null and b/mmseg/utils/__pycache__/typing_utils.cpython-310.pyc differ diff --git a/mmseg/utils/class_names.py b/mmseg/utils/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..961a08520d212bfd3f7aac551e0bd73c0dc150c4 --- /dev/null +++ b/mmseg/utils/class_names.py @@ -0,0 +1,473 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import is_str + + +def cityscapes_classes(): + """Cityscapes class names for external use.""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def ade_classes(): + """ADE20K class names for external use.""" + return [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' + ] + + +def voc_classes(): + """Pascal VOC class names for external use.""" + return [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor' + ] + + +def cocostuff_classes(): + """CocoStuff class names for external use.""" + return [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper', + 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', + 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', + 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', + 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', + 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', + 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', + 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood' + ] + + +def loveda_classes(): + """LoveDA class names for external use.""" + return [ + 'background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural' + ] + + +def potsdam_classes(): + """Potsdam class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def vaihingen_classes(): + """Vaihingen class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def isaid_classes(): + """iSAID class names for external use.""" + return [ + 'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court', + 'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle', + 'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout', + 'Soccer_ball_field', 'plane', 'Harbor' + ] + + +def stare_classes(): + """stare class names for external use.""" + return ['background', 'vessel'] + + +def mapillary_v1_classes(): + """mapillary_v1 class names for external use.""" + return [ + 'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', + 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', + 'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', + 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General', + 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', + 'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', + 'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', + 'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame', + 'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)', + 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', + 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', + 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled' + ] + + +def mapillary_v1_palette(): + """mapillary_v1_ palette for external use.""" + return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255], + [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], + [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], + [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], + [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], + [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], + [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], + [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], + [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], + [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], + [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], + [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]] + + +def mapillary_v2_classes(): + """mapillary_v2 class names for external use.""" + return [ + 'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb', + 'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side', + 'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane', + 'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking', + 'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road', + 'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island', + 'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group', + 'Bicyclist', 'Motorcyclist', 'Other Rider', + 'Lane Marking - Dashed Line', 'Lane Marking - Straight Line', + 'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous', + 'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)', + 'Lane Marking - Arrow (Right)', + 'Lane Marking - Arrow (Split Left or Straight)', + 'Lane Marking - Arrow (Split Right or Straight)', + 'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk', + 'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)', + 'Lane Marking - Hatched (Chevron)', + 'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other', + 'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)', + 'Lane Marking - Symbol (Other)', 'Lane Marking - Text', + 'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk', + 'Lane Marking (only) - Other', 'Lane Marking (only) - Test', + 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', + 'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter', + 'Phone Booth', 'Pothole', 'Signage - Advertisement', + 'Signage - Ambiguous', 'Signage - Back', 'Signage - Information', + 'Signage - Other', 'Signage - Store', 'Street Light', 'Pole', + 'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone', + 'Traffic Light - General (Single)', 'Traffic Light - Pedestrians', + 'Traffic Light - General (Upright)', + 'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists', + 'Traffic Light - Other', 'Traffic Sign - Ambiguous', + 'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)', + 'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)', + 'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)', + 'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', + 'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve', + 'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled' + ] + + +def mapillary_v2_palette(): + """mapillary_v2_ palette for external use.""" + return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32], + [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], + [250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35], + [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110], + [244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70], + [150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255], + [255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26], + [250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21], + [250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18], + [250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255], + [250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255], + [255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64], + [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255], + [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30], + [250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30], + [210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128], + [0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30], + [250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30], + [192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0], + [220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0], + [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], + [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], + [0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170], + [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81], + [111, 111, 0], [0, 0, 0]] + + +def cityscapes_palette(): + """Cityscapes palette for external use.""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def ade_palette(): + """ADE20K palette for external use.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def voc_palette(): + """Pascal VOC palette for external use.""" + return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + +def cocostuff_palette(): + """CocoStuff palette for external use.""" + return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], + [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], + [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], + [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], + [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], + [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], + [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0], + [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], + [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32], + [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], + [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], + [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32], + [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], + [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], + [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32], + [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], + [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0], + [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], + [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0], + [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0], + [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96], + [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128], + [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64], + [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96], + [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0], + [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64], + [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96], + [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128], + [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0], + [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32], + [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64], + [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0], + [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32], + [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192], + [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64], + [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32], + [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64], + [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64], + [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32], + [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192], + [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0], + [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96], + [64, 160, 64], [64, 64, 0]] + + +def loveda_palette(): + """LoveDA palette for external use.""" + return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]] + + +def potsdam_palette(): + """Potsdam palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def vaihingen_palette(): + """Vaihingen palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def isaid_palette(): + """iSAID palette for external use.""" + return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, + 127], [0, 0, 127], + [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191], + [0, 127, 255], [0, 100, 155]] + + +def stare_palette(): + """STARE palette for external use.""" + return [[120, 120, 120], [6, 230, 230]] + + +def synapse_palette(): + """Synapse palette for external use.""" + return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255], + [255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]] + + +def synapse_classes(): + """Synapse class names for external use.""" + return [ + 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', + 'liver', 'pancreas', 'spleen', 'stomach' + ] + + +def lip_classes(): + """LIP class names for external use.""" + return [ + 'background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', + 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', + 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe', + 'rightShoe' + ] + + +def lip_palette(): + """LIP palette for external use.""" + return [ + 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes', + 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', + 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', + 'Right-shoe' + ] + + +dataset_aliases = { + 'cityscapes': ['cityscapes'], + 'ade': ['ade', 'ade20k'], + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'], + 'loveda': ['loveda'], + 'potsdam': ['potsdam'], + 'vaihingen': ['vaihingen'], + 'cocostuff': [ + 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff', + 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k', + 'coco_stuff164k' + ], + 'isaid': ['isaid', 'iSAID'], + 'stare': ['stare', 'STARE'], + 'lip': ['LIP', 'lip'], + 'mapillary_v1': ['mapillary_v1'], + 'mapillary_v2': ['mapillary_v2'] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels + + +def get_palette(dataset): + """Get class palette (RGB) of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_palette()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/mmseg/utils/collect_env.py b/mmseg/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d6ea290283e3af2f29475f82d225072cf39d99 --- /dev/null +++ b/mmseg/utils/collect_env.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmseg + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/mmseg/utils/io.py b/mmseg/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..d03517401c5cef499bb5b11f04dfb3fc7b2a8d30 --- /dev/null +++ b/mmseg/utils/io.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import gzip +import io +import pickle + +import numpy as np + + +def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray: + """Data decoding from bytes. + + Args: + content (bytes): The data bytes got from files or other streams. + backend (str): The data decoding backend type. Options are 'numpy', + 'nifti' and 'pickle'. Defaults to 'numpy'. + + Returns: + numpy.ndarray: Loaded data array. + """ + if backend == 'pickle': + data = pickle.loads(content) + else: + with io.BytesIO(content) as f: + if backend == 'nifti': + f = gzip.open(f) + try: + from nibabel import FileHolder, Nifti1Image + except ImportError: + print('nifti files io depends on nibabel, please run' + '`pip install nibabel` to install it') + fh = FileHolder(fileobj=f) + data = Nifti1Image.from_file_map({'header': fh, 'image': fh}) + data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata() + elif backend == 'numpy': + data = np.load(f) + else: + raise ValueError + return data diff --git a/mmseg/utils/misc.py b/mmseg/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..0a561732e9a0bd07b5065aa3ec96e8070117f53d --- /dev/null +++ b/mmseg/utils/misc.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from .typing_utils import SampleList + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs + + +def stack_batch(inputs: List[torch.Tensor], + data_samples: Optional[SampleList] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Union[int, float] = 0, + seg_pad_val: Union[int, float] = 255) -> torch.Tensor: + """Stack multiple inputs to form a batch and pad the images and gt_sem_segs + to the max shape use the right bottom padding mode. + + Args: + inputs (List[Tensor]): The input multiple tensors. each is a + CHW 3D-tensor. + data_samples (list[:obj:`SegDataSample`]): The list of data samples. + It usually includes information such as `gt_sem_seg`. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (int, float): The padding value. Defaults to 0 + seg_pad_val (int, float): The padding value. Defaults to 255 + + Returns: + Tensor: The 4D-tensor. + List[:obj:`SegDataSample`]: After the padding of the gt_seg_map. + """ + assert isinstance(inputs, list), \ + f'Expected input type to be list, but got {type(inputs)}' + assert len({tensor.ndim for tensor in inputs}) == 1, \ + f'Expected the dimensions of all inputs must be the same, ' \ + f'but got {[tensor.ndim for tensor in inputs]}' + assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ + f'but got {inputs[0].ndim}' + assert len({tensor.shape[0] for tensor in inputs}) == 1, \ + f'Expected the channels of all inputs must be the same, ' \ + f'but got {[tensor.shape[0] for tensor in inputs]}' + + # only one of size and size_divisor should be valid + assert (size is not None) ^ (size_divisor is not None), \ + 'only one of size and size_divisor should be valid' + + padded_inputs = [] + padded_samples = [] + inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs] + max_size = np.stack(inputs_sizes).max(0) + if size_divisor is not None and size_divisor > 1: + # the last two dims are H,W, both subject to divisibility requirement + max_size = (max_size + + (size_divisor - 1)) // size_divisor * size_divisor + + for i in range(len(inputs)): + tensor = inputs[i] + if size is not None: + width = max(size[-1] - tensor.shape[-1], 0) + height = max(size[-2] - tensor.shape[-2], 0) + # (padding_left, padding_right, padding_top, padding_bottom) + padding_size = (0, width, 0, height) + elif size_divisor is not None: + width = max(max_size[-1] - tensor.shape[-1], 0) + height = max(max_size[-2] - tensor.shape[-2], 0) + padding_size = (0, width, 0, height) + else: + padding_size = [0, 0, 0, 0] + + # pad img + pad_img = F.pad(tensor, padding_size, value=pad_val) + padded_inputs.append(pad_img) + # pad gt_sem_seg + if data_samples is not None: + data_sample = data_samples[i] + gt_sem_seg = data_sample.gt_sem_seg.data + del data_sample.gt_sem_seg.data + data_sample.gt_sem_seg.data = F.pad( + gt_sem_seg, padding_size, value=seg_pad_val) + if 'gt_edge_map' in data_sample: + gt_edge_map = data_sample.gt_edge_map.data + del data_sample.gt_edge_map.data + data_sample.gt_edge_map.data = F.pad( + gt_edge_map, padding_size, value=seg_pad_val) + data_sample.set_metainfo({ + 'img_shape': tensor.shape[-2:], + 'pad_shape': data_sample.gt_sem_seg.shape, + 'padding_size': padding_size + }) + padded_samples.append(data_sample) + else: + padded_samples.append( + dict( + img_padding_size=padding_size, + pad_shape=pad_img.shape[-2:])) + + return torch.stack(padded_inputs, dim=0), padded_samples diff --git a/mmseg/utils/set_env.py b/mmseg/utils/set_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c948950d62a7463295c1055a27a9a0ce881d9fad --- /dev/null +++ b/mmseg/utils/set_env.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import warnings + +from mmengine import DefaultScope + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmseg into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmseg default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmseg`, and all registries will build modules from mmseg's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmseg.datasets # noqa: F401,F403 + import mmseg.engine # noqa: F401,F403 + import mmseg.evaluation # noqa: F401,F403 + import mmseg.models # noqa: F401,F403 + import mmseg.structures # noqa: F401,F403 + + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmseg') + if never_created: + DefaultScope.get_instance('mmseg', scope_name='mmseg') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmseg': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmseg", ' + '`register_all_modules` will force the current' + 'default scope to be "mmseg". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmseg-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmseg') diff --git a/mmseg/utils/typing_utils.py b/mmseg/utils/typing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fba7d3b92bba8301171d2a0fffadfabfcd112976 --- /dev/null +++ b/mmseg/utils/typing_utils.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Collecting some commonly used type hint in mmflow.""" +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +from mmengine.config import ConfigDict + +from mmseg.structures import SegDataSample + +# Type hint of config data +ConfigType = Union[ConfigDict, dict] +OptConfigType = Optional[ConfigType] +# Type hint of one or more config data +MultiConfig = Union[ConfigType, Sequence[ConfigType]] +OptMultiConfig = Optional[MultiConfig] + +SampleList = Sequence[SegDataSample] +OptSampleList = Optional[SampleList] + +# Type hint of Tensor +TensorDict = Dict[str, torch.Tensor] +TensorList = Sequence[torch.Tensor] + +ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample], + Tuple[torch.Tensor], torch.Tensor] diff --git a/mmseg/version.py b/mmseg/version.py new file mode 100644 index 0000000000000000000000000000000000000000..ef8e391a299c2d44ed39b7c361293ad946e8f715 --- /dev/null +++ b/mmseg/version.py @@ -0,0 +1,18 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '1.0.0rc6' + + +def parse_version_info(version_str): + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/mmseg/visualization/__init__.py b/mmseg/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbb211e5243aafb4ab3d91f6a6f7ce0735b13a9 --- /dev/null +++ b/mmseg/visualization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .local_visualizer import SegLocalVisualizer + +__all__ = ['SegLocalVisualizer'] diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..504004dfcb8a64288b02188cd120e43d2597e27e --- /dev/null +++ b/mmseg/visualization/local_visualizer.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import mmcv +import numpy as np +from mmengine.dist import master_only +from mmengine.structures import PixelData +from mmengine.visualization import Visualizer + +from mmseg.registry import VISUALIZERS +from mmseg.structures import SegDataSample +from mmseg.utils import get_classes, get_palette + + +@VISUALIZERS.register_module() +class SegLocalVisualizer(Visualizer): + """Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + classes (list, optional): Input classes for result rendering, as the + prediction of segmentation model is a segment map with label + indices, `classes` is a list which includes items responding to the + label indices. If classes is not defined, visualizer will take + `cityscapes` classes by default. Defaults to None. + palette (list, optional): Input palette for result rendering, which is + a list of color palette responding to the classes. Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. classes + and palette, but the `classes` and `palette` have higher priority. + Defaults to None. + alpha (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Examples: + >>> import numpy as np + >>> import torch + >>> from mmengine.structures import PixelData + >>> from mmseg.data import SegDataSample + >>> from mmseg.engine.visualization import SegLocalVisualizer + + >>> seg_local_visualizer = SegLocalVisualizer() + >>> image = np.random.randint(0, 256, + ... size=(10, 12, 3)).astype('uint8') + >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> gt_seg_data_sample = SegDataSample() + >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg + >>> seg_local_visualizer.dataset_meta = dict( + >>> classes=('background', 'foreground'), + >>> palette=[[120, 120, 120], [6, 230, 230]]) + >>> seg_local_visualizer.add_datasample('visualizer_example', + ... image, gt_seg_data_sample) + >>> seg_local_visualizer.add_datasample( + ... 'visualizer_example', image, + ... gt_seg_data_sample, show=True) + """ # noqa + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + classes: Optional[List] = None, + palette: Optional[List] = None, + dataset_name: Optional[str] = None, + alpha: float = 0.8, + **kwargs): + super().__init__(name, image, vis_backends, save_dir, **kwargs) + self.alpha: float = alpha + self.set_dataset_meta(palette, classes, dataset_name) + + def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, + classes: Optional[List], + palette: Optional[List]) -> np.ndarray: + """Draw semantic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + sem_seg (:obj:`PixelData`): Data structure for pixel-level + annotations or predictions. + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + num_classes = len(classes) + + sem_seg = sem_seg.cpu().data + ids = np.unique(sem_seg)[::-1] + legal_indices = ids < num_classes + ids = ids[legal_indices] + labels = np.array(ids, dtype=np.int64) + + colors = [palette[label] for label in labels] + + self.set_image(image) + + # draw semantic masks + for label, color in zip(labels, colors): + self.draw_binary_masks( + sem_seg == label, colors=[color], alphas=self.alpha) + + return self.get_image() + + def set_dataset_meta(self, + classes: Optional[List] = None, + palette: Optional[List] = None, + dataset_name: Optional[str] = None) -> None: + """Set meta information to visualizer. + + Args: + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. + classes and palette, but the `classes` and `palette` have + higher priority. Defaults to None. + """ # noqa + # Set default value. When calling + # `SegLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + if dataset_name is None: + dataset_name = 'cityscapes' + classes = classes if classes else get_classes(dataset_name) + palette = palette if palette else get_palette(dataset_name) + assert len(classes) == len( + palette), 'The length of classes should be equal to palette' + self.dataset_meta: dict = {'classes': classes, 'palette': palette} + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: Optional[SegDataSample] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: float = 0, + # TODO: Supported in mmengine's Viusalizer. + out_file: Optional[str] = None, + step: int = 0) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. it is usually used when the display + is not available. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. + Defaults to None. + pred_sample (:obj:`SegDataSample`, optional): Prediction + SegDataSample. Defaults to None. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + step (int): Global step value to record. Defaults to 0. + """ + classes = self.dataset_meta.get('classes', None) + palette = self.dataset_meta.get('palette', None) + + gt_img_data = None + pred_img_data = None + + if draw_gt and data_sample is not None and 'gt_sem_seg' in data_sample: + gt_img_data = image + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + gt_img_data = self._draw_sem_seg(gt_img_data, + data_sample.gt_sem_seg, classes, + palette) + + if (draw_pred and data_sample is not None + and 'pred_sem_seg' in data_sample): + pred_img_data = image + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + pred_img_data = self._draw_sem_seg(pred_img_data, + data_sample.pred_sem_seg, + classes, palette) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file) + else: + self.add_image(name, drawn_img, step)