import numpy as np from scipy.io import savemat import h5py from bcgunet import bcgunet import platform import os import time import os import gradio as gr import matplotlib import matplotlib.pyplot as plt matplotlib.use("agg") dir = os.path.dirname(os.path.realpath(__file__)) + "/tmp" os.makedirs(dir, exist_ok=True) def run( files: list[bytes], lr: float, winsec: int, iters: int, onecycle: bool, ecg: str, bce: str, eeg: str, ) -> tuple[list[str], str]: task = os.path.join(dir, str(int(time.time()))) os.makedirs(task) outputs = [] for i, file in enumerate(files): input = os.path.join(task, str(i) + ".mat") with open(input, "wb") as o: o.write(file) output = os.path.join(task, str(i) + "_clean.mat") mat = h5py.File(input, "r") ECG = np.array(mat[ecg]).flatten() EEG = np.array(mat[bce]).T EEG_unet = bcgunet.run( EEG, ECG, iter_num=iters, winsize_sec=winsec, lr=lr, onecycle=onecycle, ) result = dict() result[eeg] = EEG_unet savemat(output, result, do_compression=True) outputs.append(output) if i == 0: plt.figure(figsize=(12, 6), dpi=300) plt.plot(EEG[19, :10000], "b.-", label="Orig EEG") plt.plot(EEG_unet[19, :10000], "g.-", label="U-Net") plt.legend() plt.title("BCG Unet") plt.xlabel("Time (samples)") plot = os.path.join(task, str(i) + ".png") plt.savefig(plot) return outputs, plot def main(): app = gr.Interface( title="BCG Unet", description="BCGunet: Suppressing BCG artifacts on EEG collected inside an MRI scanner", fn=run, inputs=[ gr.File( label="Input Files (.mat)", type="binary", file_types=["mat"], file_count=["multiple", "directory"], ), gr.Slider( label="Learning Rate", minimum=1e-5, maximum=1e-1, step=1e-5, value=1e-3 ), gr.Slider( label="Window Size (seconds)", minimum=1, maximum=10, step=1, value=2 ), gr.Slider( label="Number of Iterations", minimum=1000, maximum=10000, step=1000, value=5000, ), gr.Checkbox( label="One Cycle Scheduler", value=True, ), gr.Textbox( label="Variable name for ECG (input)", value="ECG", ), gr.Textbox( label="Variable name for BCG corropted EEG (input)", value="EEG_before_bcg", ), gr.Textbox( label="Variable name for clean EEG (output)", value="EEG_clean", ), ], outputs=[ gr.File(label="Output File", file_count="multiple"), gr.Image(label="Output Image", type="filepath"), ], allow_flagging="never", ) app.launch() if __name__ == "__main__": main() if platform.system() == "Windows": os.system("pause")