File size: 3,506 Bytes
182c1d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7a5300
 
182c1d0
 
 
 
 
 
 
 
 
3533641
 
538d051
3533641
538d051
3533641
 
538d051
3533641
538d051
182c1d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3533641
 
e6fb6ec
3533641
 
182c1d0
 
 
 
 
538d051
3533641
538d051
3533641
 
538d051
 
 
 
4653f13
 
3533641
182c1d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from huggingface_hub import list_models
from cachetools import cached, TTLCache
from toolz import groupby, valmap
import gradio as gr
from tqdm.auto import tqdm
import pandas as pd


@cached(TTLCache(maxsize=10, ttl=60 * 60 * 3))
def get_all_models():
    models = list(tqdm(iter(list_models(cardData=True))))
    return [model for model in models if model is not None]


def has_base_model_info(model):
    try:
        if card_data := model.cardData:
            if base_model := card_data.get("base_model"):
                if isinstance(base_model, str):
                    return True
    except AttributeError:
        return False
    return False


grouped_by_has_base_model_info = groupby(has_base_model_info, get_all_models())
print(valmap(len, grouped_by_has_base_model_info))

summary = f"""{len(grouped_by_has_base_model_info.get(True)):,} models have base model info. 
            {len(grouped_by_has_base_model_info.get(False)):,} models don't have base model info.
            Currently {round(len(grouped_by_has_base_model_info.get(True))/len(get_all_models())*100,2)}% of models have base model info."""

models_with_base_model_info = grouped_by_has_base_model_info.get(True)
base_models = [
    model.cardData.get("base_model") for model in models_with_base_model_info
]
df = pd.DataFrame(
    pd.DataFrame({"base_model": base_models}).value_counts()
).reset_index()
df_with_org = df.copy(deep=True)


def parse_org(hub_id):
    parts = hub_id.split("/")
    return parts[0] if len(parts) == 2 else None


df_with_org["org"] = df_with_org["base_model"].apply(parse_org)
df_with_org = df_with_org.dropna(subset=["org"])

grouped_by_base_model = groupby(
    lambda x: x.cardData.get("base_model"), models_with_base_model_info
)

all_base_models = df["base_model"].to_list()


def return_models_for_base_model(base_model):
    models = grouped_by_base_model.get(base_model)
    # sort models by downloads
    models = sorted(models, key=lambda x: x.downloads, reverse=True)
    results = ""
    results += f"## {base_model} children\n\n"
    results += f"{base_model} has {len(models)} children\n\n"
    for model in models:
        url = f"https://huggingface.co/{model.modelId}"
        results += (
            f"[{model.modelId}]({url}) | number of downloads {model.downloads}" + "\n\n"
        )
    return results


with gr.Blocks() as demo:
    gr.Markdown("# Base model explorer")
    gr.Markdown(
        """When sharing models to the Hub it is possible to specify a base model in the model card i.e. that your model is a fine-tuned version of [bert-base-cased](https://huggingface.co/bert-base-cased). 
                This Space allows you to find children models for a given base model and view the popularity of models for fine-tuning."""
    )
    gr.Markdown(summary)
    gr.Markdown("### Find all models trained from a base model")
    base_model = gr.Dropdown(all_base_models, label="Base Model")
    results = gr.Markdown()
    base_model.change(return_models_for_base_model, base_model, results)
    with gr.Accordion("Base model popularity ranking", open=False):
        gr.DataFrame(df.head(50))
    with gr.Accordion("Base model popularity ranking by organisation", open=False):
        gr.DataFrame(
            pd.DataFrame(
                df_with_org.groupby("org")["count"]
                .sum()
                .sort_values(ascending=False)
                .head(50)
            ).reset_index().sort_values("count", ascending=False)
        )


demo.launch()