Update app.py
Browse files
app.py
CHANGED
@@ -1,370 +1,130 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
-
import
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
7 |
-
from
|
8 |
-
import
|
9 |
-
import
|
10 |
-
from typing import Optional
|
11 |
from datasets import load_dataset
|
12 |
-
from huggingface_hub import list_datasets
|
13 |
-
import traceback
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
-
st.
|
18 |
-
page_title="ML Pipeline for Purple Teaming",
|
19 |
-
page_icon="🛡️",
|
20 |
-
layout="wide"
|
21 |
-
)
|
22 |
|
23 |
-
|
24 |
-
"
|
25 |
-
if
|
26 |
-
|
27 |
-
|
28 |
-
return sanitized
|
29 |
-
|
30 |
-
def load_hf_dataset(dataset_name: str, config_name: Optional[str] = None) -> pd.DataFrame:
|
31 |
-
"""Load a dataset from Hugging Face and convert to pandas DataFrame"""
|
32 |
-
try:
|
33 |
-
if config_name:
|
34 |
-
dataset = load_dataset(dataset_name, config_name)
|
35 |
else:
|
36 |
-
|
37 |
-
|
38 |
-
# Convert to pandas DataFrame (using first split, usually 'train')
|
39 |
-
split_name = list(dataset.keys())[0]
|
40 |
-
df = dataset[split_name].to_pandas()
|
41 |
-
return df
|
42 |
-
except Exception as e:
|
43 |
-
raise Exception(f"Error loading dataset from Hugging Face: {str(e)}\n{traceback.format_exc()}")
|
44 |
-
|
45 |
-
def main():
|
46 |
-
st.title("🛡️ ML Pipeline for Cybersecurity Purple Teaming")
|
47 |
-
|
48 |
-
# Initialize default values for feature engineering
|
49 |
-
if 'poly_degree' not in st.session_state:
|
50 |
-
st.session_state.poly_degree = 2
|
51 |
-
if 'k_best_features' not in st.session_state:
|
52 |
-
st.session_state.k_best_features = 10
|
53 |
-
if 'n_components' not in st.session_state:
|
54 |
-
st.session_state.n_components = 0.95
|
55 |
-
|
56 |
-
# Sidebar
|
57 |
-
st.sidebar.header("Pipeline Configuration")
|
58 |
-
|
59 |
-
# Data Input Tabs
|
60 |
-
data_input_tab = st.radio(
|
61 |
-
"Choose Data Source",
|
62 |
-
["Upload File", "Load from Hugging Face"]
|
63 |
-
)
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
if
|
68 |
-
uploaded_file = st.file_uploader(
|
69 |
-
"Upload Dataset (CSV/JSON)",
|
70 |
-
type=['csv', 'json']
|
71 |
-
)
|
72 |
-
if uploaded_file is not None:
|
73 |
-
try:
|
74 |
-
df = load_data(uploaded_file)
|
75 |
-
except Exception as e:
|
76 |
-
st.error(f"Error loading file: {str(e)}")
|
77 |
-
else:
|
78 |
-
# Hugging Face Dataset Loading
|
79 |
-
st.markdown("### Load Dataset from Hugging Face")
|
80 |
-
dataset_name = st.text_input(
|
81 |
-
"Dataset Name",
|
82 |
-
help="Enter the Hugging Face dataset name (e.g., 'username/dataset-name')"
|
83 |
-
)
|
84 |
-
config_name = st.text_input(
|
85 |
-
"Configuration Name (Optional)",
|
86 |
-
help="Enter the specific configuration name if the dataset has multiple configurations"
|
87 |
-
)
|
88 |
-
|
89 |
-
if dataset_name:
|
90 |
-
try:
|
91 |
-
with st.spinner("Loading dataset from Hugging Face..."):
|
92 |
-
df = load_hf_dataset(
|
93 |
-
dataset_name,
|
94 |
-
config_name if config_name else None
|
95 |
-
)
|
96 |
-
st.success(f"Successfully loaded dataset: {dataset_name}")
|
97 |
-
except Exception as e:
|
98 |
-
st.error(str(e))
|
99 |
-
|
100 |
-
if df is not None:
|
101 |
try:
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
return
|
106 |
-
|
107 |
-
if df.shape[1] < 2:
|
108 |
-
st.error("Dataset must contain at least two columns (features and target).")
|
109 |
-
return
|
110 |
-
|
111 |
-
# Check for numeric columns
|
112 |
-
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
113 |
-
if len(numeric_cols) == 0:
|
114 |
-
st.error("Dataset must contain at least one numeric column for analysis.")
|
115 |
-
return
|
116 |
-
|
117 |
-
# Initialize components
|
118 |
-
processor = DataProcessor()
|
119 |
-
trainer = ModelTrainer()
|
120 |
-
visualizer = Visualizer()
|
121 |
-
|
122 |
-
# Data Processing Section
|
123 |
-
st.header("1. Data Processing")
|
124 |
-
col1, col2 = st.columns(2)
|
125 |
-
|
126 |
-
with col1:
|
127 |
-
st.subheader("Dataset Overview")
|
128 |
-
st.write(f"Shape: {df.shape}")
|
129 |
-
st.write("Sample Data:")
|
130 |
-
st.dataframe(df.head())
|
131 |
-
|
132 |
-
with col2:
|
133 |
-
st.subheader("Data Statistics")
|
134 |
-
st.write(df.describe())
|
135 |
-
|
136 |
-
# Feature Engineering Configuration
|
137 |
-
st.header("2. Feature Engineering")
|
138 |
-
col3, col4 = st.columns(2)
|
139 |
-
|
140 |
-
with col3:
|
141 |
-
# Basic preprocessing
|
142 |
-
handling_strategy = st.selectbox(
|
143 |
-
"Missing Values Strategy",
|
144 |
-
["mean", "median", "most_frequent", "constant"]
|
145 |
-
)
|
146 |
-
scaling_method = st.selectbox(
|
147 |
-
"Scaling Method",
|
148 |
-
["standard", "minmax", "robust"]
|
149 |
-
)
|
150 |
-
|
151 |
-
# Advanced Feature Engineering
|
152 |
-
st.subheader("Advanced Features")
|
153 |
-
use_polynomial = st.checkbox("Use Polynomial Features")
|
154 |
-
if use_polynomial:
|
155 |
-
st.session_state.poly_degree = st.slider("Polynomial Degree", 2, 5, st.session_state.poly_degree)
|
156 |
-
|
157 |
-
use_feature_selection = st.checkbox("Use Feature Selection")
|
158 |
-
if use_feature_selection:
|
159 |
-
max_features = min(50, df.shape[1]) # Limit k_best_features to number of columns
|
160 |
-
st.session_state.k_best_features = st.slider(
|
161 |
-
"Number of Best Features",
|
162 |
-
2, # Minimum 2 features required
|
163 |
-
max_features,
|
164 |
-
min(st.session_state.k_best_features, max_features),
|
165 |
-
help="Select the number of most important features to use"
|
166 |
-
)
|
167 |
-
|
168 |
-
with col4:
|
169 |
-
use_pca = st.checkbox("Use PCA")
|
170 |
-
if use_pca:
|
171 |
-
st.session_state.n_components = st.slider(
|
172 |
-
"PCA Components (%)",
|
173 |
-
1, 100,
|
174 |
-
int(st.session_state.n_components * 100),
|
175 |
-
help="Percentage of variance to preserve"
|
176 |
-
) / 100.0
|
177 |
-
|
178 |
-
add_cyber_features = st.checkbox("Add Cybersecurity Features")
|
179 |
-
|
180 |
-
numeric_features = df.select_dtypes(include=[np.number]).columns.tolist()
|
181 |
-
if not numeric_features:
|
182 |
-
st.error("No numeric features found in the dataset.")
|
183 |
-
return
|
184 |
-
|
185 |
-
feature_cols = st.multiselect(
|
186 |
-
"Select Features",
|
187 |
-
numeric_features,
|
188 |
-
default=numeric_features,
|
189 |
-
help="Select the features to use for training"
|
190 |
-
)
|
191 |
-
|
192 |
-
if not feature_cols:
|
193 |
-
st.error("Please select at least one feature column")
|
194 |
-
return
|
195 |
-
|
196 |
-
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
197 |
-
target_col = st.selectbox(
|
198 |
-
"Select Target Column",
|
199 |
-
[col for col in categorical_cols if col not in feature_cols],
|
200 |
-
help="Select the target variable to predict"
|
201 |
-
)
|
202 |
-
|
203 |
-
if target_col is None:
|
204 |
-
st.error("No suitable target column found. Target should be categorical.")
|
205 |
-
return
|
206 |
-
|
207 |
-
# Create feature engineering config
|
208 |
-
feature_engineering_config = {
|
209 |
-
'use_polynomial': use_polynomial,
|
210 |
-
'poly_degree': st.session_state.poly_degree if use_polynomial else None,
|
211 |
-
'use_feature_selection': use_feature_selection,
|
212 |
-
'k_best_features': st.session_state.k_best_features if use_feature_selection else None,
|
213 |
-
'use_pca': use_pca,
|
214 |
-
'n_components': st.session_state.n_components if use_pca else None,
|
215 |
-
'add_cyber_features': add_cyber_features
|
216 |
-
}
|
217 |
-
|
218 |
-
# Model Configuration Section
|
219 |
-
st.header("3. Model Configuration")
|
220 |
-
col5, col6 = st.columns(2)
|
221 |
-
|
222 |
-
with col5:
|
223 |
-
n_estimators = st.slider(
|
224 |
-
"Number of Trees",
|
225 |
-
min_value=10,
|
226 |
-
max_value=500,
|
227 |
-
value=100
|
228 |
-
)
|
229 |
-
max_depth = st.slider(
|
230 |
-
"Max Depth",
|
231 |
-
min_value=1,
|
232 |
-
max_value=50,
|
233 |
-
value=10
|
234 |
-
)
|
235 |
-
|
236 |
-
with col6:
|
237 |
-
min_samples_split = st.slider(
|
238 |
-
"Min Samples Split",
|
239 |
-
min_value=2,
|
240 |
-
max_value=20,
|
241 |
-
value=2
|
242 |
-
)
|
243 |
-
min_samples_leaf = st.slider(
|
244 |
-
"Min Samples Leaf",
|
245 |
-
min_value=1,
|
246 |
-
max_value=10,
|
247 |
-
value=1
|
248 |
-
)
|
249 |
-
|
250 |
-
if st.button("Train Model"):
|
251 |
-
with st.spinner("Processing data and training model..."):
|
252 |
-
# Process data with feature engineering
|
253 |
-
X_train, X_test, y_train, y_test = processor.process_data(
|
254 |
-
df,
|
255 |
-
feature_cols,
|
256 |
-
target_col,
|
257 |
-
handling_strategy,
|
258 |
-
scaling_method,
|
259 |
-
feature_engineering_config
|
260 |
-
)
|
261 |
-
|
262 |
-
# Train model
|
263 |
-
model, metrics = trainer.train_model(
|
264 |
-
X_train, X_test, y_train, y_test,
|
265 |
-
n_estimators=n_estimators,
|
266 |
-
max_depth=max_depth,
|
267 |
-
min_samples_split=min_samples_split,
|
268 |
-
min_samples_leaf=min_samples_leaf
|
269 |
-
)
|
270 |
-
|
271 |
-
# Results Section
|
272 |
-
st.header("4. Results and Visualizations")
|
273 |
-
col7, col8 = st.columns(2)
|
274 |
-
|
275 |
-
with col7:
|
276 |
-
st.subheader("Model Performance Metrics")
|
277 |
-
for metric, value in metrics.items():
|
278 |
-
st.metric(metric, f"{value:.4f}")
|
279 |
-
|
280 |
-
# Add model export section with improved validation
|
281 |
-
st.subheader("Export Model")
|
282 |
-
model_name = st.text_input(
|
283 |
-
"Model Name (optional)",
|
284 |
-
help="Enter a name for your model (alphanumeric and underscores only)"
|
285 |
-
)
|
286 |
-
|
287 |
-
if st.button("Save Model"):
|
288 |
-
try:
|
289 |
-
# Validate and sanitize model name
|
290 |
-
sanitized_name = validate_model_name(model_name)
|
291 |
-
|
292 |
-
if sanitized_name != model_name:
|
293 |
-
st.warning(f"Model name was sanitized to: {sanitized_name}")
|
294 |
-
|
295 |
-
# Save model and metadata
|
296 |
-
preprocessing_params = {
|
297 |
-
'feature_engineering_config': feature_engineering_config,
|
298 |
-
'handling_strategy': handling_strategy,
|
299 |
-
'scaling_method': scaling_method,
|
300 |
-
'feature_columns': feature_cols,
|
301 |
-
'target_column': target_col
|
302 |
-
}
|
303 |
-
|
304 |
-
model_path, metadata_path = save_model(
|
305 |
-
model,
|
306 |
-
feature_cols,
|
307 |
-
preprocessing_params,
|
308 |
-
metrics,
|
309 |
-
sanitized_name
|
310 |
-
)
|
311 |
-
|
312 |
-
st.success(f"Model saved successfully!\nFiles:\n- {model_path}\n- {metadata_path}")
|
313 |
-
except Exception as e:
|
314 |
-
st.error(f"Error saving model: {str(e)}")
|
315 |
-
st.error("Please ensure you have proper permissions and sufficient disk space.")
|
316 |
-
|
317 |
-
with col8:
|
318 |
-
if not use_pca: # Skip feature importance for PCA
|
319 |
-
st.subheader("Feature Importance")
|
320 |
-
fig_importance = visualizer.plot_feature_importance(
|
321 |
-
model,
|
322 |
-
feature_cols if not use_polynomial else [f"Feature_{i}" for i in range(X_train.shape[1])]
|
323 |
-
)
|
324 |
-
st.pyplot(fig_importance)
|
325 |
-
|
326 |
-
# Confusion Matrix
|
327 |
-
st.subheader("Confusion Matrix")
|
328 |
-
fig_cm = visualizer.plot_confusion_matrix(
|
329 |
-
y_test,
|
330 |
-
model.predict(X_test)
|
331 |
-
)
|
332 |
-
st.pyplot(fig_cm)
|
333 |
-
|
334 |
-
# ROC Curve
|
335 |
-
st.subheader("ROC Curve")
|
336 |
-
fig_roc = visualizer.plot_roc_curve(
|
337 |
-
model,
|
338 |
-
X_test,
|
339 |
-
y_test
|
340 |
-
)
|
341 |
-
st.pyplot(fig_roc)
|
342 |
-
|
343 |
except Exception as e:
|
344 |
-
st.error(f"
|
345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
else:
|
347 |
-
|
348 |
-
st.info("Please upload a dataset to begin.")
|
349 |
-
else:
|
350 |
-
st.info("Please enter a Hugging Face dataset name to begin.")
|
351 |
|
352 |
-
|
353 |
-
st.
|
354 |
-
try:
|
355 |
-
saved_models = list_saved_models()
|
356 |
-
if saved_models:
|
357 |
-
for model_info in saved_models:
|
358 |
-
with st.expander(f"Model: {model_info['name']}"):
|
359 |
-
st.write(f"Type: {model_info['type']}")
|
360 |
-
st.write(f"Created: {model_info['created_at']}")
|
361 |
-
st.write("Performance Metrics:")
|
362 |
-
for metric, value in model_info['metrics'].items():
|
363 |
-
st.metric(metric, f"{value:.4f}")
|
364 |
-
else:
|
365 |
-
st.info("No saved models found.")
|
366 |
-
except Exception as e:
|
367 |
-
st.error(f"Error loading saved models: {str(e)}")
|
368 |
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
+
import seaborn as sns
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from sklearn.linear_model import LogisticRegression
|
6 |
+
from sklearn.tree import DecisionTreeClassifier
|
7 |
+
from sklearn.ensemble import RandomForestClassifier
|
8 |
+
from sklearn.metrics import classification_report
|
9 |
+
from sklearn.preprocessing import LabelEncoder, StandardScaler, MinMaxScaler
|
|
|
10 |
from datasets import load_dataset
|
|
|
|
|
11 |
|
12 |
+
# 1. Load Dataset
|
13 |
+
st.header("1. Load Dataset")
|
14 |
|
15 |
+
data_source = st.radio("Choose data source:", ["Upload File", "Hugging Face", "Sample Dataset"])
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
if data_source == "Upload File":
|
18 |
+
uploaded_file = st.file_uploader("Upload your dataset (CSV, Excel, or Parquet)", type=["csv", "xlsx", "parquet"])
|
19 |
+
if uploaded_file:
|
20 |
+
if uploaded_file.name.endswith(".csv"):
|
21 |
+
df = pd.read_csv(uploaded_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
else:
|
23 |
+
df = pd.read_excel(uploaded_file)
|
24 |
+
st.success(f"Successfully loaded {uploaded_file.name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
elif data_source == "Hugging Face":
|
27 |
+
hf_dataset_name = st.text_input("Enter Hugging Face dataset name:")
|
28 |
+
if hf_dataset_name:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
try:
|
30 |
+
dataset = load_dataset(hf_dataset_name)
|
31 |
+
df = dataset.to_pandas()
|
32 |
+
st.success(f"Loaded dataset: {hf_dataset_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
except Exception as e:
|
34 |
+
st.error(f"Error loading dataset: {str(e)}")
|
35 |
+
|
36 |
+
else: # Sample Dataset
|
37 |
+
sample_data = st.selectbox("Select a sample dataset:", ["Iris", "Wine", "Titanic"])
|
38 |
+
df = sns.load_dataset(sample_data.lower())
|
39 |
+
st.success(f"Loaded sample dataset: {sample_data}")
|
40 |
+
|
41 |
+
if 'df' in locals():
|
42 |
+
st.dataframe(df.head())
|
43 |
+
|
44 |
+
# 2. Explore Dataset
|
45 |
+
st.header("2. Explore Dataset")
|
46 |
+
|
47 |
+
if 'df' in locals():
|
48 |
+
st.subheader("Dataset Overview")
|
49 |
+
st.write(f"Shape: {df.shape}")
|
50 |
+
st.write("Column Information:")
|
51 |
+
st.dataframe(df.dtypes)
|
52 |
+
|
53 |
+
if st.checkbox("Show Missing Values"):
|
54 |
+
missing = df.isnull().sum()
|
55 |
+
st.bar_chart(missing[missing > 0])
|
56 |
+
|
57 |
+
st.subheader("Summary Statistics")
|
58 |
+
st.write(df.describe())
|
59 |
+
|
60 |
+
if st.checkbox("Generate Correlation Matrix"):
|
61 |
+
corr_matrix = df.corr()
|
62 |
+
st.write(sns.heatmap(corr_matrix, annot=True, cmap="coolwarm"))
|
63 |
+
st.pyplot()
|
64 |
+
else:
|
65 |
+
st.warning("Load a dataset to explore.")
|
66 |
+
|
67 |
+
# 3. Preprocess Dataset
|
68 |
+
st.header("3. Preprocess Dataset")
|
69 |
+
|
70 |
+
if 'df' in locals():
|
71 |
+
st.subheader("Handle Missing Values")
|
72 |
+
missing_option = st.radio("Choose missing value strategy:", ["None", "Fill with Mean", "Drop Rows"])
|
73 |
+
if missing_option == "Fill with Mean":
|
74 |
+
df = df.fillna(df.mean())
|
75 |
+
elif missing_option == "Drop Rows":
|
76 |
+
df = df.dropna()
|
77 |
+
|
78 |
+
st.subheader("Encode Categorical Variables")
|
79 |
+
encoding_method = st.radio("Encoding Method:", ["None", "One-Hot Encoding", "Label Encoding"])
|
80 |
+
if encoding_method == "One-Hot Encoding":
|
81 |
+
df = pd.get_dummies(df)
|
82 |
+
elif encoding_method == "Label Encoding":
|
83 |
+
le = LabelEncoder()
|
84 |
+
for col in df.select_dtypes(include="object").columns:
|
85 |
+
df[col] = le.fit_transform(df[col])
|
86 |
+
|
87 |
+
st.subheader("Feature Scaling")
|
88 |
+
scaling_method = st.radio("Scaling Method:", ["None", "Standardization", "Normalization"])
|
89 |
+
if scaling_method != "None":
|
90 |
+
scaler = StandardScaler() if scaling_method == "Standardization" else MinMaxScaler()
|
91 |
+
numeric_cols = df.select_dtypes(include="number").columns
|
92 |
+
df[numeric_cols] = scaler.fit_transform(df[numeric_cols])
|
93 |
+
|
94 |
+
st.success("Preprocessing complete!")
|
95 |
+
st.dataframe(df.head())
|
96 |
+
else:
|
97 |
+
st.warning("Load a dataset to preprocess.")
|
98 |
+
|
99 |
+
# 4. Train Model
|
100 |
+
st.header("4. Train Model")
|
101 |
+
|
102 |
+
if 'df' in locals():
|
103 |
+
st.subheader("Select Target Column")
|
104 |
+
target_col = st.selectbox("Choose the target column:", df.columns)
|
105 |
+
features = [col for col in df.columns if col != target_col]
|
106 |
+
|
107 |
+
st.subheader("Train/Test Split")
|
108 |
+
test_size = st.slider("Test size (percentage):", 10, 50, 20) / 100
|
109 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
110 |
+
df[features], df[target_col], test_size=test_size, random_state=42
|
111 |
+
)
|
112 |
+
|
113 |
+
st.subheader("Select and Train Model")
|
114 |
+
model_type = st.selectbox("Choose a model:", ["Logistic Regression", "Decision Tree", "Random Forest"])
|
115 |
+
if model_type == "Logistic Regression":
|
116 |
+
model = LogisticRegression()
|
117 |
+
elif model_type == "Decision Tree":
|
118 |
+
model = DecisionTreeClassifier()
|
119 |
else:
|
120 |
+
model = RandomForestClassifier()
|
|
|
|
|
|
|
121 |
|
122 |
+
model.fit(X_train, y_train)
|
123 |
+
st.success("Model trained successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
+
st.subheader("Model Performance")
|
126 |
+
y_pred = model.predict(X_test)
|
127 |
+
report = classification_report(y_test, y_pred, output_dict=True)
|
128 |
+
st.dataframe(pd.DataFrame(report).transpose())
|
129 |
+
else:
|
130 |
+
st.warning("Load and preprocess a dataset to train a model.")
|