Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,14 +9,11 @@ from sklearn.preprocessing import StandardScaler
|
|
9 |
import joblib
|
10 |
import os
|
11 |
|
12 |
-
# Force CPU usage for TensorFlow to avoid CUDA issues
|
13 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
14 |
-
|
15 |
# Set device
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
# Load trained ViT model (PyTorch)
|
19 |
-
vit_model = models.vit_b_16(
|
20 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
21 |
|
22 |
# Load ViT model weights
|
@@ -37,15 +34,11 @@ transform = transforms.Compose([
|
|
37 |
class_names = ["Benign", "Malignant"]
|
38 |
|
39 |
# Load trained Neural Network model (TensorFlow/Keras)
|
40 |
-
nn_model_path = "my_NN_BC_model.keras"
|
41 |
nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
|
42 |
|
43 |
-
# Load trained CNN model (TensorFlow/Keras)
|
44 |
-
cnn_model_path = "my_CNN_BC_model.keras"
|
45 |
-
cnn_model = tf.keras.models.load_model(cnn_model_path) if os.path.exists(cnn_model_path) else None
|
46 |
-
|
47 |
# Load scaler for feature normalization
|
48 |
-
scaler_path = "nn_bc_scaler.pkl"
|
49 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
50 |
|
51 |
# Feature names
|
@@ -58,8 +51,17 @@ feature_names = [
|
|
58 |
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
|
59 |
]
|
60 |
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
if model_choice == "ViT":
|
64 |
if image is None:
|
65 |
return "Please upload an image for ViT classification."
|
@@ -72,43 +74,37 @@ def classify(model_choice, image=None, file=None, *features):
|
|
72 |
|
73 |
return class_names[predicted_class]
|
74 |
|
75 |
-
elif model_choice
|
76 |
-
if
|
77 |
-
|
78 |
-
file_data = file.read().decode("utf-8").strip().split(',')
|
79 |
-
features = [float(x) for x in file_data]
|
80 |
-
except Exception as e:
|
81 |
-
return f"Error reading file: {e}"
|
82 |
-
|
83 |
-
if len(features) != len(feature_names):
|
84 |
-
return "Incorrect number of features provided. Please provide exactly 30 numerical values."
|
85 |
|
86 |
input_data = np.array(features).reshape(1, -1)
|
87 |
input_data_std = scaler.transform(input_data) if scaler else input_data
|
88 |
-
|
89 |
-
if model_choice == "Neural Network":
|
90 |
-
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
|
91 |
-
else: # CNN
|
92 |
-
prediction = cnn_model.predict(input_data_std) if cnn_model else [[0, 1]]
|
93 |
-
|
94 |
predicted_class = np.argmax(prediction)
|
|
|
95 |
return class_names[predicted_class]
|
96 |
|
97 |
# Gradio UI
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
iface.launch()
|
|
|
9 |
import joblib
|
10 |
import os
|
11 |
|
|
|
|
|
|
|
12 |
# Set device
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
# Load trained ViT model (PyTorch)
|
16 |
+
vit_model = models.vit_b_16(pretrained=False)
|
17 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
18 |
|
19 |
# Load ViT model weights
|
|
|
34 |
class_names = ["Benign", "Malignant"]
|
35 |
|
36 |
# Load trained Neural Network model (TensorFlow/Keras)
|
37 |
+
nn_model_path = "my_NN_BC_model.keras" # Update with uploaded model path
|
38 |
nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
|
39 |
|
|
|
|
|
|
|
|
|
40 |
# Load scaler for feature normalization
|
41 |
+
scaler_path = "nn_bc_scaler.pkl" # Update with uploaded model path
|
42 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
43 |
|
44 |
# Feature names
|
|
|
51 |
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
|
52 |
]
|
53 |
|
54 |
+
# Example inputs
|
55 |
+
benign_example = [13.54, 14.36, 87.46, 566.3, 0.09779, 0.08129, 0.06664, 0.04781, 0.1885, 0.05766,
|
56 |
+
0.2699, 0.7886, 2.058, 23.56, 0.008462, 0.0146, 0.02387, 0.01315, 0.0198, 0.0023,
|
57 |
+
15.11, 19.26, 99.7, 711.2, 0.144, 0.1773, 0.239, 0.1288, 0.2977, 0.07259]
|
58 |
+
|
59 |
+
malignant_example = [17.99, 10.38, 122.8, 1001.0, 0.1184, 0.2776, 0.3001, 0.1471, 0.2419, 0.07871,
|
60 |
+
1.095, 0.9053, 8.589, 153.4, 0.006399, 0.04904, 0.05373, 0.01587, 0.03003, 0.006193,
|
61 |
+
25.38, 17.33, 184.6, 2019.0, 0.1622, 0.6656, 0.7119, 0.2654, 0.4601, 0.1189]
|
62 |
+
|
63 |
+
def classify(model_choice, image=None, *features):
|
64 |
+
"""Classify using ViT (image) or NN (features)."""
|
65 |
if model_choice == "ViT":
|
66 |
if image is None:
|
67 |
return "Please upload an image for ViT classification."
|
|
|
74 |
|
75 |
return class_names[predicted_class]
|
76 |
|
77 |
+
elif model_choice == "Neural Network":
|
78 |
+
if any(f is None for f in features):
|
79 |
+
return "Please enter all 30 numerical features."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
input_data = np.array(features).reshape(1, -1)
|
82 |
input_data_std = scaler.transform(input_data) if scaler else input_data
|
83 |
+
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
|
|
|
|
|
|
|
|
|
|
|
84 |
predicted_class = np.argmax(prediction)
|
85 |
+
|
86 |
return class_names[predicted_class]
|
87 |
|
88 |
# Gradio UI
|
89 |
+
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
|
90 |
+
image_input = gr.Image(type="pil", label="Upload Mammogram Image")
|
91 |
+
feature_inputs = [gr.Number(label=feature) for feature in feature_names]
|
92 |
+
|
93 |
+
# Example buttons
|
94 |
+
def fill_example(example):
|
95 |
+
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
|
96 |
+
|
97 |
+
example_buttons = gr.Radio(["Benign Example", "Malignant Example"], label="Select Example Input")
|
98 |
+
def update_example(choice):
|
99 |
+
return fill_example(benign_example if choice == "Benign Example" else malignant_example)
|
100 |
+
|
101 |
+
iface = gr.Interface(
|
102 |
+
fn=classify,
|
103 |
+
inputs=[model_selector, image_input, example_buttons] + feature_inputs,
|
104 |
+
outputs="text",
|
105 |
+
title="Breast Cancer Classification",
|
106 |
+
description="Choose between ViT (image-based) and Neural Network (feature-based) classification.",
|
107 |
+
live=True
|
108 |
+
)
|
109 |
|
110 |
iface.launch()
|