andromeda01111 commited on
Commit
9e95ef5
Β·
verified Β·
1 Parent(s): 398f0a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -74
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import torch
3
  import torchvision.transforms as transforms
4
  import torchvision.models as models
@@ -8,43 +7,51 @@ import tensorflow as tf
8
  from PIL import Image
9
  from sklearn.preprocessing import StandardScaler
10
  import joblib
 
11
 
12
- # Disable TensorFlow GPU to avoid CUDA conflicts
13
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
 
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- # Load ViT model
17
- vit_model = models.vit_b_16(weights="DEFAULT")
18
- vit_model.heads = torch.nn.Linear(in_features=768, out_features=2)
 
 
19
  vit_model_path = "vit_bc.pth"
20
  if os.path.exists(vit_model_path):
21
  vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
22
- vit_model.to(device).eval()
 
 
 
 
 
 
 
 
23
 
24
- # Load Neural Net model
 
 
 
25
  nn_model_path = "my_NN_BC_model.keras"
26
- nn_model = None
 
 
27
  if os.path.exists(nn_model_path):
28
  try:
29
  nn_model = tf.keras.models.load_model(nn_model_path)
30
  except Exception as e:
31
- print(f"⚠️ Error loading NN model: {e}")
32
 
33
- # Load scaler
34
  scaler_path = "nn_bc_scaler.pkl"
35
  scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
36
 
37
- # Transforms
38
- transform = transforms.Compose([
39
- transforms.Resize((224, 224)),
40
- transforms.ToTensor(),
41
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
42
- ])
43
-
44
- # Labels
45
- class_names = ["Benign", "Malignant"]
46
-
47
- # Feature names (30)
48
  feature_names = [
49
  "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
50
  "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
@@ -54,88 +61,117 @@ feature_names = [
54
  "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
55
  ]
56
 
57
- # Sample Inputs
58
- benign_example = [9.504, 12.44, 60.34, 273.9, 0.1024, 0.06492, 0.02956, 0.02076, 0.1815, 0.06905,
59
- 0.2773, 0.9768, 1.909, 15.7, 0.009606, 0.01432, 0.01985, 0.01421, 0.02027, 0.002968,
60
- 10.23, 15.66, 65.13, 314.9, 0.1324, 0.1148, 0.08867, 0.06227, 0.245, 0.07773]
61
 
62
- malignant_example = [11.42, 20.38, 77.58, 386.1, 0.1425, 0.2839, 0.2414, 0.1052, 0.2597, 0.09744,
63
- 0.4956, 1.156, 3.445, 27.23, 0.00911, 0.07458, 0.05661, 0.01867, 0.05963, 0.009208,
64
- 14.91, 26.5, 98.87, 567.7, 0.2098, 0.8663, 0.6869, 0.2575, 0.6638, 0.173]
65
 
66
-
67
- # --- Classification Function ---
68
  def classify(model_choice, image=None, *features):
 
69
  if model_choice == "ViT":
70
  if image is None:
71
- return "❌ Please upload an image."
72
  image = image.convert("RGB")
73
  input_tensor = transform(image).unsqueeze(0).to(device)
 
74
  with torch.no_grad():
75
  output = vit_model(input_tensor)
76
- pred_class = torch.argmax(output, dim=1).item()
77
- return class_names[pred_class]
 
78
 
79
  elif model_choice == "Neural Network":
80
  if any(f is None for f in features):
81
- return "❌ All 30 numerical features are required."
 
82
  input_data = np.array(features).reshape(1, -1)
83
  input_data_std = scaler.transform(input_data) if scaler else input_data
84
  prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
85
- pred_class = np.argmax(prediction)
86
- return class_names[pred_class]
87
 
88
- # --- File parser for NN features ---
89
- def extract_features_from_file(file):
90
- try:
91
- content = file.read().decode("utf-8").strip()
92
- values = [float(x) for x in content.replace(",", " ").split()]
93
- if len(values) != 30:
94
- raise ValueError("Expected 30 values, got {}".format(len(values)))
95
- return values # must return a list of 30 floats
96
- except Exception as e:
97
- print(f"Error reading file: {e}")
98
- return [0.0] * 30 # return empty values for safety
99
 
100
- # --- UI with Gradio ---
101
  with gr.Blocks() as demo:
102
- gr.Markdown("## 🧠 Breast Cancer Classifier (ViT & Neural Net)")
 
103
 
104
- model_selector = gr.Radio(["ViT", "Neural Network"], label="Choose Model", value="ViT")
 
105
 
106
- with gr.Column(visible=True) as vit_section:
107
- image_input = gr.Image(label="Upload Image", type="pil")
108
- # vit_example = gr.Examples(
109
- # examples=["images/benign (1)_aug_0.png", "images/malignant (1)_aug_0.png"], # Add scan images
110
- # inputs=[image_input],
111
- # )
112
 
113
- with gr.Column(visible=False) as nn_section:
114
- file_input = gr.File(label="πŸ“‚ Upload 30-feature TXT file", file_types=[".txt"])
115
- feature_inputs = [gr.Number(label=name, visible=True, elem_id=f"feature_{i}") for i, name in enumerate(feature_names)]
116
 
117
- with gr.Row():
118
- nn_example_btn_1 = gr.Button("πŸ”΄ Malignant Example")
119
- nn_example_btn_2 = gr.Button("πŸ”΅ Benign Example")
 
 
120
 
121
- output = gr.Textbox(label="Prediction")
 
 
 
122
 
123
- classify_btn = gr.Button("πŸš€ Classify")
 
 
124
 
125
- # Event logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def toggle_inputs(choice):
127
- return (
128
- gr.update(visible=choice == "ViT"),
129
- gr.update(visible=choice == "Neural Network"),
130
- )
131
 
132
- model_selector.change(toggle_inputs, model_selector, [vit_section, nn_section])
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- nn_example_btn_1.click(lambda: malignant_example, None, feature_inputs)
135
- nn_example_btn_2.click(lambda: benign_example, None, feature_inputs)
136
 
137
- file_input.change(extract_features_from_file, file_input, outputs=feature_inputs)
 
138
 
139
- classify_btn.click(classify, inputs=[model_selector, image_input] + feature_inputs, outputs=output)
 
140
 
141
  demo.launch()
 
 
1
  import torch
2
  import torchvision.transforms as transforms
3
  import torchvision.models as models
 
7
  from PIL import Image
8
  from sklearn.preprocessing import StandardScaler
9
  import joblib
10
+ import os
11
 
12
+ # Disable GPU for TensorFlow to avoid CUDA conflicts
13
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
14
+
15
+ # Set PyTorch device
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ # Load trained ViT model (PyTorch)
19
+ vit_model = models.vit_b_16(weights="DEFAULT") # Fixed deprecated 'pretrained'
20
+ vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification
21
+
22
+ # Load ViT model weights (if available)
23
  vit_model_path = "vit_bc.pth"
24
  if os.path.exists(vit_model_path):
25
  vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))
