Umang-Bansal commited on
Commit
bd3d9f9
·
verified ·
1 Parent(s): 5c8bfc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -4
app.py CHANGED
@@ -24,9 +24,12 @@ def label_data(ranges):
24
  start = int(start)
25
  end = int(end)
26
  print(f"Processing range {i}: start={start}, end={end}, label={label}")
27
- if start < 0 or end >= len(global_data):
28
  print(f"Invalid range: start={start}, end={end}, label={label}")
29
  continue
 
 
 
30
  global_data.loc[start:end, 'label'] = label
31
  print("Data after labeling:\n", global_data.tail())
32
  return global_data.tail()
@@ -34,8 +37,7 @@ def label_data(ranges):
34
  def preprocess_data():
35
  global global_data
36
  try:
37
- if 'Unnamed: 0' in global_data.columns:
38
- global_data.drop(columns='Unnamed: 0', axis=1, inplace=True)
39
  global_data.columns = ['raw_eeg', 'label']
40
  raw_data = global_data['raw_eeg']
41
  labels_old = global_data['label']
@@ -120,16 +122,43 @@ def preprocess_data():
120
  except Exception as e:
121
  print(f"An error occurred during preprocessing: {e}")
122
  return f"An error occurred during preprocessing: {e}", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  with gr.Blocks() as demo:
125
  file_input = gr.File(label="Upload CSV File")
126
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
127
  ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
128
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
 
129
  preprocessing_status = gr.Textbox(label="Preprocessing Status")
130
  processed_data_file = gr.File(label="Download Processed Data")
131
  scaler_file = gr.File(label="Download Scaler")
132
-
133
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
134
  label_button = gr.Button("Label Data")
135
  label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview)
 
24
  start = int(start)
25
  end = int(end)
26
  print(f"Processing range {i}: start={start}, end={end}, label={label}")
27
+ if start < 0 or start >= len(global_data):
28
  print(f"Invalid range: start={start}, end={end}, label={label}")
29
  continue
30
+ if end >= len(global_data):
31
+ print(f"End index {end} exceeds data length {len(global_data)}. Adjusting to {len(global_data) - 1}.")
32
+ end = len(global_data) - 1
33
  global_data.loc[start:end, 'label'] = label
34
  print("Data after labeling:\n", global_data.tail())
35
  return global_data.tail()
 
37
  def preprocess_data():
38
  global global_data
39
  try:
40
+ global_data.drop(columns=global_data.columns[0], axis=1, inplace=True)
 
41
  global_data.columns = ['raw_eeg', 'label']
42
  raw_data = global_data['raw_eeg']
43
  labels_old = global_data['label']
 
122
  except Exception as e:
123
  print(f"An error occurred during preprocessing: {e}")
124
  return f"An error occurred during preprocessing: {e}", None, None
125
+
126
+ def train_model():
127
+ global global_data
128
+ data = preprocess_data(global_data)
129
+ scaler = StandardScaler()
130
+ X = data.drop('label', axis=1)
131
+ y = data['label']
132
+ X_scaled = scaler.fit_transform(X)
133
+ X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
134
+
135
+ param_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.1, 0.01, 0.001, 0.0001], 'kernel': ['rbf']}
136
+ svc = SVC(probability=True)
137
+ grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, verbose=2, n_jobs=-1)
138
+ grid_search.fit(X_train, y_train)
139
+
140
+ model = grid_search.best_estimator_
141
+ model_filename = 'model.pkl'
142
+ scaler_filename = 'scaler.pkl'
143
+
144
+ with open(model_filename, 'wb') as file:
145
+ pickle.dump(model, file)
146
+
147
+ with open(scaler_filename, 'wb') as file:
148
+ pickle.dump(scaler, file)
149
+
150
+ return "Training complete! Model and scaler saved.", model_filename, scaler_filename
151
 
152
  with gr.Blocks() as demo:
153
  file_input = gr.File(label="Upload CSV File")
154
  data_preview = gr.Dataframe(label="Data Preview", interactive=False)
155
  ranges_input = gr.Dataframe(headers=["Start Index", "End Index", "Label"], label="Ranges for Labeling")
156
  labeled_data_preview = gr.Dataframe(label="Labeled Data Preview", interactive=False)
157
+
158
  preprocessing_status = gr.Textbox(label="Preprocessing Status")
159
  processed_data_file = gr.File(label="Download Processed Data")
160
  scaler_file = gr.File(label="Download Scaler")
161
+
162
  file_input.upload(get_data_preview, inputs=file_input, outputs=data_preview)
163
  label_button = gr.Button("Label Data")
164
  label_button.click(label_data, inputs=[ranges_input], outputs=labeled_data_preview)