Spaces:
Sleeping
Sleeping
File size: 4,721 Bytes
7e401f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
def show_segmentation(ti=0, target_dir: str = "./seg"):
fig = plt.figure(figsize=(15, 2 * len(Alg) * 2))
fig.subplots_adjust(hspace=0, wspace=0.1)
# for lead_cnt, ti in enumerate(range(12 * target_idx, 12 * (target_idx + 1))):
test_df = pd.read_csv(test_path)
lead_type = ast.literal_eval(test_df.iloc[ti]["lead_type"])[0]
seq = test_df.iloc[ti]["seq"]
object_id: str = test_df.iloc[ti]["objectid"]
ecg_data_path = f"/bfai/data/ecg_data/{object_id[18:22]}/{object_id}.json"
# ecg_data_path = test_df.iloc[ti]["file_path"]
ecg_data = []
with open(ecg_data_path) as ecg_data_file:
ecg_json = json.load(ecg_data_file)
ecg_data = (
np.array(ecg_json["waveform"]["data"][lead_type])
* ecg_json["study"]["mv_unit"]
)
if seq != "1/1":
seq_idx, seq_range = [int(str_seq) for str_seq in seq.split("/")]
use_lead_length = 5000 # (500 * 10)
total_use_length = use_lead_length * seq_range
front_idx = int((len(ecg_data) - total_use_length) / 2)
ecg_data = ecg_data[front_idx:]
for alg, alg_idx in ALG_ORDER.items():
for pp_type_idx, pp_type in enumerate(["ori", "pp"]):
sub_fig = fig.add_subplot(len(Alg) * 2, 1, 2 * alg_idx + pp_type_idx + 1)
# sub_fig.set_title(f"{alg} - Lead {lead_type}")
sub_fig.text(
0.02,
0.5,
# f"{alg} {pp_type} - {lead_type}",
f"{alg} {pp_type}\n{lead_type}",
fontsize=9,
fontweight="bold",
ha="center",
va="center",
rotation=90,
transform=sub_fig.transAxes,
)
sub_fig.set_xticks([])
sub_fig.set_yticks([])
sub_fig.plot(range(len(ecg_data)), ecg_data, color="black", linewidth=1.0)
sub_fig.plot(
range(len(output[alg][ti][0])),
(output[alg][ti][0] >= cutoff[alg][0]).astype(int) / 2
if pp_type_idx == 0
else (pp_output[alg][ti][0]).astype(int) / 2,
label="P",
color="red",
linewidth=0.7,
)
sub_fig.plot(
range(len(output[alg][ti][1])),
(output[alg][ti][1] >= cutoff[alg][1]).astype(int)
if pp_type_idx == 0
else (pp_output[alg][ti][1]).astype(int),
label="QRS",
color="green",
linewidth=0.7,
)
sub_fig.plot(
(output[alg][ti][2] >= cutoff[alg][2]).astype(int) / 2
if pp_type_idx == 0
else (pp_output[alg][ti][2]).astype(int) / 2,
label="T",
color="blue",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][0])),
((origin_seg[ti][0] > 0).astype(int) * (-1)) / 2,
label="P Label",
color="salmon",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][1])),
((origin_seg[ti][1] > 0).astype(int) * (-1)),
label="QRS Label",
color="seagreen",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][2])),
((origin_seg[ti][2] > 0).astype(int) * (-1)) / 2,
label="T Label",
color="darkslateblue",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][0])),
((origin_seg[ti][0] < 0).astype(int) * (-1)),
label="P UnLabeled",
linestyle=":",
color="salmon",
linewidth=0.5,
)
sub_fig.plot(
range(len(origin_seg[ti][1])),
((origin_seg[ti][1] < 0).astype(int) * (-1)),
label="QRS UnLabeled",
linestyle=":",
color="seagreen",
linewidth=0.5,
)
sub_fig.plot(
range(len(origin_seg[ti][2])),
((origin_seg[ti][2] < 0).astype(int) * (-1)),
label="T UnLabeled",
linestyle=":",
color="darkslateblue",
linewidth=0.5,
)
# sub_fig.legend()
plt.savefig(
f"./{target_dir}/{test_df.iloc[ti]['lead_cnt']}_{object_id}_{lead_type}_{test_df.iloc[ti]['seq'].replace('/', '_')}.png",
dpi=150,
)
plt.close() |