|
from typing import List, Optional |
|
|
|
from inference.core.entities.types import DatasetID, WorkspaceID |
|
from inference.core.roboflow_api import ( |
|
get_roboflow_labeling_batches, |
|
get_roboflow_labeling_jobs, |
|
) |
|
|
|
|
|
def image_can_be_submitted_to_batch( |
|
batch_name: str, |
|
workspace_id: WorkspaceID, |
|
dataset_id: DatasetID, |
|
max_batch_images: Optional[int], |
|
api_key: str, |
|
) -> bool: |
|
"""Check if an image can be submitted to a batch. |
|
|
|
Args: |
|
batch_name: Name of the batch. |
|
workspace_id: ID of the workspace. |
|
dataset_id: ID of the dataset. |
|
max_batch_images: Maximum number of images allowed in the batch. |
|
api_key: API key to use for the request. |
|
|
|
Returns: |
|
True if the image can be submitted to the batch, False otherwise. |
|
""" |
|
if max_batch_images is None: |
|
return True |
|
labeling_batches = get_roboflow_labeling_batches( |
|
api_key=api_key, |
|
workspace_id=workspace_id, |
|
dataset_id=dataset_id, |
|
) |
|
matching_labeling_batch = get_matching_labeling_batch( |
|
all_labeling_batches=labeling_batches["batches"], |
|
batch_name=batch_name, |
|
) |
|
if matching_labeling_batch is None: |
|
return max_batch_images > 0 |
|
batch_images_under_labeling = 0 |
|
if matching_labeling_batch["numJobs"] > 0: |
|
labeling_jobs = get_roboflow_labeling_jobs( |
|
api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id |
|
) |
|
batch_images_under_labeling = get_images_in_labeling_jobs_of_specific_batch( |
|
all_labeling_jobs=labeling_jobs["jobs"], |
|
batch_id=matching_labeling_batch["id"], |
|
) |
|
total_batch_images = matching_labeling_batch["images"] + batch_images_under_labeling |
|
return max_batch_images > total_batch_images |
|
|
|
|
|
def get_matching_labeling_batch( |
|
all_labeling_batches: List[dict], |
|
batch_name: str, |
|
) -> Optional[dict]: |
|
"""Get the matching labeling batch. |
|
|
|
Args: |
|
all_labeling_batches: All labeling batches. |
|
batch_name: Name of the batch. |
|
|
|
Returns: |
|
The matching labeling batch if found, None otherwise. |
|
|
|
""" |
|
matching_batch = None |
|
for labeling_batch in all_labeling_batches: |
|
if labeling_batch["name"] == batch_name: |
|
matching_batch = labeling_batch |
|
break |
|
return matching_batch |
|
|
|
|
|
def get_images_in_labeling_jobs_of_specific_batch( |
|
all_labeling_jobs: List[dict], |
|
batch_id: str, |
|
) -> int: |
|
"""Get the number of images in labeling jobs of a specific batch. |
|
|
|
Args: |
|
all_labeling_jobs: All labeling jobs. |
|
batch_id: ID of the batch. |
|
|
|
Returns: |
|
The number of images in labeling jobs of the batch. |
|
|
|
""" |
|
|
|
matching_jobs = [] |
|
for labeling_job in all_labeling_jobs: |
|
if batch_id in labeling_job["sourceBatch"]: |
|
matching_jobs.append(labeling_job) |
|
return sum(job["numImages"] for job in matching_jobs) |
|
|