diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f918f02bd80144d580904036b70ef81e66d70921 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +*.pth diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..cd6769d9e9599f2057775a6047ecfec37e000e38 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +PYTHON=3.9 +BASENAME=$(shell basename $(CURDIR)) + +env: + conda create -n $(BASENAME) python=$(PYTHON) + +setup: + pip install -r requirements.txt diff --git a/app.py b/app.py index 0c2724eb62631bb1aa7cbe259be21860a45160e6..bbef47146755d1ac57d35206611bc8625ed3bee7 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,85 @@ import streamlit as st +import torch +import numpy as np +import cv2 +import wget +import os -x = st.slider('Select a value') -st.write(x, 'squared is', x * x) +from PIL import Image +from streamlit_drawable_canvas import st_canvas + +from isegm.inference import clicker as ck +from isegm.inference import utils +from isegm.inference.predictors import get_predictor + +# Model Path +prefix = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main" +models = { + "RITM": "ritm_coco_lvis_h18_itermask.pth", +} + +# Items in the sidebar. +model = st.sidebar.selectbox("Select a Model:", tuple(models.keys())) +threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5) +marking_type = st.sidebar.radio("Marking Type:", ("positive", "negative")) +image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"]) + +# Objects for prediction. +clicker = ck.Clicker() +device = torch.device("cpu") +predictor = None +with st.spinner("Wait for downloading a model..."): + if not os.path.exists(models[model]): + _ = wget.download(f"{prefix}/{models[model]}") + +with st.spinner("Wait for loading a model..."): + model = utils.load_is_model(models[model], device, cpu_dist_maps=True) + predictor_params = {"brs_mode": "NoBRS"} + predictor = get_predictor(model, device=device, **predictor_params) + +# Create a canvas component. +image = None +if image_path: + image = Image.open(image_path) +canvas_height, canvas_width = 600, 600 +pos_color, neg_color = "#3498DB", "#C70039" +st.title("Canvas:") +canvas_result = st_canvas( + fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity + stroke_width=3, + stroke_color=pos_color if marking_type == "positive" else neg_color, + background_color="#eee", + background_image=image, + update_streamlit=True, + drawing_mode="point", + point_display_radius=3, + key="canvas", + width=canvas_width, + height=canvas_height, +) + +# Check the user inputs ans execute predictions. +st.title("Prediction:") +if canvas_result.json_data and canvas_result.json_data["objects"] and image: + objects = canvas_result.json_data["objects"] + image_width, image_height = image.size + ratio_h, ratio_w = image_height / canvas_height, image_width / canvas_width + + err_x, err_y = 5.5, 1.0 + pos_clicks, neg_clicks = [], [] + for click in objects: + x, y = (click["left"] + err_x) * ratio_w, (click["top"] + err_y) * ratio_h + x, y = min(image_width, max(0, x)), min(image_height, max(0, y)) + + is_positive = click["stroke"] == pos_color + click = ck.Click(is_positive=is_positive, coords=(y, x)) + clicker.add_click(click) + + # prediction. + pred = None + predictor.set_input_image(np.array(image)) + with st.spinner("Wait for prediction..."): + pred = predictor.get_prediction(clicker, prev_mask=None) + pred = cv2.resize(pred, dsize=(canvas_height, canvas_width), interpolation=cv2.INTER_CUBIC) + pred = np.where(pred > threshold, 1.0, 0) + st.image(pred, caption="") diff --git a/isegm/data/base.py b/isegm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2a532643d2ebf9c234a19e46df652ac56497cb --- /dev/null +++ b/isegm/data/base.py @@ -0,0 +1,99 @@ +import random +import pickle +import numpy as np +import torch +from torchvision import transforms +from .points_sampler import MultiPointSampler +from .sample import DSample + + +class ISDataset(torch.utils.data.dataset.Dataset): + def __init__(self, + augmentator=None, + points_sampler=MultiPointSampler(max_num_points=12), + min_object_area=0, + keep_background_prob=0.0, + with_image_info=False, + samples_scores_path=None, + samples_scores_gamma=1.0, + epoch_len=-1): + super(ISDataset, self).__init__() + self.epoch_len = epoch_len + self.augmentator = augmentator + self.min_object_area = min_object_area + self.keep_background_prob = keep_background_prob + self.points_sampler = points_sampler + self.with_image_info = with_image_info + self.samples_precomputed_scores = self._load_samples_scores(samples_scores_path, samples_scores_gamma) + self.to_tensor = transforms.ToTensor() + + self.dataset_samples = None + + def __getitem__(self, index): + if self.samples_precomputed_scores is not None: + index = np.random.choice(self.samples_precomputed_scores['indices'], + p=self.samples_precomputed_scores['probs']) + else: + if self.epoch_len > 0: + index = random.randrange(0, len(self.dataset_samples)) + + sample = self.get_sample(index) + sample = self.augment_sample(sample) + sample.remove_small_objects(self.min_object_area) + + self.points_sampler.sample_object(sample) + points = np.array(self.points_sampler.sample_points()) + mask = self.points_sampler.selected_mask + + output = { + 'images': self.to_tensor(sample.image), + 'points': points.astype(np.float32), + 'instances': mask + } + + if self.with_image_info: + output['image_info'] = sample.sample_id + + return output + + def augment_sample(self, sample) -> DSample: + if self.augmentator is None: + return sample + + valid_augmentation = False + while not valid_augmentation: + sample.augment(self.augmentator) + keep_sample = (self.keep_background_prob < 0.0 or + random.random() < self.keep_background_prob) + valid_augmentation = len(sample) > 0 or keep_sample + + return sample + + def get_sample(self, index) -> DSample: + raise NotImplementedError + + def __len__(self): + if self.epoch_len > 0: + return self.epoch_len + else: + return self.get_samples_number() + + def get_samples_number(self): + return len(self.dataset_samples) + + @staticmethod + def _load_samples_scores(samples_scores_path, samples_scores_gamma): + if samples_scores_path is None: + return None + + with open(samples_scores_path, 'rb') as f: + images_scores = pickle.load(f) + + probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores]) + probs /= probs.sum() + samples_scores = { + 'indices': [x[0] for x in images_scores], + 'probs': probs + } + print(f'Loaded {len(probs)} weights with gamma={samples_scores_gamma}') + return samples_scores diff --git a/isegm/data/compose.py b/isegm/data/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e458cfd5693a3b5a73b9717c268213914f8430 --- /dev/null +++ b/isegm/data/compose.py @@ -0,0 +1,39 @@ +import numpy as np +from math import isclose +from .base import ISDataset + + +class ComposeDataset(ISDataset): + def __init__(self, datasets, **kwargs): + super(ComposeDataset, self).__init__(**kwargs) + + self._datasets = datasets + self.dataset_samples = [] + for dataset_indx, dataset in enumerate(self._datasets): + self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) + + def get_sample(self, index): + dataset_indx, sample_indx = self.dataset_samples[index] + return self._datasets[dataset_indx].get_sample(sample_indx) + + +class ProportionalComposeDataset(ISDataset): + def __init__(self, datasets, ratios, **kwargs): + super().__init__(**kwargs) + + assert len(ratios) == len(datasets),\ + "The number of datasets must match the number of ratios" + assert isclose(sum(ratios), 1.0),\ + "The sum of ratios must be equal to 1" + + self._ratios = ratios + self._datasets = datasets + self.dataset_samples = [] + for dataset_indx, dataset in enumerate(self._datasets): + self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) + + def get_sample(self, index): + dataset_indx = np.random.choice(len(self._datasets), p=self._ratios) + sample_indx = np.random.choice(len(self._datasets[dataset_indx])) + + return self._datasets[dataset_indx].get_sample(sample_indx) diff --git a/isegm/data/datasets/__init__.py b/isegm/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..966ffff2028cd494f785011eb037628890c06b94 --- /dev/null +++ b/isegm/data/datasets/__init__.py @@ -0,0 +1,12 @@ +from isegm.data.compose import ComposeDataset, ProportionalComposeDataset +from .berkeley import BerkeleyDataset +from .coco import CocoDataset +from .davis import DavisDataset +from .grabcut import GrabCutDataset +from .coco_lvis import CocoLvisDataset +from .lvis import LvisDataset +from .openimages import OpenImagesDataset +from .sbd import SBDDataset, SBDEvaluationDataset +from .images_dir import ImagesDirDataset +from .ade20k import ADE20kDataset +from .pascalvoc import PascalVocDataset diff --git a/isegm/data/datasets/ade20k.py b/isegm/data/datasets/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..6791b8353a2d34c5e6e36ca5cdc6e4bdb62339c2 --- /dev/null +++ b/isegm/data/datasets/ade20k.py @@ -0,0 +1,55 @@ +import os +import random +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample +from isegm.utils.misc import get_labels_with_sizes + + +class ADE20kDataset(ISDataset): + def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val'} + + self.dataset_path = Path(dataset_path) + self.dataset_split = split + self.dataset_split_folder = 'training' if split == 'train' else 'validation' + self.stuff_prob = stuff_prob + + anno_path = self.dataset_path / f'{split}-annotations-object-segmentation.pkl' + if os.path.exists(anno_path): + with anno_path.open('rb') as f: + annotations = pkl.load(f) + else: + raise RuntimeError(f"Can't find annotations at {anno_path}") + self.annotations = annotations + self.dataset_samples = list(annotations.keys()) + + def get_sample(self, index) -> DSample: + image_id = self.dataset_samples[index] + sample_annos = self.annotations[image_id] + + image_path = str(self.dataset_path / sample_annos['folder'] / f'{image_id}.jpg') + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # select random mask for an image + layer = random.choice(sample_annos['layers']) + mask_path = str(self.dataset_path / sample_annos['folder'] / layer['mask_name']) + instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[:, :, 0] # the B channel holds instances + instances_mask = instances_mask.astype(np.int32) + object_ids, _ = get_labels_with_sizes(instances_mask) + + if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob): + # remove stuff objects + for i, object_id in enumerate(object_ids): + if i in layer['stuff_instances']: + instances_mask[instances_mask == object_id] = 0 + object_ids, _ = get_labels_with_sizes(instances_mask) + + return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index) diff --git a/isegm/data/datasets/berkeley.py b/isegm/data/datasets/berkeley.py new file mode 100644 index 0000000000000000000000000000000000000000..5c269d84afdc8350cf92f0deddf732c5b62c0687 --- /dev/null +++ b/isegm/data/datasets/berkeley.py @@ -0,0 +1,6 @@ +from .grabcut import GrabCutDataset + + +class BerkeleyDataset(GrabCutDataset): + def __init__(self, dataset_path, **kwargs): + super().__init__(dataset_path, images_dir_name='images', masks_dir_name='masks', **kwargs) diff --git a/isegm/data/datasets/coco.py b/isegm/data/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..985eb768579636ca9fcae68c56654af94d477f2a --- /dev/null +++ b/isegm/data/datasets/coco.py @@ -0,0 +1,74 @@ +import cv2 +import json +import random +import numpy as np +from pathlib import Path +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class CocoDataset(ISDataset): + def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs): + super(CocoDataset, self).__init__(**kwargs) + self.split = split + self.dataset_path = Path(dataset_path) + self.stuff_prob = stuff_prob + + self.load_samples() + + def load_samples(self): + annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json' + self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}' + self.images_path = self.dataset_path / self.split + + with open(annotation_path, 'r') as f: + annotation = json.load(f) + + self.dataset_samples = annotation['annotations'] + + self._categories = annotation['categories'] + self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0] + self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1] + self._things_labels_set = set(self._things_labels) + self._stuff_labels_set = set(self._stuff_labels) + + def get_sample(self, index) -> DSample: + dataset_sample = self.dataset_samples[index] + + image_path = self.images_path / self.get_image_name(dataset_sample['file_name']) + label_path = self.labels_path / dataset_sample['file_name'] + + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32) + label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2] + + instance_map = np.full_like(label, 0) + things_ids = [] + stuff_ids = [] + + for segment in dataset_sample['segments_info']: + class_id = segment['category_id'] + obj_id = segment['id'] + if class_id in self._things_labels_set: + if segment['iscrowd'] == 1: + continue + things_ids.append(obj_id) + else: + stuff_ids.append(obj_id) + + instance_map[label == obj_id] = obj_id + + if self.stuff_prob > 0 and random.random() < self.stuff_prob: + instances_ids = things_ids + stuff_ids + else: + instances_ids = things_ids + + for stuff_id in stuff_ids: + instance_map[instance_map == stuff_id] = 0 + + return DSample(image, instance_map, objects_ids=instances_ids) + + @classmethod + def get_image_name(cls, panoptic_name): + return panoptic_name.replace('.png', '.jpg') diff --git a/isegm/data/datasets/coco_lvis.py b/isegm/data/datasets/coco_lvis.py new file mode 100644 index 0000000000000000000000000000000000000000..03691036178a588e02167faf512ed473cae10e25 --- /dev/null +++ b/isegm/data/datasets/coco_lvis.py @@ -0,0 +1,67 @@ +from pathlib import Path +import pickle +import random +import numpy as np +import json +import cv2 +from copy import deepcopy +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class CocoLvisDataset(ISDataset): + def __init__(self, dataset_path, split='train', stuff_prob=0.0, + allow_list_name=None, anno_file='hannotation.pickle', **kwargs): + super(CocoLvisDataset, self).__init__(**kwargs) + dataset_path = Path(dataset_path) + self._split_path = dataset_path / split + self.split = split + self._images_path = self._split_path / 'images' + self._masks_path = self._split_path / 'masks' + self.stuff_prob = stuff_prob + + with open(self._split_path / anno_file, 'rb') as f: + self.dataset_samples = sorted(pickle.load(f).items()) + + if allow_list_name is not None: + allow_list_path = self._split_path / allow_list_name + with open(allow_list_path, 'r') as f: + allow_images_ids = json.load(f) + allow_images_ids = set(allow_images_ids) + + self.dataset_samples = [sample for sample in self.dataset_samples + if sample[0] in allow_images_ids] + + def get_sample(self, index) -> DSample: + image_id, sample = self.dataset_samples[index] + image_path = self._images_path / f'{image_id}.jpg' + + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + packed_masks_path = self._masks_path / f'{image_id}.pickle' + with open(packed_masks_path, 'rb') as f: + encoded_layers, objs_mapping = pickle.load(f) + layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers] + layers = np.stack(layers, axis=2) + + instances_info = deepcopy(sample['hierarchy']) + for inst_id, inst_info in list(instances_info.items()): + if inst_info is None: + inst_info = {'children': [], 'parent': None, 'node_level': 0} + instances_info[inst_id] = inst_info + inst_info['mapping'] = objs_mapping[inst_id] + + if self.stuff_prob > 0 and random.random() < self.stuff_prob: + for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): + instances_info[inst_id] = { + 'mapping': objs_mapping[inst_id], + 'parent': None, + 'children': [] + } + else: + for inst_id in range(sample['num_instance_masks'], len(objs_mapping)): + layer_indx, mask_id = objs_mapping[inst_id] + layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0 + + return DSample(image, layers, objects=instances_info) diff --git a/isegm/data/datasets/davis.py b/isegm/data/datasets/davis.py new file mode 100644 index 0000000000000000000000000000000000000000..de36b96be27f12a286865086a4e070a452987169 --- /dev/null +++ b/isegm/data/datasets/davis.py @@ -0,0 +1,33 @@ +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class DavisDataset(ISDataset): + def __init__(self, dataset_path, + images_dir_name='img', masks_dir_name='gt', + **kwargs): + super(DavisDataset, self).__init__(**kwargs) + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / images_dir_name + self._insts_path = self.dataset_path / masks_dir_name + + self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] + self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} + + def get_sample(self, index) -> DSample: + image_name = self.dataset_samples[index] + image_path = str(self._images_path / image_name) + mask_path = str(self._masks_paths[image_name.split('.')[0]]) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = np.max(cv2.imread(mask_path).astype(np.int32), axis=2) + instances_mask[instances_mask > 0] = 1 + + return DSample(image, instances_mask, objects_ids=[1], sample_id=index) diff --git a/isegm/data/datasets/grabcut.py b/isegm/data/datasets/grabcut.py new file mode 100644 index 0000000000000000000000000000000000000000..ff00446d613183e3a0deed29cd8ed8dae53fd5b3 --- /dev/null +++ b/isegm/data/datasets/grabcut.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class GrabCutDataset(ISDataset): + def __init__(self, dataset_path, + images_dir_name='data_GT', masks_dir_name='boundary_GT', + **kwargs): + super(GrabCutDataset, self).__init__(**kwargs) + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / images_dir_name + self._insts_path = self.dataset_path / masks_dir_name + + self.dataset_samples = [x.name for x in sorted(self._images_path.glob('*.*'))] + self._masks_paths = {x.stem: x for x in self._insts_path.glob('*.*')} + + def get_sample(self, index) -> DSample: + image_name = self.dataset_samples[index] + image_path = str(self._images_path / image_name) + mask_path = str(self._masks_paths[image_name.split('.')[0]]) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path)[:, :, 0].astype(np.int32) + instances_mask[instances_mask == 128] = -1 + instances_mask[instances_mask > 128] = 1 + + return DSample(image, instances_mask, objects_ids=[1], ignore_ids=[-1], sample_id=index) diff --git a/isegm/data/datasets/images_dir.py b/isegm/data/datasets/images_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..db7d0fa6288ebdd7a7865648965e64183543ac87 --- /dev/null +++ b/isegm/data/datasets/images_dir.py @@ -0,0 +1,59 @@ +import cv2 +import numpy as np +from pathlib import Path + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class ImagesDirDataset(ISDataset): + def __init__(self, dataset_path, + images_dir_name='images', masks_dir_name='masks', + **kwargs): + super(ImagesDirDataset, self).__init__(**kwargs) + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / images_dir_name + self._insts_path = self.dataset_path / masks_dir_name + + images_list = [x for x in sorted(self._images_path.glob('*.*'))] + + samples = {x.stem: {'image': x, 'masks': []} for x in images_list} + for mask_path in self._insts_path.glob('*.*'): + mask_name = mask_path.stem + if mask_name in samples: + samples[mask_name]['masks'].append(mask_path) + continue + + mask_name_split = mask_name.split('_') + if mask_name_split[-1].isdigit(): + mask_name = '_'.join(mask_name_split[:-1]) + assert mask_name in samples + samples[mask_name]['masks'].append(mask_path) + + for x in samples.values(): + assert len(x['masks']) > 0, x['image'] + + self.dataset_samples = [v for k, v in sorted(samples.items())] + + def get_sample(self, index) -> DSample: + sample = self.dataset_samples[index] + image_path = str(sample['image']) + + objects = [] + ignored_regions = [] + masks = [] + for indx, mask_path in enumerate(sample['masks']): + gt_mask = cv2.imread(str(mask_path))[:, :, 0].astype(np.int32) + instances_mask = np.zeros_like(gt_mask) + instances_mask[gt_mask == 128] = 2 + instances_mask[gt_mask > 128] = 1 + masks.append(instances_mask) + objects.append((indx, 1)) + ignored_regions.append((indx, 2)) + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + return DSample(image, np.stack(masks, axis=2), + objects_ids=objects, ignore_ids=ignored_regions, sample_id=index) diff --git a/isegm/data/datasets/lvis.py b/isegm/data/datasets/lvis.py new file mode 100644 index 0000000000000000000000000000000000000000..fd94b431d97effcff96ee3bee607f97375b88325 --- /dev/null +++ b/isegm/data/datasets/lvis.py @@ -0,0 +1,97 @@ +import json +import random +from collections import defaultdict +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class LvisDataset(ISDataset): + def __init__(self, dataset_path, split='train', + max_overlap_ratio=0.5, + **kwargs): + super(LvisDataset, self).__init__(**kwargs) + dataset_path = Path(dataset_path) + train_categories_path = dataset_path / 'train_categories.json' + self._train_path = dataset_path / 'train' + self._val_path = dataset_path / 'val' + + self.split = split + self.max_overlap_ratio = max_overlap_ratio + + with open( dataset_path / split / f'lvis_{self.split}.json', 'r') as f: + json_annotation = json.loads(f.read()) + + self.annotations = defaultdict(list) + for x in json_annotation['annotations']: + self.annotations[x['image_id']].append(x) + + if not train_categories_path.exists(): + self.generate_train_categories(dataset_path, train_categories_path) + self.dataset_samples = [x for x in json_annotation['images'] + if len(self.annotations[x['id']]) > 0] + + def get_sample(self, index) -> DSample: + image_info = self.dataset_samples[index] + image_id, image_url = image_info['id'], image_info['coco_url'] + image_filename = image_url.split('/')[-1] + image_annotations = self.annotations[image_id] + random.shuffle(image_annotations) + + # LVISv1 splits do not match older LVIS splits (some images in val may come from COCO train2017) + if 'train2017' in image_url: + image_path = self._train_path / 'images' / image_filename + else: + image_path = self._val_path / 'images' / image_filename + image = cv2.imread(str(image_path)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + instances_mask = None + instances_area = defaultdict(int) + objects_ids = [] + for indx, obj_annotation in enumerate(image_annotations): + mask = self.get_mask_from_polygon(obj_annotation, image) + object_mask = mask > 0 + object_area = object_mask.sum() + + if instances_mask is None: + instances_mask = np.zeros_like(object_mask, dtype=np.int32) + + overlap_ids = np.bincount(instances_mask[object_mask].flatten()) + overlap_areas = [overlap_area / instances_area[inst_id] for inst_id, overlap_area in enumerate(overlap_ids) + if overlap_area > 0 and inst_id > 0] + overlap_ratio = np.logical_and(object_mask, instances_mask > 0).sum() / object_area + if overlap_areas: + overlap_ratio = max(overlap_ratio, max(overlap_areas)) + if overlap_ratio > self.max_overlap_ratio: + continue + + instance_id = indx + 1 + instances_mask[object_mask] = instance_id + instances_area[instance_id] = object_area + objects_ids.append(instance_id) + + return DSample(image, instances_mask, objects_ids=objects_ids) + + + @staticmethod + def get_mask_from_polygon(annotation, image): + mask = np.zeros(image.shape[:2], dtype=np.int32) + for contour_points in annotation['segmentation']: + contour_points = np.array(contour_points).reshape((-1, 2)) + contour_points = np.round(contour_points).astype(np.int32)[np.newaxis, :] + cv2.fillPoly(mask, contour_points, 1) + + return mask + + @staticmethod + def generate_train_categories(dataset_path, train_categories_path): + with open(dataset_path / 'train/lvis_train.json', 'r') as f: + annotation = json.load(f) + + with open(train_categories_path, 'w') as f: + json.dump(annotation['categories'], f, indent=1) diff --git a/isegm/data/datasets/openimages.py b/isegm/data/datasets/openimages.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a81cfbf08b9b5ddd3fe565a00e778733e9ee4a --- /dev/null +++ b/isegm/data/datasets/openimages.py @@ -0,0 +1,58 @@ +import os +import random +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class OpenImagesDataset(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'test'} + + self.dataset_path = Path(dataset_path) + self._split_path = self.dataset_path / split + self._images_path = self._split_path / 'images' + self._masks_path = self._split_path / 'masks' + self.dataset_split = split + + clean_anno_path = self._split_path / f'{split}-annotations-object-segmentation_clean.pkl' + if os.path.exists(clean_anno_path): + with clean_anno_path.open('rb') as f: + annotations = pkl.load(f) + else: + raise RuntimeError(f"Can't find annotations at {clean_anno_path}") + self.image_id_to_masks = annotations['image_id_to_masks'] + self.dataset_samples = annotations['dataset_samples'] + + def get_sample(self, index) -> DSample: + image_id = self.dataset_samples[index] + + image_path = str(self._images_path / f'{image_id}.jpg') + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + mask_paths = self.image_id_to_masks[image_id] + # select random mask for an image + mask_path = str(self._masks_path / random.choice(mask_paths)) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY) + instances_mask[instances_mask > 0] = 1 + instances_mask = instances_mask.astype(np.int32) + + min_width = min(image.shape[1], instances_mask.shape[1]) + min_height = min(image.shape[0], instances_mask.shape[0]) + + if image.shape[0] != min_height or image.shape[1] != min_width: + image = cv2.resize(image, (min_width, min_height), interpolation=cv2.INTER_LINEAR) + if instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width: + instances_mask = cv2.resize(instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST) + + object_ids = [1] if instances_mask.sum() > 0 else [] + + return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index) diff --git a/isegm/data/datasets/pascalvoc.py b/isegm/data/datasets/pascalvoc.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1ad488f2228c1a94040d1bde21cd421ff70b3e --- /dev/null +++ b/isegm/data/datasets/pascalvoc.py @@ -0,0 +1,48 @@ +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np + +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class PascalVocDataset(ISDataset): + def __init__(self, dataset_path, split='train', **kwargs): + super().__init__(**kwargs) + assert split in {'train', 'val', 'trainval', 'test'} + + self.dataset_path = Path(dataset_path) + self._images_path = self.dataset_path / "JPEGImages" + self._insts_path = self.dataset_path / "SegmentationObject" + self.dataset_split = split + + if split == 'test': + with open(self.dataset_path / f'ImageSets/Segmentation/test.pickle', 'rb') as f: + self.dataset_samples, self.instance_ids = pkl.load(f) + else: + with open(self.dataset_path / f'ImageSets/Segmentation/{split}.txt', 'r') as f: + self.dataset_samples = [name.strip() for name in f.readlines()] + + def get_sample(self, index) -> DSample: + sample_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{sample_id}.jpg') + mask_path = str(self._insts_path / f'{sample_id}.png') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = cv2.imread(mask_path) + instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY).astype(np.int32) + if self.dataset_split == 'test': + instance_id = self.instance_ids[index] + mask = np.zeros_like(instances_mask) + mask[instances_mask == 220] = 220 # ignored area + mask[instances_mask == instance_id] = 1 + objects_ids = [1] + instances_mask = mask + else: + objects_ids = np.unique(instances_mask) + objects_ids = [x for x in objects_ids if x != 0 and x != 220] + + return DSample(image, instances_mask, objects_ids=objects_ids, ignore_ids=[220], sample_id=index) diff --git a/isegm/data/datasets/sbd.py b/isegm/data/datasets/sbd.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a05e4b370f4b6486ebc24ceb961f545f256f81 --- /dev/null +++ b/isegm/data/datasets/sbd.py @@ -0,0 +1,111 @@ +import pickle as pkl +from pathlib import Path + +import cv2 +import numpy as np +from scipy.io import loadmat + +from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes +from isegm.data.base import ISDataset +from isegm.data.sample import DSample + + +class SBDDataset(ISDataset): + def __init__(self, dataset_path, split='train', buggy_mask_thresh=0.08, **kwargs): + super(SBDDataset, self).__init__(**kwargs) + assert split in {'train', 'val'} + + self.dataset_path = Path(dataset_path) + self.dataset_split = split + self._images_path = self.dataset_path / 'img' + self._insts_path = self.dataset_path / 'inst' + self._buggy_objects = dict() + self._buggy_mask_thresh = buggy_mask_thresh + + with open(self.dataset_path / f'{split}.txt', 'r') as f: + self.dataset_samples = [x.strip() for x in f.readlines()] + + def get_sample(self, index): + image_name = self.dataset_samples[index] + image_path = str(self._images_path / f'{image_name}.jpg') + inst_info_path = str(self._insts_path / f'{image_name}.mat') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_mask = self.remove_buggy_masks(index, instances_mask) + instances_ids, _ = get_labels_with_sizes(instances_mask) + + return DSample(image, instances_mask, objects_ids=instances_ids, sample_id=index) + + def remove_buggy_masks(self, index, instances_mask): + if self._buggy_mask_thresh > 0.0: + buggy_image_objects = self._buggy_objects.get(index, None) + if buggy_image_objects is None: + buggy_image_objects = [] + instances_ids, _ = get_labels_with_sizes(instances_mask) + for obj_id in instances_ids: + obj_mask = instances_mask == obj_id + mask_area = obj_mask.sum() + bbox = get_bbox_from_mask(obj_mask) + bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) + obj_area_ratio = mask_area / bbox_area + if obj_area_ratio < self._buggy_mask_thresh: + buggy_image_objects.append(obj_id) + + self._buggy_objects[index] = buggy_image_objects + for obj_id in buggy_image_objects: + instances_mask[instances_mask == obj_id] = 0 + + return instances_mask + + +class SBDEvaluationDataset(ISDataset): + def __init__(self, dataset_path, split='val', **kwargs): + super(SBDEvaluationDataset, self).__init__(**kwargs) + assert split in {'train', 'val'} + + self.dataset_path = Path(dataset_path) + self.dataset_split = split + self._images_path = self.dataset_path / 'img' + self._insts_path = self.dataset_path / 'inst' + + with open(self.dataset_path / f'{split}.txt', 'r') as f: + self.dataset_samples = [x.strip() for x in f.readlines()] + + self.dataset_samples = self.get_sbd_images_and_ids_list() + + def get_sample(self, index) -> DSample: + image_name, instance_id = self.dataset_samples[index] + image_path = str(self._images_path / f'{image_name}.jpg') + inst_info_path = str(self._insts_path / f'{image_name}.mat') + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_mask[instances_mask != instance_id] = 0 + instances_mask[instances_mask > 0] = 1 + + return DSample(image, instances_mask, objects_ids=[1], sample_id=index) + + def get_sbd_images_and_ids_list(self): + pkl_path = self.dataset_path / f'{self.dataset_split}_images_and_ids_list.pkl' + + if pkl_path.exists(): + with open(str(pkl_path), 'rb') as fp: + images_and_ids_list = pkl.load(fp) + else: + images_and_ids_list = [] + + for sample in self.dataset_samples: + inst_info_path = str(self._insts_path / f'{sample}.mat') + instances_mask = loadmat(str(inst_info_path))['GTinst'][0][0][0].astype(np.int32) + instances_ids, _ = get_labels_with_sizes(instances_mask) + + for instances_id in instances_ids: + images_and_ids_list.append((sample, instances_id)) + + with open(str(pkl_path), 'wb') as fp: + pkl.dump(images_and_ids_list, fp) + + return images_and_ids_list diff --git a/isegm/data/points_sampler.py b/isegm/data/points_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..43cc6380b173599457099ab35bd288db5cee0193 --- /dev/null +++ b/isegm/data/points_sampler.py @@ -0,0 +1,305 @@ +import cv2 +import math +import random +import numpy as np +from functools import lru_cache +from .sample import DSample + + +class BasePointSampler: + def __init__(self): + self._selected_mask = None + self._selected_masks = None + + def sample_object(self, sample: DSample): + raise NotImplementedError + + def sample_points(self): + raise NotImplementedError + + @property + def selected_mask(self): + assert self._selected_mask is not None + return self._selected_mask + + @selected_mask.setter + def selected_mask(self, mask): + self._selected_mask = mask[np.newaxis, :].astype(np.float32) + + +class MultiPointSampler(BasePointSampler): + def __init__(self, max_num_points, prob_gamma=0.7, expand_ratio=0.1, + positive_erode_prob=0.9, positive_erode_iters=3, + negative_bg_prob=0.1, negative_other_prob=0.4, negative_border_prob=0.5, + merge_objects_prob=0.0, max_num_merged_objects=2, + use_hierarchy=False, soft_targets=False, + first_click_center=False, only_one_first_click=False, + sfc_inner_k=1.7, sfc_full_inner_prob=0.0): + super().__init__() + self.max_num_points = max_num_points + self.expand_ratio = expand_ratio + self.positive_erode_prob = positive_erode_prob + self.positive_erode_iters = positive_erode_iters + self.merge_objects_prob = merge_objects_prob + self.use_hierarchy = use_hierarchy + self.soft_targets = soft_targets + self.first_click_center = first_click_center + self.only_one_first_click = only_one_first_click + self.sfc_inner_k = sfc_inner_k + self.sfc_full_inner_prob = sfc_full_inner_prob + + if max_num_merged_objects == -1: + max_num_merged_objects = max_num_points + self.max_num_merged_objects = max_num_merged_objects + + self.neg_strategies = ['bg', 'other', 'border'] + self.neg_strategies_prob = [negative_bg_prob, negative_other_prob, negative_border_prob] + assert math.isclose(sum(self.neg_strategies_prob), 1.0) + + self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma) + self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma) + self._neg_masks = None + + def sample_object(self, sample: DSample): + if len(sample) == 0: + bg_mask = sample.get_background_mask() + self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32) + self._selected_masks = [[]] + self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies} + self._neg_masks['required'] = [] + return + + gt_mask, pos_masks, neg_masks = self._sample_mask(sample) + binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0 + + self.selected_mask = gt_mask + self._selected_masks = pos_masks + + neg_mask_bg = np.logical_not(binary_gt_mask) + neg_mask_border = self._get_border_mask(binary_gt_mask) + if len(sample) <= len(self._selected_masks): + neg_mask_other = neg_mask_bg + else: + neg_mask_other = np.logical_and(np.logical_not(sample.get_background_mask()), + np.logical_not(binary_gt_mask)) + + self._neg_masks = { + 'bg': neg_mask_bg, + 'other': neg_mask_other, + 'border': neg_mask_border, + 'required': neg_masks + } + + def _sample_mask(self, sample: DSample): + root_obj_ids = sample.root_objects + + if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob: + max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects) + num_selected_objects = np.random.randint(2, max_selected_objects + 1) + random_ids = random.sample(root_obj_ids, num_selected_objects) + else: + random_ids = [random.choice(root_obj_ids)] + + gt_mask = None + pos_segments = [] + neg_segments = [] + for obj_id in random_ids: + obj_gt_mask, obj_pos_segments, obj_neg_segments = self._sample_from_masks_layer(obj_id, sample) + if gt_mask is None: + gt_mask = obj_gt_mask + else: + gt_mask = np.maximum(gt_mask, obj_gt_mask) + + pos_segments.extend(obj_pos_segments) + neg_segments.extend(obj_neg_segments) + + pos_masks = [self._positive_erode(x) for x in pos_segments] + neg_masks = [self._positive_erode(x) for x in neg_segments] + + return gt_mask, pos_masks, neg_masks + + def _sample_from_masks_layer(self, obj_id, sample: DSample): + objs_tree = sample._objects + + if not self.use_hierarchy: + node_mask = sample.get_object_mask(obj_id) + gt_mask = sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask + return gt_mask, [node_mask], [] + + def _select_node(node_id): + node_info = objs_tree[node_id] + if not node_info['children'] or random.random() < 0.5: + return node_id + return _select_node(random.choice(node_info['children'])) + + selected_node = _select_node(obj_id) + node_info = objs_tree[selected_node] + node_mask = sample.get_object_mask(selected_node) + gt_mask = sample.get_soft_object_mask(selected_node) if self.soft_targets else node_mask + pos_mask = node_mask.copy() + + negative_segments = [] + if node_info['parent'] is not None and node_info['parent'] in objs_tree: + parent_mask = sample.get_object_mask(node_info['parent']) + negative_segments.append(np.logical_and(parent_mask, np.logical_not(node_mask))) + + for child_id in node_info['children']: + if objs_tree[child_id]['area'] / node_info['area'] < 0.10: + child_mask = sample.get_object_mask(child_id) + pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) + + if node_info['children']: + max_disabled_children = min(len(node_info['children']), 3) + num_disabled_children = np.random.randint(0, max_disabled_children + 1) + disabled_children = random.sample(node_info['children'], num_disabled_children) + + for child_id in disabled_children: + child_mask = sample.get_object_mask(child_id) + pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask)) + if self.soft_targets: + soft_child_mask = sample.get_soft_object_mask(child_id) + gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask) + else: + gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask)) + negative_segments.append(child_mask) + + return gt_mask, [pos_mask], negative_segments + + def sample_points(self): + assert self._selected_mask is not None + pos_points = self._multi_mask_sample_points(self._selected_masks, + is_negative=[False] * len(self._selected_masks), + with_first_click=self.first_click_center) + + neg_strategy = [(self._neg_masks[k], prob) + for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)] + neg_masks = self._neg_masks['required'] + [neg_strategy] + neg_points = self._multi_mask_sample_points(neg_masks, + is_negative=[False] * len(self._neg_masks['required']) + [True]) + + return pos_points + neg_points + + def _multi_mask_sample_points(self, selected_masks, is_negative, with_first_click=False): + selected_masks = selected_masks[:self.max_num_points] + + each_obj_points = [ + self._sample_points(mask, is_negative=is_negative[i], + with_first_click=with_first_click) + for i, mask in enumerate(selected_masks) + ] + each_obj_points = [x for x in each_obj_points if len(x) > 0] + + points = [] + if len(each_obj_points) == 1: + points = each_obj_points[0] + elif len(each_obj_points) > 1: + if self.only_one_first_click: + each_obj_points = each_obj_points[:1] + + points = [obj_points[0] for obj_points in each_obj_points] + + aggregated_masks_with_prob = [] + for indx, x in enumerate(selected_masks): + if isinstance(x, (list, tuple)) and x and isinstance(x[0], (list, tuple)): + for t, prob in x: + aggregated_masks_with_prob.append((t, prob / len(selected_masks))) + else: + aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks))) + + other_points_union = self._sample_points(aggregated_masks_with_prob, is_negative=True) + if len(other_points_union) + len(points) <= self.max_num_points: + points.extend(other_points_union) + else: + points.extend(random.sample(other_points_union, self.max_num_points - len(points))) + + if len(points) < self.max_num_points: + points.extend([(-1, -1, -1)] * (self.max_num_points - len(points))) + + return points + + def _sample_points(self, mask, is_negative=False, with_first_click=False): + if is_negative: + num_points = np.random.choice(np.arange(self.max_num_points + 1), p=self._neg_probs) + else: + num_points = 1 + np.random.choice(np.arange(self.max_num_points), p=self._pos_probs) + + indices_probs = None + if isinstance(mask, (list, tuple)): + indices_probs = [x[1] for x in mask] + indices = [(np.argwhere(x), prob) for x, prob in mask] + if indices_probs: + assert math.isclose(sum(indices_probs), 1.0) + else: + indices = np.argwhere(mask) + + points = [] + for j in range(num_points): + first_click = with_first_click and j == 0 and indices_probs is None + + if first_click: + point_indices = get_point_candidates(mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob) + elif indices_probs: + point_indices_indx = np.random.choice(np.arange(len(indices)), p=indices_probs) + point_indices = indices[point_indices_indx][0] + else: + point_indices = indices + + num_indices = len(point_indices) + if num_indices > 0: + point_indx = 0 if first_click else 100 + click = point_indices[np.random.randint(0, num_indices)].tolist() + [point_indx] + points.append(click) + + return points + + def _positive_erode(self, mask): + if random.random() > self.positive_erode_prob: + return mask + + kernel = np.ones((3, 3), np.uint8) + eroded_mask = cv2.erode(mask.astype(np.uint8), + kernel, iterations=self.positive_erode_iters).astype(np.bool) + + if eroded_mask.sum() > 10: + return eroded_mask + else: + return mask + + def _get_border_mask(self, mask): + expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum()))) + kernel = np.ones((3, 3), np.uint8) + expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r) + expanded_mask[mask.astype(np.bool)] = 0 + return expanded_mask + + +@lru_cache(maxsize=None) +def generate_probs(max_num_points, gamma): + probs = [] + last_value = 1 + for i in range(max_num_points): + probs.append(last_value) + last_value *= gamma + + probs = np.array(probs) + probs /= probs.sum() + + return probs + + +def get_point_candidates(obj_mask, k=1.7, full_prob=0.0): + if full_prob > 0 and random.random() < full_prob: + return obj_mask + + padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), 'constant') + + dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1] + if k > 0: + inner_mask = dt > dt.max() / k + return np.argwhere(inner_mask) + else: + prob_map = dt.flatten() + prob_map /= max(prob_map.sum(), 1e-6) + click_indx = np.random.choice(len(prob_map), p=prob_map) + click_coords = np.unravel_index(click_indx, dt.shape) + return np.array([click_coords]) diff --git a/isegm/data/sample.py b/isegm/data/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..d57794ca405a02f2e3a317b87efdeb94352cf138 --- /dev/null +++ b/isegm/data/sample.py @@ -0,0 +1,148 @@ +import numpy as np +from copy import deepcopy +from isegm.utils.misc import get_labels_with_sizes +from isegm.data.transforms import remove_image_only_transforms +from albumentations import ReplayCompose + + +class DSample: + def __init__(self, image, encoded_masks, objects=None, + objects_ids=None, ignore_ids=None, sample_id=None): + self.image = image + self.sample_id = sample_id + + if len(encoded_masks.shape) == 2: + encoded_masks = encoded_masks[:, :, np.newaxis] + self._encoded_masks = encoded_masks + self._ignored_regions = [] + + if objects_ids is not None: + if not objects_ids or not isinstance(objects_ids[0], tuple): + assert encoded_masks.shape[2] == 1 + objects_ids = [(0, obj_id) for obj_id in objects_ids] + + self._objects = dict() + for indx, obj_mapping in enumerate(objects_ids): + self._objects[indx] = { + 'parent': None, + 'mapping': obj_mapping, + 'children': [] + } + + if ignore_ids: + if isinstance(ignore_ids[0], tuple): + self._ignored_regions = ignore_ids + else: + self._ignored_regions = [(0, region_id) for region_id in ignore_ids] + else: + self._objects = deepcopy(objects) + + self._augmented = False + self._soft_mask_aug = None + self._original_data = self.image, self._encoded_masks, deepcopy(self._objects) + + def augment(self, augmentator): + self.reset_augmentation() + aug_output = augmentator(image=self.image, mask=self._encoded_masks) + self.image = aug_output['image'] + self._encoded_masks = aug_output['mask'] + + aug_replay = aug_output.get('replay', None) + if aug_replay: + assert len(self._ignored_regions) == 0 + mask_replay = remove_image_only_transforms(aug_replay) + self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay) + + self._compute_objects_areas() + self.remove_small_objects(min_area=1) + + self._augmented = True + + def reset_augmentation(self): + if not self._augmented: + return + orig_image, orig_masks, orig_objects = self._original_data + self.image = orig_image + self._encoded_masks = orig_masks + self._objects = deepcopy(orig_objects) + self._augmented = False + self._soft_mask_aug = None + + def remove_small_objects(self, min_area): + if self._objects and not 'area' in list(self._objects.values())[0]: + self._compute_objects_areas() + + for obj_id, obj_info in list(self._objects.items()): + if obj_info['area'] < min_area: + self._remove_object(obj_id) + + def get_object_mask(self, obj_id): + layer_indx, mask_id = self._objects[obj_id]['mapping'] + obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32) + if self._ignored_regions: + for layer_indx, mask_id in self._ignored_regions: + ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id + obj_mask[ignore_mask] = -1 + + return obj_mask + + def get_soft_object_mask(self, obj_id): + assert self._soft_mask_aug is not None + original_encoded_masks = self._original_data[1] + layer_indx, mask_id = self._objects[obj_id]['mapping'] + obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(np.float32) + obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)['image'] + return np.clip(obj_mask, 0, 1) + + def get_background_mask(self): + return np.max(self._encoded_masks, axis=2) == 0 + + @property + def objects_ids(self): + return list(self._objects.keys()) + + @property + def gt_mask(self): + assert len(self._objects) == 1 + return self.get_object_mask(self.objects_ids[0]) + + @property + def root_objects(self): + return [obj_id for obj_id, obj_info in self._objects.items() if obj_info['parent'] is None] + + def _compute_objects_areas(self): + inverse_index = {node['mapping']: node_id for node_id, node in self._objects.items()} + ignored_regions_keys = set(self._ignored_regions) + + for layer_indx in range(self._encoded_masks.shape[2]): + objects_ids, objects_areas = get_labels_with_sizes(self._encoded_masks[:, :, layer_indx]) + for obj_id, obj_area in zip(objects_ids, objects_areas): + inv_key = (layer_indx, obj_id) + if inv_key in ignored_regions_keys: + continue + try: + self._objects[inverse_index[inv_key]]['area'] = obj_area + del inverse_index[inv_key] + except KeyError: + layer = self._encoded_masks[:, :, layer_indx] + layer[layer == obj_id] = 0 + self._encoded_masks[:, :, layer_indx] = layer + + for obj_id in inverse_index.values(): + self._objects[obj_id]['area'] = 0 + + def _remove_object(self, obj_id): + obj_info = self._objects[obj_id] + obj_parent = obj_info['parent'] + for child_id in obj_info['children']: + self._objects[child_id]['parent'] = obj_parent + + if obj_parent is not None: + parent_children = self._objects[obj_parent]['children'] + parent_children = [x for x in parent_children if x != obj_id] + self._objects[obj_parent]['children'] = parent_children + obj_info['children'] + + del self._objects[obj_id] + + def __len__(self): + return len(self._objects) diff --git a/isegm/data/transforms.py b/isegm/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3fd67f6969ba7e120d03ce85b67a8b4651281d --- /dev/null +++ b/isegm/data/transforms.py @@ -0,0 +1,178 @@ +import cv2 +import random +import numpy as np + +from albumentations.core.serialization import SERIALIZABLE_REGISTRY +from albumentations import ImageOnlyTransform, DualTransform +from albumentations.core.transforms_interface import to_tuple +from albumentations.augmentations import functional as F +from isegm.utils.misc import get_bbox_from_mask, expand_bbox, clamp_bbox, get_labels_with_sizes + + +class UniformRandomResize(DualTransform): + def __init__(self, scale_range=(0.9, 1.1), interpolation=cv2.INTER_LINEAR, always_apply=False, p=1): + super().__init__(always_apply, p) + self.scale_range = scale_range + self.interpolation = interpolation + + def get_params_dependent_on_targets(self, params): + scale = random.uniform(*self.scale_range) + height = int(round(params['image'].shape[0] * scale)) + width = int(round(params['image'].shape[1] * scale)) + return {'new_height': height, 'new_width': width} + + def apply(self, img, new_height=0, new_width=0, interpolation=cv2.INTER_LINEAR, **params): + return F.resize(img, height=new_height, width=new_width, interpolation=interpolation) + + def apply_to_keypoint(self, keypoint, new_height=0, new_width=0, **params): + scale_x = new_width / params["cols"] + scale_y = new_height / params["rows"] + return F.keypoint_scale(keypoint, scale_x, scale_y) + + def get_transform_init_args_names(self): + return "scale_range", "interpolation" + + @property + def targets_as_params(self): + return ["image"] + + +class ZoomIn(DualTransform): + def __init__( + self, + height, + width, + bbox_jitter=0.1, + expansion_ratio=1.4, + min_crop_size=200, + min_area=100, + always_resize=False, + always_apply=False, + p=0.5, + ): + super(ZoomIn, self).__init__(always_apply, p) + self.height = height + self.width = width + self.bbox_jitter = to_tuple(bbox_jitter) + self.expansion_ratio = expansion_ratio + self.min_crop_size = min_crop_size + self.min_area = min_area + self.always_resize = always_resize + + def apply(self, img, selected_object, bbox, **params): + if selected_object is None: + if self.always_resize: + img = F.resize(img, height=self.height, width=self.width) + return img + + rmin, rmax, cmin, cmax = bbox + img = img[rmin:rmax + 1, cmin:cmax + 1] + img = F.resize(img, height=self.height, width=self.width) + + return img + + def apply_to_mask(self, mask, selected_object, bbox, **params): + if selected_object is None: + if self.always_resize: + mask = F.resize(mask, height=self.height, width=self.width, + interpolation=cv2.INTER_NEAREST) + return mask + + rmin, rmax, cmin, cmax = bbox + mask = mask[rmin:rmax + 1, cmin:cmax + 1] + if isinstance(selected_object, tuple): + layer_indx, mask_id = selected_object + obj_mask = mask[:, :, layer_indx] == mask_id + new_mask = np.zeros_like(mask) + new_mask[:, :, layer_indx][obj_mask] = mask_id + else: + obj_mask = mask == selected_object + new_mask = mask.copy() + new_mask[np.logical_not(obj_mask)] = 0 + + new_mask = F.resize(new_mask, height=self.height, width=self.width, + interpolation=cv2.INTER_NEAREST) + return new_mask + + def get_params_dependent_on_targets(self, params): + instances = params['mask'] + + is_mask_layer = len(instances.shape) > 2 + candidates = [] + if is_mask_layer: + for layer_indx in range(instances.shape[2]): + labels, areas = get_labels_with_sizes(instances[:, :, layer_indx]) + candidates.extend([(layer_indx, obj_id) + for obj_id, area in zip(labels, areas) + if area > self.min_area]) + else: + labels, areas = get_labels_with_sizes(instances) + candidates = [obj_id for obj_id, area in zip(labels, areas) + if area > self.min_area] + + selected_object = None + bbox = None + if candidates: + selected_object = random.choice(candidates) + if is_mask_layer: + layer_indx, mask_id = selected_object + obj_mask = instances[:, :, layer_indx] == mask_id + else: + obj_mask = instances == selected_object + + bbox = get_bbox_from_mask(obj_mask) + + if isinstance(self.expansion_ratio, tuple): + expansion_ratio = random.uniform(*self.expansion_ratio) + else: + expansion_ratio = self.expansion_ratio + + bbox = expand_bbox(bbox, expansion_ratio, self.min_crop_size) + bbox = self._jitter_bbox(bbox) + bbox = clamp_bbox(bbox, 0, obj_mask.shape[0] - 1, 0, obj_mask.shape[1] - 1) + + return { + 'selected_object': selected_object, + 'bbox': bbox + } + + def _jitter_bbox(self, bbox): + rmin, rmax, cmin, cmax = bbox + height = rmax - rmin + 1 + width = cmax - cmin + 1 + rmin = int(rmin + random.uniform(*self.bbox_jitter) * height) + rmax = int(rmax + random.uniform(*self.bbox_jitter) * height) + cmin = int(cmin + random.uniform(*self.bbox_jitter) * width) + cmax = int(cmax + random.uniform(*self.bbox_jitter) * width) + + return rmin, rmax, cmin, cmax + + def apply_to_bbox(self, bbox, **params): + raise NotImplementedError + + def apply_to_keypoint(self, keypoint, **params): + raise NotImplementedError + + @property + def targets_as_params(self): + return ["mask"] + + def get_transform_init_args_names(self): + return ("height", "width", "bbox_jitter", + "expansion_ratio", "min_crop_size", "min_area", "always_resize") + + +def remove_image_only_transforms(sdict): + if not 'transforms' in sdict: + return sdict + + keep_transforms = [] + for tdict in sdict['transforms']: + cls = SERIALIZABLE_REGISTRY[tdict['__class_fullname__']] + if 'transforms' in tdict: + keep_transforms.append(remove_image_only_transforms(tdict)) + elif not issubclass(cls, ImageOnlyTransform): + keep_transforms.append(tdict) + sdict['transforms'] = keep_transforms + + return sdict diff --git a/isegm/engine/optimizer.py b/isegm/engine/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fd03d8cfc368ee6807fce420ad73e0024a5b6401 --- /dev/null +++ b/isegm/engine/optimizer.py @@ -0,0 +1,27 @@ +import torch +import math +from isegm.utils.log import logger + + +def get_optimizer(model, opt_name, opt_kwargs): + params = [] + base_lr = opt_kwargs['lr'] + for name, param in model.named_parameters(): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + + if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): + logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') + param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult + + params.append(param_group) + + optimizer = { + 'sgd': torch.optim.SGD, + 'adam': torch.optim.Adam, + 'adamw': torch.optim.AdamW + }[opt_name.lower()](params, **opt_kwargs) + + return optimizer diff --git a/isegm/engine/trainer.py b/isegm/engine/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba56323dbc0e909ba0c48025bf620868ff90cfb9 --- /dev/null +++ b/isegm/engine/trainer.py @@ -0,0 +1,413 @@ +import os +import random +import logging +from copy import deepcopy +from collections import defaultdict + +import cv2 +import torch +import numpy as np +from tqdm import tqdm +from torch.utils.data import DataLoader + +from isegm.utils.log import logger, TqdmToLogger, SummaryWriterAvg +from isegm.utils.vis import draw_probmap, draw_points +from isegm.utils.misc import save_checkpoint +from isegm.utils.serialization import get_config_repr +from isegm.utils.distributed import get_dp_wrapper, get_sampler, reduce_loss_dict +from .optimizer import get_optimizer + + +class ISTrainer(object): + def __init__(self, model, cfg, model_cfg, loss_cfg, + trainset, valset, + optimizer='adam', + optimizer_params=None, + image_dump_interval=200, + checkpoint_interval=10, + tb_dump_period=25, + max_interactive_points=0, + lr_scheduler=None, + metrics=None, + additional_val_metrics=None, + net_inputs=('images', 'points'), + max_num_next_clicks=0, + click_models=None, + prev_mask_drop_prob=0.0, + ): + self.cfg = cfg + self.model_cfg = model_cfg + self.max_interactive_points = max_interactive_points + self.loss_cfg = loss_cfg + self.val_loss_cfg = deepcopy(loss_cfg) + self.tb_dump_period = tb_dump_period + self.net_inputs = net_inputs + self.max_num_next_clicks = max_num_next_clicks + + self.click_models = click_models + self.prev_mask_drop_prob = prev_mask_drop_prob + + if cfg.distributed: + cfg.batch_size //= cfg.ngpus + cfg.val_batch_size //= cfg.ngpus + + if metrics is None: + metrics = [] + self.train_metrics = metrics + self.val_metrics = deepcopy(metrics) + if additional_val_metrics is not None: + self.val_metrics.extend(additional_val_metrics) + + self.checkpoint_interval = checkpoint_interval + self.image_dump_interval = image_dump_interval + self.task_prefix = '' + self.sw = None + + self.trainset = trainset + self.valset = valset + + logger.info(f'Dataset of {trainset.get_samples_number()} samples was loaded for training.') + logger.info(f'Dataset of {valset.get_samples_number()} samples was loaded for validation.') + + self.train_data = DataLoader( + trainset, cfg.batch_size, + sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers + ) + + self.val_data = DataLoader( + valset, cfg.val_batch_size, + sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), + drop_last=True, pin_memory=True, + num_workers=cfg.workers + ) + + self.optim = get_optimizer(model, optimizer, optimizer_params) + model = self._load_weights(model) + + if cfg.multi_gpu: + model = get_dp_wrapper(cfg.distributed)(model, device_ids=cfg.gpu_ids, + output_device=cfg.gpu_ids[0]) + + if self.is_master: + logger.info(model) + logger.info(get_config_repr(model._config)) + + self.device = cfg.device + self.net = model.to(self.device) + self.lr = optimizer_params['lr'] + + if lr_scheduler is not None: + self.lr_scheduler = lr_scheduler(optimizer=self.optim) + if cfg.start_epoch > 0: + for _ in range(cfg.start_epoch): + self.lr_scheduler.step() + + self.tqdm_out = TqdmToLogger(logger, level=logging.INFO) + + if self.click_models is not None: + for click_model in self.click_models: + for param in click_model.parameters(): + param.requires_grad = False + click_model.to(self.device) + click_model.eval() + + def run(self, num_epochs, start_epoch=None, validation=True): + if start_epoch is None: + start_epoch = self.cfg.start_epoch + + logger.info(f'Starting Epoch: {start_epoch}') + logger.info(f'Total Epochs: {num_epochs}') + for epoch in range(start_epoch, num_epochs): + self.training(epoch) + if validation: + self.validation(epoch) + + def training(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + if self.cfg.distributed: + self.train_data.sampler.set_epoch(epoch) + + log_prefix = 'Train' + self.task_prefix.capitalize() + tbar = tqdm(self.train_data, file=self.tqdm_out, ncols=100)\ + if self.is_master else self.train_data + + for metric in self.train_metrics: + metric.reset_epoch_stats() + + self.net.train() + train_loss = 0.0 + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.train_data) + i + + loss, losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data) + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + losses_logging['overall'] = loss + reduce_loss_dict(losses_logging) + + train_loss += losses_logging['overall'].item() + + if self.is_master: + for loss_name, loss_value in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', + value=loss_value.item(), + global_step=global_step) + + for k, v in self.loss_cfg.items(): + if '_loss' in k and hasattr(v, 'log_states') and self.loss_cfg.get(k + '_weight', 0.0) > 0: + v.log_states(self.sw, f'{log_prefix}Losses/{k}', global_step) + + if self.image_dump_interval > 0 and global_step % self.image_dump_interval == 0: + self.save_visualization(splitted_batch_data, outputs, global_step, prefix='train') + + self.sw.add_scalar(tag=f'{log_prefix}States/learning_rate', + value=self.lr if not hasattr(self, 'lr_scheduler') else self.lr_scheduler.get_lr()[-1], + global_step=global_step) + + tbar.set_description(f'Epoch {epoch}, training loss {train_loss/(i+1):.4f}') + for metric in self.train_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for metric in self.train_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', + value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=None, multi_gpu=self.cfg.multi_gpu) + + if isinstance(self.checkpoint_interval, (list, tuple)): + checkpoint_interval = [x for x in self.checkpoint_interval if x[0] <= epoch][-1][1] + else: + checkpoint_interval = self.checkpoint_interval + + if epoch % checkpoint_interval == 0: + save_checkpoint(self.net, self.cfg.CHECKPOINTS_PATH, prefix=self.task_prefix, + epoch=epoch, multi_gpu=self.cfg.multi_gpu) + + if hasattr(self, 'lr_scheduler'): + self.lr_scheduler.step() + + def validation(self, epoch): + if self.sw is None and self.is_master: + self.sw = SummaryWriterAvg(log_dir=str(self.cfg.LOGS_PATH), + flush_secs=10, dump_period=self.tb_dump_period) + + log_prefix = 'Val' + self.task_prefix.capitalize() + tbar = tqdm(self.val_data, file=self.tqdm_out, ncols=100) if self.is_master else self.val_data + + for metric in self.val_metrics: + metric.reset_epoch_stats() + + val_loss = 0 + losses_logging = defaultdict(list) + + self.net.eval() + for i, batch_data in enumerate(tbar): + global_step = epoch * len(self.val_data) + i + loss, batch_losses_logging, splitted_batch_data, outputs = \ + self.batch_forward(batch_data, validation=True) + + batch_losses_logging['overall'] = loss + reduce_loss_dict(batch_losses_logging) + for loss_name, loss_value in batch_losses_logging.items(): + losses_logging[loss_name].append(loss_value.item()) + + val_loss += batch_losses_logging['overall'].item() + + if self.is_master: + tbar.set_description(f'Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}') + for metric in self.val_metrics: + metric.log_states(self.sw, f'{log_prefix}Metrics/{metric.name}', global_step) + + if self.is_master: + for loss_name, loss_values in losses_logging.items(): + self.sw.add_scalar(tag=f'{log_prefix}Losses/{loss_name}', value=np.array(loss_values).mean(), + global_step=epoch, disable_avg=True) + + for metric in self.val_metrics: + self.sw.add_scalar(tag=f'{log_prefix}Metrics/{metric.name}', value=metric.get_epoch_value(), + global_step=epoch, disable_avg=True) + + def batch_forward(self, batch_data, validation=False): + metrics = self.val_metrics if validation else self.train_metrics + losses_logging = dict() + + with torch.set_grad_enabled(not validation): + batch_data = {k: v.to(self.device) for k, v in batch_data.items()} + image, gt_mask, points = batch_data['images'], batch_data['instances'], batch_data['points'] + orig_image, orig_gt_mask, orig_points = image.clone(), gt_mask.clone(), points.clone() + + prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] + + last_click_indx = None + + with torch.no_grad(): + num_iters = random.randint(0, self.max_num_next_clicks) + + for click_indx in range(num_iters): + last_click_indx = click_indx + + if not validation: + self.net.eval() + + if self.click_models is None or click_indx >= len(self.click_models): + eval_model = self.net + else: + eval_model = self.click_models[click_indx] + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + prev_output = torch.sigmoid(eval_model(net_input, points)['instances']) + + points = get_next_points(prev_output, orig_gt_mask, points, click_indx + 1) + + if not validation: + self.net.train() + + if self.net.with_prev_mask and self.prev_mask_drop_prob > 0 and last_click_indx is not None: + zero_mask = np.random.random(size=prev_output.size(0)) < self.prev_mask_drop_prob + prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) + + batch_data['points'] = points + + net_input = torch.cat((image, prev_output), dim=1) if self.net.with_prev_mask else image + output = self.net(net_input, points) + + loss = 0.0 + loss = self.add_loss('instance_loss', loss, losses_logging, validation, + lambda: (output['instances'], batch_data['instances'])) + loss = self.add_loss('instance_aux_loss', loss, losses_logging, validation, + lambda: (output['instances_aux'], batch_data['instances'])) + + if self.is_master: + with torch.no_grad(): + for m in metrics: + m.update(*(output.get(x) for x in m.pred_outputs), + *(batch_data[x] for x in m.gt_outputs)) + return loss, losses_logging, batch_data, output + + def add_loss(self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs): + loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg + loss_weight = loss_cfg.get(loss_name + '_weight', 0.0) + if loss_weight > 0.0: + loss_criterion = loss_cfg.get(loss_name) + loss = loss_criterion(*lambda_loss_inputs()) + loss = torch.mean(loss) + losses_logging[loss_name] = loss + loss = loss_weight * loss + total_loss = total_loss + loss + + return total_loss + + def save_visualization(self, splitted_batch_data, outputs, global_step, prefix): + output_images_path = self.cfg.VIS_PATH / prefix + if self.task_prefix: + output_images_path /= self.task_prefix + + if not output_images_path.exists(): + output_images_path.mkdir(parents=True) + image_name_prefix = f'{global_step:06d}' + + def _save_image(suffix, image): + cv2.imwrite(str(output_images_path / f'{image_name_prefix}_{suffix}.jpg'), + image, [cv2.IMWRITE_JPEG_QUALITY, 85]) + + images = splitted_batch_data['images'] + points = splitted_batch_data['points'] + instance_masks = splitted_batch_data['instances'] + + gt_instance_masks = instance_masks.cpu().numpy() + predicted_instance_masks = torch.sigmoid(outputs['instances']).detach().cpu().numpy() + points = points.detach().cpu().numpy() + + image_blob, points = images[0], points[0] + gt_mask = np.squeeze(gt_instance_masks[0], axis=0) + predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0) + + image = image_blob.cpu().numpy() * 255 + image = image.transpose((1, 2, 0)) + + image_with_points = draw_points(image, points[:self.max_interactive_points], (0, 255, 0)) + image_with_points = draw_points(image_with_points, points[self.max_interactive_points:], (0, 0, 255)) + + gt_mask[gt_mask < 0] = 0.25 + gt_mask = draw_probmap(gt_mask) + predicted_mask = draw_probmap(predicted_mask) + viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype(np.uint8) + + _save_image('instance_segmentation', viz_image[:, :, ::-1]) + + def _load_weights(self, net): + if self.cfg.weights is not None: + if os.path.isfile(self.cfg.weights): + load_weights(net, self.cfg.weights) + self.cfg.weights = None + else: + raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") + elif self.cfg.resume_exp is not None: + checkpoints = list(self.cfg.CHECKPOINTS_PATH.glob(f'{self.cfg.resume_prefix}*.pth')) + assert len(checkpoints) == 1 + + checkpoint_path = checkpoints[0] + logger.info(f'Load checkpoint from path: {checkpoint_path}') + load_weights(net, str(checkpoint_path)) + return net + + @property + def is_master(self): + return self.cfg.local_rank == 0 + + +def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): + assert click_indx > 0 + pred = pred.cpu().numpy()[:, 0, :, :] + gt = gt.cpu().numpy()[:, 0, :, :] > 0.5 + + fn_mask = np.logical_and(gt, pred < pred_thresh) + fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) + + fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) + num_points = points.size(1) // 2 + points = points.clone() + + for bindx in range(fn_mask.shape[0]): + fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] + + fn_max_dist = np.max(fn_mask_dt) + fp_max_dist = np.max(fp_mask_dt) + + is_positive = fn_max_dist > fp_max_dist + dt = fn_mask_dt if is_positive else fp_mask_dt + inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 + indices = np.argwhere(inner_mask) + if len(indices) > 0: + coords = indices[np.random.randint(0, len(indices))] + if is_positive: + points[bindx, num_points - click_indx, 0] = float(coords[0]) + points[bindx, num_points - click_indx, 1] = float(coords[1]) + points[bindx, num_points - click_indx, 2] = float(click_indx) + else: + points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) + points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) + points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) + + return points + + +def load_weights(model, path_to_weights): + current_state_dict = model.state_dict() + new_state_dict = torch.load(path_to_weights, map_location='cpu')['state_dict'] + current_state_dict.update(new_state_dict) + model.load_state_dict(current_state_dict) diff --git a/isegm/inference/__init__.py b/isegm/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/inference/clicker.py b/isegm/inference/clicker.py new file mode 100644 index 0000000000000000000000000000000000000000..8789e117b139cd8f99914892022176b774698b2a --- /dev/null +++ b/isegm/inference/clicker.py @@ -0,0 +1,118 @@ +import numpy as np +from copy import deepcopy +import cv2 + + +class Clicker(object): + def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): + self.click_indx_offset = click_indx_offset + if gt_mask is not None: + self.gt_mask = gt_mask == 1 + self.not_ignore_mask = gt_mask != ignore_label + else: + self.gt_mask = None + + self.reset_clicks() + + if init_clicks is not None: + for click in init_clicks: + self.add_click(click) + + def make_next_click(self, pred_mask): + assert self.gt_mask is not None + click = self._get_next_click(pred_mask) + self.add_click(click) + + def get_clicks(self, clicks_limit=None): + return self.clicks_list[:clicks_limit] + + def _get_next_click(self, pred_mask, padding=True): + fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) + fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) + + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') + + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + fn_mask_dt = fn_mask_dt * self.not_clicked_map + fp_mask_dt = fp_mask_dt * self.not_clicked_map + + fn_max_dist = np.max(fn_mask_dt) + fp_max_dist = np.max(fp_mask_dt) + + is_positive = fn_max_dist > fp_max_dist + if is_positive: + coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] + else: + coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] + + return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) + + def add_click(self, click): + coords = click.coords + + click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks + if click.is_positive: + self.num_pos_clicks += 1 + else: + self.num_neg_clicks += 1 + + self.clicks_list.append(click) + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = False + + def _remove_last_click(self): + click = self.clicks_list.pop() + coords = click.coords + + if click.is_positive: + self.num_pos_clicks -= 1 + else: + self.num_neg_clicks -= 1 + + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = True + + def reset_clicks(self): + if self.gt_mask is not None: + self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) + + self.num_pos_clicks = 0 + self.num_neg_clicks = 0 + + self.clicks_list = [] + + def get_state(self): + return deepcopy(self.clicks_list) + + def set_state(self, state): + self.reset_clicks() + for click in state: + self.add_click(click) + + def __len__(self): + return len(self.clicks_list) + + +class Click: + def __init__(self, is_positive, coords, indx=None): + self.is_positive = is_positive + self.coords = coords + self.indx = indx + + @property + def coords_and_indx(self): + return (*self.coords, self.indx) + + def copy(self, **kwargs): + self_copy = deepcopy(self) + for k, v in kwargs.items(): + setattr(self_copy, k, v) + return self_copy diff --git a/isegm/inference/evaluation.py b/isegm/inference/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..ef46e40849a9890151bda1aa1e9b2b55814dfce6 --- /dev/null +++ b/isegm/inference/evaluation.py @@ -0,0 +1,56 @@ +from time import time + +import numpy as np +import torch + +from isegm.inference import utils +from isegm.inference.clicker import Clicker + +try: + get_ipython() + from tqdm import tqdm_notebook as tqdm +except NameError: + from tqdm import tqdm + + +def evaluate_dataset(dataset, predictor, **kwargs): + all_ious = [] + + start_time = time() + for index in tqdm(range(len(dataset)), leave=False): + sample = dataset.get_sample(index) + + _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask, predictor, + sample_id=index, **kwargs) + all_ious.append(sample_ious) + end_time = time() + elapsed_time = end_time - start_time + + return all_ious, elapsed_time + + +def evaluate_sample(image, gt_mask, predictor, max_iou_thr, + pred_thr=0.49, min_clicks=1, max_clicks=20, + sample_id=None, callback=None): + clicker = Clicker(gt_mask=gt_mask) + pred_mask = np.zeros_like(gt_mask) + ious_list = [] + + with torch.no_grad(): + predictor.set_input_image(image) + + for click_indx in range(max_clicks): + clicker.make_next_click(pred_mask) + pred_probs = predictor.get_prediction(clicker) + pred_mask = pred_probs > pred_thr + + if callback is not None: + callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) + + iou = utils.get_iou(gt_mask, pred_mask) + ious_list.append(iou) + + if iou >= max_iou_thr and click_indx + 1 >= min_clicks: + break + + return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs diff --git a/isegm/inference/predictors/__init__.py b/isegm/inference/predictors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5a4f7b58fa6234c898d42f43e91a42669308cd --- /dev/null +++ b/isegm/inference/predictors/__init__.py @@ -0,0 +1,98 @@ +from .base import BasePredictor +from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor +from .brs_functors import InputOptimizer, ScaleBiasOptimizer +from isegm.inference.transforms import ZoomIn +from isegm.model.is_hrnet_model import HRNetModel + + +def get_predictor(net, brs_mode, device, + prob_thresh=0.49, + with_flip=True, + zoom_in_params=dict(), + predictor_params=None, + brs_opt_func_params=None, + lbfgs_params=None): + lbfgs_params_ = { + 'm': 20, + 'factr': 0, + 'pgtol': 1e-8, + 'maxfun': 20, + } + + predictor_params_ = { + 'optimize_after_n_clicks': 1 + } + + if zoom_in_params is not None: + zoom_in = ZoomIn(**zoom_in_params) + else: + zoom_in = None + + if lbfgs_params is not None: + lbfgs_params_.update(lbfgs_params) + lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] + + if brs_opt_func_params is None: + brs_opt_func_params = dict() + + if isinstance(net, (list, tuple)): + assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode." + + if brs_mode == 'NoBRS': + if predictor_params is not None: + predictor_params_.update(predictor_params) + predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) + elif brs_mode.startswith('f-BRS'): + predictor_params_.update({ + 'net_clicks_limit': 8, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + insertion_mode = { + 'f-BRS-A': 'after_c4', + 'f-BRS-B': 'after_aspp', + 'f-BRS-C': 'after_deeplab' + }[brs_mode] + + opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + if isinstance(net, HRNetModel): + FeaturePredictor = HRNetFeatureBRSPredictor + insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] + else: + FeaturePredictor = FeatureBRSPredictor + + predictor = FeaturePredictor(net, device, + opt_functor=opt_functor, + with_flip=with_flip, + insertion_mode=insertion_mode, + zoom_in=zoom_in, + **predictor_params_) + elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': + use_dmaps = brs_mode == 'DistMap-BRS' + + predictor_params_.update({ + 'net_clicks_limit': 5, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + opt_functor = InputOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + predictor = InputBRSPredictor(net, device, + optimize_target='dmaps' if use_dmaps else 'rgb', + opt_functor=opt_functor, + with_flip=with_flip, + zoom_in=zoom_in, + **predictor_params_) + else: + raise NotImplementedError + + return predictor diff --git a/isegm/inference/predictors/base.py b/isegm/inference/predictors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..870311726adfc0d5a6600f590f834a973dbefce0 --- /dev/null +++ b/isegm/inference/predictors/base.py @@ -0,0 +1,126 @@ +import torch +import torch.nn.functional as F +from torchvision import transforms +from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide + + +class BasePredictor(object): + def __init__(self, model, device, + net_clicks_limit=None, + with_flip=False, + zoom_in=None, + max_size=None, + **kwargs): + self.with_flip = with_flip + self.net_clicks_limit = net_clicks_limit + self.original_image = None + self.device = device + self.zoom_in = zoom_in + self.prev_prediction = None + self.model_indx = 0 + self.click_models = None + self.net_state_dict = None + + if isinstance(model, tuple): + self.net, self.click_models = model + else: + self.net = model + + self.to_tensor = transforms.ToTensor() + + self.transforms = [zoom_in] if zoom_in is not None else [] + if max_size is not None: + self.transforms.append(LimitLongestSide(max_size=max_size)) + self.transforms.append(SigmoidForPred()) + if with_flip: + self.transforms.append(AddHorizontalFlip()) + + def set_input_image(self, image): + image_nd = self.to_tensor(image) + for transform in self.transforms: + transform.reset() + self.original_image = image_nd.to(self.device) + if len(self.original_image.shape) == 3: + self.original_image = self.original_image.unsqueeze(0) + self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :]) + + def get_prediction(self, clicker, prev_mask=None): + clicks_list = clicker.get_clicks() + + if self.click_models is not None: + model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1 + if model_indx != self.model_indx: + self.model_indx = model_indx + self.net = self.click_models[model_indx] + + input_image = self.original_image + if prev_mask is None: + prev_mask = self.prev_prediction + if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask: + input_image = torch.cat((input_image, prev_mask), dim=1) + image_nd, clicks_lists, is_image_changed = self.apply_transforms( + input_image, [clicks_list] + ) + + pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) + prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, + size=image_nd.size()[2:]) + + for t in reversed(self.transforms): + prediction = t.inv_transform(prediction) + + if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): + return self.get_prediction(clicker) + + self.prev_prediction = prediction + return prediction.cpu().numpy()[0, 0] + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + return self.net(image_nd, points_nd)['instances'] + + def _get_transform_states(self): + return [x.get_state() for x in self.transforms] + + def _set_transform_states(self, states): + assert len(states) == len(self.transforms) + for state, transform in zip(states, self.transforms): + transform.set_state(state) + + def apply_transforms(self, image_nd, clicks_lists): + is_image_changed = False + for t in self.transforms: + image_nd, clicks_lists = t.transform(image_nd, clicks_lists) + is_image_changed |= t.image_changed + + return image_nd, clicks_lists, is_image_changed + + def get_points_nd(self, clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + if self.net_clicks_limit is not None: + num_max_points = min(self.net_clicks_limit, num_max_points) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:self.net_clicks_limit] + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device=self.device) + + def get_states(self): + return { + 'transform_states': self._get_transform_states(), + 'prev_prediction': self.prev_prediction.clone() + } + + def set_states(self, states): + self._set_transform_states(states['transform_states']) + self.prev_prediction = states['prev_prediction'] diff --git a/isegm/inference/predictors/brs.py b/isegm/inference/predictors/brs.py new file mode 100644 index 0000000000000000000000000000000000000000..910e3fd52471c39fe56668575765adcc00393d3d --- /dev/null +++ b/isegm/inference/predictors/brs.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.optimize import fmin_l_bfgs_b + +from .base import BasePredictor + + +class BRSBasePredictor(BasePredictor): + def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs): + super().__init__(model, device, **kwargs) + self.optimize_after_n_clicks = optimize_after_n_clicks + self.opt_functor = opt_functor + + self.opt_data = None + self.input_data = None + + def set_input_image(self, image): + super().set_input_image(image) + self.opt_data = None + self.input_data = None + + def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1): + pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + + for list_indx, clicks_list in enumerate(clicks_lists): + for click in clicks_list: + y, x = click.coords + y, x = int(round(y)), int(round(x)) + y1, x1 = y - radius, x - radius + y2, x2 = y + radius + 1, x + radius + 1 + + if click.is_positive: + pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + else: + neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + + with torch.no_grad(): + pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device) + neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device) + + return pos_clicks_map, neg_clicks_map + + def get_states(self): + return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data} + + def set_states(self, states): + self._set_transform_states(states['transform_states']) + self.opt_data = states['opt_data'] + + +class FeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'after_deeplab': + self.num_channels = model.feature_extractor.ch + elif self.insertion_mode == 'after_c4': + self.num_channels = model.feature_extractor.aspp_in_channels + elif self.insertion_mode == 'after_aspp': + self.num_channels = model.feature_extractor.ch + 32 + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'after_c4': + x = self.net.feature_extractor.aspp(scaled_backbone_features) + x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:], + align_corners=True) + x = torch.cat((x, self._c1_features), dim=1) + scaled_backbone_features = self.net.feature_extractor.head(x) + elif self.insertion_mode == 'after_aspp': + scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features) + + pred_logits = self.net.head(scaled_backbone_features) + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + image_nd, prev_mask = self.net.prepare_input(image_nd) + coord_features = self.net.get_coord_features(image_nd, prev_mask, points) + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + additional_features = None + elif hasattr(self.net, 'maps_transform'): + x = image_nd + additional_features = self.net.maps_transform(coord_features) + + if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp': + c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features) + c1 = self.net.feature_extractor.skip_project(c1) + + if self.insertion_mode == 'after_aspp': + x = self.net.feature_extractor.aspp(c4) + x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + backbone_features = x + else: + backbone_features = c4 + self._c1_features = c1 + else: + backbone_features = self.net.feature_extractor(x, additional_features)[0] + + return backbone_features + + +class HRNetFeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'A': + self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8]) + elif self.insertion_mode == 'C': + self.num_channels = 2 * model.feature_extractor.ocr_width + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'A': + if self.net.feature_extractor.ocr_width > 0: + out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features) + feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + feats = self.net.feature_extractor.ocr_distri_head(feats, context) + else: + feats = scaled_backbone_features + pred_logits = self.net.feature_extractor.cls_head(feats) + elif self.insertion_mode == 'C': + pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features) + else: + raise NotImplementedError + + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + image_nd, prev_mask = self.net.prepare_input(image_nd) + coord_features = self.net.get_coord_features(image_nd, prev_mask, points) + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + additional_features = None + elif hasattr(self.net, 'maps_transform'): + x = image_nd + additional_features = self.net.maps_transform(coord_features) + + feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features) + + if self.insertion_mode == 'A': + backbone_features = feats + elif self.insertion_mode == 'C': + out_aux = self.net.feature_extractor.aux_head(feats) + feats = self.net.feature_extractor.conv3x3_ocr(feats) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context) + else: + raise NotImplementedError + + return backbone_features + + +class InputBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.optimize_target = optimize_target + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + + if self.opt_data is None or is_image_changed: + if self.optimize_target == 'dmaps': + opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch + else: + opt_channels = 3 + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), + device=self.device, dtype=torch.float32) + + def get_prediction_logits(opt_bias): + input_image, prev_mask = self.net.prepare_input(image_nd) + dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd) + + if self.optimize_target == 'rgb': + input_image = input_image + opt_bias + elif self.optimize_target == 'dmaps': + if self.net.with_prev_mask: + dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias + else: + dmaps = dmaps + opt_bias + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1)) + if self.optimize_target == 'all': + x = x + opt_bias + coord_features = None + elif hasattr(self.net, 'maps_transform'): + x = input_image + coord_features = self.net.maps_transform(dmaps) + + pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances'] + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) + + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device, + shape=self.opt_data.shape) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(), + **self.opt_functor.optimizer_params) + + self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device) + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits diff --git a/isegm/inference/predictors/brs_functors.py b/isegm/inference/predictors/brs_functors.py new file mode 100644 index 0000000000000000000000000000000000000000..f919e13c6c9edb6a9eb7c4afc37933db7b303c12 --- /dev/null +++ b/isegm/inference/predictors/brs_functors.py @@ -0,0 +1,109 @@ +import torch +import numpy as np + +from isegm.model.metrics import _compute_iou +from .brs_losses import BRSMaskLoss + + +class BaseOptimizer: + def __init__(self, optimizer_params, + prob_thresh=0.49, + reg_weight=1e-3, + min_iou_diff=0.01, + brs_loss=BRSMaskLoss(), + with_flip=False, + flip_average=False, + **kwargs): + self.brs_loss = brs_loss + self.optimizer_params = optimizer_params + self.prob_thresh = prob_thresh + self.reg_weight = reg_weight + self.min_iou_diff = min_iou_diff + self.with_flip = with_flip + self.flip_average = flip_average + + self.best_prediction = None + self._get_prediction_logits = None + self._opt_shape = None + self._best_loss = None + self._click_masks = None + self._last_mask = None + self.device = None + + def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): + self.best_prediction = None + self._get_prediction_logits = get_prediction_logits + self._click_masks = (pos_mask, neg_mask) + self._opt_shape = shape + self._last_mask = None + self.device = device + + def __call__(self, x): + opt_params = torch.from_numpy(x).float().to(self.device) + opt_params.requires_grad_(True) + + with torch.enable_grad(): + opt_vars, reg_loss = self.unpack_opt_params(opt_params) + result_before_sigmoid = self._get_prediction_logits(*opt_vars) + result = torch.sigmoid(result_before_sigmoid) + + pos_mask, neg_mask = self._click_masks + if self.with_flip and self.flip_average: + result, result_flipped = torch.chunk(result, 2, dim=0) + result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) + pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] + + loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) + loss = loss + reg_loss + + f_val = loss.detach().cpu().numpy() + if self.best_prediction is None or f_val < self._best_loss: + self.best_prediction = result_before_sigmoid.detach() + self._best_loss = f_val + + if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: + return [f_val, np.zeros_like(x)] + + current_mask = result > self.prob_thresh + if self._last_mask is not None and self.min_iou_diff > 0: + diff_iou = _compute_iou(current_mask, self._last_mask) + if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: + return [f_val, np.zeros_like(x)] + self._last_mask = current_mask + + loss.backward() + f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) + + return [f_val, f_grad] + + def unpack_opt_params(self, opt_params): + raise NotImplementedError + + +class InputOptimizer(BaseOptimizer): + def unpack_opt_params(self, opt_params): + opt_params = opt_params.view(self._opt_shape) + if self.with_flip: + opt_params_flipped = torch.flip(opt_params, dims=[3]) + opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) + reg_loss = self.reg_weight * torch.sum(opt_params**2) + + return (opt_params,), reg_loss + + +class ScaleBiasOptimizer(BaseOptimizer): + def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): + super().__init__(*args, **kwargs) + self.scale_act = scale_act + self.reg_bias_weight = reg_bias_weight + + def unpack_opt_params(self, opt_params): + scale, bias = torch.chunk(opt_params, 2, dim=0) + reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) + + if self.scale_act == 'tanh': + scale = torch.tanh(scale) + elif self.scale_act == 'sin': + scale = torch.sin(scale) + + return (1 + scale, bias), reg_loss diff --git a/isegm/inference/predictors/brs_losses.py b/isegm/inference/predictors/brs_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..ea98824356cf5a4d09094fb92c13ee8d8dfe15dc --- /dev/null +++ b/isegm/inference/predictors/brs_losses.py @@ -0,0 +1,58 @@ +import torch + +from isegm.model.losses import SigmoidBinaryCrossEntropyLoss + + +class BRSMaskLoss(torch.nn.Module): + def __init__(self, eps=1e-5): + super().__init__() + self._eps = eps + + def forward(self, result, pos_mask, neg_mask): + pos_diff = (1 - result) * pos_mask + pos_target = torch.sum(pos_diff ** 2) + pos_target = pos_target / (torch.sum(pos_mask) + self._eps) + + neg_diff = result * neg_mask + neg_target = torch.sum(neg_diff ** 2) + neg_target = neg_target / (torch.sum(neg_mask) + self._eps) + + loss = pos_target + neg_target + + with torch.no_grad(): + f_max_pos = torch.max(torch.abs(pos_diff)).item() + f_max_neg = torch.max(torch.abs(neg_diff)).item() + + return loss, f_max_pos, f_max_neg + + +class OracleMaskLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.gt_mask = None + self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) + self.predictor = None + self.history = [] + + def set_gt_mask(self, gt_mask): + self.gt_mask = gt_mask + self.history = [] + + def forward(self, result, pos_mask, neg_mask): + gt_mask = self.gt_mask.to(result.device) + if self.predictor.object_roi is not None: + r1, r2, c1, c2 = self.predictor.object_roi[:4] + gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] + gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) + + if result.shape[0] == 2: + gt_mask_flipped = torch.flip(gt_mask, dims=[3]) + gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) + + loss = self.loss(result, gt_mask) + self.history.append(loss.detach().cpu().numpy()[0]) + + if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: + return 0, 0, 0 + + return loss, 1.0, 1.0 diff --git a/isegm/inference/transforms/__init__.py b/isegm/inference/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd54e38a2f84b3fef481672a7ceab070eb01b82 --- /dev/null +++ b/isegm/inference/transforms/__init__.py @@ -0,0 +1,5 @@ +from .base import SigmoidForPred +from .flip import AddHorizontalFlip +from .zoom_in import ZoomIn +from .limit_longest_side import LimitLongestSide +from .crops import Crops diff --git a/isegm/inference/transforms/base.py b/isegm/inference/transforms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5a2deb3c44f5aed7530fd1e299fff1273737b8 --- /dev/null +++ b/isegm/inference/transforms/base.py @@ -0,0 +1,38 @@ +import torch + + +class BaseTransform(object): + def __init__(self): + self.image_changed = False + + def transform(self, image_nd, clicks_lists): + raise NotImplementedError + + def inv_transform(self, prob_map): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def get_state(self): + raise NotImplementedError + + def set_state(self, state): + raise NotImplementedError + + +class SigmoidForPred(BaseTransform): + def transform(self, image_nd, clicks_lists): + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + return torch.sigmoid(prob_map) + + def reset(self): + pass + + def get_state(self): + return None + + def set_state(self, state): + pass diff --git a/isegm/inference/transforms/crops.py b/isegm/inference/transforms/crops.py new file mode 100644 index 0000000000000000000000000000000000000000..428d977295e2ff973b5aa1bf0a0c955df1235614 --- /dev/null +++ b/isegm/inference/transforms/crops.py @@ -0,0 +1,97 @@ +import math + +import torch +import numpy as np +from typing import List + +from isegm.inference.clicker import Click +from .base import BaseTransform + + +class Crops(BaseTransform): + def __init__(self, crop_size=(320, 480), min_overlap=0.2): + super().__init__() + self.crop_height, self.crop_width = crop_size + self.min_overlap = min_overlap + + self.x_offsets = None + self.y_offsets = None + self._counts = None + + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_height, image_width = image_nd.shape[2:4] + self._counts = None + + if image_height < self.crop_height or image_width < self.crop_width: + return image_nd, clicks_lists + + self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) + self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) + self._counts = np.zeros((image_height, image_width)) + + image_crops = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 + image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] + image_crops.append(image_crop) + image_crops = torch.cat(image_crops, dim=0) + self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) + + clicks_list = clicks_lists[0] + clicks_lists = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] + clicks_lists.append(crop_clicks) + + return image_crops, clicks_lists + + def inv_transform(self, prob_map): + if self._counts is None: + return prob_map + + new_prob_map = torch.zeros((1, 1, *self._counts.shape), + dtype=prob_map.dtype, device=prob_map.device) + + crop_indx = 0 + for dy in self.y_offsets: + for dx in self.x_offsets: + new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] + crop_indx += 1 + new_prob_map = torch.div(new_prob_map, self._counts) + + return new_prob_map + + def get_state(self): + return self.x_offsets, self.y_offsets, self._counts + + def set_state(self, state): + self.x_offsets, self.y_offsets, self._counts = state + + def reset(self): + self.x_offsets = None + self.y_offsets = None + self._counts = None + + +def get_offsets(length, crop_size, min_overlap_ratio=0.2): + if length == crop_size: + return [0] + + N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) + N = math.ceil(N) + + overlap_ratio = (N - length / crop_size) / (N - 1) + overlap_width = int(crop_size * overlap_ratio) + + offsets = [0] + for i in range(1, N): + new_offset = offsets[-1] + crop_size - overlap_width + if new_offset + crop_size > length: + new_offset = length - crop_size + + offsets.append(new_offset) + + return offsets diff --git a/isegm/inference/transforms/flip.py b/isegm/inference/transforms/flip.py new file mode 100644 index 0000000000000000000000000000000000000000..373640ebe153ae8a53c136c72f13e0c14aa788ec --- /dev/null +++ b/isegm/inference/transforms/flip.py @@ -0,0 +1,37 @@ +import torch + +from typing import List +from isegm.inference.clicker import Click +from .base import BaseTransform + + +class AddHorizontalFlip(BaseTransform): + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert len(image_nd.shape) == 4 + image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) + + image_width = image_nd.shape[3] + clicks_lists_flipped = [] + for clicks_list in clicks_lists: + clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1)) + for click in clicks_list] + clicks_lists_flipped.append(clicks_list_flipped) + clicks_lists = clicks_lists + clicks_lists_flipped + + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 + num_maps = prob_map.shape[0] // 2 + prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] + + return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) + + def get_state(self): + return None + + def set_state(self, state): + pass + + def reset(self): + pass diff --git a/isegm/inference/transforms/limit_longest_side.py b/isegm/inference/transforms/limit_longest_side.py new file mode 100644 index 0000000000000000000000000000000000000000..50c5a53d2670df52285621dc0d33e86df520d77c --- /dev/null +++ b/isegm/inference/transforms/limit_longest_side.py @@ -0,0 +1,22 @@ +from .zoom_in import ZoomIn, get_roi_image_nd + + +class LimitLongestSide(ZoomIn): + def __init__(self, max_size=800): + super().__init__(target_size=max_size, skip_clicks=0) + + def transform(self, image_nd, clicks_lists): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_max_size = max(image_nd.shape[2:4]) + self.image_changed = False + + if image_max_size <= self.target_size: + return image_nd, clicks_lists + self._input_image = image_nd + + self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + self.image_changed = True + + tclicks_lists = [self._transform_clicks(clicks_lists[0])] + return self._roi_image, tclicks_lists diff --git a/isegm/inference/transforms/zoom_in.py b/isegm/inference/transforms/zoom_in.py new file mode 100644 index 0000000000000000000000000000000000000000..04b576a3e351aa7ad723fd447b309615648bc55d --- /dev/null +++ b/isegm/inference/transforms/zoom_in.py @@ -0,0 +1,175 @@ +import torch + +from typing import List +from isegm.inference.clicker import Click +from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox +from .base import BaseTransform + + +class ZoomIn(BaseTransform): + def __init__(self, + target_size=400, + skip_clicks=1, + expansion_ratio=1.4, + min_crop_size=200, + recompute_thresh_iou=0.5, + prob_thresh=0.50): + super().__init__() + self.target_size = target_size + self.min_crop_size = min_crop_size + self.skip_clicks = skip_clicks + self.expansion_ratio = expansion_ratio + self.recompute_thresh_iou = recompute_thresh_iou + self.prob_thresh = prob_thresh + + self._input_image_shape = None + self._prev_probs = None + self._object_roi = None + self._roi_image = None + + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + self.image_changed = False + + clicks_list = clicks_lists[0] + if len(clicks_list) <= self.skip_clicks: + return image_nd, clicks_lists + + self._input_image_shape = image_nd.shape + + current_object_roi = None + if self._prev_probs is not None: + current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if current_pred_mask.sum() > 0: + current_object_roi = get_object_roi(current_pred_mask, clicks_list, + self.expansion_ratio, self.min_crop_size) + + if current_object_roi is None: + if self.skip_clicks >= 0: + return image_nd, clicks_lists + else: + current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1 + + update_object_roi = False + if self._object_roi is None: + update_object_roi = True + elif not check_object_roi(self._object_roi, clicks_list): + update_object_roi = True + elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou: + update_object_roi = True + + if update_object_roi: + self._object_roi = current_object_roi + self.image_changed = True + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + + tclicks_lists = [self._transform_clicks(clicks_list)] + return self._roi_image.to(image_nd.device), tclicks_lists + + def inv_transform(self, prob_map): + if self._object_roi is None: + self._prev_probs = prob_map.cpu().numpy() + return prob_map + + assert prob_map.shape[0] == 1 + rmin, rmax, cmin, cmax = self._object_roi + prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1), + mode='bilinear', align_corners=True) + + if self._prev_probs is not None: + new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype) + new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map + else: + new_prob_map = prob_map + + self._prev_probs = new_prob_map.cpu().numpy() + + return new_prob_map + + def check_possible_recalculation(self): + if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0: + return False + + pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if pred_mask.sum() > 0: + possible_object_roi = get_object_roi(pred_mask, [], + self.expansion_ratio, self.min_crop_size) + image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1) + if get_bbox_iou(possible_object_roi, image_roi) < 0.50: + return True + return False + + def get_state(self): + roi_image = self._roi_image.cpu() if self._roi_image is not None else None + return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed + + def set_state(self, state): + self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state + + def reset(self): + self._input_image_shape = None + self._object_roi = None + self._prev_probs = None + self._roi_image = None + self.image_changed = False + + def _transform_clicks(self, clicks_list): + if self._object_roi is None: + return clicks_list + + rmin, rmax, cmin, cmax = self._object_roi + crop_height, crop_width = self._roi_image.shape[2:] + + transformed_clicks = [] + for click in clicks_list: + new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1) + new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1) + transformed_clicks.append(click.copy(coords=(new_r, new_c))) + return transformed_clicks + + +def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size): + pred_mask = pred_mask.copy() + + for click in clicks_list: + if click.is_positive: + pred_mask[int(click.coords[0]), int(click.coords[1])] = 1 + + bbox = get_bbox_from_mask(pred_mask) + bbox = expand_bbox(bbox, expansion_ratio, min_crop_size) + h, w = pred_mask.shape[0], pred_mask.shape[1] + bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1) + + return bbox + + +def get_roi_image_nd(image_nd, object_roi, target_size): + rmin, rmax, cmin, cmax = object_roi + + height = rmax - rmin + 1 + width = cmax - cmin + 1 + + if isinstance(target_size, tuple): + new_height, new_width = target_size + else: + scale = target_size / max(height, width) + new_height = int(round(height * scale)) + new_width = int(round(width * scale)) + + with torch.no_grad(): + roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1] + roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width), + mode='bilinear', align_corners=True) + + return roi_image_nd + + +def check_object_roi(object_roi, clicks_list): + for click in clicks_list: + if click.is_positive: + if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]: + return False + if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]: + return False + + return True diff --git a/isegm/inference/utils.py b/isegm/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7102d4075c9812de6a26b21a8a8946c44c3ddb3f --- /dev/null +++ b/isegm/inference/utils.py @@ -0,0 +1,143 @@ +from datetime import timedelta +from pathlib import Path + +import torch +import numpy as np + +from isegm.data.datasets import GrabCutDataset, BerkeleyDataset, DavisDataset, SBDEvaluationDataset, PascalVocDataset +from isegm.utils.serialization import load_model + + +def get_time_metrics(all_ious, elapsed_time): + n_images = len(all_ious) + n_clicks = sum(map(len, all_ious)) + + mean_spc = elapsed_time / n_clicks + mean_spi = elapsed_time / n_images + + return mean_spc, mean_spi + + +def load_is_model(checkpoint, device, **kwargs): + if isinstance(checkpoint, (str, Path)): + state_dict = torch.load(checkpoint, map_location='cpu') + else: + state_dict = checkpoint + + if isinstance(state_dict, list): + model = load_single_is_model(state_dict[0], device, **kwargs) + models = [load_single_is_model(x, device, **kwargs) for x in state_dict] + + return model, models + else: + return load_single_is_model(state_dict, device, **kwargs) + + +def load_single_is_model(state_dict, device, **kwargs): + model = load_model(state_dict['config'], **kwargs) + model.load_state_dict(state_dict['state_dict'], strict=False) + + for param in model.parameters(): + param.requires_grad = False + model.to(device) + model.eval() + + return model + + +def get_dataset(dataset_name, cfg): + if dataset_name == 'GrabCut': + dataset = GrabCutDataset(cfg.GRABCUT_PATH) + elif dataset_name == 'Berkeley': + dataset = BerkeleyDataset(cfg.BERKELEY_PATH) + elif dataset_name == 'DAVIS': + dataset = DavisDataset(cfg.DAVIS_PATH) + elif dataset_name == 'SBD': + dataset = SBDEvaluationDataset(cfg.SBD_PATH) + elif dataset_name == 'SBD_Train': + dataset = SBDEvaluationDataset(cfg.SBD_PATH, split='train') + elif dataset_name == 'PascalVOC': + dataset = PascalVocDataset(cfg.PASCALVOC_PATH, split='test') + elif dataset_name == 'COCO_MVal': + dataset = DavisDataset(cfg.COCO_MVAL_PATH) + else: + dataset = None + + return dataset + + +def get_iou(gt_mask, pred_mask, ignore_label=-1): + ignore_gt_mask_inv = gt_mask != ignore_label + obj_gt_mask = gt_mask == 1 + + intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + + return intersection / union + + +def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): + def _get_noc(iou_arr, iou_thr): + vals = iou_arr >= iou_thr + return np.argmax(vals) + 1 if np.any(vals) else max_clicks + + noc_list = [] + over_max_list = [] + for iou_thr in iou_thrs: + scores_arr = np.array([_get_noc(iou_arr, iou_thr) + for iou_arr in all_ious], dtype=np.int) + + score = scores_arr.mean() + over_max = (scores_arr == max_clicks).sum() + + noc_list.append(score) + over_max_list.append(over_max) + + return noc_list, over_max_list + + +def find_checkpoint(weights_folder, checkpoint_name): + weights_folder = Path(weights_folder) + if ':' in checkpoint_name: + model_name, checkpoint_name = checkpoint_name.split(':') + models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] + assert len(models_candidates) == 1 + model_folder = models_candidates[0] + else: + model_folder = weights_folder + + if checkpoint_name.endswith('.pth'): + if Path(checkpoint_name).exists(): + checkpoint_path = checkpoint_name + else: + checkpoint_path = weights_folder / checkpoint_name + else: + model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) + assert len(model_checkpoints) == 1 + checkpoint_path = model_checkpoints[0] + + return str(checkpoint_path) + + +def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, + n_clicks=20, model_name=None): + table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|' + f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' + f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' + f'{"SPC,s":^7}|{"Time":^9}|') + row_width = len(table_header) + + header = f'Eval results for model: {model_name}\n' if model_name is not None else '' + header += '-' * row_width + '\n' + header += table_header + '\n' + '-' * row_width + + eval_time = str(timedelta(seconds=int(elapsed_time))) + table_row = f'|{brs_type:^13}|{dataset_name:^11}|' + table_row += f'{noc_list[0]:^9.2f}|' + table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' + + return header, table_row \ No newline at end of file diff --git a/isegm/model/initializer.py b/isegm/model/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..470c7df4659bc1e80ceec80a170b3b2e0302fb84 --- /dev/null +++ b/isegm/model/initializer.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import numpy as np + + +class Initializer(object): + def __init__(self, local_init=True, gamma=None): + self.local_init = local_init + self.gamma = gamma + + def __call__(self, m): + if getattr(m, '__initialized', False): + return + + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: + if m.weight is not None: + self._init_gamma(m.weight.data) + if m.bias is not None: + self._init_beta(m.bias.data) + else: + if getattr(m, 'weight', None) is not None: + self._init_weight(m.weight.data) + if getattr(m, 'bias', None) is not None: + self._init_bias(m.bias.data) + + if self.local_init: + object.__setattr__(m, '__initialized', True) + + def _init_weight(self, data): + nn.init.uniform_(data, -0.07, 0.07) + + def _init_bias(self, data): + nn.init.constant_(data, 0) + + def _init_gamma(self, data): + if self.gamma is None: + nn.init.constant_(data, 1.0) + else: + nn.init.normal_(data, 1.0, self.gamma) + + def _init_beta(self, data): + nn.init.constant_(data, 0) + + +class Bilinear(Initializer): + def __init__(self, scale, groups, in_channels, **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.groups = groups + self.in_channels = in_channels + + def _init_weight(self, data): + """Reset the weight and bias.""" + bilinear_kernel = self.get_bilinear_kernel(self.scale) + weight = torch.zeros_like(data) + for i in range(self.in_channels): + if self.groups == 1: + j = i + else: + j = 0 + weight[i, j] = bilinear_kernel + data[:] = weight + + @staticmethod + def get_bilinear_kernel(scale): + """Generate a bilinear upsampling kernel.""" + kernel_size = 2 * scale - scale % 2 + scale = (kernel_size + 1) // 2 + center = scale - 0.5 * (1 + kernel_size % 2) + + og = np.ogrid[:kernel_size, :kernel_size] + kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) + + return torch.tensor(kernel, dtype=torch.float32) + + +class XavierGluon(Initializer): + def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): + super().__init__(**kwargs) + + self.rnd_type = rnd_type + self.factor_type = factor_type + self.magnitude = float(magnitude) + + def _init_weight(self, arr): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) + + if self.factor_type == 'avg': + factor = (fan_in + fan_out) / 2.0 + elif self.factor_type == 'in': + factor = fan_in + elif self.factor_type == 'out': + factor = fan_out + else: + raise ValueError('Incorrect factor type') + scale = np.sqrt(self.magnitude / factor) + + if self.rnd_type == 'uniform': + nn.init.uniform_(arr, -scale, scale) + elif self.rnd_type == 'gaussian': + nn.init.normal_(arr, 0, scale) + else: + raise ValueError('Unknown random type') diff --git a/isegm/model/is_deeplab_model.py b/isegm/model/is_deeplab_model.py new file mode 100644 index 0000000000000000000000000000000000000000..45fa55364d14d129889fce083a791be1e48a35c9 --- /dev/null +++ b/isegm/model/is_deeplab_model.py @@ -0,0 +1,25 @@ +import torch.nn as nn + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.deeplab_v3 import DeepLabV3Plus +from .modeling.basic_blocks import SepConvHead +from isegm.model.modifiers import LRMult + + +class DeeplabModel(ISModel): + @serialize + def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, + backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs): + super().__init__(norm_layer=norm_layer, **kwargs) + + self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout, + norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer) + self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult)) + self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, + num_layers=2, norm_layer=norm_layer) + + def backbone_forward(self, image, coord_features=None): + backbone_features = self.feature_extractor(image, coord_features) + + return {'instances': self.head(backbone_features[0])} diff --git a/isegm/model/is_hrnet_model.py b/isegm/model/is_hrnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a82e746adf49e44d7ff011bef3c7cb105ae4cb --- /dev/null +++ b/isegm/model/is_hrnet_model.py @@ -0,0 +1,26 @@ +import torch.nn as nn + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.hrnet_ocr import HighResolutionNet +from isegm.model.modifiers import LRMult + + +class HRNetModel(ISModel): + @serialize + def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1, + norm_layer=nn.BatchNorm2d, **kwargs): + super().__init__(norm_layer=norm_layer, **kwargs) + + self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small, + num_classes=1, norm_layer=norm_layer) + self.feature_extractor.apply(LRMult(backbone_lr_mult)) + if ocr_width > 0: + self.feature_extractor.ocr_distri_head.apply(LRMult(1.0)) + self.feature_extractor.ocr_gather_head.apply(LRMult(1.0)) + self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0)) + + def backbone_forward(self, image, coord_features=None): + net_outputs = self.feature_extractor(image, coord_features) + + return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]} diff --git a/isegm/model/is_model.py b/isegm/model/is_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f6555401b15a0f72c252745da726beaa602e6231 --- /dev/null +++ b/isegm/model/is_model.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import numpy as np + +from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize +from isegm.model.modifiers import LRMult + + +class ISModel(nn.Module): + def __init__(self, use_rgb_conv=True, with_aux_output=False, + norm_radius=260, use_disks=False, cpu_dist_maps=False, + clicks_groups=None, with_prev_mask=False, use_leaky_relu=False, + binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d, + norm_mean_std=([.485, .456, .406], [.229, .224, .225])): + super().__init__() + self.with_aux_output = with_aux_output + self.clicks_groups = clicks_groups + self.with_prev_mask = with_prev_mask + self.binary_prev_mask = binary_prev_mask + self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1]) + + self.coord_feature_ch = 2 + if clicks_groups is not None: + self.coord_feature_ch *= len(clicks_groups) + + if self.with_prev_mask: + self.coord_feature_ch += 1 + + if use_rgb_conv: + rgb_conv_layers = [ + nn.Conv2d(in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1), + norm_layer(6 + self.coord_feature_ch), + nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), + nn.Conv2d(in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1) + ] + self.rgb_conv = nn.Sequential(*rgb_conv_layers) + elif conv_extend: + self.rgb_conv = None + self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=64, + kernel_size=3, stride=2, padding=1) + self.maps_transform.apply(LRMult(0.1)) + else: + self.rgb_conv = None + mt_layers = [ + nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), + nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1), + ScaleLayer(init_value=0.05, lr_mult=1) + ] + self.maps_transform = nn.Sequential(*mt_layers) + + if self.clicks_groups is not None: + self.dist_maps = nn.ModuleList() + for click_radius in self.clicks_groups: + self.dist_maps.append(DistMaps(norm_radius=click_radius, spatial_scale=1.0, + cpu_mode=cpu_dist_maps, use_disks=use_disks)) + else: + self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, + cpu_mode=cpu_dist_maps, use_disks=use_disks) + + def forward(self, image, points): + image, prev_mask = self.prepare_input(image) + coord_features = self.get_coord_features(image, prev_mask, points) + + if self.rgb_conv is not None: + x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) + outputs = self.backbone_forward(x) + else: + coord_features = self.maps_transform(coord_features) + outputs = self.backbone_forward(image, coord_features) + + outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:], + mode='bilinear', align_corners=True) + if self.with_aux_output: + outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:], + mode='bilinear', align_corners=True) + + return outputs + + def prepare_input(self, image): + prev_mask = None + if self.with_prev_mask: + prev_mask = image[:, 3:, :, :] + image = image[:, :3, :, :] + if self.binary_prev_mask: + prev_mask = (prev_mask > 0.5).float() + + image = self.normalization(image) + return image, prev_mask + + def backbone_forward(self, image, coord_features=None): + raise NotImplementedError + + def get_coord_features(self, image, prev_mask, points): + if self.clicks_groups is not None: + points_groups = split_points_by_order(points, groups=(2,) + (1, ) * (len(self.clicks_groups) - 2) + (-1,)) + coord_features = [dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups)] + coord_features = torch.cat(coord_features, dim=1) + else: + coord_features = self.dist_maps(image, points) + + if prev_mask is not None: + coord_features = torch.cat((prev_mask, coord_features), dim=1) + + return coord_features + + +def split_points_by_order(tpoints: torch.Tensor, groups): + points = tpoints.cpu().numpy() + num_groups = len(groups) + bs = points.shape[0] + num_points = points.shape[1] // 2 + + groups = [x if x > 0 else num_points for x in groups] + group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) + for x in groups] + + last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int) + for group_indx, group_size in enumerate(groups): + last_point_indx_group[:, group_indx, 1] = group_size + + for bindx in range(bs): + for pindx in range(2 * num_points): + point = points[bindx, pindx, :] + group_id = int(point[2]) + if group_id < 0: + continue + + is_negative = int(pindx >= num_points) + if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click + group_id = num_groups - 1 + + new_point_indx = last_point_indx_group[bindx, group_id, is_negative] + last_point_indx_group[bindx, group_id, is_negative] += 1 + + group_points[group_id][bindx, new_point_indx, :] = point + + group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) + for x in group_points] + + return group_points diff --git a/isegm/model/losses.py b/isegm/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..b90f18f31e7718cf6c79a267be0ccb0d99797325 --- /dev/null +++ b/isegm/model/losses.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from isegm.utils import misc + + +class NormalizedFocalLossSigmoid(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, + from_sigmoid=False, detach_delimeter=True, + batch_axis=0, weight=None, size_average=True, + ignore_label=-1): + super(NormalizedFocalLossSigmoid, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._from_logits = from_sigmoid + self._eps = eps + self._size_average = size_average + self._detach_delimeter = detach_delimeter + self._max_mult = max_mult + self._k_sum = 0 + self._m_max = 0 + + def forward(self, pred, label): + one_hot = label > 0.5 + sample_weight = label != self._ignore_label + + if not self._from_logits: + pred = torch.sigmoid(pred) + + alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) + pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + + beta = (1 - pt) ** self._gamma + + sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) + beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) + mult = sw_sum / (beta_sum + self._eps) + if self._detach_delimeter: + mult = mult.detach() + beta = beta * mult + if self._max_mult > 0: + beta = torch.clamp_max(beta, self._max_mult) + + with torch.no_grad(): + ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() + sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() + if np.any(ignore_area == 0): + self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() + + beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) + beta_pmax = beta_pmax.mean().item() + self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = self._weight * (loss * sample_weight) + + if self._size_average: + bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return loss + + def log_states(self, sw, name, global_step): + sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) + sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) + + +class FocalLoss(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, + from_logits=False, batch_axis=0, + weight=None, num_class=None, + eps=1e-9, size_average=True, scale=1.0, + ignore_label=-1): + super(FocalLoss, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._scale = scale + self._num_class = num_class + self._from_logits = from_logits + self._eps = eps + self._size_average = size_average + + def forward(self, pred, label, sample_weight=None): + one_hot = label > 0.5 + sample_weight = label != self._ignore_label + + if not self._from_logits: + pred = torch.sigmoid(pred) + + alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) + pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + + beta = (1 - pt) ** self._gamma + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = self._weight * (loss * sample_weight) + + if self._size_average: + tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return self._scale * loss + + +class SoftIoU(nn.Module): + def __init__(self, from_sigmoid=False, ignore_label=-1): + super().__init__() + self._from_sigmoid = from_sigmoid + self._ignore_label = ignore_label + + def forward(self, pred, label): + label = label.view(pred.size()) + sample_weight = label != self._ignore_label + + if not self._from_sigmoid: + pred = torch.sigmoid(pred) + + loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \ + / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8) + + return loss + + +class SigmoidBinaryCrossEntropyLoss(nn.Module): + def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): + super(SigmoidBinaryCrossEntropyLoss, self).__init__() + self._from_sigmoid = from_sigmoid + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + def forward(self, pred, label): + label = label.view(pred.size()) + sample_weight = label != self._ignore_label + label = torch.where(sample_weight, label, torch.zeros_like(label)) + + if not self._from_sigmoid: + loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) + else: + eps = 1e-12 + loss = -(torch.log(pred + eps) * label + + torch.log(1. - pred + eps) * (1. - label)) + + loss = self._weight * (loss * sample_weight) + return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a572dcd97ed2dac222fa51a33657aa5b403dbb2a --- /dev/null +++ b/isegm/model/metrics.py @@ -0,0 +1,101 @@ +import torch +import numpy as np + +from isegm.utils import misc + + +class TrainMetric(object): + def __init__(self, pred_outputs, gt_outputs): + self.pred_outputs = pred_outputs + self.gt_outputs = gt_outputs + + def update(self, *args, **kwargs): + raise NotImplementedError + + def get_epoch_value(self): + raise NotImplementedError + + def reset_epoch_stats(self): + raise NotImplementedError + + def log_states(self, sw, tag_prefix, global_step): + pass + + @property + def name(self): + return type(self).__name__ + + +class AdaptiveIoU(TrainMetric): + def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, + ignore_label=-1, from_logits=True, + pred_output='instances', gt_output='instances'): + super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) + self._ignore_label = ignore_label + self._from_logits = from_logits + self._iou_thresh = init_thresh + self._thresh_step = thresh_step + self._thresh_beta = thresh_beta + self._iou_beta = iou_beta + self._ema_iou = 0.0 + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def update(self, pred, gt): + gt_mask = gt > 0.5 + if self._from_logits: + pred = torch.sigmoid(pred) + + gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() + if np.all(gt_mask_area == 0): + return + + ignore_mask = gt == self._ignore_label + max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() + best_thresh = self._iou_thresh + for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: + temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() + if temp_iou > max_iou: + max_iou = temp_iou + best_thresh = t + + self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh + self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou + self._epoch_iou_sum += max_iou + self._epoch_batch_count += 1 + + def get_epoch_value(self): + if self._epoch_batch_count > 0: + return self._epoch_iou_sum / self._epoch_batch_count + else: + return 0.0 + + def reset_epoch_stats(self): + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def log_states(self, sw, tag_prefix, global_step): + sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) + sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) + + @property + def iou_thresh(self): + return self._iou_thresh + + +def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): + if ignore_mask is not None: + pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) + + reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) + union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + nonzero = union > 0 + + iou = intersection[nonzero] / union[nonzero] + if not keep_ignore: + return iou + else: + result = np.full_like(intersection, -1) + result[nonzero] = iou + return result diff --git a/isegm/model/modeling/basic_blocks.py b/isegm/model/modeling/basic_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..13753e85353ed9250aa3888ab2e715350b1b2c50 --- /dev/null +++ b/isegm/model/modeling/basic_blocks.py @@ -0,0 +1,71 @@ +import torch.nn as nn + +from isegm.model import ops + + +class ConvHead(nn.Module): + def __init__(self, out_channels, in_channels=32, num_layers=1, + kernel_size=3, padding=1, + norm_layer=nn.BatchNorm2d): + super(ConvHead, self).__init__() + convhead = [] + + for i in range(num_layers): + convhead.extend([ + nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), + nn.ReLU(), + norm_layer(in_channels) if norm_layer is not None else nn.Identity() + ]) + convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + self.convhead = nn.Sequential(*convhead) + + def forward(self, *inputs): + return self.convhead(inputs[0]) + + +class SepConvHead(nn.Module): + def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, + kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, + norm_layer=nn.BatchNorm2d): + super(SepConvHead, self).__init__() + + sepconvhead = [] + + for i in range(num_layers): + sepconvhead.append( + SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, + out_channels=mid_channels, + dw_kernel=kernel_size, dw_padding=padding, + norm_layer=norm_layer, activation='relu') + ) + if dropout_ratio > 0 and dropout_indx == i: + sepconvhead.append(nn.Dropout(dropout_ratio)) + + sepconvhead.append( + nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) + ) + + self.layers = nn.Sequential(*sepconvhead) + + def forward(self, *inputs): + x = inputs[0] + + return self.layers(x) + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, + activation=None, use_bias=False, norm_layer=None): + super(SeparableConv2d, self).__init__() + _activation = ops.select_activation_function(activation) + self.body = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, + padding=dw_padding, bias=use_bias, groups=in_channels), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + _activation() + ) + + def forward(self, x): + return self.body(x) diff --git a/isegm/model/modeling/deeplab_v3.py b/isegm/model/modeling/deeplab_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..8219a4ef18048a0fc79fdf3e5b603af7eac03892 --- /dev/null +++ b/isegm/model/modeling/deeplab_v3.py @@ -0,0 +1,176 @@ +from contextlib import ExitStack + +import torch +from torch import nn +import torch.nn.functional as F + +from .basic_blocks import SeparableConv2d +from .resnet import ResNetBackbone +from isegm.model import ops + + +class DeepLabV3Plus(nn.Module): + def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, + backbone_norm_layer=None, + ch=256, + project_dropout=0.5, + inference_mode=False, + **kwargs): + super(DeepLabV3Plus, self).__init__() + if backbone_norm_layer is None: + backbone_norm_layer = norm_layer + + self.backbone_name = backbone + self.norm_layer = norm_layer + self.backbone_norm_layer = backbone_norm_layer + self.inference_mode = False + self.ch = ch + self.aspp_in_channels = 2048 + self.skip_project_in_channels = 256 # layer 1 out_channels + + self._kwargs = kwargs + if backbone == 'resnet34': + self.aspp_in_channels = 512 + self.skip_project_in_channels = 64 + + self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, + norm_layer=self.backbone_norm_layer, **kwargs) + + self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, + norm_layer=self.norm_layer) + self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) + self.aspp = _ASPP(in_channels=self.aspp_in_channels, + atrous_rates=[12, 24, 36], + out_channels=ch, + project_dropout=project_dropout, + norm_layer=self.norm_layer) + + if inference_mode: + self.set_prediction_mode() + + def load_pretrained_weights(self): + pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, + norm_layer=self.backbone_norm_layer, **self._kwargs) + backbone_state_dict = self.backbone.state_dict() + pretrained_state_dict = pretrained.state_dict() + + backbone_state_dict.update(pretrained_state_dict) + self.backbone.load_state_dict(backbone_state_dict) + + if self.inference_mode: + for param in self.backbone.parameters(): + param.requires_grad = False + + def set_prediction_mode(self): + self.inference_mode = True + self.eval() + + def forward(self, x, additional_features=None): + with ExitStack() as stack: + if self.inference_mode: + stack.enter_context(torch.no_grad()) + + c1, _, c3, c4 = self.backbone(x, additional_features) + c1 = self.skip_project(c1) + + x = self.aspp(c4) + x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + x = self.head(x) + + return x, + + +class _SkipProject(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): + super(_SkipProject, self).__init__() + _activation = ops.select_activation_function("relu") + + self.skip_project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + _activation() + ) + + def forward(self, x): + return self.skip_project(x) + + +class _DeepLabHead(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): + super(_DeepLabHead, self).__init__() + + self.block = nn.Sequential( + SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) + ) + + def forward(self, x): + return self.block(x) + + +class _ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates, out_channels=256, + project_dropout=0.5, norm_layer=nn.BatchNorm2d): + super(_ASPP, self).__init__() + + b0 = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) + b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) + b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) + b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) + + self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) + + project = [ + nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ] + if project_dropout > 0: + project.append(nn.Dropout(project_dropout)) + self.project = nn.Sequential(*project) + + def forward(self, x): + x = torch.cat([block(x) for block in self.concurent], dim=1) + + return self.project(x) + + +class _AsppPooling(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer): + super(_AsppPooling, self).__init__() + + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + def forward(self, x): + pool = self.gap(x) + return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) + + +def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): + block = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, padding=atrous_rate, + dilation=atrous_rate, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + return block diff --git a/isegm/model/modeling/hrnet_ocr.py b/isegm/model/modeling/hrnet_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..d386ee0d376df2d498ef3c05f743caaf83374273 --- /dev/null +++ b/isegm/model/modeling/hrnet_ocr.py @@ -0,0 +1,416 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +from .ocr import SpatialOCR_Module, SpatialGather_Module +from .resnetv1b import BasicBlockV1b, BottleneckV1b + +relu_inplace = True + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method,multi_scale_output=True, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self.norm_layer = norm_layer + self.align_corners = align_corners + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=relu_inplace) + + def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, + downsample=downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=self.align_corners) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionNet(nn.Module): + def __init__(self, width, num_classes, ocr_width=256, small=False, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionNet, self).__init__() + self.norm_layer = norm_layer + self.width = width + self.ocr_width = ocr_width + self.align_corners = align_corners + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = norm_layer(64) + self.relu = nn.ReLU(inplace=relu_inplace) + + num_blocks = 2 if small else 4 + + stage1_num_channels = 64 + self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels + + self.stage2_num_branches = 2 + num_channels = [width, 2 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_inchannels) + self.stage2, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], num_channels=num_channels) + + self.stage3_num_branches = 3 + num_channels = [width, 2 * width, 4 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage3, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, + num_modules=3 if small else 4, num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], num_channels=num_channels) + + self.stage4_num_branches = 4 + num_channels = [width, 2 * width, 4 * width, 8 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage4, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + num_branches=self.stage4_num_branches, + num_blocks=4 * [num_blocks], num_channels=num_channels) + + last_inp_channels = np.int(np.sum(pre_stage_channels)) + if self.ocr_width > 0: + ocr_mid_channels = 2 * self.ocr_width + ocr_key_channels = self.ocr_width + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=relu_inplace), + ) + self.ocr_gather_head = SpatialGather_Module(num_classes) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners) + self.cls_head = nn.Conv2d( + ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) + + self.aux_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=1, stride=1, padding=0), + norm_layer(last_inp_channels), + nn.ReLU(inplace=relu_inplace), + nn.Conv2d(last_inp_channels, num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + ) + else: + self.cls_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(last_inp_channels), + nn.ReLU(inplace=relu_inplace), + nn.Conv2d(last_inp_channels, num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + ) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, + downsample=downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, block, num_inchannels, + num_modules, num_branches, num_blocks, num_channels, + fuse_method='SUM', + multi_scale_output=True): + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, additional_features=None): + feats = self.compute_hrnet_feats(x, additional_features) + if self.ocr_width > 0: + out_aux = self.aux_head(feats) + feats = self.conv3x3_ocr(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + out = self.cls_head(feats) + return [out, out_aux] + else: + return [self.cls_head(feats), None] + + def compute_hrnet_feats(self, x, additional_features): + x = self.compute_pre_stage_features(x, additional_features) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_num_branches): + if self.transition2[i] is not None: + if i < self.stage2_num_branches: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_num_branches): + if self.transition3[i] is not None: + if i < self.stage3_num_branches: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + return self.aggregate_hrnet_features(x) + + def compute_pre_stage_features(self, x, additional_features): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + additional_features + x = self.conv2(x) + x = self.bn2(x) + return self.relu(x) + + def aggregate_hrnet_features(self, x): + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + + return torch.cat([x[0], x1, x2, x3], 1) + + def load_pretrained_weights(self, pretrained_path=''): + model_dict = self.state_dict() + + if not os.path.exists(pretrained_path): + print(f'\nFile "{pretrained_path}" does not exist.') + print('You need to specify the correct path to the pre-trained weights.\n' + 'You can download the weights for HRNet from the repository:\n' + 'https://github.com/HRNet/HRNet-Image-Classification') + exit(1) + pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'}) + pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in + pretrained_dict.items()} + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) diff --git a/isegm/model/modeling/ocr.py b/isegm/model/modeling/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..df3b4f67959fc6a088b93ee7a34b15c1e07402df --- /dev/null +++ b/isegm/model/modeling/ocr.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) \ + .permute(0, 2, 1).unsqueeze(3) # batch x k x c + return ocr_context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, + norm_layer, align_corners) + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + +class ObjectAttentionBlock2D(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + + def __init__(self, + in_channels, + key_channels, + scale=1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(ObjectAttentionBlock2D, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.align_corners = align_corners + + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), + mode='bilinear', align_corners=self.align_corners) + + return context diff --git a/isegm/model/modeling/resnet.py b/isegm/model/modeling/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..65fe949cef0035ba691ee319b25a0132d8ad37fe --- /dev/null +++ b/isegm/model/modeling/resnet.py @@ -0,0 +1,43 @@ +import torch +from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s + + +class ResNetBackbone(torch.nn.Module): + def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): + super(ResNetBackbone, self).__init__() + + if backbone == 'resnet34': + pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet50': + pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet101': + pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet152': + pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + else: + raise RuntimeError(f'unknown backbone: {backbone}') + + self.conv1 = pretrained.conv1 + self.bn1 = pretrained.bn1 + self.relu = pretrained.relu + self.maxpool = pretrained.maxpool + self.layer1 = pretrained.layer1 + self.layer2 = pretrained.layer2 + self.layer3 = pretrained.layer3 + self.layer4 = pretrained.layer4 + + def forward(self, x, additional_features=None): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + torch.nn.functional.pad(additional_features, + [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], + mode='constant', value=0) + x = self.maxpool(x) + c1 = self.layer1(x) + c2 = self.layer2(c1) + c3 = self.layer3(c2) + c4 = self.layer4(c3) + + return c1, c2, c3, c4 diff --git a/isegm/model/modeling/resnetv1b.py b/isegm/model/modeling/resnetv1b.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad24cef5bde19f2627cfd3f755636f37cfb39ac --- /dev/null +++ b/isegm/model/modeling/resnetv1b.py @@ -0,0 +1,276 @@ +import torch +import torch.nn as nn +GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' + + +class BasicBlockV1b(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BasicBlockV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn1 = norm_layer(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, + padding=previous_dilation, dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class BottleneckV1b(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BottleneckV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class ResNetV1b(nn.Module): + """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used (default: :class:`nn.BatchNorm2d`) + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + final_drop : float, default 0.0 + Dropout ratio before the final classification layer. + + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, + avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): + self.inplanes = stem_width*2 if deep_stem else 64 + super(ResNetV1b, self).__init__() + if not deep_stem: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(True) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, + norm_layer=norm_layer) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, + norm_layer=norm_layer) + if dilated: + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, + avg_down=avg_down, norm_layer=norm_layer) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = None + if final_drop > 0.0: + self.drop = nn.Dropout(final_drop) + self.fc = nn.Linear(512 * block.expansion, classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, + avg_down=False, norm_layer=nn.BatchNorm2d): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = [] + if avg_down: + if dilation == 1: + downsample.append( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + ) + else: + downsample.append( + nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) + ) + downsample.extend([ + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=1, bias=False), + norm_layer(planes * block.expansion) + ]) + downsample = nn.Sequential(*downsample) + else: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + norm_layer(planes * block.expansion) + ) + + layers = [] + if dilation in (1, 2): + layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + elif dilation == 4: + layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, + previous_dilation=dilation, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop is not None: + x = self.drop(x) + x = self.fc(x) + + return x + + +def _safe_state_dict_filtering(orig_dict, model_dict_keys): + filtered_orig_dict = {} + for k, v in orig_dict.items(): + if k in model_dict_keys: + filtered_orig_dict[k] = v + else: + print(f"[ERROR] Failed to load <{k}> in backbone") + return filtered_orig_dict + + +def resnet34_v1b(pretrained=False, **kwargs): + model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet50_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet101_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet152_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model diff --git a/isegm/model/modifiers.py b/isegm/model/modifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..046221838069e90ae201b9169db159cc69c13244 --- /dev/null +++ b/isegm/model/modifiers.py @@ -0,0 +1,11 @@ + + +class LRMult(object): + def __init__(self, lr_mult=1.): + self.lr_mult = lr_mult + + def __call__(self, m): + if getattr(m, 'weight', None) is not None: + m.weight.lr_mult = self.lr_mult + if getattr(m, 'bias', None) is not None: + m.bias.lr_mult = self.lr_mult diff --git a/isegm/model/ops.py b/isegm/model/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9be9c73cbef7b83645af93e1fa7338fa6513a92b --- /dev/null +++ b/isegm/model/ops.py @@ -0,0 +1,116 @@ +import torch +from torch import nn as nn +import numpy as np +import isegm.model.initializer as initializer + + +def select_activation_function(activation): + if isinstance(activation, str): + if activation.lower() == 'relu': + return nn.ReLU + elif activation.lower() == 'softplus': + return nn.Softplus + else: + raise ValueError(f"Unknown activation type {activation}") + elif isinstance(activation, nn.Module): + return activation + else: + raise ValueError(f"Unknown activation type {activation}") + + +class BilinearConvTranspose2d(nn.ConvTranspose2d): + def __init__(self, in_channels, out_channels, scale, groups=1): + kernel_size = 2 * scale - scale % 2 + self.scale = scale + + super().__init__( + in_channels, out_channels, + kernel_size=kernel_size, + stride=scale, + padding=1, + groups=groups, + bias=False) + + self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) + + +class DistMaps(nn.Module): + def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): + super(DistMaps, self).__init__() + self.spatial_scale = spatial_scale + self.norm_radius = norm_radius + self.cpu_mode = cpu_mode + self.use_disks = use_disks + if self.cpu_mode: + from isegm.utils.cython import get_dist_maps + self._get_dist_maps = get_dist_maps + + def get_coord_features(self, points, batchsize, rows, cols): + if self.cpu_mode: + coords = [] + for i in range(batchsize): + norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius + coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, + norm_delimeter)) + coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + else: + num_points = points.shape[1] // 2 + points = points.view(-1, points.size(2)) + points, points_order = torch.split(points, [2, 1], dim=1) + + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) + col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) + + coord_rows, coord_cols = torch.meshgrid(row_array, col_array) + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + + add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) + coords.add_(-add_xy) + if not self.use_disks: + coords.div_(self.norm_radius * self.spatial_scale) + coords.mul_(coords) + + coords[:, 0] += coords[:, 1] + coords = coords[:, :1] + + coords[invalid_points, :, :, :] = 1e6 + + coords = coords.view(-1, num_points, 1, rows, cols) + coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w + coords = coords.view(-1, 2, rows, cols) + + if self.use_disks: + coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float() + else: + coords.sqrt_().mul_(2).tanh_() + + return coords + + def forward(self, x, coords): + return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) + + +class ScaleLayer(nn.Module): + def __init__(self, init_value=1.0, lr_mult=1): + super().__init__() + self.lr_mult = lr_mult + self.scale = nn.Parameter( + torch.full((1,), init_value / lr_mult, dtype=torch.float32) + ) + + def forward(self, x): + scale = torch.abs(self.scale * self.lr_mult) + return x * scale + + +class BatchImageNormalize: + def __init__(self, mean, std, dtype=torch.float): + self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None] + self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None] + + def __call__(self, tensor): + tensor = tensor.clone() + + tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device)) + return tensor diff --git a/isegm/utils/cython/__init__.py b/isegm/utils/cython/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb66bdbba883b9477bbc1a52d8355131d32a04cb --- /dev/null +++ b/isegm/utils/cython/__init__.py @@ -0,0 +1,2 @@ +# noinspection PyUnresolvedReferences +from .dist_maps import get_dist_maps \ No newline at end of file diff --git a/isegm/utils/cython/_get_dist_maps.pyx b/isegm/utils/cython/_get_dist_maps.pyx new file mode 100644 index 0000000000000000000000000000000000000000..779a7f02ad7c2ba25e68302c6fc6683cd4ab54f7 --- /dev/null +++ b/isegm/utils/cython/_get_dist_maps.pyx @@ -0,0 +1,63 @@ +import numpy as np +cimport cython +cimport numpy as np +from libc.stdlib cimport malloc, free + +ctypedef struct qnode: + int row + int col + int layer + int orig_row + int orig_col + +@cython.infer_types(True) +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +def get_dist_maps(np.ndarray[np.float32_t, ndim=2, mode="c"] points, + int height, int width, float norm_delimeter): + cdef np.ndarray[np.float32_t, ndim=3, mode="c"] dist_maps = \ + np.full((2, height, width), 1e6, dtype=np.float32, order="C") + + cdef int *dxy = [-1, 0, 0, -1, 0, 1, 1, 0] + cdef int i, j, x, y, dx, dy + cdef qnode v + cdef qnode *q = malloc((4 * height * width + 1) * sizeof(qnode)) + cdef int qhead = 0, qtail = -1 + cdef float ndist + + for i in range(points.shape[0]): + x, y = round(points[i, 0]), round(points[i, 1]) + if x >= 0: + qtail += 1 + q[qtail].row = x + q[qtail].col = y + q[qtail].orig_row = x + q[qtail].orig_col = y + if i >= points.shape[0] / 2: + q[qtail].layer = 1 + else: + q[qtail].layer = 0 + dist_maps[q[qtail].layer, x, y] = 0 + + while qtail - qhead + 1 > 0: + v = q[qhead] + qhead += 1 + + for k in range(4): + x = v.row + dxy[2 * k] + y = v.col + dxy[2 * k + 1] + + ndist = ((x - v.orig_row)/norm_delimeter) ** 2 + ((y - v.orig_col)/norm_delimeter) ** 2 + if (x >= 0 and y >= 0 and x < height and y < width and + dist_maps[v.layer, x, y] > ndist): + qtail += 1 + q[qtail].orig_col = v.orig_col + q[qtail].orig_row = v.orig_row + q[qtail].layer = v.layer + q[qtail].row = x + q[qtail].col = y + dist_maps[v.layer, x, y] = ndist + + free(q) + return dist_maps diff --git a/isegm/utils/cython/_get_dist_maps.pyxbld b/isegm/utils/cython/_get_dist_maps.pyxbld new file mode 100644 index 0000000000000000000000000000000000000000..bd4451729201b5ebc6bbbd8f392389ab6b530636 --- /dev/null +++ b/isegm/utils/cython/_get_dist_maps.pyxbld @@ -0,0 +1,7 @@ +import numpy + +def make_ext(modname, pyxfilename): + from distutils.extension import Extension + return Extension(modname, [pyxfilename], + include_dirs=[numpy.get_include()], + extra_compile_args=['-O3'], language='c++') diff --git a/isegm/utils/cython/dist_maps.py b/isegm/utils/cython/dist_maps.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffa1e3f25231cd7c48b66ef8ef5167235c3ea4e --- /dev/null +++ b/isegm/utils/cython/dist_maps.py @@ -0,0 +1,3 @@ +import pyximport; pyximport.install(pyximport=True, language_level=3) +# noinspection PyUnresolvedReferences +from ._get_dist_maps import get_dist_maps \ No newline at end of file diff --git a/isegm/utils/distributed.py b/isegm/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e48f50500ee7440be035b17107573e86bb5d24 --- /dev/null +++ b/isegm/utils/distributed.py @@ -0,0 +1,67 @@ +import torch +from torch import distributed as dist +from torch.utils import data + + +def get_rank(): + if not dist.is_available() or not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def synchronize(): + if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1: + return + dist.barrier() + + +def get_world_size(): + if not dist.is_available() or not dist.is_initialized(): + return 1 + + return dist.get_world_size() + + +def reduce_loss_dict(loss_dict): + world_size = get_world_size() + + if world_size < 2: + return loss_dict + + with torch.no_grad(): + keys = [] + losses = [] + + for k in loss_dict.keys(): + keys.append(k) + losses.append(loss_dict[k]) + + losses = torch.stack(losses, 0) + dist.reduce(losses, dst=0) + + if dist.get_rank() == 0: + losses /= world_size + + reduced_losses = {k: v for k, v in zip(keys, losses)} + + return reduced_losses + + +def get_sampler(dataset, shuffle, distributed): + if distributed: + return data.distributed.DistributedSampler(dataset, shuffle=shuffle) + + if shuffle: + return data.RandomSampler(dataset) + else: + return data.SequentialSampler(dataset) + + +def get_dp_wrapper(distributed): + class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + return DPWrapper diff --git a/isegm/utils/exp.py b/isegm/utils/exp.py new file mode 100644 index 0000000000000000000000000000000000000000..1ff63ccb3524f06d76475f6b7a77058431b2fe14 --- /dev/null +++ b/isegm/utils/exp.py @@ -0,0 +1,187 @@ +import os +import sys +import shutil +import pprint +from pathlib import Path +from datetime import datetime + +import yaml +import torch +from easydict import EasyDict as edict + +from .log import logger, add_logging +from .distributed import synchronize, get_world_size + + +def init_experiment(args, model_name): + model_path = Path(args.model_path) + ftree = get_model_family_tree(model_path, model_name=model_name) + + if ftree is None: + print('Models can only be located in the "models" directory in the root of the repository') + sys.exit(1) + + cfg = load_config(model_path) + update_config(cfg, args) + + cfg.distributed = args.distributed + cfg.local_rank = args.local_rank + if cfg.distributed: + torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.workers > 0: + torch.multiprocessing.set_start_method('forkserver', force=True) + + experiments_path = Path(cfg.EXPS_PATH) + exp_parent_path = experiments_path / '/'.join(ftree) + exp_parent_path.mkdir(parents=True, exist_ok=True) + + if cfg.resume_exp: + exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) + else: + last_exp_indx = find_last_exp_indx(exp_parent_path) + exp_name = f'{last_exp_indx:03d}' + if cfg.exp_name: + exp_name += '_' + cfg.exp_name + exp_path = exp_parent_path / exp_name + synchronize() + if cfg.local_rank == 0: + exp_path.mkdir(parents=True) + + cfg.EXP_PATH = exp_path + cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' + cfg.VIS_PATH = exp_path / 'vis' + cfg.LOGS_PATH = exp_path / 'logs' + + if cfg.local_rank == 0: + cfg.LOGS_PATH.mkdir(exist_ok=True) + cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) + cfg.VIS_PATH.mkdir(exist_ok=True) + + dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py')) + if args.temp_model_path: + shutil.copy(args.temp_model_path, dst_script_path) + os.remove(args.temp_model_path) + else: + shutil.copy(model_path, dst_script_path) + + synchronize() + + if cfg.gpus != '': + gpu_ids = [int(id) for id in cfg.gpus.split(',')] + else: + gpu_ids = list(range(max(cfg.ngpus, get_world_size()))) + cfg.gpus = ','.join([str(id) for id in gpu_ids]) + + cfg.gpu_ids = gpu_ids + cfg.ngpus = len(gpu_ids) + cfg.multi_gpu = cfg.ngpus > 1 + + if cfg.distributed: + cfg.device = torch.device('cuda') + cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]] + torch.cuda.set_device(cfg.gpu_ids[0]) + else: + if cfg.multi_gpu: + os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus + ngpus = torch.cuda.device_count() + assert ngpus == cfg.ngpus + cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}') + + if cfg.local_rank == 0: + add_logging(cfg.LOGS_PATH, prefix='train_') + logger.info(f'Number of GPUs: {cfg.ngpus}') + if cfg.distributed: + logger.info(f'Multi-Process Multi-GPU Distributed Training') + + logger.info('Run experiment with config:') + logger.info(pprint.pformat(cfg, indent=4)) + + return cfg + + +def get_model_family_tree(model_path, terminate_name='models', model_name=None): + if model_name is None: + model_name = model_path.stem + family_tree = [model_name] + for x in model_path.parents: + if x.stem == terminate_name: + break + family_tree.append(x.stem) + else: + return None + + return family_tree[::-1] + + +def find_last_exp_indx(exp_parent_path): + indx = 0 + for x in exp_parent_path.iterdir(): + if not x.is_dir(): + continue + + exp_name = x.stem + if exp_name[:3].isnumeric(): + indx = max(indx, int(exp_name[:3]) + 1) + + return indx + + +def find_resume_exp(exp_parent_path, exp_pattern): + candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*')) + if len(candidates) == 0: + print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"') + sys.exit(1) + elif len(candidates) > 1: + print('More than one experiment found:') + for x in candidates: + print(x) + sys.exit(1) + else: + exp_path = candidates[0] + print(f'Continue with experiment "{exp_path}"') + + return exp_path + + +def update_config(cfg, args): + for param_name, value in vars(args).items(): + if param_name.lower() in cfg or param_name.upper() in cfg: + continue + cfg[param_name] = value + + +def load_config(model_path): + model_name = model_path.stem + config_path = model_path.parent / (model_name + '.yml') + + if config_path.exists(): + cfg = load_config_file(config_path) + else: + cfg = dict() + + cwd = Path.cwd() + config_parent = config_path.parent.absolute() + while len(config_parent.parents) > 0: + config_path = config_parent / 'config.yml' + + if config_path.exists(): + local_config = load_config_file(config_path, model_name=model_name) + cfg.update({k: v for k, v in local_config.items() if k not in cfg}) + + if config_parent.absolute() == cwd: + break + config_parent = config_parent.parent + + return edict(cfg) + + +def load_config_file(config_path, model_name=None, return_edict=False): + with open(config_path, 'r') as f: + cfg = yaml.safe_load(f) + + if 'SUBCONFIGS' in cfg: + if model_name is not None and model_name in cfg['SUBCONFIGS']: + cfg.update(cfg['SUBCONFIGS'][model_name]) + del cfg['SUBCONFIGS'] + + return edict(cfg) if return_edict else cfg diff --git a/isegm/utils/exp_imports/default.py b/isegm/utils/exp_imports/default.py new file mode 100644 index 0000000000000000000000000000000000000000..e78e21c85013af8ccd4c23d860c792bc40a2d822 --- /dev/null +++ b/isegm/utils/exp_imports/default.py @@ -0,0 +1,16 @@ +import torch +from functools import partial +from easydict import EasyDict as edict +from albumentations import * + +from isegm.data.datasets import * +from isegm.model.losses import * +from isegm.data.transforms import * +from isegm.engine.trainer import ISTrainer +from isegm.model.metrics import AdaptiveIoU +from isegm.data.points_sampler import MultiPointSampler +from isegm.utils.log import logger +from isegm.model import initializer + +from isegm.model.is_hrnet_model import HRNetModel +from isegm.model.is_deeplab_model import DeeplabModel \ No newline at end of file diff --git a/isegm/utils/log.py b/isegm/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9f8bdb4bdd74d72514db8cf9cecef51001a588 --- /dev/null +++ b/isegm/utils/log.py @@ -0,0 +1,97 @@ +import io +import time +import logging +from datetime import datetime + +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +LOGGER_NAME = 'root' +LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' + +handler = logging.StreamHandler() + +logger = logging.getLogger(LOGGER_NAME) +logger.setLevel(logging.INFO) +logger.addHandler(handler) + + +def add_logging(logs_path, prefix): + log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' + stdout_log_path = logs_path / log_name + + fh = logging.FileHandler(str(stdout_log_path)) + formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s', + datefmt=LOGGER_DATEFMT) + fh.setFormatter(formatter) + logger.addHandler(fh) + + +class TqdmToLogger(io.StringIO): + logger = None + level = None + buf = '' + + def __init__(self, logger, level=None, mininterval=5): + super(TqdmToLogger, self).__init__() + self.logger = logger + self.level = level or logging.INFO + self.mininterval = mininterval + self.last_time = 0 + + def write(self, buf): + self.buf = buf.strip('\r\n\t ') + + def flush(self): + if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: + self.logger.log(self.level, self.buf) + self.last_time = time.time() + + +class SummaryWriterAvg(SummaryWriter): + def __init__(self, *args, dump_period=20, **kwargs): + super().__init__(*args, **kwargs) + self._dump_period = dump_period + self._avg_scalars = dict() + + def add_scalar(self, tag, value, global_step=None, disable_avg=False): + if disable_avg or isinstance(value, (tuple, list, dict)): + super().add_scalar(tag, np.array(value), global_step=global_step) + else: + if tag not in self._avg_scalars: + self._avg_scalars[tag] = ScalarAccumulator(self._dump_period) + avg_scalar = self._avg_scalars[tag] + avg_scalar.add(value) + + if avg_scalar.is_full(): + super().add_scalar(tag, avg_scalar.value, + global_step=global_step) + avg_scalar.reset() + + +class ScalarAccumulator(object): + def __init__(self, period): + self.sum = 0 + self.cnt = 0 + self.period = period + + def add(self, value): + self.sum += value + self.cnt += 1 + + @property + def value(self): + if self.cnt > 0: + return self.sum / self.cnt + else: + return 0 + + def reset(self): + self.cnt = 0 + self.sum = 0 + + def is_full(self): + return self.cnt >= self.period + + def __len__(self): + return self.cnt diff --git a/isegm/utils/misc.py b/isegm/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..688c11e182f1aaea0f23d8e58811f713cf816da9 --- /dev/null +++ b/isegm/utils/misc.py @@ -0,0 +1,86 @@ +import torch +import numpy as np + +from .log import logger + + +def get_dims_with_exclusion(dim, exclude=None): + dims = list(range(dim)) + if exclude is not None: + dims.remove(exclude) + + return dims + + +def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False): + if epoch is None: + checkpoint_name = 'last_checkpoint.pth' + else: + checkpoint_name = f'{epoch:03d}.pth' + + if prefix: + checkpoint_name = f'{prefix}_{checkpoint_name}' + + if not checkpoints_path.exists(): + checkpoints_path.mkdir(parents=True) + + checkpoint_path = checkpoints_path / checkpoint_name + if verbose: + logger.info(f'Save checkpoint to {str(checkpoint_path)}') + + net = net.module if multi_gpu else net + torch.save({'state_dict': net.state_dict(), + 'config': net._config}, str(checkpoint_path)) + + +def get_bbox_from_mask(mask): + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return rmin, rmax, cmin, cmax + + +def expand_bbox(bbox, expand_ratio, min_crop_size=None): + rmin, rmax, cmin, cmax = bbox + rcenter = 0.5 * (rmin + rmax) + ccenter = 0.5 * (cmin + cmax) + height = expand_ratio * (rmax - rmin + 1) + width = expand_ratio * (cmax - cmin + 1) + if min_crop_size is not None: + height = max(height, min_crop_size) + width = max(width, min_crop_size) + + rmin = int(round(rcenter - 0.5 * height)) + rmax = int(round(rcenter + 0.5 * height)) + cmin = int(round(ccenter - 0.5 * width)) + cmax = int(round(ccenter + 0.5 * width)) + + return rmin, rmax, cmin, cmax + + +def clamp_bbox(bbox, rmin, rmax, cmin, cmax): + return (max(rmin, bbox[0]), min(rmax, bbox[1]), + max(cmin, bbox[2]), min(cmax, bbox[3])) + + +def get_bbox_iou(b1, b2): + h_iou = get_segments_iou(b1[:2], b2[:2]) + w_iou = get_segments_iou(b1[2:4], b2[2:4]) + return h_iou * w_iou + + +def get_segments_iou(s1, s2): + a, b = s1 + c, d = s2 + intersection = max(0, min(b, d) - max(a, c) + 1) + union = max(1e-6, max(b, d) - min(a, c) + 1) + return intersection / union + + +def get_labels_with_sizes(x): + obj_sizes = np.bincount(x.flatten()) + labels = np.nonzero(obj_sizes)[0].tolist() + labels = [x for x in labels if x != 0] + return labels, obj_sizes[labels].tolist() diff --git a/isegm/utils/serialization.py b/isegm/utils/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..c73935b9aa7e7f2f5a11c685c4192321da78c5f3 --- /dev/null +++ b/isegm/utils/serialization.py @@ -0,0 +1,107 @@ +from functools import wraps +from copy import deepcopy +import inspect +import torch.nn as nn + + +def serialize(init): + parameters = list(inspect.signature(init).parameters) + + @wraps(init) + def new_init(self, *args, **kwargs): + params = deepcopy(kwargs) + for pname, value in zip(parameters[1:], args): + params[pname] = value + + config = { + 'class': get_classname(self.__class__), + 'params': dict() + } + specified_params = set(params.keys()) + + for pname, param in get_default_params(self.__class__).items(): + if pname not in params: + params[pname] = param.default + + for name, value in list(params.items()): + param_type = 'builtin' + if inspect.isclass(value): + param_type = 'class' + value = get_classname(value) + + config['params'][name] = { + 'type': param_type, + 'value': value, + 'specified': name in specified_params + } + + setattr(self, '_config', config) + init(self, *args, **kwargs) + + return new_init + + +def load_model(config, **kwargs): + model_class = get_class_from_str(config['class']) + model_default_params = get_default_params(model_class) + + model_args = dict() + for pname, param in config['params'].items(): + value = param['value'] + if param['type'] == 'class': + value = get_class_from_str(value) + + if pname not in model_default_params and not param['specified']: + continue + + assert pname in model_default_params + if not param['specified'] and model_default_params[pname].default == value: + continue + model_args[pname] = value + + model_args.update(kwargs) + + return model_class(**model_args) + + +def get_config_repr(config): + config_str = f'Model: {config["class"]}\n' + for pname, param in config['params'].items(): + value = param["value"] + if param['type'] == 'class': + value = value.split('.')[-1] + param_str = f'{pname:<22} = {str(value):<12}' + if not param['specified']: + param_str += ' (default)' + config_str += param_str + '\n' + return config_str + + +def get_default_params(some_class): + params = dict() + for mclass in some_class.mro(): + if mclass is nn.Module or mclass is object: + continue + + mclass_params = inspect.signature(mclass.__init__).parameters + for pname, param in mclass_params.items(): + if param.default != param.empty and pname not in params: + params[pname] = param + + return params + + +def get_classname(cls): + module = cls.__module__ + name = cls.__qualname__ + if module is not None and module != "__builtin__": + name = module + "." + name + return name + + +def get_class_from_str(class_str): + components = class_str.split('.') + mod = __import__('.'.join(components[:-1])) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod diff --git a/isegm/utils/vis.py b/isegm/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..9790a4ca76e7768fd95980cd2d3a492800bfdd1e --- /dev/null +++ b/isegm/utils/vis.py @@ -0,0 +1,135 @@ +from functools import lru_cache + +import cv2 +import numpy as np + + +def visualize_instances(imask, bg_color=255, + boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): + num_objects = imask.max() + 1 + palette = get_palette(num_objects) + if bg_color is not None: + palette[0] = bg_color + + result = palette[imask].astype(np.uint8) + if boundaries_color is not None: + boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width) + tresult = result.astype(np.float32) + tresult[boundaries_mask] = boundaries_color + tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result + result = tresult.astype(np.uint8) + + return result + + +@lru_cache(maxsize=16) +def get_palette(num_cls): + palette = np.zeros(3 * num_cls, dtype=np.int32) + + for j in range(0, num_cls): + lab = j + i = 0 + + while lab > 0: + palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) + palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) + palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) + i = i + 1 + lab >>= 3 + + return palette.reshape((-1, 3)) + + +def visualize_mask(mask, num_cls): + palette = get_palette(num_cls) + mask[mask == -1] = 0 + + return palette[mask].astype(np.uint8) + + +def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1): + proposal_map, colors, candidates = proposals_info + + proposal_map = draw_probmap(proposal_map) + for x, y in candidates: + proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1) + + return proposal_map + + +def draw_probmap(x): + return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT) + + +def draw_points(image, points, color, radius=3): + image = image.copy() + for p in points: + if p[0] < 0: + continue + if len(p) == 3: + pradius = {0: 8, 1: 6, 2: 4}[p[2]] if p[2] < 3 else 2 + else: + pradius = radius + image = cv2.circle(image, (int(p[1]), int(p[0])), pradius, color, -1) + + return image + + +def draw_instance_map(x, palette=None): + num_colors = x.max() + 1 + if palette is None: + palette = get_palette(num_colors) + + return palette[x].astype(np.uint8) + + +def blend_mask(image, mask, alpha=0.6): + if mask.min() == -1: + mask = mask.copy() + 1 + + imap = draw_instance_map(mask) + result = (image * (1 - alpha) + alpha * imap).astype(np.uint8) + return result + + +def get_boundaries(instances_masks, boundaries_width=1): + boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=np.bool) + + for obj_id in np.unique(instances_masks.flatten()): + if obj_id == 0: + continue + + obj_mask = instances_masks == obj_id + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(np.bool) + + obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) + boundaries = np.logical_or(boundaries, obj_boundary) + return boundaries + + +def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), + neg_color=(255, 0, 0), radius=4): + result = img.copy() + + if mask is not None: + palette = get_palette(np.max(mask) + 1) + rgb_mask = palette[mask.astype(np.uint8)] + + mask_region = (mask > 0).astype(np.uint8) + result = result * (1 - mask_region[:, :, np.newaxis]) + \ + (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ + alpha * rgb_mask + result = result.astype(np.uint8) + + # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) + + if clicks_list is not None and len(clicks_list) > 0: + pos_points = [click.coords for click in clicks_list if click.is_positive] + neg_points = [click.coords for click in clicks_list if not click.is_positive] + + result = draw_points(result, pos_points, pos_color, radius=radius) + result = draw_points(result, neg_points, neg_color, radius=radius) + + return result + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b0eada64f4a5bbb6b07717c975520f549286fd65 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +streamlit == 1.20.0 +streamlit-drawable-canvas == 0.9.2 +opencv-python == 4.7.0.72 +torch == 2.0.0 +torchvision == 1.15.1 +tensorboard == 2.12.0 +albumentations == 1.3.0 +numpy == 1.23.5 +Cython == 0.29.33 +wget == 3.2