File size: 3,527 Bytes
1d040cb
 
 
 
fbcd930
1d040cb
 
963c6da
1d040cb
fbcd930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963c6da
 
fbcd930
 
 
 
 
 
 
 
963c6da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#import streamlit as st
#from src.load_data import load_dataframe, sort_by
#from src.plot import plot_radar_chart_index, plot_radar_chart_name
#from st_aggrid import GridOptionsBuilder, AgGrid
from st_aggrid import GridOptionsBuilder, AgGrid
import streamlit as st
from .load_data import load_dataframe, sort_by
from .plot import plot_radar_chart_name, plot_radar_chart_rows


def display_app():
    st.markdown("# Open LLM Leaderboard Viz")
    st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
    st.markdown("To select a model, click on the checkbox beside its name.")

    
    
    #container = st.container(height = 150)

    dataframe = load_dataframe()

    sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns))
    ascending = True
    indexes = None
    if sort_selection is None:
        sort_selection = "model_name"
        ascending = True
    elif sort_selection == "model_name":
        ascending = True
    else:
        ascending = False
    name = st.text_input(label = ":mag: Search by name")
    if name is not None:
        indexes = dataframe["model_name"].str.contains(name)
        if len(indexes) > 0:
            dataframe = dataframe[indexes]
        else:
            dataframe = load_dataframe()

    dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
    dataframe_display = dataframe.copy()
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)

    #Infer basic colDefs from dataframe types
    gb = GridOptionsBuilder.from_dataframe(dataframe_display)
    gb.configure_selection(selection_mode = "single", use_checkbox=True)
    gb.configure_grid_options(domLayout='normal')
    gridOptions = gb.build()

    column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small")

    with column1:
        #with container:
        #st.dataframe(dataframe_display)
        grid_response = AgGrid(
    dataframe_display, 
    gridOptions=gridOptions,
    height=300, 
    width='40%'
    )
        
    subdata = dataframe.head(1)
    if len(subdata) > 0:
        model_name = subdata["model_name"].values[0]
    else:
        model_name = ""

    with column2:
        if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
            figure = plot_radar_chart_rows(rows=grid_response['selected_rows'])
            st.plotly_chart(figure, use_container_width=True)
        else:
            if len(subdata)>0:
                figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
                st.plotly_chart(figure, use_container_width=True)

    if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
        st.markdown("**Model name:**   %s" % grid_response['selected_rows'][0]["model_name"])
    else:
        st.markdown("**Model name:**   %s" % model_name)