Spaces:
Runtime error
Runtime error
from typing import Optional, Dict, List, Tuple, Any | |
from phi.embedder.base import Embedder | |
from phi.utils.log import logger | |
try: | |
from mistralai.client import MistralClient | |
from mistralai.models.embeddings import EmbeddingResponse | |
except ImportError: | |
raise ImportError("`openai` not installed") | |
class MistralEmbedder(Embedder): | |
model: str = "mistral-embed" | |
dimensions: int = 1024 | |
# -*- Request parameters | |
request_params: Optional[Dict[str, Any]] = None | |
# -*- Client parameters | |
api_key: Optional[str] = None | |
endpoint: Optional[str] = None | |
max_retries: Optional[int] = None | |
timeout: Optional[int] = None | |
client_params: Optional[Dict[str, Any]] = None | |
# -*- Provide the MistralClient manually | |
mistral_client: Optional[MistralClient] = None | |
def client(self) -> MistralClient: | |
if self.mistral_client: | |
return self.mistral_client | |
_client_params: Dict[str, Any] = {} | |
if self.api_key: | |
_client_params["api_key"] = self.api_key | |
if self.endpoint: | |
_client_params["endpoint"] = self.endpoint | |
if self.max_retries: | |
_client_params["max_retries"] = self.max_retries | |
if self.timeout: | |
_client_params["timeout"] = self.timeout | |
if self.client_params: | |
_client_params.update(self.client_params) | |
return MistralClient(**_client_params) | |
def _response(self, text: str) -> EmbeddingResponse: | |
_request_params: Dict[str, Any] = { | |
"input": text, | |
"model": self.model, | |
} | |
if self.request_params: | |
_request_params.update(self.request_params) | |
return self.client.embeddings(**_request_params) | |
def get_embedding(self, text: str) -> List[float]: | |
response: EmbeddingResponse = self._response(text=text) | |
try: | |
return response.data[0].embedding | |
except Exception as e: | |
logger.warning(e) | |
return [] | |
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]: | |
response: EmbeddingResponse = self._response(text=text) | |
embedding = response.data[0].embedding | |
usage = response.usage | |
return embedding, usage.model_dump() | |