# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Union import numpy as np from mmpose.utils import adapt_mmdet_pipeline from ...utils import get_config_path from ..node import Node from ..registry import NODES try: from mmdet.apis import inference_detector, init_detector has_mmdet = True except (ImportError, ModuleNotFoundError): has_mmdet = False @NODES.register_module() class DetectorNode(Node): """Detect objects from the frame image using MMDetection model. Note that MMDetection is required for this node. Please refer to `MMDetection documentation `_ for the installation guide. Parameters: name (str): The node name (also thread name) model_cfg (str): The model config file model_checkpoint (str): The model checkpoint file input_buffer (str): The name of the input buffer output_buffer (str|list): The name(s) of the output buffer(s) enable_key (str|int, optional): Set a hot-key to toggle enable/disable of the node. If an int value is given, it will be treated as an ascii code of a key. Please note: (1) If ``enable_key`` is set, the ``bypass()`` method need to be overridden to define the node behavior when disabled; (2) Some hot-keys are reserved for particular use. For example: 'q', 'Q' and 27 are used for exiting. Default: ``None`` enable (bool): Default enable/disable status. Default: ``True`` device (str): Specify the device to hold model weights and inference the model. Default: ``'cuda:0'`` bbox_thr (float): Set a threshold to filter out objects with low bbox scores. Default: 0.5 multi_input (bool): Whether load all frames in input buffer. If True, all frames in buffer will be loaded and stacked. The latest frame is used to detect objects of interest. Default: False Example:: >>> cfg = dict( ... type='DetectorNode', ... name='detector', ... model_config='demo/mmdetection_cfg/' ... 'ssdlite_mobilenetv2_scratch_600e_coco.py', ... model_checkpoint='https://download.openmmlab.com' ... '/mmdetection/v2.0/ssd/' ... 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_' ... 'scratch_600e_coco_20210629_110627-974d9307.pth', ... # `_input_` is an executor-reserved buffer ... input_buffer='_input_', ... output_buffer='det_result') >>> from mmpose.apis.webcam.nodes import NODES >>> node = NODES.build(cfg) """ def __init__(self, name: str, model_config: str, model_checkpoint: str, input_buffer: str, output_buffer: Union[str, List[str]], enable_key: Optional[Union[str, int]] = None, enable: bool = True, device: str = 'cuda:0', bbox_thr: float = 0.5, multi_input: bool = False): # Check mmdetection is installed assert has_mmdet, \ f'MMDetection is required for {self.__class__.__name__}.' super().__init__( name=name, enable_key=enable_key, enable=enable, multi_input=multi_input) self.model_config = get_config_path(model_config, 'mmdet') self.model_checkpoint = model_checkpoint self.device = device.lower() self.bbox_thr = bbox_thr # Init model self.model = init_detector( self.model_config, self.model_checkpoint, device=self.device) self.model.cfg = adapt_mmdet_pipeline(self.model.cfg) # Register buffers self.register_input_buffer(input_buffer, 'input', trigger=True) self.register_output_buffer(output_buffer) def bypass(self, input_msgs): return input_msgs['input'] def process(self, input_msgs): input_msg = input_msgs['input'] if self.multi_input: imgs = [frame.get_image() for frame in input_msg] input_msg = input_msg[-1] img = input_msg.get_image() preds = inference_detector(self.model, img) objects = self._post_process(preds) input_msg.update_objects(objects) if self.multi_input: input_msg.set_image(np.stack(imgs, axis=0)) return input_msg def _post_process(self, preds) -> List[Dict]: """Post-process the predictions of MMDetection model.""" instances = preds.pred_instances.cpu().numpy() classes = self.model.dataset_meta['classes'] if isinstance(classes, str): classes = (classes, ) objects = [] for i in range(len(instances)): if instances.scores[i] < self.bbox_thr: continue class_id = instances.labels[i] obj = { 'class_id': class_id, 'label': classes[class_id], 'bbox': instances.bboxes[i], 'det_model_cfg': self.model.cfg, 'dataset_meta': self.model.dataset_meta.copy(), } objects.append(obj) return objects