|
import time |
|
from datetime import datetime |
|
from functools import partial |
|
from queue import Queue |
|
from threading import Thread |
|
from typing import Callable, Generator, List, Optional, Tuple, Union |
|
|
|
from inference.core import logger |
|
from inference.core.active_learning.middlewares import ( |
|
NullActiveLearningMiddleware, |
|
ThreadingActiveLearningMiddleware, |
|
) |
|
from inference.core.cache import cache |
|
from inference.core.env import ( |
|
ACTIVE_LEARNING_ENABLED, |
|
API_KEY, |
|
API_KEY_ENV_NAMES, |
|
DISABLE_PREPROC_AUTO_ORIENT, |
|
PREDICTIONS_QUEUE_SIZE, |
|
RESTART_ATTEMPT_DELAY, |
|
) |
|
from inference.core.exceptions import MissingApiKeyError |
|
from inference.core.interfaces.camera.entities import ( |
|
StatusUpdate, |
|
UpdateSeverity, |
|
VideoFrame, |
|
) |
|
from inference.core.interfaces.camera.exceptions import SourceConnectionError |
|
from inference.core.interfaces.camera.utils import get_video_frames_generator |
|
from inference.core.interfaces.camera.video_source import ( |
|
BufferConsumptionStrategy, |
|
BufferFillingStrategy, |
|
VideoSource, |
|
) |
|
from inference.core.interfaces.stream.entities import ( |
|
ModelConfig, |
|
ObjectDetectionPrediction, |
|
) |
|
from inference.core.interfaces.stream.sinks import active_learning_sink, multi_sink |
|
from inference.core.interfaces.stream.watchdog import ( |
|
NullPipelineWatchdog, |
|
PipelineWatchDog, |
|
) |
|
from inference.core.models.roboflow import OnnxRoboflowInferenceModel |
|
from inference.models.utils import get_roboflow_model |
|
|
|
INFERENCE_PIPELINE_CONTEXT = "inference_pipeline" |
|
SOURCE_CONNECTION_ATTEMPT_FAILED_EVENT = "SOURCE_CONNECTION_ATTEMPT_FAILED" |
|
SOURCE_CONNECTION_LOST_EVENT = "SOURCE_CONNECTION_LOST" |
|
INFERENCE_RESULTS_DISPATCHING_ERROR_EVENT = "INFERENCE_RESULTS_DISPATCHING_ERROR" |
|
INFERENCE_THREAD_STARTED_EVENT = "INFERENCE_THREAD_STARTED" |
|
INFERENCE_THREAD_FINISHED_EVENT = "INFERENCE_THREAD_FINISHED" |
|
INFERENCE_COMPLETED_EVENT = "INFERENCE_COMPLETED" |
|
INFERENCE_ERROR_EVENT = "INFERENCE_ERROR" |
|
|
|
|
|
class InferencePipeline: |
|
@classmethod |
|
def init( |
|
cls, |
|
model_id: str, |
|
video_reference: Union[str, int], |
|
on_prediction: Callable[[ObjectDetectionPrediction, VideoFrame], None], |
|
api_key: Optional[str] = None, |
|
max_fps: Optional[Union[float, int]] = None, |
|
watchdog: Optional[PipelineWatchDog] = None, |
|
status_update_handlers: Optional[List[Callable[[StatusUpdate], None]]] = None, |
|
source_buffer_filling_strategy: Optional[BufferFillingStrategy] = None, |
|
source_buffer_consumption_strategy: Optional[BufferConsumptionStrategy] = None, |
|
class_agnostic_nms: Optional[bool] = None, |
|
confidence: Optional[float] = None, |
|
iou_threshold: Optional[float] = None, |
|
max_candidates: Optional[int] = None, |
|
max_detections: Optional[int] = None, |
|
mask_decode_mode: Optional[str] = "accurate", |
|
tradeoff_factor: Optional[float] = 0.0, |
|
active_learning_enabled: Optional[bool] = None, |
|
) -> "InferencePipeline": |
|
""" |
|
This class creates the abstraction for making inferences from CV models against video stream. |
|
It allows to choose Object Detection model from Roboflow platform and run predictions against |
|
video streams - just by the price of specifying which model to use and what to do with predictions. |
|
|
|
It allows to set the model post-processing parameters (via .init() or env) and intercept updates |
|
related to state of pipeline via `PipelineWatchDog` abstraction (although that is something probably |
|
useful only for advanced use-cases). |
|
|
|
For maximum efficiency, all separate chunks of processing: video decoding, inference, results dispatching |
|
are handled by separate threads. |
|
|
|
Given that reference to stream is passed and connectivity is lost - it attempts to re-connect with delay. |
|
|
|
Since version 0.9.11 it works not only for object detection models but is also compatible with stubs, |
|
classification, instance-segmentation and keypoint-detection models. |
|
|
|
Args: |
|
model_id (str): Name and version of model at Roboflow platform (example: "my-model/3") |
|
video_reference (Union[str, int]): Reference of source to be used to make predictions against. |
|
It can be video file path, stream URL and device (like camera) id (we handle whatever cv2 handles). |
|
on_prediction (Callable[ObjectDetectionPrediction, VideoFrame], None]): Function to be called |
|
once prediction is ready - passing both decoded frame, their metadata and dict with standard |
|
Roboflow Object Detection prediction. |
|
api_key (Optional[str]): Roboflow API key - if not passed - will be looked in env under "ROBOFLOW_API_KEY" |
|
and "API_KEY" variables. API key, passed in some form is required. |
|
max_fps (Optional[Union[float, int]]): Specific value passed as this parameter will be used to |
|
dictate max FPS of processing. It can be useful if we wanted to run concurrent inference pipelines |
|
on single machine making tradeoff between number of frames and number of streams handled. Disabled |
|
by default. |
|
watchdog (Optional[PipelineWatchDog]): Implementation of class that allows profiling of |
|
inference pipeline - if not given null implementation (doing nothing) will be used. |
|
status_update_handlers (Optional[List[Callable[[StatusUpdate], None]]]): List of handlers to intercept |
|
status updates of all elements of the pipeline. Should be used only if detailed inspection of |
|
pipeline behaviour in time is needed. Please point out that handlers should be possible to be executed |
|
fast - otherwise they will impair pipeline performance. All errors will be logged as warnings |
|
without re-raising. Default: None. |
|
source_buffer_filling_strategy (Optional[BufferFillingStrategy]): Parameter dictating strategy for |
|
video stream decoding behaviour. By default - tweaked to the type of source given. |
|
Please find detailed explanation in docs of [`VideoSource`](../camera/video_source.py) |
|
source_buffer_consumption_strategy (Optional[BufferConsumptionStrategy]): Parameter dictating strategy for |
|
video stream frames consumption. By default - tweaked to the type of source given. |
|
Please find detailed explanation in docs of [`VideoSource`](../camera/video_source.py) |
|
class_agnostic_nms (Optional[bool]): Parameter of model post-processing. If not given - value checked in |
|
env variable "CLASS_AGNOSTIC_NMS" with default "False" |
|
confidence (Optional[float]): Parameter of model post-processing. If not given - value checked in |
|
env variable "CONFIDENCE" with default "0.5" |
|
iou_threshold (Optional[float]): Parameter of model post-processing. If not given - value checked in |
|
env variable "IOU_THRESHOLD" with default "0.5" |
|
max_candidates (Optional[int]): Parameter of model post-processing. If not given - value checked in |
|
env variable "MAX_CANDIDATES" with default "3000" |
|
max_detections (Optional[int]): Parameter of model post-processing. If not given - value checked in |
|
env variable "MAX_DETECTIONS" with default "300" |
|
mask_decode_mode: (Optional[str]): Parameter of model post-processing. If not given - model "accurate" is |
|
used. Applicable for instance segmentation models |
|
tradeoff_factor (Optional[float]): Parameter of model post-processing. If not 0.0 - model default is used. |
|
Applicable for instance segmentation models |
|
active_learning_enabled (Optional[bool]): Flag to enable / disable Active Learning middleware (setting it |
|
true does not guarantee any data to be collected, as data collection is controlled by Roboflow backend - |
|
it just enables middleware intercepting predictions). If not given, env variable |
|
`ACTIVE_LEARNING_ENABLED` will be used. Please point out that Active Learning will be forcefully |
|
disabled in a scenario when Roboflow API key is not given, as Roboflow account is required |
|
for this feature to be operational. |
|
|
|
Other ENV variables involved in low-level configuration: |
|
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching |
|
* INFERENCE_PIPELINE_RESTART_ATTEMPT_DELAY - delay for restarts on stream connection drop |
|
* ACTIVE_LEARNING_ENABLED - controls Active Learning middleware if explicit parameter not given |
|
|
|
Returns: Instance of InferencePipeline |
|
|
|
Throws: |
|
* SourceConnectionError if source cannot be connected at start, however it attempts to reconnect |
|
always if connection to stream is lost. |
|
""" |
|
if api_key is None: |
|
api_key = API_KEY |
|
if status_update_handlers is None: |
|
status_update_handlers = [] |
|
inference_config = ModelConfig.init( |
|
class_agnostic_nms=class_agnostic_nms, |
|
confidence=confidence, |
|
iou_threshold=iou_threshold, |
|
max_candidates=max_candidates, |
|
max_detections=max_detections, |
|
mask_decode_mode=mask_decode_mode, |
|
tradeoff_factor=tradeoff_factor, |
|
) |
|
model = get_roboflow_model(model_id=model_id, api_key=api_key) |
|
if watchdog is None: |
|
watchdog = NullPipelineWatchdog() |
|
status_update_handlers.append(watchdog.on_status_update) |
|
video_source = VideoSource.init( |
|
video_reference=video_reference, |
|
status_update_handlers=status_update_handlers, |
|
buffer_filling_strategy=source_buffer_filling_strategy, |
|
buffer_consumption_strategy=source_buffer_consumption_strategy, |
|
) |
|
watchdog.register_video_source(video_source=video_source) |
|
predictions_queue = Queue(maxsize=PREDICTIONS_QUEUE_SIZE) |
|
active_learning_middleware = NullActiveLearningMiddleware() |
|
if active_learning_enabled is None: |
|
logger.info( |
|
f"`active_learning_enabled` parameter not set - using env `ACTIVE_LEARNING_ENABLED` " |
|
f"with value: {ACTIVE_LEARNING_ENABLED}" |
|
) |
|
active_learning_enabled = ACTIVE_LEARNING_ENABLED |
|
if api_key is None: |
|
logger.info( |
|
f"Roboflow API key not given - Active Learning is forced to be disabled." |
|
) |
|
active_learning_enabled = False |
|
if active_learning_enabled is True: |
|
active_learning_middleware = ThreadingActiveLearningMiddleware.init( |
|
api_key=api_key, |
|
model_id=model_id, |
|
cache=cache, |
|
) |
|
al_sink = partial( |
|
active_learning_sink, |
|
active_learning_middleware=active_learning_middleware, |
|
model_type=model.task_type, |
|
disable_preproc_auto_orient=DISABLE_PREPROC_AUTO_ORIENT, |
|
) |
|
logger.info( |
|
"AL enabled - wrapping `on_prediction` with multi_sink() and active_learning_sink()" |
|
) |
|
on_prediction = partial(multi_sink, sinks=[on_prediction, al_sink]) |
|
return cls( |
|
model=model, |
|
video_source=video_source, |
|
on_prediction=on_prediction, |
|
max_fps=max_fps, |
|
predictions_queue=predictions_queue, |
|
watchdog=watchdog, |
|
status_update_handlers=status_update_handlers, |
|
inference_config=inference_config, |
|
active_learning_middleware=active_learning_middleware, |
|
) |
|
|
|
def __init__( |
|
self, |
|
model: OnnxRoboflowInferenceModel, |
|
video_source: VideoSource, |
|
on_prediction: Callable[[ObjectDetectionPrediction, VideoFrame], None], |
|
max_fps: Optional[float], |
|
predictions_queue: Queue, |
|
watchdog: PipelineWatchDog, |
|
status_update_handlers: List[Callable[[StatusUpdate], None]], |
|
inference_config: ModelConfig, |
|
active_learning_middleware: Union[ |
|
NullActiveLearningMiddleware, ThreadingActiveLearningMiddleware |
|
], |
|
): |
|
self._model = model |
|
self._video_source = video_source |
|
self._on_prediction = on_prediction |
|
self._max_fps = max_fps |
|
self._predictions_queue = predictions_queue |
|
self._watchdog = watchdog |
|
self._command_handler_thread: Optional[Thread] = None |
|
self._inference_thread: Optional[Thread] = None |
|
self._dispatching_thread: Optional[Thread] = None |
|
self._stop = False |
|
self._camera_restart_ongoing = False |
|
self._status_update_handlers = status_update_handlers |
|
self._inference_config = inference_config |
|
self._active_learning_middleware = active_learning_middleware |
|
|
|
def start(self, use_main_thread: bool = True) -> None: |
|
self._stop = False |
|
self._inference_thread = Thread(target=self._execute_inference) |
|
self._inference_thread.start() |
|
if self._active_learning_middleware is not None: |
|
self._active_learning_middleware.start_registration_thread() |
|
if use_main_thread: |
|
self._dispatch_inference_results() |
|
else: |
|
self._dispatching_thread = Thread(target=self._dispatch_inference_results) |
|
self._dispatching_thread.start() |
|
|
|
def terminate(self) -> None: |
|
self._stop = True |
|
self._video_source.terminate() |
|
|
|
def pause_stream(self) -> None: |
|
self._video_source.pause() |
|
|
|
def mute_stream(self) -> None: |
|
self._video_source.mute() |
|
|
|
def resume_stream(self) -> None: |
|
self._video_source.resume() |
|
|
|
def join(self) -> None: |
|
if self._inference_thread is not None: |
|
self._inference_thread.join() |
|
self._inference_thread = None |
|
if self._dispatching_thread is not None: |
|
self._dispatching_thread.join() |
|
self._dispatching_thread = None |
|
if self._active_learning_middleware is not None: |
|
self._active_learning_middleware.stop_registration_thread() |
|
|
|
def _execute_inference(self) -> None: |
|
send_inference_pipeline_status_update( |
|
severity=UpdateSeverity.INFO, |
|
event_type=INFERENCE_THREAD_STARTED_EVENT, |
|
status_update_handlers=self._status_update_handlers, |
|
) |
|
logger.info(f"Inference thread started") |
|
try: |
|
for video_frame in self._generate_frames(): |
|
self._watchdog.on_model_preprocessing_started( |
|
frame_timestamp=video_frame.frame_timestamp, |
|
frame_id=video_frame.frame_id, |
|
) |
|
preprocessed_image, preprocessing_metadata = self._model.preprocess( |
|
video_frame.image |
|
) |
|
self._watchdog.on_model_inference_started( |
|
frame_timestamp=video_frame.frame_timestamp, |
|
frame_id=video_frame.frame_id, |
|
) |
|
predictions = self._model.predict(preprocessed_image) |
|
self._watchdog.on_model_postprocessing_started( |
|
frame_timestamp=video_frame.frame_timestamp, |
|
frame_id=video_frame.frame_id, |
|
) |
|
postprocessing_args = self._inference_config.to_postprocessing_params() |
|
predictions = self._model.postprocess( |
|
predictions, |
|
preprocessing_metadata, |
|
**postprocessing_args, |
|
) |
|
if issubclass(type(predictions), list): |
|
predictions = predictions[0].dict( |
|
by_alias=True, |
|
exclude_none=True, |
|
) |
|
self._watchdog.on_model_prediction_ready( |
|
frame_timestamp=video_frame.frame_timestamp, |
|
frame_id=video_frame.frame_id, |
|
) |
|
self._predictions_queue.put((predictions, video_frame)) |
|
send_inference_pipeline_status_update( |
|
severity=UpdateSeverity.DEBUG, |
|
event_type=INFERENCE_COMPLETED_EVENT, |
|
payload={ |
|
"frame_id": video_frame.frame_id, |
|
"frame_timestamp": video_frame.frame_timestamp, |
|
}, |
|
status_update_handlers=self._status_update_handlers, |
|
) |
|
except Exception as error: |
|
payload = { |
|
"error_type": error.__class__.__name__, |
|
"error_message": str(error), |
|
"error_context": "inference_thread", |
|
} |
|
send_inference_pipeline_status_update( |
|
severity=UpdateSeverity.ERROR, |
|
event_type=INFERENCE_ERROR_EVENT, |
|
payload=payload, |
|
status_update_handlers=self._status_update_handlers, |
|
) |
|
logger.exception(f"Encountered inference error: {error}") |
|
finally: |
|
self._predictions_queue.put(None) |
|
send_inference_pipeline_status_update( |
|
severity=UpdateSeverity.INFO, |
|
event_type=INFERENCE_THREAD_FINISHED_EVENT, |
|
status_update_handlers=self._status_update_handlers, |
|
) |
|
logger.info(f"Inference thread finished") |
|
|
|
def _dispatch_inference_results(self) -> None: |
|
while True: |
|
inference_results: Optional[Tuple[dict, VideoFrame]] = ( |
|
self._predictions_queue.get() |
|
) |
|
if inference_results is None: |
|
self._predictions_queue.task_done() |
|
break |
|
predictions, video_frame = inference_results |
|
try: |
|
self._on_prediction(predictions, video_frame) |
|
except Exception as error: |
|
payload = { |
|
"error_type": error.__class__.__name__, |
|
"error_message": str(error), |
|
"error_context": "inference_results_dispatching", |
|
} |
|
send_inference_pipeline_status_update( |
|
severity=UpdateSeverity.ERROR, |
|
event_type=INFERENCE_RESULTS_DISPATCHING_ERROR_EVENT, |
|
payload=payload, |
|
status_update_handlers=self._status_update_handlers, |
|
) |
|
logger.warning(f"Error in results dispatching - {error}") |
|
finally: |
|
self._predictions_queue.task_done() |
|
|
|
def _generate_frames( |
|
self, |
|
) -> Generator[VideoFrame, None, None]: |
|
self._video_source.start() |
|
while True: |
|
source_properties = self._video_source.describe_source().source_properties |
|
if source_properties is None: |
|
break |
|
allow_reconnect = not source_properties.is_file |
|
yield from get_video_frames_generator( |
|
video=self._video_source, max_fps=self._max_fps |
|
) |
|
if not allow_reconnect: |
|
self.terminate() |
|
break |
|
if self._stop: |
|
break |
|
logger.warning(f"Lost connection with video source.") |
|
send_inference_pipeline_status_update( |
|
severity=UpdateSeverity.WARNING, |
|
event_type=SOURCE_CONNECTION_LOST_EVENT, |
|
payload={ |
|
"source_reference": self._video_source.describe_source().source_reference |
|
}, |
|
status_update_handlers=self._status_update_handlers, |
|
) |
|
self._attempt_restart() |
|
|
|
def _attempt_restart(self) -> None: |
|
succeeded = False |
|
while not self._stop and not succeeded: |
|
try: |
|
self._video_source.restart() |
|
succeeded = True |
|
except SourceConnectionError as error: |
|
payload = { |
|
"error_type": error.__class__.__name__, |
|
"error_message": str(error), |
|
"error_context": "video_frames_generator", |
|
} |
|
send_inference_pipeline_status_update( |
|
severity=UpdateSeverity.WARNING, |
|
event_type=SOURCE_CONNECTION_ATTEMPT_FAILED_EVENT, |
|
payload=payload, |
|
status_update_handlers=self._status_update_handlers, |
|
) |
|
logger.warning( |
|
f"Could not connect to video source. Retrying in {RESTART_ATTEMPT_DELAY}s..." |
|
) |
|
time.sleep(RESTART_ATTEMPT_DELAY) |
|
|
|
|
|
def send_inference_pipeline_status_update( |
|
severity: UpdateSeverity, |
|
event_type: str, |
|
status_update_handlers: List[Callable[[StatusUpdate], None]], |
|
payload: Optional[dict] = None, |
|
sub_context: Optional[str] = None, |
|
) -> None: |
|
if payload is None: |
|
payload = {} |
|
context = INFERENCE_PIPELINE_CONTEXT |
|
if sub_context is not None: |
|
context = f"{context}.{sub_context}" |
|
status_update = StatusUpdate( |
|
timestamp=datetime.now(), |
|
severity=severity, |
|
event_type=event_type, |
|
payload=payload, |
|
context=context, |
|
) |
|
for handler in status_update_handlers: |
|
try: |
|
handler(status_update) |
|
except Exception as error: |
|
logger.warning(f"Could not execute handler update. Cause: {error}") |
|
|