Spaces:
Sleeping
Sleeping
#import streamlit as st | |
#from src.load_data import load_dataframe, sort_by | |
#from src.plot import plot_radar_chart_index, plot_radar_chart_name | |
#from st_aggrid import GridOptionsBuilder, AgGrid | |
from st_aggrid import GridOptionsBuilder, AgGrid | |
import streamlit as st | |
from .load_data import load_dataframe, sort_by | |
from .plot import plot_radar_chart_name, plot_radar_chart_rows | |
def display_app(): | |
st.markdown("# Open LLM Leaderboard Viz") | |
st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") | |
st.markdown("To select a model, click on the checkbox beside its name.") | |
#container = st.container(height = 150) | |
dataframe = load_dataframe() | |
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns)) | |
ascending = True | |
indexes = None | |
if sort_selection is None: | |
sort_selection = "model_name" | |
ascending = True | |
elif sort_selection == "model_name": | |
ascending = True | |
else: | |
ascending = False | |
name = st.text_input(label = ":mag: Search by name") | |
if name is not None: | |
indexes = dataframe["model_name"].str.contains(name) | |
if len(indexes) > 0: | |
dataframe = dataframe[indexes] | |
else: | |
dataframe = load_dataframe() | |
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) | |
dataframe_display = dataframe.copy() | |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) | |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 | |
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) | |
#Infer basic colDefs from dataframe types | |
gb = GridOptionsBuilder.from_dataframe(dataframe_display) | |
gb.configure_selection(selection_mode = "single", use_checkbox=True) | |
gb.configure_grid_options(domLayout='normal') | |
gridOptions = gb.build() | |
column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") | |
with column1: | |
#with container: | |
#st.dataframe(dataframe_display) | |
grid_response = AgGrid( | |
dataframe_display, | |
gridOptions=gridOptions, | |
height=300, | |
width='40%' | |
) | |
subdata = dataframe.head(1) | |
if len(subdata) > 0: | |
model_name = subdata["model_name"].values[0] | |
else: | |
model_name = "" | |
with column2: | |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: | |
figure = plot_radar_chart_rows(rows=grid_response['selected_rows']) | |
st.plotly_chart(figure, use_container_width=True) | |
else: | |
if len(subdata)>0: | |
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) | |
st.plotly_chart(figure, use_container_width=True) | |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: | |
st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"]) | |
else: | |
st.markdown("**Model name:** %s" % model_name) |