ECG_Delineation / utils.py
wogh2012's picture
feat: add gradio app to demonstrating segmentations by hrnet
7e401f6
raw
history blame
4.72 kB
def show_segmentation(ti=0, target_dir: str = "./seg"):
fig = plt.figure(figsize=(15, 2 * len(Alg) * 2))
fig.subplots_adjust(hspace=0, wspace=0.1)
# for lead_cnt, ti in enumerate(range(12 * target_idx, 12 * (target_idx + 1))):
test_df = pd.read_csv(test_path)
lead_type = ast.literal_eval(test_df.iloc[ti]["lead_type"])[0]
seq = test_df.iloc[ti]["seq"]
object_id: str = test_df.iloc[ti]["objectid"]
ecg_data_path = f"/bfai/data/ecg_data/{object_id[18:22]}/{object_id}.json"
# ecg_data_path = test_df.iloc[ti]["file_path"]
ecg_data = []
with open(ecg_data_path) as ecg_data_file:
ecg_json = json.load(ecg_data_file)
ecg_data = (
np.array(ecg_json["waveform"]["data"][lead_type])
* ecg_json["study"]["mv_unit"]
)
if seq != "1/1":
seq_idx, seq_range = [int(str_seq) for str_seq in seq.split("/")]
use_lead_length = 5000 # (500 * 10)
total_use_length = use_lead_length * seq_range
front_idx = int((len(ecg_data) - total_use_length) / 2)
ecg_data = ecg_data[front_idx:]
for alg, alg_idx in ALG_ORDER.items():
for pp_type_idx, pp_type in enumerate(["ori", "pp"]):
sub_fig = fig.add_subplot(len(Alg) * 2, 1, 2 * alg_idx + pp_type_idx + 1)
# sub_fig.set_title(f"{alg} - Lead {lead_type}")
sub_fig.text(
0.02,
0.5,
# f"{alg} {pp_type} - {lead_type}",
f"{alg} {pp_type}\n{lead_type}",
fontsize=9,
fontweight="bold",
ha="center",
va="center",
rotation=90,
transform=sub_fig.transAxes,
)
sub_fig.set_xticks([])
sub_fig.set_yticks([])
sub_fig.plot(range(len(ecg_data)), ecg_data, color="black", linewidth=1.0)
sub_fig.plot(
range(len(output[alg][ti][0])),
(output[alg][ti][0] >= cutoff[alg][0]).astype(int) / 2
if pp_type_idx == 0
else (pp_output[alg][ti][0]).astype(int) / 2,
label="P",
color="red",
linewidth=0.7,
)
sub_fig.plot(
range(len(output[alg][ti][1])),
(output[alg][ti][1] >= cutoff[alg][1]).astype(int)
if pp_type_idx == 0
else (pp_output[alg][ti][1]).astype(int),
label="QRS",
color="green",
linewidth=0.7,
)
sub_fig.plot(
(output[alg][ti][2] >= cutoff[alg][2]).astype(int) / 2
if pp_type_idx == 0
else (pp_output[alg][ti][2]).astype(int) / 2,
label="T",
color="blue",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][0])),
((origin_seg[ti][0] > 0).astype(int) * (-1)) / 2,
label="P Label",
color="salmon",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][1])),
((origin_seg[ti][1] > 0).astype(int) * (-1)),
label="QRS Label",
color="seagreen",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][2])),
((origin_seg[ti][2] > 0).astype(int) * (-1)) / 2,
label="T Label",
color="darkslateblue",
linewidth=0.7,
)
sub_fig.plot(
range(len(origin_seg[ti][0])),
((origin_seg[ti][0] < 0).astype(int) * (-1)),
label="P UnLabeled",
linestyle=":",
color="salmon",
linewidth=0.5,
)
sub_fig.plot(
range(len(origin_seg[ti][1])),
((origin_seg[ti][1] < 0).astype(int) * (-1)),
label="QRS UnLabeled",
linestyle=":",
color="seagreen",
linewidth=0.5,
)
sub_fig.plot(
range(len(origin_seg[ti][2])),
((origin_seg[ti][2] < 0).astype(int) * (-1)),
label="T UnLabeled",
linestyle=":",
color="darkslateblue",
linewidth=0.5,
)
# sub_fig.legend()
plt.savefig(
f"./{target_dir}/{test_df.iloc[ti]['lead_cnt']}_{object_id}_{lead_type}_{test_df.iloc[ti]['seq'].replace('/', '_')}.png",
dpi=150,
)
plt.close()