ArchCoder commited on
Commit
71461a8
Β·
verified Β·
1 Parent(s): fbfd1a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -92
app.py CHANGED
@@ -8,32 +8,35 @@ import io
8
  import base64
9
  from torchvision import transforms
10
  import torch.nn.functional as F
 
 
 
 
 
 
11
 
12
- # Load the pretrained model
13
- @gr.utils.cache
14
  def load_model():
15
  """Load the pretrained brain segmentation model"""
16
- try:
17
- model = torch.hub.load(
18
- 'mateuszbuda/brain-segmentation-pytorch',
19
- 'unet',
20
- in_channels=3,
21
- out_channels=1,
22
- init_features=32,
23
- pretrained=True,
24
- force_reload=False
25
- )
26
- model.eval()
27
- return model
28
- except Exception as e:
29
- print(f"Error loading model: {e}")
30
- return None
31
-
32
- # Initialize model
33
- model = load_model()
34
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
- if model:
36
- model = model.to(device)
37
 
38
  def preprocess_image(image):
39
  """Preprocess the input image for the model"""
@@ -45,7 +48,12 @@ def preprocess_image(image):
45
  image = image.convert('RGB')
46
 
47
  # Resize to 256x256 (model's expected input size)
48
- image = image.resize((256, 256), Image.Resampling.LANCZOS)
 
 
 
 
 
49
 
50
  # Convert to tensor and normalize
51
  transform = transforms.Compose([
@@ -73,20 +81,24 @@ def create_overlay_visualization(original_img, mask, alpha=0.6):
73
 
74
  def predict_tumor(image):
75
  """Main prediction function"""
76
- if model is None:
77
- return None, "❌ Model failed to load. Please try again."
 
 
 
78
 
79
  if image is None:
80
  return None, "⚠️ Please upload an image first."
81
 
82
  try:
 
83
  # Preprocess the image
84
  input_tensor, original_img = preprocess_image(image)
85
  input_tensor = input_tensor.to(device)
86
 
87
  # Make prediction
88
  with torch.no_grad():
89
- prediction = model(input_tensor)
90
  # Apply sigmoid to get probability map
91
  prediction = torch.sigmoid(prediction)
92
  # Convert to numpy
@@ -109,24 +121,25 @@ def predict_tumor(image):
109
 
110
  # 4. Side-by-side comparison
111
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
 
112
 
113
  axes[0].imshow(original_array)
114
- axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
115
  axes[0].axis('off')
116
 
117
  axes[1].imshow(mask_colored)
118
- axes[1].set_title('Tumor Segmentation', fontsize=14, fontweight='bold')
119
  axes[1].axis('off')
120
 
121
  axes[2].imshow(overlay)
122
- axes[2].set_title('Overlay (Red = Tumor)', fontsize=14, fontweight='bold')
123
  axes[2].axis('off')
124
 
125
  plt.tight_layout()
126
 
127
  # Save plot to bytes
128
  buf = io.BytesIO()
129
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
130
  buf.seek(0)
131
  plt.close()
132
 
@@ -140,38 +153,41 @@ def predict_tumor(image):
140
 
141
  # Create analysis report
142
  analysis_text = f"""
143
- ## 🧠 Brain Tumor Segmentation Analysis
144
-
145
- **πŸ“Š Tumor Statistics:**
146
- - Total pixels analyzed: {total_pixels:,}
147
- - Tumor pixels detected: {tumor_pixels:,}
148
- - Tumor area percentage: {tumor_percentage:.2f}%
149
-
150
- **🎯 Model Performance:**
151
- - Model: U-Net with attention mechanism
152
- - Input resolution: 256Γ—256 pixels
153
- - Detection threshold: {threshold}
154
-
155
- **⚠️ Medical Disclaimer:**
156
- This is an AI tool for research purposes only.
157
- Always consult qualified medical professionals for diagnosis.
 
158
  """
159
 
 
160
  return result_image, analysis_text
161
 
162
  except Exception as e:
163
  error_msg = f"❌ Error during prediction: {str(e)}"
 
164
  return None, error_msg
165
 
166
  def clear_all():
167
  """Clear all inputs and outputs"""
168
- return None, None, ""
169
 
170
  # Custom CSS for better styling
171
  css = """
172
- #main-container {
173
- max-width: 1200px;
174
- margin: 0 auto;
175
  }
176
  #title {
177
  text-align: center;
@@ -181,21 +197,21 @@ css = """
181
  border-radius: 10px;
182
  margin-bottom: 20px;
183
  }
184
- #upload-box {
185
- border: 2px dashed #ccc;
186
- border-radius: 10px;
187
- padding: 20px;
188
- text-align: center;
189
- margin: 10px 0;
190
- }
191
  .output-image {
192
  border-radius: 10px;
193
  box-shadow: 0 4px 8px rgba(0,0,0,0.1);
194
  }
 
 
 
 
 
 
 
