davanstrien HF Staff commited on
Commit
d849643
·
1 Parent(s): a0c28a9

add trending models and datasets fetching endpoints with summaries

Browse files
Files changed (1) hide show
  1. main.py +153 -0
main.py CHANGED
@@ -14,12 +14,15 @@ from huggingface_hub import HfApi
14
  from transformers import AutoTokenizer
15
  import torch
16
  import dateutil.parser
 
 
17
 
18
  # Configuration constants
19
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
20
  EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
21
  BATCH_SIZE = 2000
22
  CACHE_TTL = "60"
 
23
 
24
  if torch.cuda.is_available():
25
  DEVICE = "cuda"
@@ -463,6 +466,156 @@ def process_search_results(results, id_field, k, sort_by, exclude_id=None):
463
  return query_results
464
 
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  if __name__ == "__main__":
467
  import uvicorn
468
 
 
14
  from transformers import AutoTokenizer
15
  import torch
16
  import dateutil.parser
17
+ import httpx
18
+ from datetime import datetime
19
 
20
  # Configuration constants
21
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
22
  EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
23
  BATCH_SIZE = 2000
24
  CACHE_TTL = "60"
25
+ TRENDING_CACHE_TTL = "900" # 15 minutes cache for trending data
26
 
27
  if torch.cuda.is_available():
28
  DEVICE = "cuda"
 
466
  return query_results
467
 
468
 
469
+ async def fetch_trending_models():
470
+ """Fetch trending models from HuggingFace API"""
471
+ async with httpx.AsyncClient() as client:
472
+ response = await client.get("https://huggingface.co/api/models")
473
+ response.raise_for_status()
474
+ return response.json()
475
+
476
+
477
+ @cache(ttl=TRENDING_CACHE_TTL)
478
+ async def get_trending_models_with_summaries(
479
+ limit: int = 10,
480
+ min_likes: int = 0,
481
+ min_downloads: int = 0,
482
+ ) -> List[ModelQueryResult]:
483
+ """Fetch trending models and combine with summaries from database"""
484
+ try:
485
+ # Fetch trending models
486
+ trending_models = await fetch_trending_models()
487
+
488
+ # Filter by minimum likes/downloads
489
+ trending_models = [
490
+ model
491
+ for model in trending_models
492
+ if model.get("likes", 0) >= min_likes
493
+ and model.get("downloads", 0) >= min_downloads
494
+ ]
495
+
496
+ # Sort by trending score and limit
497
+ trending_models = sorted(
498
+ trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True
499
+ )[:limit]
500
+
501
+ # Get model IDs
502
+ model_ids = [model["modelId"] for model in trending_models]
503
+
504
+ # Fetch summaries from ChromaDB
505
+ collection = client.get_collection("model_cards")
506
+ summaries = collection.get(ids=model_ids, include=["documents"])
507
+
508
+ # Create mapping of model_id to summary
509
+ id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
510
+
511
+ # Combine data
512
+ results = []
513
+ for model in trending_models:
514
+ if model["modelId"] in id_to_summary:
515
+ result = ModelQueryResult(
516
+ model_id=model["modelId"],
517
+ similarity=1.0, # Not applicable for trending
518
+ summary=id_to_summary[model["modelId"]],
519
+ likes=model.get("likes", 0),
520
+ downloads=model.get("downloads", 0),
521
+ )
522
+ results.append(result)
523
+
524
+ return results
525
+
526
+ except Exception as e:
527
+ logger.error(f"Error fetching trending models: {str(e)}")
528
+ raise HTTPException(status_code=500, detail="Failed to fetch trending models")
529
+
530
+
531
+ @app.get("/trending/models", response_model=ModelQueryResponse)
532
+ async def get_trending_models(
533
+ limit: int = Query(default=10, ge=1, le=100),
534
+ min_likes: int = Query(default=0, ge=0),
535
+ min_downloads: int = Query(default=0, ge=0),
536
+ ):
537
+ """Get trending models with their summaries"""
538
+ results = await get_trending_models_with_summaries(
539
+ limit=limit, min_likes=min_likes, min_downloads=min_downloads
540
+ )
541
+ return ModelQueryResponse(results=results)
542
+
543
+
544
+ async def fetch_trending_datasets():
545
+ """Fetch trending datasets from HuggingFace API"""
546
+ async with httpx.AsyncClient() as client:
547
+ response = await client.get("https://huggingface.co/api/datasets")
548
+ response.raise_for_status()
549
+ return response.json()
550
+
551
+
552
+ @cache(ttl=TRENDING_CACHE_TTL)
553
+ async def get_trending_datasets_with_summaries(
554
+ limit: int = 10,
555
+ min_likes: int = 0,
556
+ min_downloads: int = 0,
557
+ ) -> List[QueryResult]:
558
+ """Fetch trending datasets and combine with summaries from database"""
559
+ try:
560
+ # Fetch trending datasets
561
+ trending_datasets = await fetch_trending_datasets()
562
+
563
+ # Filter by minimum likes/downloads
564
+ trending_datasets = [
565
+ dataset
566
+ for dataset in trending_datasets
567
+ if dataset.get("likes", 0) >= min_likes
568
+ and dataset.get("downloads", 0) >= min_downloads
569
+ ]
570
+
571
+ # Sort by trending score and limit
572
+ trending_datasets = sorted(
573
+ trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True
574
+ )[:limit]
575
+
576
+ # Get dataset IDs
577
+ dataset_ids = [dataset["id"] for dataset in trending_datasets]
578
+
579
+ # Fetch summaries from ChromaDB
580
+ collection = client.get_collection("dataset_cards")
581
+ summaries = collection.get(ids=dataset_ids, include=["documents"])
582
+
583
+ # Create mapping of dataset_id to summary
584
+ id_to_summary = dict(zip(summaries["ids"], summaries["documents"]))
585
+
586
+ # Combine data
587
+ results = []
588
+ for dataset in trending_datasets:
589
+ if dataset["id"] in id_to_summary:
590
+ result = QueryResult(
591
+ dataset_id=dataset["id"],
592
+ similarity=1.0, # Not applicable for trending
593
+ summary=id_to_summary[dataset["id"]],
594
+ likes=dataset.get("likes", 0),
595
+ downloads=dataset.get("downloads", 0),
596
+ )
597
+ results.append(result)
598
+
599
+ return results
600
+
601
+ except Exception as e:
602
+ logger.error(f"Error fetching trending datasets: {str(e)}")
603
+ raise HTTPException(status_code=500, detail="Failed to fetch trending datasets")
604
+
605
+
606
+ @app.get("/trending/datasets", response_model=QueryResponse)
607
+ async def get_trending_datasets(
608
+ limit: int = Query(default=10, ge=1, le=100),
609
+ min_likes: int = Query(default=0, ge=0),
610
+ min_downloads: int = Query(default=0, ge=0),
611
+ ):
612
+ """Get trending datasets with their summaries"""
613
+ results = await get_trending_datasets_with_summaries(
614
+ limit=limit, min_likes=min_likes, min_downloads=min_downloads
615
+ )
616
+ return QueryResponse(results=results)
617
+
618
+
619
  if __name__ == "__main__":
620
  import uvicorn
621