ECG_Delineation / app.py
wogh2012's picture
feat: add gradio app to demonstrating segmentations by hrnet
7e401f6
raw
history blame
7.56 kB
# from io import BytesIO
import gradio as gr
import pandas as pd
import numpy as np
import torch
import os
from aitiautils.model_loader import ModelLoader
import tempfile
import matplotlib.pyplot as plt
import traceback as tb
# True 이면, tmp directory 에 파일 존재 유무와 상관없이 항상 새로운 이미지 생성
ALWAYS_RECREATE_IMAGE = os.getenv("ALWAYS_RECREATE_IMAGE", "False").lower() == "true"
selected_columns = ["subject_id", "no_p", "Rhythm", "Electric axis of the heart", "Etc"]
train_df = pd.read_csv("./res/ludb/dataset/train_for_public.csv").drop_duplicates(
subset=["subject_id"]
)[selected_columns]
valid_df = pd.read_csv("./res/ludb/dataset/valid_for_public.csv").drop_duplicates(
subset=["subject_id"]
)[selected_columns]
test_df = pd.read_csv("./res/ludb/dataset/test_for_public.csv").drop_duplicates(
subset=["subject_id"]
)[selected_columns]
cutoffs = [0.001163482666015625, 0.15087890625, -0.587890625]
lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
def gen_seg(subject_id):
input = np.load(f"./res/ludb/ecg_np/{subject_id}.npy")
network = ModelLoader("./res/models/hrnetv2/checkpoint.pth").get_network()
output: torch.Tensor = network(torch.from_numpy(input)).detach().numpy()
seg = [(output[:, i, :] >= cutoffs[i]).astype(int) for i in range(len(cutoffs))]
return input, np.stack(seg, axis=1)
def concat_short_interval(seg, th):
"""seg에서 구간(1)과 구간(1) 사이에 th 보다 짧은 부분(0)을 이어 붙인다. (0 -> 1)"""
# seg 에서 같은 구간끼리 그룹을 만듦. ex: seg = [0, 0, 1, 1, 0, 1, 1, 1, 1] -> seg_groups=[[0, 0], [1, 1], [0], [1, 1, 1, 1]]]
seg_groups = np.split(seg, np.where(np.diff(seg) != 0)[0] + 1)
for i in range(1, len(seg_groups) - 1): # 첫 번째와 마지막 그룹 제외
group = seg_groups[i]
if len(group) <= th and np.all(group == 0):
seg_groups[i] = np.ones_like(group) # 0 -> 1
return np.concatenate(seg_groups)
def remove_short_duration(seg, th):
"""seg의 구간(1)중에 th 보다 짧은 구간은 제거 (1 -> 0)"""
seg_groups = np.split(seg, np.where(np.diff(seg) != 0)[0] + 1)
for i, group in enumerate(seg_groups):
if len(group) <= th and np.all(group == 1):
seg_groups[i] = np.zeros_like(group) # 1 -> 0
return np.concatenate(seg_groups)
def gen_each_image(input, seg, image_path, ths, pp=False):
fig = plt.figure(figsize=(15, 18))
plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02, hspace=0.2)
for idx, (in_by_lead, seg_by_lead) in enumerate(zip(input, seg)):
sub_fig = fig.add_subplot(12, 1, idx + 1)
sub_fig.text(
0.02,
0.5,
f"{lead_names[idx]}",
fontsize=9,
fontweight="bold",
ha="center",
va="center",
rotation=90,
transform=sub_fig.transAxes,
)
sub_fig.set_xticks([])
sub_fig.set_yticks([])
sub_fig.plot(
range(len(in_by_lead[0])), in_by_lead[0], color="black", linewidth=1.0
)
p_seg = seg_by_lead[0]
qrs_seg = seg_by_lead[1]
t_seg = seg_by_lead[2]
if pp:
p_seg = remove_short_duration(concat_short_interval(p_seg, ths[0]), ths[1])
qrs_seg = remove_short_duration(
concat_short_interval(qrs_seg, ths[2]), ths[3]
)
t_seg = remove_short_duration(concat_short_interval(t_seg, ths[4]), ths[5])
sub_fig.plot(
range(len(p_seg)), p_seg / 2, label="P", color="red", linewidth=0.7
)
sub_fig.plot(
range(len(qrs_seg)), qrs_seg, label="QRS", color="green", linewidth=0.7
)
sub_fig.plot(
range(len(t_seg)), t_seg / 2, label="T", color="blue", linewidth=0.7
)
plt.savefig(image_path, dpi=150)
plt.close()
def gen_image(subject_id, image_path, pp_image_path, ths):
try:
input, seg = gen_seg(subject_id)
gen_each_image(input, seg, image_path, ths)
gen_each_image(input, seg, pp_image_path, ths, True)
return True
except Exception:
print(tb.format_exc())
return False
with gr.Blocks() as demo:
with gr.Tab("App"):
with gr.Row():
gr.Textbox(
"Welcome to visit ECG-Delineation space",
label="Information",
lines=1,
)
gr_dfs = []
with gr.Row():
gr_dfs.append(
gr.Dataframe(
value=train_df,
interactive=False,
max_height=250,
label="our train dataset. (source: ./res/ludb/dataset/train_for_public.csv)",
)
)
with gr.Row():
gr_dfs.append(
gr.Dataframe(
value=valid_df,
interactive=False,
max_height=250,
label="our valid dataset. (source: ./res/ludb/dataset/valid_for_public.csv)",
)
)
with gr.Row():
gr_dfs.append(
gr.Dataframe(
value=test_df,
interactive=False,
max_height=250,
label="our test dataset. (source: ./res/ludb/dataset/test_for_public.csv)",
)
)
with gr.Row():
gr_ths = [
gr.Textbox(
label="Interval Threshold of P (ms)",
lines=1,
value="10",
),
gr.Textbox(
label="Duration Threshold of P (ms)",
lines=1,
value="50",
),
gr.Textbox(
label="Interval Threshold of QRS (ms)",
lines=1,
value="50",
),
gr.Textbox(
label="Duration Threshold of QRS (ms)",
lines=1,
value="50",
),
gr.Textbox(
label="Interval Threshold of T (ms)",
lines=1,
value="30",
),
gr.Textbox(
label="Duration Threshold of T (ms)",
lines=1,
value="50",
),
]
with gr.Row():
gr_image = gr.Image(type="filepath", label="Output")
gr_pp_image = gr.Image(type="filepath", label="PostProcessed Output")
def show_image(df: pd.DataFrame, evt: gr.SelectData, *ths):
subject_id = evt.row_value[0]
image_path = f"{tempfile.gettempdir()}/ludb_{subject_id}.png"
pp_image_path = f"{tempfile.gettempdir()}/ludb_{subject_id}_pp.png"
if not ALWAYS_RECREATE_IMAGE and (
os.path.exists(image_path) and os.path.exists(pp_image_path)
):
return [image_path, pp_image_path]
gen_image(
subject_id, image_path, pp_image_path, [int(th) / 2 for th in ths]
)
return [image_path, pp_image_path]
for gr_df in gr_dfs:
gr_df.select(
fn=show_image,
inputs=[gr_df, *gr_ths],
outputs=[gr_image, gr_pp_image],
)
demo.launch()