import logging from contextlib import asynccontextmanager from typing import List, Optional import chromadb from cashews import cache from fastapi import FastAPI, HTTPException, Query from httpx import AsyncClient from huggingface_hub import DatasetCard from pydantic import BaseModel from starlette.responses import RedirectResponse from starlette.status import ( HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR, HTTP_403_FORBIDDEN, ) from load_card_data import get_embedding_function, get_save_path, refresh_card_data # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Set up caching cache.setup("mem://?check_interval=10&size=1000") # Initialize Chroma client SAVE_PATH = get_save_path() client = chromadb.PersistentClient(path=SAVE_PATH) collection = None async_client = AsyncClient( follow_redirects=True, ) @asynccontextmanager async def lifespan(app: FastAPI): global collection # Startup: refresh data and initialize collection logger.info("Starting up the application") try: # Create or get the collection embedding_function = get_embedding_function() collection = client.get_or_create_collection( name="dataset_cards", embedding_function=embedding_function ) logger.info("Collection initialized successfully") # Refresh data refresh_card_data() logger.info("Data refresh completed successfully") except Exception as e: logger.error(f"Error during startup: {str(e)}") raise yield # Here the app is running and handling requests # Shutdown: perform any cleanup logger.info("Shutting down the application") # Add any cleanup code here if needed app = FastAPI(lifespan=lifespan) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") async def try_get_card(hub_id: str) -> Optional[str]: try: response = await async_client.get( f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" ) if response.status_code == 200: card = DatasetCard(response.text) return card.text except Exception as e: logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") return None class QueryResult(BaseModel): dataset_id: str similarity: float class QueryResponse(BaseModel): results: List[QueryResult] class DatasetCardNotFoundError(HTTPException): def __init__(self, dataset_id: str): super().__init__( status_code=HTTP_404_NOT_FOUND, detail=f"No dataset card available for dataset: {dataset_id}", ) class DatasetNotForAllAudiencesError(HTTPException): def __init__(self, dataset_id: str): super().__init__( status_code=HTTP_403_FORBIDDEN, detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.", ) @app.get("/similar", response_model=QueryResponse) @cache(ttl="1h") async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): try: logger.info(f"Querying dataset: {dataset_id}") # Get the embedding for the given dataset_id result = collection.get(ids=[dataset_id], include=["embeddings"]) if not result.get("embeddings"): logger.info(f"Dataset not found: {dataset_id}") try: embedding_function = get_embedding_function() card = await try_get_card(dataset_id) if card is None: raise DatasetCardNotFoundError(dataset_id) embeddings = embedding_function(card) collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) logger.info(f"Dataset {dataset_id} added to collection") result = collection.get(ids=[dataset_id], include=["embeddings"]) if result.get("not-for-all-audiences"): raise DatasetNotForAllAudiencesError(dataset_id) except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError): raise except Exception as e: logger.error( f"Error adding dataset {dataset_id} to collection: {str(e)}" ) raise DatasetCardNotFoundError(dataset_id) from e embedding = result["embeddings"][0] # Query the collection for similar datasets query_result = collection.query( query_embeddings=[embedding], n_results=n, include=["distances"] ) if not query_result["ids"]: logger.info(f"No similar datasets found for: {dataset_id}") raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." ) # Prepare the response results = [ QueryResult(dataset_id=id, similarity=1 - distance) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") return QueryResponse(results=results) except (HTTPException, DatasetCardNotFoundError): raise except Exception as e: logger.error(f"Error querying dataset {dataset_id}: {str(e)}") raise HTTPException( status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.", ) from e @app.post("/similar_by_text", response_model=QueryResponse) @cache(ttl="1h") async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)): try: logger.info(f"Querying datasets by text: {query}") collection = client.get_collection( name="dataset_cards", embedding_function=get_embedding_function() ) print(query) query_result = collection.query( query_texts=query, n_results=n, include=["distances"] ) print(query_result) if not query_result["ids"]: logger.info(f"No similar datasets found for query: {query}") raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." ) # Prepare the response results = [ QueryResult(dataset_id=str(id), similarity=float(1 - distance)) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for query: {query}") return QueryResponse(results=results) except Exception as e: logger.error(f"Error querying datasets by text {query}: {str(e)}") raise HTTPException( status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.", ) from e if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)