Spaces:
Sleeping
Sleeping
File size: 7,106 Bytes
5e22eaa 6acdc14 5e22eaa 2c8fa83 9ac0ba6 5e22eaa 9ac0ba6 17217b4 915c4bc 9ac0ba6 5e22eaa 9ac0ba6 17217b4 cc09ded bd3d9f9 cc09ded bd3d9f9 9ac0ba6 9956f18 5e22eaa 2c8fa83 5c8bfc0 bd3d9f9 5c8bfc0 5e22eaa 5c8bfc0 6acdc14 bd3d9f9 6acdc14 bd3d9f9 6acdc14 bd3d9f9 6acdc14 5e22eaa 9ac0ba6 5e22eaa bd3d9f9 6acdc14 5e22eaa bd3d9f9 5e22eaa 6acdc14 5e22eaa 2c8fa83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import scipy
from scipy import signal
import pickle
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split, GridSearchCV
# Global variable to store the uploaded data
global_data = None
def get_data_preview(file):
global global_data
global_data = pd.read_csv(file.name)
global_data['label'] = np.nan # Initialize a label column
global_data['label'] = global_data['label'].astype(object) # Ensure the label column can hold different types
return global_data.head()
def label_data(ranges):
global global_data
print("Ranges received for labeling:", ranges)
for i, (start, end, label) in enumerate(ranges.values):
start = int(start)
end = int(end)
if start < 0 or start >= len(global_data):
continue
if end >= len(global_data):
end = len(global_data) - 1
global_data.loc[start:end, 'label'] = label
return global_data.tail()
def preprocess_data():
global global_data
try:
global_data.drop(columns=global_data.columns[0], axis=1, inplace=True)
global_data.columns = ['raw_eeg', 'label']
raw_data = global_data['raw_eeg']
labels_old = global_data['label']
sampling_rate = 512
notch_freq = 50.0
lowcut, highcut = 0.5, 30.0
nyquist = (0.5 * sampling_rate)
notch_freq_normalized = notch_freq / nyquist
b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
lowcut_normalized = lowcut / nyquist
highcut_normalized = highcut / nyquist
b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
features = []
labels = []
def calculate_psd_features(segment, sampling_rate):
f, psd_values = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
alpha_indices = np.where((f >= 8) & (f <= 13))
beta_indices = np.where((f >= 14) & (f <= 30))
theta_indices = np.where((f >= 4) & (f <= 7))
delta_indices = np.where((f >= 0.5) & (f <= 3))
energy_alpha = np.sum(psd_values[alpha_indices])
energy_beta = np.sum(psd_values[beta_indices])
energy_theta = np.sum(psd_values[theta_indices])
energy_delta = np.sum(psd_values[delta_indices])
alpha_beta_ratio = energy_alpha / energy_beta
return {
'E_alpha': energy_alpha,
'E_beta': energy_beta,
'E_theta': energy_theta,
'E_delta': energy_delta,
'alpha_beta_ratio': alpha_beta_ratio
}
def calculate_additional_features(segment, sampling_rate):
f, psd = scipy.signal.welch(segment, fs=sampling_rate, nperseg=len(segment))
peak_frequency = f[np.argmax(psd)]
spectral_centroid = np.sum(f * psd) / np.sum(psd)
log_f = np.log(f[1:])
log_psd = np.log(psd[1:])
spectral_slope = np.polyfit(log_f, log_psd, 1)[0]
return {
'peak_frequency': peak_frequency,
'spectral_centroid': spectral_centroid,
'spectral_slope': spectral_slope
}
for i in range(0, len(raw_data) - 512, 256):
print(f"Processing segment {i} to {i + 512}")
segment = raw_data.loc[i:i+512]
segment = pd.to_numeric(segment, errors='coerce')
segment = signal.filtfilt(b_notch, a_notch, segment)
segment = signal.filtfilt(b_bandpass, a_bandpass, segment)
segment_features = calculate_psd_features(segment, 512)
additional_features = calculate_additional_features(segment, 512)
segment_features = {**segment_features, **additional_features}
features.append(segment_features)
labels.append(labels_old[i])
columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
df_features = pd.DataFrame(features, columns=columns)
df_features['label'] = labels
scaler = StandardScaler()
X_scaled = scaler.fit_transform(df_features.drop('label', axis=1))
df_scaled = pd.DataFrame(X_scaled, columns=columns)
df_scaled['label'] = df_features['label']
processed_data_filename = 'processed_data.csv'
df_scaled.to_csv(processed_data_filename, index=False)
scaler_filename = 'scaler.pkl'
with open(scaler_filename, 'wb') as file:
pickle.dump(scaler, file)
return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename
except Exception as e:
return f"An error occurred during preprocessing: {e}", None, None
def train_model():
global global_data
try:
preprocess_status, processed_data_filename, scaler_filename = preprocess_data()
if processed_data_filename is None:
return preprocess_status, None, None
df_scaled = pd.read_csv(processed_data_filename)
X = df_scaled.drop('label', axis=1)
y = df_scaled['label']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
svc = SVC(probability=True)
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
grid_search.fit(X_train, y_train)
model = grid_search.best_estimator_
model_filename = 'model.pkl'
with open(model_filename, 'wb') as file:
pickle.dump(model, file)
return "Training complete! Download the model and scaler below.", model_filename, scaler_filename
except Exception as e:
print(f"An error occurred during training: {e}")
return f"An error occurred during training: {e}", None, None
with gr.Blocks() as demo:
file_input = gr.File(label="Upload CSV File")
data_preview = gr.Dataframe(label="Data Preview", interactive=False)
ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
training_status = gr.Textbox(label="Training Status")
model_file = gr.File(label="Download Trained Model")
scaler_file = gr.File(label="Download Scaler")
file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
label_button = gr.Button("Label Data")
label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True)
train_button = gr.Button("Train Model")
train_button.click(train_model, outputs=[training_status, model_file, scaler_file])
demo.launch()
|