Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,11 +9,14 @@ from sklearn.preprocessing import StandardScaler
|
|
9 |
import joblib
|
10 |
import os
|
11 |
|
|
|
|
|
|
|
12 |
# Set device
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
# Load trained ViT model (PyTorch)
|
16 |
-
vit_model = models.vit_b_16(
|
17 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
18 |
|
19 |
# Load ViT model weights
|
@@ -34,11 +37,15 @@ transform = transforms.Compose([
|
|
34 |
class_names = ["Benign", "Malignant"]
|
35 |
|
36 |
# Load trained Neural Network model (TensorFlow/Keras)
|
37 |
-
nn_model_path = "my_NN_BC_model.keras"
|
38 |
nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
|
39 |
|
|
|
|
|
|
|
|
|
40 |
# Load scaler for feature normalization
|
41 |
-
scaler_path = "nn_bc_scaler.pkl"
|
42 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
43 |
|
44 |
# Feature names
|
@@ -52,7 +59,7 @@ feature_names = [
|
|
52 |
]
|
53 |
|
54 |
def classify(model_choice, image=None, file=None, *features):
|
55 |
-
"""Classify using ViT (image) or
|
56 |
if model_choice == "ViT":
|
57 |
if image is None:
|
58 |
return "Please upload an image for ViT classification."
|
@@ -65,7 +72,7 @@ def classify(model_choice, image=None, file=None, *features):
|
|
65 |
|
66 |
return class_names[predicted_class]
|
67 |
|
68 |
-
elif model_choice
|
69 |
if file is not None:
|
70 |
try:
|
71 |
file_data = file.read().decode("utf-8").strip().split(',')
|
@@ -74,42 +81,34 @@ def classify(model_choice, image=None, file=None, *features):
|
|
74 |
return f"Error reading file: {e}"
|
75 |
|
76 |
if len(features) != len(feature_names):
|
77 |
-
return "Please provide
|
78 |
|
79 |
input_data = np.array(features).reshape(1, -1)
|
80 |
input_data_std = scaler.transform(input_data) if scaler else input_data
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
predicted_class = np.argmax(prediction)
|
83 |
-
|
84 |
return class_names[predicted_class]
|
85 |
|
|
|
86 |
with gr.Blocks() as iface:
|
87 |
gr.Markdown("# Breast Cancer Classification")
|
88 |
-
gr.
|
89 |
-
|
90 |
-
with gr.Row():
|
91 |
-
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model", interactive=True)
|
92 |
-
|
93 |
-
with gr.Row():
|
94 |
-
with gr.Column():
|
95 |
-
image_input = gr.Image(type="pil", label="Upload Image")
|
96 |
-
file_input = gr.File(label="Upload Text File (for NN model)")
|
97 |
-
|
98 |
-
with gr.Row():
|
99 |
-
gr.Markdown("### Enter Features (For Neural Network Model)")
|
100 |
|
101 |
-
num_cols = 3
|
102 |
-
feature_inputs = [gr.Number(label=feature) for feature in feature_names]
|
103 |
-
feature_inputs_split = [feature_inputs[i::num_cols] for i in range(num_cols)]
|
104 |
-
|
105 |
with gr.Row():
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
114 |
|
115 |
iface.launch()
|
|
|
9 |
import joblib
|
10 |
import os
|
11 |
|
12 |
+
# Force CPU usage for TensorFlow to avoid CUDA issues
|
13 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
14 |
+
|
15 |
# Set device
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
# Load trained ViT model (PyTorch)
|
19 |
+
vit_model = models.vit_b_16(weights=None)
|
20 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
21 |
|
22 |
# Load ViT model weights
|
|
|
37 |
class_names = ["Benign", "Malignant"]
|
38 |
|
39 |
# Load trained Neural Network model (TensorFlow/Keras)
|
40 |
+
nn_model_path = "my_NN_BC_model.keras"
|
41 |
nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
|
42 |
|
43 |
+
# Load trained CNN model (TensorFlow/Keras)
|
44 |
+
cnn_model_path = "my_CNN_BC_model.keras"
|
45 |
+
cnn_model = tf.keras.models.load_model(cnn_model_path) if os.path.exists(cnn_model_path) else None
|
46 |
+
|
47 |
# Load scaler for feature normalization
|
48 |
+
scaler_path = "nn_bc_scaler.pkl"
|
49 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
50 |
|
51 |
# Feature names
|
|
|
59 |
]
|
60 |
|
61 |
def classify(model_choice, image=None, file=None, *features):
|
62 |
+
"""Classify using ViT (image), NN, or CNN (features from manual input or file)."""
|
63 |
if model_choice == "ViT":
|
64 |
if image is None:
|
65 |
return "Please upload an image for ViT classification."
|
|
|
72 |
|
73 |
return class_names[predicted_class]
|
74 |
|
75 |
+
elif model_choice in ["Neural Network", "CNN"]:
|
76 |
if file is not None:
|
77 |
try:
|
78 |
file_data = file.read().decode("utf-8").strip().split(',')
|
|
|
81 |
return f"Error reading file: {e}"
|
82 |
|
83 |
if len(features) != len(feature_names):
|
84 |
+
return "Incorrect number of features provided. Please provide exactly 30 numerical values."
|
85 |
|
86 |
input_data = np.array(features).reshape(1, -1)
|
87 |
input_data_std = scaler.transform(input_data) if scaler else input_data
|
88 |
+
|
89 |
+
if model_choice == "Neural Network":
|
90 |
+
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
|
91 |
+
else: # CNN
|
92 |
+
prediction = cnn_model.predict(input_data_std) if cnn_model else [[0, 1]]
|
93 |
+
|
94 |
predicted_class = np.argmax(prediction)
|
|
|
95 |
return class_names[predicted_class]
|
96 |
|
97 |
+
# Gradio UI
|
98 |
with gr.Blocks() as iface:
|
99 |
gr.Markdown("# Breast Cancer Classification")
|
100 |
+
model_selector = gr.Radio(["ViT", "Neural Network", "CNN"], label="Choose Model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
|
|
|
|
|
|
|
|
102 |
with gr.Row():
|
103 |
+
image_input = gr.Image(type="pil", label="Upload Mammogram Image")
|
104 |
+
file_input = gr.File(label="Upload Text File (for NN/CNN model)")
|
105 |
+
|
106 |
+
with gr.Column():
|
107 |
+
feature_inputs = [gr.Number(label=feature) for feature in feature_names]
|
108 |
+
|
109 |
+
submit_button = gr.Button("Classify")
|
110 |
+
output_text = gr.Textbox(label="Prediction Result")
|
111 |
+
|
112 |
+
submit_button.click(fn=classify, inputs=[model_selector, image_input, file_input] + feature_inputs, outputs=output_text)
|
113 |
|
114 |
iface.launch()
|