def show_segmentation(ti=0, target_dir: str = "./seg"): fig = plt.figure(figsize=(15, 2 * len(Alg) * 2)) fig.subplots_adjust(hspace=0, wspace=0.1) # for lead_cnt, ti in enumerate(range(12 * target_idx, 12 * (target_idx + 1))): test_df = pd.read_csv(test_path) lead_type = ast.literal_eval(test_df.iloc[ti]["lead_type"])[0] seq = test_df.iloc[ti]["seq"] object_id: str = test_df.iloc[ti]["objectid"] ecg_data_path = f"/bfai/data/ecg_data/{object_id[18:22]}/{object_id}.json" # ecg_data_path = test_df.iloc[ti]["file_path"] ecg_data = [] with open(ecg_data_path) as ecg_data_file: ecg_json = json.load(ecg_data_file) ecg_data = ( np.array(ecg_json["waveform"]["data"][lead_type]) * ecg_json["study"]["mv_unit"] ) if seq != "1/1": seq_idx, seq_range = [int(str_seq) for str_seq in seq.split("/")] use_lead_length = 5000 # (500 * 10) total_use_length = use_lead_length * seq_range front_idx = int((len(ecg_data) - total_use_length) / 2) ecg_data = ecg_data[front_idx:] for alg, alg_idx in ALG_ORDER.items(): for pp_type_idx, pp_type in enumerate(["ori", "pp"]): sub_fig = fig.add_subplot(len(Alg) * 2, 1, 2 * alg_idx + pp_type_idx + 1) # sub_fig.set_title(f"{alg} - Lead {lead_type}") sub_fig.text( 0.02, 0.5, # f"{alg} {pp_type} - {lead_type}", f"{alg} {pp_type}\n{lead_type}", fontsize=9, fontweight="bold", ha="center", va="center", rotation=90, transform=sub_fig.transAxes, ) sub_fig.set_xticks([]) sub_fig.set_yticks([]) sub_fig.plot(range(len(ecg_data)), ecg_data, color="black", linewidth=1.0) sub_fig.plot( range(len(output[alg][ti][0])), (output[alg][ti][0] >= cutoff[alg][0]).astype(int) / 2 if pp_type_idx == 0 else (pp_output[alg][ti][0]).astype(int) / 2, label="P", color="red", linewidth=0.7, ) sub_fig.plot( range(len(output[alg][ti][1])), (output[alg][ti][1] >= cutoff[alg][1]).astype(int) if pp_type_idx == 0 else (pp_output[alg][ti][1]).astype(int), label="QRS", color="green", linewidth=0.7, ) sub_fig.plot( (output[alg][ti][2] >= cutoff[alg][2]).astype(int) / 2 if pp_type_idx == 0 else (pp_output[alg][ti][2]).astype(int) / 2, label="T", color="blue", linewidth=0.7, ) sub_fig.plot( range(len(origin_seg[ti][0])), ((origin_seg[ti][0] > 0).astype(int) * (-1)) / 2, label="P Label", color="salmon", linewidth=0.7, ) sub_fig.plot( range(len(origin_seg[ti][1])), ((origin_seg[ti][1] > 0).astype(int) * (-1)), label="QRS Label", color="seagreen", linewidth=0.7, ) sub_fig.plot( range(len(origin_seg[ti][2])), ((origin_seg[ti][2] > 0).astype(int) * (-1)) / 2, label="T Label", color="darkslateblue", linewidth=0.7, ) sub_fig.plot( range(len(origin_seg[ti][0])), ((origin_seg[ti][0] < 0).astype(int) * (-1)), label="P UnLabeled", linestyle=":", color="salmon", linewidth=0.5, ) sub_fig.plot( range(len(origin_seg[ti][1])), ((origin_seg[ti][1] < 0).astype(int) * (-1)), label="QRS UnLabeled", linestyle=":", color="seagreen", linewidth=0.5, ) sub_fig.plot( range(len(origin_seg[ti][2])), ((origin_seg[ti][2] < 0).astype(int) * (-1)), label="T UnLabeled", linestyle=":", color="darkslateblue", linewidth=0.5, ) # sub_fig.legend() plt.savefig( f"./{target_dir}/{test_df.iloc[ti]['lead_cnt']}_{object_id}_{lead_type}_{test_df.iloc[ti]['seq'].replace('/', '_')}.png", dpi=150, ) plt.close()