Spaces:
Sleeping
Sleeping
File size: 3,740 Bytes
70b87af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
"""MistralAI embeddings file."""
from typing import Any, List, Optional
from llama_index.core.base.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
BaseEmbedding,
)
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
class MistralAIEmbedding(BaseEmbedding):
"""Class for MistralAI embeddings.
Args:
model_name (str): Model for embedding.
Defaults to "mistral-embed".
api_key (Optional[str]): API key to access the model. Defaults to None.
"""
# Instance variables initialized via Pydantic's mechanism
_mistralai_client: Any = PrivateAttr()
_mistralai_async_client: Any = PrivateAttr()
def __init__(
self,
model_name: str = "mistral-embed",
api_key: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
):
api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")
if not api_key:
raise ValueError(
"You must provide an API key to use mistralai. "
"You can either pass it in as an argument or set it `MISTRAL_API_KEY`."
)
self._mistralai_client = MistralClient(api_key=api_key)
self._mistralai_async_client = MistralAsyncClient(api_key=api_key)
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "MistralAIEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return (
self._mistralai_client.embeddings(model=self.model_name, input=[query])
.data[0]
.embedding
)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return (
(
await self._mistralai_async_client.embeddings(
model=self.model_name, input=[query]
)
)
.data[0]
.embedding
)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return (
self._mistralai_client.embeddings(model=self.model_name, input=[text])
.data[0]
.embedding
)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return (
(
await self._mistralai_async_client.embeddings(
model=self.model_name, input=[text]
)
)
.data[0]
.embedding
)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
embedding_response = self._mistralai_client.embeddings(
model=self.model_name, input=texts
).data
return [embed.embedding for embed in embedding_response]
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
embedding_response = await self._mistralai_async_client.embeddings(
model=self.model_name, input=texts
)
return [embed.embedding for embed in embedding_response.data]
|