File size: 8,261 Bytes
b0e6781
 
 
6a1e601
 
b0e6781
 
1d32376
f7d283c
fb2bc19
b0e6781
4c1d731
 
 
b0e6781
6a1e601
 
 
 
 
32e2641
6a1e601
 
 
 
b0e6781
4c1d731
b0e6781
4c1d731
b0e6781
fb2bc19
4c1d731
 
32e2641
 
6a1e601
f3cadf1
101c142
 
32e2641
bd0c4d1
101c142
 
6a1e601
5792938
 
 
 
 
 
 
 
9224fab
 
 
 
 
 
 
 
f7d283c
 
 
 
 
 
 
 
 
 
 
 
 
61dd7eb
 
 
2e7bc8b
 
101c142
32e2641
 
 
2e7bc8b
 
32e2641
 
 
 
101c142
f92272f
32e2641
 
 
 
101c142
c751340
f3cadf1
32e2641
 
 
 
 
 
 
 
 
 
2e7bc8b
4c1d731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a1e601
4c1d731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a1e601
4c1d731
 
 
 
 
 
 
 
 
6a1e601
4c1d731
6a1e601
4c1d731
 
 
 
 
 
 
 
 
 
 
b0e6781
4c1d731
b0e6781
4c1d731
 
 
 
 
 
 
 
 
1d32376
4c1d731
1d32376
 
 
 
 
4c1d731
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import streamlit as st
import pandas as pd
import numpy as np
import json

from streamlit_echarts import st_echarts
from app.show_examples import *
from app.content import *

import pandas as pd

from model_information import get_dataframe
info_df = get_dataframe()


