Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -23,6 +23,16 @@ from ydata_profiling import ProfileReport
|
|
23 |
from streamlit_pandas_profiling import st_profile_report
|
24 |
import joblib # For saving and loading models
|
25 |
import os # For file directory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
import shap
|
27 |
from datetime import datetime
|
28 |
from stqdm import stqdm
|
@@ -941,217 +951,224 @@ elif app_mode == "EDA":
|
|
941 |
except Exception as e:
|
942 |
st.error(f"Could not generate analysis report. Ensure pandas-profiling is installed correctly.")
|
943 |
|
944 |
-
# Streamlit App
|
945 |
elif app_mode == "Model Training":
|
946 |
-
st.title("
|
947 |
-
|
948 |
-
#
|
949 |
-
|
950 |
-
st.
|
951 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
952 |
col1, col2 = st.columns(2)
|
953 |
with col1:
|
954 |
-
|
955 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
956 |
st.experimental_rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
957 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
958 |
with col2:
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
-
|
967 |
-
|
968 |
-
|
969 |
-
|
970 |
-
|
971 |
-
|
|
|
|
|
|
|
|
|
972 |
|
973 |
-
#
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
model_type = st.selectbox("Select Multiclass Model", ["Logistic Regression", "Support Vector Machine", "Random Forest"]) # Added SVM and Logistic Regression
|
991 |
-
else:
|
992 |
-
model_type = None # handle this
|
993 |
-
|
994 |
-
# Hyperparameter Configuration - Dynamic based on Model Type
|
995 |
-
st.subheader("Hyperparameter Configuration")
|
996 |
-
model_params = {}
|
997 |
-
|
998 |
-
if model_type == "Neural Network": # Add options for NN parameters
|
999 |
-
hidden_layers = st.text_input("Hidden Layer Sizes (e.g., 50,50 for two layers of 50 neurons)", "50,50")
|
1000 |
-
activation = st.selectbox("Activation Function", ["relu", "tanh", "logistic"])
|
1001 |
-
alpha = st.number_input("L2 Regularization (Alpha)", value=0.0001)
|
1002 |
-
|
1003 |
-
# Process the hidden layers string to a tuple of ints
|
1004 |
try:
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
|
1044 |
-
|
1045 |
-
|
1046 |
-
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
)
|
1051 |
-
|
1052 |
-
if model: # Only proceed if training was successful
|
1053 |
-
st.success("Model trained successfully!")
|
1054 |
-
|
1055 |
-
# Display Metrics
|
1056 |
-
st.subheader("Model Evaluation Metrics")
|
1057 |
-
if problem_type in ["Classification", "Multiclass"]: # Combined here
|
1058 |
-
st.metric("Accuracy", f"{metrics['accuracy']:.2%}")
|
1059 |
-
|
1060 |
-
# Confusion Matrix Visualization
|
1061 |
-
st.subheader("Confusion Matrix")
|
1062 |
-
cm = metrics['confusion_matrix']
|
1063 |
-
class_names = [str(i) for i in np.unique(df[target])] # Get original class names
|
1064 |
-
fig_cm = px.imshow(cm,
|
1065 |
-
labels=dict(x="Predicted", y="Actual"),
|
1066 |
-
x=class_names,
|
1067 |
-
y=class_names,
|
1068 |
-
color_continuous_scale="Viridis")
|
1069 |
-
st.plotly_chart(fig_cm, use_container_width=True)
|
1070 |
-
|
1071 |
-
# Classification Report
|
1072 |
-
st.subheader("Classification Report")
|
1073 |
-
report = metrics['classification_report']
|
1074 |
-
report_df = pd.DataFrame(report).transpose()
|
1075 |
-
st.dataframe(report_df)
|
1076 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1077 |
else:
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
try:
|
1086 |
-
fig_importance = px.bar(
|
1087 |
-
x=importance,
|
1088 |
-
y=column_order, # Use stored column order
|
1089 |
-
orientation='h',
|
1090 |
-
title="Feature Importance"
|
1091 |
)
|
1092 |
-
st.plotly_chart(fig_importance, use_container_width=True)
|
1093 |
-
except Exception as e:
|
1094 |
-
st.warning(f"Could not display feature importance: {e}")
|
1095 |
-
|
1096 |
-
# Explainable AI (Placeholder)
|
1097 |
-
st.subheader("Explainable AI (XAI)")
|
1098 |
-
st.write("Future implementation will include model explanations using techniques like SHAP or LIME.") #To be implemented
|
1099 |
-
if st.checkbox("Show a random model explanation (example)"): #Example of a feature, to be implemented
|
1100 |
-
st.write("This feature is important because...")
|
1101 |
-
|
1102 |
-
# Save Model
|
1103 |
-
st.subheader("Save Model")
|
1104 |
-
model_name = st.text_input("Enter model name (without extension)", "my_model")
|
1105 |
-
if st.button("Save Model"):
|
1106 |
-
try:
|
1107 |
-
model_path = f"{model_name}.joblib"
|
1108 |
-
joblib.dump({
|
1109 |
-
'model': model,
|
1110 |
-
'scaler': scaler,
|
1111 |
-
'label_encoder': label_encoder,
|
1112 |
-
'imputer_numerical': imputer_numerical,
|
1113 |
-
'column_order': column_order,
|
1114 |
-
'features': features,
|
1115 |
-
'target': target,
|
1116 |
-
'problem_type': problem_type,
|
1117 |
-
'model_type': model_type,
|
1118 |
-
'model_params': model_params,
|
1119 |
-
'X_train': X_train, # Store X_train
|
1120 |
-
'y_train': y_train # Store y_train
|
1121 |
-
}, model_path)
|
1122 |
-
st.success(f"Model saved as {model_path}")
|
1123 |
-
except Exception as e:
|
1124 |
-
st.error(f"Error saving model: {e}")
|
1125 |
-
|
1126 |
-
# Model Validation Section
|
1127 |
-
st.header("Model Validation")
|
1128 |
-
model_path_validate = st.text_input("Enter path to saved model for validation", "my_model.joblib")
|
1129 |
-
if st.button("Validate Model"):
|
1130 |
-
if not os.path.exists(model_path_validate):
|
1131 |
-
st.error("Model file not found.")
|
1132 |
-
else:
|
1133 |
-
validation_metrics, problem_type = validate_model(model_path_validate, df.copy(), target, features, test_size) #Pass a copy of the dataframe
|
1134 |
-
if validation_metrics:
|
1135 |
-
st.subheader("Validation Metrics")
|
1136 |
-
if problem_type in ["Classification", "Multiclass"]: #Combined here
|
1137 |
-
st.metric("Accuracy", f"{validation_metrics['accuracy']:.2%}")
|
1138 |
-
st.subheader("Confusion Matrix")
|
1139 |
-
cm = validation_metrics['confusion_matrix']
|
1140 |
-
class_names = [str(i) for i in np.unique(df[target])] #Get original class names
|
1141 |
-
fig_cm = px.imshow(cm,
|
1142 |
-
labels=dict(x="Predicted", y="Actual"),
|
1143 |
-
x=class_names,
|
1144 |
-
y=class_names,
|
1145 |
-
color_continuous_scale="Viridis")
|
1146 |
-
st.plotly_chart(fig_cm, use_container_width=True)
|
1147 |
-
st.subheader("Classification Report")
|
1148 |
-
report = validation_metrics['classification_report']
|
1149 |
-
report_df = pd.DataFrame(report).transpose()
|
1150 |
-
st.dataframe(report_df)
|
1151 |
-
|
1152 |
-
else:
|
1153 |
-
st.metric("MSE", f"{validation_metrics['mse']:.2f}")
|
1154 |
-
st.metric("R2", f"{validation_metrics['r2']:.2f}")
|
1155 |
|
1156 |
# Predictions Section (Fixed)
|
1157 |
if app_mode == "Predictions":
|
|
|
23 |
from streamlit_pandas_profiling import st_profile_report
|
24 |
import joblib # For saving and loading models
|
25 |
import os # For file directory
|
26 |
+
# Advanced
|
27 |
+
from transformers import TFBertModel
|
28 |
+
import tensorflow as tf
|
29 |
+
from tensorflow.keras.models import Sequential
|
30 |
+
from tensorflow.keras.layers import Dense, Conv2D, LSTM, Embedding, Dropout, Flatten, MaxPooling2D, BatchNormalization
|
31 |
+
from tensorflow.keras.applications import MobileNetV2, ResNet50
|
32 |
+
from tensorflow.keras.utils import plot_model
|
33 |
+
from tensorflow.keras.callbacks import Callback
|
34 |
+
import tf2onnx
|
35 |
+
import onnx
|
36 |
import shap
|
37 |
from datetime import datetime
|
38 |
from stqdm import stqdm
|
|
|
951 |
except Exception as e:
|
952 |
st.error(f"Could not generate analysis report. Ensure pandas-profiling is installed correctly.")
|
953 |
|
|
|
954 |
elif app_mode == "Model Training":
|
955 |
+
st.title("🧠 Advanced Model Architect")
|
956 |
+
|
957 |
+
# ----- [1. Preset Selection] -----
|
958 |
+
with st.sidebar.expander("🚀 Quick Start", expanded=True):
|
959 |
+
presets = st.selectbox("Load Preset", [
|
960 |
+
"None",
|
961 |
+
"CNN-MNIST",
|
962 |
+
"LSTM-Text",
|
963 |
+
"ResNet-Lite",
|
964 |
+
"Transformer-NLP"
|
965 |
+
])
|
966 |
+
|
967 |
+
if presets != "None":
|
968 |
+
if presets == "CNN-MNIST":
|
969 |
+
st.session_state.layers = [
|
970 |
+
{"type": "Conv2D", "filters":32, "kernel_size":3},
|
971 |
+
{"type": "MaxPooling2D", "pool_size":2},
|
972 |
+
{"type": "Flatten"},
|
973 |
+
{"type": "Dense", "units":10}
|
974 |
+
]
|
975 |
+
elif presets == "LSTM-Text":
|
976 |
+
st.session_state.layers = [
|
977 |
+
{"type": "Embedding", "input_dim":10000, "output_dim":128},
|
978 |
+
{"type": "LSTM", "units":64},
|
979 |
+
{"type": "Dense", "units":1, "activation":"sigmoid"}
|
980 |
+
]
|
981 |
+
st.experimental_rerun()
|
982 |
+
|
983 |
+
# ----- [2. Base Model & Transfer Learning] -----
|
984 |
+
with st.expander("🏗️ Transfer Learning", expanded=False):
|
985 |
col1, col2 = st.columns(2)
|
986 |
with col1:
|
987 |
+
base_model = st.selectbox("Base Model", [
|
988 |
+
"None",
|
989 |
+
"MobileNetV2",
|
990 |
+
"ResNet50",
|
991 |
+
"BERT"
|
992 |
+
])
|
993 |
+
|
994 |
+
with col2:
|
995 |
+
if base_model != "None":
|
996 |
+
freeze_layers = st.checkbox("Freeze Base Layers", True)
|
997 |
+
custom_input = st.checkbox("Custom Input Shape", False)
|
998 |
+
|
999 |
+
if base_model == "MobileNetV2":
|
1000 |
+
model = tf.keras.applications.MobileNetV2(
|
1001 |
+
include_top=False,
|
1002 |
+
weights='imagenet',
|
1003 |
+
input_shape=(224, 224, 3) if custom_input else None
|
1004 |
+
)
|
1005 |
+
st.info(f"Loaded {base_model} with {len(model.layers)} layers")
|
1006 |
+
|
1007 |
+
# ----- [3. Layer Configuration] -----
|
1008 |
+
st.subheader("🏗️ Network Architecture")
|
1009 |
+
|
1010 |
+
# Dynamic layer builder
|
1011 |
+
layer_types = [
|
1012 |
+
"Dense", "Conv2D", "LSTM",
|
1013 |
+
"Dropout", "BatchNorm", "Flatten"
|
1014 |
+
]
|
1015 |
+
|
1016 |
+
if 'layers' not in st.session_state:
|
1017 |
+
st.session_state.layers = []
|
1018 |
+
|
1019 |
+
for i, layer in enumerate(st.session_state.layers):
|
1020 |
+
cols = st.columns([1,3,2])
|
1021 |
+
with cols[0]:
|
1022 |
+
st.markdown(f"**Layer {i+1}**")
|
1023 |
+
with cols[1]:
|
1024 |
+
st.code(f"{layer['type']}: {dict((k,v) for k,v in layer.items() if k != 'type')}")
|
1025 |
+
with cols[2]:
|
1026 |
+
if st.button(f"❌ Remove {i+1}", key=f"remove_{i}"):
|
1027 |
+
del st.session_state.layers[i]
|
1028 |
st.experimental_rerun()
|
1029 |
+
|
1030 |
+
# Add new layer controls
|
1031 |
+
with st.expander("➕ Add New Layer", expanded=True):
|
1032 |
+
new_layer_type = st.selectbox("Layer Type", layer_types)
|
1033 |
+
new_layer_params = {}
|
1034 |
+
|
1035 |
+
if new_layer_type == "Dense":
|
1036 |
+
new_layer_params["units"] = st.number_input("Units", 1, 1024, 128)
|
1037 |
+
new_layer_params["activation"] = st.selectbox(
|
1038 |
+
"Activation", ["relu", "sigmoid", "tanh"]
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
elif new_layer_type == "Conv2D":
|
1042 |
+
new_layer_params["filters"] = st.number_input("Filters", 1, 256, 32)
|
1043 |
+
new_layer_params["kernel_size"] = st.number_input("Kernel Size", 1, 9, 3)
|
1044 |
+
|
1045 |
+
if st.button("Add Layer"):
|
1046 |
+
st.session_state.layers.append({
|
1047 |
+
"type": new_layer_type,
|
1048 |
+
**new_layer_params
|
1049 |
+
})
|
1050 |
+
st.experimental_rerun()
|
1051 |
|
1052 |
+
# ----- [4. Regularization & Advanced Options] -----
|
1053 |
+
with st.expander("⚙️ Advanced Configuration", expanded=False):
|
1054 |
+
col1, col2 = st.columns(2)
|
1055 |
+
|
1056 |
+
with col1:
|
1057 |
+
st.subheader("Regularization")
|
1058 |
+
l2_reg = st.number_input("L2 Regularization", 0.0, 0.1, 0.001)
|
1059 |
+
dropout = st.number_input("Global Dropout", 0.0, 0.5, 0.2)
|
1060 |
+
batch_norm = st.checkbox("Batch Normalization")
|
1061 |
+
|
1062 |
with col2:
|
1063 |
+
st.subheader("Optimization")
|
1064 |
+
optimizer = st.selectbox("Optimizer", [
|
1065 |
+
"adam", "sgd", "rmsprop",
|
1066 |
+
"nadam", "adamax"
|
1067 |
+
])
|
1068 |
+
|
1069 |
+
loss = st.selectbox("Loss Function", [
|
1070 |
+
"categorical_crossentropy",
|
1071 |
+
"binary_crossentropy",
|
1072 |
+
"mse",
|
1073 |
+
"mae"
|
1074 |
+
])
|
1075 |
+
|
1076 |
+
metrics = st.multiselect("Metrics", [
|
1077 |
+
"accuracy", "precision",
|
1078 |
+
"recall", "auc"
|
1079 |
+
])
|
1080 |
|
1081 |
+
# ----- [5. Training & Monitoring] -----
|
1082 |
+
st.subheader("🎯 Training Configuration")
|
1083 |
+
|
1084 |
+
class LiveMetrics(Callback):
|
1085 |
+
def on_epoch_end(self, epoch, logs=None):
|
1086 |
+
if 'metrics' not in st.session_state:
|
1087 |
+
st.session_state.metrics = []
|
1088 |
+
st.session_state.metrics.append(logs)
|
1089 |
+
self.update_chart()
|
1090 |
+
|
1091 |
+
def update_chart(self):
|
1092 |
+
df = pd.DataFrame(st.session_state.metrics)
|
1093 |
+
fig = px.line(df, y=['loss', 'val_loss'],
|
1094 |
+
title="Training Progress")
|
1095 |
+
loss_chart.plotly_chart(fig)
|
1096 |
+
|
1097 |
+
if st.button("🚀 Start Training"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1098 |
try:
|
1099 |
+
model = tf.keras.Sequential()
|
1100 |
+
|
1101 |
+
# Add layers with regularization
|
1102 |
+
for layer in st.session_state.layers:
|
1103 |
+
layer_class = {
|
1104 |
+
"Dense": Dense,
|
1105 |
+
"Conv2D": Conv2D,
|
1106 |
+
"LSTM": LSTM
|
1107 |
+
}[layer['type']]
|
1108 |
+
|
1109 |
+
# Add regularization
|
1110 |
+
if l2_reg > 0:
|
1111 |
+
layer['kernel_regularizer'] = tf.keras.regularizers.l2(l2_reg)
|
1112 |
+
|
1113 |
+
model.add(layer_class(**layer))
|
1114 |
+
|
1115 |
+
# Add batch norm after each layer
|
1116 |
+
if batch_norm:
|
1117 |
+
model.add(BatchNormalization())
|
1118 |
+
|
1119 |
+
# Add global dropout
|
1120 |
+
model.add(Dropout(dropout))
|
1121 |
+
|
1122 |
+
model.compile(
|
1123 |
+
optimizer=optimizer,
|
1124 |
+
loss=loss,
|
1125 |
+
metrics=metrics
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
# Show model summary
|
1129 |
+
st.subheader("Model Architecture")
|
1130 |
+
with tempfile.NamedTemporaryFile(suffix='.png') as tmp:
|
1131 |
+
plot_model(model, to_file=tmp.name, show_shapes=True)
|
1132 |
+
st.image(tmp.name)
|
1133 |
+
|
1134 |
+
# Start training
|
1135 |
+
st.subheader("Live Training Metrics")
|
1136 |
+
loss_chart = st.empty()
|
1137 |
+
model.fit(X_train, y_train,
|
1138 |
+
epochs=10,
|
1139 |
+
validation_data=(X_val, y_val),
|
1140 |
+
callbacks=[LiveMetrics()])
|
1141 |
+
|
1142 |
+
except Exception as e:
|
1143 |
+
st.error(f"Training failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1144 |
|
1145 |
+
# ----- [6. Export & Deployment] -----
|
1146 |
+
st.subheader("💾 Export Model")
|
1147 |
+
|
1148 |
+
export_format = st.radio("Format", [
|
1149 |
+
"TensorFlow SavedModel",
|
1150 |
+
"HDF5",
|
1151 |
+
"ONNX"
|
1152 |
+
])
|
1153 |
+
|
1154 |
+
if st.button("Export"):
|
1155 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
1156 |
+
if export_format == "HDF5":
|
1157 |
+
model.save(tmp.name + '.h5')
|
1158 |
+
elif export_format == "ONNX":
|
1159 |
+
import tf2onnx
|
1160 |
+
model_proto, _ = tf2onnx.convert.from_keras(model)
|
1161 |
+
with open(tmp.name + '.onnx', 'wb') as f:
|
1162 |
+
f.write(model_proto.SerializeToString())
|
1163 |
else:
|
1164 |
+
tf.saved_model.save(model, tmp.name)
|
1165 |
+
|
1166 |
+
with open(tmp.name, 'rb') as f:
|
1167 |
+
st.download_button(
|
1168 |
+
"Download Model",
|
1169 |
+
f.read(),
|
1170 |
+
file_name=f"model.{'h5' if export_format=='HDF5' else 'onnx'}"
|
|
|
|
|
|
|
|
|
|
|
|
|
1171 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1172 |
|
1173 |
# Predictions Section (Fixed)
|
1174 |
if app_mode == "Predictions":
|