File size: 5,302 Bytes
101c142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import pandas as pd
import numpy as np
from streamlit_echarts import st_echarts
from streamlit.components.v1 import html
# from PIL import Image 
from app.show_examples import *
import pandas as pd
from typing import List

from model_information import get_dataframe

info_df = get_dataframe()

metrics_info = {
    'wer': 'Word Error Rate (WER), a common metric for ASR evaluation. (The lower, the better)',
    'llama3_70b_judge_binary': 'Binary evaluation using the LLAMA3-70B model, for tasks requiring a binary outcome. (0-100 based on score 0-1)',
    'llama3_70b_judge': 'General evaluation using the LLAMA3-70B model, typically scoring based on subjective judgments. (0-100 based on score 0-5)',
    'meteor': 'METEOR, a metric used for evaluating text generation, often used in translation or summarization tasks. (Sensitive to output length)',
    'bleu': 'BLEU (Bilingual Evaluation Understudy), another text generation evaluation metric commonly used in machine translation. (Sensitive to output length)',
}

def sum_table_mulit_metrix(task_name, metrics_lists: List[str]):
    
    for metrics in metrics_lists:
        folder = f"./results/{metrics}/"
        data_path = f'{folder}/{task_name.lower()}.csv'

        chart_data = pd.read_csv(data_path).round(3)
        selected_columns = [i for i in chart_data.columns if i != 'Model']
        chart_data['Average'] = chart_data[selected_columns].mean(axis=1)
        
        # new_dataset_name = dataset_name.replace('-', '_').lower()
        
        st.markdown("""
                    <style>
                    .stMultiSelect [data-baseweb=select] span {
                        max-width: 800px;
                        font-size: 0.9rem;
                        background-color: #3C6478 !important; /* Background color for selected items */
                        color: white; /* Change text color */
                        back
                    }
                    </style>
                    """, unsafe_allow_html=True)
        
        # remap model names
        display_model_names = {key.strip() :val.strip() for key, val in zip(info_df['Original Name'], info_df['Proper Display Name'])}
        chart_data['model_show'] = chart_data['Model'].map(lambda x: display_model_names.get(x, x))

        models = st.multiselect("Please choose the model", 
                                sorted(chart_data['model_show'].tolist()), 
                                default = sorted(chart_data['model_show'].tolist()),
                                key=f"multiselect_{task_name}_{metrics}"
                                )
        
        chart_data = chart_data[chart_data['model_show'].isin(models)].dropna(axis=0)
        # chart_data = chart_data.sort_values(by=['Average'], ascending=True).dropna(axis=0)

        if len(chart_data) == 0: return

        # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
        '''
        Show Table
        '''
        with st.container():
            st.markdown(f'#### Overal Evaluation Results')
            st.markdown(f'###### Evaluation Method: {metrics_info[metrics]}')

            model_link = {key.strip(): val for key, val in zip(info_df['Proper Display Name'], info_df['Link'])}

            chart_data['model_link'] = chart_data['model_show'].map(model_link) 

            tabel_columns = [i for i in chart_data.columns if i not in ['Model', 'model_show']]
            column_to_front = 'Average'
            new_order = [column_to_front] + [col for col in tabel_columns if col != column_to_front]
            
            chart_data_table = chart_data[['model_show'] + new_order]
            

            # Format numeric columns to 2 decimal places
            chart_data_table[chart_data_table.columns[1]] = chart_data_table[chart_data_table.columns[1]].apply(lambda x: round(float(x), 3) if isinstance(float(x), (int, float)) else float(x))

            if metrics in ['wer']:
                ascend = True
            else:
                ascend= False

            chart_data_table = chart_data_table.sort_values(
                    by=['Average'],
                    ascending=ascend
                ).reset_index(drop=True)
            
            def highlight_first_element(x):
                # Create a DataFrame with the same shape as the input
                df_style = pd.DataFrame('', index=x.index, columns=x.columns)
                
                # Apply background color to the first element in row 0 (df[0][0])
                df_style.iloc[0, 1] = 'background-color: #b0c1d7; color: white'
                
                return df_style
            
            styled_df = chart_data_table.style.apply(
                highlight_first_element, axis=None
            )

            st.dataframe(
                    styled_df,
                    column_config={
                        'model_show': 'Model',
                        chart_data_table.columns[1]: {'alignment': 'left'},
                        "model_link": st.column_config.LinkColumn(
                            "Model Link",
                        ),
                    },
                    hide_index=True,
                    use_container_width=True
                )