26
+ vit_model.to(device)
27
+ vit_model.eval()
28
+
29
+ # Define image transformations for ViT
30
+ transform = transforms.Compose([
31
+ transforms.Resize((224, 224)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
34
+ ])
35
 
36
+ # Class labels
37
+ class_names = ["Benign", "Malignant"]
38
+
39
+ # Load trained Neural Network model (TensorFlow/Keras)
40
  nn_model_path = "my_NN_BC_model.keras"
41
+
42
+ nn_model = tf.keras.models.load_model(nn_model_path)
43
+
44
  if os.path.exists(nn_model_path):
45
  try:
46
  nn_model = tf.keras.models.load_model(nn_model_path)
47
  except Exception as e:
48
+ print(f"Error loading NN model: {e}")
49
 
50
+ # Load scaler for feature normalization
51
  scaler_path = "nn_bc_scaler.pkl"
52
  scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None
53
 
54
+ # Feature names
 
 
 
 
 
 
 
 
 
 
55
  feature_names = [
56
  "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness",
57
  "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension",
 
61
  "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension"
62
  ]
63
 
64
+ # Example inputs
65
+ benign_example = [9.504,12.44,60.34,273.9,0.1024,0.06492,0.02956,0.02076,0.1815,0.06905,0.2773,0.9768,
66
+ 1.909,15.7,0.009606,0.01432,0.01985,0.01421,0.02027,0.002968,10.23,15.66,65.13,314.9,
67
+ 0.1324,0.1148,0.08867,0.06227,0.245,0.07773]
68
 
69
+ malignant_example = [11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,0.4956,1.156,
70
+ 3.445,27.23,0.00911,0.07458,0.05661,0.01867,0.05963,0.009208,14.91,26.5,98.87,567.7,
71
+ 0.2098,0.8663,0.6869,0.2575,0.6638,0.173]
72
 
 
 
73
  def classify(model_choice, image=None, *features):
74
+ """Classify using ViT (image) or NN (features)."""
75
  if model_choice == "ViT":
76
  if image is None:
77
+ return "❌ Please upload an image for ViT classification."
78
  image = image.convert("RGB")
79
  input_tensor = transform(image).unsqueeze(0).to(device)
80
+
81
  with torch.no_grad():
82
  output = vit_model(input_tensor)
83
+ predicted_class = torch.argmax(output, dim=1).item()
84
+
85
+ return class_names[predicted_class]
86
 
87
  elif model_choice == "Neural Network":
88
  if any(f is None for f in features):
89
+ return "❌ Please enter all 30 numerical features."
90
+
91
  input_data = np.array(features).reshape(1, -1)
92
  input_data_std = scaler.transform(input_data) if scaler else input_data
93
  prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]]
