Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ vit_model = models.vit_b_16(pretrained=False)
|
|
17 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
18 |
|
19 |
# Load ViT model weights
|
20 |
-
vit_model_path = "vit_bc.pth"
|
21 |
if os.path.exists(vit_model_path):
|
22 |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
|
23 |
vit_model.to(device)
|
@@ -34,11 +34,11 @@ transform = transforms.Compose([
|
|
34 |
class_names = ["Benign", "Malignant"]
|
35 |
|
36 |
# Load trained Neural Network model (TensorFlow/Keras)
|
37 |
-
nn_model_path = "my_NN_BC_model.keras"
|
38 |
nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
|
39 |
|
40 |
# Load scaler for feature normalization
|
41 |
-
scaler_path = "nn_bc_scaler.pkl"
|
42 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
43 |
|
44 |
# Feature names
|
@@ -85,35 +85,26 @@ def classify(model_choice, image=None, *features):
|
|
85 |
|
86 |
return class_names[predicted_class]
|
87 |
|
88 |
-
# Gradio UI
|
89 |
-
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
|
90 |
-
image_input = gr.Image(type="pil", label="Upload Image")
|
91 |
-
|
92 |
-
# feature_inputs = [gr.Number(label=feature, scale=0.5) for feature in feature_names]
|
93 |
-
num_columns = 3 # Change to 4 for a 4-column layout
|
94 |
-
|
95 |
-
feature_inputs = []
|
96 |
# Gradio UI
|
97 |
with gr.Blocks() as demo:
|
98 |
gr.Markdown("# Breast Cancer Classification")
|
99 |
gr.Markdown("Choose between ViT (image-based) and Neural Network (feature-based) classification.")
|
100 |
|
101 |
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
|
102 |
-
|
103 |
image_input = gr.Image(type="pil", label="Upload Image")
|
104 |
|
105 |
# Arrange feature inputs in a matrix layout (3 columns)
|
106 |
-
num_columns = 3
|
107 |
feature_inputs = []
|
108 |
|
109 |
with gr.Row():
|
110 |
-
for
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
|
115 |
-
# Example buttons
|
116 |
def fill_example(example):
|
|
|
117 |
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
|
118 |
|
119 |
examples = [
|
@@ -127,6 +118,6 @@ with gr.Blocks() as demo:
|
|
127 |
outputs="text",
|
128 |
examples=examples,
|
129 |
live=True
|
130 |
-
)
|
131 |
|
132 |
demo.launch()
|
|
|
17 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
18 |
|
19 |
# Load ViT model weights
|
20 |
+
vit_model_path = "vit_bc.pth"
|
21 |
if os.path.exists(vit_model_path):
|
22 |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
|
23 |
vit_model.to(device)
|
|
|
34 |
class_names = ["Benign", "Malignant"]
|
35 |
|
36 |
# Load trained Neural Network model (TensorFlow/Keras)
|
37 |
+
nn_model_path = "my_NN_BC_model.keras"
|
38 |
nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
|
39 |
|
40 |
# Load scaler for feature normalization
|
41 |
+
scaler_path = "nn_bc_scaler.pkl"
|
42 |
scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
|
43 |
|
44 |
# Feature names
|
|
|
85 |
|
86 |
return class_names[predicted_class]
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
# Gradio UI
|
89 |
with gr.Blocks() as demo:
|
90 |
gr.Markdown("# Breast Cancer Classification")
|
91 |
gr.Markdown("Choose between ViT (image-based) and Neural Network (feature-based) classification.")
|
92 |
|
93 |
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
|
|
|
94 |
image_input = gr.Image(type="pil", label="Upload Image")
|
95 |
|
96 |
# Arrange feature inputs in a matrix layout (3 columns)
|
97 |
+
num_columns = 3
|
98 |
feature_inputs = []
|
99 |
|
100 |
with gr.Row():
|
101 |
+
columns = [gr.Column() for _ in range(num_columns)]
|
102 |
+
for i, feature in enumerate(feature_names):
|
103 |
+
with columns[i % num_columns]:
|
104 |
+
feature_inputs.append(gr.Number(label=feature, scale=1))
|
105 |
|
|
|
106 |
def fill_example(example):
|
107 |
+
"""Pre-fills example inputs."""
|
108 |
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
|
109 |
|
110 |
examples = [
|
|
|
118 |
outputs="text",
|
119 |
examples=examples,
|
120 |
live=True
|
121 |
+
).render()
|
122 |
|
123 |
demo.launch()
|