import logging import os from typing import List import sys import chromadb from chromadb.utils import embedding_functions from cashews import cache from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from contextlib import asynccontextmanager import polars as pl from huggingface_hub import HfApi from transformers import AutoTokenizer import torch # Configuration constants MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base" BATCH_SIZE = 2000 CACHE_TTL = "60" if torch.cuda.is_available(): DEVICE = "cuda" elif torch.backends.mps.is_available(): DEVICE = "mps" else: DEVICE = "cpu" hf_api = HfApi() tokenizer = AutoTokenizer.from_pretrained( "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" ) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) LOCAL = False if sys.platform == "darwin": LOCAL = True DATA_DIR = "data" if LOCAL else "/data" # Configure cache cache.setup("mem://", size_limit="5gb") # Initialize ChromaDB client client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") # Initialize FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): # Setup setup_database() yield # Cleanup await cache.close() app = FastAPI(lifespan=lifespan) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=[ "https://*.hf.space", # Allow all Hugging Face Spaces "https://*.huggingface.co", # Allow all Hugging Face domains "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define the embedding function at module level def get_embedding_function(): logger.info(f"Using device: {DEVICE}") return embedding_functions.SentenceTransformerEmbeddingFunction( model_name="nomic-ai/modernbert-embed-base", device=DEVICE ) def setup_database(): try: embedding_function = get_embedding_function() # Create dataset collection dataset_collection = client.get_or_create_collection( embedding_function=embedding_function, name="dataset_cards", metadata={"hnsw:space": "cosine"}, ) # Create model collection model_collection = client.get_or_create_collection( embedding_function=embedding_function, name="model_cards", metadata={"hnsw:space": "cosine"}, ) # Load dataset data df = pl.scan_parquet( "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" ) df = df.filter( pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() ) row_count = df.select(pl.len()).collect().item() logger.info(f"Row count of dataset data: {row_count}") # Check if we need to update the collection current_count = dataset_collection.count() logger.info(f"Current dataset collection count: {current_count}") if current_count < row_count: logger.info( f"Updating dataset collection with {row_count - current_count} new records" ) # Load parquet files and upsert into ChromaDB df = df.select( ["datasetId", "summary", "likes", "downloads", "last_modified"] ) df = df.collect() total_rows = len(df) for i in range(0, total_rows, BATCH_SIZE): batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) dataset_collection.upsert( ids=batch_df.select(["datasetId"]).to_series().to_list(), documents=batch_df.select(["summary"]).to_series().to_list(), metadatas=[ { "likes": int(likes), "downloads": int(downloads), "last_modified": str(last_modified), } for likes, downloads, last_modified in zip( batch_df.select(["likes"]).to_series().to_list(), batch_df.select(["downloads"]).to_series().to_list(), batch_df.select(["last_modified"]).to_series().to_list(), ) ], ) logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows") logger.info(f"Database initialized with {dataset_collection.count():,} rows") # Load model data model_df = pl.scan_parquet( "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet" ) model_row_count = model_df.select(pl.len()).collect().item() logger.info(f"Row count of new model data: {model_row_count}") if model_collection.count() < model_row_count: model_df = model_df.select( ["modelId", "summary", "likes", "downloads", "last_modified"] ) model_df = model_df.collect() total_rows = len(model_df) for i in range(0, total_rows, BATCH_SIZE): batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i)) model_collection.upsert( ids=batch_df.select(["modelId"]).to_series().to_list(), documents=batch_df.select(["summary"]).to_series().to_list(), metadatas=[ { "likes": int(likes), "downloads": int(downloads), "last_modified": str(last_modified), } for likes, downloads, last_modified in zip( batch_df.select(["likes"]).to_series().to_list(), batch_df.select(["downloads"]).to_series().to_list(), batch_df.select(["last_modified"]).to_series().to_list(), ) ], ) logger.info( f"Processed {i + len(batch_df):,} / {total_rows:,} model rows" ) logger.info( f"Model database initialized with {model_collection.count():,} rows" ) except Exception as e: logger.error(f"Setup error: {e}") # Run setup on startup setup_database() class QueryResult(BaseModel): dataset_id: str similarity: float summary: str likes: int downloads: int class QueryResponse(BaseModel): results: List[QueryResult] class ModelQueryResult(BaseModel): model_id: str similarity: float summary: str likes: int downloads: int class ModelQueryResponse(BaseModel): results: List[ModelQueryResult] @app.get("/") async def redirect_to_docs(): from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.get("/search/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def search_datasets( query: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: # Get collection with proper embedding function collection = client.get_collection( name="dataset_cards", embedding_function=get_embedding_function() ) # Query ChromaDB results = collection.query( query_texts=[f"search_query: {query}"], n_results=k * 4 if sort_by != "similarity" else k, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) # Process results query_results = process_search_results(results, "dataset", k, sort_by) return QueryResponse(results=query_results) except Exception as e: logger.error(f"Search error: {str(e)}") raise HTTPException(status_code=500, detail="Search failed") @app.get("/similarity/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_datasets( dataset_id: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection("dataset_cards") # Get the reference document results = collection.get(ids=[dataset_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Dataset ID '{dataset_id}' not found" ) # Query using the embedding results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, # +1 to account for self-match where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) # Process results (excluding the query dataset itself) query_results = process_search_results( results, "dataset", k, sort_by, dataset_id ) return QueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Similarity search failed") @app.get("/search/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def search_models( query: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection( name="model_cards", embedding_function=get_embedding_function() ) results = collection.query( query_texts=[f"search_query: {query}"], n_results=k * 4 if sort_by != "similarity" else k, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = process_search_results(results, "model", k, sort_by) return ModelQueryResponse(results=query_results) except Exception as e: logger.error(f"Model search error: {str(e)}") raise HTTPException(status_code=500, detail="Model search failed") @app.get("/similarity/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_models( model_id: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection("model_cards") results = collection.get(ids=[model_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Model ID '{model_id}' not found" ) results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = process_search_results(results, "model", k, sort_by, model_id) return ModelQueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Model similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Model similarity search failed") def process_search_results(results, id_field, k, sort_by, exclude_id=None): """Process search results into a standardized format.""" query_results = [] for i in range(len(results["ids"][0])): current_id = results["ids"][0][i] if exclude_id and current_id == exclude_id: continue result = { f"{id_field}_id": current_id, "similarity": float(results["distances"][0][i]), "summary": results["documents"][0][i], "likes": results["metadatas"][0][i]["likes"], "downloads": results["metadatas"][0][i]["downloads"], } if id_field == "dataset": query_results.append(QueryResult(**result)) else: query_results.append(ModelQueryResult(**result)) if sort_by != "similarity": query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) query_results = query_results[:k] elif exclude_id: # We fetched extra for similarity + exclude_id case query_results = query_results[:k] return query_results if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)