Spaces:
Runtime error
Runtime error
import queue | |
from queue import Queue | |
from threading import Thread | |
from typing import Any, List, Optional | |
from inference.core import logger | |
from inference.core.active_learning.accounting import image_can_be_submitted_to_batch | |
from inference.core.active_learning.batching import generate_batch_name | |
from inference.core.active_learning.configuration import ( | |
prepare_active_learning_configuration, | |
prepare_active_learning_configuration_inplace, | |
) | |
from inference.core.active_learning.core import ( | |
execute_datapoint_registration, | |
execute_sampling, | |
) | |
from inference.core.active_learning.entities import ( | |
ActiveLearningConfiguration, | |
Prediction, | |
PredictionType, | |
) | |
from inference.core.cache.base import BaseCache | |
from inference.core.utils.image_utils import load_image | |
MAX_REGISTRATION_QUEUE_SIZE = 512 | |
class NullActiveLearningMiddleware: | |
def register_batch( | |
self, | |
inference_inputs: List[Any], | |
predictions: List[Prediction], | |
prediction_type: PredictionType, | |
disable_preproc_auto_orient: bool = False, | |
) -> None: | |
pass | |
def register( | |
self, | |
inference_input: Any, | |
prediction: dict, | |
prediction_type: PredictionType, | |
disable_preproc_auto_orient: bool = False, | |
) -> None: | |
pass | |
def start_registration_thread(self) -> None: | |
pass | |
def stop_registration_thread(self) -> None: | |
pass | |
def __enter__(self) -> "NullActiveLearningMiddleware": | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
pass | |
class ActiveLearningMiddleware: | |
def init( | |
cls, api_key: str, model_id: str, cache: BaseCache | |
) -> "ActiveLearningMiddleware": | |
configuration = prepare_active_learning_configuration( | |
api_key=api_key, | |
model_id=model_id, | |
cache=cache, | |
) | |
return cls( | |
api_key=api_key, | |
configuration=configuration, | |
cache=cache, | |
) | |
def init_from_config( | |
cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict] | |
) -> "ActiveLearningMiddleware": | |
configuration = prepare_active_learning_configuration_inplace( | |
api_key=api_key, | |
model_id=model_id, | |
active_learning_configuration=config, | |
) | |
return cls( | |
api_key=api_key, | |
configuration=configuration, | |
cache=cache, | |
) | |
def __init__( | |
self, | |
api_key: str, | |
configuration: Optional[ActiveLearningConfiguration], | |
cache: BaseCache, | |
): | |
self._api_key = api_key | |
self._configuration = configuration | |
self._cache = cache | |
def register_batch( | |
self, | |
inference_inputs: List[Any], | |
predictions: List[Prediction], | |
prediction_type: PredictionType, | |
disable_preproc_auto_orient: bool = False, | |
) -> None: | |
for inference_input, prediction in zip(inference_inputs, predictions): | |
self.register( | |
inference_input=inference_input, | |
prediction=prediction, | |
prediction_type=prediction_type, | |
disable_preproc_auto_orient=disable_preproc_auto_orient, | |
) | |
def register( | |
self, | |
inference_input: Any, | |
prediction: dict, | |
prediction_type: PredictionType, | |
disable_preproc_auto_orient: bool = False, | |
) -> None: | |
self._execute_registration( | |
inference_input=inference_input, | |
prediction=prediction, | |
prediction_type=prediction_type, | |
disable_preproc_auto_orient=disable_preproc_auto_orient, | |
) | |
def _execute_registration( | |
self, | |
inference_input: Any, | |
prediction: dict, | |
prediction_type: PredictionType, | |
disable_preproc_auto_orient: bool = False, | |
) -> None: | |
if self._configuration is None: | |
return None | |
image, is_bgr = load_image( | |
value=inference_input, | |
disable_preproc_auto_orient=disable_preproc_auto_orient, | |
) | |
if not is_bgr: | |
image = image[:, :, ::-1] | |
matching_strategies = execute_sampling( | |
image=image, | |
prediction=prediction, | |
prediction_type=prediction_type, | |
sampling_methods=self._configuration.sampling_methods, | |
) | |
if len(matching_strategies) == 0: | |
return None | |
batch_name = generate_batch_name(configuration=self._configuration) | |
if not image_can_be_submitted_to_batch( | |
batch_name=batch_name, | |
workspace_id=self._configuration.workspace_id, | |
dataset_id=self._configuration.dataset_id, | |
max_batch_images=self._configuration.max_batch_images, | |
api_key=self._api_key, | |
): | |
logger.debug(f"Limit on Active Learning batch size reached.") | |
return None | |
execute_datapoint_registration( | |
cache=self._cache, | |
matching_strategies=matching_strategies, | |
image=image, | |
prediction=prediction, | |
prediction_type=prediction_type, | |
configuration=self._configuration, | |
api_key=self._api_key, | |
batch_name=batch_name, | |
) | |
class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware): | |
def init( | |
cls, | |
api_key: str, | |
model_id: str, | |
cache: BaseCache, | |
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, | |
) -> "ThreadingActiveLearningMiddleware": | |
configuration = prepare_active_learning_configuration( | |
api_key=api_key, | |
model_id=model_id, | |
cache=cache, | |
) | |
task_queue = Queue(max_queue_size) | |
return cls( | |
api_key=api_key, | |
configuration=configuration, | |
cache=cache, | |
task_queue=task_queue, | |
) | |
def init_from_config( | |
cls, | |
api_key: str, | |
model_id: str, | |
cache: BaseCache, | |
config: Optional[dict], | |
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, | |
) -> "ThreadingActiveLearningMiddleware": | |
configuration = prepare_active_learning_configuration_inplace( | |
api_key=api_key, | |
model_id=model_id, | |
active_learning_configuration=config, | |
) | |
task_queue = Queue(max_queue_size) | |
return cls( | |
api_key=api_key, | |
configuration=configuration, | |
cache=cache, | |
task_queue=task_queue, | |
) | |
def __init__( | |
self, | |
api_key: str, | |
configuration: ActiveLearningConfiguration, | |
cache: BaseCache, | |
task_queue: Queue, | |
): | |
super().__init__(api_key=api_key, configuration=configuration, cache=cache) | |
self._task_queue = task_queue | |
self._registration_thread: Optional[Thread] = None | |
def register( | |
self, | |
inference_input: Any, | |
prediction: dict, | |
prediction_type: PredictionType, | |
disable_preproc_auto_orient: bool = False, | |
) -> None: | |
logger.debug(f"Putting registration task into queue") | |
try: | |
self._task_queue.put_nowait( | |
( | |
inference_input, | |
prediction, | |
prediction_type, | |
disable_preproc_auto_orient, | |
) | |
) | |
except queue.Full: | |
logger.warning( | |
f"Dropping datapoint registered in Active Learning due to insufficient processing " | |
f"capabilities." | |
) | |
def start_registration_thread(self) -> None: | |
if self._registration_thread is not None: | |
logger.warning(f"Registration thread already started.") | |
return None | |
logger.debug("Staring registration thread") | |
self._registration_thread = Thread(target=self._consume_queue) | |
self._registration_thread.start() | |
def stop_registration_thread(self) -> None: | |
if self._registration_thread is None: | |
logger.warning("Registration thread is already stopped.") | |
return None | |
logger.debug("Stopping registration thread") | |
self._task_queue.put(None) | |
self._registration_thread.join() | |
if self._registration_thread.is_alive(): | |
logger.warning(f"Registration thread stopping was unsuccessful.") | |
self._registration_thread = None | |
def _consume_queue(self) -> None: | |
queue_closed = False | |
while not queue_closed: | |
queue_closed = self._consume_queue_task() | |
def _consume_queue_task(self) -> bool: | |
logger.debug("Consuming registration task") | |
task = self._task_queue.get() | |
logger.debug("Received registration task") | |
if task is None: | |
logger.debug("Terminating registration thread") | |
self._task_queue.task_done() | |
return True | |
inference_input, prediction, prediction_type, disable_preproc_auto_orient = task | |
try: | |
self._execute_registration( | |
inference_input=inference_input, | |
prediction=prediction, | |
prediction_type=prediction_type, | |
disable_preproc_auto_orient=disable_preproc_auto_orient, | |
) | |
except Exception as error: | |
# Error handling to be decided | |
logger.warning( | |
f"Error in datapoint registration for Active Learning. Details: {error}. " | |
f"Error is suppressed in favour of normal operations of registration thread." | |
) | |
self._task_queue.task_done() | |
return False | |
def __enter__(self) -> "ThreadingActiveLearningMiddleware": | |
self.start_registration_thread() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
self.stop_registration_thread() | |