Spaces:
Runtime error
Runtime error
File size: 5,395 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import numpy as np
from mmpose.utils import adapt_mmdet_pipeline
from ...utils import get_config_path
from ..node import Node
from ..registry import NODES
try:
from mmdet.apis import inference_detector, init_detector
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
@NODES.register_module()
class DetectorNode(Node):
"""Detect objects from the frame image using MMDetection model.
Note that MMDetection is required for this node. Please refer to
`MMDetection documentation <https://mmdetection.readthedocs.io/en
/latest/get_started.html>`_ for the installation guide.
Parameters:
name (str): The node name (also thread name)
model_cfg (str): The model config file
model_checkpoint (str): The model checkpoint file
input_buffer (str): The name of the input buffer
output_buffer (str|list): The name(s) of the output buffer(s)
enable_key (str|int, optional): Set a hot-key to toggle enable/disable
of the node. If an int value is given, it will be treated as an
ascii code of a key. Please note: (1) If ``enable_key`` is set,
the ``bypass()`` method need to be overridden to define the node
behavior when disabled; (2) Some hot-keys are reserved for
particular use. For example: 'q', 'Q' and 27 are used for exiting.
Default: ``None``
enable (bool): Default enable/disable status. Default: ``True``
device (str): Specify the device to hold model weights and inference
the model. Default: ``'cuda:0'``
bbox_thr (float): Set a threshold to filter out objects with low bbox
scores. Default: 0.5
multi_input (bool): Whether load all frames in input buffer. If True,
all frames in buffer will be loaded and stacked. The latest frame
is used to detect objects of interest. Default: False
Example::
>>> cfg = dict(
... type='DetectorNode',
... name='detector',
... model_config='demo/mmdetection_cfg/'
... 'ssdlite_mobilenetv2_scratch_600e_coco.py',
... model_checkpoint='https://download.openmmlab.com'
... '/mmdetection/v2.0/ssd/'
... 'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
... 'scratch_600e_coco_20210629_110627-974d9307.pth',
... # `_input_` is an executor-reserved buffer
... input_buffer='_input_',
... output_buffer='det_result')
>>> from mmpose.apis.webcam.nodes import NODES
>>> node = NODES.build(cfg)
"""
def __init__(self,
name: str,
model_config: str,
model_checkpoint: str,
input_buffer: str,
output_buffer: Union[str, List[str]],
enable_key: Optional[Union[str, int]] = None,
enable: bool = True,
device: str = 'cuda:0',
bbox_thr: float = 0.5,
multi_input: bool = False):
# Check mmdetection is installed
assert has_mmdet, \
f'MMDetection is required for {self.__class__.__name__}.'
super().__init__(
name=name,
enable_key=enable_key,
enable=enable,
multi_input=multi_input)
self.model_config = get_config_path(model_config, 'mmdet')
self.model_checkpoint = model_checkpoint
self.device = device.lower()
self.bbox_thr = bbox_thr
# Init model
self.model = init_detector(
self.model_config, self.model_checkpoint, device=self.device)
self.model.cfg = adapt_mmdet_pipeline(self.model.cfg)
# Register buffers
self.register_input_buffer(input_buffer, 'input', trigger=True)
self.register_output_buffer(output_buffer)
def bypass(self, input_msgs):
return input_msgs['input']
def process(self, input_msgs):
input_msg = input_msgs['input']
if self.multi_input:
imgs = [frame.get_image() for frame in input_msg]
input_msg = input_msg[-1]
img = input_msg.get_image()
preds = inference_detector(self.model, img)
objects = self._post_process(preds)
input_msg.update_objects(objects)
if self.multi_input:
input_msg.set_image(np.stack(imgs, axis=0))
return input_msg
def _post_process(self, preds) -> List[Dict]:
"""Post-process the predictions of MMDetection model."""
instances = preds.pred_instances.cpu().numpy()
classes = self.model.dataset_meta['classes']
if isinstance(classes, str):
classes = (classes, )
objects = []
for i in range(len(instances)):
if instances.scores[i] < self.bbox_thr:
continue
class_id = instances.labels[i]
obj = {
'class_id': class_id,
'label': classes[class_id],
'bbox': instances.bboxes[i],
'det_model_cfg': self.model.cfg,
'dataset_meta': self.model.dataset_meta.copy(),
}
objects.append(obj)
return objects
|