CosmickVisions commited on
Commit
8453100
Β·
verified Β·
1 Parent(s): fb1cc2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -77
app.py CHANGED
@@ -522,19 +522,31 @@ elif app_mode == "Data Cleaning":
522
  # --------------------------
523
  # Label Encoding
524
  # --------------------------
525
- enhance_section_title("Label Encoding", "πŸ”’")
526
- with st.expander("πŸ”’ Label Encoding"):
527
- data_to_encode = st.multiselect("Select categorical columns to encode", df.select_dtypes(include='object').columns)
528
- if data_to_encode:
529
- if st.button("Apply Label Encoding (Encoding)"):
530
- new_df = df.copy()
 
 
 
 
 
 
531
  label_encoders = {}
532
  for col in data_to_encode:
533
  le = LabelEncoder()
534
  new_df[col] = le.fit_transform(new_df[col].astype(str))
535
  label_encoders[col] = le
536
- update_cleaned_data(new_df)
537
- st.rerun() #Force re-run after apply
 
 
 
 
 
 
538
 
539
  # --------------------------
540
  # StandardScaler
@@ -574,16 +586,139 @@ elif app_mode == "Data Cleaning":
574
  # --------------------------
575
  # Bulk Operations
576
  # --------------------------
577
- enhance_section_title("Bulk Actions", "πŸš€")
578
- with st.expander("πŸš€ Bulk Actions"):
579
- if st.button("Auto-Clean Common Issues (Cleaning)"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  new_df = df.copy()
581
  new_df = new_df.dropna(axis=1, how='all') # Remove empty cols
582
  new_df = new_df.convert_dtypes() # Better type inference
583
  text_cols = new_df.select_dtypes(include='object').columns
584
  new_df[text_cols] = new_df[text_cols].apply(lambda x: x.str.strip())
585
  update_cleaned_data(new_df)
586
- st.rerun() #Force re-run after apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
  # --------------------------
589
  # Cleaned Data Preview
@@ -681,7 +816,6 @@ elif app_mode == "EDA":
681
  "Swarm Plot", # YData Library Plots,
682
  "Ridge Plot",
683
  "Bubble Plot",
684
- "Barh Plot",
685
  "Density Plot",
686
  "Count Plot",
687
  "Lollipop Chart",
@@ -803,13 +937,6 @@ elif app_mode == "EDA":
803
  hover_name = size_col,#Hover Name, to show value
804
  title=f"Bubble Plot of {x_axis} vs. {y_axis} Colored by{size_col}"
805
  )
806
- elif plot_type == "Barh Plot":
807
-
808
- if x_axis and y_axis:
809
- fig = px.bar(df, y=x_axis, x=y_axis,
810
- color=color_by if color_by != "None" else None, orientation = 'h', # set x on y-axis side.
811
- title=f"Horizontal Bar Plot of {y_axis} vs {x_axis}"# added chart titles
812
- ) #Set as Vertical as Base, and
813
 
814
  elif plot_type == "Density Plot": #Kernel Estimations with px
815
 
@@ -910,24 +1037,83 @@ elif app_mode == "EDA":
910
  st.write("There is no statistically significant association between the two categorical variables.")
911
 
912
  with tab2:
 
 
913
  st.subheader("Pattern Discovery")
914
  explore_col = st.selectbox("Column to analyze", df.columns)
 
915
  if pd.api.types.is_string_dtype(df[explore_col]):
916
  pattern = st.text_input("Regex pattern")
917
  if pattern:
918
- matches = df[explore_col].str.contains(pattern).sum()
919
- st.write(f"Found {matches} matches")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
 
921
  with tab3:
922
  st.subheader("Data Transformation")
923
- transform_col = st.selectbox("Column to transform", numeric_cols)
924
- transform_type = st.selectbox("Transformation", ["Log", "Square Root", "Z-score"])
 
 
 
 
 
 
 
925
  if transform_type == "Log":
926
  df[transform_col] = np.log1p(df[transform_col])
927
  elif transform_type == "Square Root":
928
  df[transform_col] = np.sqrt(df[transform_col])
