Spaces:
Sleeping
Sleeping
from typing import List, Optional, Union | |
import httpx | |
from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException | |
from litellm.llms.base_llm.embedding.transformation import ( | |
BaseEmbeddingConfig, | |
LiteLLMLoggingObj, | |
) | |
from litellm.types.llms.openai import AllEmbeddingInputValues | |
from litellm.types.utils import EmbeddingResponse | |
from ..common_utils import TritonError | |
class TritonEmbeddingConfig(BaseEmbeddingConfig): | |
""" | |
Transformations for triton /embeddings endpoint (This is a trtllm model) | |
""" | |
def __init__(self) -> None: | |
pass | |
def get_supported_openai_params(self, model: str) -> list: | |
return [] | |
def map_openai_params( | |
self, | |
non_default_params: dict, | |
optional_params: dict, | |
model: str, | |
drop_params: bool, | |
) -> dict: | |
""" | |
Map OpenAI params to Triton Embedding params | |
""" | |
return optional_params | |
def validate_environment( | |
self, | |
headers: dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
) -> dict: | |
return {} | |
def transform_embedding_request( | |
self, | |
model: str, | |
input: AllEmbeddingInputValues, | |
optional_params: dict, | |
headers: dict, | |
) -> dict: | |
return { | |
"inputs": [ | |
{ | |
"name": "input_text", | |
"shape": [len(input)], | |
"datatype": "BYTES", | |
"data": input, | |
} | |
] | |
} | |
def transform_embedding_response( | |
self, | |
model: str, | |
raw_response: httpx.Response, | |
model_response: EmbeddingResponse, | |
logging_obj: LiteLLMLoggingObj, | |
api_key: Optional[str] = None, | |
request_data: dict = {}, | |
optional_params: dict = {}, | |
litellm_params: dict = {}, | |
) -> EmbeddingResponse: | |
try: | |
raw_response_json = raw_response.json() | |
except Exception: | |
raise TritonError( | |
message=raw_response.text, status_code=raw_response.status_code | |
) | |
_embedding_output = [] | |
_outputs = raw_response_json["outputs"] | |
for output in _outputs: | |
_shape = output["shape"] | |
_data = output["data"] | |
_split_output_data = self.split_embedding_by_shape(_data, _shape) | |
for idx, embedding in enumerate(_split_output_data): | |
_embedding_output.append( | |
{ | |
"object": "embedding", | |
"index": idx, | |
"embedding": embedding, | |
} | |
) | |
model_response.model = raw_response_json.get("model_name", "None") | |
model_response.data = _embedding_output | |
return model_response | |
def get_error_class( | |
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
) -> BaseLLMException: | |
return TritonError( | |
message=error_message, status_code=status_code, headers=headers | |
) | |
def split_embedding_by_shape( | |
data: List[float], shape: List[int] | |
) -> List[List[float]]: | |
if len(shape) != 2: | |
raise ValueError("Shape must be of length 2.") | |
embedding_size = shape[1] | |
return [ | |
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0]) | |
] | |