davanstrien HF staff commited on
Commit
97ab261
·
1 Parent(s): d574b22

add trending sorting option and fetch trending scores for datasets and models

Browse files
Files changed (1) hide show
  1. main.py +67 -27
main.py CHANGED
@@ -1,21 +1,23 @@
 
1
  import logging
2
  import os
3
- from typing import List
4
  import sys
 
 
 
 
5
  import chromadb
6
- from chromadb.utils import embedding_functions
 
 
 
7
  from cashews import cache
 
8
  from fastapi import FastAPI, HTTPException, Query
9
  from fastapi.middleware.cors import CORSMiddleware
 
10
  from pydantic import BaseModel
11
- from contextlib import asynccontextmanager
12
- import polars as pl
13
- from huggingface_hub import HfApi
14
  from transformers import AutoTokenizer
15
- import torch
16
- import dateutil.parser
17
- import httpx
18
- from datetime import datetime
19
 
20
  # Configuration constants
21
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
@@ -272,18 +274,16 @@ async def search_datasets(
272
  query: str,
273
  k: int = Query(default=5, ge=1, le=100),
274
  sort_by: str = Query(
275
- default="similarity", enum=["similarity", "likes", "downloads"]
276
  ),
277
  min_likes: int = Query(default=0, ge=0),
278
  min_downloads: int = Query(default=0, ge=0),
279
  ):
280
  try:
281
- # Get collection with proper embedding function
282
  collection = client.get_collection(
283
  name="dataset_cards", embedding_function=get_embedding_function()
284
  )
285
 
