HoangNB commited on
Commit
133f1d4
·
1 Parent(s): c62049f

Add embedding service and preprocessor; integrate with Gradio interface

Browse files

- Introduced `EmbeddingService` for generating text embeddings using sentence-transformers.
- Added `TextPreprocessor` for cleaning and tokenizing input text.
- Created a new endpoint for obtaining embeddings and integrated it into the Gradio interface.
- Updated `requirements.txt` to include necessary libraries.
- Added configuration settings in `config.py` for model and server parameters.

app.py CHANGED
@@ -1,11 +1,17 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
9
 
10
  def respond(
11
  message,
@@ -39,6 +45,26 @@ def respond(
39
  response += token
40
  yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  """
44
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -59,6 +85,9 @@ demo = gr.ChatInterface(
59
  ],
60
  )
61
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
  demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from app.services.embedding_service import EmbeddingService
4
+ from app.config import EMBEDDING_MODEL # Import from config
5
+ from app.services.preprocessor import TextPreprocessor
6
 
7
  """
8
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
  """
10
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
11
 
12
+ # Initialize EmbeddingService
13
+ embedding_service = EmbeddingService(model_name=EMBEDDING_MODEL, preprocessor=TextPreprocessor())
14
+
15
 
16
  def respond(
17
  message,
 
45
  response += token
46
  yield response
47
 
48
+ def get_embedding(text: str) -> list[float]:
49
+ """
50
+ Endpoint to get the embedding of a text.
51
+ """
52
+ try:
53
+ return embedding_service.get_embedding(text)
54
+ except ValueError as e:
55
+ # Handle the case where the input text is too long
56
+ return f"Error: {str(e)}"
57
+ except Exception as e:
58
+ return f"Error: {str(e)}"
59
+
60
+ # Create a separate Gradio interface for the embedding endpoint
61
+ embedding_iface = gr.Interface(
62
+ fn=get_embedding,
63
+ inputs=gr.Textbox(placeholder="Enter text here...", label="Input Text"),
64
+ outputs=gr.JSON(label="Embedding"), # Use JSON output for the embedding vector
65
+ title="Embedding Service",
66
+ description="Get the embedding of a text using the Vietnamese Bi-Encoder.",
67
+ )
68
 
69
  """
70
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
85
  ],
86
  )
87
 
88
+ # Combine the interfaces
89
+ demo = gr.TabbedInterface([demo, embedding_iface], ["Chatbot", "Embedding"])
90
+
91
 
92
  if __name__ == "__main__":
93
  demo.launch()
