davanstrien HF Staff commited on
Commit
2085332
·
1 Parent(s): 61fc1ff

Add script to load base model data

Browse files
Files changed (1) hide show
  1. load_base_model_data.py +89 -0
load_base_model_data.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import lru_cache
3
+ from typing import List, Optional
4
+
5
+ from datasets import Dataset
6
+ from huggingface_hub import list_models
7
+ from pydantic import BaseModel, field_validator
8
+ from toolz import groupby
9
+ from tqdm.auto import tqdm
10
+
11
+
12
+ @lru_cache()
13
+ def get_all_models():
14
+ models = list(
15
+ tqdm(
16
+ iter(
17
+ list_models(
18
+ cardData=True, full=True, limit=None, sort="downloads", direction=-1
19
+ )
20
+ )
21
+ )
22
+ )
23
+ return [model for model in models if model is not None]
24
+
25
+
26
+ def has_base_model_info(model):
27
+ return bool(hasattr(model.card_data, "base_model"))
28
+
29
+
30
+ class HubModel(BaseModel):
31
+ author: Optional[str] = None
32
+ last_modified: Optional[datetime] = None
33
+ createdAt: Optional[datetime] = None
34
+ downloads: Optional[int] = None
35
+ likes: Optional[int] = None
36
+ library_name: Optional[str] = None
37
+ modelId: Optional[str] = None
38
+ datasets: Optional[List[str]] = None
39
+ language: Optional[List[str]] = None
40
+ base_model: Optional[str] = None
41
+
42
+ @field_validator("language", "datasets", mode="before")
43
+ def ensure_list(cls, v):
44
+ return [v] if isinstance(v, str) else v
45
+
46
+ @classmethod
47
+ def from_original(cls, original_data: dict) -> "HubModel":
48
+ card_data = original_data.get("card_data", {})
49
+ if card_data is None:
50
+ card_data = {}
51
+ if not isinstance(card_data, dict):
52
+ card_data = card_data.__dict__
53
+
54
+ return cls(
55
+ author=original_data.get("author"),
56
+ last_modified=original_data.get("last_modified"),
57
+ createdAt=original_data.get("createdAt"),
58
+ downloads=original_data.get("downloads"),
59
+ likes=original_data.get("likes"),
60
+ library_name=original_data.get("library_name"),
61
+ modelId=original_data.get("modelId"),
62
+ datasets=card_data.get("datasets"),
63
+ language=card_data.get("language"),
64
+ base_model=card_data.get("base_model"),
65
+ )
66
+
67
+
68
+ def load_data():
69
+ grouped_by_has_base_model_info = groupby(has_base_model_info, get_all_models())
70
+ models_with_base_model_info = grouped_by_has_base_model_info.get(True)
71
+ models_without_base_models = grouped_by_has_base_model_info.get(False)
72
+
73
+ parsed_models = [
74
+ HubModel.from_original(model.__dict__).model_dump()
75
+ for model in models_with_base_model_info
76
+ ]
77
+
78
+ base_models = {model["base_model"] for model in parsed_models}
79
+ base_models = [
80
+ model for model in tqdm(models_without_base_models) if model.id in base_models
81
+ ]
82
+ base_models = [model for model in base_models if model is not None]
83
+ parsed_base_models = [
84
+ HubModel.from_original(model.__dict__).model_dump() for model in base_models
85
+ ]
86
+ ds = Dataset.from_list(parsed_models + parsed_base_models)
87
+ ds.push_to_hub("davanstrien/hub_models_with_base_model_info")
88
+ print("Pushed to hub")
89
+ return ds