import logging import os from typing import List import sys import duckdb from cashews import cache # Add this import from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sentence_transformers import SentenceTransformer from contextlib import asynccontextmanager os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) LOCAL = False if sys.platform == "darwin": LOCAL = True DATA_DIR = "data" if LOCAL else "/data" # Configure cache cache.setup("mem://", size_limit="4gb") # Initialize FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): # Startup: nothing special needed here since model and DB are initialized at module level yield # Cleanup await cache.close() con.close() app = FastAPI(lifespan=lifespan) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=[ "https://*.hf.space", # Allow all Hugging Face Spaces "https://*.huggingface.co", # Allow all Hugging Face domains # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize model and DuckDB model = SentenceTransformer("nomic-ai/modernbert-embed-base", device="cpu") embedding_dim = model.get_sentence_embedding_dimension() # Database setup with fallback db_path = f"{DATA_DIR}/vector_store.db" try: # Create directory if it doesn't exist os.makedirs(os.path.dirname(db_path), exist_ok=True) con = duckdb.connect(db_path) logger.info(f"Connected to persistent database at {db_path}") except (OSError, PermissionError) as e: logger.warning( f"Could not create/access {db_path}. Falling back to in-memory database. Error: {e}" ) con = duckdb.connect(":memory:") # Initialize VSS extension con.sql("INSTALL vss; LOAD vss;") con.sql("SET hnsw_enable_experimental_persistence=true;") def setup_database(): try: # Create table with properly typed embeddings con.sql(f""" CREATE TABLE IF NOT EXISTS model_cards AS SELECT *, embeddings::FLOAT[{embedding_dim}] as embeddings_float FROM 'hf://datasets/davanstrien/outputs-embeddings/**/*.parquet'; """) # Check if index exists index_exists = ( con.sql(""" SELECT COUNT(*) as count FROM duckdb_indexes WHERE index_name = 'my_hnsw_index'; """).fetchone()[0] > 0 ) if index_exists: # Drop existing index con.sql("DROP INDEX my_hnsw_index;") logger.info("Dropped existing HNSW index") # Create/Recreate HNSW index con.sql(""" CREATE INDEX my_hnsw_index ON model_cards USING HNSW (embeddings_float) WITH (metric = 'cosine'); """) logger.info("Created/Recreated HNSW index") # Log the number of rows in the database row_count = con.sql("SELECT COUNT(*) as count FROM model_cards").fetchone()[0] logger.info(f"Database initialized with {row_count:,} rows") except Exception as e: logger.error(f"Setup error: {e}") # Run setup on startup setup_database() class QueryResult(BaseModel): dataset_id: str similarity: float summary: str likes: int downloads: int class QueryResponse(BaseModel): results: List[QueryResult] @app.get("/") async def redirect_to_docs(): from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.get("/search/datasets", response_model=QueryResponse) @cache(ttl="10m") async def search_datasets(query: str, k: int = Query(default=5, ge=1, le=100)): try: query_embedding = model.encode(f"search_query: {query}").tolist() # Updated SQL query to include likes and downloads result = con.sql(f""" SELECT datasetId as dataset_id, 1 - array_cosine_distance( embeddings_float::FLOAT[{embedding_dim}], {query_embedding}::FLOAT[{embedding_dim}] ) as similarity, summary, likes, downloads FROM model_cards ORDER BY similarity DESC LIMIT {k}; """).df() # Updated result conversion results = [ QueryResult( dataset_id=row["dataset_id"], similarity=float(row["similarity"]), summary=row["summary"], likes=int(row["likes"]), downloads=int(row["downloads"]), ) for _, row in result.iterrows() ] return QueryResponse(results=results) except Exception as e: logger.error(f"Search error: {str(e)}") raise HTTPException(status_code=500, detail="Search failed") @app.get("/similarity/datasets", response_model=QueryResponse) @cache(ttl="10m") async def find_similar_datasets( dataset_id: str, k: int = Query(default=5, ge=1, le=100) ): try: # First, get the embedding for the input dataset_id reference_embedding = con.sql(f""" SELECT embeddings_float FROM model_cards WHERE datasetId = '{dataset_id}' LIMIT 1; """).df() if reference_embedding.empty: raise HTTPException( status_code=404, detail=f"Dataset ID '{dataset_id}' not found" ) # Updated similarity search query to include likes and downloads result = con.sql(f""" SELECT datasetId as dataset_id, 1 - array_cosine_distance( embeddings_float::FLOAT[{embedding_dim}], (SELECT embeddings_float FROM model_cards WHERE datasetId = '{dataset_id}' LIMIT 1) ) as similarity, summary, likes, downloads FROM model_cards WHERE datasetId != '{dataset_id}' ORDER BY similarity DESC LIMIT {k}; """).df() # Updated result conversion results = [ QueryResult( dataset_id=row["dataset_id"], similarity=float(row["similarity"]), summary=row["summary"], likes=int(row["likes"]), downloads=int(row["downloads"]), ) for _, row in result.iterrows() ] return QueryResponse(results=results) except HTTPException: raise except Exception as e: logger.error(f"Similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Similarity search failed") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)