CosmickVisions commited on
Commit
e4d7511
Β·
verified Β·
1 Parent(s): 8453100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -62
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import streamlit as st
2
  import pandas as pd
 
 
3
  import numpy as np
4
  import plotly.express as px
5
  import plotly.graph_objects as go
@@ -1335,69 +1337,78 @@ elif app_mode == "Model Training":
1335
  num_trials = st.number_input("Number of Trials", 1, 100, 10, help="Number of trials for hyperparameter search.")
1336
 
1337
  # ----- [5. Training & Monitoring] -----
1338
- st.subheader("🎯 Training Configuration")
1339
-
1340
- class LiveMetrics(Callback):
1341
- def on_epoch_end(self, epoch, logs=None):
1342
- if 'metrics' not in st.session_state:
1343
- st.session_state.metrics = []
1344
- st.session_state.metrics.append(logs)
1345
- self.update_chart()
1346
-
1347
- def update_chart(self):
1348
- df = pd.DataFrame(st.session_state.metrics)
1349
- fig = px.line(df, y=['loss', 'val_loss'],
1350
- title="Training Progress")
1351
- loss_chart.plotly_chart(fig)
1352
-
1353
- if st.button("πŸš€ Start Training"):
1354
- try:
1355
- model = tf.keras.Sequential()
1356
-
1357
- # Add layers with regularization
1358
- for layer in st.session_state.layers:
1359
- layer_class = {
1360
- "Dense": Dense,
1361
- "Conv2D": Conv2D,
1362
- "LSTM": LSTM
1363
- }[layer['type']]
1364
-
1365
- # Add regularization
1366
- if l2_reg > 0:
1367
- layer['kernel_regularizer'] = tf.keras.regularizers.l2(l2_reg)
1368
-
1369
- model.add(layer_class(**layer))
1370
-
1371
- # Add batch norm after each layer
1372
- if batch_norm:
1373
- model.add(BatchNormalization())
1374
-
1375
- # Add global dropout
1376
- if dropout > 0:
1377
- model.add(Dropout(dropout))
1378
-
1379
- model.compile(
1380
- optimizer=optimizer,
1381
- loss=loss,
1382
- metrics=metrics
1383
- )
1384
 
1385
- # Show model summary
1386
- st.subheader("Model Architecture")
1387
- with tempfile.NamedTemporaryFile(suffix='.png') as tmp:
1388
- plot_model(model, to_file=tmp.name, show_shapes=True)
1389
- st.image(tmp.name)
1390
 
1391
- # Start training
1392
- st.subheader("Live Training Metrics")
1393
- loss_chart = st.empty()
1394
- model.fit(X_train, y_train,
1395
- epochs=10,
1396
- validation_data=(X_val, y_val),
1397
- callbacks=[LiveMetrics()])
1398
 
