Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import logging | |
import os | |
from typing import List | |
import sys | |
import duckdb | |
from cashews import cache # Add this import | |
from fastapi import FastAPI, HTTPException, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer | |
from contextlib import asynccontextmanager | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
LOCAL = False | |
if sys.platform == "darwin": | |
LOCAL = True | |
DATA_DIR = "data" if LOCAL else "/data" | |
# Configure cache | |
cache.setup("mem://", size_limit="4gb") | |
# Initialize FastAPI app | |
async def lifespan(app: FastAPI): | |
# Startup: nothing special needed here since model and DB are initialized at module level | |
yield | |
# Cleanup | |
await cache.close() | |
con.close() | |
app = FastAPI(lifespan=lifespan) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=[ | |
"https://*.hf.space", # Allow all Hugging Face Spaces | |
"https://*.huggingface.co", # Allow all Hugging Face domains | |
# "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod | |
], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize model and DuckDB | |
model = SentenceTransformer("nomic-ai/modernbert-embed-base", device="cpu") | |
embedding_dim = model.get_sentence_embedding_dimension() | |
# Database setup with fallback | |
db_path = f"{DATA_DIR}/vector_store.db" | |
try: | |
# Create directory if it doesn't exist | |
os.makedirs(os.path.dirname(db_path), exist_ok=True) | |
con = duckdb.connect(db_path) | |
logger.info(f"Connected to persistent database at {db_path}") | |
except (OSError, PermissionError) as e: | |
logger.warning( | |
f"Could not create/access {db_path}. Falling back to in-memory database. Error: {e}" | |
) | |
con = duckdb.connect(":memory:") | |
# Initialize VSS extension | |
con.sql("INSTALL vss; LOAD vss;") | |
con.sql("SET hnsw_enable_experimental_persistence=true;") | |
def setup_database(): | |
try: | |
# Create table with properly typed embeddings | |
con.sql(f""" | |
CREATE TABLE IF NOT EXISTS model_cards AS | |
SELECT *, embeddings::FLOAT[{embedding_dim}] as embeddings_float | |
FROM 'hf://datasets/davanstrien/outputs-embeddings/**/*.parquet'; | |
""") | |
# Check if index exists | |
index_exists = ( | |
con.sql(""" | |
SELECT COUNT(*) as count | |
FROM duckdb_indexes | |
WHERE index_name = 'my_hnsw_index'; | |
""").fetchone()[0] | |
> 0 | |
) | |
if index_exists: | |
# Drop existing index | |
con.sql("DROP INDEX my_hnsw_index;") | |
logger.info("Dropped existing HNSW index") | |
# Create/Recreate HNSW index | |
con.sql(""" | |
CREATE INDEX my_hnsw_index ON model_cards | |
USING HNSW (embeddings_float) WITH (metric = 'cosine'); | |
""") | |
logger.info("Created/Recreated HNSW index") | |
# Log the number of rows in the database | |
row_count = con.sql("SELECT COUNT(*) as count FROM model_cards").fetchone()[0] | |
logger.info(f"Database initialized with {row_count:,} rows") | |
except Exception as e: | |
logger.error(f"Setup error: {e}") | |
# Run setup on startup | |
setup_database() | |
class QueryResult(BaseModel): | |
dataset_id: str | |
similarity: float | |
summary: str | |
likes: int | |
downloads: int | |
class QueryResponse(BaseModel): | |
results: List[QueryResult] | |
async def redirect_to_docs(): | |
from fastapi.responses import RedirectResponse | |
return RedirectResponse(url="/docs") | |
async def search_datasets(query: str, k: int = Query(default=5, ge=1, le=100)): | |
try: | |
query_embedding = model.encode(f"search_query: {query}").tolist() | |
# Updated SQL query to include likes and downloads | |
result = con.sql(f""" | |
SELECT | |
datasetId as dataset_id, | |
1 - array_cosine_distance( | |
embeddings_float::FLOAT[{embedding_dim}], | |
{query_embedding}::FLOAT[{embedding_dim}] | |
) as similarity, | |
summary, | |
likes, | |
downloads | |
FROM model_cards | |
ORDER BY similarity DESC | |
LIMIT {k}; | |
""").df() | |
# Updated result conversion | |
results = [ | |
QueryResult( | |
dataset_id=row["dataset_id"], | |
similarity=float(row["similarity"]), | |
summary=row["summary"], | |
likes=int(row["likes"]), | |
downloads=int(row["downloads"]), | |
) | |
for _, row in result.iterrows() | |
] | |
return QueryResponse(results=results) | |
except Exception as e: | |
logger.error(f"Search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Search failed") | |
async def find_similar_datasets( | |
dataset_id: str, k: int = Query(default=5, ge=1, le=100) | |
): | |
try: | |
# First, get the embedding for the input dataset_id | |
reference_embedding = con.sql(f""" | |
SELECT embeddings_float | |
FROM model_cards | |
WHERE datasetId = '{dataset_id}' | |
LIMIT 1; | |
""").df() | |
if reference_embedding.empty: | |
raise HTTPException( | |
status_code=404, detail=f"Dataset ID '{dataset_id}' not found" | |
) | |
# Updated similarity search query to include likes and downloads | |
result = con.sql(f""" | |
SELECT | |
datasetId as dataset_id, | |
1 - array_cosine_distance( | |
embeddings_float::FLOAT[{embedding_dim}], | |
(SELECT embeddings_float FROM model_cards WHERE datasetId = '{dataset_id}' LIMIT 1) | |
) as similarity, | |
summary, | |
likes, | |
downloads | |
FROM model_cards | |
WHERE datasetId != '{dataset_id}' | |
ORDER BY similarity DESC | |
LIMIT {k}; | |
""").df() | |
# Updated result conversion | |
results = [ | |
QueryResult( | |
dataset_id=row["dataset_id"], | |
similarity=float(row["similarity"]), | |
summary=row["summary"], | |
likes=int(row["likes"]), | |
downloads=int(row["downloads"]), | |
) | |
for _, row in result.iterrows() | |
] | |
return QueryResponse(results=results) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Similarity search error: {str(e)}") | |
raise HTTPException(status_code=500, detail="Similarity search failed") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |