#import streamlit as st #from src.load_data import load_dataframe, sort_by #from src.plot import plot_radar_chart_index, plot_radar_chart_name #from st_aggrid import GridOptionsBuilder, AgGrid from st_aggrid import GridOptionsBuilder, AgGrid import streamlit as st from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name from .plot import plot_radar_chart_name, plot_radar_chart_rows def display_app(): st.markdown("# Open LLM Leaderboard Viz") st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") st.markdown("To select a model, click on the checkbox beside its name.") st.markdown("This displays the top 100 models by default, but you can change that using the number input below.") st.markdown("By defalut as well, the maximum number of row you can display is 500, it is due to the problem with st__aggrid component loading.") st.markdown("If your model doesn't show up, please search it by its name.") dataframe = load_dataframe() sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7) number_of_row = st.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100) ascending = True if sort_selection is None: sort_selection = "model_name" ascending = True elif sort_selection == "model_name": ascending = True else: ascending = False name = st.text_input(label = ":mag: Search by name") len_name_input = len(name) if len_name_input > 0: dataframe_by_search = search_by_name(name) if len(dataframe_by_search) > 0: #st.write("number of model name with name", len(dataframe_by_search)) dataframe = dataframe_by_search else: dataframe = load_dataframe() dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) dataframe_display = dataframe.copy() if len_name_input == 0: # Show every only top n row dataframe_display = show_dataframe_top(number_of_row,dataframe_display) dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) #Infer basic colDefs from dataframe types gb = GridOptionsBuilder.from_dataframe(dataframe_display) gb.configure_selection(selection_mode = "single", use_checkbox=True) gb.configure_grid_options(domLayout='normal') gridOptions = gb.build() column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") with column1: grid_response = AgGrid( dataframe_display, gridOptions=gridOptions, height=300, width='40%' ) subdata = dataframe.head(1) if len(subdata) > 0: model_name = subdata["model_name"].values[0] else: model_name = "" with column2: if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: figure = plot_radar_chart_rows(rows=grid_response['selected_rows']) #figure = plot_radar_chart_name(dataframe= dataframe, model_name=grid_response['selected_rows'][0]["model_name"]) st.plotly_chart(figure, use_container_width=True) else: if len(subdata)>0: figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) st.plotly_chart(figure, use_container_width=True) if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"]) else: st.markdown("**Model name:** %s" % model_name)