Spaces:
Running
Running
Commit
·
b777c8f
1
Parent(s):
fdc0b3d
add new endpoint
Browse files- src/api/database.py +10 -0
- src/api/models/embedding_models.py +4 -0
- src/main.py +27 -1
src/api/database.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging
|
|
|
2 |
from typing import AsyncGenerator, List, Optional, Dict
|
3 |
from pydantic_settings import BaseSettings
|
4 |
from pydantic import PostgresDsn
|
@@ -191,6 +192,15 @@ class Database:
|
|
191 |
raise HealthCheckError("Failed to perform health check.") from e
|
192 |
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
# Dependency to get the database instance
|
195 |
async def get_db() -> AsyncGenerator[Database, None]:
|
196 |
settings = DatabaseSettings()
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
from typing import AsyncGenerator, List, Optional, Dict
|
4 |
from pydantic_settings import BaseSettings
|
5 |
from pydantic import PostgresDsn
|
|
|
192 |
raise HealthCheckError("Failed to perform health check.") from e
|
193 |
|
194 |
|
195 |
+
async def get_db_from_url() -> AsyncGenerator[Database, None]:
|
196 |
+
db = Database(db_url=os.getenv("DB_URL"), pool_size=5)
|
197 |
+
await db.connect()
|
198 |
+
try:
|
199 |
+
yield db
|
200 |
+
finally:
|
201 |
+
await db.disconnect()
|
202 |
+
|
203 |
+
|
204 |
# Dependency to get the database instance
|
205 |
async def get_db() -> AsyncGenerator[Database, None]:
|
206 |
settings = DatabaseSettings()
|
src/api/models/embedding_models.py
CHANGED
@@ -65,3 +65,7 @@ class SearchEmbeddingRequest(BaseModel):
|
|
65 |
additional_columns: Optional[List[str]] = (
|
66 |
None # Optional list of additional columns to include in the results
|
67 |
)
|
|
|
|
|
|
|
|
|
|
65 |
additional_columns: Optional[List[str]] = (
|
66 |
None # Optional list of additional columns to include in the results
|
67 |
)
|
68 |
+
|
69 |
+
|
70 |
+
class ResetEmbeddingsRequest(BaseModel):
|
71 |
+
dataset_name: str
|
src/main.py
CHANGED
@@ -13,8 +13,9 @@ from src.api.models.embedding_models import (
|
|
13 |
DeleteEmbeddingRequest,
|
14 |
EmbedRequest,
|
15 |
SearchEmbeddingRequest,
|
|
|
16 |
)
|
17 |
-
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
|
18 |
from src.api.services.embedding_service import EmbeddingService
|
19 |
from src.api.services.huggingface_service import HuggingFaceService
|
20 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
@@ -46,12 +47,19 @@ app = FastAPI(
|
|
46 |
|
47 |
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
48 |
|
|
|
|
|
|
|
49 |
|
50 |
# Dependency to get EmbeddingService
|
51 |
def get_embedding_service() -> EmbeddingService:
|
52 |
return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
55 |
# Dependency to get HuggingFaceService
|
56 |
def get_huggingface_service() -> HuggingFaceService:
|
57 |
return HuggingFaceService()
|
@@ -296,3 +304,21 @@ async def search_embedding(
|
|
296 |
except Exception as e:
|
297 |
logger.error(f"An error occurred: {e}")
|
298 |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
DeleteEmbeddingRequest,
|
14 |
EmbedRequest,
|
15 |
SearchEmbeddingRequest,
|
16 |
+
ResetEmbeddingsRequest,
|
17 |
)
|
18 |
+
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
|
19 |
from src.api.services.embedding_service import EmbeddingService
|
20 |
from src.api.services.huggingface_service import HuggingFaceService
|
21 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
|
|
47 |
|
48 |
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
49 |
|
50 |
+
# def get_database_service() -> Database:
|
51 |
+
# return Database(db_url=os.getenv("DB_URL"))
|
52 |
+
|
53 |
|
54 |
# Dependency to get EmbeddingService
|
55 |
def get_embedding_service() -> EmbeddingService:
|
56 |
return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
|
57 |
|
58 |
|
59 |
+
# def get_db_from_env():
|
60 |
+
# return get_db_from_url(os.getenv("DB_URL"))
|
61 |
+
|
62 |
+
|
63 |
# Dependency to get HuggingFaceService
|
64 |
def get_huggingface_service() -> HuggingFaceService:
|
65 |
return HuggingFaceService()
|
|
|
304 |
except Exception as e:
|
305 |
logger.error(f"An error occurred: {e}")
|
306 |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
|
307 |
+
|
308 |
+
|
309 |
+
@app.post("/reset_embeddings")
|
310 |
+
async def reset_embeddings(
|
311 |
+
request: ResetEmbeddingsRequest,
|
312 |
+
db: Database = Depends(get_db_from_url)
|
313 |
+
):
|
314 |
+
"""
|
315 |
+
Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
|
316 |
+
using the actual database
|
317 |
+
"""
|
318 |
+
try:
|
319 |
+
is_healthy = db.health_check()
|
320 |
+
if not is_healthy:
|
321 |
+
raise HTTPException(status_code=500, detail="Database is unhealthy")
|
322 |
+
return {"status": "healthy"}
|
323 |
+
except HealthCheckError as e:
|
324 |
+
raise HTTPException(status_code=500, detail=str(e))
|