Spaces:
Running
Running
Commit
·
b1341dd
1
Parent(s):
b777c8f
reset embeddings effective
Browse files- src/api/models/embedding_models.py +1 -1
- src/main.py +24 -6
src/api/models/embedding_models.py
CHANGED
@@ -68,4 +68,4 @@ class SearchEmbeddingRequest(BaseModel):
|
|
68 |
|
69 |
|
70 |
class ResetEmbeddingsRequest(BaseModel):
|
71 |
-
dataset_name: str
|
|
|
68 |
|
69 |
|
70 |
class ResetEmbeddingsRequest(BaseModel):
|
71 |
+
dataset_name: str = "re-mind/product_type_embedding"
|
src/main.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
from fastapi import FastAPI, Depends, HTTPException
|
3 |
from fastapi.responses import JSONResponse, RedirectResponse
|
4 |
from fastapi.middleware.gzip import GZipMiddleware
|
|
|
5 |
from pydantic import BaseModel
|
6 |
from typing import List, Dict
|
7 |
from datasets import Dataset
|
@@ -18,6 +19,7 @@ from src.api.models.embedding_models import (
|
|
18 |
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
|
19 |
from src.api.services.embedding_service import EmbeddingService
|
20 |
from src.api.services.huggingface_service import HuggingFaceService
|
|
|
21 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
22 |
|
23 |
import logging
|
@@ -309,16 +311,32 @@ async def search_embedding(
|
|
309 |
@app.post("/reset_embeddings")
|
310 |
async def reset_embeddings(
|
311 |
request: ResetEmbeddingsRequest,
|
312 |
-
db: Database = Depends(get_db_from_url)
|
|
|
|
|
313 |
):
|
314 |
"""
|
315 |
Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
|
316 |
using the actual database
|
317 |
"""
|
|
|
|
|
318 |
try:
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
2 |
from fastapi import FastAPI, Depends, HTTPException
|
3 |
from fastapi.responses import JSONResponse, RedirectResponse
|
4 |
from fastapi.middleware.gzip import GZipMiddleware
|
5 |
+
from pg8000 import DatabaseError
|
6 |
from pydantic import BaseModel
|
7 |
from typing import List, Dict
|
8 |
from datasets import Dataset
|
|
|
19 |
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
|
20 |
from src.api.services.embedding_service import EmbeddingService
|
21 |
from src.api.services.huggingface_service import HuggingFaceService
|
22 |
+
from src.api.services.postgresql_service import PostgresqlService
|
23 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
24 |
|
25 |
import logging
|
|
|
311 |
@app.post("/reset_embeddings")
|
312 |
async def reset_embeddings(
|
313 |
request: ResetEmbeddingsRequest,
|
314 |
+
db: Database = Depends(get_db_from_url),
|
315 |
+
embedding_service: EmbeddingService = Depends(get_embedding_service),
|
316 |
+
huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
|
317 |
):
|
318 |
"""
|
319 |
Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
|
320 |
using the actual database
|
321 |
"""
|
322 |
+
postgresql_service = PostgresqlService(db)
|
323 |
+
|
324 |
try:
|
325 |
+
# List of rows from database
|
326 |
+
results = await postgresql_service.get_db_rows_from_dataset_name(request.dataset_name)
|
327 |
+
|
328 |
+
# Generation of embeddings for each row
|
329 |
+
dataset = Dataset.from_dict(results)
|
330 |
+
target_column = "type" if request.dataset_name == "re-mind/product_type_embedding" else "name"
|
331 |
+
dataset_embedded = await embedding_service.create_embeddings(dataset, target_column, "embedding")
|
332 |
+
# Embeddings up-to-date with database will overwrite old dataset
|
333 |
+
await huggingface_service.push_to_hub(dataset_embedded, request.dataset_name)
|
334 |
+
|
335 |
+
return {
|
336 |
+
"message": "Dataset updated succesfully with up-to-date rows from database",
|
337 |
+
"dataset_name": request.dataset_name,
|
338 |
+
"num_rows": len(dataset_embedded)
|
339 |
+
}
|
340 |
+
|
341 |
+
except DatabaseError as e:
|
342 |
raise HTTPException(status_code=500, detail=str(e))
|