File size: 3,050 Bytes
a7be586
2085332
 
 
 
 
 
 
 
 
55f6c76
 
2085332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42f30e0
2085332
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from datetime import datetime
from functools import lru_cache
from typing import List, Optional

from datasets import Dataset
from huggingface_hub import list_models
from pydantic import BaseModel, field_validator
from toolz import groupby
from tqdm.auto import tqdm

HF_TOKEN = os.environ.get("HF_TOKEN")


@lru_cache()
def get_all_models():
    models = list(
        tqdm(
            iter(
                list_models(
                    cardData=True, full=True, limit=None, sort="downloads", direction=-1
                )
            )
        )
    )
    return [model for model in models if model is not None]


def has_base_model_info(model):
    return bool(hasattr(model.card_data, "base_model"))


class HubModel(BaseModel):
    author: Optional[str] = None
    last_modified: Optional[datetime] = None
    createdAt: Optional[datetime] = None
    downloads: Optional[int] = None
    likes: Optional[int] = None
    library_name: Optional[str] = None
    modelId: Optional[str] = None
    datasets: Optional[List[str]] = None
    language: Optional[List[str]] = None
    base_model: Optional[str] = None

    @field_validator("language", "datasets", mode="before")
    def ensure_list(cls, v):
        return [v] if isinstance(v, str) else v

    @classmethod
    def from_original(cls, original_data: dict) -> "HubModel":
        card_data = original_data.get("card_data", {})
        if card_data is None:
            card_data = {}
        if not isinstance(card_data, dict):
            card_data = card_data.__dict__

        return cls(
            author=original_data.get("author"),
            last_modified=original_data.get("last_modified"),
            createdAt=original_data.get("createdAt"),
            downloads=original_data.get("downloads"),
            likes=original_data.get("likes"),
            library_name=original_data.get("library_name"),
            modelId=original_data.get("modelId"),
            datasets=card_data.get("datasets"),
            language=card_data.get("language"),
            base_model=card_data.get("base_model"),
        )


def load_data():
    grouped_by_has_base_model_info = groupby(has_base_model_info, get_all_models())
    models_with_base_model_info = grouped_by_has_base_model_info.get(True)
    models_without_base_models = grouped_by_has_base_model_info.get(False)

    parsed_models = [
        HubModel.from_original(model.__dict__).model_dump()
        for model in models_with_base_model_info
    ]

    base_models = {model["base_model"] for model in parsed_models}
    base_models = [
        model for model in tqdm(models_without_base_models) if model.id in base_models
    ]
    base_models = [model for model in base_models if model is not None]
    parsed_base_models = [
        HubModel.from_original(model.__dict__).model_dump() for model in base_models
    ]
    ds = Dataset.from_list(parsed_models + parsed_base_models)
    ds.push_to_hub("librarian-bots/hub_models_with_base_model_info", token=HF_TOKEN)
    print("Pushed to hub")
    return ds