lamhieu commited on
Commit
ddd02b3
·
1 Parent(s): 676899f

fix: update model encoding flow

Browse files
lightweight_embeddings/router.py CHANGED
@@ -21,8 +21,7 @@ from __future__ import annotations
21
 
22
  import logging
23
  import os
24
- from typing import Dict, Any, List, Union
25
- from enum import Enum
26
  from datetime import datetime
27
 
28
  from fastapi import APIRouter, BackgroundTasks, HTTPException
@@ -32,8 +31,9 @@ from .analytics import Analytics
32
  from .service import (
33
  ModelConfig,
34
  TextModelType,
35
- ImageModelType,
36
  EmbeddingsService,
 
 
37
  )
38
 
39
  logger = logging.getLogger(__name__)
@@ -44,28 +44,6 @@ router = APIRouter(
44
  )
45
 
46
 
47
- class ModelKind(str, Enum):
48
- TEXT = "text"
49
- IMAGE = "image"
50
-
51
-
52
- def detect_model_kind(model_id: str) -> ModelKind:
53
- """
54
- Detect whether model_id is for a text or an image model.
55
- Raises ValueError if unrecognized.
56
- """
57
- if model_id in [m.value for m in TextModelType]:
58
- return ModelKind.TEXT
59
- elif model_id in [m.value for m in ImageModelType]:
60
- return ModelKind.IMAGE
61
- else:
62
- raise ValueError(
63
- f"Unrecognized model ID: {model_id}.\n"
64
- f"Valid text: {[m.value for m in TextModelType]}\n"
65
- f"Valid image: {[m.value for m in ImageModelType]}"
66
- )
67
-
68
-
69
  class EmbeddingRequest(BaseModel):
70
  """
71
  Input to /v1/embeddings
@@ -147,7 +125,7 @@ embeddings_service = EmbeddingsService(config=service_config)
147
  analytics = Analytics(
148
  url=os.environ.get("REDIS_URL", "redis://localhost:6379/0"),
149
  token=os.environ.get("REDIS_TOKEN", "***"),
150
- sync_interval=5 * 60, # 5 minutes
151
  )
152
 
153
 
@@ -159,23 +137,15 @@ async def create_embeddings(
159
  Generates embeddings for the given input (text or image).
160
  """
161
  try:
162
- # 1) Determine if it's text or image
163
- mkind = detect_model_kind(request.model)
164
-
165
- # 2) Update global service config so it uses the correct model
166
- if mkind == ModelKind.TEXT:
167
- service_config.text_model_type = TextModelType(request.model)
168
- else:
169
- service_config.image_model_type = ImageModelType(request.model)
170
-
171
- # 3) Generate
172
  embeddings = await embeddings_service.generate_embeddings(
173
- input_data=request.input, modality=mkind.value
 
174
  )
175
 
176
- # 4) Estimate tokens for text only
177
  total_tokens = 0
178
- if mkind == ModelKind.TEXT:
179
  total_tokens = embeddings_service.estimate_tokens(request.input)
180
 
