davanstrien's picture
davanstrien HF Staff
switch to chromadb
7cf16e2
raw
history blame
11.1 kB
import logging
import os
from typing import List
import sys
import chromadb
from chromadb.utils import embedding_functions
from cashews import cache
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from contextlib import asynccontextmanager
import polars as pl
from huggingface_hub import hf_hub_url, DatasetCard, ModelCard, HfApi
from datetime import datetime, timedelta
from typing import Generator
from huggingface_hub import ModelInfo, DatasetInfo
import stamina
import logging
import polars as pl
from huggingface_hub import dataset_info
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
import stamina
from tqdm.contrib.concurrent import thread_map
from datasets import Dataset, Value, Sequence
import datasets
import os
from dotenv import load_dotenv
from huggingface_hub import get_inference_endpoint
from huggingface_hub import AsyncInferenceClient
import asyncio
from typing import List
hf_api = HfApi()
tokenizer = AutoTokenizer.from_pretrained(
"davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # turn on HF_TRANSFER
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOCAL = False
if sys.platform == "darwin":
LOCAL = True
DATA_DIR = "data" if LOCAL else "/data"
# Configure cache
cache.setup("mem://", size_limit="4gb")
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=f"{DATA_DIR}/chroma")
# Initialize FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup
setup_database()
yield
# Cleanup
await cache.close()
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.hf.space", # Allow all Hugging Face Spaces
"https://*.huggingface.co", # Allow all Hugging Face domains
# "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the embedding function at module level
def get_embedding_function():
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="nomic-ai/modernbert-embed-base"
)
def setup_database():
try:
embedding_function = get_embedding_function()
# Create collection with embedding function
dataset_collection = client.get_or_create_collection(
embedding_function=embedding_function,
name="dataset_cards",
metadata={"hnsw:space": "cosine"},
)
# TODO incremental updates
df = pl.scan_parquet(
"hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
)
df = df.filter(
pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
)
row_count = df.select(pl.len()).collect().item()
logger.info(f"Row count of new data: {row_count}")
if dataset_collection.count() < row_count:
# Load parquet files and upsert into ChromaDB
df = df.select(
["datasetId", "summary", "likes", "downloads", "last_modified"]
)
df = df.collect()
BATCH_SIZE = 1000
total_rows = len(df)
for i in range(0, total_rows, BATCH_SIZE):
batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
dataset_collection.upsert(
ids=batch_df.select(["datasetId"]).to_series().to_list(),
documents=batch_df.select(["summary"]).to_series().to_list(),
metadatas=[
{
"likes": int(likes),
"downloads": int(downloads),
"last_modified": str(last_modified),
}
for likes, downloads, last_modified in zip(
batch_df.select(["likes"]).to_series().to_list(),
batch_df.select(["downloads"]).to_series().to_list(),
batch_df.select(["last_modified"]).to_series().to_list(),
)
],
)
logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
logger.info(f"Database initialized with {dataset_collection.count():,} rows")
# model_collection = client.get_or_create_collection(
# embedding_function=embedding_function,
# name="model_cards",
# metadata={"hnsw:space": "cosine"},
# )
# # If collection is empty, load data from parquet files
# if model_collection.count() == 0:
# # Load parquet files and insert into ChromaDB
# df = pl.scan_parquet(
# "hf://datasets/librarian-bots/model_cards_with_metadata/data/train-*.parquet"
# )
# df = df.select(["modelId", "likes", "downloads"])
# df = df.collect()
# df = df.sample(n=1000) # TODO remove for prod
# # Process in batches of 1000
# BATCH_SIZE = 1000
# total_rows = len(df)
# for i in range(0, total_rows, BATCH_SIZE):
# batch_df = df.slice(i, min(BATCH_SIZE, total_rows - i))
# model_collection.add(
# ids=batch_df.select(["modelId"]).to_series().to_list(),
# documents=batch_df.select(["summary"]).to_series().to_list(),
# metadatas=[
# {"likes": int(likes), "downloads": int(downloads)}
# for likes, downloads in zip(
# batch_df.select(["likes"]).to_series().to_list(),
# batch_df.select(["downloads"]).to_series().to_list(),
# )
# ],
# )
# logger.info(f"Processed {i + len(batch_df):,} / {total_rows:,} rows")
# logger.info(f"Database initialized with {model_collection.count():,} rows")
except Exception as e:
logger.error(f"Setup error: {e}")
# Run setup on startup
setup_database()
class QueryResult(BaseModel):
dataset_id: str
similarity: float
summary: str
likes: int
downloads: int
class QueryResponse(BaseModel):
results: List[QueryResult]
@app.get("/")
async def redirect_to_docs():
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
@app.get("/search/datasets", response_model=QueryResponse)
@cache(ttl="10m")
async def search_datasets(
query: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
# Get collection with proper embedding function
collection = client.get_collection(
name="dataset_cards", embedding_function=get_embedding_function()
)
# Query ChromaDB
results = collection.query(
query_texts=[f"search_query: {query}"],
n_results=k * 4 if sort_by != "similarity" else k,
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
# Process results
query_results = []
for i in range(len(results["ids"][0])):
query_results.append(
QueryResult(
dataset_id=results["ids"][0][i],
similarity=float(results["distances"][0][i]),
summary=results["documents"][0][i],
likes=results["metadatas"][0][i]["likes"],
downloads=results["metadatas"][0][i]["downloads"],
)
)
# Sort results if needed
if sort_by != "similarity":
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
query_results = query_results[:k]
return QueryResponse(results=query_results)
except Exception as e:
logger.error(f"Search error: {str(e)}")
raise HTTPException(status_code=500, detail="Search failed")
@app.get("/similarity/datasets", response_model=QueryResponse)
@cache(ttl="10m")
async def find_similar_datasets(
dataset_id: str,
k: int = Query(default=5, ge=1, le=100),
sort_by: str = Query(
default="similarity", enum=["similarity", "likes", "downloads"]
),
min_likes: int = Query(default=0, ge=0),
min_downloads: int = Query(default=0, ge=0),
):
try:
collection = client.get_collection("dataset_cards")
# Get the reference document
results = collection.get(ids=[dataset_id], include=["embeddings"])
if not results["ids"]:
raise HTTPException(
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
)
# Query using the embedding
results = collection.query(
query_embeddings=[results["embeddings"][0]],
n_results=k * 4
if sort_by != "similarity"
else k + 1, # +1 to account for self-match
where={
"$and": [
{"likes": {"$gte": min_likes}},
{"downloads": {"$gte": min_downloads}},
]
}
if min_likes > 0 or min_downloads > 0
else None,
)
# Process results (excluding the query dataset itself)
query_results = []
for i in range(len(results["ids"][0])):
if results["ids"][0][i] != dataset_id:
query_results.append(
QueryResult(
dataset_id=results["ids"][0][i],
similarity=float(results["distances"][0][i]),
summary=results["documents"][0][i],
likes=results["metadatas"][0][i]["likes"],
downloads=results["metadatas"][0][i]["downloads"],
)
)
# Sort results if needed
if sort_by != "similarity":
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
query_results = query_results[:k]
else:
query_results = query_results[:k]
return QueryResponse(results=query_results)
except HTTPException:
raise
except Exception as e:
logger.error(f"Similarity search error: {str(e)}")
raise HTTPException(status_code=500, detail="Similarity search failed")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)