Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
97ab261
1
Parent(s):
d574b22
add trending sorting option and fetch trending scores for datasets and models
Browse files
main.py
CHANGED
@@ -1,21 +1,23 @@
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
-
from typing import List
|
4 |
import sys
|
|
|
|
|
|
|
|
|
5 |
import chromadb
|
6 |
-
|
|
|
|
|
|
|
7 |
from cashews import cache
|
|
|
8 |
from fastapi import FastAPI, HTTPException, Query
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
10 |
from pydantic import BaseModel
|
11 |
-
from contextlib import asynccontextmanager
|
12 |
-
import polars as pl
|
13 |
-
from huggingface_hub import HfApi
|
14 |
from transformers import AutoTokenizer
|
15 |
-
import torch
|
16 |
-
import dateutil.parser
|
17 |
-
import httpx
|
18 |
-
from datetime import datetime
|
19 |
|
20 |
# Configuration constants
|
21 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
@@ -272,18 +274,16 @@ async def search_datasets(
|
|
272 |
query: str,
|
273 |
k: int = Query(default=5, ge=1, le=100),
|
274 |
sort_by: str = Query(
|
275 |
-
default="similarity", enum=["similarity", "likes", "downloads"]
|
276 |
),
|
277 |
min_likes: int = Query(default=0, ge=0),
|
278 |
min_downloads: int = Query(default=0, ge=0),
|
279 |
):
|
280 |
try:
|
281 |
-
# Get collection with proper embedding function
|
282 |
collection = client.get_collection(
|
283 |
name="dataset_cards", embedding_function=get_embedding_function()
|
284 |
)
|
285 |
|
286 |
-
# Query ChromaDB
|
287 |
results = collection.query(
|
288 |
query_texts=[f"search_query: {query}"],
|
289 |
n_results=k * 4 if sort_by != "similarity" else k,
|
@@ -297,8 +297,7 @@ async def search_datasets(
|
|
297 |
else None,
|
298 |
)
|
299 |
|
300 |
-
|
301 |
-
query_results = process_search_results(results, "dataset", k, sort_by)
|
302 |
|
303 |
return QueryResponse(results=query_results)
|
304 |
|
@@ -313,7 +312,7 @@ async def find_similar_datasets(
|
|
313 |
dataset_id: str,
|
314 |
k: int = Query(default=5, ge=1, le=100),
|
315 |
sort_by: str = Query(
|
316 |
-
default="similarity", enum=["similarity", "likes", "downloads"]
|
317 |
),
|
318 |
min_likes: int = Query(default=0, ge=0),
|
319 |
min_downloads: int = Query(default=0, ge=0),
|
@@ -321,7 +320,6 @@ async def find_similar_datasets(
|
|
321 |
try:
|
322 |
collection = client.get_collection("dataset_cards")
|
323 |
|
324 |
-
# Get the reference document
|
325 |
results = collection.get(ids=[dataset_id], include=["embeddings"])
|
326 |
|
327 |
if not results["ids"]:
|
@@ -329,12 +327,9 @@ async def find_similar_datasets(
|
|
329 |
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
|
330 |
)
|
331 |
|
332 |
-
# Query using the embedding
|
333 |
results = collection.query(
|
334 |
query_embeddings=[results["embeddings"][0]],
|
335 |
-
n_results=k * 4
|
336 |
-
if sort_by != "similarity"
|
337 |
-
else k + 1, # +1 to account for self-match
|
338 |
where={
|
339 |
"$and": [
|
340 |
{"likes": {"$gte": min_likes}},
|
@@ -345,8 +340,7 @@ async def find_similar_datasets(
|
|
345 |
else None,
|
346 |
)
|
347 |
|
348 |
-
|
349 |
-
query_results = process_search_results(
|
350 |
results, "dataset", k, sort_by, dataset_id
|
351 |
)
|
352 |
|
@@ -365,7 +359,7 @@ async def search_models(
|
|
365 |
query: str,
|
366 |
k: int = Query(default=5, ge=1, le=100),
|
367 |
sort_by: str = Query(
|
368 |
-
default="similarity", enum=["similarity", "likes", "downloads"]
|
369 |
),
|
370 |
min_likes: int = Query(default=0, ge=0),
|
371 |
min_downloads: int = Query(default=0, ge=0),
|
@@ -388,7 +382,7 @@ async def search_models(
|
|
388 |
else None,
|
389 |
)
|
390 |
|
391 |
-
query_results = process_search_results(results, "model", k, sort_by)
|
392 |
|
393 |
return ModelQueryResponse(results=query_results)
|
394 |
|
@@ -431,7 +425,9 @@ async def find_similar_models(
|
|
431 |
else None,
|
432 |
)
|
433 |
|
434 |
-
query_results = process_search_results(
|
|
|
|
|
435 |
|
436 |
return ModelQueryResponse(results=query_results)
|
437 |
|
@@ -442,9 +438,29 @@ async def find_similar_models(
|
|
442 |
raise HTTPException(status_code=500, detail="Model similarity search failed")
|
443 |
|
444 |
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
"""Process search results into a standardized format."""
|
447 |
query_results = []
|
|
|
|
|
448 |
for i in range(len(results["ids"][0])):
|
449 |
current_id = results["ids"][0][i]
|
450 |
if exclude_id and current_id == exclude_id:
|
@@ -463,7 +479,31 @@ def process_search_results(results, id_field, k, sort_by, exclude_id=None):
|
|
463 |
else:
|
464 |
query_results.append(ModelQueryResult(**result))
|
465 |
|
466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
|
468 |
query_results = query_results[:k]
|
469 |
elif exclude_id: # We fetched extra for similarity + exclude_id case
|
|
|
1 |
+
import asyncio
|
2 |
import logging
|
3 |
import os
|
|
|
4 |
import sys
|
5 |
+
from contextlib import asynccontextmanager
|
6 |
+
from datetime import datetime
|
7 |
+
from typing import List
|
8 |
+
|
9 |
import chromadb
|
10 |
+
import dateutil.parser
|
11 |
+
import httpx
|
12 |
+
import polars as pl
|
13 |
+
import torch
|
14 |
from cashews import cache
|
15 |
+
from chromadb.utils import embedding_functions
|
16 |
from fastapi import FastAPI, HTTPException, Query
|
17 |
from fastapi.middleware.cors import CORSMiddleware
|
18 |
+
from huggingface_hub import HfApi, model_info
|
19 |
from pydantic import BaseModel
|
|
|
|
|
|
|
20 |
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Configuration constants
|
23 |
MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
|
|
|
274 |
query: str,
|
275 |
k: int = Query(default=5, ge=1, le=100),
|
276 |
sort_by: str = Query(
|
277 |
+
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
|
278 |
),
|
279 |
min_likes: int = Query(default=0, ge=0),
|
280 |
min_downloads: int = Query(default=0, ge=0),
|
281 |
):
|
282 |
try:
|
|
|
283 |
collection = client.get_collection(
|
284 |
name="dataset_cards", embedding_function=get_embedding_function()
|
285 |
)
|
286 |
|
|
|
287 |
results = collection.query(
|
288 |
query_texts=[f"search_query: {query}"],
|
289 |
n_results=k * 4 if sort_by != "similarity" else k,
|
|
|
297 |
else None,
|
298 |
)
|
299 |
|
300 |
+
query_results = await process_search_results(results, "dataset", k, sort_by)
|
|
|
301 |
|
302 |
return QueryResponse(results=query_results)
|
303 |
|
|
|
312 |
dataset_id: str,
|
313 |
k: int = Query(default=5, ge=1, le=100),
|
314 |
sort_by: str = Query(
|
315 |
+
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
|
316 |
),
|
317 |
min_likes: int = Query(default=0, ge=0),
|
318 |
min_downloads: int = Query(default=0, ge=0),
|
|
|
320 |
try:
|
321 |
collection = client.get_collection("dataset_cards")
|
322 |
|
|
|
323 |
results = collection.get(ids=[dataset_id], include=["embeddings"])
|
324 |
|
325 |
if not results["ids"]:
|
|
|
327 |
status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
|
328 |
)
|
329 |
|
|
|
330 |
results = collection.query(
|
331 |
query_embeddings=[results["embeddings"][0]],
|
332 |
+
n_results=k * 4 if sort_by != "similarity" else k + 1,
|
|
|
|
|
333 |
where={
|
334 |
"$and": [
|
335 |
{"likes": {"$gte": min_likes}},
|
|
|
340 |
else None,
|
341 |
)
|
342 |
|
343 |
+
query_results = await process_search_results(
|
|
|
344 |
results, "dataset", k, sort_by, dataset_id
|
345 |
)
|
346 |
|
|
|
359 |
query: str,
|
360 |
k: int = Query(default=5, ge=1, le=100),
|
361 |
sort_by: str = Query(
|
362 |
+
default="similarity", enum=["similarity", "likes", "downloads", "trending"]
|
363 |
),
|
364 |
min_likes: int = Query(default=0, ge=0),
|
365 |
min_downloads: int = Query(default=0, ge=0),
|
|
|
382 |
else None,
|
383 |
)
|
384 |
|
385 |
+
query_results = await process_search_results(results, "model", k, sort_by)
|
386 |
|
387 |
return ModelQueryResponse(results=query_results)
|
388 |
|
|
|
425 |
else None,
|
426 |
)
|
427 |
|
428 |
+
query_results = await process_search_results(
|
429 |
+
results, "model", k, sort_by, model_id
|
430 |
+
)
|
431 |
|
432 |
return ModelQueryResponse(results=query_results)
|
433 |
|
|
|
438 |
raise HTTPException(status_code=500, detail="Model similarity search failed")
|
439 |
|
440 |
|
441 |
+
@cache(ttl="1h")
|
442 |
+
async def get_trending_score(item_id: str, item_type: str) -> float:
|
443 |
+
"""Fetch trending score for a model or dataset from HuggingFace API"""
|
444 |
+
try:
|
445 |
+
async with httpx.AsyncClient() as client:
|
446 |
+
endpoint = "models" if item_type == "model" else "datasets"
|
447 |
+
response = await client.get(
|
448 |
+
f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore"
|
449 |
+
)
|
450 |
+
response.raise_for_status()
|
451 |
+
return response.json().get("trendingScore", 0)
|
452 |
+
except Exception as e:
|
453 |
+
logger.error(
|
454 |
+
f"Error fetching trending score for {item_type} {item_id}: {str(e)}"
|
455 |
+
)
|
456 |
+
return 0
|
457 |
+
|
458 |
+
|
459 |
+
async def process_search_results(results, id_field, k, sort_by, exclude_id=None):
|
460 |
"""Process search results into a standardized format."""
|
461 |
query_results = []
|
462 |
+
|
463 |
+
# Create base results
|
464 |
for i in range(len(results["ids"][0])):
|
465 |
current_id = results["ids"][0][i]
|
466 |
if exclude_id and current_id == exclude_id:
|
|
|
479 |
else:
|
480 |
query_results.append(ModelQueryResult(**result))
|
481 |
|
482 |
+
# Handle sorting
|
483 |
+
if sort_by == "trending":
|
484 |
+
# Fetch trending scores for all results
|
485 |
+
trending_scores = {}
|
486 |
+
async with httpx.AsyncClient() as client:
|
487 |
+
tasks = [
|
488 |
+
get_trending_score(
|
489 |
+
getattr(result, f"{id_field}_id"),
|
490 |
+
"model" if id_field == "model" else "dataset",
|
491 |
+
)
|
492 |
+
for result in query_results
|
493 |
+
]
|
494 |
+
scores = await asyncio.gather(*tasks)
|
495 |
+
trending_scores = {
|
496 |
+
getattr(result, f"{id_field}_id"): score
|
497 |
+
for result, score in zip(query_results, scores)
|
498 |
+
}
|
499 |
+
|
500 |
+
# Sort by trending score
|
501 |
+
query_results.sort(
|
502 |
+
key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0),
|
503 |
+
reverse=True,
|
504 |
+
)
|
505 |
+
query_results = query_results[:k]
|
506 |
+
elif sort_by != "similarity":
|
507 |
query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
|
508 |
query_results = query_results[:k]
|
509 |
elif exclude_id: # We fetched extra for similarity + exclude_id case
|