AmmarFahmy
adding all files
105b369
from typing import Optional, Dict, List, Tuple, Any
from typing_extensions import Literal
from phi.embedder.base import Embedder
from phi.utils.log import logger
try:
from openai import OpenAI as OpenAIClient
from openai.types.create_embedding_response import CreateEmbeddingResponse
except ImportError:
raise ImportError("`openai` not installed")
class OpenAIEmbedder(Embedder):
model: str = "text-embedding-ada-002"
dimensions: int = 1536
encoding_format: Literal["float", "base64"] = "float"
user: Optional[str] = None
api_key: Optional[str] = None
organization: Optional[str] = None
base_url: Optional[str] = None
request_params: Optional[Dict[str, Any]] = None
client_params: Optional[Dict[str, Any]] = None
openai_client: Optional[OpenAIClient] = None
@property
def client(self) -> OpenAIClient:
if self.openai_client:
return self.openai_client
_client_params: Dict[str, Any] = {}
if self.api_key:
_client_params["api_key"] = self.api_key
if self.organization:
_client_params["organization"] = self.organization
if self.base_url:
_client_params["base_url"] = self.base_url
if self.client_params:
_client_params.update(self.client_params)
return OpenAIClient(**_client_params)
def _response(self, text: str) -> CreateEmbeddingResponse:
_request_params: Dict[str, Any] = {
"input": text,
"model": self.model,
"encoding_format": self.encoding_format,
}
if self.user is not None:
_request_params["user"] = self.user
if self.model.startswith("text-embedding-3"):
_request_params["dimensions"] = self.dimensions
if self.request_params:
_request_params.update(self.request_params)
return self.client.embeddings.create(**_request_params)
def get_embedding(self, text: str) -> List[float]:
response: CreateEmbeddingResponse = self._response(text=text)
try:
return response.data[0].embedding
except Exception as e:
logger.warning(e)
return []
def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
response: CreateEmbeddingResponse = self._response(text=text)
embedding = response.data[0].embedding
usage = response.usage
return embedding, usage.model_dump()