|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import Any, Callable, Dict, List, Optional, Tuple |
|
|
|
import numpy as np |
|
|
|
from inference.core.entities.types import DatasetID, WorkspaceID |
|
from inference.core.exceptions import ActiveLearningConfigurationDecodingError |
|
|
|
LocalImageIdentifier = str |
|
PredictionType = str |
|
Prediction = dict |
|
SerialisedPrediction = str |
|
PredictionFileType = str |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ImageDimensions: |
|
height: int |
|
width: int |
|
|
|
def to_hw(self) -> Tuple[int, int]: |
|
return self.height, self.width |
|
|
|
def to_wh(self) -> Tuple[int, int]: |
|
return self.width, self.height |
|
|
|
|
|
@dataclass(frozen=True) |
|
class SamplingMethod: |
|
name: str |
|
sample: Callable[[np.ndarray, Prediction, PredictionType], bool] |
|
|
|
|
|
class BatchReCreationInterval(Enum): |
|
NEVER = "never" |
|
DAILY = "daily" |
|
WEEKLY = "weekly" |
|
MONTHLY = "monthly" |
|
|
|
|
|
class StrategyLimitType(Enum): |
|
MINUTELY = "minutely" |
|
HOURLY = "hourly" |
|
DAILY = "daily" |
|
|
|
|
|
@dataclass(frozen=True) |
|
class StrategyLimit: |
|
limit_type: StrategyLimitType |
|
value: int |
|
|
|
@classmethod |
|
def from_dict(cls, specification: dict) -> "StrategyLimit": |
|
return cls( |
|
limit_type=StrategyLimitType(specification["type"]), |
|
value=specification["value"], |
|
) |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ActiveLearningConfiguration: |
|
max_image_size: Optional[ImageDimensions] |
|
jpeg_compression_level: int |
|
persist_predictions: bool |
|
sampling_methods: List[SamplingMethod] |
|
batches_name_prefix: str |
|
batch_recreation_interval: BatchReCreationInterval |
|
max_batch_images: Optional[int] |
|
workspace_id: WorkspaceID |
|
dataset_id: DatasetID |
|
model_id: str |
|
strategies_limits: Dict[str, List[StrategyLimit]] |
|
tags: List[str] |
|
strategies_tags: Dict[str, List[str]] |
|
|
|
@classmethod |
|
def init( |
|
cls, |
|
roboflow_api_configuration: Dict[str, Any], |
|
sampling_methods: List[SamplingMethod], |
|
workspace_id: WorkspaceID, |
|
dataset_id: DatasetID, |
|
model_id: str, |
|
) -> "ActiveLearningConfiguration": |
|
try: |
|
max_image_size = roboflow_api_configuration.get("max_image_size") |
|
if max_image_size is not None: |
|
max_image_size = ImageDimensions( |
|
height=roboflow_api_configuration["max_image_size"][0], |
|
width=roboflow_api_configuration["max_image_size"][1], |
|
) |
|
strategies_limits = { |
|
strategy["name"]: [ |
|
StrategyLimit.from_dict(specification=specification) |
|
for specification in strategy.get("limits", []) |
|
] |
|
for strategy in roboflow_api_configuration["sampling_strategies"] |
|
} |
|
strategies_tags = { |
|
strategy["name"]: strategy.get("tags", []) |
|
for strategy in roboflow_api_configuration["sampling_strategies"] |
|
} |
|
return cls( |
|
max_image_size=max_image_size, |
|
jpeg_compression_level=roboflow_api_configuration.get( |
|
"jpeg_compression_level", 95 |
|
), |
|
persist_predictions=roboflow_api_configuration["persist_predictions"], |
|
sampling_methods=sampling_methods, |
|
batches_name_prefix=roboflow_api_configuration["batching_strategy"][ |
|
"batches_name_prefix" |
|
], |
|
batch_recreation_interval=BatchReCreationInterval( |
|
roboflow_api_configuration["batching_strategy"][ |
|
"recreation_interval" |
|
] |
|
), |
|
max_batch_images=roboflow_api_configuration["batching_strategy"].get( |
|
"max_batch_images" |
|
), |
|
workspace_id=workspace_id, |
|
dataset_id=dataset_id, |
|
model_id=model_id, |
|
strategies_limits=strategies_limits, |
|
tags=roboflow_api_configuration.get("tags", []), |
|
strategies_tags=strategies_tags, |
|
) |
|
except (KeyError, ValueError) as e: |
|
raise ActiveLearningConfigurationDecodingError( |
|
f"Failed to initialise Active Learning configuration. Cause: {str(e)}" |
|
) from e |
|
|
|
|
|
@dataclass(frozen=True) |
|
class RoboflowProjectMetadata: |
|
dataset_id: DatasetID |
|
version_id: str |
|
workspace_id: WorkspaceID |
|
dataset_type: str |
|
active_learning_configuration: dict |
|
|