195
  """
196
 
197
  # Create Gradio interface
198
- with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
199
 
200
  # Header
201
  gr.HTML("""
@@ -207,34 +223,34 @@ with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
207
 
208
  with gr.Row():
209
  with gr.Column(scale=1):
210
- gr.HTML("<h3>πŸ“€ Input Image</h3>")
211
 
212
  # Image input with camera option
213
  image_input = gr.Image(
214
  label="Upload Brain MRI Scan",
215
  type="pil",
216
- sources=["upload", "webcam"], # Allow both upload and camera
217
  height=300
218
  )
219
 
220
  with gr.Row():
221
- predict_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
222
- clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
223
 
224
  gr.HTML("""
225
- <div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 8px;">
226
  <h4>πŸ“‹ Instructions:</h4>
227
- <ul>
228
  <li>Upload a brain MRI scan image</li>
229
  <li>Supported formats: PNG, JPG, JPEG</li>
230
  <li>For best results, use clear, high-contrast MRI images</li>
231
- <li>You can also use the camera to capture an image from your device</li>
232
  </ul>
233
  </div>
234
  """)
235
 
236
  with gr.Column(scale=2):
237
- gr.HTML("<h3>πŸ“Š Segmentation Results</h3>")
238
 
239
  # Output image
240
  output_image = gr.Image(
@@ -246,26 +262,31 @@ with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
246
 
247
  # Analysis text
248
  analysis_output = gr.Markdown(
249
- label="Analysis Report",
250
- value="Upload an image and click 'Analyze Image' to see results."
251
  )
252
 
253
  # Add footer with information
254
  gr.HTML("""
255
- <div style="margin-top: 30px; padding: 20px; background-color: #f9f9f9; border-radius: 10px;">
256
- <h4>πŸ”¬ About This Tool</h4>
257
- <p><strong>Model:</strong> Pre-trained U-Net architecture optimized for brain tumor segmentation</p>
258
- <p><strong>Technology:</strong> PyTorch, Deep Learning, Computer Vision</p>
259
- <p><strong>Dataset:</strong> Trained on medical MRI brain scans</p>
260
-
261
- <h4>⚠️ Important Medical Disclaimer</h4>
262
- <p style="color: #d73027; font-weight: bold;">
263
- This AI tool is for research and educational purposes only. It should NOT be used for medical diagnosis.
264
- Always consult qualified healthcare professionals for medical advice and diagnosis.
265
- </p>
266
-
267
- <p style="text-align: center; margin-top: 20px; color: #666;">
268
- Made with ❀️ using Gradio β€’ Powered by PyTorch β€’ Hosted on πŸ€— Hugging Face Spaces
 
 
 
 
 
269
  </p>
270
  </div>
271
  """)
@@ -274,26 +295,22 @@ with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
274
  predict_btn.click(
275
  fn=predict_tumor,
276
  inputs=[image_input],
277
- outputs=[output_image, analysis_output]
 
278
  )
279
 
280
  clear_btn.click(
281
  fn=clear_all,
 
282
  outputs=[image_input, output_image, analysis_output]
283
  )
284
 
285
- # Auto-predict when image is uploaded
286
- image_input.change(
287
- fn=predict_tumor,
288
- inputs=[image_input],
289
- outputs=[output_image, analysis_output]
290
- )
291
-
292
  # Launch the app
293
  if __name__ == "__main__":
 
294
  app.launch(
295
- share=True,
296
  server_name="0.0.0.0",
297
  server_port=7860,
298
- show_error=True
 
299
  )
 
8
  import base64
9
  from torchvision import transforms
10
  import torch.nn.functional as F
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ # Global variable to store model
15
+ model = None
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
 
 
18
  def load_model():
19
  """Load the pretrained brain segmentation model"""
20
+ global model
21
+ if model is None:
22
+ try:
23
+ print("Loading brain segmentation model...")
24
+ model = torch.hub.load(
25
+ 'mateuszbuda/brain-segmentation-pytorch',
26
+ 'unet',
27
+ in_channels=3,
28
+ out_channels=1,
29
+ init_features=32,
30
+ pretrained=True,
31
+ force_reload=False
32
+ )
33
+ model.eval()
34
+ model = model.to(device)
35
+ print("Model loaded successfully!")
36
+ except Exception as e:
37
+ print(f"Error loading model: {e}")
38
+ model = None
39
+ return model
 
40
 
41
  def preprocess_image(image):
42
  """Preprocess the input image for the model"""
 
48
  image = image.convert('RGB')
49
 
50
  # Resize to 256x256 (model's expected input size)
51
+ # Use LANCZOS if available, otherwise use BILINEAR
52
+ try:
53
+ image = image.resize((256, 256), Image.Resampling.LANCZOS)
54
+ except AttributeError:
55
+ # For older PIL versions
56
+ image = image.resize((256, 256), Image.LANCZOS)
57
 
58
  # Convert to tensor and normalize
59
  transform = transforms.Compose([
 
81
 
82
  def predict_tumor(image):
83
  """Main prediction function"""
84
+ # Load model if not loaded
85
+ current_model = load_model()
86
+
87
+ if current_model is None:
88
+ return None, "❌ Model failed to load. Please try again later."
89
 
90
  if image is None:
91
  return None, "⚠️ Please upload an image first."
92
 
93
  try:
94
+ print("Processing image...")
95
  # Preprocess the image
96
  input_tensor, original_img = preprocess_image(image)
97
  input_tensor = input_tensor.to(device)
98
 
99
  # Make prediction
100
  with torch.no_grad():
101
+ prediction = current_model(input_tensor)
102
  # Apply sigmoid to get probability map
103
  prediction = torch.sigmoid(prediction)
104
  # Convert to numpy
 
121
 
122
  # 4. Side-by-side comparison
123
  fig, axes = plt.subplots(1, 3, figsize=(15, 5))
124
+ fig.suptitle('Brain Tumor Segmentation Results', fontsize=16, fontweight='bold')
125
 
126
  axes[0].imshow(original_array)
127
+ axes[0].set_title('Original Image', fontsize=12, fontweight='bold')
128
  axes[0].axis('off')
129
 
130
  axes[1].imshow(mask_colored)
131
+ axes[1].set_title('Tumor Segmentation', fontsize=12, fontweight='bold')
132
  axes[1].axis('off')
133
 
134
  axes[2].imshow(overlay)
135
+ axes[2].set_title('Overlay (Red = Tumor)', fontsize=12, fontweight='bold')
136
  axes[2].axis('off')
137
 
138
  plt.tight_layout()
139
 
140
  # Save plot to bytes
141
  buf = io.BytesIO()
142
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
143
  buf.seek(0)
144
  plt.close()
145
 
 
153
 
154
  # Create analysis report
155
  analysis_text = f"""
156
+ ## 🧠 Brain Tumor Segmentation Analysis
157
+
158
+ **πŸ“Š Tumor Statistics:**
159
+ - Total pixels analyzed: {total_pixels:,}
160
+ - Tumor pixels detected: {tumor_pixels:,}
161
+ - Tumor area percentage: {tumor_percentage:.2f}%
162
+
163
+ **🎯 Model Information:**
164
+ - Model: Pre-trained U-Net for brain segmentation
165
+ - Input resolution: 256Γ—256 pixels
166
+ - Detection threshold: {threshold}
167
+ - Device: {device.type.upper()}
168
+
169
+ **⚠️ Medical Disclaimer:**
170
+ This is an AI tool for research and educational purposes only.
171
+ Always consult qualified medical professionals for diagnosis.
172
  """
173
 
174
+ print("Processing completed successfully!")
175
  return result_image, analysis_text
176
 
177
  except Exception as e:
178
  error_msg = f"❌ Error during prediction: {str(e)}"
179
+ print(error_msg)
180
  return None, error_msg
181
 
182
  def clear_all():
183
  """Clear all inputs and outputs"""
184
+ return None, None, "Upload an image and click 'Analyze Image' to see results."
185
 
186
  # Custom CSS for better styling
187
  css = """
188
+ .gradio-container {
189
+ max-width: 1200px !important;
190
+ margin: auto !important;
191
  }
192
  #title {
193
  text-align: center;
 
197
  border-radius: 10px;
198
  margin-bottom: 20px;
199
  }
 
 
 
 
 
 
 
200
  .output-image {
201
  border-radius: 10px;
202
  box-shadow: 0 4px 8px rgba(0,0,0,0.1);
203
  }
204
+ button {
205
+ border-radius: 8px;
206
+ font-weight: 500;
207
+ }
208
+ .progress-bar {
209
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
210
+ }
211
  """
212
 
213
  # Create Gradio interface
214
+ with gr.Blocks(css=css, title="🧠 Brain Tumor Segmentation AI", theme=gr.themes.Soft()) as app:
215
 
216
  # Header
217
  gr.HTML("""
 
223
 
224
  with gr.Row():
225
  with gr.Column(scale=1):
226
+ gr.Markdown("### πŸ“€ Input Image")
227
 
228
  # Image input with camera option
229
  image_input = gr.Image(
230
  label="Upload Brain MRI Scan",
231
  type="pil",
232
+ sources=["upload", "webcam"],
233
  height=300
234
  )
235
 
236
  with gr.Row():
237
+ predict_btn = gr.Button("πŸ” Analyze Image", variant="primary", scale=2)
238
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
239
 
240
  gr.HTML("""
241
+ <div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 8px; border-left: 4px solid #667eea;">
242
  <h4>πŸ“‹ Instructions:</h4>
243
+ <ul style="margin: 10px 0; padding-left: 20px;">
244
  <li>Upload a brain MRI scan image</li>
245
  <li>Supported formats: PNG, JPG, JPEG</li>
246
  <li>For best results, use clear, high-contrast MRI images</li>
247
+ <li>Camera option available for mobile devices</li>
248
  </ul>
249
  </div>
250
  """)
251
 
252
  with gr.Column(scale=2):
253
+ gr.Markdown("### πŸ“Š Segmentation Results")
254
 
255
  # Output image
256
  output_image = gr.Image(
 
262
 
263
  # Analysis text
264
  analysis_output = gr.Markdown(
265
+ value="Upload an image and click 'Analyze Image' to see results.",
266
+ elem_id="analysis"
267
  )
268
 
269
  # Add footer with information
270
  gr.HTML("""
271
+ <div style="margin-top: 30px; padding: 20px; background-color: #f9f9f9; border-radius: 10px; border: 1px solid #e1e4e8;">
272
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px;">
273
+ <div>
274
+ <h4 style="color: #667eea; margin-bottom: 10px;">πŸ”¬ About This Tool</h4>
275
+ <p><strong>Model:</strong> Pre-trained U-Net for brain segmentation</p>
276
+ <p><strong>Technology:</strong> PyTorch + Deep Learning</p>
277
+ <p><strong>Purpose:</strong> Research & Educational Use</p>
278
+ </div>
279
+ <div>
280
+ <h4 style="color: #d73027; margin-bottom: 10px;">⚠️ Medical Disclaimer</h4>
281
+ <p style="color: #d73027; font-weight: 500;">
282
+ This AI tool is for research and educational purposes only.<br>
283
+ <strong>NOT for medical diagnosis.</strong> Always consult healthcare professionals.
284
+ </p>
285
+ </div>
286
+ </div>
287
+ <hr style="margin: 20px 0; border: none; border-top: 1px solid #e1e4e8;">
288
+ <p style="text-align: center; color: #666; margin: 10px 0;">
289
+ Made with ❀️ using Gradio β€’ Powered by PyTorch β€’ Hosted on πŸ€— Hugging Face Spaces
290
  </p>
291
  </div>
292
  """)
 
295
  predict_btn.click(
296
  fn=predict_tumor,
297
  inputs=[image_input],
298
+ outputs=[output_image, analysis_output],
299
+ show_progress=True
300
  )
301
 
302
  clear_btn.click(
303
  fn=clear_all,
304
+ inputs=[],
305
  outputs=[image_input, output_image, analysis_output]
306
  )
307
 
 
 
 
 
 
 
 
308
  # Launch the app
309
  if __name__ == "__main__":
310
+ print("Starting Brain Tumor Segmentation App...")
311
  app.launch(
 
312
  server_name="0.0.0.0",
313
  server_port=7860,
314
+ show_error=True,
315
+ share=False
316
  )