File size: 4,721 Bytes
7e401f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def show_segmentation(ti=0, target_dir: str = "./seg"):
    fig = plt.figure(figsize=(15, 2 * len(Alg) * 2))
    fig.subplots_adjust(hspace=0, wspace=0.1)

    # for lead_cnt, ti in enumerate(range(12 * target_idx, 12 * (target_idx + 1))):
    test_df = pd.read_csv(test_path)
    lead_type = ast.literal_eval(test_df.iloc[ti]["lead_type"])[0]
    seq = test_df.iloc[ti]["seq"]
    object_id: str = test_df.iloc[ti]["objectid"]
    ecg_data_path = f"/bfai/data/ecg_data/{object_id[18:22]}/{object_id}.json"
    # ecg_data_path = test_df.iloc[ti]["file_path"]
    ecg_data = []
    with open(ecg_data_path) as ecg_data_file:
        ecg_json = json.load(ecg_data_file)
        ecg_data = (
            np.array(ecg_json["waveform"]["data"][lead_type])
            * ecg_json["study"]["mv_unit"]
        )
        if seq != "1/1":
            seq_idx, seq_range = [int(str_seq) for str_seq in seq.split("/")]
            use_lead_length = 5000  # (500 * 10)
            total_use_length = use_lead_length * seq_range
            front_idx = int((len(ecg_data) - total_use_length) / 2)
            ecg_data = ecg_data[front_idx:]

    for alg, alg_idx in ALG_ORDER.items():
        for pp_type_idx, pp_type in enumerate(["ori", "pp"]):
            sub_fig = fig.add_subplot(len(Alg) * 2, 1, 2 * alg_idx + pp_type_idx + 1)
            # sub_fig.set_title(f"{alg} - Lead {lead_type}")
            sub_fig.text(
                0.02,
                0.5,
                # f"{alg} {pp_type} - {lead_type}",
                f"{alg} {pp_type}\n{lead_type}",
                fontsize=9,
                fontweight="bold",
                ha="center",
                va="center",
                rotation=90,
                transform=sub_fig.transAxes,
            )
            sub_fig.set_xticks([])
            sub_fig.set_yticks([])

            sub_fig.plot(range(len(ecg_data)), ecg_data, color="black", linewidth=1.0)
            sub_fig.plot(
                range(len(output[alg][ti][0])),
                (output[alg][ti][0] >= cutoff[alg][0]).astype(int) / 2
                if pp_type_idx == 0
                else (pp_output[alg][ti][0]).astype(int) / 2,
                label="P",
                color="red",
                linewidth=0.7,
            )
            sub_fig.plot(
                range(len(output[alg][ti][1])),
                (output[alg][ti][1] >= cutoff[alg][1]).astype(int)
                if pp_type_idx == 0
                else (pp_output[alg][ti][1]).astype(int),
                label="QRS",
                color="green",
                linewidth=0.7,
            )
            sub_fig.plot(
                (output[alg][ti][2] >= cutoff[alg][2]).astype(int) / 2
                if pp_type_idx == 0
                else (pp_output[alg][ti][2]).astype(int) / 2,
                label="T",
                color="blue",
                linewidth=0.7,
            )

            sub_fig.plot(
                range(len(origin_seg[ti][0])),
                ((origin_seg[ti][0] > 0).astype(int) * (-1)) / 2,
                label="P Label",
                color="salmon",
                linewidth=0.7,
            )
            sub_fig.plot(
                range(len(origin_seg[ti][1])),
                ((origin_seg[ti][1] > 0).astype(int) * (-1)),
                label="QRS Label",
                color="seagreen",
                linewidth=0.7,
            )
            sub_fig.plot(
                range(len(origin_seg[ti][2])),
                ((origin_seg[ti][2] > 0).astype(int) * (-1)) / 2,
                label="T Label",
                color="darkslateblue",
                linewidth=0.7,
            )

            sub_fig.plot(
                range(len(origin_seg[ti][0])),
                ((origin_seg[ti][0] < 0).astype(int) * (-1)),
                label="P UnLabeled",
                linestyle=":",
                color="salmon",
                linewidth=0.5,
            )
            sub_fig.plot(
                range(len(origin_seg[ti][1])),
                ((origin_seg[ti][1] < 0).astype(int) * (-1)),
                label="QRS UnLabeled",
                linestyle=":",
                color="seagreen",
                linewidth=0.5,
            )
            sub_fig.plot(
                range(len(origin_seg[ti][2])),
                ((origin_seg[ti][2] < 0).astype(int) * (-1)),
                label="T UnLabeled",
                linestyle=":",
                color="darkslateblue",
                linewidth=0.5,
            )
            # sub_fig.legend()
    plt.savefig(
        f"./{target_dir}/{test_df.iloc[ti]['lead_cnt']}_{object_id}_{lead_type}_{test_df.iloc[ti]['seq'].replace('/', '_')}.png",
        dpi=150,
    )
    plt.close()