# from io import BytesIO import gradio as gr import pandas as pd import numpy as np import torch import os from aitiautils.model_loader import ModelLoader import tempfile import matplotlib.pyplot as plt import traceback as tb # True 이면, tmp directory 에 파일 존재 유무와 상관없이 항상 새로운 이미지 생성 ALWAYS_RECREATE_IMAGE = os.getenv("ALWAYS_RECREATE_IMAGE", "False").lower() == "true" selected_columns = ["subject_id", "no_p", "Rhythm", "Electric axis of the heart", "Etc"] train_df = pd.read_csv("./res/ludb/dataset/train_for_public.csv").drop_duplicates( subset=["subject_id"] )[selected_columns] valid_df = pd.read_csv("./res/ludb/dataset/valid_for_public.csv").drop_duplicates( subset=["subject_id"] )[selected_columns] test_df = pd.read_csv("./res/ludb/dataset/test_for_public.csv").drop_duplicates( subset=["subject_id"] )[selected_columns] cutoffs = [0.001163482666015625, 0.15087890625, -0.587890625] lead_names = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"] def gen_seg(subject_id): input = np.load(f"./res/ludb/ecg_np/{subject_id}.npy") network = ModelLoader("./res/models/hrnetv2/checkpoint.pth").get_network() output: torch.Tensor = network(torch.from_numpy(input)).detach().numpy() seg = [(output[:, i, :] >= cutoffs[i]).astype(int) for i in range(len(cutoffs))] return input, np.stack(seg, axis=1) def concat_short_interval(seg, th): """seg에서 구간(1)과 구간(1) 사이에 th 보다 짧은 부분(0)을 이어 붙인다. (0 -> 1)""" # seg 에서 같은 구간끼리 그룹을 만듦. ex: seg = [0, 0, 1, 1, 0, 1, 1, 1, 1] -> seg_groups=[[0, 0], [1, 1], [0], [1, 1, 1, 1]]] seg_groups = np.split(seg, np.where(np.diff(seg) != 0)[0] + 1) for i in range(1, len(seg_groups) - 1): # 첫 번째와 마지막 그룹 제외 group = seg_groups[i] if len(group) <= th and np.all(group == 0): seg_groups[i] = np.ones_like(group) # 0 -> 1 return np.concatenate(seg_groups) def remove_short_duration(seg, th): """seg의 구간(1)중에 th 보다 짧은 구간은 제거 (1 -> 0)""" seg_groups = np.split(seg, np.where(np.diff(seg) != 0)[0] + 1) for i, group in enumerate(seg_groups): if len(group) <= th and np.all(group == 1): seg_groups[i] = np.zeros_like(group) # 1 -> 0 return np.concatenate(seg_groups) def gen_each_image(input, seg, image_path, ths, pp=False): fig = plt.figure(figsize=(15, 18)) plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02, hspace=0.2) for idx, (in_by_lead, seg_by_lead) in enumerate(zip(input, seg)): sub_fig = fig.add_subplot(12, 1, idx + 1) sub_fig.text( 0.02, 0.5, f"{lead_names[idx]}", fontsize=9, fontweight="bold", ha="center", va="center", rotation=90, transform=sub_fig.transAxes, ) sub_fig.set_xticks([]) sub_fig.set_yticks([]) sub_fig.plot( range(len(in_by_lead[0])), in_by_lead[0], color="black", linewidth=1.0 ) p_seg = seg_by_lead[0] qrs_seg = seg_by_lead[1] t_seg = seg_by_lead[2] if pp: p_seg = remove_short_duration(concat_short_interval(p_seg, ths[0]), ths[1]) qrs_seg = remove_short_duration( concat_short_interval(qrs_seg, ths[2]), ths[3] ) t_seg = remove_short_duration(concat_short_interval(t_seg, ths[4]), ths[5]) sub_fig.plot( range(len(p_seg)), p_seg / 2, label="P", color="red", linewidth=0.7 ) sub_fig.plot( range(len(qrs_seg)), qrs_seg, label="QRS", color="green", linewidth=0.7 ) sub_fig.plot( range(len(t_seg)), t_seg / 2, label="T", color="blue", linewidth=0.7 ) plt.savefig(image_path, dpi=150) plt.close() def gen_image(subject_id, image_path, pp_image_path, ths): try: input, seg = gen_seg(subject_id) gen_each_image(input, seg, image_path, ths) gen_each_image(input, seg, pp_image_path, ths, True) return True except Exception: print(tb.format_exc()) return False with gr.Blocks() as demo: with gr.Tab("App"): with gr.Row(): gr.Textbox( "Welcome to visit ECG-Delineation space", label="Information", lines=1, ) gr_dfs = [] with gr.Row(): gr_dfs.append( gr.Dataframe( value=train_df, interactive=False, max_height=250, label="our train dataset. (source: ./res/ludb/dataset/train_for_public.csv)", ) ) with gr.Row(): gr_dfs.append( gr.Dataframe( value=valid_df, interactive=False, max_height=250, label="our valid dataset. (source: ./res/ludb/dataset/valid_for_public.csv)", ) ) with gr.Row(): gr_dfs.append( gr.Dataframe( value=test_df, interactive=False, max_height=250, label="our test dataset. (source: ./res/ludb/dataset/test_for_public.csv)", ) ) with gr.Row(): gr_ths = [ gr.Textbox( label="Interval Threshold of P (ms)", lines=1, value="10", ), gr.Textbox( label="Duration Threshold of P (ms)", lines=1, value="50", ), gr.Textbox( label="Interval Threshold of QRS (ms)", lines=1, value="50", ), gr.Textbox( label="Duration Threshold of QRS (ms)", lines=1, value="50", ), gr.Textbox( label="Interval Threshold of T (ms)", lines=1, value="30", ), gr.Textbox( label="Duration Threshold of T (ms)", lines=1, value="50", ), ] with gr.Row(): gr_image = gr.Image(type="filepath", label="Output") gr_pp_image = gr.Image(type="filepath", label="PostProcessed Output") def show_image(df: pd.DataFrame, evt: gr.SelectData, *ths): subject_id = evt.row_value[0] image_path = f"{tempfile.gettempdir()}/ludb_{subject_id}.png" pp_image_path = f"{tempfile.gettempdir()}/ludb_{subject_id}_pp.png" if not ALWAYS_RECREATE_IMAGE and ( os.path.exists(image_path) and os.path.exists(pp_image_path) ): return [image_path, pp_image_path] gen_image( subject_id, image_path, pp_image_path, [int(th) / 2 for th in ths] ) return [image_path, pp_image_path] for gr_df in gr_dfs: gr_df.select( fn=show_image, inputs=[gr_df, *gr_ths], outputs=[gr_image, gr_pp_image], ) demo.launch()