94
+ predicted_class = np.argmax(prediction)
 
95
 
96
+ return class_names[predicted_class]
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # Gradio UI
99
  with gr.Blocks() as demo:
100
+ gr.Markdown("## 🩺 Breast Cancer Classification Model")
101
+ gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.")
102
 
103
+ with gr.Row():
104
+ model_selector = gr.Radio(["ViT", "Neural Network"], label="πŸ”¬ Choose Model", value="ViT")
105
 
106
+ image_input = gr.Image(type="pil", label="πŸ“· Upload Image (for ViT)", visible=True)
 
 
 
 
 
107
 
108
+ feature_inputs = [gr.Number(label=feature) for feature in feature_names]
 
 
109
 
110
+ # Organizing feature inputs into rows of 3 columns
111
+ with gr.Row():
112
+ with gr.Column():
113
+ for i in range(0, len(feature_inputs), 3):
114
+ gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))])
115
 
116
+ # Example buttons
117
+ def fill_example(example):
118
+ """Pre-fills example inputs."""
119
+ return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))}
120
 
121
+ with gr.Row():
122
+ example_btn_1 = gr.Button("πŸ”΄ Malignant Example")
123
+ example_btn_2 = gr.Button("πŸ”΅ Benign Example")
124
 
125
+ output_text = gr.Textbox(label="πŸ” Model Prediction", interactive=False)
126
+
127
+ def extract_features_from_file(file):
128
+ """Reads a text file and extracts numerical features."""
129
+ if file is None:
130
+ return "❌ Please upload a valid feature file."
131
+
132
+ try:
133
+ # Read and process file contents
134
+ content = file.read().decode("utf-8").strip()
135
+ values = [float(x) for x in content.replace(",", " ").split()]
136
+
137
+ # Check if we have exactly 30 features
138
+ if len(values) != 30:
139
+ return "❌ The file must contain exactly 30 numerical values."
140
+
141
+ return {feature_inputs[i]: values[i] for i in range(30)}
142
+
143
+ except Exception as e:
144
+ return f"❌ Error processing file: {e}"
145
+
146
+ # Add file upload component
147
+ file_input = gr.File(label="πŸ“‚ Upload Feature File (for NN)", type="binary", visible=False)
148
+
149
+ # Update UI logic to show file input for NN model
150
  def toggle_inputs(choice):
151
+ image_visibility = choice == "ViT"
152
+ feature_visibility = choice == "Neural Network"
153
+ file_visibility = choice == "Neural Network"
154
+ return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) + [gr.update(visible=file_visibility)]
155
 
156
+ model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs, file_input])
157
+
158
+ # Process uploaded file and populate feature fields
159
+ file_input.change(extract_features_from_file, file_input, feature_inputs)
160
+
161
+
162
+ # Toggle input fields based on model selection
163
+ """Toggle visibility of inputs based on model selection."""
164
+ def toggle_inputs(choice):
165
+ image_visibility = choice == "ViT"
166
+ feature_visibility = choice == "Neural Network"
167
+ return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs)
168
 
169
+ model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs])
 
170
 
171
+ example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs)
172
+ example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs)
173
 
174
+ classify_button = gr.Button("πŸš€ Classify")
175
+ classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text)
176
 
177
  demo.launch()