Spaces:
Sleeping
Sleeping
Commit
·
b777c8f
1
Parent(s):
fdc0b3d
add new endpoint
Browse files- src/api/database.py +10 -0
- src/api/models/embedding_models.py +4 -0
- src/main.py +27 -1
src/api/database.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import logging
|
|
|
|
| 2 |
from typing import AsyncGenerator, List, Optional, Dict
|
| 3 |
from pydantic_settings import BaseSettings
|
| 4 |
from pydantic import PostgresDsn
|
|
@@ -191,6 +192,15 @@ class Database:
|
|
| 191 |
raise HealthCheckError("Failed to perform health check.") from e
|
| 192 |
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
# Dependency to get the database instance
|
| 195 |
async def get_db() -> AsyncGenerator[Database, None]:
|
| 196 |
settings = DatabaseSettings()
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
from typing import AsyncGenerator, List, Optional, Dict
|
| 4 |
from pydantic_settings import BaseSettings
|
| 5 |
from pydantic import PostgresDsn
|
|
|
|
| 192 |
raise HealthCheckError("Failed to perform health check.") from e
|
| 193 |
|
| 194 |
|
| 195 |
+
async def get_db_from_url() -> AsyncGenerator[Database, None]:
|
| 196 |
+
db = Database(db_url=os.getenv("DB_URL"), pool_size=5)
|
| 197 |
+
await db.connect()
|
| 198 |
+
try:
|
| 199 |
+
yield db
|
| 200 |
+
finally:
|
| 201 |
+
await db.disconnect()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
# Dependency to get the database instance
|
| 205 |
async def get_db() -> AsyncGenerator[Database, None]:
|
| 206 |
settings = DatabaseSettings()
|
src/api/models/embedding_models.py
CHANGED
|
@@ -65,3 +65,7 @@ class SearchEmbeddingRequest(BaseModel):
|
|
| 65 |
additional_columns: Optional[List[str]] = (
|
| 66 |
None # Optional list of additional columns to include in the results
|
| 67 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
additional_columns: Optional[List[str]] = (
|
| 66 |
None # Optional list of additional columns to include in the results
|
| 67 |
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ResetEmbeddingsRequest(BaseModel):
|
| 71 |
+
dataset_name: str
|
src/main.py
CHANGED
|
@@ -13,8 +13,9 @@ from src.api.models.embedding_models import (
|
|
| 13 |
DeleteEmbeddingRequest,
|
| 14 |
EmbedRequest,
|
| 15 |
SearchEmbeddingRequest,
|
|
|
|
| 16 |
)
|
| 17 |
-
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
|
| 18 |
from src.api.services.embedding_service import EmbeddingService
|
| 19 |
from src.api.services.huggingface_service import HuggingFaceService
|
| 20 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
|
@@ -46,12 +47,19 @@ app = FastAPI(
|
|
| 46 |
|
| 47 |
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# Dependency to get EmbeddingService
|
| 51 |
def get_embedding_service() -> EmbeddingService:
|
| 52 |
return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# Dependency to get HuggingFaceService
|
| 56 |
def get_huggingface_service() -> HuggingFaceService:
|
| 57 |
return HuggingFaceService()
|
|
@@ -296,3 +304,21 @@ async def search_embedding(
|
|
| 296 |
except Exception as e:
|
| 297 |
logger.error(f"An error occurred: {e}")
|
| 298 |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
DeleteEmbeddingRequest,
|
| 14 |
EmbedRequest,
|
| 15 |
SearchEmbeddingRequest,
|
| 16 |
+
ResetEmbeddingsRequest,
|
| 17 |
)
|
| 18 |
+
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
|
| 19 |
from src.api.services.embedding_service import EmbeddingService
|
| 20 |
from src.api.services.huggingface_service import HuggingFaceService
|
| 21 |
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
|
|
|
|
| 47 |
|
| 48 |
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
| 49 |
|
| 50 |
+
# def get_database_service() -> Database:
|
| 51 |
+
# return Database(db_url=os.getenv("DB_URL"))
|
| 52 |
+
|
| 53 |
|
| 54 |
# Dependency to get EmbeddingService
|
| 55 |
def get_embedding_service() -> EmbeddingService:
|
| 56 |
return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
|
| 57 |
|
| 58 |
|
| 59 |
+
# def get_db_from_env():
|
| 60 |
+
# return get_db_from_url(os.getenv("DB_URL"))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
# Dependency to get HuggingFaceService
|
| 64 |
def get_huggingface_service() -> HuggingFaceService:
|
| 65 |
return HuggingFaceService()
|
|
|
|
| 304 |
except Exception as e:
|
| 305 |
logger.error(f"An error occurred: {e}")
|
| 306 |
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@app.post("/reset_embeddings")
|
| 310 |
+
async def reset_embeddings(
|
| 311 |
+
request: ResetEmbeddingsRequest,
|
| 312 |
+
db: Database = Depends(get_db_from_url)
|
| 313 |
+
):
|
| 314 |
+
"""
|
| 315 |
+
Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
|
| 316 |
+
using the actual database
|
| 317 |
+
"""
|
| 318 |
+
try:
|
| 319 |
+
is_healthy = db.health_check()
|
| 320 |
+
if not is_healthy:
|
| 321 |
+
raise HTTPException(status_code=500, detail="Database is unhealthy")
|
| 322 |
+
return {"status": "healthy"}
|
| 323 |
+
except HealthCheckError as e:
|
| 324 |
+
raise HTTPException(status_code=500, detail=str(e))
|