"""MistralAI embeddings file.""" from typing import Any, List, Optional from llama_index.core.base.embeddings.base import ( DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding, ) from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.callbacks.base import CallbackManager from llama_index.core.base.llms.generic_utils import get_from_param_or_env from mistralai.async_client import MistralAsyncClient from mistralai.client import MistralClient class MistralAIEmbedding(BaseEmbedding): """Class for MistralAI embeddings. Args: model_name (str): Model for embedding. Defaults to "mistral-embed". api_key (Optional[str]): API key to access the model. Defaults to None. """ # Instance variables initialized via Pydantic's mechanism _mistralai_client: Any = PrivateAttr() _mistralai_async_client: Any = PrivateAttr() def __init__( self, model_name: str = "mistral-embed", api_key: Optional[str] = None, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ): api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "") if not api_key: raise ValueError( "You must provide an API key to use mistralai. " "You can either pass it in as an argument or set it `MISTRAL_API_KEY`." ) self._mistralai_client = MistralClient(api_key=api_key) self._mistralai_async_client = MistralAsyncClient(api_key=api_key) super().__init__( model_name=model_name, embed_batch_size=embed_batch_size, callback_manager=callback_manager, **kwargs, ) @classmethod def class_name(cls) -> str: return "MistralAIEmbedding" def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" return ( self._mistralai_client.embeddings(model=self.model_name, input=[query]) .data[0] .embedding ) async def _aget_query_embedding(self, query: str) -> List[float]: """The asynchronous version of _get_query_embedding.""" return ( ( await self._mistralai_async_client.embeddings( model=self.model_name, input=[query] ) ) .data[0] .embedding ) def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" return ( self._mistralai_client.embeddings(model=self.model_name, input=[text]) .data[0] .embedding ) async def _aget_text_embedding(self, text: str) -> List[float]: """Asynchronously get text embedding.""" return ( ( await self._mistralai_async_client.embeddings( model=self.model_name, input=[text] ) ) .data[0] .embedding ) def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings.""" embedding_response = self._mistralai_client.embeddings( model=self.model_name, input=texts ).data return [embed.embedding for embed in embedding_response] async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Asynchronously get text embeddings.""" embedding_response = await self._mistralai_async_client.embeddings( model=self.model_name, input=texts ) return [embed.embedding for embed in embedding_response.data]