davanstrien HF staff commited on
Commit
2057a2c
·
1 Parent(s): 2d844f4

chore: Refactor main.py for improved readability and maintainability

Browse files
Files changed (1) hide show
  1. main.py +96 -0
main.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ from contextlib import asynccontextmanager
3
+ from fastapi import FastAPI, HTTPException, Query
4
+ from pydantic import BaseModel
5
+ import chromadb
6
+ import logging
7
+ from load_data import get_save_path, refresh_data
8
+ from cashews import cache
9
+
10
+ # Set up logging
11
+ logging.basicConfig(
12
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
13
+ )
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Set up caching
17
+ cache.setup("mem://?check_interval=10&size=10000")
18
+
19
+ # Initialize Chroma client
20
+ SAVE_PATH = get_save_path()
21
+ client = chromadb.PersistentClient(path=SAVE_PATH)
22
+ collection = client.get_collection("dataset_cards")
23
+
24
+
25
+ class QueryResult(BaseModel):
26
+ dataset_id: str
27
+ similarity: float
28
+
29
+
30
+ class QueryResponse(BaseModel):
31
+ results: List[QueryResult]
32
+
33
+
34
+ @asynccontextmanager
35
+ async def lifespan(app: FastAPI):
36
+ # Startup: refresh data
37
+ logger.info("Starting up the application")
38
+ try:
39
+ refresh_data()
40
+ logger.info("Data refresh completed successfully")
41
+ except Exception as e:
42
+ logger.error(f"Error during data refresh: {str(e)}")
43
+
44
+ yield # Here the app is running and handling requests
45
+
46
+ # Shutdown: perform any cleanup
47
+ logger.info("Shutting down the application")
48
+ # Add any cleanup code here if needed
49
+
50
+
51
+ app = FastAPI(lifespan=lifespan)
52
+
53
+
54
+ @app.get("/query", response_model=Optional[QueryResponse])
55
+ @cache(ttl="1h")
56
+ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
57
+ try:
58
+ logger.info(f"Querying dataset: {dataset_id}")
59
+ # Get the embedding for the given dataset_id
60
+ result = collection.get(ids=[dataset_id], include=["embeddings"])
61
+
62
+ if not result["embeddings"]:
63
+ logger.info(f"Dataset not found: {dataset_id}")
64
+ raise HTTPException(status_code=404, detail="Dataset not found")
65
+
66
+ embedding = result["embeddings"][0]
67
+
68
+ # Query the collection for similar datasets
69
+ query_result = collection.query(
70
+ query_embeddings=[embedding], n_results=n, include=["distances"]
71
+ )
72
+
73
+ if not query_result["ids"]:
74
+ logger.info(f"No similar datasets found for: {dataset_id}")
75
+ return None
76
+
77
+ # Prepare the response
78
+ results = [
79
+ QueryResult(dataset_id=id, similarity=1 - distance)
80
+ for id, distance in zip(
81
+ query_result["ids"][0], query_result["distances"][0]
82
+ )
83
+ ]
84
+
85
+ logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
86
+ return QueryResponse(results=results)
87
+
88
+ except Exception as e:
89
+ logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
90
+ raise HTTPException(status_code=500, detail=str(e))
91
+
92
+
93
+ if __name__ == "__main__":
94
+ import uvicorn
95
+
96
+ uvicorn.run(app, host="0.0.0.0", port=8000)