Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
|
|
|
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
import plotly.graph_objects as go
|
@@ -1335,69 +1337,78 @@ elif app_mode == "Model Training":
|
|
1335 |
num_trials = st.number_input("Number of Trials", 1, 100, 10, help="Number of trials for hyperparameter search.")
|
1336 |
|
1337 |
# ----- [5. Training & Monitoring] -----
|
1338 |
-
|
1339 |
-
|
1340 |
-
|
1341 |
-
|
1342 |
-
|
1343 |
-
|
1344 |
-
|
1345 |
-
|
1346 |
-
|
1347 |
-
|
1348 |
-
|
1349 |
-
|
1350 |
-
|
1351 |
-
|
1352 |
-
|
1353 |
-
|
1354 |
-
|
1355 |
-
|
1356 |
-
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
1361 |
-
|
1362 |
-
|
1363 |
-
|
1364 |
-
|
1365 |
-
# Add regularization
|
1366 |
-
if l2_reg > 0:
|
1367 |
-
layer['kernel_regularizer'] = tf.keras.regularizers.l2(l2_reg)
|
1368 |
-
|
1369 |
-
model.add(layer_class(**layer))
|
1370 |
-
|
1371 |
-
# Add batch norm after each layer
|
1372 |
-
if batch_norm:
|
1373 |
-
model.add(BatchNormalization())
|
1374 |
-
|
1375 |
-
# Add global dropout
|
1376 |
-
if dropout > 0:
|
1377 |
-
model.add(Dropout(dropout))
|
1378 |
-
|
1379 |
-
model.compile(
|
1380 |
-
optimizer=optimizer,
|
1381 |
-
loss=loss,
|
1382 |
-
metrics=metrics
|
1383 |
-
)
|
1384 |
|
1385 |
-
#
|
1386 |
-
|
1387 |
-
|
1388 |
-
plot_model(model, to_file=tmp.name, show_shapes=True)
|
1389 |
-
st.image(tmp.name)
|
1390 |
|
1391 |
-
|
1392 |
-
st.subheader("Live Training Metrics")
|
1393 |
-
loss_chart = st.empty()
|
1394 |
-
model.fit(X_train, y_train,
|
1395 |
-
epochs=10,
|
1396 |
-
validation_data=(X_val, y_val),
|
1397 |
-
callbacks=[LiveMetrics()])
|
1398 |
|
1399 |
-
|
1400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1401 |
|
1402 |
|
1403 |
# ----- [6. Export & Deployment] -----
|
@@ -1501,13 +1512,17 @@ if app_mode == "Predictions":
|
|
1501 |
|
1502 |
if problem_type == "Classification":
|
1503 |
st.metric("Predicted Class", str(prediction))
|
|
|
|
|
|
|
|
|
1504 |
else:
|
1505 |
st.metric("Predicted Value", f"{prediction:.2f}")
|
1506 |
|
1507 |
# 8. Feature Explanation (SHAP)
|
1508 |
enhance_section_title("Insights", "π‘")
|
1509 |
|
1510 |
-
if problem_type == "Classification":
|
1511 |
explainer = shap.TreeExplainer(model)
|
1512 |
shap_values = explainer.shap_values(scaled_input)
|
1513 |
fig = shap.force_plot(explainer.expected_value[1], shap_values[1], input_df, matplotlib=False, link="logit")
|
@@ -1534,4 +1549,4 @@ if app_mode == "Predictions":
|
|
1534 |
st.warning(f"Could not calculate permutation feature importance: {e}")
|
1535 |
|
1536 |
except Exception as e:
|
1537 |
-
st.error(f"Prediction failed: {str(e)}")
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
+
import seaborn as sns
|
4 |
+
from scipy.stats import boxcox
|
5 |
import numpy as np
|
6 |
import plotly.express as px
|
7 |
import plotly.graph_objects as go
|
|
|
1337 |
num_trials = st.number_input("Number of Trials", 1, 100, 10, help="Number of trials for hyperparameter search.")
|
1338 |
|
1339 |
# ----- [5. Training & Monitoring] -----
|
1340 |
+
st.subheader("π― Training Configuration")
|
1341 |
+
|
1342 |
+
import shap # Ensure SHAP is installed: pip install shap
|
1343 |
+
|
1344 |
+
class LiveMetrics(Callback):
|
1345 |
+
def on_epoch_end(self, epoch, logs=None):
|
1346 |
+
if 'metrics' not in st.session_state:
|
1347 |
+
st.session_state.metrics = []
|
1348 |
+
st.session_state.metrics.append(logs)
|
1349 |
+
self.update_chart()
|
1350 |
+
|
1351 |
+
def update_chart(self):
|
1352 |
+
df = pd.DataFrame(st.session_state.metrics)
|
1353 |
+
fig = px.line(df, y=['loss', 'val_loss'], title="Training Progress")
|
1354 |
+
loss_chart.plotly_chart(fig)
|
1355 |
+
|
1356 |
+
if st.button("π Start Training"):
|
1357 |
+
try:
|
1358 |
+
model = tf.keras.Sequential()
|
1359 |
+
|
1360 |
+
# Add layers with regularization
|
1361 |
+
for layer in st.session_state.layers:
|
1362 |
+
layer_class = {
|
1363 |
+
"Dense": Dense,
|
1364 |
+
"Conv2D": Conv2D,
|
1365 |
+
"LSTM": LSTM
|
1366 |
+
}[layer['type']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1367 |
|
1368 |
+
# Add regularization
|
1369 |
+
if l2_reg > 0:
|
1370 |
+
layer['kernel_regularizer'] = tf.keras.regularizers.l2(l2_reg)
|
|
|
|
|
1371 |
|
1372 |
+
model.add(layer_class(**layer))
|
|
|
|
|
|
|
|
|
|
|
|
|
1373 |
|
1374 |
+
# Add batch norm after each layer
|
1375 |
+
if batch_norm:
|
1376 |
+
model.add(BatchNormalization())
|
1377 |
+
|
1378 |
+
# Add global dropout
|
1379 |
+
if dropout > 0:
|
1380 |
+
model.add(Dropout(dropout))
|
1381 |
+
|
1382 |
+
model.compile(
|
1383 |
+
optimizer=optimizer,
|
1384 |
+
loss=loss,
|
1385 |
+
metrics=metrics
|
1386 |
+
)
|
1387 |
+
|
1388 |
+
# Show model summary
|
1389 |
+
st.subheader("Model Architecture")
|
1390 |
+
with tempfile.NamedTemporaryFile(suffix='.png') as tmp:
|
1391 |
+
plot_model(model, to_file=tmp.name, show_shapes=True)
|
1392 |
+
st.image(tmp.name)
|
1393 |
+
|
1394 |
+
# Start training
|
1395 |
+
st.subheader("Live Training Metrics")
|
1396 |
+
loss_chart = st.empty()
|
1397 |
+
model.fit(X_train, y_train,
|
1398 |
+
epochs=10,
|
1399 |
+
validation_data=(X_val, y_val),
|
1400 |
+
callbacks=[LiveMetrics()])
|
1401 |
+
|
1402 |
+
# SHAP explanations
|
1403 |
+
st.subheader("SHAP Explanations")
|
1404 |
+
explainer = shap.KernelExplainer(model.predict, X_train[:100])
|
1405 |
+
shap_values = explainer.shap_values(X_train[:100])
|
1406 |
+
shap.summary_plot(shap_values, X_train[:100], plot_type="bar")
|
1407 |
+
st.pyplot(bbox_inches='tight')
|
1408 |
+
|
1409 |
+
except Exception as e:
|
1410 |
+
st.error(f"Training failed: {str(e)}")
|
1411 |
+
|
1412 |
|
1413 |
|
1414 |
# ----- [6. Export & Deployment] -----
|
|
|
1512 |
|
1513 |
if problem_type == "Classification":
|
1514 |
st.metric("Predicted Class", str(prediction))
|
1515 |
+
elif problem_type == "Binary Classification":
|
1516 |
+
st.metric("Predicted Probability", f"{prediction:.2f}")
|
1517 |
+
binary_class = "Yes" if prediction >= 0.5 else "No"
|
1518 |
+
st.metric("Binary Class", binary_class)
|
1519 |
else:
|
1520 |
st.metric("Predicted Value", f"{prediction:.2f}")
|
1521 |
|
1522 |
# 8. Feature Explanation (SHAP)
|
1523 |
enhance_section_title("Insights", "π‘")
|
1524 |
|
1525 |
+
if problem_type == "Classification" or problem_type == "Binary Classification":
|
1526 |
explainer = shap.TreeExplainer(model)
|
1527 |
shap_values = explainer.shap_values(scaled_input)
|
1528 |
fig = shap.force_plot(explainer.expected_value[1], shap_values[1], input_df, matplotlib=False, link="logit")
|
|
|
1549 |
st.warning(f"Could not calculate permutation feature importance: {e}")
|
1550 |
|
1551 |
except Exception as e:
|
1552 |
+
st.error(f"Prediction failed: {str(e)}")
|