File size: 3,343 Bytes
29b2705 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import numpy as np
from scipy.io import savemat
import h5py
from bcgunet import bcgunet
import platform
import os
import time
import os
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("agg")
dir = os.path.dirname(os.path.realpath(__file__)) + "/tmp"
os.makedirs(dir, exist_ok=True)
def run(
files: list[bytes],
lr: float,
winsec: int,
iters: int,
onecycle: bool,
ecg: str,
bce: str,
eeg: str,
) -> tuple[list[str], str]:
task = os.path.join(dir, str(int(time.time())))
os.makedirs(task)
outputs = []
for i, file in enumerate(files):
input = os.path.join(task, str(i) + ".mat")
with open(input, "wb") as o:
o.write(file)
output = os.path.join(task, str(i) + "_clean.mat")
mat = h5py.File(input, "r")
ECG = np.array(mat[ecg]).flatten()
EEG = np.array(mat[bce]).T
EEG_unet = bcgunet.run(
EEG,
ECG,
iter_num=iters,
winsize_sec=winsec,
lr=lr,
onecycle=onecycle,
)
result = dict()
result[eeg] = EEG_unet
savemat(output, result, do_compression=True)
outputs.append(output)
if i == 0:
plt.figure(figsize=(12, 6), dpi=300)
plt.plot(EEG[19, :10000], "b.-", label="Orig EEG")
plt.plot(EEG_unet[19, :10000], "g.-", label="U-Net")
plt.legend()
plt.title("BCG Unet")
plt.xlabel("Time (samples)")
plot = os.path.join(task, str(i) + ".png")
plt.savefig(plot)
return outputs, plot
def main():
app = gr.Interface(
title="BCG Unet",
description="BCGunet: Suppressing BCG artifacts on EEG collected inside an MRI scanner",
fn=run,
inputs=[
gr.File(
label="Input Files (.mat)",
type="binary",
file_types=["mat"],
file_count=["multiple", "directory"],
),
gr.Slider(
label="Learning Rate", minimum=1e-5, maximum=1e-1, step=1e-5, value=1e-3
),
gr.Slider(
label="Window Size (seconds)", minimum=1, maximum=10, step=1, value=2
),
gr.Slider(
label="Number of Iterations",
minimum=1000,
maximum=10000,
step=1000,
value=5000,
),
gr.Checkbox(
label="One Cycle Scheduler",
value=True,
),
gr.Textbox(
label="Variable name for ECG (input)",
value="ECG",
),
gr.Textbox(
label="Variable name for BCG corropted EEG (input)",
value="EEG_before_bcg",
),
gr.Textbox(
label="Variable name for clean EEG (output)",
value="EEG_clean",
),
],
outputs=[
gr.File(label="Output File", file_count="multiple"),
gr.Image(label="Output Image", type="filepath"),
],
allow_flagging="never",
)
app.launch()
if __name__ == "__main__":
main()
if platform.system() == "Windows":
os.system("pause")
|