nabilcheikh1 commited on
Commit
b777c8f
·
1 Parent(s): fdc0b3d

add new endpoint

Browse files
src/api/database.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  from typing import AsyncGenerator, List, Optional, Dict
3
  from pydantic_settings import BaseSettings
4
  from pydantic import PostgresDsn
@@ -191,6 +192,15 @@ class Database:
191
  raise HealthCheckError("Failed to perform health check.") from e
192
 
193
 
 
 
 
 
 
 
 
 
 
194
  # Dependency to get the database instance
195
  async def get_db() -> AsyncGenerator[Database, None]:
196
  settings = DatabaseSettings()
 
1
  import logging
2
+ import os
3
  from typing import AsyncGenerator, List, Optional, Dict
4
  from pydantic_settings import BaseSettings
5
  from pydantic import PostgresDsn
 
192
  raise HealthCheckError("Failed to perform health check.") from e
193
 
194
 
195
+ async def get_db_from_url() -> AsyncGenerator[Database, None]:
196
+ db = Database(db_url=os.getenv("DB_URL"), pool_size=5)
197
+ await db.connect()
198
+ try:
199
+ yield db
200
+ finally:
201
+ await db.disconnect()
202
+
203
+
204
  # Dependency to get the database instance
205
  async def get_db() -> AsyncGenerator[Database, None]:
206
  settings = DatabaseSettings()
src/api/models/embedding_models.py CHANGED
@@ -65,3 +65,7 @@ class SearchEmbeddingRequest(BaseModel):
65
  additional_columns: Optional[List[str]] = (
66
  None # Optional list of additional columns to include in the results
67
  )
 
 
 
 
 
65
  additional_columns: Optional[List[str]] = (
66
  None # Optional list of additional columns to include in the results
67
  )
68
+
69
+
70
+ class ResetEmbeddingsRequest(BaseModel):
71
+ dataset_name: str
src/main.py CHANGED
@@ -13,8 +13,9 @@ from src.api.models.embedding_models import (
13
  DeleteEmbeddingRequest,
14
  EmbedRequest,
15
  SearchEmbeddingRequest,
 
16
  )
17
- from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
18
  from src.api.services.embedding_service import EmbeddingService
19
  from src.api.services.huggingface_service import HuggingFaceService
20
  from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
@@ -46,12 +47,19 @@ app = FastAPI(
46
 
47
  app.add_middleware(GZipMiddleware, minimum_size=1000)
48
 
 
 
 
49
 
50
  # Dependency to get EmbeddingService
51
  def get_embedding_service() -> EmbeddingService:
52
  return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
53
 
54
 
 
 
 
 
55
  # Dependency to get HuggingFaceService
56
  def get_huggingface_service() -> HuggingFaceService:
57
  return HuggingFaceService()
@@ -296,3 +304,21 @@ async def search_embedding(
296
  except Exception as e:
297
  logger.error(f"An error occurred: {e}")
298
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  DeleteEmbeddingRequest,
14
  EmbedRequest,
15
  SearchEmbeddingRequest,
16
+ ResetEmbeddingsRequest,
17
  )
18
+ from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError, get_db_from_url
19
  from src.api.services.embedding_service import EmbeddingService
20
  from src.api.services.huggingface_service import HuggingFaceService
21
  from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
 
47
 
48
  app.add_middleware(GZipMiddleware, minimum_size=1000)
49
 
50
+ # def get_database_service() -> Database:
51
+ # return Database(db_url=os.getenv("DB_URL"))
52
+
53
 
54
  # Dependency to get EmbeddingService
55
  def get_embedding_service() -> EmbeddingService:
56
  return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
57
 
58
 
59
+ # def get_db_from_env():
60
+ # return get_db_from_url(os.getenv("DB_URL"))
61
+
62
+
63
  # Dependency to get HuggingFaceService
64
  def get_huggingface_service() -> HuggingFaceService:
65
  return HuggingFaceService()
 
304
  except Exception as e:
305
  logger.error(f"An error occurred: {e}")
306
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
307
+
308
+
309
+ @app.post("/reset_embeddings")
310
+ async def reset_embeddings(
311
+ request: ResetEmbeddingsRequest,
312
+ db: Database = Depends(get_db_from_url)
313
+ ):
314
+ """
315
+ Reset embeddings from a Hugging Face dataset by deleting them, then reloading them
316
+ using the actual database
317
+ """
318
+ try:
319
+ is_healthy = db.health_check()
320
+ if not is_healthy:
321
+ raise HTTPException(status_code=500, detail="Database is unhealthy")
322
+ return {"status": "healthy"}
323
+ except HealthCheckError as e:
324
+ raise HTTPException(status_code=500, detail=str(e))