181
  resp = {
@@ -218,17 +188,10 @@ async def rank_candidates(request: RankRequest, background_tasks: BackgroundTask
218
  Ranks candidate texts against the given queries (which can be text or image).
219
  """
220
  try:
221
- mkind = detect_model_kind(request.model)
222
-
223
- if mkind == ModelKind.TEXT:
224
- service_config.text_model_type = TextModelType(request.model)
225
- else:
226
- service_config.image_model_type = ImageModelType(request.model)
227
-
228
  results = await embeddings_service.rank(
 
229
  queries=request.queries,
230
  candidates=request.candidates,
231
- modality=mkind.value,
232
  )
233
 
234
  background_tasks.add_task(
 
21
 
22
  import logging
23
  import os
24
+ from typing import Dict, List, Union
 
25
  from datetime import datetime
26
 
27
  from fastapi import APIRouter, BackgroundTasks, HTTPException
 
31
  from .service import (
32
  ModelConfig,
33
  TextModelType,
 
34
  EmbeddingsService,
35
+ ModelKind,
36
+ detect_model_kind,
37
  )
38
 
39
  logger = logging.getLogger(__name__)
 
44
  )
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class EmbeddingRequest(BaseModel):
48
  """
49
  Input to /v1/embeddings
 
125
  analytics = Analytics(
126
  url=os.environ.get("REDIS_URL", "redis://localhost:6379/0"),
127
  token=os.environ.get("REDIS_TOKEN", "***"),
128
+ sync_interval=5 * 60, # 5 minutes
129
  )
130
 
131
 
 
137
  Generates embeddings for the given input (text or image).
138
  """
139
  try:
140
+ modality = detect_model_kind(request.model)
 
 
 
 
 
 
 
 
 
141
  embeddings = await embeddings_service.generate_embeddings(
142
+ inputs=request.input,
143
+ model=request.model,
144
  )
145
 
146
+ # Estimate tokens for text only
147
  total_tokens = 0
148
+ if modality == ModelKind.TEXT:
149
  total_tokens = embeddings_service.estimate_tokens(request.input)
150
 
151
  resp = {
 
188
  Ranks candidate texts against the given queries (which can be text or image).
189
  """
190
  try:
 
 
 
 
 
 
 
191
  results = await embeddings_service.rank(
192
+ model=request.model,
193
  queries=request.queries,
194
  candidates=request.candidates,
 
195
  )
196
 
197
  background_tasks.add_task(
lightweight_embeddings/service.py CHANGED
@@ -28,7 +28,7 @@ from __future__ import annotations
28
 
29
  import logging
30
  from enum import Enum
31
- from typing import List, Union, Literal, Dict, Optional, NamedTuple, Any
32
  from dataclasses import dataclass
33
  from pathlib import Path
34
  from io import BytesIO
@@ -149,6 +149,28 @@ class ModelConfig:
149
  return image_configs[self.image_model_type]
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  class EmbeddingsService:
153
  """
154
  Service for generating text/image embeddings and performing ranking.
@@ -264,7 +286,11 @@ class EmbeddingsService:
264
  except Exception as e:
265
  raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e
266
 
267
- def _generate_text_embeddings(self, texts: List[str]) -> np.ndarray:
 
 
 
 
268
  """
269
  Generate text embeddings using the currently configured text model
270
  with an LRU cache for single-text requests.
@@ -274,7 +300,7 @@ class EmbeddingsService:
274
  key = md5(texts[0].encode("utf-8")).hexdigest()
275
  if key in self.lru_cache:
276
  return self.lru_cache[key]
277
- model = self.text_models[self.config.text_model_type]
278
  embeddings = model.encode(texts)
279
 
280
  if len(texts) == 1:
@@ -287,6 +313,7 @@ class EmbeddingsService:
287
 
288
  def _generate_image_embeddings(
289
  self,
 
290
  images: Union[str, List[str]],
291
  batch_size: Optional[int] = None,
292
  ) -> np.ndarray:
@@ -295,7 +322,7 @@ class EmbeddingsService:
295
  If `batch_size` is None, all images are processed at once.
296
  """
297
  try:
298
- model = self.image_models[self.config.image_model_type]
299
 
300
  # Single image
301
  if isinstance(images, str):
@@ -341,36 +368,57 @@ class EmbeddingsService:
341
 
342
  async def generate_embeddings(
343
  self,
344
- input_data: Union[str, List[str]],
345
- modality: Literal["text", "image"],
346
  batch_size: Optional[int] = None,
347
  ) -> np.ndarray:
348
  """
349
  Asynchronously generate embeddings for text or image.
350
  """
 
 
 
 
 
 
 
 
351
  self._validate_modality(modality)
352
- if modality == "text":
353
- text_list = self._validate_text_input(input_data)
354
- return self._generate_text_embeddings(text_list)
355
- else:
356
- return self._generate_image_embeddings(input_data, batch_size=batch_size)
 
 
357
 
358
  async def rank(
359
  self,
 
360
  queries: Union[str, List[str]],
361
  candidates: List[str],
362
- modality: Literal["text", "image"],
363
  batch_size: Optional[int] = None,
364
  ) -> Dict[str, Any]:
365
  """
366
  Rank candidates (always text) against the queries, which may be text or image.
367
  Returns dict of { probabilities, cosine_similarities, usage }.
368
  """
 
 
 
 
 
 
 
369
 
370
  # 1) Generate embeddings for queries
371
- query_embeds = await self.generate_embeddings(queries, modality, batch_size)
 
 
372
  # 2) Generate embeddings for text candidates
373
- candidate_embeds = await self.generate_embeddings(candidates, "text")
 
 
374
 
375
  # 3) Compute cosine similarity
376
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
 
28
 
29
  import logging
30
  from enum import Enum
31
+ from typing import List, Union, Dict, Optional, NamedTuple, Any
32
  from dataclasses import dataclass
33
  from pathlib import Path
34
  from io import BytesIO
 
149
  return image_configs[self.image_model_type]
150
 
151
 
152
+ class ModelKind(str, Enum):
153
+ TEXT = "text"
154
+ IMAGE = "image"
155
+
156
+
157
+ def detect_model_kind(model_id: str) -> ModelKind:
158
+ """
159
+ Detect whether model_id is for a text or an image model.
160
+ Raises ValueError if unrecognized.
161
+ """
162
+ if model_id in [m.value for m in TextModelType]:
163
+ return ModelKind.TEXT
164
+ elif model_id in [m.value for m in ImageModelType]:
165
+ return ModelKind.IMAGE
166
+ else:
167
+ raise ValueError(
168
+ f"Unrecognized model ID: {model_id}.\n"
169
+ f"Valid text: {[m.value for m in TextModelType]}\n"
170
+ f"Valid image: {[m.value for m in ImageModelType]}"
171
+ )
172
+
173
+
174
  class EmbeddingsService:
175
  """
176
  Service for generating text/image embeddings and performing ranking.
 
286
  except Exception as e:
287
  raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e
288
 
289
+ def _generate_text_embeddings(
290
+ self,
291
+ model_id: TextModelType,
292
+ texts: List[str],
293
+ ) -> np.ndarray:
294
  """
295
  Generate text embeddings using the currently configured text model
296
  with an LRU cache for single-text requests.
 
300
  key = md5(texts[0].encode("utf-8")).hexdigest()
301
  if key in self.lru_cache:
302
  return self.lru_cache[key]
303
+ model = self.text_models[model_id]
304
  embeddings = model.encode(texts)
305
 
306
  if len(texts) == 1:
 
313
 
314
  def _generate_image_embeddings(
315
  self,
316
+ model_id: ImageModelType,
317
  images: Union[str, List[str]],
318
  batch_size: Optional[int] = None,
319
  ) -> np.ndarray:
 
322
  If `batch_size` is None, all images are processed at once.
323
  """
324
  try:
325
+ model = self.image_models[model_id]
326
 
327
  # Single image
328
  if isinstance(images, str):
 
368
 
369
  async def generate_embeddings(
370
  self,
371
+ model: str,
372
+ inputs: Union[str, List[str]],
373
  batch_size: Optional[int] = None,
374
  ) -> np.ndarray:
375
  """
376
  Asynchronously generate embeddings for text or image.
377
  """
378
+ # Determine if it's text or image
379
+ modality = detect_model_kind(model)
380
+ model_id = (
381
+ TextModelType(model)
382
+ if modality == ModelKind.TEXT
383
+ else ImageModelType(model)
384
+ )
385
+
386
  self._validate_modality(modality)
387
+ if modality == "text" and isinstance(model_id, TextModelType):
388
+ text_list = self._validate_text_input(inputs)
389
+ return self._generate_text_embeddings(model_id=model_id, texts=text_list)
390
+ elif modality == "image" and isinstance(model_id, ImageModelType):
391
+ return self._generate_image_embeddings(
392
+ model_id=model_id, images=inputs, batch_size=batch_size
393
+ )
394
 
395
  async def rank(
396
  self,
397
+ model: str,
398
  queries: Union[str, List[str]],
399
  candidates: List[str],
 
400
  batch_size: Optional[int] = None,
401
  ) -> Dict[str, Any]:
402
  """
403
  Rank candidates (always text) against the queries, which may be text or image.
404
  Returns dict of { probabilities, cosine_similarities, usage }.
405
  """
406
+ # Determine if it's text or image
407
+ modality = detect_model_kind(model)
408
+ model_id = (
409
+ TextModelType(model)
410
+ if modality == ModelKind.TEXT
411
+ else ImageModelType(model)
412
+ )
413
 
414
  # 1) Generate embeddings for queries
415
+ query_embeds = await self.generate_embeddings(
416
+ model=model_id, inputs=queries, batch_size=batch_size
417
+ )
418
  # 2) Generate embeddings for text candidates
419
+ candidate_embeds = await self.generate_embeddings(
420
+ model=model_id, inputs=candidates, batch_size=batch_size
421
+ )
422
 
423
  # 3) Compute cosine similarity
424
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)