andromeda01111 commited on
Commit
db0a0ae
·
verified ·
1 Parent(s): a3a28f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -19
app.py CHANGED
@@ -17,7 +17,7 @@ vit_model = models.vit_b_16(pretrained=False)
17
  vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
18
 
19
  # Load ViT model weights
20
- vit_model_path = "vit_bc.pth"
21
  if os.path.exists(vit_model_path):
22
  vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
23
  vit_model.to(device)
@@ -34,11 +34,11 @@ transform = transforms.Compose([
34
  class_names = ["Benign", "Malignant"]
35
 
36
  # Load trained Neural Network model (TensorFlow/Keras)
37
- nn_model_path = "my_NN_BC_model.keras" # Update with uploaded model path
38
  nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
39
 
40
  # Load scaler for feature normalization
41
- scaler_path = "nn_bc_scaler.pkl" # Update with uploaded model path
42
  scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
43
 
44
  # Feature names
@@ -85,35 +85,26 @@ def classify(model_choice, image=None, *features):
85
 
86
  return class_names[predicted_class]
87
 
88
- # Gradio UI
89
- model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
90
- image_input = gr.Image(type="pil", label="Upload Image")
91
-
92
- # feature_inputs = [gr.Number(label=feature, scale=0.5) for feature in feature_names]
93
- num_columns = 3 # Change to 4 for a 4-column layout
94
-
95
- feature_inputs = []
96
  # Gradio UI
97
  with gr.Blocks() as demo:
98
  gr.Markdown("# Breast Cancer Classification")
99
  gr.Markdown("Choose between ViT (image-based) and Neural Network (feature-based) classification.")
100
 
101
  model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
102
-
103
  image_input = gr.Image(type="pil", label="Upload Image")
104
 
105
  # Arrange feature inputs in a matrix layout (3 columns)
106
- num_columns = 3 # Change to 4 for a 4-column layout
107
  feature_inputs = []
108
 
109
  with gr.Row():
110
- for i in range(0, len(feature_names), num_columns):
111
- with gr.Column():
112
- for feature in feature_names[i:i+num_columns]:
113
- feature_inputs.append(gr.Number(label=feature, scale=1))
114
 
115
- # Example buttons
116
  def fill_example(example):
 
117
  return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
118
 
119
  examples = [
@@ -127,6 +118,6 @@ with gr.Blocks() as demo:
127
  outputs="text",
128
  examples=examples,
129
  live=True
130
- )
131
 
132
  demo.launch()
 
17
  vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
18
 
19
  # Load ViT model weights
20
+ vit_model_path = "vit_bc.pth"
21
  if os.path.exists(vit_model_path):
22
  vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
23
  vit_model.to(device)
 
34
  class_names = ["Benign", "Malignant"]
35
 
36
  # Load trained Neural Network model (TensorFlow/Keras)
37
+ nn_model_path = "my_NN_BC_model.keras"
38
  nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path) else None
39
 
40
  # Load scaler for feature normalization
41
+ scaler_path = "nn_bc_scaler.pkl"
42
  scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
43
 
44
  # Feature names
 
85
 
86
  return class_names[predicted_class]
87
 
 
 
 
 
 
 
 
 
88
  # Gradio UI
89
  with gr.Blocks() as demo:
90
  gr.Markdown("# Breast Cancer Classification")
91
  gr.Markdown("Choose between ViT (image-based) and Neural Network (feature-based) classification.")
92
 
93
  model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model")
 
94
  image_input = gr.Image(type="pil", label="Upload Image")
95
 
96
  # Arrange feature inputs in a matrix layout (3 columns)
97
+ num_columns = 3
98
  feature_inputs = []
99
 
100
  with gr.Row():
101
+ columns = [gr.Column() for _ in range(num_columns)]
102
+ for i, feature in enumerate(feature_names):
103
+ with columns[i % num_columns]:
104
+ feature_inputs.append(gr.Number(label=feature, scale=1))
105
 
 
106
  def fill_example(example):
107
+ """Pre-fills example inputs."""
108
  return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
109
 
110
  examples = [
 
118
  outputs="text",
119
  examples=examples,
120
  live=True
121
+ ).render()
122
 
123
  demo.launch()