|
from typing import Any, List, Optional, Set, Type |
|
|
|
from pydantic import ValidationError |
|
|
|
from inference.core.entities.requests.inference import InferenceRequestImage |
|
from inference.enterprise.workflows.entities.base import GraphNone |
|
from inference.enterprise.workflows.errors import ( |
|
InvalidStepInputDetected, |
|
VariableTypeError, |
|
) |
|
|
|
STEPS_WITH_IMAGE = { |
|
"InferenceImage", |
|
"Crop", |
|
"AbsoluteStaticCrop", |
|
"RelativeStaticCrop", |
|
} |
|
|
|
|
|
def validate_image_is_valid_selector(value: Any, field_name: str = "image") -> None: |
|
if issubclass(type(value), list): |
|
if any(not is_selector(selector_or_value=e) for e in value): |
|
raise ValueError(f"`{field_name}` field can only contain selector values") |
|
elif not is_selector(selector_or_value=value): |
|
raise ValueError(f"`{field_name}` field can only contain selector values") |
|
|
|
|
|
def validate_field_is_in_range_zero_one_or_empty_or_selector( |
|
value: Any, field_name: str = "confidence" |
|
) -> None: |
|
if is_selector(selector_or_value=value) or value is None: |
|
return None |
|
validate_value_is_empty_or_number_in_range_zero_one( |
|
value=value, field_name=field_name |
|
) |
|
|
|
|
|
def validate_value_is_empty_or_number_in_range_zero_one( |
|
value: Any, field_name: str = "confidence", error: Type[Exception] = ValueError |
|
) -> None: |
|
validate_field_has_given_type( |
|
field_name=field_name, |
|
allowed_types=[type(None), int, float], |
|
value=value, |
|
error=error, |
|
) |
|
if value is None: |
|
return None |
|
if not (0 <= value <= 1): |
|
raise error(f"Parameter `{field_name}` must be in range [0.0, 1.0]") |
|
|
|
|
|
def validate_value_is_empty_or_selector_or_positive_number( |
|
value: Any, field_name: str |
|
) -> None: |
|
if is_selector(selector_or_value=value): |
|
return None |
|
validate_value_is_empty_or_positive_number(value=value, field_name=field_name) |
|
|
|
|
|
def validate_value_is_empty_or_positive_number( |
|
value: Any, field_name: str, error: Type[Exception] = ValueError |
|
) -> None: |
|
validate_field_has_given_type( |
|
field_name=field_name, |
|
allowed_types=[type(None), int, float], |
|
value=value, |
|
error=error, |
|
) |
|
if value is None: |
|
return None |
|
if value <= 0: |
|
raise error(f"Parameter `{field_name}` must be positive (> 0)") |
|
|
|
|
|
def validate_field_is_list_of_selectors( |
|
value: Any, field_name: str, error: Type[Exception] = ValueError |
|
) -> None: |
|
if not issubclass(type(value), list): |
|
raise error(f"`{field_name}` field must be list") |
|
if any(not is_selector(selector_or_value=e) for e in value): |
|
raise error(f"Parameter `{field_name}` must be a list of selectors") |
|
|
|
|
|
def validate_field_is_empty_or_selector_or_list_of_string( |
|
value: Any, field_name: str |
|
) -> None: |
|
if is_selector(selector_or_value=value) or value is None: |
|
return value |
|
validate_field_is_list_of_string(value=value, field_name=field_name) |
|
|
|
|
|
def validate_field_is_list_of_string( |
|
value: Any, field_name: str, error: Type[Exception] = ValueError |
|
) -> None: |
|
if not issubclass(type(value), list): |
|
raise error(f"`{field_name}` field must be list") |
|
if any(not issubclass(type(e), str) for e in value): |
|
raise error(f"Parameter `{field_name}` must be a list of string") |
|
|
|
|
|
def validate_field_is_selector_or_one_of_values( |
|
value: Any, field_name: str, selected_values: set |
|
) -> None: |
|
if is_selector(selector_or_value=value) or value is None: |
|
return value |
|
validate_field_is_one_of_selected_values( |
|
value=value, field_name=field_name, selected_values=selected_values |
|
) |
|
|
|
|
|
def validate_field_is_one_of_selected_values( |
|
value: Any, |
|
field_name: str, |
|
selected_values: set, |
|
error: Type[Exception] = ValueError, |
|
) -> None: |
|
if value not in selected_values: |
|
raise error( |
|
f"Value of field `{field_name}` must be in {selected_values}. Found: {value}" |
|
) |
|
|
|
|
|
def validate_field_is_selector_or_has_given_type( |
|
value: Any, field_name: str, allowed_types: List[type] |
|
) -> None: |
|
if is_selector(selector_or_value=value): |
|
return None |
|
validate_field_has_given_type( |
|
field_name=field_name, allowed_types=allowed_types, value=value |
|
) |
|
return None |
|
|
|
|
|
def validate_field_has_given_type( |
|
value: Any, |
|
field_name: str, |
|
allowed_types: List[type], |
|
error: Type[Exception] = ValueError, |
|
) -> None: |
|
if all(not issubclass(type(value), allowed_type) for allowed_type in allowed_types): |
|
raise error( |
|
f"`{field_name}` field type must be one of {allowed_types}. Detected: {value}" |
|
) |
|
|
|
|
|
def validate_image_biding(value: Any, field_name: str = "image") -> None: |
|
try: |
|
if not issubclass(type(value), list): |
|
value = [value] |
|
for e in value: |
|
InferenceRequestImage.model_validate(e) |
|
except (ValueError, ValidationError) as error: |
|
raise VariableTypeError( |
|
f"Parameter `{field_name}` must be compatible with `InferenceRequestImage`" |
|
) from error |
|
|
|
|
|
def validate_selector_is_inference_parameter( |
|
step_type: str, |
|
field_name: str, |
|
input_step: GraphNone, |
|
applicable_fields: Set[str], |
|
) -> None: |
|
if field_name not in applicable_fields: |
|
return None |
|
input_step_type = input_step.get_type() |
|
if input_step_type not in {"InferenceParameter"}: |
|
raise InvalidStepInputDetected( |
|
f"Field {field_name} of step {step_type} comes from invalid input type: {input_step_type}. " |
|
f"Expected: `InferenceParameter`" |
|
) |
|
|
|
|
|
def validate_selector_holds_image( |
|
step_type: str, |
|
field_name: str, |
|
input_step: GraphNone, |
|
applicable_fields: Optional[Set[str]] = None, |
|
) -> None: |
|
if applicable_fields is None: |
|
applicable_fields = {"image"} |
|
if field_name not in applicable_fields: |
|
return None |
|
if input_step.get_type() not in STEPS_WITH_IMAGE: |
|
raise InvalidStepInputDetected( |
|
f"Field {field_name} of step {step_type} comes from invalid input type: {input_step.get_type()}. " |
|
f"Expected: {STEPS_WITH_IMAGE}" |
|
) |
|
|
|
|
|
def validate_selector_holds_detections( |
|
step_name: str, |
|
image_selector: Optional[str], |
|
detections_selector: str, |
|
field_name: str, |
|
input_step: GraphNone, |
|
applicable_fields: Optional[Set[str]] = None, |
|
) -> None: |
|
if applicable_fields is None: |
|
applicable_fields = {"detections"} |
|
if field_name not in applicable_fields: |
|
return None |
|
if input_step.get_type() not in { |
|
"ObjectDetectionModel", |
|
"KeypointsDetectionModel", |
|
"InstanceSegmentationModel", |
|
"DetectionFilter", |
|
"DetectionsConsensus", |
|
"DetectionOffset", |
|
"YoloWorld", |
|
}: |
|
raise InvalidStepInputDetected( |
|
f"Step step with name {step_name} cannot take as an input predictions from {input_step.get_type()}. " |
|
f"Step requires detection-based output." |
|
) |
|
if get_last_selector_chunk(detections_selector) != "predictions": |
|
raise InvalidStepInputDetected( |
|
f"Step with name {step_name} must take as input step output of name `predictions`" |
|
) |
|
if not hasattr(input_step, "image") or image_selector is None: |
|
|
|
return None |
|
input_step_image_reference = input_step.image |
|
if image_selector != input_step_image_reference: |
|
raise InvalidStepInputDetected( |
|
f"Step step with name {step_name} was given detections reference that is bound to different image: " |
|
f"step.image: {image_selector}, detections step image: {input_step_image_reference}" |
|
) |
|
|
|
|
|
def is_selector(selector_or_value: Any) -> bool: |
|
if not issubclass(type(selector_or_value), str): |
|
return False |
|
return selector_or_value.startswith("$") |
|
|
|
|
|
def get_last_selector_chunk(selector: str) -> str: |
|
return selector.split(".")[-1] |
|
|