def draw_table(dataset_displayname, metrics):

    with open('organize_model_results.json', 'r') as f:
        organize_model_results = json.load(f)

    dataset_nickname   = displayname2datasetname[dataset_displayname]
    model_results      = organize_model_results[dataset_nickname][metrics]
    model_name_mapping = {key.strip(): val for key, val in zip(info_df['Original Name'], info_df['Proper Display Name'])}
    model_results      = {model_name_mapping.get(key, key): val for key, val in model_results.items()}


    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
    '''
    Show Table
    '''
    with st.container():
        st.markdown('##### TABLE')

        model_link_mapping             = {key.strip(): val for key, val in zip(info_df['Proper Display Name'], info_df['Link'])}
        chart_data_table               = pd.DataFrame(list(model_results.items()), columns=["model_show", dataset_displayname])
        chart_data_table["model_link"] = chart_data_table["model_show"].map(model_link_mapping)

        def highlight_first_element(x):
                # Create a DataFrame with the same shape as the input
                df_style            = pd.DataFrame('', index=x.index, columns=x.columns)
                df_style.iloc[0, 1] = 'background-color: #b0c1d7'
                return df_style

        if dataset_displayname in [
                            'LibriSpeech-Clean',
                            'LibriSpeech-Other',
                            'CommonVoice-15-EN',
                            'Peoples-Speech',
                            'GigaSpeech-1',
                            'Earnings-21',
                            'Earnings-22',
                            'TED-LIUM-3',
                            'TED-LIUM-3-LongForm',
                            'AISHELL-ASR-ZH',
                            'MNSC-PART1-ASR',
                            'MNSC-PART2-ASR',
                            'MNSC-PART3-ASR',
                            'MNSC-PART4-ASR',
                            'MNSC-PART5-ASR',
                            'MNSC-PART6-ASR',
                            'CNA',
                            'IDPC',
                            'Parliament',
                            'UKUS-News',
                            'Mediacorp',
                            'IDPC-Short',
                            'Parliament-Short',
                            'UKUS-News-Short',
                            'Mediacorp-Short',
                            'YTB-ASR-Batch1',
                            'YTB-ASR-Batch2',
                            'SEAME-Dev-Man',
                            'SEAME-Dev-Sge',
                            'GigaSpeech2-Indo',
                            'GigaSpeech2-Thai',
                            'GigaSpeech2-Viet',
                            ]:
            
            chart_data_table = chart_data_table.sort_values(
                                    by        = chart_data_table.columns[1],
                                    ascending = True
                                ).reset_index(drop=True)
        else:
            chart_data_table = chart_data_table.sort_values(
                                    by        = chart_data_table.columns[1],
                                    ascending = False
                                ).reset_index(drop=True)
                            

        styled_df = chart_data_table.style.format(
                                    {chart_data_table.columns[1]: "{:.3f}"}
                                ).apply(
                                    highlight_first_element, axis=None
                                )


        st.dataframe(
                        styled_df,
                        column_config={
                            'model_show'               : 'Model',
                            chart_data_table.columns[1]: {'alignment': 'left'},
                            "model_link"               : st.column_config.LinkColumn("Model Link"),
                        },
                        hide_index=True,
                        use_container_width=True
                    )
                

    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
    '''
    Show Chart
    '''
    # Initialize a session state variable for toggling the chart visibility
    if "show_chart" not in st.session_state:
        st.session_state.show_chart = False

    # Create a button to toggle visibility
    if st.button("Show Chart"):
        st.session_state.show_chart = not st.session_state.show_chart

    if st.session_state.show_chart:

        with st.container():
            st.markdown('##### CHART')

            # Get Values
            data_values = chart_data_table.iloc[:, 1]
            
            # Calculate Q1 and Q3
            q1 = data_values.quantile(0.25)
            q3 = data_values.quantile(0.75)

            # Calculate IQR
            iqr = q3 - q1

            # Define lower and upper bounds (1.5*IQR is a common threshold)
            lower_bound = q1 - 1.5 * iqr
            upper_bound = q3 + 1.5 * iqr

            # Filter data within the bounds
            filtered_data = data_values[(data_values >= lower_bound) & (data_values <= upper_bound)]

            # Calculate min and max values after outlier handling
            min_value = round(filtered_data.min() - 0.1 * filtered_data.min(), 3)
            max_value = round(filtered_data.max() + 0.1 * filtered_data.max(), 3)

            options = {
                # "title": {"text": f"{dataset_name}"},
                "tooltip": {
                    "trigger": "axis",
                    "axisPointer": {"type": "cross", "label": {"backgroundColor": "#6a7985"}},
                    "triggerOn": 'mousemove',
                },
                "legend": {"data": ['Overall Accuracy']},
                "toolbox": {"feature": {"saveAsImage": {}}},
                "grid": {"left": "3%", "right": "4%", "bottom": "3%", "containLabel": True},
                "xAxis": [
                    {
                        "type": "category",
                        "boundaryGap": True,
                        "triggerEvent": True,
                        "data":  chart_data_table['model_show'].tolist(),
                    }
                ],
                "yAxis": [{"type": "value", 
                            "min": min_value,
                            "max": max_value, 
                            "boundaryGap": True
                            # "splitNumber": 10
                            }],
                "series": [{
                        "name": f"{dataset_nickname}",
                        "type": "bar",
                        "data": chart_data_table[f'{dataset_displayname}'].tolist(),
                    }],
            }
            
            events = {
                "click": "function(params) { return params.value }"
            }

            value = st_echarts(options=options, events=events, height="500px")
            

    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
    '''
    Show Examples
    '''
    # Initialize a session state variable for toggling the chart visibility
    if "show_examples" not in st.session_state:
        st.session_state.show_examples = False

    # Create a button to toggle visibility
    if st.button("Show Examples"):
        st.session_state.show_examples = not st.session_state.show_examples

    if st.session_state.show_examples:
        st.markdown('To be implemented')

        # # if dataset_name in ['Earnings21-Test', 'Earnings22-Test', 'Tedlium3-Test', 'Tedlium3-Long-form-Test']:
        # if dataset_name in []:
        #     pass
        # else:
        #     show_examples(category_name, dataset_name, chart_data['Model'].tolist(), display_model_names)