import logging import os from typing import List import sys import chromadb from chromadb.utils import embedding_functions from cashews import cache from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from contextlib import asynccontextmanager import polars as pl from huggingface_hub import hf_hub_url, DatasetCard, ModelCard, HfApi from datetime import datetime, timedelta from typing import Generator from huggingface_hub import ModelInfo, DatasetInfo import stamina import logging import polars as pl from huggingface_hub import dataset_info from huggingface_hub import InferenceClient from transformers import AutoTokenizer import stamina from tqdm.contrib.concurrent import thread_map from datasets import Dataset, Value, Sequence import datasets import os from dotenv import load_dotenv from huggingface_hub import get_inference_endpoint from huggingface_hub import AsyncInferenceClient import asyncio from typing import List hf_api = HfApi() tokenizer = AutoTokenizer.from_pretrained( "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13" ) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) LOCAL = False if sys.platform == "darwin": LOCAL = True DATA_DIR = "data" if LOCAL else "/data" # Configure cache cache.setup("mem://", size_limit="4gb") # Initialize ChromaDB client client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma") # Initialize FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): # Setup setup_database() yield # Cleanup await cache.close() app = FastAPI(lifespan=lifespan) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=[ "https://*.hf.space", # Allow all Hugging Face Spaces "https://*.huggingface.co", # Allow all Hugging Face domains # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define the embedding function at module level def get_embedding_function(): return embedding_functions.SentenceTransformerEmbeddingFunction( model_name="nomic-ai/modernbert-embed-base" ) def setup_database(): try: embedding_function = get_embedding_function() # Create collection with embedding function dataset_collection = client.get_or_create_collection( embedding_function=embedding_function, name="dataset_cards", metadata={"hnsw:space": "cosine"}, ) # TODO incremental updates df = pl.scan_parquet( "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet" ) df = df.filter( pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_() ) row_count = df.select(pl.len()).collect().item() logger.info(f"Row count of new data: {row_count}") if dataset_collection.count() < row_count: # Load parquet files and upsert into ChromaDB df = df.select( ["datasetId", "summary", "likes", "downloads", "last_modified"] ) df = df.collect() BATCH_SIZE = 1000 total_rows = len(df) for i in range(0, total_rows, BATCH_SIZE): batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) dataset_collection.upsert( ids=batch_df.select(["datasetId"]).to_series().to_list(), documents=batch_df.select(["summary"]).to_series().to_list(), metadatas=[ { "likes": int(likes), "downloads": int(downloads), "last_modified": str(last_modified), } for likes, downloads, last_modified in zip( batch_df.select(["likes"]).to_series().to_list(), batch_df.select(["downloads"]).to_series().to_list(), batch_df.select(["last_modified"]).to_series().to_list(), ) ], ) logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows") logger.info(f"Database initialized with {dataset_collection.count():,} rows") # model_collection = client.get_or_create_collection( # embedding_function=embedding_function, # name="model_cards", # metadata={"hnsw:space": "cosine"}, # ) # # If collection is empty, load data from parquet files # if model_collection.count() == 0: # # Load parquet files and insert into ChromaDB # df = pl.scan_parquet( # "hf://datasets/librarian-bots/model_cards_with_metadata/data/train-*.parquet" # ) # df = df.select(["modelId", "likes", "downloads"]) # df = df.collect() # df = df.sample(n=1000) # TODO remove for prod # # Process in batches of 1000 # BATCH_SIZE = 1000 # total_rows = len(df) # for i in range(0, total_rows, BATCH_SIZE): # batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i)) # model_collection.add( # ids=batch_df.select(["modelId"]).to_series().to_list(), # documents=batch_df.select(["summary"]).to_series().to_list(), # metadatas=[ # {"likes": int(likes), "downloads": int(downloads)} # for likes, downloads in zip( # batch_df.select(["likes"]).to_series().to_list(), # batch_df.select(["downloads"]).to_series().to_list(), # ) # ], # ) # logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows") # logger.info(f"Database initialized with {model_collection.count():,} rows") except Exception as e: logger.error(f"Setup error: {e}") # Run setup on startup setup_database() class QueryResult(BaseModel): dataset_id: str similarity: float summary: str likes: int downloads: int class QueryResponse(BaseModel): results: List[QueryResult] @app.get("/") async def redirect_to_docs(): from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.get("/search/datasets", response_model=QueryResponse) @cache(ttl="10m") async def search_datasets( query: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: # Get collection with proper embedding function collection = client.get_collection( name="dataset_cards", embedding_function=get_embedding_function() ) # Query ChromaDB results = collection.query( query_texts=[f"search_query: {query}"], n_results=k * 4 if sort_by != "similarity" else k, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) # Process results query_results = [] for i in range(len(results["ids"][0])): query_results.append( QueryResult( dataset_id=results["ids"][0][i], similarity=float(results["distances"][0][i]), summary=results["documents"][0][i], likes=results["metadatas"][0][i]["likes"], downloads=results["metadatas"][0][i]["downloads"], ) ) # Sort results if needed if sort_by != "similarity": query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) query_results = query_results[:k] return QueryResponse(results=query_results) except Exception as e: logger.error(f"Search error: {str(e)}") raise HTTPException(status_code=500, detail="Search failed") @app.get("/similarity/datasets", response_model=QueryResponse) @cache(ttl="10m") async def find_similar_datasets( dataset_id: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection("dataset_cards") # Get the reference document results = collection.get(ids=[dataset_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Dataset ID '{dataset_id}' not found" ) # Query using the embedding results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, # +1 to account for self-match where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) # Process results (excluding the query dataset itself) query_results = [] for i in range(len(results["ids"][0])): if results["ids"][0][i] != dataset_id: query_results.append( QueryResult( dataset_id=results["ids"][0][i], similarity=float(results["distances"][0][i]), summary=results["documents"][0][i], likes=results["metadatas"][0][i]["likes"], downloads=results["metadatas"][0][i]["downloads"], ) ) # Sort results if needed if sort_by != "similarity": query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) query_results = query_results[:k] else: query_results = query_results[:k] return QueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Similarity search failed") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)