File size: 2,833 Bytes
88faaa4
885c364
af674e3
 
885c364
 
88faaa4
 
 
 
 
 
 
885c364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88faaa4
885c364
 
 
 
 
88faaa4
 
885c364
 
 
 
 
88faaa4
885c364
 
88faaa4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import altair as alt
import gradio as gr
import pandas as pd

from datasets import load_dataset

model_id = "ybelkada/model_cards_correct_tag"
dataset = load_dataset(model_id, split="train").to_pandas()

# Convert dataset to a pandas DataFrame and sort by commit_dates
df = pd.DataFrame(dataset)
df["commit_dates"] = pd.to_datetime(df["commit_dates"])  # Convert commit_dates to datetime format
df = df.sort_values(by="commit_dates")
melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type')

df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100
ratio_df = df = df[['commit_dates', 'ratio']].copy()

def make_plot(plot_type):
    if plot_type == "Total models with missing 'transformers' tag":
        highlight = alt.selection(type='single', on='mouseover',
                                fields=['type'], nearest=True)


        base = alt.Chart(melted_df).encode(
            x=alt.X('commit_dates:T', title='Date'),
            y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"),
            color='type:N',
        )

        points = base.mark_circle().encode(
            opacity=alt.value(1),
        ).add_selection(
            highlight
        ).properties(
            width=1200,
            height=800,
        )

        lines = base.mark_line().encode(
            size=alt.condition(~highlight, alt.value(1), alt.value(3))
        )

        return points + lines
    else:
        highlight = alt.selection(type='single', on='mouseover',
                                fields=['ratio'], nearest=True)

        base = alt.Chart(ratio_df).encode(
            x=alt.X('commit_dates:T', title='Date'),
            y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"),
        )

        points = base.mark_circle().encode(
            opacity=alt.value(1)
        ).add_selection(
            highlight
        ).properties(
            width=1200,
            height=800,
        )

        lines = base.mark_line().encode(
            size=alt.condition(~highlight, alt.value(1), alt.value(3))
        )
        
        return points + lines

with gr.Blocks() as demo:
    button = gr.Radio(
        label="Plot type",
        choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"], 
        value="Total models with missing 'transformers' tag"
    )
    plot = gr.Plot(label="Plot")
    button.change(make_plot, inputs=button, outputs=[plot])
    demo.load(make_plot, inputs=[button], outputs=[plot])

if __name__ == "__main__":
    demo.launch()