929
  elif transform_type == "Z-score":
930
  df[transform_col] = (df[transform_col] - df[transform_col].mean())/df[transform_col].std()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
 
932
  # --------------------------
933
  # Export & Save
@@ -956,6 +1142,12 @@ elif app_mode == "Model Training":
956
 
957
  # ----- [1. Preset Selection] -----
958
  with st.sidebar.expander("πŸš€ Quick Start", expanded=True):
 
 
 
 
 
 
959
  presets = st.selectbox("Load Preset", [
960
  "None",
961
  "CNN-MNIST",
@@ -979,7 +1171,7 @@ elif app_mode == "Model Training":
979
  {"type": "Dense", "units":1, "activation":"sigmoid"}
980
  ]
981
  st.experimental_rerun()
982
-
983
  # ----- [2. Base Model & Transfer Learning] -----
984
  with st.expander("πŸ—οΈ Transfer Learning", expanded=False):
985
  col1, col2 = st.columns(2)
@@ -988,7 +1180,9 @@ elif app_mode == "Model Training":
988
  "None",
989
  "MobileNetV2",
990
  "ResNet50",
991
- "BERT"
 
 
992
  ])
993
 
994
  with col2:
@@ -1002,81 +1196,143 @@ elif app_mode == "Model Training":
1002
  weights='imagenet',
1003
  input_shape=(224, 224, 3) if custom_input else None
1004
  )
1005
- st.info(f"Loaded {base_model} with {len(model.layers)} layers")
1006
-
1007
- # ----- [3. Layer Configuration] -----
1008
- st.subheader("πŸ—οΈ Network Architecture")
1009
-
1010
- # Dynamic layer builder
1011
- layer_types = [
1012
- "Dense", "Conv2D", "LSTM",
1013
- "Dropout", "BatchNorm", "Flatten"
1014
- ]
1015
-
1016
- if 'layers' not in st.session_state:
1017
- st.session_state.layers = []
 
1018
 
