File size: 2,839 Bytes
2057a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Optional, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
import chromadb
import logging
from load_data import get_save_path, refresh_data
from cashews import cache

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Set up caching
cache.setup("mem://?check_interval=10&size=10000")

# Initialize Chroma client
SAVE_PATH = get_save_path()
client = chromadb.PersistentClient(path=SAVE_PATH)
collection = client.get_collection("dataset_cards")


class QueryResult(BaseModel):
    dataset_id: str
    similarity: float


class QueryResponse(BaseModel):
    results: List[QueryResult]


@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup: refresh data
    logger.info("Starting up the application")
    try:
        refresh_data()
        logger.info("Data refresh completed successfully")
    except Exception as e:
        logger.error(f"Error during data refresh: {str(e)}")

    yield  # Here the app is running and handling requests

    # Shutdown: perform any cleanup
    logger.info("Shutting down the application")
    # Add any cleanup code here if needed


app = FastAPI(lifespan=lifespan)


@app.get("/query", response_model=Optional[QueryResponse])
@cache(ttl="1h")
async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
    try:
        logger.info(f"Querying dataset: {dataset_id}")
        # Get the embedding for the given dataset_id
        result = collection.get(ids=[dataset_id], include=["embeddings"])

        if not result["embeddings"]:
            logger.info(f"Dataset not found: {dataset_id}")
            raise HTTPException(status_code=404, detail="Dataset not found")

        embedding = result["embeddings"][0]

        # Query the collection for similar datasets
        query_result = collection.query(
            query_embeddings=[embedding], n_results=n, include=["distances"]
        )

        if not query_result["ids"]:
            logger.info(f"No similar datasets found for: {dataset_id}")
            return None

        # Prepare the response
        results = [
            QueryResult(dataset_id=id, similarity=1 - distance)
            for id, distance in zip(
                query_result["ids"][0], query_result["distances"][0]
            )
        ]

        logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
        return QueryResponse(results=results)

    except Exception as e:
        logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)