|
|
|
|
|
|
|
|
|
from st_aggrid import GridOptionsBuilder, AgGrid |
|
import streamlit as st |
|
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name |
|
from .plot import plot_radar_chart_name, plot_radar_chart_rows |
|
|
|
|
|
def display_app(): |
|
st.markdown("# Open LLM Leaderboard Viz") |
|
st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") |
|
st.markdown("To select a model, click on the checkbox beside its name.") |
|
st.markdown("This displays the top 100 models by default, but you can change that using the number input in the sidebar.") |
|
st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.") |
|
st.markdown("If your model doesn't show up, please search it by its name.") |
|
|
|
dataframe = load_dataframe() |
|
|
|
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7) |
|
number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100) |
|
ascending = True |
|
|
|
if sort_selection is None: |
|
sort_selection = "model_name" |
|
ascending = True |
|
elif sort_selection == "model_name": |
|
ascending = True |
|
else: |
|
ascending = False |
|
|
|
|
|
name = st.text_input(label = ":mag: Search by name") |
|
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0) |
|
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.") |
|
len_name_input = len(name) |
|
if len_name_input > 0: |
|
dataframe_by_search = search_by_name(name) |
|
if len(dataframe_by_search) > 0: |
|
|
|
dataframe = dataframe_by_search |
|
else: |
|
dataframe = load_dataframe() |
|
|
|
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) |
|
dataframe_display = dataframe.copy() |
|
|
|
if len_name_input == 0: |
|
|
|
dataframe_display = show_dataframe_top(number_of_row,dataframe_display) |
|
|
|
|
|
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) |
|
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 |
|
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) |
|
|
|
|
|
gb = GridOptionsBuilder.from_dataframe(dataframe_display) |
|
gb.configure_selection(selection_mode = selection_mode, use_checkbox=True) |
|
gb.configure_grid_options(domLayout='normal') |
|
gridOptions = gb.build() |
|
|
|
column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") |
|
|
|
with column1: |
|
grid_response = AgGrid( |
|
dataframe_display, |
|
gridOptions=gridOptions, |
|
height=300, |
|
width='40%' |
|
) |
|
|
|
subdata = dataframe.head(1) |
|
if len(subdata) > 0: |
|
model_name = subdata["model_name"].values[0] |
|
else: |
|
model_name = "" |
|
|
|
with column2: |
|
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: |
|
figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3]) |
|
|
|
st.plotly_chart(figure, use_container_width=False) |
|
else: |
|
if len(subdata)>0: |
|
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) |
|
st.plotly_chart(figure, use_container_width=True) |
|
|
|
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1: |
|
n_col = len(grid_response['selected_rows']) if len(grid_response['selected_rows']) <=3 else 3 |
|
st.markdown("## Models") |
|
columns = st.columns(n_col) |
|
for i in range(n_col): |
|
with columns[i]: |
|
st.markdown("**Model name:** %s" % grid_response['selected_rows'][i]["model_name"]) |
|
else: |
|
st.markdown("**Model name:** %s" % model_name) |