1399
- except Exception as e:
1400
- st.error(f"Training failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1401
 
1402
 
1403
  # ----- [6. Export & Deployment] -----
@@ -1501,13 +1512,17 @@ if app_mode == "Predictions":
1501
 
1502
  if problem_type == "Classification":
1503
  st.metric("Predicted Class", str(prediction))
 
 
 
 
1504
  else:
1505
  st.metric("Predicted Value", f"{prediction:.2f}")
1506
 
1507
  # 8. Feature Explanation (SHAP)
1508
  enhance_section_title("Insights", "πŸ’‘")
1509
 
1510
- if problem_type == "Classification":
1511
  explainer = shap.TreeExplainer(model)
1512
  shap_values = explainer.shap_values(scaled_input)
1513
  fig = shap.force_plot(explainer.expected_value[1], shap_values[1], input_df, matplotlib=False, link="logit")
@@ -1534,4 +1549,4 @@ if app_mode == "Predictions":
1534
  st.warning(f"Could not calculate permutation feature importance: {e}")
1535
 
1536
  except Exception as e:
1537
- st.error(f"Prediction failed: {str(e)}")
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import seaborn as sns
4
+ from scipy.stats import boxcox
5
  import numpy as np
6
  import plotly.express as px
7
  import plotly.graph_objects as go
 
1337
  num_trials = st.number_input("Number of Trials", 1, 100, 10, help="Number of trials for hyperparameter search.")
1338
 
1339
  # ----- [5. Training & Monitoring] -----
1340
+ st.subheader("🎯 Training Configuration")
1341
+
1342
+ import shap # Ensure SHAP is installed: pip install shap
1343
+
1344
+ class LiveMetrics(Callback):
1345
+ def on_epoch_end(self, epoch, logs=None):
1346
+ if 'metrics' not in st.session_state:
1347
+ st.session_state.metrics = []
1348
+ st.session_state.metrics.append(logs)
1349
+ self.update_chart()
1350
+
1351
+ def update_chart(self):
1352
+ df = pd.DataFrame(st.session_state.metrics)
1353
+ fig = px.line(df, y=['loss', 'val_loss'], title="Training Progress")
1354
+ loss_chart.plotly_chart(fig)
1355
+
1356
+ if st.button("πŸš€ Start Training"):
1357
+ try:
1358
+ model = tf.keras.Sequential()
1359
+
1360
+ # Add layers with regularization
1361
+ for layer in st.session_state.layers:
1362
+ layer_class = {
1363
+ "Dense": Dense,
1364
+ "Conv2D": Conv2D,
1365
+ "LSTM": LSTM
1366
+ }[layer['type']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
+ # Add regularization
1369
+ if l2_reg > 0:
1370
+ layer['kernel_regularizer'] = tf.keras.regularizers.l2(l2_reg)
 
 
1371
 
1372
+ model.add(layer_class(**layer))
 
 
 
 
 
 
1373
 
1374
+ # Add batch norm after each layer
1375
+ if batch_norm:
1376
+ model.add(BatchNormalization())
1377
+
1378
+ # Add global dropout
1379
+ if dropout > 0:
1380
+ model.add(Dropout(dropout))
1381
+
1382
+ model.compile(
1383
+ optimizer=optimizer,
1384
+ loss=loss,
1385
+ metrics=metrics
1386
+ )
1387
+
1388
+ # Show model summary
1389
+ st.subheader("Model Architecture")
1390
+ with tempfile.NamedTemporaryFile(suffix='.png') as tmp:
1391
+ plot_model(model, to_file=tmp.name, show_shapes=True)
1392
+ st.image(tmp.name)
1393
+
1394
+ # Start training
1395
+ st.subheader("Live Training Metrics")
1396
+ loss_chart = st.empty()
1397
+ model.fit(X_train, y_train,
1398
+ epochs=10,
1399
+ validation_data=(X_val, y_val),
1400
+ callbacks=[LiveMetrics()])
1401
+
1402
+ # SHAP explanations
1403
+ st.subheader("SHAP Explanations")
1404
+ explainer = shap.KernelExplainer(model.predict, X_train[:100])
1405
+ shap_values = explainer.shap_values(X_train[:100])
1406
+ shap.summary_plot(shap_values, X_train[:100], plot_type="bar")
1407
+ st.pyplot(bbox_inches='tight')
1408
+
1409
+ except Exception as e:
1410
+ st.error(f"Training failed: {str(e)}")
1411
+
1412
 
1413
 
1414
  # ----- [6. Export & Deployment] -----
 
1512
 
1513
  if problem_type == "Classification":
1514
  st.metric("Predicted Class", str(prediction))
1515
+ elif problem_type == "Binary Classification":
1516
+ st.metric("Predicted Probability", f"{prediction:.2f}")
1517
+ binary_class = "Yes" if prediction >= 0.5 else "No"
1518
+ st.metric("Binary Class", binary_class)
1519
  else:
1520
  st.metric("Predicted Value", f"{prediction:.2f}")
1521
 
1522
  # 8. Feature Explanation (SHAP)
1523
  enhance_section_title("Insights", "πŸ’‘")
1524
 
1525
+ if problem_type == "Classification" or problem_type == "Binary Classification":
1526
  explainer = shap.TreeExplainer(model)
1527
  shap_values = explainer.shap_values(scaled_input)
1528
  fig = shap.force_plot(explainer.expected_value[1], shap_values[1], input_df, matplotlib=False, link="logit")
 
1549
  st.warning(f"Could not calculate permutation feature importance: {e}")
1550
 
1551
  except Exception as e:
1552
+ st.error(f"Prediction failed: {str(e)}")