alexneakameni commited on
Commit
5e2d740
·
verified ·
1 Parent(s): 50d55b1

Compute embedding using local resource directly

Browse files
Files changed (1) hide show
  1. src/utilities/embedding.py +10 -32
src/utilities/embedding.py CHANGED
@@ -4,21 +4,14 @@ from typing import Any, List
4
 
5
  import torch
6
  from langchain_core.embeddings import Embeddings
7
- from langchain_huggingface import (
8
- HuggingFaceEmbeddings,
9
- HuggingFaceEndpointEmbeddings,
10
- )
11
  from pydantic import BaseModel, Field
12
 
13
-
14
  class CustomEmbedding(BaseModel, Embeddings):
15
  """
16
- Custom embedding class that supports both hosted and CPU embeddings.
17
  """
18
 
19
- hosted_embedding: HuggingFaceEndpointEmbeddings = Field(
20
- default_factory=lambda: None
21
- )
22
  cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
23
  matryoshka_dim: int = Field(default=256)
24
 
@@ -71,23 +64,8 @@ class CustomEmbedding(BaseModel, Embeddings):
71
  super().__init__(**kwargs)
72
  query_instruction = self.get_instruction()
73
  self.matryoshka_dim = matryoshka_dim
74
- if torch.cuda.is_available():
75
- logging.info("CUDA is available")
76
- self.hosted_embedding = self.get_hf_embedd()
77
- self.cpu_embedding = self.hosted_embedding
78
- else:
79
- logging.info("CUDA is not available")
80
- self.hosted_embedding = HuggingFaceEndpointEmbeddings(
81
- model=os.getenv("HF_MODEL"),
82
- model_kwargs={
83
- "encode_kwargs": {
84
- "normalize_embeddings": True,
85
- "prompt": query_instruction,
86
- }
87
- },
88
- huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
89
- )
90
- self.cpu_embedding = self.get_hf_embedd()
91
 
92
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
93
  """
@@ -100,10 +78,10 @@ class CustomEmbedding(BaseModel, Embeddings):
100
  List[List[float]]: List of embedded document vectors.
101
  """
102
  try:
103
- embed = self.hosted_embedding.embed_documents(texts)
104
- except Exception as e:
105
- logging.warning(f"Issue with batch hosted embedding, moving to CPU: {e}")
106
  embed = self.cpu_embedding.embed_documents(texts)
 
 
 
107
  return (
108
  [e[: self.matryoshka_dim] for e in embed] if self.matryoshka_dim else embed
109
  )
@@ -119,9 +97,9 @@ class CustomEmbedding(BaseModel, Embeddings):
119
  List[float]: The embedded query vector.
120
  """
121
  try:
122
- embed = self.hosted_embedding.embed_query(text)
123
- except Exception as e:
124
- logging.warning(f"Issue with hosted embedding, moving to CPU: {e}")
125
  embed = self.cpu_embedding.embed_query(text)
 
 
 
126
  logging.warning(text)
127
  return embed[: self.matryoshka_dim] if self.matryoshka_dim else embed
 
4
 
5
  import torch
6
  from langchain_core.embeddings import Embeddings
7
+ from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
8
  from pydantic import BaseModel, Field
9
 
 
10
  class CustomEmbedding(BaseModel, Embeddings):
11
  """
12
+ Custom embedding class that supports CPU embeddings.
13
  """
14
 
 
 
 
15
  cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
16
  matryoshka_dim: int = Field(default=256)
17
 
 
64
  super().__init__(**kwargs)
65
  query_instruction = self.get_instruction()
66
  self.matryoshka_dim = matryoshka_dim
67
+ logging.info("Initializing CPU embedding")
68
+ self.cpu_embedding = self.get_hf_embedd()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
71
  """
 
78
  List[List[float]]: List of embedded document vectors.
79
  """
80
  try:
 
 
 
81
  embed = self.cpu_embedding.embed_documents(texts)
82
+ except Exception as e:
83
+ logging.warning(f"Issue with CPU embedding: {e}")
84
+ embed = []
85
  return (
86
  [e[: self.matryoshka_dim] for e in embed] if self.matryoshka_dim else embed
87
  )
 
97
  List[float]: The embedded query vector.
98
  """
99
  try:
 
 
 
100
  embed = self.cpu_embedding.embed_query(text)
101
+ except Exception as e:
102
+ logging.warning(f"Issue with CPU embedding: {e}")
103
+ embed = []
104
  logging.warning(text)
105
  return embed[: self.matryoshka_dim] if self.matryoshka_dim else embed