""" Author: Siyuan Li Licensed: Apache-2.0 License """ import copy import logging import re import warnings from mmdet.registry import MODELS from mmengine.logging import MMLogger, print_log from mmengine.model.weight_init import (PretrainedInit, initialize, update_init_info) from .grounding_dino import GroundingDINO def clean_label_name(name: str) -> str: name = re.sub(r"\(.*\)", "", name) name = re.sub(r"_", " ", name) name = re.sub(r" ", " ", name) return name def chunks(lst: list, n: int) -> list: """Yield successive n-sized chunks from lst.""" all_ = [] for i in range(0, len(lst), n): data_index = lst[i : i + n] all_.append(data_index) counter = 0 for i in all_: counter += len(i) assert counter == len(lst) return all_ @MODELS.register_module() class GroundingDINOMasa(GroundingDINO): """Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- Training for Open-Set Object Detection. `_ Code is modified from the `official github repo `_. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.track_text_prompt = None self.track_text_dict = None self.token_positive_maps = None self.track_entities = None def init_weights(self) -> None: """Initialize weights for Transformer and other components.""" if self.init_cfg: print_log( f"initialize {self.__class__.__name__} with init_cfg {self.init_cfg}", logger="current", level=logging.DEBUG, ) init_cfgs = self.init_cfg if isinstance(self.init_cfg, dict): init_cfgs = [self.init_cfg] # PretrainedInit has higher priority than any other init_cfg. # Therefore we initialize `pretrained_cfg` last to overwrite # the previous initialized weights. # See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501 other_cfgs = [] pretrained_cfg = [] for init_cfg in init_cfgs: assert isinstance(init_cfg, dict) if ( init_cfg["type"] == "Pretrained" or init_cfg["type"] is PretrainedInit ): pretrained_cfg.append(init_cfg) else: other_cfgs.append(init_cfg) initialize(self, other_cfgs) else: super().init_weights() initialize(self, pretrained_cfg) def predict( self, batch_inputs, detection_features, batch_data_samples, rescale: bool = True ): text_prompts = [] enhanced_text_prompts = [] tokens_positives = [] for data_samples in batch_data_samples: text_prompts.append(data_samples.text) if "caption_prompt" in data_samples: enhanced_text_prompts.append(data_samples.caption_prompt) else: enhanced_text_prompts.append(None) tokens_positives.append(data_samples.get("tokens_positive", None)) if "custom_entities" in batch_data_samples[0]: # Assuming that the `custom_entities` flag # inside a batch is always the same. For single image inference custom_entities = batch_data_samples[0].custom_entities else: custom_entities = False if self.track_text_dict is not None and self.track_text_prompt == text_prompts: # text feature map layer is_rec_tasks = [] for i, data_samples in enumerate(batch_data_samples): if self.token_positive_maps[i] is not None: is_rec_tasks.append(False) else: is_rec_tasks.append(True) data_samples.token_positive_map = self.token_positive_maps[i] visual_feats = detection_features head_inputs_dict = self.forward_transformer( visual_feats, self.track_text_dict, batch_data_samples ) results_list = self.bbox_head.predict( **head_inputs_dict, rescale=rescale, batch_data_samples=batch_data_samples, ) entities = self.track_entities else: self.track_text_prompt = text_prompts if len(text_prompts) == 1: # All the text prompts are the same, # so there is no need to calculate them multiple times. _positive_maps_and_prompts = [ self.get_tokens_positive_and_prompts( text_prompts[0], custom_entities, enhanced_text_prompts[0], tokens_positives[0], ) ] * len(batch_inputs) else: _positive_maps_and_prompts = [ self.get_tokens_positive_and_prompts( text_prompt, custom_entities, enhanced_text_prompt, tokens_positive, ) for text_prompt, enhanced_text_prompt, tokens_positive in zip( text_prompts, enhanced_text_prompts, tokens_positives ) ] token_positive_maps, text_prompts, _, entities = zip( *_positive_maps_and_prompts ) self.token_positive_maps = token_positive_maps self.track_entities = entities # image feature extraction visual_feats = detection_features if isinstance(text_prompts[0], list): # chunked text prompts, only bs=1 is supported assert len(batch_inputs) == 1 count = 0 results_list = [] entities = [[item for lst in entities[0] for item in lst]] for b in range(len(text_prompts[0])): text_prompts_once = [text_prompts[0][b]] token_positive_maps_once = token_positive_maps[0][b] text_dict = self.language_model(text_prompts_once) # text feature map layer if self.text_feat_map is not None: text_dict["embedded"] = self.text_feat_map( text_dict["embedded"] ) batch_data_samples[0].token_positive_map = token_positive_maps_once head_inputs_dict = self.forward_transformer( copy.deepcopy(visual_feats), text_dict, batch_data_samples ) pred_instances = self.bbox_head.predict( **head_inputs_dict, rescale=rescale, batch_data_samples=batch_data_samples, )[0] if len(pred_instances) > 0: pred_instances.labels += count count += len(token_positive_maps_once) results_list.append(pred_instances) results_list = [results_list[0].cat(results_list)] is_rec_tasks = [False] * len(results_list) else: # extract text feats text_dict = self.language_model(list(text_prompts)) # text feature map layer if self.text_feat_map is not None: text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) is_rec_tasks = [] for i, data_samples in enumerate(batch_data_samples): if token_positive_maps[i] is not None: is_rec_tasks.append(False) else: is_rec_tasks.append(True) data_samples.token_positive_map = token_positive_maps[i] if self.track_text_dict is None: self.track_text_dict = text_dict head_inputs_dict = self.forward_transformer( visual_feats, text_dict, batch_data_samples ) results_list = self.bbox_head.predict( **head_inputs_dict, rescale=rescale, batch_data_samples=batch_data_samples, ) for data_sample, pred_instances, entity, is_rec_task in zip( batch_data_samples, results_list, entities, is_rec_tasks ): if len(pred_instances) > 0: label_names = [] for labels in pred_instances.labels: if is_rec_task: label_names.append(entity) continue if labels >= len(entity): warnings.warn( "The unexpected output indicates an issue with " "named entity recognition. You can try " "setting custom_entities=True and running " "again to see if it helps." ) label_names.append("unobject") else: label_names.append(entity[labels]) # for visualization pred_instances.label_names = label_names data_sample.pred_instances = pred_instances return batch_data_samples