app/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ # Load environment variables from .env file
5
+ load_dotenv(override=True)
6
+
7
+ # Application settings
8
+ APP_NAME = "Vietnamese RAG"
9
+ DEBUG = os.getenv("DEBUG", "False").lower() in ("true", "1", "t")
10
+ API_PREFIX = "/api"
11
+
12
+ # Model settings
13
+ EMBEDDING_MODEL = "bkai-foundation-models/vietnamese-bi-encoder"
14
+ MAX_TOKEN_LIMIT = 128
15
+ DEFAULT_CHUNK_SIZE = 110 # Safe margin below MAX_TOKEN_LIMIT
16
+ DEFAULT_CHUNK_OVERLAP = 20
17
+ DEFAULT_TOP_K = 5
18
+
19
+ # Server settings
20
+ HOST = os.getenv("HOST", "0.0.0.0")
21
+ PORT = int(os.getenv("PORT", "8000"))
22
+
23
+ # Cache settings
24
+ ENABLE_CACHE = os.getenv("ENABLE_CACHE", "True").lower() in ("true", "1", "t")
25
+ CACHE_SIZE = int(os.getenv("CACHE_SIZE", "1000"))
app/services/embedding_service.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Dict, Any, Optional, Union
3
+ import numpy as np
4
+ from functools import lru_cache
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ from app.config import (
8
+ EMBEDDING_MODEL,
9
+ MAX_TOKEN_LIMIT,
10
+ ENABLE_CACHE,
11
+ CACHE_SIZE
12
+ )
13
+ from app.services.preprocessor import TextPreprocessor
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class EmbeddingService:
19
+ """Service for generating embeddings for text using sentence-transformers."""
20
+
21
+ def __init__(self, model_name: str = EMBEDDING_MODEL, preprocessor=None):
22
+ """
23
+ Initialize the embedding service.
24
+
25
+ Args:
26
+ model_name: Name of the sentence-transformers model to use
27
+ preprocessor: Optional TextPreprocessor instance
28
+ """
29
+ logger.info(f"Loading embedding model: {model_name}")
30
+ self.model = SentenceTransformer(model_name)
31
+ self.model_dim = self.model.get_sentence_embedding_dimension()
32
+ logger.info(f"Model loaded. Embedding dimension: {self.model_dim}")
33
+
34
+ # Use provided preprocessor or create one
35
+ self.preprocessor = preprocessor or TextPreprocessor()
36
+
37
+ # Set up caching if enabled
38
+ if ENABLE_CACHE:
39
+ self.get_embedding = lru_cache(maxsize=CACHE_SIZE)(self._get_embedding)
40
+ else:
41
+ self.get_embedding = self._get_embedding
42
+
43
+ def _get_embedding(self, text: str) -> List[float]:
44
+ """
45
+ Generate embedding for a text string.
46
+
47
+ Args:
48
+ text: Text to generate embedding for
49
+
50
+ Returns:
51
+ List of floats representing the embedding vector
52
+ """
53
+ if not text or not isinstance(text, str):
54
+ logger.warning("Empty or invalid text provided for embedding generation")
55
+ return [0.0] * self.model_dim
56
+
57
+ # Use preprocessor for token counting only
58
+ token_count = self.preprocessor.count_tokens(text)
59
+
60
+ # Check against token limit
61
+ if token_count > MAX_TOKEN_LIMIT:
62
+ logger.error(
63
+ f"Text exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT}). "
64
+ f"Please chunk your text before encoding."
65
+ )
66
+ raise ValueError(f"Text exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT})")
67
+
68
+ try:
69
+ # Directly encode the text string
70
+ embedding = self.model.encode(text).tolist()
71
+ return embedding
72
+ except Exception as e:
73
+ logger.error(f"Error generating embedding: {str(e)}")
74
+ return [0.0] * self.model_dim
75
+
76
+ def get_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
77
+ """
78
+ Generate embeddings for a batch of texts.
79
+
80
+ Args:
81
+ texts: List of texts to generate embeddings for
82
+
83
+ Returns:
84
+ List of embedding vectors
85
+ """
86
+ if not texts:
87
+ return []
88
+
89
+ # Validate texts are within token limit
90
+ for i, text in enumerate(texts):
91
+ if not text or not isinstance(text, str):
92
+ logger.warning(f"Empty or invalid text at index {i}")
93
+ continue
94
+
95
+ # Check token count
96
+ token_count = self.preprocessor.count_tokens(text)
97
+ if token_count > MAX_TOKEN_LIMIT:
98
+ logger.error(
99
+ f"Text at index {i} exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT}). "
100
+ f"Please chunk your text before encoding."
101
+ )
102
+ raise ValueError(f"Text at index {i} exceeds max token limit ({token_count} > {MAX_TOKEN_LIMIT})")
103
+
104
+ try:
105
+ # Let the model handle the batch encoding directly
106
+ embeddings = self.model.encode(texts).tolist()
107
+ return embeddings
108
+ except Exception as e:
109
+ logger.error(f"Error generating batch embeddings: {str(e)}")
110
+ return [[0.0] * self.model_dim] * len(texts)
111
+
112
+ def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
113
+ """
114
+ Generate embeddings for a list of text chunks.
115
+
116
+ Args:
117
+ chunks: List of chunk dictionaries with text and metadata
118
+
119
+ Returns:
120
+ List of chunk dictionaries with added embeddings
121
+ """
122
+ if not chunks:
123
+ return []
124
+
125
+ # Extract texts from chunks
126
+ texts = [chunk["text"] for chunk in chunks]
127
+
128
+ # Generate embeddings
129
+ embeddings = self.get_embeddings_batch(texts)
130
+
131
+ # Add embeddings to chunks
132
+ result_chunks = []
133
+ for chunk, embedding in zip(chunks, embeddings):
134
+ chunk_with_embedding = chunk.copy()
135
+ chunk_with_embedding["embedding"] = embedding
136
+ result_chunks.append(chunk_with_embedding)
137
+
138
+ return result_chunks
139
+
140
+ def similarity_search(
141
+ self,
142
+ query: str,
143
+ embeddings: List[List[float]],
144
+ texts: List[str],
145
+ metadata: Optional[List[Dict[str, Any]]] = None,
146
+ top_k: int = 5
147
+ ) -> List[Dict[str, Any]]:
148
+ """
149
+ Find the most similar texts to a query.
150
+
151
+ Args:
152
+ query: Query text
153
+ embeddings: List of embedding vectors to search
154
+ texts: List of texts corresponding to the embeddings
155
+ metadata: Optional list of metadata for each text
156
+ top_k: Number of top matches to return
157
+
158
+ Returns:
159
+ List of matches with text, score, and metadata
160
+ """
161
+ if not query or not embeddings or not texts:
162
+ return []
163
+
164
+ if metadata is None:
165
+ metadata = [{} for _ in range(len(texts))]
166
+
167
+ # Generate query embedding
168
+ query_embedding = self.get_embedding(query)
169
+
170
+ # Convert to numpy arrays for efficient computation
171
+ query_embedding_np = np.array(query_embedding)
172
+ embeddings_np = np.array(embeddings)
173
+
174
+ # Compute cosine similarity
175
+ similarity_scores = np.dot(embeddings_np, query_embedding_np) / (
176
+ np.linalg.norm(embeddings_np, axis=1) * np.linalg.norm(query_embedding_np)
177
+ )
178
+
179
+ # Get top-k indices
180
+ if top_k > len(texts):
181
+ top_k = len(texts)
182
+
183
+ top_indices = np.argsort(similarity_scores)[-top_k:][::-1]
184
+
185
+ # Prepare results
186
+ results = []
187
+ for idx in top_indices:
188
+ results.append({
189
+ "text": texts[idx],
190
+ "score": float(similarity_scores[idx]),
191
+ "metadata": metadata[idx]
192
+ })
193
+
194
+ return results
app/services/preprocessor.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from transformers import AutoTokenizer
3
+ from app.config import EMBEDDING_MODEL
4
+
5
+ class TextPreprocessor:
6
+ """
7
+ A simple text preprocessor for cleaning and tokenizing text.
8
+ """
9
+
10
+ def __init__(self, model_name: str = EMBEDDING_MODEL):
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+
13
+ def clean_text(self, text: str) -> str:
14
+ """
15
+ Remove extra whitespace and control characters from text.
16
+
17
+ Args:
18
+ text: The text to clean.
19
+
20
+ Returns:
21
+ The cleaned text.
22
+ """
23
+ text = re.sub(r"[\s\t\n]+", " ", text) # Normalize whitespace
24
+ text = re.sub(r"[\x00-\x1F\x7F]", "", text) # Remove control characters
25
+ return text.strip()
26
+
27
+ def count_tokens(self, text: str) -> int:
28
+ """
29
+ Count the number of tokens in the text using a tokenizer.
30
+
31
+ Args:
32
+ text: The text to tokenize.
33
+
34
+ Returns:
35
+ The number of tokens.
36
+ """
37
+ # Tokenize the text and return the length of the input IDs
38
+ return len(self.tokenizer(text).input_ids)
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
1
+ huggingface_hub==0.25.2
2
+ gradio
3
+ sentence-transformers
4
+ python-dotenv
5
+ transformers