Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ vit_model = models.vit_b_16(pretrained=False)
|
|
17 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
18 |
|
19 |
# Load ViT model weights
|
20 |
-
vit_model_path = "vit_bc.pth"
|
21 |
if os.path.exists(vit_model_path):
|
22 |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
|
23 |
vit_model.to(device)
|
@@ -87,35 +87,46 @@ def classify(model_choice, image=None, *features):
|
|
87 |
|
88 |
# Gradio UI
|
89 |
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
|
90 |
-
image_input = gr.Image(type="pil", label="Upload
|
91 |
|
92 |
# feature_inputs = [gr.Number(label=feature, scale=0.5) for feature in feature_names]
|
93 |
num_columns = 3 # Change to 4 for a 4-column layout
|
94 |
|
95 |
feature_inputs = []
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
|
18 |
|
19 |
# Load ViT model weights
|
20 |
+
vit_model_path = "vit_bc.pth"
|
21 |
if os.path.exists(vit_model_path):
|
22 |
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
|
23 |
vit_model.to(device)
|
|
|
87 |
|
88 |
# Gradio UI
|
89 |
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
|
90 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
91 |
|
92 |
# feature_inputs = [gr.Number(label=feature, scale=0.5) for feature in feature_names]
|
93 |
num_columns = 3 # Change to 4 for a 4-column layout
|
94 |
|
95 |
feature_inputs = []
|
96 |
+
# Gradio UI
|
97 |
+
with gr.Blocks() as demo:
|
98 |
+
gr.Markdown("# Breast Cancer Classification")
|
99 |
+
gr.Markdown("Choose between ViT (image-based) and Neural Network (feature-based) classification.")
|
100 |
+
|
101 |
+
model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
|
102 |
+
|
103 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
104 |
+
|
105 |
+
# Arrange feature inputs in a matrix layout (3 columns)
|
106 |
+
num_columns = 3 # Change to 4 for a 4-column layout
|
107 |
+
feature_inputs = []
|
108 |
+
|
109 |
+
with gr.Row():
|
110 |
+
for i in range(0, len(feature_names), num_columns):
|
111 |
+
with gr.Column():
|
112 |
+
for feature in feature_names[i:i+num_columns]:
|
113 |
+
feature_inputs.append(gr.Number(label=feature, scale=1))
|
114 |
+
|
115 |
+
# Example buttons
|
116 |
+
def fill_example(example):
|
117 |
+
return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
|
118 |
+
|
119 |
+
examples = [
|
120 |
+
["Neural Network", None] + benign_example,
|
121 |
+
["Neural Network", None] + malignant_example
|
122 |
+
]
|
123 |
+
|
124 |
+
gr.Interface(
|
125 |
+
fn=classify,
|
126 |
+
inputs=[model_selector, image_input] + feature_inputs,
|
127 |
+
outputs="text",
|
128 |
+
examples=examples,
|
129 |
+
live=True
|
130 |
+
)
|
131 |
+
|
132 |
+
demo.launch()
|