import logging import os from typing import Any, List import torch from langchain_core.embeddings import Embeddings from langchain_huggingface import ( HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings, ) from pydantic import BaseModel, Field class CustomEmbedding(BaseModel, Embeddings): """ Custom embedding class that supports both hosted and CPU embeddings. """ hosted_embedding: HuggingFaceEndpointEmbeddings = Field( default_factory=lambda: None ) cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None) matryoshka_dim: int = Field(default=256) def get_instruction(self) -> str: """ Generates the instruction for the embedding model based on environment variables. Returns: str: The instruction string. """ if "nomic" in os.getenv("HF_MODEL"): return ( "query" if (os.getenv("IS_APP", "0") == "1") else "search_document: " ) return ( "Represent this sentence for searching relevant passages" if (os.getenv("IS_APP", "0") == "1") else "" ) def get_hf_embedd(self) -> HuggingFaceEmbeddings: """ Initializes the HuggingFaceEmbeddings with the appropriate settings. Returns: HuggingFaceEmbeddings: The initialized HuggingFaceEmbeddings object. """ return HuggingFaceEmbeddings( model_name=os.getenv("HF_MODEL"), # You can replace with any HF model model_kwargs={ "device": "cpu" if not torch.cuda.is_available() else "cuda", "trust_remote_code": True, }, encode_kwargs={ "normalize_embeddings": True, "prompt": self.get_instruction(), }, ) def __init__(self, matryoshka_dim=256, **kwargs: Any): """ Initializes the CustomEmbedding with the given parameters. Args: matryoshka_dim (int): Dimension of the embeddings. **kwargs: Additional keyword arguments. """ super().__init__(**kwargs) query_instruction = self.get_instruction() self.matryoshka_dim = matryoshka_dim if torch.cuda.is_available(): logging.info("CUDA is available") self.hosted_embedding = self.get_hf_embedd() self.cpu_embedding = self.hosted_embedding else: logging.info("CUDA is not available") self.hosted_embedding = self.get_hf_embedd() """ HuggingFaceEndpointEmbeddings is deprecated HuggingFaceEndpointEmbeddings( model=os.getenv("HF_MODEL"), model_kwargs={ "encode_kwargs": { "normalize_embeddings": True, "prompt": query_instruction, } }, huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), )""" self.cpu_embedding = self.hosted_embedding def embed_documents(self, texts: List[str]) -> List[List[float]]: """ Embeds a list of documents using the appropriate embedding model. Args: texts (List[str]): List of document texts to embed. Returns: List[List[float]]: List of embedded document vectors. """ try: embed = self.hosted_embedding.embed_documents(texts) except Exception as e: logging.warning(f"Issue with batch hosted embedding, moving to CPU: {e}") embed = self.cpu_embedding.embed_documents(texts) return ( [e[: self.matryoshka_dim] for e in embed] if self.matryoshka_dim else embed ) def embed_query(self, text: str) -> List[float]: """ Embeds a single query using the appropriate embedding model. Args: text (str): The query text to embed. Returns: List[float]: The embedded query vector. """ try: logging.info(text) embed = self.hosted_embedding.embed_query(text) except Exception as e: logging.warning(f"Issue with hosted embedding, moving to CPU: {e}") embed = self.cpu_embedding.embed_query(text) logging.warning(text) return embed[: self.matryoshka_dim] if self.matryoshka_dim else embed