import asyncio import logging import os import sys from contextlib import asynccontextmanager from datetime import datetime from typing import List import chromadb import dateutil.parser import httpx import polars as pl import torch from cashews import cache from chromadb.utils import embedding_functions from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer # Configuration constants MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base" BATCH_SIZE = 2000 CACHE_TTL = "24h" TRENDING_CACHE_TTL = "1h" # 15 minutes cache for trending data if torch.cuda.is_available(): DEVICE = "cuda" elif torch.backends.mps.is_available(): DEVICE = "mps" else: DEVICE = "cpu" tokenizer = AutoTokenizer.from_pretrained( "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" ) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) LOCAL = False if sys.platform == "darwin": LOCAL = True DATA_DIR = "data" if LOCAL else "/data" # Configure cache cache.setup("mem://", size_limit="8gb") # Initialize ChromaDB client client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") # Initialize FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): # Setup setup_database() yield # Cleanup await cache.close() app = FastAPI(lifespan=lifespan) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=[ "https://*.hf.space", # Allow all Hugging Face Spaces "https://*.huggingface.co", # Allow all Hugging Face domains # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define the embedding function at module level def get_embedding_function(): logger.info(f"Using device: {DEVICE}") return embedding_functions.SentenceTransformerEmbeddingFunction( model_name="nomic-ai/modernbert-embed-base", device=DEVICE ) def setup_database(): try: embedding_function = get_embedding_function() dataset_collection = client.get_or_create_collection( embedding_function=embedding_function, name="dataset_cards", metadata={"hnsw:space": "cosine"}, ) model_collection = client.get_or_create_collection( embedding_function=embedding_function, name="model_cards", metadata={"hnsw:space": "cosine"}, ) # Load dataset data df = pl.scan_parquet( "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" ) df = df.filter( pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() ) df = df.filter( pl.col("datasetId") .str.contains_any( ["gemma-2-2B-it-thinking-function_calling-V0"] ) # course model that's not useful for retrieving .not_() ) # Get the most recent last_modified date from the collection latest_update = None if dataset_collection.count() > 0: metadata = dataset_collection.get(include=["metadatas"]).get("metadatas") logger.info(f"Found {len(metadata)} existing records in collection") last_modifieds = [ dateutil.parser.parse(m.get("last_modified")) for m in metadata ] latest_update = max(last_modifieds) logger.info(f"Most recent record in DB from: {latest_update}") logger.info(f"Oldest record in DB from: {min(last_modifieds)}") # Filter and process only newer records df = df.select(["datasetId", "summary", "likes", "downloads", "last_modified"]) # Log some stats about the incoming data sample_dates = df.select("last_modified").limit(5).collect() logger.info(f"Sample of incoming dates: {sample_dates}") total_incoming = df.select(pl.len()).collect().item() logger.info(f"Total incoming records: {total_incoming}") if latest_update: logger.info(f"Filtering records newer than {latest_update}") df = df.filter(pl.col("last_modified") > latest_update) filtered_count = df.select(pl.len()).collect().item() logger.info(f"Found {filtered_count} records to update after filtering") df = df.collect() total_rows = len(df) if total_rows > 0: logger.info(f"Updating dataset collection with {total_rows} new records") logger.info( f"Date range of updates: {df['last_modified'].min()} to {df['last_modified'].max()}" ) for i in range(0, total_rows, BATCH_SIZE): batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) batch_size = len(batch_df) logger.info( f"Processing batch {i // BATCH_SIZE + 1}: {batch_size} records " f"({batch_df['last_modified'].min()} to {batch_df['last_modified'].max()})" ) dataset_collection.upsert( ids=batch_df.select(["datasetId"]).to_series().to_list(), documents=batch_df.select(["summary"]).to_series().to_list(), metadatas=[ { "likes": int(likes), "downloads": int(downloads), "last_modified": str(last_modified), } for likes, downloads, last_modified in zip( batch_df.select(["likes"]).to_series().to_list(), batch_df.select(["downloads"]).to_series().to_list(), batch_df.select(["last_modified"]).to_series().to_list(), ) ], ) logger.info(f"Processed {i + batch_size:,} / {total_rows:,} records") logger.info( f"Database initialized with {dataset_collection.count():,} total rows" ) # Load model data model_df = pl.scan_parquet( "hf://datasets/davanstrien/models_with_metadata_and_summaries/data/train-*.parquet" ) model_row_count = model_df.select(pl.len()).collect().item() logger.info(f"Row count of new model data: {model_row_count}") if model_collection.count() < model_row_count: model_df = model_df.select( ["modelId", "summary", "likes", "downloads", "last_modified"] ) model_df = model_df.collect() total_rows = len(model_df) for i in range(0, total_rows, BATCH_SIZE): batch_df = model_df.slice(i, min(BATCH_SIZE, total_rows - i)) model_collection.upsert( ids=batch_df.select(["modelId"]).to_series().to_list(), documents=batch_df.select(["summary"]).to_series().to_list(), metadatas=[ { "likes": int(likes), "downloads": int(downloads), "last_modified": str(last_modified), } for likes, downloads, last_modified in zip( batch_df.select(["likes"]).to_series().to_list(), batch_df.select(["downloads"]).to_series().to_list(), batch_df.select(["last_modified"]).to_series().to_list(), ) ], ) logger.info( f"Processed {i + len(batch_df):,} / {total_rows:,} model rows" ) logger.info( f"Model database initialized with {model_collection.count():,} rows" ) except Exception as e: logger.error(f"Setup error: {e}") # Run setup on startup setup_database() class QueryResult(BaseModel): dataset_id: str similarity: float summary: str likes: int downloads: int class QueryResponse(BaseModel): results: List[QueryResult] class ModelQueryResult(BaseModel): model_id: str similarity: float summary: str likes: int downloads: int class ModelQueryResponse(BaseModel): results: List[ModelQueryResult] @app.get("/") async def redirect_to_docs(): from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.get("/search/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def search_datasets( query: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection( name="dataset_cards", embedding_function=get_embedding_function() ) results = collection.query( query_texts=[f"search_query: {query}"], n_results=k * 4 if sort_by != "similarity" else k, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results(results, "dataset", k, sort_by) return QueryResponse(results=query_results) except Exception as e: logger.error(f"Search error: {str(e)}") raise HTTPException(status_code=500, detail="Search failed") @app.get("/similarity/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_datasets( dataset_id: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection("dataset_cards") results = collection.get(ids=[dataset_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Dataset ID '{dataset_id}' not found" ) results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results( results, "dataset", k, sort_by, dataset_id ) return QueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Similarity search failed") @app.get("/search/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def search_models( query: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection( name="model_cards", embedding_function=get_embedding_function() ) results = collection.query( query_texts=[f"search_query: {query}"], n_results=k * 4 if sort_by != "similarity" else k, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results(results, "model", k, sort_by) return ModelQueryResponse(results=query_results) except Exception as e: logger.error(f"Model search error: {str(e)}") raise HTTPException(status_code=500, detail="Model search failed") @app.get("/similarity/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_models( model_id: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection("model_cards") results = collection.get(ids=[model_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Model ID '{model_id}' not found" ) results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results( results, "model", k, sort_by, model_id ) return ModelQueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Model similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Model similarity search failed") @cache(ttl="1h") async def get_trending_score(item_id: str, item_type: str) -> float: """Fetch trending score for a model or dataset from HuggingFace API""" try: async with httpx.AsyncClient() as client: endpoint = "models" if item_type == "model" else "datasets" response = await client.get( f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore" ) response.raise_for_status() return response.json().get("trendingScore", 0) except Exception as e: logger.error( f"Error fetching trending score for {item_type} {item_id}: {str(e)}" ) return 0 async def process_search_results(results, id_field, k, sort_by, exclude_id=None): """Process search results into a standardized format.""" query_results = [] # Create base results for i in range(len(results["ids"][0])): current_id = results["ids"][0][i] if exclude_id and current_id == exclude_id: continue result = { f"{id_field}_id": current_id, "similarity": float(results["distances"][0][i]), "summary": results["documents"][0][i], "likes": results["metadatas"][0][i]["likes"], "downloads": results["metadatas"][0][i]["downloads"], } if id_field == "dataset": query_results.append(QueryResult(**result)) else: query_results.append(ModelQueryResult(**result)) # Handle sorting if sort_by == "trending": # Fetch trending scores for all results trending_scores = {} async with httpx.AsyncClient() as client: tasks = [ get_trending_score( getattr(result, f"{id_field}_id"), "model" if id_field == "model" else "dataset", ) for result in query_results ] scores = await asyncio.gather(*tasks) trending_scores = { getattr(result, f"{id_field}_id"): score for result, score in zip(query_results, scores) } # Sort by trending score query_results.sort( key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0), reverse=True, ) query_results = query_results[:k] elif sort_by != "similarity": query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) query_results = query_results[:k] elif exclude_id: # We fetched extra for similarity + exclude_id case query_results = query_results[:k] return query_results async def fetch_trending_models(): """Fetch trending models from HuggingFace API""" async with httpx.AsyncClient() as client: response = await client.get("https://huggingface.co/api/models") response.raise_for_status() return response.json() @cache(ttl=TRENDING_CACHE_TTL) async def get_trending_models_with_summaries( limit: int = 10, min_likes: int = 0, min_downloads: int = 0, ) -> List[ModelQueryResult]: """Fetch trending models and combine with summaries from database""" try: # Fetch trending models trending_models = await fetch_trending_models() # Filter by minimum likes/downloads trending_models = [ model for model in trending_models if model.get("likes", 0) >= min_likes and model.get("downloads", 0) >= min_downloads ] # Sort by trending score and limit trending_models = sorted( trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True )[:limit] # Get model IDs model_ids = [model["modelId"] for model in trending_models] # Fetch summaries from ChromaDB collection = client.get_collection("model_cards") summaries = collection.get(ids=model_ids, include=["documents"]) # Create mapping of model_id to summary id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) # Combine data results = [] for model in trending_models: if model["modelId"] in id_to_summary: result = ModelQueryResult( model_id=model["modelId"], similarity=1.0, # Not applicable for trending summary=id_to_summary[model["modelId"]], likes=model.get("likes", 0), downloads=model.get("downloads", 0), ) results.append(result) return results except Exception as e: logger.error(f"Error fetching trending models: {str(e)}") raise HTTPException(status_code=500, detail="Failed to fetch trending models") @app.get("/trending/models", response_model=ModelQueryResponse) async def get_trending_models( limit: int = Query(default=10, ge=1, le=100), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): """Get trending models with their summaries""" results = await get_trending_models_with_summaries( limit=limit, min_likes=min_likes, min_downloads=min_downloads ) return ModelQueryResponse(results=results) async def fetch_trending_datasets(): """Fetch trending datasets from HuggingFace API""" async with httpx.AsyncClient() as client: response = await client.get("https://huggingface.co/api/datasets") response.raise_for_status() return response.json() @cache(ttl=TRENDING_CACHE_TTL) async def get_trending_datasets_with_summaries( limit: int = 10, min_likes: int = 0, min_downloads: int = 0, ) -> List[QueryResult]: """Fetch trending datasets and combine with summaries from database""" try: # Fetch trending datasets trending_datasets = await fetch_trending_datasets() # Filter by minimum likes/downloads trending_datasets = [ dataset for dataset in trending_datasets if dataset.get("likes", 0) >= min_likes and dataset.get("downloads", 0) >= min_downloads ] # Sort by trending score and limit trending_datasets = sorted( trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True )[:limit] # Get dataset IDs dataset_ids = [dataset["id"] for dataset in trending_datasets] # Fetch summaries from ChromaDB collection = client.get_collection("dataset_cards") summaries = collection.get(ids=dataset_ids, include=["documents"]) # Create mapping of dataset_id to summary id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) # Combine data results = [] for dataset in trending_datasets: if dataset["id"] in id_to_summary: result = QueryResult( dataset_id=dataset["id"], similarity=1.0, # Not applicable for trending summary=id_to_summary[dataset["id"]], likes=dataset.get("likes", 0), downloads=dataset.get("downloads", 0), ) results.append(result) return results except Exception as e: logger.error(f"Error fetching trending datasets: {str(e)}") raise HTTPException(status_code=500, detail="Failed to fetch trending datasets") @app.get("/trending/datasets", response_model=QueryResponse) async def get_trending_datasets( limit: int = Query(default=10, ge=1, le=100), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): """Get trending datasets with their summaries""" results = await get_trending_datasets_with_summaries( limit=limit, min_likes=min_likes, min_downloads=min_downloads ) return QueryResponse(results=results) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)