1019
- for i, layer in enumerate(st.session_state.layers):
1020
- cols = st.columns([1,3,2])
1021
- with cols[0]:
1022
- st.markdown(f"**Layer {i+1}**")
1023
- with cols[1]:
1024
- st.code(f"{layer['type']}: {dict((k,v) for k,v in layer.items() if k != 'type')}")
1025
- with cols[2]:
1026
- if st.button(f"❌ Remove {i+1}", key=f"remove_{i}"):
1027
- del st.session_state.layers[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1028
  st.experimental_rerun()
1029
-
1030
- # Add new layer controls
1031
- with st.expander("βž• Add New Layer", expanded=True):
1032
- new_layer_type = st.selectbox("Layer Type", layer_types)
1033
- new_layer_params = {}
1034
 
1035
- if new_layer_type == "Dense":
1036
- new_layer_params["units"] = st.number_input("Units", 1, 1024, 128)
1037
- new_layer_params["activation"] = st.selectbox(
1038
- "Activation", ["relu", "sigmoid", "tanh"]
1039
- )
1040
-
1041
- elif new_layer_type == "Conv2D":
1042
- new_layer_params["filters"] = st.number_input("Filters", 1, 256, 32)
1043
- new_layer_params["kernel_size"] = st.number_input("Kernel Size", 1, 9, 3)
1044
-
1045
- if st.button("Add Layer"):
1046
- st.session_state.layers.append({
1047
- "type": new_layer_type,
1048
- **new_layer_params
1049
- })
1050
- st.experimental_rerun()
 
 
 
 
1051
 
 
1052
  # ----- [4. Regularization & Advanced Options] -----
1053
  with st.expander("βš™οΈ Advanced Configuration", expanded=False):
1054
  col1, col2 = st.columns(2)
1055
 
1056
  with col1:
1057
  st.subheader("Regularization")
1058
- l2_reg = st.number_input("L2 Regularization", 0.0, 0.1, 0.001)
1059
- dropout = st.number_input("Global Dropout", 0.0, 0.5, 0.2)
1060
- batch_norm = st.checkbox("Batch Normalization")
1061
 
1062
  with col2:
1063
  st.subheader("Optimization")
1064
  optimizer = st.selectbox("Optimizer", [
1065
  "adam", "sgd", "rmsprop",
1066
  "nadam", "adamax"
1067
- ])
1068
 
1069
  loss = st.selectbox("Loss Function", [
1070
  "categorical_crossentropy",
1071
  "binary_crossentropy",
1072
  "mse",
1073
  "mae"
1074
- ])
1075
 
1076
  metrics = st.multiselect("Metrics", [
1077
  "accuracy", "precision",
1078
  "recall", "auc"
1079
- ])
 
 
 
 
 
 
 
 
 
 
 
1080
 
1081
  # ----- [5. Training & Monitoring] -----
1082
  st.subheader("🎯 Training Configuration")
@@ -1091,7 +1347,7 @@ elif app_mode == "Model Training":
1091
  def update_chart(self):
1092
  df = pd.DataFrame(st.session_state.metrics)
1093
  fig = px.line(df, y=['loss', 'val_loss'],
1094
- title="Training Progress")
1095
  loss_chart.plotly_chart(fig)
1096
 
1097
  if st.button("πŸš€ Start Training"):
@@ -1117,7 +1373,8 @@ elif app_mode == "Model Training":
1117
  model.add(BatchNormalization())
1118
 
1119
  # Add global dropout
1120
- model.add(Dropout(dropout))
 
1121
 
1122
  model.compile(
1123
  optimizer=optimizer,
@@ -1142,6 +1399,7 @@ elif app_mode == "Model Training":
1142
  except Exception as e:
1143
  st.error(f"Training failed: {str(e)}")
1144
 
 
1145
  # ----- [6. Export & Deployment] -----
1146
  st.subheader("πŸ’Ύ Export Model")
1147
 
 
522
  # --------------------------
523
  # Label Encoding
524
  # --------------------------
525
+ # --------------------------
526
+ # Label/One-Hot Encoding
527
+ # --------------------------
528
+ enhance_section_title("Encoding Options", "πŸ”’")
529
+ with st.expander("πŸ”’ Encoding Options"):
530
+ encoding_method = st.radio("Select Encoding Method", ("Label Encoding", "One-Hot Encoding"))
531
+
532
+ data_to_encode = st.multiselect("Select categorical columns to encode", df.select_dtypes(include='object').columns)
533
+ if data_to_encode:
534
+ if st.button("Apply Encoding"):
535
+ new_df = df.copy()
536
+ if encoding_method == "Label Encoding":
537
  label_encoders = {}
538
  for col in data_to_encode:
539
  le = LabelEncoder()
540
  new_df[col] = le.fit_transform(new_df[col].astype(str))
541
  label_encoders[col] = le
542
+ elif encoding_method == "One-Hot Encoding":
543
+ new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True)
544
+
545
+ update_cleaned_data(new_df)
546
+ st.rerun() # Force re-run after apply
547
+ except Exception as e:
548
+ st.error(f"Error: {str(e)}")
549
+
550
 
551
  # --------------------------
552
  # StandardScaler
 
586
  # --------------------------
587
  # Bulk Operations
588
  # --------------------------
589
+ # --------------------------
590
+ # Bulk Operations
591
+ # --------------------------
592
+ enhance_section_title("Bulk Actions", "πŸš€")
593
+ with st.expander("πŸš€ Bulk Actions"):
594
+ bulk_action = st.selectbox("Select Bulk Action", [
595
+ "Auto-Clean Common Issues",
596
+ "Drop All Missing Values",
597
+ "Fill Missing Values",
598
+ "One-Hot Encode All Categorical Columns",
599
+ "Min-Max Scaling",
600
+ "Remove Outliers",
601
+ "Tokenize Text Columns",
602
+ "Vectorize Text Columns (TF-IDF)",
603
+ "Extract Date Features",
604
+ "Target Encoding",
605
+ "Principal Component Analysis (PCA)"
606
+ ])
607
+
608
+ if bulk_action == "Auto-Clean Common Issues":
609
+ if st.button("Apply Auto-Clean"):
610
  new_df = df.copy()
611
  new_df = new_df.dropna(axis=1, how='all') # Remove empty cols
612
  new_df = new_df.convert_dtypes() # Better type inference
613
  text_cols = new_df.select_dtypes(include='object').columns
614
  new_df[text_cols] = new_df[text_cols].apply(lambda x: x.str.strip())
615
  update_cleaned_data(new_df)
616
+ st.rerun() # Force re-run after apply
617
+
618
+ if bulk_action == "Drop All Missing Values":
619
+ if st.button("Apply Drop All Missing"):
620
+ new_df = df.copy()
621
+ new_df = new_df.dropna() # Drop rows with any missing values
622
+ update_cleaned_data(new_df)
623
+ st.rerun() # Force re-run after apply
624
+
625
+ if bulk_action == "Fill Missing Values":
626
+ fill_value = st.text_input("Fill Value (e.g., 0, mean, median)")
627
+ if st.button("Apply Fill Missing"):
628
+ new_df = df.copy()
629
+ if fill_value.lower() == "mean":
630
+ new_df = new_df.fillna(new_df.mean())
631
+ elif fill_value.lower() == "median":
632
+ new_df = new_df.fillna(new_df.median())
633
+ else:
634
+ new_df = new_df.fillna(fill_value)
635
+ update_cleaned_data(new_df)
636
+ st.rerun() # Force re-run after apply
637
+
638
+ if bulk_action == "One-Hot Encode All Categorical Columns":
639
+ if st.button("Apply One-Hot Encoding"):
640
+ new_df = df.copy()
641
+ categorical_cols = new_df.select_dtypes(include='object').columns
642
+ new_df = pd.get_dummies(new_df, columns=categorical_cols, drop_first=True)
643
+ update_cleaned_data(new_df)
644
+ st.rerun() # Force re-run after apply
645
+
646
+ if bulk_action == "Min-Max Scaling":
647
+ if st.button("Apply Min-Max Scaling"):
648
+ new_df = df.copy()
649
+ scaler = MinMaxScaler()
650
+ numerical_cols = new_df.select_dtypes(include=np.number).columns
651
+ new_df[numerical_cols] = scaler.fit_transform(new_df[numerical_cols])
652
+ update_cleaned_data(new_df)
653
+ st.rerun() # Force re-run after apply
654
+
655
+ if bulk_action == "Remove Outliers":
656
+ if st.button("Apply Remove Outliers"):
657
+ new_df = df.copy()
658
+ z_scores = np.abs(stats.zscore(new_df.select_dtypes(include=np.number)))
659
+ new_df = new_df[(z_scores < 3).all(axis=1)] # Remove rows with z-score > 3
660
+ update_cleaned_data(new_df)
661
+ st.rerun() # Force re-run after apply
662
+
663
+ if bulk_action == "Tokenize Text Columns":
664
+ text_cols = st.multiselect("Select text columns to tokenize", df.select_dtypes(include='object').columns)
665
+ if text_cols:
666
+ if st.button("Apply Tokenization"):
667
+ tokenizer = Tokenizer()
668
+ new_df = df.copy()
669
+ for col in text_cols:
670
+ tokenizer.fit_on_texts(new_df[col])
671
+ new_df[col] = tokenizer.texts_to_sequences(new_df[col])
672
+ update_cleaned_data(new_df)
673
+ st.rerun() # Force re-run after apply
674
+
675
+ if bulk_action == "Vectorize Text Columns (TF-IDF)":
676
+ text_cols = st.multiselect("Select text columns to vectorize", df.select_dtypes(include='object').columns)
677
+ if text_cols:
678
+ if st.button("Apply TF-IDF Vectorization"):
679
+ tfidf = TfidfVectorizer()
680
+ new_df = df.copy()
681
+ for col in text_cols:
682
+ new_col = tfidf.fit_transform(new_df[col]).toarray()
683
+ new_df = new_df.drop(columns=[col])
684
+ new_df = new_df.join(pd.DataFrame(new_col, columns=[f'{col}_{i}' for i in range(new_col.shape[1])]))
685
+ update_cleaned_data(new_df)
686
+ st.rerun() # Force re-run after apply
687
+
688
+ if bulk_action == "Extract Date Features":
689
+ date_cols = st.multiselect("Select date columns to extract features from", df.select_dtypes(include='datetime').columns)
690
+ if date_cols:
691
+ if st.button("Apply Date Feature Extraction"):
692
+ new_df = df.copy()
693
+ for col in date_cols:
694
+ new_df[f'{col}_year'] = new_df[col].dt.year
695
+ new_df[f'{col}_month'] = new_df[col].dt.month
696
+ new_df[f'{col}_day'] = new_df[col].dt.day
697
+ new_df[f'{col}_weekday'] = new_df[col].dt.weekday
698
+ new_df[f'{col}_hour'] = new_df[col].dt.hour
699
+ update_cleaned_data(new_df)
700
+ st.rerun() # Force re-run after apply
701
+
702
+ if bulk_action == "Target Encoding":
703
+ target_col = st.selectbox("Select target column", df.columns)
704
+ cat_cols = st.multiselect("Select categorical columns to encode", df.select_dtypes(include='object').columns)
705
+ if cat_cols:
706
+ if st.button("Apply Target Encoding"):
707
+ new_df = df.copy()
708
+ for col in cat_cols:
709
+ target_mean = new_df.groupby(col)[target_col].mean()
710
+ new_df[col] = new_df[col].map(target_mean)
711
+ update_cleaned_data(new_df)
712
+ st.rerun() # Force re-run after apply
713
+
714
+ if bulk_action == "Principal Component Analysis (PCA)":
715
+ n_components = st.slider("Number of components", min_value=1, max_value=min(df.shape[1], 10), value=2)
716
+ if st.button("Apply PCA"):
717
+ new_df = df.copy()
718
+ pca = PCA(n_components=n_components)
719
+ pca_result = pca.fit_transform(new_df.select_dtypes(include=np.number))
720
+ new_df = pd.DataFrame(pca_result, columns=[f'PC{i+1}' for i in range
721
+
722
 
723
  # --------------------------
724
  # Cleaned Data Preview
 
816
  "Swarm Plot", # YData Library Plots,
817
  "Ridge Plot",
818
  "Bubble Plot",
 
819
  "Density Plot",
820
  "Count Plot",
821
  "Lollipop Chart",
 
937
  hover_name = size_col,#Hover Name, to show value
938
  title=f"Bubble Plot of {x_axis} vs. {y_axis} Colored by{size_col}"
939
  )
 
 
 
 
 
 
 
940
 
941
  elif plot_type == "Density Plot": #Kernel Estimations with px
942
 
 
1037
  st.write("There is no statistically significant association between the two categorical variables.")
1038
 
1039
  with tab2:
1040
+ # Pattern Discovery--------------------------
1041
+
1042
  st.subheader("Pattern Discovery")
1043
  explore_col = st.selectbox("Column to analyze", df.columns)
1044
+
1045
  if pd.api.types.is_string_dtype(df[explore_col]):
1046
  pattern = st.text_input("Regex pattern")
1047
  if pattern:
1048
+ # Perform regex matching
1049
+ matches = df[explore_col].str.contains(pattern, regex=True, na=False)
1050
+ num_matches = matches.sum()
1051
+ st.write(f"Found {num_matches} matches")
1052
+
1053
+ # Display matching rows
1054
+ if num_matches > 0:
1055
+ st.write("Matching rows:")
1056
+ st.dataframe(df[matches].head(), use_container_width=True)
1057
+
1058
+ # Provide regex syntax help
1059
+ with st.expander("Regex Syntax Help"):
1060
+ st.markdown("""
1061
+ **Basic Syntax:**
1062
+ - `.`: Any single character
1063
+ - `*`: 0 or more repetitions
1064
+ - `+`: 1 or more repetitions
1065
+ - `?`: 0 or 1 repetition
1066
+ - `[]`: Any character within the brackets
1067
+ - `|`: Either or
1068
+
1069
+ For more details, visit [Regex101](https://regex101.com/)
1070
+ """)
1071
+
1072
+ else:
1073
+ st.warning("Please select a string column for pattern discovery.")
1074
+
1075
 
1076
  with tab3:
1077
  st.subheader("Data Transformation")
1078
+ transform_col = st.selectbox("Column to transform", df.select_dtypes(include=[np.number]).columns)
1079
+ transform_type = st.selectbox("Transformation", ["Log", "Square Root", "Z-score", "Standardization", "Normalization", "Box-Cox", "Inverse"])
1080
+
1081
+ fig, ax = plt.subplots(1, 2, figsize=(12, 5))
1082
+
1083
+ if transform_col:
1084
+ sns.histplot(df[transform_col], bins=30, kde=True, ax=ax[0])
1085
+ ax[0].set_title('Before Transformation')
1086
+
1087
  if transform_type == "Log":
1088
  df[transform_col] = np.log1p(df[transform_col])
1089
  elif transform_type == "Square Root":
1090
  df[transform_col] = np.sqrt(df[transform_col])
1091
  elif transform_type == "Z-score":
1092
  df[transform_col] = (df[transform_col] - df[transform_col].mean())/df[transform_col].std()
1093
+ elif transform_type == "Standardization":
1094
+ df[transform_col] = (df[transform_col] - df[transform_col].mean()) / df[transform_col].std()
1095
+ elif transform_type == "Normalization":
1096
+ df[transform_col] = (df[transform_col] - df[transform_col].min()) / (df[transform_col].max() - df[transform_col].min())
1097
+ elif transform_type == "Box-Cox":
1098
+ df[transform_col], _ = boxcox(df[transform_col] + 1) # Adding 1 to avoid log(0) error
1099
+ elif transform_type == "Inverse":
1100
+ df[transform_col] = 1 / df[transform_col]
1101
+
1102
+ sns.histplot(df[transform_col], bins=30, kde=True, ax=ax[1])
1103
+ ax[1].set_title('After Transformation')
1104
+
1105
+ st.pyplot(fig)
1106
+ else:
1107
+ st.warning("Please select a column for transformation.")
1108
+
1109
+ # Error handling for invalid transformations
1110
+ try:
1111
+ if transform_type == "Log" and (df[transform_col] <= 0).any():
1112
+ st.error("Log transformation is not applicable to non-positive values.")
1113
+ elif transform_type == "Box-Cox" and (df[transform_col] <= 0).any():
1114
+ st.error("Box-Cox transformation requires all values to be positive.")
1115
+ except Exception as e:
1116
+ st.error(f"Transformation failed: {str(e)}")
1117
 
1118
  # --------------------------
1119
  # Export & Save
 
1142
 
1143
  # ----- [1. Preset Selection] -----
1144
  with st.sidebar.expander("πŸš€ Quick Start", expanded=True):
1145
+ col1, col2 = st.columns(2)
1146
+ with col1:
1147
+ st.image("cnn_mnist.png", caption="CNN-MNIST", use_column_width=True)
1148
+ with col2:
1149
+ st.image("lstm_text.png", caption="LSTM-Text", use_column_width=True)
1150
+
1151
  presets = st.selectbox("Load Preset", [
1152
  "None",
1153
  "CNN-MNIST",
 
1171
  {"type": "Dense", "units":1, "activation":"sigmoid"}
1172
  ]
1173
  st.experimental_rerun()
1174
+
1175
  # ----- [2. Base Model & Transfer Learning] -----
1176
  with st.expander("πŸ—οΈ Transfer Learning", expanded=False):
1177
  col1, col2 = st.columns(2)
 
1180
  "None",
1181
  "MobileNetV2",
1182
  "ResNet50",
1183
+ "BERT",
1184
+ "InceptionV3",
1185
+ "VGG16"
1186
  ])
1187
 
1188
  with col2:
 
1196
  weights='imagenet',
1197
  input_shape=(224, 224, 3) if custom_input else None
1198
  )
