Spaces:
Sleeping
Sleeping
from st_aggrid import GridOptionsBuilder, AgGrid | |
from streamlit_searchbox import st_searchbox | |
import streamlit as st | |
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories | |
from .plot import plot_radar_chart_name, plot_radar_chart_rows | |
def display_app(): | |
st.markdown("# Open LLM Leaderboard Viz") | |
st.markdown("## Some explanations") | |
st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") | |
st.markdown("To select a model, click on the checkbox beside its name, or search it by its name in the search boxes **Model 1, Model 2, or Model 3** bellow.") | |
st.markdown("You can select up to three models using the search boxes and/or the checkboxes.") | |
st.markdown("""In the case you use both the search boxes and the checkboxes, the search boxes will take precedence over the checkboxes, | |
i.e. the models searched using the search boxes will be prioritized over the ones selected using the checkboxes. | |
Please, search models using the search boxes first, and then use the checkboxes. | |
""") | |
st.markdown("This app displays the top 100 models by default, but you can change that using the number input in the sidebar.") | |
st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.") | |
st.markdown("If your model doesn't show up, please search it by its name.") | |
dataframe = load_dataframe() | |
categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"] | |
st.markdown("## Leaderboard") | |
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns.difference(["model_dtype"])), index = 1) | |
d_type_options = ["all", "torch.bfloat16", "torch.float16", "4bit", "8bit"] | |
d_type = st.radio(label = "Filter by dtype", options = d_type_options, index = 0, horizontal = True) | |
number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100) | |
ascending = True | |
if sort_selection is None: | |
sort_selection = "model_name" | |
ascending = True | |
elif sort_selection == "model_name": | |
ascending = True | |
else: | |
ascending = False | |
# Dynamic search boxes | |
def search_model(model_name: str): | |
model_list = None | |
if model_name is not None or model_name != "": | |
models = dataframe["model_name"].str.contains(model_name) | |
model_list = dataframe["model_name"][models] | |
else: | |
model_list = dataframe["model_name"] | |
return model_list | |
model_list = [] | |
#Sidebar configurations | |
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=1) | |
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.") | |
ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.", | |
placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU") | |
ordering_metrics = ordering_metrics.replace(" ", "") | |
ordering_metrics = ordering_metrics.split(",") | |
st.sidebar.markdown(""" | |
As a reminder, here are the different metrics: | |
* ARC | |
* GSM8K | |
* TruthfulQA | |
* Winogrande | |
* HellaSwag | |
* MMLU | |
""") | |
st.sidebar.markdown(""" | |
If there are **typos** in the name of the metrics, or the number of metrics | |
is **different of six**, there will be no effect on the chart and the | |
default ordering will be used. | |
""") | |
valid_categories = validate_categories(ordering_metrics) | |
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) | |
if d_type != "all": | |
dataframe = dataframe[dataframe["model_dtype"] == d_type] | |
dataframe_display = dataframe.copy() | |
dataframe_display = show_dataframe_top(number_of_row,dataframe_display) | |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) | |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 | |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) | |
#Infer basic colDefs from dataframe types | |
gb = GridOptionsBuilder.from_dataframe(dataframe_display) | |
gb.configure_selection(selection_mode = selection_mode, use_checkbox=True) | |
gb.configure_grid_options(domLayout='normal') | |
gridOptions = gb.build() | |
column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") | |
with column1: | |
grid_response = AgGrid( | |
dataframe_display, | |
gridOptions=gridOptions, | |
height=300, | |
width='40%' | |
) | |
model_one = st_searchbox(label = "Model 1", search_function = search_model, key = "model_1", default= None) | |
model_two = st_searchbox(label = "Model 2", search_function = search_model, key = "model_2", default= None) | |
model_three = st_searchbox(label = "Model 3", search_function = search_model, key = "model_3", default= None) | |
if model_one is not None: | |
row = dataframe[dataframe["model_name"] == model_one] | |
row[categories_display] = row[categories_display]*100 | |
model_list.append(row.to_dict("records")[0]) | |
if model_two is not None: | |
row = dataframe[dataframe["model_name"] == model_two] | |
row[categories_display] = row[categories_display]*100 | |
model_list.append(row.to_dict("records")[0]) | |
if model_three is not None: | |
row = dataframe[dataframe["model_name"] == model_three] | |
row[categories_display] = row[categories_display]*100 | |
model_list.append(row.to_dict("records")[0]) | |
subdata = dataframe.head(1) | |
if len(subdata) > 0: | |
model_name = subdata["model_name"].values[0] | |
else: | |
model_name = "" | |
with column2: | |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: | |
figure = None | |
#grid_response is now a Pandas dataframe, we need to | |
# convert to dict in order to merge with the searchboxes' results | |
model_list += grid_response['selected_rows'].to_dict("records") | |
model_list = model_list[:3] | |
model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) | |
if valid_categories: | |
figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) | |
else: | |
figure = plot_radar_chart_rows(rows=model_list) | |
st.plotly_chart(figure, use_container_width=False) | |
elif len(model_list) > 0: | |
figure = None | |
model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) | |
if valid_categories: | |
figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) | |
else: | |
figure = plot_radar_chart_rows(rows=model_list) | |
st.plotly_chart(figure, use_container_width=False) | |
else: | |
if len(subdata)>0: | |
figure = None | |
if valid_categories: | |
figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name) | |
else: | |
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) | |
st.plotly_chart(figure, use_container_width=True) | |
if len(model_list) > 1: | |
n_col = len(model_list) if len(model_list) <=3 else 3 | |
st.markdown("## Models") | |
columns = st.columns(n_col) | |
for i in range(n_col): | |
with columns[i]: | |
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[i]["model_name"] , model_list[i]["model_name"])) | |
st.markdown("**Results:**") | |
st.markdown(""" | |
* Average: %s | |
* ARC: %s | |
* GSM8K: %s | |
* TruthfulQA: %s | |
* Winogrande: %s | |
* HellaSwag: %s | |
* MMLU: %s | |
""" % (round(model_list[i]["Average"],2), | |
round(model_list[i]["ARC"],2), | |
round(model_list[i]["GSM8K"],2), | |
round(model_list[i]["TruthfulQA"],2), | |
round(model_list[i]["Winogrande"],2), | |
round(model_list[i]["HellaSwag"],2), | |
round(model_list[i]["MMLU"],2) | |
)) | |
st.markdown("**dtype:** %s" % model_list[i]["model_dtype"]) | |
elif len(model_list) == 1: | |
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"])) | |
st.markdown("**Results:**") | |
st.markdown(""" | |
* Average: %s | |
* ARC: %s | |
* GSM8K: %s | |
* TruthfulQA: %s | |
* Winogrande: %s | |
* HellaSwag: %s | |
* MMLU: %s | |
""" % (round(model_list[0]["Average"],2), | |
round(model_list[0]["ARC"],2), | |
round(model_list[0]["GSM8K"],2), | |
round(model_list[0]["TruthfulQA"],2), | |
round(model_list[0]["Winogrande"],2), | |
round(model_list[0]["HellaSwag"],2), | |
round(model_list[0]["MMLU"],2) | |
)) | |
st.markdown("**dtype:** %s" % model_list[0]["model_dtype"]) | |
st.markdown("For more details, hover over the radar chart.") | |
else: | |
st.markdown("**Model name:** %s" % model_name) | |
st.markdown("For more details, select the first model in the list/leaderboard.") | |