File size: 3,882 Bytes
f1e1ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8403841
f1e1ac2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import os

import gradio as gr
import plotly.graph_objects as go

from assets.constant import DELIMITER
from assets.path import SEASON

DEEPEST = 4


def build_plot(category_result, columns):
    k_x, k_y, k_text, k_color = [], [], [], []
    d_xy = {}
    for c in columns:
        k_x.append(c.split(DELIMITER)[-1])
        result = category_result.get(c)
        k_y.append(round(result.get("acc"), 4))
        sub_count = sum([1 for k in category_result if k.startswith(c)]) - 1
        k_text.append(
            f'hit:{result.get("hit")} sub_count:{sub_count}')
        for d, v in result['difficulty'].items():
            if d not in d_xy:
                d_xy[d] = {"hit": 0, "all": 0}
            d_xy[d]['hit'] += v['hit']
            d_xy[d]['all'] += v['all']
        k_color.append(result.get("all"))
    d_x = sorted(d_xy, reverse=True)
    d_y, d_text, d_color = [], [], []
    for d in d_x:
        v = d_xy[d]
        d_y.append(v['hit'] / v['all'])
        d_text.append(f'hit/total:{v["hit"]}/{v["all"]}')
        d_color.append(v['all'])

    k_fig = go.Figure([go.Bar(x=k_x, y=k_y, hovertext=k_text, marker={"color": k_color, "colorscale": "Viridis",
                                                                      "colorbar": {"title": "Total"}})])
    k_fig.update_layout(yaxis=dict(range=[0, 1]))
    d_fig = go.Figure([go.Bar(x=d_x, y=d_y, hovertext=d_text,
                              marker={"color": d_color, "colorscale": "Cividis", "colorbar": {"title": "Total"}})])
    d_fig.update_layout(yaxis=dict(range=[0, 1]))
    return k_fig, d_fig, k_x


def create_detail(top_components):
    models = os.listdir(os.path.join("results", SEASON["latest"], "details"))
    model_dropdown = gr.Dropdown(choices=models, label="Select Model")

    category_result = gr.State()
    with gr.Row():
        keypoint_dropdowns = [gr.Dropdown([], visible=False, label=f"Level{i + 1}") for i in range(DEEPEST)]
    keypoint_plot = gr.Plot(label="Keypoint Acc")
    difficulty_plot = gr.Plot(label="Difficulty Acc")

    for i in range(DEEPEST):
        keypoint_dropdown = keypoint_dropdowns[i]

        def keypoint_dropdown_func(x, *args):
            keypoints = DELIMITER.join(args)
            columns = [k for k in x if k.startswith(keypoints) and k.count(DELIMITER) == len(args)]
            sub = True
            if not columns:
                columns = [keypoints]
                sub = False
            k_fig, d_fig, choices = build_plot(x, columns)
            updates = list(args) + [gr.update(choices=choices, visible=sub)] + [
                gr.update(choices=[], visible=False)] * (DEEPEST - len(args) - 1)
            return gr.update(value=k_fig), gr.update(value=d_fig), *updates

        keypoint_dropdown.input(keypoint_dropdown_func, [category_result, *keypoint_dropdowns[:i + 1]],
                                [keypoint_plot, difficulty_plot, *keypoint_dropdowns])

    def model_dropdown_func(x):
        dir = os.path.join("results", SEASON["latest"], "details", x)
        new_category_result = json.load(open(os.path.join(dir, "category_result.json"), encoding="utf-8"))
        columns = sorted([k for k in new_category_result if k.count(DELIMITER) == 0],
                         key=lambda c: new_category_result[c]['all'], reverse=True)
        k_fig, d_fig, choices = build_plot(new_category_result, columns)
        return new_category_result, gr.update(value=k_fig), gr.update(value=d_fig), gr.update(choices=choices,
                                                                                              visible=True), *[
            gr.update(value=None, visible=False) for _ in range(DEEPEST - 1)]

    model_dropdown.change(model_dropdown_func, model_dropdown, [category_result, keypoint_plot, difficulty_plot,
                                                                *keypoint_dropdowns])