Spaces:
Sleeping
Sleeping
File size: 4,509 Bytes
15aea1e 1c0d4cb 15aea1e 1c0d4cb 15aea1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import logging
import os
from typing import Any, List
import torch
from langchain_core.embeddings import Embeddings
from langchain_huggingface import (
HuggingFaceEmbeddings,
HuggingFaceEndpointEmbeddings,
)
from pydantic import BaseModel, Field
class CustomEmbedding(BaseModel, Embeddings):
"""
Custom embedding class that supports both hosted and CPU embeddings.
"""
hosted_embedding: HuggingFaceEndpointEmbeddings = Field(
default_factory=lambda: None
)
cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
matryoshka_dim: int = Field(default=256)
def get_instruction(self) -> str:
"""
Generates the instruction for the embedding model based on environment variables.
Returns:
str: The instruction string.
"""
if "nomic" in os.getenv("HF_MODEL"):
return (
"query"
if (os.getenv("IS_APP", "0") == "1")
else "search_document: "
)
return (
"Represent this sentence for searching relevant passages"
if (os.getenv("IS_APP", "0") == "1")
else ""
)
def get_hf_embedd(self) -> HuggingFaceEmbeddings:
"""
Initializes the HuggingFaceEmbeddings with the appropriate settings.
Returns:
HuggingFaceEmbeddings: The initialized HuggingFaceEmbeddings object.
"""
return HuggingFaceEmbeddings(
model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
model_kwargs={
"device": "cpu" if not torch.cuda.is_available() else "cuda",
"trust_remote_code": True,
},
encode_kwargs={
"normalize_embeddings": True,
"prompt": self.get_instruction(),
},
)
def __init__(self, matryoshka_dim=256, **kwargs: Any):
"""
Initializes the CustomEmbedding with the given parameters.
Args:
matryoshka_dim (int): Dimension of the embeddings.
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
query_instruction = self.get_instruction()
self.matryoshka_dim = matryoshka_dim
if torch.cuda.is_available():
logging.info("CUDA is available")
self.hosted_embedding = self.get_hf_embedd()
self.cpu_embedding = self.hosted_embedding
else:
logging.info("CUDA is not available")
self.hosted_embedding = self.get_hf_embedd()
"""
HuggingFaceEndpointEmbeddings is deprecated
HuggingFaceEndpointEmbeddings(
model=os.getenv("HF_MODEL"),
model_kwargs={
"encode_kwargs": {
"normalize_embeddings": True,
"prompt": query_instruction,
}
},
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
)"""
self.cpu_embedding = self.hosted_embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of documents using the appropriate embedding model.
Args:
texts (List[str]): List of document texts to embed.
Returns:
List[List[float]]: List of embedded document vectors.
"""
try:
embed = self.hosted_embedding.embed_documents(texts)
except Exception as e:
logging.warning(f"Issue with batch hosted embedding, moving to CPU: {e}")
embed = self.cpu_embedding.embed_documents(texts)
return (
[e[: self.matryoshka_dim] for e in embed] if self.matryoshka_dim else embed
)
def embed_query(self, text: str) -> List[float]:
"""
Embeds a single query using the appropriate embedding model.
Args:
text (str): The query text to embed.
Returns:
List[float]: The embedded query vector.
"""
try:
logging.info(text)
embed = self.hosted_embedding.embed_query(text)
except Exception as e:
logging.warning(f"Issue with hosted embedding, moving to CPU: {e}")
embed = self.cpu_embedding.embed_query(text)
logging.warning(text)
return embed[: self.matryoshka_dim] if self.matryoshka_dim else embed
|