1199
+ elif base_model == "InceptionV3":
1200
+ model = tf.keras.applications.InceptionV3(
1201
+ include_top=False,
1202
+ weights='imagenet',
1203
+ input_shape=(299, 299, 3) if custom_input else None
1204
+ )
1205
+ elif base_model == "VGG16":
1206
+ model = tf.keras.applications.VGG16(
1207
+ include_top=False,
1208
+ weights='imagenet',
1209
+ input_shape=(224, 224, 3) if custom_input else None
1210
+ )
1211
+
1212
+ st.info(f"Loaded {base_model} with {len(model.layers)} layers")
1213
 
1214
+ # Visualize Model Architecture
1215
+ fig, ax = plt.subplots(figsize=(10, 5))
1216
+ sns.barplot(x=[layer.name for layer in model.layers], y=[layer.output_shape for layer in model.layers], ax=ax)
1217
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
1218
+ st.pyplot(fig)
1219
+
1220
+ # ----- [3. Layer Configuration] -----
1221
+ with st.subheader("πŸ—οΈ Network Architecture")
1222
+
1223
+ # Dynamic layer builder
1224
+ layer_types = [
1225
+ "Dense", "Conv2D", "LSTM",
1226
+ "Dropout", "BatchNorm", "Flatten"
1227
+ ]
1228
+
1229
+ if 'layers' not in st.session_state:
1230
+ st.session_state.layers = []
1231
+
1232
+ def render_layer(layer, index):
1233
+ cols = st.columns([1, 3, 2])
1234
+ with cols[0]:
1235
+ st.markdown(f"**Layer {index + 1}**")
1236
+ with cols[1]:
1237
+ st.code(f"{layer['type']}: {dict((k, v) for k, v in layer.items() if k != 'type')}")
1238
+ with cols[2]:
1239
+ if st.button(f"❌ Remove {index + 1}", key=f"remove_{index}"):
1240
+ del st.session_state.layers[index]
1241
+ st.experimental_rerun()
1242
+
1243
+ for i, layer in enumerate(st.session_state.layers):
1244
+ render_layer(layer, i)
1245
+
1246
+ # Add new layer controls
1247
+ with st.expander("βž• Add New Layer", expanded=True):
1248
+ new_layer_type = st.selectbox("Layer Type", layer_types)
1249
+ new_layer_params = {}
1250
+
1251
+ if new_layer_type == "Dense":
1252
+ new_layer_params["units"] = st.number_input("Units", 1, 1024, 128, help="Number of neurons in the layer.")
1253
+ new_layer_params["activation"] = st.selectbox(
1254
+ "Activation", ["relu", "sigmoid", "tanh"], help="Activation function to use."
1255
+ )
1256
+ elif new_layer_type == "Conv2D":
1257
+ new_layer_params["filters"] = st.number_input("Filters", 1, 256, 32, help="Number of filters in the convolution.")
1258
+ new_layer_params["kernel_size"] = st.number_input("Kernel Size", 1, 9, 3, help="Size of the convolution kernel.")
1259
+ elif new_layer_type == "LSTM":
1260
+ new_layer_params["units"] = st.number_input("Units", 1, 512, 64, help="Number of units in the LSTM layer.")
1261
+ elif new_layer_type == "Dropout":
1262
+ new_layer_params["rate"] = st.number_input("Rate", 0.0, 1.0, 0.5, help="Dropout rate to use.")
1263
+ elif new_layer_type == "BatchNorm":
1264
+ pass # BatchNorm has no parameters
1265
+ elif new_layer_type == "Flatten":
1266
+ pass # Flatten has no parameters
1267
+
1268
+ if st.button("Add Layer"):
1269
+ st.session_state.layers.append({
1270
+ "type": new_layer_type,
1271
+ **new_layer_params
1272
+ })
1273
  st.experimental_rerun()
 
 
 
 
 
