File size: 10,916 Bytes
ff0c7de fbcd930 a7f9f33 1d040cb ff0c7de 963c6da 1d040cb fbcd930 a7f9f33 fbcd930 a7f9f33 adbb181 12f938b fbcd930 a7f9f33 fbcd930 a7f9f33 864cb6d adbb181 fbcd930 12f938b fbcd930 12f938b a7f9f33 ff0c7de a7f9f33 adbb181 ff0c7de fbcd930 864cb6d a7f9f33 fbcd930 adbb181 fbcd930 a7f9f33 fbcd930 ff0c7de a7f9f33 ff0c7de a7f9f33 ff0c7de a7f9f33 adbb181 ff0c7de a7f9f33 fbcd930 ff0c7de fbcd930 a7f9f33 1f586be a7f9f33 864cb6d a7f9f33 59592ea a7f9f33 864cb6d a7f9f33 fbcd930 a7f9f33 864cb6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
from st_aggrid import GridOptionsBuilder, AgGrid
from streamlit_searchbox import st_searchbox
import streamlit as st
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
from .plot import plot_radar_chart_name, plot_radar_chart_rows
def display_app():
st.markdown("# Open LLM Leaderboard Viz")
st.markdown("## Some explanations")
st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
st.markdown("To select a model, click on the checkbox beside its name, or search it by its name in the search boxes **Model 1, Model 2, or Model 3** bellow.")
st.markdown("You can select up to three models using the search boxes and/or the checkboxes.")
st.markdown("""In the case you use both the search boxes and the checkboxes, the search boxes will take precedence over the checkboxes,
i.e. the models searched using the search boxes will be prioritized over the ones selected using the checkboxes.
Please, search models using the search boxes first, and then use the checkboxes.
""")
st.markdown("This app displays the top 100 models by default, but you can change that using the number input in the sidebar.")
st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.")
st.markdown("If your model doesn't show up, please search it by its name.")
dataframe = load_dataframe()
categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"]
st.markdown("## Leaderboard")
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns.difference(["model_dtype"])), index = 1)
d_type_options = ["all", "torch.bfloat16", "torch.float16", "4bit", "8bit"]
d_type = st.radio(label = "Filter by dtype", options = d_type_options, index = 0, horizontal = True)
number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
ascending = True
if sort_selection is None:
sort_selection = "model_name"
ascending = True
elif sort_selection == "model_name":
ascending = True
else:
ascending = False
# Dynamic search boxes
def search_model(model_name: str):
model_list = None
if model_name is not None:
models = dataframe["model_name"].str.contains(model_name)
model_list = dataframe["model_name"][models]
else:
model_list = []
return model_list
model_list = []
#Sidebar configurations
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=1)
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
ordering_metrics = ordering_metrics.replace(" ", "")
ordering_metrics = ordering_metrics.split(",")
st.sidebar.markdown("""
As a reminder, here are the different metrics:
* ARC
* GSM8K
* TruthfulQA
* Winogrande
* HellaSwag
* MMLU
""")
st.sidebar.markdown("""
If there are **typos** in the name of the metrics, or the number of metrics
is **different of six**, there will be no effect on the chart and the
default ordering will be used.
""")
valid_categories = validate_categories(ordering_metrics)
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
if d_type != "all":
dataframe = dataframe[dataframe["model_dtype"] == d_type]
dataframe_display = dataframe.copy()
dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)
#Infer basic colDefs from dataframe types
gb = GridOptionsBuilder.from_dataframe(dataframe_display)
gb.configure_selection(selection_mode = selection_mode, use_checkbox=True)
gb.configure_grid_options(domLayout='normal')
gridOptions = gb.build()
column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small")
with column1:
grid_response = AgGrid(
dataframe_display,
gridOptions=gridOptions,
height=300,
width='40%'
)
model_one = st_searchbox(label = "Model 1", search_function = search_model, key = "model_1", default= None)
model_two = st_searchbox(label = "Model 2", search_function = search_model, key = "model_2", default= None)
model_three = st_searchbox(label = "Model 3", search_function = search_model, key = "model_3", default= None)
if model_one is not None:
row = dataframe[dataframe["model_name"] == model_one]
row[categories_display] = row[categories_display]*100
model_list.append(row.to_dict("records")[0])
if model_two is not None:
row = dataframe[dataframe["model_name"] == model_two]
row[categories_display] = row[categories_display]*100
model_list.append(row.to_dict("records")[0])
if model_three is not None:
row = dataframe[dataframe["model_name"] == model_three]
row[categories_display] = row[categories_display]*100
model_list.append(row.to_dict("records")[0])
subdata = dataframe.head(1)
if len(subdata) > 0:
model_name = subdata["model_name"].values[0]
else:
model_name = ""
with column2:
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
figure = None
model_list += grid_response['selected_rows']
model_list = model_list[:3]
model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True)
if valid_categories:
figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics)
else:
figure = plot_radar_chart_rows(rows=model_list)
st.plotly_chart(figure, use_container_width=False)
elif len(model_list) > 0:
figure = None
model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True)
if valid_categories:
figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics)
else:
figure = plot_radar_chart_rows(rows=model_list)
st.plotly_chart(figure, use_container_width=False)
else:
if len(subdata)>0:
figure = None
if valid_categories:
figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name)
else:
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
st.plotly_chart(figure, use_container_width=True)
if len(model_list) > 1:
n_col = len(model_list) if len(model_list) <=3 else 3
st.markdown("## Models")
columns = st.columns(n_col)
for i in range(n_col):
with columns[i]:
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[i]["model_name"] , model_list[i]["model_name"]))
st.markdown("**Results:**")
st.markdown("""
* Average: %s
* ARC: %s
* GSM8K: %s
* TruthfulQA: %s
* Winogrande: %s
* HellaSwag: %s
* MMLU: %s
""" % (round(model_list[i]["Average"],2),
round(model_list[i]["ARC"],2),
round(model_list[i]["GSM8K"],2),
round(model_list[i]["TruthfulQA"],2),
round(model_list[i]["Winogrande"],2),
round(model_list[i]["HellaSwag"],2),
round(model_list[i]["MMLU"],2)
))
st.markdown("**dtype:** %s" % model_list[i]["model_dtype"])
elif len(model_list) == 1:
st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"]))
st.markdown("**Results:**")
st.markdown("""
* Average: %s
* ARC: %s
* GSM8K: %s
* TruthfulQA: %s
* Winogrande: %s
* HellaSwag: %s
* MMLU: %s
""" % (round(model_list[0]["Average"],2),
round(model_list[0]["ARC"],2),
round(model_list[0]["GSM8K"],2),
round(model_list[0]["TruthfulQA"],2),
round(model_list[0]["Winogrande"],2),
round(model_list[0]["HellaSwag"],2),
round(model_list[0]["MMLU"],2)
))
st.markdown("**dtype:** %s" % model_list[0]["model_dtype"])
st.markdown("For more details, hover over the radar chart.")
else:
st.markdown("**Model name:** %s" % model_name)
st.markdown("For more details, select the first model in the list/leaderboard.") |