andromeda01111 commited on
Commit
b2fe1f7
·
verified ·
1 Parent(s): b1666ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -43
app.py CHANGED
@@ -9,14 +9,11 @@ from sklearn.preprocessing import StandardScaler
9
  import joblib
10
  import os
11
 
12
- # Force CPU usage for TensorFlow to avoid CUDA issues
13
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
14
-
15
  # Set device
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  # Load trained ViT model (PyTorch)
19
- vit_model = models.vit_b_16(weights=None)
20
  vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
21
 
22
  # Load ViT model weights
@@ -37,15 +34,11 @@ transform = transforms.Compose([
37
  class_names = ["Benign", "Malignant"]
38
 
39
  # Load trained Neural Network model (TensorFlow/Keras)
40
- nn_model_path = "my_NN_BC_model.keras"
41
  nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
42
 
43
- # Load trained CNN model (TensorFlow/Keras)
44
- cnn_model_path = "my_CNN_BC_model.keras"
45
- cnn_model = tf.keras.models.load_model(cnn_model_path) if os.path.exists(cnn_model_path) else None
46
-
47
  # Load scaler for feature normalization
48
- scaler_path = "nn_bc_scaler.pkl"
49
  scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
50
 
51
  # Feature names
@@ -58,8 +51,17 @@ feature_names = [
58
  "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
59
  ]
60
 
61
- def classify(model_choice, image=None, file=None, *features):
62
- """Classify using ViT (image), NN, or CNN (features from manual input or file)."""
 
 
 
 
 
 
 
 
 
63
  if model_choice == "ViT":
64
  if image is None:
65
  return "Please upload an image for ViT classification."
@@ -72,43 +74,37 @@ def classify(model_choice, image=None, file=None, *features):
72
 
73
  return class_names[predicted_class]
74
 
75
- elif model_choice in ["Neural Network", "CNN"]:
76
- if file is not None:
77
- try:
78
- file_data = file.read().decode("utf-8").strip().split(',')
79
- features = [float(x) for x in file_data]
80
- except Exception as e:
81
- return f"Error reading file: {e}"
82
-
83
- if len(features) != len(feature_names):
84
- return "Incorrect number of features provided. Please provide exactly 30 numerical values."
85
 
86
  input_data = np.array(features).reshape(1, -1)
87
  input_data_std = scaler.transform(input_data) if scaler else input_data
88
-
89
- if model_choice == "Neural Network":
90
- prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
91
- else: # CNN
92
- prediction = cnn_model.predict(input_data_std) if cnn_model else [[0, 1]]
93
-
94
  predicted_class = np.argmax(prediction)
 
95
  return class_names[predicted_class]
96
 
97
  # Gradio UI
98
- with gr.Blocks() as iface:
99
- gr.Markdown("# Breast Cancer Classification")
100
- model_selector = gr.Radio(["ViT", "Neural Network", "CNN"], label="Choose Model")
101
-
102
- with gr.Row():
103
- image_input = gr.Image(type="pil", label="Upload Mammogram Image")
104
- file_input = gr.File(label="Upload Text File (for NN/CNN model)")
105
-
106
- with gr.Column():
107
- feature_inputs = [gr.Number(label=feature) for feature in feature_names]
108
-
109
- submit_button = gr.Button("Classify")
110
- output_text = gr.Textbox(label="Prediction Result")
111
-
112
- submit_button.click(fn=classify, inputs=[model_selector, image_input, file_input] + feature_inputs, outputs=output_text)
 
 
 
 
 
113
 
114
  iface.launch()
 
9
  import joblib
10
  import os
11
 
 
 
 
12
  # Set device
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  # Load trained ViT model (PyTorch)
16
+ vit_model = models.vit_b_16(pretrained=False)
17
  vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
18
 
19
  # Load ViT model weights
 
34
  class_names = ["Benign", "Malignant"]
35
 
36
  # Load trained Neural Network model (TensorFlow/Keras)
37
+ nn_model_path = "my_NN_BC_model.keras" # Update with uploaded model path
38
  nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
39
 
 
 
 
 
40
  # Load scaler for feature normalization
41
+ scaler_path = "nn_bc_scaler.pkl" # Update with uploaded model path
42
  scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
43
 
44
  # Feature names
 
51
  "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
52
  ]
53
 
54
+ # Example inputs
55
+ benign_example = [13.54, 14.36, 87.46, 566.3, 0.09779, 0.08129, 0.06664, 0.04781, 0.1885, 0.05766,
56
+ 0.2699, 0.7886, 2.058, 23.56, 0.008462, 0.0146, 0.02387, 0.01315, 0.0198, 0.0023,
57
+ 15.11, 19.26, 99.7, 711.2, 0.144, 0.1773, 0.239, 0.1288, 0.2977, 0.07259]
58
+
59
+ malignant_example = [17.99, 10.38, 122.8, 1001.0, 0.1184, 0.2776, 0.3001, 0.1471, 0.2419, 0.07871,
60
+ 1.095, 0.9053, 8.589, 153.4, 0.006399, 0.04904, 0.05373, 0.01587, 0.03003, 0.006193,
61
+ 25.38, 17.33, 184.6, 2019.0, 0.1622, 0.6656, 0.7119, 0.2654, 0.4601, 0.1189]
62
+
63
+ def classify(model_choice, image=None, *features):
64
+ """Classify using ViT (image) or NN (features)."""
65
  if model_choice == "ViT":
66
  if image is None:
67
  return "Please upload an image for ViT classification."
 
74
 
75
  return class_names[predicted_class]
76
 
77
+ elif model_choice == "Neural Network":
78
+ if any(f is None for f in features):
79
+ return "Please enter all 30 numerical features."
 
 
 
 
 
 
 
80
 
81
  input_data = np.array(features).reshape(1, -1)
82
  input_data_std = scaler.transform(input_data) if scaler else input_data
83
+ prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
 
 
 
 
 
84
  predicted_class = np.argmax(prediction)
85
+
86
  return class_names[predicted_class]
87
 
88
  # Gradio UI
89
+ model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
90
+ image_input = gr.Image(type="pil", label="Upload Mammogram Image")
91
+ feature_inputs = [gr.Number(label=feature) for feature in feature_names]
92
+
93
+ # Example buttons
94
+ def fill_example(example):
95
+ return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
96
+
97
+ example_buttons = gr.Radio(["Benign Example", "Malignant Example"], label="Select Example Input")
98
+ def update_example(choice):
99
+ return fill_example(benign_example if choice == "Benign Example" else malignant_example)
100
+
101
+ iface = gr.Interface(
102
+ fn=classify,
103
+ inputs=[model_selector, image_input, example_buttons] + feature_inputs,
104
+ outputs="text",
105
+ title="Breast Cancer Classification",
106
+ description="Choose between ViT (image-based) and Neural Network (feature-based) classification.",
107
+ live=True
108
+ )
109
 
110
  iface.launch()