1274
 
1275
+ # Visualize the model architecture
1276
+ st.subheader("Model Visualization")
1277
+ fig, ax = plt.subplots(figsize=(10, 6))
1278
+ layer_types = [layer['type'] for layer in st.session_state.layers]
1279
+ layer_counts = {layer: layer_types.count(layer) for layer in layer_types}
1280
+ ax.bar(layer_counts.keys(), layer_counts.values())
1281
+ ax.set_xlabel("Layer Types")
1282
+ ax.set_ylabel("Count")
1283
+ st.pyplot(fig)
1284
+
1285
+ # Tooltip information
1286
+ st.info("""
1287
+ **Tooltips:**
1288
+ - **Units**: Number of neurons/units in the layer.
1289
+ - **Activation**: Function used to activate the neurons.
1290
+ - **Filters**: Number of filters in the convolution layer.
1291
+ - **Kernel Size**: Size of the kernel used in the convolution layer.
1292
+ - **Rate**: Dropout rate used to drop neurons during training to prevent overfitting.
1293
+ """)
1294
+
1295
 
1296
+ # ----- [4. Regularization & Advanced Options] -----
1297
  # ----- [4. Regularization & Advanced Options] -----
1298
  with st.expander("βš™οΈ Advanced Configuration", expanded=False):
