medivocate / src /utilities /embedding.py
alexneakameni's picture
Using HF Embedding directly
1c0d4cb verified
import logging
import os
from typing import Any, List
import torch
from langchain_core.embeddings import Embeddings
from langchain_huggingface import (
HuggingFaceEmbeddings,
HuggingFaceEndpointEmbeddings,
)
from pydantic import BaseModel, Field
class CustomEmbedding(BaseModel, Embeddings):
"""
Custom embedding class that supports both hosted and CPU embeddings.
"""
hosted_embedding: HuggingFaceEndpointEmbeddings = Field(
default_factory=lambda: None
)
cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
matryoshka_dim: int = Field(default=256)
def get_instruction(self) -> str:
"""
Generates the instruction for the embedding model based on environment variables.
Returns:
str: The instruction string.
"""
if "nomic" in os.getenv("HF_MODEL"):
return (
"query"
if (os.getenv("IS_APP", "0") == "1")
else "search_document: "
)
return (
"Represent this sentence for searching relevant passages"
if (os.getenv("IS_APP", "0") == "1")
else ""
)
def get_hf_embedd(self) -> HuggingFaceEmbeddings:
"""
Initializes the HuggingFaceEmbeddings with the appropriate settings.
Returns:
HuggingFaceEmbeddings: The initialized HuggingFaceEmbeddings object.
"""
return HuggingFaceEmbeddings(
model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
model_kwargs={
"device": "cpu" if not torch.cuda.is_available() else "cuda",
"trust_remote_code": True,
},
encode_kwargs={
"normalize_embeddings": True,
"prompt": self.get_instruction(),
},
)
def __init__(self, matryoshka_dim=256, **kwargs: Any):
"""
Initializes the CustomEmbedding with the given parameters.
Args:
matryoshka_dim (int): Dimension of the embeddings.
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
query_instruction = self.get_instruction()
self.matryoshka_dim = matryoshka_dim
if torch.cuda.is_available():
logging.info("CUDA is available")
self.hosted_embedding = self.get_hf_embedd()
self.cpu_embedding = self.hosted_embedding
else:
logging.info("CUDA is not available")
self.hosted_embedding = self.get_hf_embedd()
"""
HuggingFaceEndpointEmbeddings is deprecated
HuggingFaceEndpointEmbeddings(
model=os.getenv("HF_MODEL"),
model_kwargs={
"encode_kwargs": {
"normalize_embeddings": True,
"prompt": query_instruction,
}
},
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
)"""
self.cpu_embedding = self.hosted_embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of documents using the appropriate embedding model.
Args:
texts (List[str]): List of document texts to embed.
Returns:
List[List[float]]: List of embedded document vectors.
"""
try:
embed = self.hosted_embedding.embed_documents(texts)
except Exception as e:
logging.warning(f"Issue with batch hosted embedding, moving to CPU: {e}")
embed = self.cpu_embedding.embed_documents(texts)
return (
[e[: self.matryoshka_dim] for e in embed] if self.matryoshka_dim else embed
)
def embed_query(self, text: str) -> List[float]:
"""
Embeds a single query using the appropriate embedding model.
Args:
text (str): The query text to embed.
Returns:
List[float]: The embedded query vector.
"""
try:
logging.info(text)
embed = self.hosted_embedding.embed_query(text)
except Exception as e:
logging.warning(f"Issue with hosted embedding, moving to CPU: {e}")
embed = self.cpu_embedding.embed_query(text)
logging.warning(text)
return embed[: self.matryoshka_dim] if self.matryoshka_dim else embed