Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import torch
|
3 |
import torchvision.transforms as transforms
|
4 |
import torchvision.models as models
|
@@ -8,43 +7,51 @@ import tensorflow as tf
|
|
8 |
from PIL import Image
|
9 |
from sklearn.preprocessing import StandardScaler
|
10 |
import joblib
|
|
|
11 |
|
12 |
-
# Disable TensorFlow
|
13 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
|
|
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
-
# Load ViT model
|
17 |
-
vit_model = models.vit_b_16(weights="DEFAULT")
|
18 |
-
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2)
|
|
|
|
|
19 |
vit_model_path = "vit_bc.pth"
|
20 |
if os.path.exists(vit_model_path):
|
21 |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
|
22 |
-
vit_model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
#
|
|
|
|
|
|
|
25 |
nn_model_path = "my_NN_BC_model.keras"
|
26 |
-
|
|
|
|
|
27 |
if os.path.exists(nn_model_path):
|
28 |
try:
|
29 |
nn_model = tf.keras.models.load_model(nn_model_path)
|
30 |
except Exception as e:
|
31 |
-
print(f"
|
32 |
|
33 |
-
# Load scaler
|
34 |
scaler_path = "nn_bc_scaler.pkl"
|
35 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
36 |
|
37 |
-
#
|
38 |
-
transform = transforms.Compose([
|
39 |
-
transforms.Resize((224, 224)),
|
40 |
-
transforms.ToTensor(),
|
41 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
42 |
-
])
|
43 |
-
|
44 |
-
# Labels
|
45 |
-
class_names = ["Benign", "Malignant"]
|
46 |
-
|
47 |
-
# Feature names (30)
|
48 |
feature_names = [
|
49 |
"Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
|
50 |
"Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
|
@@ -54,88 +61,117 @@ feature_names = [
|
|
54 |
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
|
55 |
]
|
56 |
|
57 |
-
#
|
58 |
-
benign_example = [9.504,
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
malignant_example = [11.42,
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
# --- Classification Function ---
|
68 |
def classify(model_choice, image=None, *features):
|
|
|
69 |
if model_choice == "ViT":
|
70 |
if image is None:
|
71 |
-
return "β Please upload an image."
|
72 |
image = image.convert("RGB")
|
73 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
74 |
with torch.no_grad():
|
75 |
output = vit_model(input_tensor)
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
elif model_choice == "Neural Network":
|
80 |
if any(f is None for f in features):
|
81 |
-
return "β
|
|
|
82 |
input_data = np.array(features).reshape(1, -1)
|
83 |
input_data_std = scaler.transform(input_data) if scaler else input_data
|
84 |
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
|
85 |
-
|
86 |
-
return class_names[pred_class]
|
87 |
|
88 |
-
|
89 |
-
def extract_features_from_file(file):
|
90 |
-
try:
|
91 |
-
content = file.read().decode("utf-8").strip()
|
92 |
-
values = [float(x) for x in content.replace(",", " ").split()]
|
93 |
-
if len(values) != 30:
|
94 |
-
raise ValueError("Expected 30 values, got {}".format(len(values)))
|
95 |
-
return values # must return a list of 30 floats
|
96 |
-
except Exception as e:
|
97 |
-
print(f"Error reading file: {e}")
|
98 |
-
return [0.0] * 30 # return empty values for safety
|
99 |
|
100 |
-
#
|
101 |
with gr.Blocks() as demo:
|
102 |
-
gr.Markdown("##
|
|
|
103 |
|
104 |
-
|
|
|
105 |
|
106 |
-
|
107 |
-
image_input = gr.Image(label="Upload Image", type="pil")
|
108 |
-
# vit_example = gr.Examples(
|
109 |
-
# examples=["images/benign (1)_aug_0.png", "images/malignant (1)_aug_0.png"], # Add scan images
|
110 |
-
# inputs=[image_input],
|
111 |
-
# )
|
112 |
|
113 |
-
|
114 |
-
file_input = gr.File(label="π Upload 30-feature TXT file", file_types=[".txt"])
|
115 |
-
feature_inputs = [gr.Number(label=name, visible=True, elem_id=f"feature_{i}") for i, name in enumerate(feature_names)]
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
120 |
|
121 |
-
|
|
|
|
|
|
|
122 |
|
123 |
-
|
|
|
|
|
124 |
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def toggle_inputs(choice):
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
)
|
131 |
|
132 |
-
model_selector.change(toggle_inputs, model_selector, [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
|
135 |
-
nn_example_btn_2.click(lambda: benign_example, None, feature_inputs)
|
136 |
|
137 |
-
|
|
|
138 |
|
139 |
-
|
|
|
140 |
|
141 |
demo.launch()
|
|
|
|
|
1 |
import torch
|
2 |
import torchvision.transforms as transforms
|
3 |
import torchvision.models as models
|
|
|
7 |
from PIL import Image
|
8 |
from sklearn.preprocessing import StandardScaler
|
9 |
import joblib
|
10 |
+
import os
|
11 |
|
12 |
+
# Disable GPU for TensorFlow to avoid CUDA conflicts
|
13 |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
14 |
+
|
15 |
+
# Set PyTorch device
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
+
# Load trained ViT model (PyTorch)
|
19 |
+
vit_model = models.vit_b_16(weights="DEFAULT") # Fixed deprecated 'pretrained'
|
20 |
+
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
21 |
+
|
22 |
+
# Load ViT model weights (if available)
|
23 |
vit_model_path = "vit_bc.pth"
|
24 |
if os.path.exists(vit_model_path):
|
25 |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
|
26 |
+
vit_model.to(device)
|
27 |
+
vit_model.eval()
|
28 |
+
|
29 |
+
# Define image transformations for ViT
|
30 |
+
transform = transforms.Compose([
|
31 |
+
transforms.Resize((224, 224)),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
34 |
+
])
|
35 |
|
36 |
+
# Class labels
|
37 |
+
class_names = ["Benign", "Malignant"]
|
38 |
+
|
39 |
+
# Load trained Neural Network model (TensorFlow/Keras)
|
40 |
nn_model_path = "my_NN_BC_model.keras"
|
41 |
+
|
42 |
+
nn_model = tf.keras.models.load_model(nn_model_path)
|
43 |
+
|
44 |
if os.path.exists(nn_model_path):
|
45 |
try:
|
46 |
nn_model = tf.keras.models.load_model(nn_model_path)
|
47 |
except Exception as e:
|
48 |
+
print(f"Error loading NN model: {e}")
|
49 |
|
50 |
+
# Load scaler for feature normalization
|
51 |
scaler_path = "nn_bc_scaler.pkl"
|
52 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
53 |
|
54 |
+
# Feature names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
feature_names = [
|
56 |
"Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
|
57 |
"Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
|
|
|
61 |
"Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
|
62 |
]
|
63 |
|
64 |
+
# Example inputs
|
65 |
+
benign_example = [9.504,12.44,60.34,273.9,0.1024,0.06492,0.02956,0.02076,0.1815,0.06905,0.2773,0.9768,
|
66 |
+
1.909,15.7,0.009606,0.01432,0.01985,0.01421,0.02027,0.002968,10.23,15.66,65.13,314.9,
|
67 |
+
0.1324,0.1148,0.08867,0.06227,0.245,0.07773]
|
68 |
|
69 |
+
malignant_example = [11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,0.4956,1.156,
|
70 |
+
3.445,27.23,0.00911,0.07458,0.05661,0.01867,0.05963,0.009208,14.91,26.5,98.87,567.7,
|
71 |
+
0.2098,0.8663,0.6869,0.2575,0.6638,0.173]
|
72 |
|
|
|
|
|
73 |
def classify(model_choice, image=None, *features):
|
74 |
+
"""Classify using ViT (image) or NN (features)."""
|
75 |
if model_choice == "ViT":
|
76 |
if image is None:
|
77 |
+
return "β Please upload an image for ViT classification."
|
78 |
image = image.convert("RGB")
|
79 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
80 |
+
|
81 |
with torch.no_grad():
|
82 |
output = vit_model(input_tensor)
|
83 |
+
predicted_class = torch.argmax(output, dim=1).item()
|
84 |
+
|
85 |
+
return class_names[predicted_class]
|
86 |
|
87 |
elif model_choice == "Neural Network":
|
88 |
if any(f is None for f in features):
|
89 |
+
return "β Please enter all 30 numerical features."
|
90 |
+
|
91 |
input_data = np.array(features).reshape(1, -1)
|
92 |
input_data_std = scaler.transform(input_data) if scaler else input_data
|
93 |
prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
|
94 |
+
predicted_class = np.argmax(prediction)
|
|
|
95 |
|
96 |
+
return class_names[predicted_class]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
# Gradio UI
|
99 |
with gr.Blocks() as demo:
|
100 |
+
gr.Markdown("## π©Ί Breast Cancer Classification Model")
|
101 |
+
gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.")
|
102 |
|
103 |
+
with gr.Row():
|
104 |
+
model_selector = gr.Radio(["ViT", "Neural Network"], label="π¬ Choose Model", value="ViT")
|
105 |
|
106 |
+
image_input = gr.Image(type="pil", label="π· Upload Image (for ViT)", visible=True)
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
feature_inputs = [gr.Number(label=feature) for feature in feature_names]
|
|
|
|
|
109 |
|
110 |
+
# Organizing feature inputs into rows of 3 columns
|
111 |
+
with gr.Row():
|
112 |
+
with gr.Column():
|
113 |
+
for i in range(0, len(feature_inputs), 3):
|
114 |
+
gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))])
|
115 |
|
116 |
+
# Example buttons
|
117 |
+
def fill_example(example):
|
118 |
+
"""Pre-fills example inputs."""
|
119 |
+
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
|
120 |
|
121 |
+
with gr.Row():
|
122 |
+
example_btn_1 = gr.Button("π΄ Malignant Example")
|
123 |
+
example_btn_2 = gr.Button("π΅ Benign Example")
|
124 |
|
125 |
+
output_text = gr.Textbox(label="π Model Prediction", interactive=False)
|
126 |
+
|
127 |
+
def extract_features_from_file(file):
|
128 |
+
"""Reads a text file and extracts numerical features."""
|
129 |
+
if file is None:
|
130 |
+
return "β Please upload a valid feature file."
|
131 |
+
|
132 |
+
try:
|
133 |
+
# Read and process file contents
|
134 |
+
content = file.read().decode("utf-8").strip()
|
135 |
+
values = [float(x) for x in content.replace(",", " ").split()]
|
136 |
+
|
137 |
+
# Check if we have exactly 30 features
|
138 |
+
if len(values) != 30:
|
139 |
+
return "β The file must contain exactly 30 numerical values."
|
140 |
+
|
141 |
+
return {feature_inputs[i]: values[i] for i in range(30)}
|
142 |
+
|
143 |
+
except Exception as e:
|
144 |
+
return f"β Error processing file: {e}"
|
145 |
+
|
146 |
+
# Add file upload component
|
147 |
+
file_input = gr.File(label="π Upload Feature File (for NN)", type="binary", visible=False)
|
148 |
+
|
149 |
+
# Update UI logic to show file input for NN model
|
150 |
def toggle_inputs(choice):
|
151 |
+
image_visibility = choice == "ViT"
|
152 |
+
feature_visibility = choice == "Neural Network"
|
153 |
+
file_visibility = choice == "Neural Network"
|
154 |
+
return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) + [gr.update(visible=file_visibility)]
|
155 |
|
156 |
+
model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs, file_input])
|
157 |
+
|
158 |
+
# Process uploaded file and populate feature fields
|
159 |
+
file_input.change(extract_features_from_file, file_input, feature_inputs)
|
160 |
+
|
161 |
+
|
162 |
+
# Toggle input fields based on model selection
|
163 |
+
"""Toggle visibility of inputs based on model selection."""
|
164 |
+
def toggle_inputs(choice):
|
165 |
+
image_visibility = choice == "ViT"
|
166 |
+
feature_visibility = choice == "Neural Network"
|
167 |
+
return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs)
|
168 |
|
169 |
+
model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs])
|
|
|
170 |
|
171 |
+
example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs)
|
172 |
+
example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs)
|
173 |
|
174 |
+
classify_button = gr.Button("π Classify")
|
175 |
+
classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text)
|
176 |
|
177 |
demo.launch()
|