286
- # Query ChromaDB
287
  results = collection.query(
288
  query_texts=[f"search_query: {query}"],
289
  n_results=k * 4 if sort_by != "similarity" else k,
@@ -297,8 +297,7 @@ async def search_datasets(
297
  else None,
298
  )
299
 
300
- # Process results
301
- query_results = process_search_results(results, "dataset", k, sort_by)
302
 
303
  return QueryResponse(results=query_results)
304
 
@@ -313,7 +312,7 @@ async def find_similar_datasets(
313
  dataset_id: str,
314
  k: int = Query(default=5, ge=1, le=100),
315
  sort_by: str = Query(
316
- default="similarity", enum=["similarity", "likes", "downloads"]
317
  ),
318
  min_likes: int = Query(default=0, ge=0),
319
  min_downloads: int = Query(default=0, ge=0),
@@ -321,7 +320,6 @@ async def find_similar_datasets(
321
  try:
322
  collection = client.get_collection("dataset_cards")
323
 
324
- # Get the reference document
325
  results = collection.get(ids=[dataset_id], include=["embeddings"])
326
 
327
  if not results["ids"]:
@@ -329,12 +327,9 @@ async def find_similar_datasets(
329
  status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
330
  )
331
 
332
- # Query using the embedding
333
  results = collection.query(
334
  query_embeddings=[results["embeddings"][0]],
335
- n_results=k * 4
336
- if sort_by != "similarity"
337
- else k + 1, # +1 to account for self-match
338
  where={
339
  "$and": [
340
  {"likes": {"$gte": min_likes}},
@@ -345,8 +340,7 @@ async def find_similar_datasets(
345
  else None,
346
  )
347
 
348
- # Process results (excluding the query dataset itself)
349
- query_results = process_search_results(
350
  results, "dataset", k, sort_by, dataset_id
351
  )
352
 
@@ -365,7 +359,7 @@ async def search_models(
365
  query: str,
366
  k: int = Query(default=5, ge=1, le=100),
367
  sort_by: str = Query(
368
- default="similarity", enum=["similarity", "likes", "downloads"]
369
  ),
370
  min_likes: int = Query(default=0, ge=0),
371
  min_downloads: int = Query(default=0, ge=0),
@@ -388,7 +382,7 @@ async def search_models(
388
  else None,
389
  )
390
 
391
- query_results = process_search_results(results, "model", k, sort_by)
392
 
393
  return ModelQueryResponse(results=query_results)
394
 
@@ -431,7 +425,9 @@ async def find_similar_models(
431
  else None,
432
  )
433
 
434
- query_results = process_search_results(results, "model", k, sort_by, model_id)
 
 
435
 
436
  return ModelQueryResponse(results=query_results)
437
 
@@ -442,9 +438,29 @@ async def find_similar_models(
442
  raise HTTPException(status_code=500, detail="Model similarity search failed")
443
 
444
 
445
- def process_search_results(results, id_field, k, sort_by, exclude_id=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  """Process search results into a standardized format."""
447
  query_results = []
 
 
448
  for i in range(len(results["ids"][0])):
449
  current_id = results["ids"][0][i]
450
  if exclude_id and current_id == exclude_id:
@@ -463,7 +479,31 @@ def process_search_results(results, id_field, k, sort_by, exclude_id=None):
463
  else:
464
  query_results.append(ModelQueryResult(**result))
465
 
466
- if sort_by != "similarity":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
468
  query_results = query_results[:k]
469
  elif exclude_id: # We fetched extra for similarity + exclude_id case
 
1
+ import asyncio
2
  import logging
3
  import os
 
4
  import sys
5
+ from contextlib import asynccontextmanager
6
+ from datetime import datetime
7
+ from typing import List
8
+
9
  import chromadb
10
+ import dateutil.parser
11
+ import httpx
12
+ import polars as pl
13
+ import torch
14
  from cashews import cache
15
+ from chromadb.utils import embedding_functions
16
  from fastapi import FastAPI, HTTPException, Query
17
  from fastapi.middleware.cors import CORSMiddleware
18
+ from huggingface_hub import HfApi, model_info
19
  from pydantic import BaseModel
 
 
 
20
  from transformers import AutoTokenizer
 
 
 
 
21
 
22
  # Configuration constants
23
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
 
274
  query: str,
275
  k: int = Query(default=5, ge=1, le=100),
276
  sort_by: str = Query(
277
+ default="similarity", enum=["similarity", "likes", "downloads", "trending"]
278
  ),
279
  min_likes: int = Query(default=0, ge=0),
280
  min_downloads: int = Query(default=0, ge=0),
281
  ):
282
  try:
 
283
  collection = client.get_collection(
284
  name="dataset_cards", embedding_function=get_embedding_function()
285
  )
286
 
 
287
  results = collection.query(
288
  query_texts=[f"search_query: {query}"],
289
  n_results=k * 4 if sort_by != "similarity" else k,
 
297
  else None,
298
  )
299
 
300
+ query_results = await process_search_results(results, "dataset", k, sort_by)
 
301
 
302
  return QueryResponse(results=query_results)
303
 
 
312
  dataset_id: str,
313
  k: int = Query(default=5, ge=1, le=100),
314
  sort_by: str = Query(
315
+ default="similarity", enum=["similarity", "likes", "downloads", "trending"]
316
  ),
317
  min_likes: int = Query(default=0, ge=0),
318
  min_downloads: int = Query(default=0, ge=0),
 
320
  try:
321
  collection = client.get_collection("dataset_cards")
322
 
 
323
  results = collection.get(ids=[dataset_id], include=["embeddings"])
324
 
325
  if not results["ids"]:
 
327
  status_code=404, detail=f"Dataset ID '{dataset_id}' not found"
328
  )
329
 
 
330
  results = collection.query(
331
  query_embeddings=[results["embeddings"][0]],
332
+ n_results=k * 4 if sort_by != "similarity" else k + 1,
 
 
333
  where={
334
  "$and": [
335
  {"likes": {"$gte": min_likes}},
 
340
  else None,
341
  )
342
 
343
+ query_results = await process_search_results(
 
344
  results, "dataset", k, sort_by, dataset_id
345
  )
346
 
 
359
  query: str,
360
  k: int = Query(default=5, ge=1, le=100),
361
  sort_by: str = Query(
362
+ default="similarity", enum=["similarity", "likes", "downloads", "trending"]
363
  ),
364
  min_likes: int = Query(default=0, ge=0),
365
  min_downloads: int = Query(default=0, ge=0),
 
382
  else None,
383
  )
384
 
385
+ query_results = await process_search_results(results, "model", k, sort_by)
386
 
387
  return ModelQueryResponse(results=query_results)
388
 
 
425
  else None,
426
  )
427
 
428
+ query_results = await process_search_results(
429
+ results, "model", k, sort_by, model_id
430
+ )
431
 
432
  return ModelQueryResponse(results=query_results)
433
 
 
438
  raise HTTPException(status_code=500, detail="Model similarity search failed")
439
 
440
 
441
+ @cache(ttl="1h")
442
+ async def get_trending_score(item_id: str, item_type: str) -> float:
443
+ """Fetch trending score for a model or dataset from HuggingFace API"""
444
+ try:
445
+ async with httpx.AsyncClient() as client:
446
+ endpoint = "models" if item_type == "model" else "datasets"
447
+ response = await client.get(
448
+ f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore"
449
+ )
450
+ response.raise_for_status()
451
+ return response.json().get("trendingScore", 0)
452
+ except Exception as e:
453
+ logger.error(
454
+ f"Error fetching trending score for {item_type} {item_id}: {str(e)}"
455
+ )
456
+ return 0
457
+
458
+
459
+ async def process_search_results(results, id_field, k, sort_by, exclude_id=None):
460
  """Process search results into a standardized format."""
461
  query_results = []
462
+
463
+ # Create base results
464
  for i in range(len(results["ids"][0])):
465
  current_id = results["ids"][0][i]
466
  if exclude_id and current_id == exclude_id:
 
479
  else:
480
  query_results.append(ModelQueryResult(**result))
481
 
482
+ # Handle sorting
483
+ if sort_by == "trending":
484
+ # Fetch trending scores for all results
485
+ trending_scores = {}
486
+ async with httpx.AsyncClient() as client:
487
+ tasks = [
488
+ get_trending_score(
489
+ getattr(result, f"{id_field}_id"),
490
+ "model" if id_field == "model" else "dataset",
491
+ )
492
+ for result in query_results
493
+ ]
494
+ scores = await asyncio.gather(*tasks)
495
+ trending_scores = {
496
+ getattr(result, f"{id_field}_id"): score
497
+ for result, score in zip(query_results, scores)
498
+ }
499
+
500
+ # Sort by trending score
501
+ query_results.sort(
502
+ key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0),
503
+ reverse=True,
504
+ )
505
+ query_results = query_results[:k]
506
+ elif sort_by != "similarity":
507
  query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True)
508
  query_results = query_results[:k]
509
  elif exclude_id: # We fetched extra for similarity + exclude_id case