|
from time import perf_counter |
|
from typing import Any, Dict, List, Tuple, Union |
|
|
|
import clip |
|
import numpy as np |
|
import onnxruntime |
|
from PIL import Image |
|
|
|
from inference.core.entities.requests.clip import ( |
|
ClipCompareRequest, |
|
ClipImageEmbeddingRequest, |
|
ClipInferenceRequest, |
|
ClipTextEmbeddingRequest, |
|
) |
|
from inference.core.entities.requests.inference import InferenceRequestImage |
|
from inference.core.entities.responses.clip import ( |
|
ClipCompareResponse, |
|
ClipEmbeddingResponse, |
|
) |
|
from inference.core.entities.responses.inference import InferenceResponse |
|
from inference.core.env import ( |
|
CLIP_MAX_BATCH_SIZE, |
|
CLIP_MODEL_ID, |
|
ONNXRUNTIME_EXECUTION_PROVIDERS, |
|
REQUIRED_ONNX_PROVIDERS, |
|
TENSORRT_CACHE_PATH, |
|
) |
|
from inference.core.exceptions import OnnxProviderNotAvailable |
|
from inference.core.models.roboflow import OnnxRoboflowCoreModel |
|
from inference.core.models.types import PreprocessReturnMetadata |
|
from inference.core.utils.image_utils import load_image_rgb |
|
from inference.core.utils.onnx import get_onnxruntime_execution_providers |
|
from inference.core.utils.postprocess import cosine_similarity |
|
|
|
|
|
class Clip(OnnxRoboflowCoreModel): |
|
"""Roboflow ONNX ClipModel model. |
|
|
|
This class is responsible for handling the ONNX ClipModel model, including |
|
loading the model, preprocessing the input, and performing inference. |
|
|
|
Attributes: |
|
visual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for visual inference. |
|
textual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for textual inference. |
|
resolution (int): The resolution of the input image. |
|
clip_preprocess (function): Function to preprocess the image. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*args, |
|
model_id: str = CLIP_MODEL_ID, |
|
onnxruntime_execution_providers: List[ |
|
str |
|
] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS), |
|
**kwargs, |
|
): |
|
"""Initializes the Clip with the given arguments and keyword arguments.""" |
|
self.onnxruntime_execution_providers = onnxruntime_execution_providers |
|
t1 = perf_counter() |
|
super().__init__(*args, model_id=model_id, **kwargs) |
|
|
|
self.log("Creating inference sessions") |
|
self.visual_onnx_session = onnxruntime.InferenceSession( |
|
self.cache_file("visual.onnx"), |
|
providers=self.onnxruntime_execution_providers, |
|
) |
|
|
|
self.textual_onnx_session = onnxruntime.InferenceSession( |
|
self.cache_file("textual.onnx"), |
|
providers=self.onnxruntime_execution_providers, |
|
) |
|
|
|
if REQUIRED_ONNX_PROVIDERS: |
|
available_providers = onnxruntime.get_available_providers() |
|
for provider in REQUIRED_ONNX_PROVIDERS: |
|
if provider not in available_providers: |
|
raise OnnxProviderNotAvailable( |
|
f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device." |
|
) |
|
|
|
self.resolution = self.visual_onnx_session.get_inputs()[0].shape[2] |
|
|
|
self.clip_preprocess = clip.clip._transform(self.resolution) |
|
self.log(f"CLIP model loaded in {perf_counter() - t1:.2f} seconds") |
|
self.task_type = "embedding" |
|
|
|
def compare( |
|
self, |
|
subject: Any, |
|
prompt: Any, |
|
subject_type: str = "image", |
|
prompt_type: Union[str, List[str], Dict[str, Any]] = "text", |
|
**kwargs, |
|
) -> Union[List[float], Dict[str, float]]: |
|
""" |
|
Compares the subject with the prompt to calculate similarity scores. |
|
|
|
Args: |
|
subject (Any): The subject data to be compared. Can be either an image or text. |
|
prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values. |
|
subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image". |
|
prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text". |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys. |
|
|
|
Raises: |
|
ValueError: If subject_type or prompt_type is neither "image" nor "text". |
|
ValueError: If the number of prompts exceeds the maximum batch size. |
|
""" |
|
|
|
if subject_type == "image": |
|
subject_embeddings = self.embed_image(subject) |
|
elif subject_type == "text": |
|
subject_embeddings = self.embed_text(subject) |
|
else: |
|
raise ValueError( |
|
"subject_type must be either 'image' or 'text', but got {request.subject_type}" |
|
) |
|
|
|
if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt): |
|
prompt_keys = prompt.keys() |
|
prompt = [prompt[k] for k in prompt_keys] |
|
prompt_obj = "dict" |
|
else: |
|
prompt = prompt |
|
if not isinstance(prompt, list): |
|
prompt = [prompt] |
|
prompt_obj = "list" |
|
|
|
if len(prompt) > CLIP_MAX_BATCH_SIZE: |
|
raise ValueError( |
|
f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}" |
|
) |
|
|
|
if prompt_type == "image": |
|
prompt_embeddings = self.embed_image(prompt) |
|
elif prompt_type == "text": |
|
prompt_embeddings = self.embed_text(prompt) |
|
else: |
|
raise ValueError( |
|
"prompt_type must be either 'image' or 'text', but got {request.prompt_type}" |
|
) |
|
|
|
similarities = [ |
|
cosine_similarity(subject_embeddings, p) for p in prompt_embeddings |
|
] |
|
|
|
if prompt_obj == "dict": |
|
similarities = dict(zip(prompt_keys, similarities)) |
|
|
|
return similarities |
|
|
|
def make_compare_response( |
|
self, similarities: Union[List[float], Dict[str, float]] |
|
) -> ClipCompareResponse: |
|
""" |
|
Creates a ClipCompareResponse object from the provided similarity data. |
|
|
|
Args: |
|
similarities (Union[List[float], Dict[str, float]]): A list or dictionary containing similarity scores. |
|
|
|
Returns: |
|
ClipCompareResponse: An instance of the ClipCompareResponse with the given similarity scores. |
|
|
|
Example: |
|
Assuming `ClipCompareResponse` expects a dictionary of string-float pairs: |
|
|
|
>>> make_compare_response({"image1": 0.98, "image2": 0.76}) |
|
ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76}) |
|
""" |
|
response = ClipCompareResponse(similarity=similarities) |
|
return response |
|
|
|
def embed_image( |
|
self, |
|
image: Any, |
|
**kwargs, |
|
) -> np.ndarray: |
|
""" |
|
Embeds an image or a list of images using the Clip model. |
|
|
|
Args: |
|
image (Any): The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
np.ndarray: The embeddings of the image(s) as a numpy array. |
|
|
|
Raises: |
|
ValueError: If the number of images in the list exceeds the maximum batch size. |
|
|
|
Notes: |
|
The function measures performance using perf_counter and also has support for ONNX session to get embeddings. |
|
""" |
|
t1 = perf_counter() |
|
|
|
if isinstance(image, list): |
|
if len(image) > CLIP_MAX_BATCH_SIZE: |
|
raise ValueError( |
|
f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}" |
|
) |
|
imgs = [self.preproc_image(i) for i in image] |
|
img_in = np.concatenate(imgs, axis=0) |
|
else: |
|
img_in = self.preproc_image(image) |
|
|
|
onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in} |
|
embeddings = self.visual_onnx_session.run(None, onnx_input_image)[0] |
|
|
|
return embeddings |
|
|
|
def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: |
|
onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in} |
|
embeddings = self.visual_onnx_session.run(None, onnx_input_image)[0] |
|
return (embeddings,) |
|
|
|
def make_embed_image_response( |
|
self, embeddings: np.ndarray |
|
) -> ClipEmbeddingResponse: |
|
""" |
|
Converts the given embeddings into a ClipEmbeddingResponse object. |
|
|
|
Args: |
|
embeddings (np.ndarray): A numpy array containing the embeddings for an image or images. |
|
|
|
Returns: |
|
ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list. |
|
|
|
Example: |
|
>>> embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) |
|
>>> make_embed_image_response(embeddings_array) |
|
ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) |
|
""" |
|
response = ClipEmbeddingResponse(embeddings=embeddings.tolist()) |
|
|
|
return response |
|
|
|
def embed_text( |
|
self, |
|
text: Union[str, List[str]], |
|
**kwargs, |
|
) -> np.ndarray: |
|
""" |
|
Embeds a text or a list of texts using the Clip model. |
|
|
|
Args: |
|
text (Union[str, List[str]]): The text string or list of text strings to be embedded. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
np.ndarray: The embeddings of the text or texts as a numpy array. |
|
|
|
Raises: |
|
ValueError: If the number of text strings in the list exceeds the maximum batch size. |
|
|
|
Notes: |
|
The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter. |
|
""" |
|
t1 = perf_counter() |
|
|
|
if isinstance(text, list): |
|
if len(text) > CLIP_MAX_BATCH_SIZE: |
|
raise ValueError( |
|
f"The maximum number of text strings that can be embedded at once is {CLIP_MAX_BATCH_SIZE}" |
|
) |
|
|
|
texts = text |
|
else: |
|
texts = [text] |
|
|
|
texts = clip.tokenize(texts).numpy().astype(np.int32) |
|
|
|
onnx_input_text = {self.textual_onnx_session.get_inputs()[0].name: texts} |
|
embeddings = self.textual_onnx_session.run(None, onnx_input_text)[0] |
|
|
|
return embeddings |
|
|
|
def make_embed_text_response(self, embeddings: np.ndarray) -> ClipEmbeddingResponse: |
|
""" |
|
Converts the given text embeddings into a ClipEmbeddingResponse object. |
|
|
|
Args: |
|
embeddings (np.ndarray): A numpy array containing the embeddings for a text or texts. |
|
|
|
Returns: |
|
ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list. |
|
|
|
Example: |
|
>>> embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) |
|
>>> make_embed_text_response(embeddings_array) |
|
ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) |
|
""" |
|
response = ClipEmbeddingResponse(embeddings=embeddings.tolist()) |
|
return response |
|
|
|
def get_infer_bucket_file_list(self) -> List[str]: |
|
"""Gets the list of files required for inference. |
|
|
|
Returns: |
|
List[str]: The list of file names. |
|
""" |
|
return ["textual.onnx", "visual.onnx"] |
|
|
|
def infer_from_request( |
|
self, request: ClipInferenceRequest |
|
) -> ClipEmbeddingResponse: |
|
"""Routes the request to the appropriate inference function. |
|
|
|
Args: |
|
request (ClipInferenceRequest): The request object containing the inference details. |
|
|
|
Returns: |
|
ClipEmbeddingResponse: The response object containing the embeddings. |
|
""" |
|
t1 = perf_counter() |
|
if isinstance(request, ClipImageEmbeddingRequest): |
|
infer_func = self.embed_image |
|
make_response_func = self.make_embed_image_response |
|
elif isinstance(request, ClipTextEmbeddingRequest): |
|
infer_func = self.embed_text |
|
make_response_func = self.make_embed_text_response |
|
elif isinstance(request, ClipCompareRequest): |
|
infer_func = self.compare |
|
make_response_func = self.make_compare_response |
|
else: |
|
raise ValueError( |
|
f"Request type {type(request)} is not a valid ClipInferenceRequest" |
|
) |
|
data = infer_func(**request.dict()) |
|
response = make_response_func(data) |
|
response.time = perf_counter() - t1 |
|
return response |
|
|
|
def make_response(self, embeddings, *args, **kwargs) -> InferenceResponse: |
|
return [self.make_embed_image_response(embeddings)] |
|
|
|
def postprocess( |
|
self, |
|
predictions: Tuple[np.ndarray], |
|
preprocess_return_metadata: PreprocessReturnMetadata, |
|
**kwargs, |
|
) -> Any: |
|
return [self.make_embed_image_response(predictions[0])] |
|
|
|
def infer(self, image: Any, **kwargs) -> Any: |
|
"""Embeds an image""" |
|
return super().infer(image, **kwargs) |
|
|
|
def preproc_image(self, image: InferenceRequestImage) -> np.ndarray: |
|
"""Preprocesses an inference request image. |
|
|
|
Args: |
|
image (InferenceRequestImage): The object containing information necessary to load the image for inference. |
|
|
|
Returns: |
|
np.ndarray: A numpy array of the preprocessed image pixel data. |
|
""" |
|
pil_image = Image.fromarray(load_image_rgb(image)) |
|
preprocessed_image = self.clip_preprocess(pil_image) |
|
|
|
img_in = np.expand_dims(preprocessed_image, axis=0) |
|
|
|
return img_in.astype(np.float32) |
|
|
|
def preprocess( |
|
self, image: Any, **kwargs |
|
) -> Tuple[np.ndarray, PreprocessReturnMetadata]: |
|
return self.preproc_image(image), PreprocessReturnMetadata({}) |
|
|