Umang-Bansal commited on
Commit
2c8fa83
·
verified ·
1 Parent(s): 5ebcf6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -56
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
- from sklearn.model_selection import train_test_split, GridSearchCV
5
- from sklearn.svm import SVC
6
  from sklearn.preprocessing import StandardScaler
7
  import scipy
8
  from scipy import signal
9
  import pickle
10
- import asyncio
11
 
 
12
  global_data = None
13
 
14
  def get_data_preview(file):
@@ -18,37 +16,29 @@ def get_data_preview(file):
18
 
19
  def label_data(ranges):
20
  global global_data
21
- for start, end, label in ranges.values:
22
  global_data.loc[start:end, 'label'] = label
23
- labeled_data = pd.concat([global_data.head(), global_data.tail()])
24
- return labeled_data
25
-
26
-
27
- #def label_data(ranges):
28
- #global global_data
29
- #for start, end, label in ranges.values:
30
- # global_data.loc[start:end, 'label'] = label
31
- #return global_data
32
-
33
-
34
- def preprocess_data(data):
35
- data.drop(columns=data.columns[0], axis=1, inplace=True)
36
- data.columns = ['raw_eeg', 'label']
37
- raw_data = data['raw_eeg']
38
- labels_old = data['label']
39
 
 
 
 
 
 
 
 
40
  sampling_rate = 512
41
  notch_freq = 50.0
42
  lowcut, highcut = 0.5, 30.0
43
-
44
  nyquist = (0.5 * sampling_rate)
45
  notch_freq_normalized = notch_freq / nyquist
46
  b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
47
-
48
  lowcut_normalized = lowcut / nyquist
49
  highcut_normalized = highcut / nyquist
50
  b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
51
-
52
  features = []
53
  labels = []
54
 
@@ -94,56 +84,38 @@ def preprocess_data(data):
94
  segment_features = {**segment_features, **additional_features}
95
  features.append(segment_features)
96
  labels.append(labels_old[i])
97
-
98
  columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
99
  df_features = pd.DataFrame(features, columns=columns)
100
  df_features['label'] = labels
101
- return df_features
102
-
103
- def train_model():
104
- global global_data
105
- data = preprocess_data(global_data)
106
  scaler = StandardScaler()
107
- X = data.drop('label', axis=1)
108
- y = data['label']
109
- X_scaled = scaler.fit_transform(X)
110
- X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
111
 
112
- param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
113
- svc = SVC(probability=True)
114
- grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
115
- grid_search.fit(X_train, y_train)
116
 
117
- model = grid_search.best_estimator_
118
- model_filename = 'model.pkl'
119
  scaler_filename = 'scaler.pkl'
120
-
121
- with open(model_filename, 'wb') as file:
122
- pickle.dump(model, file)
123
-
124
  with open(scaler_filename, 'wb') as file:
125
  pickle.dump(scaler, file)
126
 
127
- return "Training complete! Model and scaler saved.", model_filename, scaler_filename
128
-
129
 
130
  with gr.Blocks() as demo:
131
  file_input = gr.File(label="Upload CSV File")
132
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
133
  ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
134
-
135
- #start_input = gr.Number(label="Start Index", value=0)
136
- #end_input = gr.Number(label="End Index", value=100)
137
- #label_input = gr.Number(label="Label Value", value=1)
138
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
139
- training_status = gr.Textbox(label="Training Status")
140
- model_file = gr.File(label="Download Trained Model")
141
  scaler_file = gr.File(label="Download Scaler")
142
-
143
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
144
  label_button = gr.Button("Label Data")
145
- label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview, queue=True)
146
- train_button = gr.Button("Train Model")
147
- train_button.click(train_model, outputs=[training_status, model_file, scaler_file])
148
 
149
- demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
 
 
4
  from sklearn.preprocessing import StandardScaler
5
  import scipy
6
  from scipy import signal
7
  import pickle
 
8
 
9
+ # Global variable to store the uploaded data
10
  global_data = None
11
 
12
  def get_data_preview(file):
 
16
 
17
  def label_data(ranges):
18
  global global_data
19
+ for start, end, label in ranges:
20
  global_data.loc[start:end, 'label'] = label
21
+ return global_data.head()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def preprocess_data():
24
+ global global_data
25
+ global_data.drop(columns=global_data.columns[0], axis=1, inplace=True)
26
+ global_data.columns = ['raw_eeg', 'label']
27
+ raw_data = global_data['raw_eeg']
28
+ labels_old = global_data['label']
29
+
30
  sampling_rate = 512
31
  notch_freq = 50.0
32
  lowcut, highcut = 0.5, 30.0
33
+
34
  nyquist = (0.5 * sampling_rate)
35
  notch_freq_normalized = notch_freq / nyquist
36
  b_notch, a_notch = signal.iirnotch(notch_freq_normalized, Q=0.05, fs=sampling_rate)
37
+
38
  lowcut_normalized = lowcut / nyquist
39
  highcut_normalized = highcut / nyquist
40
  b_bandpass, a_bandpass = signal.butter(4, [lowcut_normalized, highcut_normalized], btype='band')
41
+
42
  features = []
43
  labels = []
44
 
 
84
  segment_features = {**segment_features, **additional_features}
85
  features.append(segment_features)
86
  labels.append(labels_old[i])
87
+
88
  columns = ['E_alpha', 'E_beta', 'E_theta', 'E_delta', 'alpha_beta_ratio', 'peak_frequency', 'spectral_centroid', 'spectral_slope']
89
  df_features = pd.DataFrame(features, columns=columns)
90
  df_features['label'] = labels
91
+
 
 
 
 
92
  scaler = StandardScaler()
93
+ X_scaled = scaler.fit_transform(df_features.drop('label', axis=1))
94
+ df_scaled = pd.DataFrame(X_scaled, columns=columns)
95
+ df_scaled['label'] = df_features['label']
 
96
 
97
+ processed_data_filename = 'processed_data.csv'
98
+ df_scaled.to_csv(processed_data_filename, index=False)
 
 
99
 
 
 
100
  scaler_filename = 'scaler.pkl'
 
 
 
 
101
  with open(scaler_filename, 'wb') as file:
102
  pickle.dump(scaler, file)
103
 
104
+ return "Data preprocessing complete! Download the processed data and scaler below.", processed_data_filename, scaler_filename
 
105
 
106
  with gr.Blocks() as demo:
107
  file_input = gr.File(label="Upload CSV File")
108
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
109
  ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
 
 
 
 
110
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
111
+ preprocessing_status = gr.Textbox(label="Preprocessing Status")
112
+ processed_data_file = gr.File(label="Download Processed Data")
113
  scaler_file = gr.File(label="Download Scaler")
114
+
115
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
116
  label_button = gr.Button("Label Data")
117
+ label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview)
118
+ preprocess_button = gr.Button("Preprocess Data")
119
+ preprocess_button.click(preprocess_data, outputs=[preprocessing_status, processed_data_file, scaler_file])
120
 
121
+ demo.launch()