1299
  col1, col2 = st.columns(2)
1300
 
1301
  with col1:
1302
  st.subheader("Regularization")
1303
+ l2_reg = st.number_input("L2 Regularization", 0.0, 0.1, 0.001, help="Regularization to prevent overfitting.")
1304
+ dropout = st.number_input("Global Dropout", 0.0, 0.5, 0.2, help="Dropout rate for neurons during training.")
1305
+ batch_norm = st.checkbox("Batch Normalization", help="Add batch normalization after each layer.")
1306
 
1307
  with col2:
1308
  st.subheader("Optimization")
1309
  optimizer = st.selectbox("Optimizer", [
1310
  "adam", "sgd", "rmsprop",
1311
  "nadam", "adamax"
1312
+ ], help="Optimizer for model training.")
1313
 
1314
  loss = st.selectbox("Loss Function", [
1315
  "categorical_crossentropy",
1316
  "binary_crossentropy",
1317
  "mse",
1318
  "mae"
1319
+ ], help="Loss function to minimize during training.")
1320
 
1321
  metrics = st.multiselect("Metrics", [
1322
  "accuracy", "precision",
1323
  "recall", "auc"
1324
+ ], help="Evaluation metrics to track during training.")
1325
+
1326
+ # Additional Configuration for Validation and Hyperparameter Tuning
1327
+ with st.expander("πŸ”§ Additional Configuration", expanded=False):
1328
+ st.subheader("Validation Settings")
1329
+ val_split = st.slider("Validation Split", 0.0, 0.5, 0.2, help="Proportion of data to use for validation.")
1330
+
1331
+ st.subheader("Hyperparameter Tuning")
1332
+ tuning = st.checkbox("Enable Hyperparameter Tuning", help="Enable automated hyperparameter tuning.")
1333
+ if tuning:
1334
+ tuning_method = st.selectbox("Tuning Method", ["Grid Search", "Random Search"])
1335
+ num_trials = st.number_input("Number of Trials", 1, 100, 10, help="Number of trials for hyperparameter search.")
1336
 
1337
  # ----- [5. Training & Monitoring] -----
1338
  st.subheader("🎯 Training Configuration")
 
1347
  def update_chart(self):
1348
  df = pd.DataFrame(st.session_state.metrics)
1349
  fig = px.line(df, y=['loss', 'val_loss'],
1350
+ title="Training Progress")
1351
  loss_chart.plotly_chart(fig)
1352
 
1353
  if st.button("πŸš€ Start Training"):
 
1373
  model.add(BatchNormalization())
1374
 
1375
  # Add global dropout
1376
+ if dropout > 0:
1377
+ model.add(Dropout(dropout))
1378
 
1379
  model.compile(
1380
  optimizer=optimizer,
 
1399
  except Exception as e:
1400
  st.error(f"Training failed: {str(e)}")
1401
 
1402
+
1403
  # ----- [6. Export & Deployment] -----
1404
  st.subheader("πŸ’Ύ Export Model")
1405