AmmarFahmy
adding all files
105b369
from typing import Optional, Dict, List, Tuple, Any
from phi.embedder.base import Embedder
from phi.utils.log import logger
try:
from mistralai.client import MistralClient
from mistralai.models.embeddings import EmbeddingResponse
except ImportError:
raise ImportError("`openai` not installed")
class MistralEmbedder(Embedder):
model: str = "mistral-embed"
dimensions: int = 1024
# -*- Request parameters
request_params: Optional[Dict[str, Any]] = None
# -*- Client parameters
api_key: Optional[str] = None
endpoint: Optional[str] = None
max_retries: Optional[int] = None
timeout: Optional[int] = None
client_params: Optional[Dict[str, Any]] = None
# -*- Provide the MistralClient manually
mistral_client: Optional[MistralClient] = None
@property
def client(self) -> MistralClient:
if self.mistral_client:
return self.mistral_client
_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.endpoint:
_client_params["endpoint"] = self.endpoint
if self.max_retries:
_client_params["max_retries"] = self.max_retries
if self.timeout:
_client_params["timeout"] = self.timeout
if self.client_params:
_client_params.update(self.client_params)
return MistralClient(**_client_params)
def _response(self, text: str) -> EmbeddingResponse:
_request_params: Dict[str, Any] = {
"input": text,
"model": self.model,
}
if self.request_params:
_request_params.update(self.request_params)
return self.client.embeddings(**_request_params)
def get_embedding(self, text: str) -> List[float]:
response: EmbeddingResponse = self._response(text=text)
try:
return response.data[0].embedding
except Exception as e:
logger.warning(e)
return []
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
response: EmbeddingResponse = self._response(text=text)
embedding = response.data[0].embedding
usage = response.usage
return embedding, usage.model_dump()