nabilcheikh1 commited on
Commit
b1341dd
·
1 Parent(s): b777c8f

reset embeddings effective

Browse files
Files changed (2) hide show
  1. src/api/models/embedding_models.py +1 -1
  2. src/main.py +24 -6
src/api/models/embedding_models.py CHANGED
@@ -68,4 +68,4 @@ class SearchEmbeddingRequest(BaseModel):
68
 
69
 
70
  class ResetEmbeddingsRequest(BaseModel):
71
- dataset_name: str
 
68
 
69
 
70
  class ResetEmbeddingsRequest(BaseModel):
71
+ dataset_name: str = "re-mind/product_type_embedding"
src/main.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  from fastapi import FastAPI, Depends, HTTPException
3
  from fastapi.responses import JSONResponse, RedirectResponse
4
  from fastapi.middleware.gzip import GZipMiddleware
 
5
  from pydantic import BaseModel
6
  from typing import List, Dict
7
  from datasets import Dataset
@@ -18,6 +19,7 @@ from src.api.models.embedding_models import (
18
  from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
19
  from src.api.services.embedding_service import EmbeddingService
20
  from src.api.services.huggingface_service import HuggingFaceService
 
21
  from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
22
 
23
  import logging
@@ -309,16 +311,32 @@ async def search_embedding(
309
  @app.post("/reset_embeddings")
310
  async def reset_embeddings(
311
  request: ResetEmbeddingsRequest,
312
- db: Database = Depends(get_db_from_url)
 
 
313
  ):
314
  """
315
  Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
316
  using the actual database
317
  """
 
 
318
  try:
319
- is_healthy = db.health_check()
320
- if not is_healthy:
321
- raise HTTPException(status_code=500, detail="Database is unhealthy")
322
- return {"status": "healthy"}
323
- except HealthCheckError as e:
 
 
 
 
 
 
 
 
 
 
 
 
324
  raise HTTPException(status_code=500, detail=str(e))
 
2
  from fastapi import FastAPI, Depends, HTTPException
3
  from fastapi.responses import JSONResponse, RedirectResponse
4
  from fastapi.middleware.gzip import GZipMiddleware
5
+ from pg8000 import DatabaseError
6
  from pydantic import BaseModel
7
  from typing import List, Dict
8
  from datasets import Dataset
 
19
  from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
20
  from src.api.services.embedding_service import EmbeddingService
21
  from src.api.services.huggingface_service import HuggingFaceService
22
+ from src.api.services.postgresql_service import PostgresqlService
23
  from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
24
 
25
  import logging
 
311
  @app.post("/reset_embeddings")
312
  async def reset_embeddings(
313
  request: ResetEmbeddingsRequest,
314
+ db: Database = Depends(get_db_from_url),
315
+ embedding_service: EmbeddingService = Depends(get_embedding_service),
316
+ huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
317
  ):
318
  """
319
  Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
320
  using the actual database
321
  """
322
+ postgresql_service = PostgresqlService(db)
323
+
324
  try:
325
+ # List of rows from database
326
+ results = await postgresql_service.get_db_rows_from_dataset_name(request.dataset_name)
327
+
328
+ # Generation of embeddings for each row
329
+ dataset = Dataset.from_dict(results)
330
+ target_column = "type" if request.dataset_name == "re-mind/product_type_embedding" else "name"
331
+ dataset_embedded = await embedding_service.create_embeddings(dataset, target_column, "embedding")
332
+ # Embeddings up-to-date with database will overwrite old dataset
333
+ await huggingface_service.push_to_hub(dataset_embedded, request.dataset_name)
334
+
335
+ return {
336
+ "message": "Dataset updated succesfully with up-to-date rows from database",
337
+ "dataset_name": request.dataset_name,
338
+ "num_rows": len(dataset_embedded)
339
+ }
340
+
341
+ except DatabaseError as e:
342
  raise HTTPException(status_code=500, detail=str(e))