|
import random |
|
from functools import partial |
|
from typing import Any, Dict |
|
|
|
import numpy as np |
|
|
|
from inference.core.active_learning.entities import ( |
|
Prediction, |
|
PredictionType, |
|
SamplingMethod, |
|
) |
|
from inference.core.exceptions import ActiveLearningConfigurationError |
|
|
|
|
|
def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: |
|
try: |
|
sample_function = partial( |
|
sample_randomly, |
|
traffic_percentage=strategy_config["traffic_percentage"], |
|
) |
|
return SamplingMethod( |
|
name=strategy_config["name"], |
|
sample=sample_function, |
|
) |
|
except KeyError as error: |
|
raise ActiveLearningConfigurationError( |
|
f"In configuration of `random_sampling` missing key detected: {error}." |
|
) from error |
|
|
|
|
|
def sample_randomly( |
|
image: np.ndarray, |
|
prediction: Prediction, |
|
prediction_type: PredictionType, |
|
traffic_percentage: float, |
|
) -> bool: |
|
return random.random() < traffic_percentage |
|
|