File size: 4,509 Bytes
15aea1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c0d4cb
 
 
 
15aea1e
 
 
 
 
 
 
 
1c0d4cb
 
15aea1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import logging
import os
from typing import Any, List

import torch
from langchain_core.embeddings import Embeddings
from langchain_huggingface import (
    HuggingFaceEmbeddings,
    HuggingFaceEndpointEmbeddings,
)
from pydantic import BaseModel, Field


class CustomEmbedding(BaseModel, Embeddings):
    """
    Custom embedding class that supports both hosted and CPU embeddings.
    """

    hosted_embedding: HuggingFaceEndpointEmbeddings = Field(
        default_factory=lambda: None
    )
    cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
    matryoshka_dim: int = Field(default=256)

    def get_instruction(self) -> str:
        """
        Generates the instruction for the embedding model based on environment variables.

        Returns:
            str: The instruction string.
        """
        if "nomic" in os.getenv("HF_MODEL"):
            return (
                "query"
                if (os.getenv("IS_APP", "0") == "1")
                else "search_document: "
            )
        return (
            "Represent this sentence for searching relevant passages"
            if (os.getenv("IS_APP", "0") == "1")
            else ""
        )

    def get_hf_embedd(self) -> HuggingFaceEmbeddings:
        """
        Initializes the HuggingFaceEmbeddings with the appropriate settings.

        Returns:
            HuggingFaceEmbeddings: The initialized HuggingFaceEmbeddings object.
        """
        return HuggingFaceEmbeddings(
            model_name=os.getenv("HF_MODEL"),  # You can replace with any HF model
            model_kwargs={
                "device": "cpu" if not torch.cuda.is_available() else "cuda",
                "trust_remote_code": True,
            },
            encode_kwargs={
                "normalize_embeddings": True,
                "prompt": self.get_instruction(),
            },
        )

    def __init__(self, matryoshka_dim=256, **kwargs: Any):
        """
        Initializes the CustomEmbedding with the given parameters.

        Args:
            matryoshka_dim (int): Dimension of the embeddings.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(**kwargs)
        query_instruction = self.get_instruction()
        self.matryoshka_dim = matryoshka_dim
        if torch.cuda.is_available():
            logging.info("CUDA is available")
            self.hosted_embedding = self.get_hf_embedd()
            self.cpu_embedding = self.hosted_embedding
        else:
            logging.info("CUDA is not available")
            self.hosted_embedding = self.get_hf_embedd()
            """
            HuggingFaceEndpointEmbeddings is deprecated
            HuggingFaceEndpointEmbeddings(
                model=os.getenv("HF_MODEL"),
                model_kwargs={
                    "encode_kwargs": {
                        "normalize_embeddings": True,
                        "prompt": query_instruction,
                    }
                },
                huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
            )"""
            self.cpu_embedding = self.hosted_embedding

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        Embeds a list of documents using the appropriate embedding model.

        Args:
            texts (List[str]): List of document texts to embed.

        Returns:
            List[List[float]]: List of embedded document vectors.
        """
        try:
            embed = self.hosted_embedding.embed_documents(texts)
        except Exception as e:
            logging.warning(f"Issue with batch hosted embedding, moving to CPU: {e}")
            embed = self.cpu_embedding.embed_documents(texts)
        return (
            [e[: self.matryoshka_dim] for e in embed] if self.matryoshka_dim else embed
        )

    def embed_query(self, text: str) -> List[float]:
        """
        Embeds a single query using the appropriate embedding model.

        Args:
            text (str): The query text to embed.

        Returns:
            List[float]: The embedded query vector.
        """
        try:
            logging.info(text)
            embed = self.hosted_embedding.embed_query(text)
        except Exception as e:
            logging.warning(f"Issue with hosted embedding, moving to CPU: {e}")
            embed = self.cpu_embedding.embed_query(text)
        logging.warning(text)
        return embed[: self.matryoshka_dim] if self.matryoshka_dim else embed