Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,839 Bytes
2057a2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import Optional, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
import chromadb
import logging
from load_data import get_save_path, refresh_data
from cashews import cache
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Set up caching
cache.setup("mem://?check_interval=10&size=10000")
# Initialize Chroma client
SAVE_PATH = get_save_path()
client = chromadb.PersistentClient(path=SAVE_PATH)
collection = client.get_collection("dataset_cards")
class QueryResult(BaseModel):
dataset_id: str
similarity: float
class QueryResponse(BaseModel):
results: List[QueryResult]
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: refresh data
logger.info("Starting up the application")
try:
refresh_data()
logger.info("Data refresh completed successfully")
except Exception as e:
logger.error(f"Error during data refresh: {str(e)}")
yield # Here the app is running and handling requests
# Shutdown: perform any cleanup
logger.info("Shutting down the application")
# Add any cleanup code here if needed
app = FastAPI(lifespan=lifespan)
@app.get("/query", response_model=Optional[QueryResponse])
@cache(ttl="1h")
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
try:
logger.info(f"Querying dataset: {dataset_id}")
# Get the embedding for the given dataset_id
result = collection.get(ids=[dataset_id], include=["embeddings"])
if not result["embeddings"]:
logger.info(f"Dataset not found: {dataset_id}")
raise HTTPException(status_code=404, detail="Dataset not found")
embedding = result["embeddings"][0]
# Query the collection for similar datasets
query_result = collection.query(
query_embeddings=[embedding], n_results=n, include=["distances"]
)
if not query_result["ids"]:
logger.info(f"No similar datasets found for: {dataset_id}")
return None
# Prepare the response
results = [
QueryResult(dataset_id=id, similarity=1 - distance)
for id, distance in zip(
query_result["ids"][0], query_result["distances"][0]
)
]
logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
return QueryResponse(results=results)
except Exception as e:
logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|