|
import asyncio |
|
from copy import deepcopy |
|
from functools import partial |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
from uuid import uuid4 |
|
|
|
from inference.core.entities.requests.clip import ClipCompareRequest |
|
from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest |
|
from inference.core.entities.requests.inference import ( |
|
ClassificationInferenceRequest, |
|
InstanceSegmentationInferenceRequest, |
|
KeypointsDetectionInferenceRequest, |
|
ObjectDetectionInferenceRequest, |
|
) |
|
from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest |
|
from inference.core.env import ( |
|
HOSTED_CLASSIFICATION_URL, |
|
HOSTED_CORE_MODEL_URL, |
|
HOSTED_DETECT_URL, |
|
HOSTED_INSTANCE_SEGMENTATION_URL, |
|
LOCAL_INFERENCE_API_URL, |
|
WORKFLOWS_REMOTE_API_TARGET, |
|
WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, |
|
WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
from inference.core.managers.base import ModelManager |
|
from inference.enterprise.workflows.complier.entities import StepExecutionMode |
|
from inference.enterprise.workflows.complier.steps_executors.constants import ( |
|
CENTER_X_KEY, |
|
CENTER_Y_KEY, |
|
ORIGIN_COORDINATES_KEY, |
|
ORIGIN_SIZE_KEY, |
|
PARENT_COORDINATES_SUFFIX, |
|
) |
|
from inference.enterprise.workflows.complier.steps_executors.types import ( |
|
NextStepReference, |
|
OutputsLookup, |
|
) |
|
from inference.enterprise.workflows.complier.steps_executors.utils import ( |
|
get_image, |
|
make_batches, |
|
resolve_parameter, |
|
) |
|
from inference.enterprise.workflows.complier.utils import construct_step_selector |
|
from inference.enterprise.workflows.entities.steps import ( |
|
ClassificationModel, |
|
ClipComparison, |
|
InstanceSegmentationModel, |
|
KeypointsDetectionModel, |
|
MultiLabelClassificationModel, |
|
ObjectDetectionModel, |
|
OCRModel, |
|
RoboflowModel, |
|
StepInterface, |
|
YoloWorld, |
|
) |
|
from inference_sdk import InferenceConfiguration, InferenceHTTPClient |
|
|
|
MODEL_TYPE2PREDICTION_TYPE = { |
|
"ClassificationModel": "classification", |
|
"MultiLabelClassificationModel": "classification", |
|
"ObjectDetectionModel": "object-detection", |
|
"InstanceSegmentationModel": "instance-segmentation", |
|
"KeypointsDetectionModel": "keypoint-detection", |
|
} |
|
|
|
|
|
async def run_roboflow_model_step( |
|
step: RoboflowModel, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
step_execution_mode: StepExecutionMode, |
|
) -> Tuple[NextStepReference, OutputsLookup]: |
|
model_id = resolve_parameter( |
|
selector_or_value=step.model_id, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
image = get_image( |
|
step=step, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
if step_execution_mode is StepExecutionMode.LOCAL: |
|
serialised_result = await get_roboflow_model_predictions_locally( |
|
image=image, |
|
model_id=model_id, |
|
step=step, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
model_manager=model_manager, |
|
api_key=api_key, |
|
) |
|
else: |
|
serialised_result = await get_roboflow_model_predictions_from_remote_api( |
|
image=image, |
|
model_id=model_id, |
|
step=step, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
api_key=api_key, |
|
) |
|
serialised_result = attach_prediction_type_info( |
|
results=serialised_result, |
|
prediction_type=MODEL_TYPE2PREDICTION_TYPE[step.get_type()], |
|
) |
|
if step.type in {"ClassificationModel", "MultiLabelClassificationModel"}: |
|
serialised_result = attach_parent_info( |
|
image=image, results=serialised_result, nested_key=None |
|
) |
|
else: |
|
serialised_result = attach_parent_info(image=image, results=serialised_result) |
|
serialised_result = anchor_detections_in_parent_coordinates( |
|
image=image, |
|
serialised_result=serialised_result, |
|
) |
|
outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result |
|
return None, outputs_lookup |
|
|
|
|
|
async def get_roboflow_model_predictions_locally( |
|
image: List[dict], |
|
model_id: str, |
|
step: RoboflowModel, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
request_constructor = MODEL_TYPE2REQUEST_CONSTRUCTOR[step.type] |
|
request = request_constructor( |
|
step=step, |
|
image=image, |
|
api_key=api_key, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
model_manager.add_model( |
|
model_id=model_id, |
|
api_key=api_key, |
|
) |
|
result = await model_manager.infer_from_request(model_id=model_id, request=request) |
|
if issubclass(type(result), list): |
|
serialised_result = [e.dict(by_alias=True, exclude_none=True) for e in result] |
|
else: |
|
serialised_result = [result.dict(by_alias=True, exclude_none=True)] |
|
return serialised_result |
|
|
|
|
|
def construct_classification_request( |
|
step: Union[ClassificationModel, MultiLabelClassificationModel], |
|
image: Any, |
|
api_key: Optional[str], |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> ClassificationInferenceRequest: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return ClassificationInferenceRequest( |
|
api_key=api_key, |
|
model_id=resolve(step.model_id), |
|
image=image, |
|
confidence=resolve(step.confidence), |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
) |
|
|
|
|
|
def construct_object_detection_request( |
|
step: ObjectDetectionModel, |
|
image: Any, |
|
api_key: Optional[str], |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> ObjectDetectionInferenceRequest: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return ObjectDetectionInferenceRequest( |
|
api_key=api_key, |
|
model_id=resolve(step.model_id), |
|
image=image, |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
class_agnostic_nms=resolve(step.class_agnostic_nms), |
|
class_filter=resolve(step.class_filter), |
|
confidence=resolve(step.confidence), |
|
iou_threshold=resolve(step.iou_threshold), |
|
max_detections=resolve(step.max_detections), |
|
max_candidates=resolve(step.max_candidates), |
|
) |
|
|
|
|
|
def construct_instance_segmentation_request( |
|
step: InstanceSegmentationModel, |
|
image: Any, |
|
api_key: Optional[str], |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> InstanceSegmentationInferenceRequest: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return InstanceSegmentationInferenceRequest( |
|
api_key=api_key, |
|
model_id=resolve(step.model_id), |
|
image=image, |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
class_agnostic_nms=resolve(step.class_agnostic_nms), |
|
class_filter=resolve(step.class_filter), |
|
confidence=resolve(step.confidence), |
|
iou_threshold=resolve(step.iou_threshold), |
|
max_detections=resolve(step.max_detections), |
|
max_candidates=resolve(step.max_candidates), |
|
mask_decode_mode=resolve(step.mask_decode_mode), |
|
tradeoff_factor=resolve(step.tradeoff_factor), |
|
) |
|
|
|
|
|
def construct_keypoints_detection_request( |
|
step: KeypointsDetectionModel, |
|
image: Any, |
|
api_key: Optional[str], |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> KeypointsDetectionInferenceRequest: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return KeypointsDetectionInferenceRequest( |
|
api_key=api_key, |
|
model_id=resolve(step.model_id), |
|
image=image, |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
class_agnostic_nms=resolve(step.class_agnostic_nms), |
|
class_filter=resolve(step.class_filter), |
|
confidence=resolve(step.confidence), |
|
iou_threshold=resolve(step.iou_threshold), |
|
max_detections=resolve(step.max_detections), |
|
max_candidates=resolve(step.max_candidates), |
|
keypoint_confidence=resolve(step.keypoint_confidence), |
|
) |
|
|
|
|
|
MODEL_TYPE2REQUEST_CONSTRUCTOR = { |
|
"ClassificationModel": construct_classification_request, |
|
"MultiLabelClassificationModel": construct_classification_request, |
|
"ObjectDetectionModel": construct_object_detection_request, |
|
"InstanceSegmentationModel": construct_instance_segmentation_request, |
|
"KeypointsDetectionModel": construct_keypoints_detection_request, |
|
} |
|
|
|
|
|
async def get_roboflow_model_predictions_from_remote_api( |
|
image: List[dict], |
|
model_id: str, |
|
step: RoboflowModel, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
api_url = resolve_model_api_url(step=step) |
|
client = InferenceHTTPClient( |
|
api_url=api_url, |
|
api_key=api_key, |
|
) |
|
if WORKFLOWS_REMOTE_API_TARGET == "hosted": |
|
client.select_api_v0() |
|
configuration = MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR[step.type]( |
|
step=step, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
client.configure(inference_configuration=configuration) |
|
inference_input = [i["value"] for i in image] |
|
results = await client.infer_async( |
|
inference_input=inference_input, |
|
model_id=model_id, |
|
) |
|
if not issubclass(type(results), list): |
|
return [results] |
|
return results |
|
|
|
|
|
def construct_http_client_configuration_for_classification_step( |
|
step: Union[ClassificationModel, MultiLabelClassificationModel], |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> InferenceConfiguration: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return InferenceConfiguration( |
|
confidence_threshold=resolve(step.confidence), |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, |
|
max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
|
|
|
|
def construct_http_client_configuration_for_detection_step( |
|
step: ObjectDetectionModel, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> InferenceConfiguration: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return InferenceConfiguration( |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
class_agnostic_nms=resolve(step.class_agnostic_nms), |
|
class_filter=resolve(step.class_filter), |
|
confidence_threshold=resolve(step.confidence), |
|
iou_threshold=resolve(step.iou_threshold), |
|
max_detections=resolve(step.max_detections), |
|
max_candidates=resolve(step.max_candidates), |
|
max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, |
|
max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
|
|
|
|
def construct_http_client_configuration_for_segmentation_step( |
|
step: InstanceSegmentationModel, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> InferenceConfiguration: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return InferenceConfiguration( |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
class_agnostic_nms=resolve(step.class_agnostic_nms), |
|
class_filter=resolve(step.class_filter), |
|
confidence_threshold=resolve(step.confidence), |
|
iou_threshold=resolve(step.iou_threshold), |
|
max_detections=resolve(step.max_detections), |
|
max_candidates=resolve(step.max_candidates), |
|
mask_decode_mode=resolve(step.mask_decode_mode), |
|
tradeoff_factor=resolve(step.tradeoff_factor), |
|
max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, |
|
max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
|
|
|
|
def construct_http_client_configuration_for_keypoints_detection_step( |
|
step: KeypointsDetectionModel, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
) -> InferenceConfiguration: |
|
resolve = partial( |
|
resolve_parameter, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
return InferenceConfiguration( |
|
disable_active_learning=resolve(step.disable_active_learning), |
|
class_agnostic_nms=resolve(step.class_agnostic_nms), |
|
class_filter=resolve(step.class_filter), |
|
confidence_threshold=resolve(step.confidence), |
|
iou_threshold=resolve(step.iou_threshold), |
|
max_detections=resolve(step.max_detections), |
|
max_candidates=resolve(step.max_candidates), |
|
keypoint_confidence_threshold=resolve(step.keypoint_confidence), |
|
max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, |
|
max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
|
|
|
|
MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR = { |
|
"ClassificationModel": construct_http_client_configuration_for_classification_step, |
|
"MultiLabelClassificationModel": construct_http_client_configuration_for_classification_step, |
|
"ObjectDetectionModel": construct_http_client_configuration_for_detection_step, |
|
"InstanceSegmentationModel": construct_http_client_configuration_for_segmentation_step, |
|
"KeypointsDetectionModel": construct_http_client_configuration_for_keypoints_detection_step, |
|
} |
|
|
|
|
|
async def run_yolo_world_model_step( |
|
step: YoloWorld, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
step_execution_mode: StepExecutionMode, |
|
) -> Tuple[NextStepReference, OutputsLookup]: |
|
image = get_image( |
|
step=step, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
class_names = resolve_parameter( |
|
selector_or_value=step.class_names, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
model_version = resolve_parameter( |
|
selector_or_value=step.version, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
confidence = resolve_parameter( |
|
selector_or_value=step.confidence, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
if step_execution_mode is StepExecutionMode.LOCAL: |
|
serialised_result = await get_yolo_world_predictions_locally( |
|
image=image, |
|
class_names=class_names, |
|
model_version=model_version, |
|
confidence=confidence, |
|
model_manager=model_manager, |
|
api_key=api_key, |
|
) |
|
else: |
|
serialised_result = await get_yolo_world_predictions_from_remote_api( |
|
image=image, |
|
class_names=class_names, |
|
model_version=model_version, |
|
confidence=confidence, |
|
step=step, |
|
api_key=api_key, |
|
) |
|
serialised_result = attach_prediction_type_info( |
|
results=serialised_result, |
|
prediction_type="object-detection", |
|
) |
|
serialised_result = attach_parent_info(image=image, results=serialised_result) |
|
serialised_result = anchor_detections_in_parent_coordinates( |
|
image=image, |
|
serialised_result=serialised_result, |
|
) |
|
outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result |
|
return None, outputs_lookup |
|
|
|
|
|
async def get_yolo_world_predictions_locally( |
|
image: List[dict], |
|
class_names: List[str], |
|
model_version: Optional[str], |
|
confidence: Optional[float], |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
serialised_result = [] |
|
for single_image in image: |
|
inference_request = YOLOWorldInferenceRequest( |
|
image=single_image, |
|
yolo_world_version_id=model_version, |
|
confidence=confidence, |
|
text=class_names, |
|
) |
|
yolo_world_model_id = load_core_model( |
|
model_manager=model_manager, |
|
inference_request=inference_request, |
|
core_model="yolo_world", |
|
api_key=api_key, |
|
) |
|
result = await model_manager.infer_from_request( |
|
yolo_world_model_id, inference_request |
|
) |
|
serialised_result.append(result.dict()) |
|
return serialised_result |
|
|
|
|
|
async def get_yolo_world_predictions_from_remote_api( |
|
image: List[dict], |
|
class_names: List[str], |
|
model_version: Optional[str], |
|
confidence: Optional[float], |
|
step: YoloWorld, |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
api_url = resolve_model_api_url(step=step) |
|
client = InferenceHTTPClient( |
|
api_url=api_url, |
|
api_key=api_key, |
|
) |
|
configuration = InferenceConfiguration( |
|
max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
client.configure(inference_configuration=configuration) |
|
if WORKFLOWS_REMOTE_API_TARGET == "hosted": |
|
client.select_api_v0() |
|
image_batches = list( |
|
make_batches( |
|
iterable=image, |
|
batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
) |
|
serialised_result = [] |
|
for single_batch in image_batches: |
|
batch_results = await client.infer_from_yolo_world_async( |
|
inference_input=[i["value"] for i in single_batch], |
|
class_names=class_names, |
|
model_version=model_version, |
|
confidence=confidence, |
|
) |
|
serialised_result.extend(batch_results) |
|
return serialised_result |
|
|
|
|
|
async def run_ocr_model_step( |
|
step: OCRModel, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
step_execution_mode: StepExecutionMode, |
|
) -> Tuple[NextStepReference, OutputsLookup]: |
|
image = get_image( |
|
step=step, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
if step_execution_mode is StepExecutionMode.LOCAL: |
|
serialised_result = await get_ocr_predictions_locally( |
|
image=image, |
|
model_manager=model_manager, |
|
api_key=api_key, |
|
) |
|
else: |
|
serialised_result = await get_ocr_predictions_from_remote_api( |
|
step=step, |
|
image=image, |
|
api_key=api_key, |
|
) |
|
serialised_result = attach_parent_info( |
|
image=image, |
|
results=serialised_result, |
|
nested_key=None, |
|
) |
|
serialised_result = attach_prediction_type_info( |
|
results=serialised_result, |
|
prediction_type="ocr", |
|
) |
|
outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result |
|
return None, outputs_lookup |
|
|
|
|
|
async def get_ocr_predictions_locally( |
|
image: List[dict], |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
serialised_result = [] |
|
for single_image in image: |
|
inference_request = DoctrOCRInferenceRequest( |
|
image=single_image, |
|
) |
|
doctr_model_id = load_core_model( |
|
model_manager=model_manager, |
|
inference_request=inference_request, |
|
core_model="doctr", |
|
api_key=api_key, |
|
) |
|
result = await model_manager.infer_from_request( |
|
doctr_model_id, inference_request |
|
) |
|
serialised_result.append(result.dict()) |
|
return serialised_result |
|
|
|
|
|
async def get_ocr_predictions_from_remote_api( |
|
step: OCRModel, |
|
image: List[dict], |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
api_url = resolve_model_api_url(step=step) |
|
client = InferenceHTTPClient( |
|
api_url=api_url, |
|
api_key=api_key, |
|
) |
|
if WORKFLOWS_REMOTE_API_TARGET == "hosted": |
|
client.select_api_v0() |
|
configuration = InferenceConfiguration( |
|
max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, |
|
max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
client.configure(configuration) |
|
result = await client.ocr_image_async( |
|
inference_input=[i["value"] for i in image], |
|
) |
|
if len(image) == 1: |
|
return [result] |
|
return result |
|
|
|
|
|
async def run_clip_comparison_step( |
|
step: ClipComparison, |
|
runtime_parameters: Dict[str, Any], |
|
outputs_lookup: OutputsLookup, |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
step_execution_mode: StepExecutionMode, |
|
) -> Tuple[NextStepReference, OutputsLookup]: |
|
image = get_image( |
|
step=step, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
text = resolve_parameter( |
|
selector_or_value=step.text, |
|
runtime_parameters=runtime_parameters, |
|
outputs_lookup=outputs_lookup, |
|
) |
|
if step_execution_mode is StepExecutionMode.LOCAL: |
|
serialised_result = await get_clip_comparison_locally( |
|
image=image, |
|
text=text, |
|
model_manager=model_manager, |
|
api_key=api_key, |
|
) |
|
else: |
|
serialised_result = await get_clip_comparison_from_remote_api( |
|
step=step, |
|
image=image, |
|
text=text, |
|
api_key=api_key, |
|
) |
|
serialised_result = attach_parent_info( |
|
image=image, |
|
results=serialised_result, |
|
nested_key=None, |
|
) |
|
serialised_result = attach_prediction_type_info( |
|
results=serialised_result, |
|
prediction_type="embeddings-comparison", |
|
) |
|
outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result |
|
return None, outputs_lookup |
|
|
|
|
|
async def get_clip_comparison_locally( |
|
image: List[dict], |
|
text: str, |
|
model_manager: ModelManager, |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
serialised_result = [] |
|
for single_image in image: |
|
inference_request = ClipCompareRequest( |
|
subject=single_image, subject_type="image", prompt=text, prompt_type="text" |
|
) |
|
doctr_model_id = load_core_model( |
|
model_manager=model_manager, |
|
inference_request=inference_request, |
|
core_model="clip", |
|
api_key=api_key, |
|
) |
|
result = await model_manager.infer_from_request( |
|
doctr_model_id, inference_request |
|
) |
|
serialised_result.append(result.dict()) |
|
return serialised_result |
|
|
|
|
|
async def get_clip_comparison_from_remote_api( |
|
step: ClipComparison, |
|
image: List[dict], |
|
text: str, |
|
api_key: Optional[str], |
|
) -> List[dict]: |
|
api_url = resolve_model_api_url(step=step) |
|
client = InferenceHTTPClient( |
|
api_url=api_url, |
|
api_key=api_key, |
|
) |
|
if WORKFLOWS_REMOTE_API_TARGET == "hosted": |
|
client.select_api_v0() |
|
image_batches = list( |
|
make_batches( |
|
iterable=image, |
|
batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, |
|
) |
|
) |
|
serialised_result = [] |
|
for single_batch in image_batches: |
|
coroutines = [] |
|
for single_image in single_batch: |
|
coroutine = client.clip_compare_async( |
|
subject=single_image["value"], |
|
prompt=text, |
|
) |
|
coroutines.append(coroutine) |
|
batch_results = list(await asyncio.gather(*coroutines)) |
|
serialised_result.extend(batch_results) |
|
return serialised_result |
|
|
|
|
|
def load_core_model( |
|
model_manager: ModelManager, |
|
inference_request: Union[DoctrOCRInferenceRequest, ClipCompareRequest], |
|
core_model: str, |
|
api_key: Optional[str] = None, |
|
) -> str: |
|
if api_key: |
|
inference_request.api_key = api_key |
|
version_id_field = f"{core_model}_version_id" |
|
core_model_id = ( |
|
f"{core_model}/{inference_request.__getattribute__(version_id_field)}" |
|
) |
|
model_manager.add_model(core_model_id, inference_request.api_key) |
|
return core_model_id |
|
|
|
|
|
def attach_prediction_type_info( |
|
results: List[Dict[str, Any]], |
|
prediction_type: str, |
|
key: str = "prediction_type", |
|
) -> List[Dict[str, Any]]: |
|
for result in results: |
|
result[key] = prediction_type |
|
return results |
|
|
|
|
|
def attach_parent_info( |
|
image: List[Dict[str, Any]], |
|
results: List[Dict[str, Any]], |
|
nested_key: Optional[str] = "predictions", |
|
) -> List[Dict[str, Any]]: |
|
return [ |
|
attach_parent_info_to_image_detections( |
|
image=i, predictions=p, nested_key=nested_key |
|
) |
|
for i, p in zip(image, results) |
|
] |
|
|
|
|
|
def attach_parent_info_to_image_detections( |
|
image: Dict[str, Any], |
|
predictions: Dict[str, Any], |
|
nested_key: Optional[str], |
|
) -> Dict[str, Any]: |
|
predictions["parent_id"] = image["parent_id"] |
|
if nested_key is None: |
|
return predictions |
|
for prediction in predictions[nested_key]: |
|
prediction["parent_id"] = image["parent_id"] |
|
return predictions |
|
|
|
|
|
def anchor_detections_in_parent_coordinates( |
|
image: List[Dict[str, Any]], |
|
serialised_result: List[Dict[str, Any]], |
|
image_metadata_key: str = "image", |
|
detections_key: str = "predictions", |
|
) -> List[Dict[str, Any]]: |
|
return [ |
|
anchor_image_detections_in_parent_coordinates( |
|
image=i, |
|
serialised_result=d, |
|
image_metadata_key=image_metadata_key, |
|
detections_key=detections_key, |
|
) |
|
for i, d in zip(image, serialised_result) |
|
] |
|
|
|
|
|
def anchor_image_detections_in_parent_coordinates( |
|
image: Dict[str, Any], |
|
serialised_result: Dict[str, Any], |
|
image_metadata_key: str = "image", |
|
detections_key: str = "predictions", |
|
) -> Dict[str, Any]: |
|
serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( |
|
serialised_result[detections_key] |
|
) |
|
serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( |
|
serialised_result[image_metadata_key] |
|
) |
|
if ORIGIN_COORDINATES_KEY not in image: |
|
return serialised_result |
|
shift_x, shift_y = ( |
|
image[ORIGIN_COORDINATES_KEY][CENTER_X_KEY], |
|
image[ORIGIN_COORDINATES_KEY][CENTER_Y_KEY], |
|
) |
|
for detection in serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"]: |
|
detection["x"] += shift_x |
|
detection["y"] += shift_y |
|
serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = image[ |
|
ORIGIN_COORDINATES_KEY |
|
][ORIGIN_SIZE_KEY] |
|
return serialised_result |
|
|
|
|
|
ROBOFLOW_MODEL2HOSTED_ENDPOINT = { |
|
"ClassificationModel": HOSTED_CLASSIFICATION_URL, |
|
"MultiLabelClassificationModel": HOSTED_CLASSIFICATION_URL, |
|
"ObjectDetectionModel": HOSTED_DETECT_URL, |
|
"KeypointsDetectionModel": HOSTED_DETECT_URL, |
|
"InstanceSegmentationModel": HOSTED_INSTANCE_SEGMENTATION_URL, |
|
"OCRModel": HOSTED_CORE_MODEL_URL, |
|
"ClipComparison": HOSTED_CORE_MODEL_URL, |
|
} |
|
|
|
|
|
def resolve_model_api_url(step: StepInterface) -> str: |
|
if WORKFLOWS_REMOTE_API_TARGET != "hosted": |
|
return LOCAL_INFERENCE_API_URL |
|
return ROBOFLOW_MODEL2HOSTED_ENDPOINT[step.get_type()] |
|
|