davanstrien's picture
davanstrien HF Staff
refactor to duckdb
79f2ae1
raw
history blame
7.05 kB
import logging
import os
from typing import List
import sys
import duckdb
from cashews import cache # Add this import
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from contextlib import asynccontextmanager
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOCAL = False
if sys.platform == "darwin":
LOCAL = True
DATA_DIR = "data" if LOCAL else "/data"
# Configure cache
cache.setup("mem://", size_limit="4gb")
# Initialize FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: nothing special needed here since model and DB are initialized at module level
yield
# Cleanup
await cache.close()
con.close()
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.hf.space", # Allow all Hugging Face Spaces
"https://*.huggingface.co", # Allow all Hugging Face domains
# "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize model and DuckDB
model = SentenceTransformer("nomic-ai/modernbert-embed-base", device="cpu")
embedding_dim = model.get_sentence_embedding_dimension()
# Database setup with fallback
db_path = f"{DATA_DIR}/vector_store.db"
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(db_path), exist_ok=True)
con = duckdb.connect(db_path)
logger.info(f"Connected to persistent database at {db_path}")
except (OSError, PermissionError) as e:
logger.warning(
f"Could not create/access {db_path}. Falling back to in-memory database. Error: {e}"
)
con = duckdb.connect(":memory:")
# Initialize VSS extension
con.sql("INSTALL vss; LOAD vss;")
con.sql("SET hnsw_enable_experimental_persistence=true;")
def setup_database():
try:
# Create table with properly typed embeddings
con.sql(f"""
CREATE TABLE IF NOT EXISTS model_cards AS
SELECT *, embeddings::FLOAT[{embedding_dim}] as embeddings_float
FROM 'hf://datasets/davanstrien/outputs-embeddings/**/*.parquet';
""")
# Check if index exists
index_exists = (
con.sql("""
SELECT COUNT(*) as count
FROM duckdb_indexes
WHERE index_name = 'my_hnsw_index';
""").fetchone()[0]
> 0
)
if index_exists:
# Drop existing index
con.sql("DROP INDEX my_hnsw_index;")
logger.info("Dropped existing HNSW index")
# Create/Recreate HNSW index
con.sql("""
CREATE INDEX my_hnsw_index ON model_cards
USING HNSW (embeddings_float) WITH (metric = 'cosine');
""")
logger.info("Created/Recreated HNSW index")
# Log the number of rows in the database
row_count = con.sql("SELECT COUNT(*) as count FROM model_cards").fetchone()[0]
logger.info(f"Database initialized with {row_count:,} rows")
except Exception as e:
logger.error(f"Setup error: {e}")
# Run setup on startup
setup_database()
class QueryResult(BaseModel):
dataset_id: str
similarity: float
summary: str
likes: int
downloads: int
class QueryResponse(BaseModel):
results: List[QueryResult]
@app.get("/")
async def redirect_to_docs():
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
@app.get("/search/datasets", response_model=QueryResponse)
@cache(ttl="10m")
async def search_datasets(query: str, k: int = Query(default=5, ge=1, le=100)):
try:
query_embedding = model.encode(f"search_query: {query}").tolist()
# Updated SQL query to include likes and downloads
result = con.sql(f"""
SELECT
datasetId as dataset_id,
1 - array_cosine_distance(
embeddings_float::FLOAT[{embedding_dim}],
{query_embedding}::FLOAT[{embedding_dim}]
) as similarity,
summary,
likes,
downloads
FROM model_cards
ORDER BY similarity DESC
LIMIT {k};
""").df()
# Updated result conversion
results = [
QueryResult(
dataset_id=row["dataset_id"],
similarity=float(row["similarity"]),
summary=row["summary"],
likes=int(row["likes"]),
downloads=int(row["downloads"]),
)
for _, row in result.iterrows()
]
return QueryResponse(results=results)
except Exception as e:
logger.error(f"Search error: {str(e)}")
raise HTTPException(status_code=500, detail="Search failed")
@app.get("/similarity/datasets", response_model=QueryResponse)
@cache(ttl="10m")
async def find_similar_datasets(
dataset_id: str, k: int = Query(default=5, ge=1, le=100)
):
try:
# First, get the embedding for the input dataset_id
reference_embedding = con.sql(f"""
SELECT embeddings_float
FROM model_cards
WHERE datasetId = '{dataset_id}'
LIMIT 1;
""").df()
if reference_embedding.empty:
raise HTTPException(
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
)
# Updated similarity search query to include likes and downloads
result = con.sql(f"""
SELECT
datasetId as dataset_id,
1 - array_cosine_distance(
embeddings_float::FLOAT[{embedding_dim}],
(SELECT embeddings_float FROM model_cards WHERE datasetId = '{dataset_id}' LIMIT 1)
) as similarity,
summary,
likes,
downloads
FROM model_cards
WHERE datasetId != '{dataset_id}'
ORDER BY similarity DESC
LIMIT {k};
""").df()
# Updated result conversion
results = [
QueryResult(
dataset_id=row["dataset_id"],
similarity=float(row["similarity"]),
summary=row["summary"],
likes=int(row["likes"]),
downloads=int(row["downloads"]),
)
for _, row in result.iterrows()
]
return QueryResponse(results=results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Similarity search failed")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)