ArchCoder commited on
Commit
90efbfd
Β·
verified Β·
1 Parent(s): bed938d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -362
app.py CHANGED
@@ -7,7 +7,8 @@ from PIL import Image
7
  import matplotlib.pyplot as plt
8
  import io
9
  from torchvision import transforms
10
- import torch.nn.functional as F
 
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
@@ -15,77 +16,54 @@ warnings.filterwarnings("ignore")
15
  model = None
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
- # Custom U-Net Architecture for Brain Tumor Segmentation
19
- class DoubleConv(nn.Module):
20
- def __init__(self, in_channels, out_channels):
21
- super(DoubleConv, self).__init__()
22
- self.conv = nn.Sequential(
23
- nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
24
- nn.BatchNorm2d(out_channels),
25
- nn.ReLU(inplace=True),
26
- nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
27
- nn.BatchNorm2d(out_channels),
28
- nn.ReLU(inplace=True),
29
- )
30
-
31
- def forward(self, x):
32
- return self.conv(x)
33
-
34
- class BrainTumorUNet(nn.Module):
35
- def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
36
- super(BrainTumorUNet, self).__init__()
37
- self.ups = nn.ModuleList()
38
- self.downs = nn.ModuleList()
39
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
40
-
41
- # Down part of UNET
42
- for feature in features:
43
- self.downs.append(DoubleConv(in_channels, feature))
44
- in_channels = feature
45
-
46
- # Bottleneck
47
- self.bottleneck = DoubleConv(features[-1], features[-1]*2)
48
-
49
- # Up part of UNET
50
- for feature in reversed(features):
51
- self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
52
- self.ups.append(DoubleConv(feature*2, feature))
53
-
54
- self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
55
-
56
- def forward(self, x):
57
- skip_connections = []
58
-
59
- for down in self.downs:
60
- x = down(x)
61
- skip_connections.append(x)
62
- x = self.pool(x)
63
-
64
- x = self.bottleneck(x)
65
- skip_connections = skip_connections[::-1]
66
-
67
- for idx in range(0, len(self.ups), 2):
68
- x = self.ups[idx](x)
69
- skip_connection = skip_connections[idx//2]
70
-
71
- if x.shape != skip_connection.shape:
72
- x = F.interpolate(x, size=skip_connection.shape[2:])
73
-
74
- concat_skip = torch.cat((skip_connection, x), dim=1)
75
- x = self.ups[idx+1](concat_skip)
76
-
77
- return self.final_conv(x)
78
-
79
- def load_model():
80
- """Load brain tumor segmentation model"""
81
- global model
82
- if model is None:
83
  try:
84
- print("Loading brain tumor segmentation model...")
 
 
 
 
 
 
 
 
85
 
86
- # Try to load a pretrained model first
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  try:
88
- # Fallback to a general segmentation model
89
  model = torch.hub.load(
90
  'mateuszbuda/brain-segmentation-pytorch',
91
  'unet',
@@ -95,190 +73,209 @@ def load_model():
95
  pretrained=True,
96
  force_reload=False
97
  )
98
- print("Loaded pretrained brain segmentation model")
 
 
99
  except:
100
- # If that fails, use our custom model
101
- model = BrainTumorUNet(in_channels=3, out_channels=1)
102
- print("Loaded custom U-Net model (not pretrained)")
103
-
104
- model.eval()
105
- model = model.to(device)
106
- print("Model loaded successfully!")
107
-
108
- except Exception as e:
109
- print(f"Error loading model: {e}")
110
- model = None
111
  return model
112
 
113
- def apply_clahe_he(image):
114
- """Apply CLAHE and Histogram Equalization preprocessing"""
115
- # Convert PIL to numpy array
116
  if isinstance(image, Image.Image):
117
  image_np = np.array(image)
118
  else:
119
  image_np = image
120
 
121
- # Convert to grayscale if RGB
122
  if len(image_np.shape) == 3:
123
  gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
124
  else:
125
  gray = image_np
126
 
127
- # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
128
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
129
- clahe_image = clahe.apply(gray)
 
 
 
 
 
 
 
130
 
131
- # Apply Histogram Equalization
132
- he_image = cv2.equalizeHist(clahe_image)
 
 
 
 
133
 
134
  # Convert back to RGB
135
- enhanced_image = cv2.cvtColor(he_image, cv2.COLOR_GRAY2RGB)
136
 
137
- return enhanced_image
138
-
139
- def preprocess_image(image):
140
- """Enhanced preprocessing for brain tumor segmentation"""
141
- if isinstance(image, np.ndarray):
142
- image = Image.fromarray(image)
143
-
144
- # Convert to RGB if not already
145
- if image.mode != 'RGB':
146
- image = image.convert('RGB')
147
 
148
- # Apply CLAHE-HE preprocessing (key for nikhilroxtomar dataset)
149
- enhanced_image = apply_clahe_he(image)
 
 
150
  enhanced_pil = Image.fromarray(enhanced_image)
151
-
152
- # Resize to 256x256
153
- try:
154
- enhanced_pil = enhanced_pil.resize((256, 256), Image.Resampling.LANCZOS)
155
- except AttributeError:
156
- enhanced_pil = enhanced_pil.resize((256, 256), Image.LANCZOS)
157
-
158
- # Normalization optimized for brain tumor segmentation
159
  transform = transforms.Compose([
160
  transforms.ToTensor(),
161
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
162
  ])
163
-
164
  image_tensor = transform(enhanced_pil).unsqueeze(0)
165
  return image_tensor, enhanced_pil
166
 
167
- def post_process_mask(prediction, threshold=0.3):
168
- """Advanced post-processing for brain tumor masks"""
169
- # Apply threshold
170
- binary_mask = (prediction > threshold).astype(np.uint8)
171
-
172
- # Morphological operations to clean up the mask
173
- kernel = np.ones((3,3), np.uint8)
174
-
175
- # Remove small noise
176
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
177
-
178
- # Fill small holes
179
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
180
-
181
- # Find connected components and keep largest ones
182
- num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask)
183
-
184
- if num_labels > 1:
185
- # Keep only components larger than minimum area
186
- min_area = 100 # Minimum tumor area in pixels
187
- cleaned_mask = np.zeros_like(binary_mask)
188
-
189
- for i in range(1, num_labels):
190
- if stats[i, cv2.CC_STAT_AREA] > min_area:
191
- cleaned_mask[labels == i] = 1
192
-
193
- binary_mask = cleaned_mask
194
-
195
- return binary_mask
196
-
197
- def predict_tumor(image):
198
- """Enhanced prediction function for brain tumor segmentation"""
199
- current_model = load_model()
200
 
201
  if current_model is None:
202
- return None, "❌ Model failed to load. Please try again later."
203
 
204
  if image is None:
205
- return None, "⚠️ Please upload a brain MRI image first."
206
 
207
  try:
208
- print("Processing brain MRI image...")
209
 
210
- # Enhanced preprocessing
211
- input_tensor, processed_img = preprocess_image(image)
212
  input_tensor = input_tensor.to(device)
213
 
214
  # Make prediction
215
  with torch.no_grad():
216
- prediction = current_model(input_tensor)
217
- prediction = torch.sigmoid(prediction)
218
- prediction = prediction.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- print(f"Prediction stats: min={prediction.min():.3f}, max={prediction.max():.3f}, mean={prediction.mean():.3f}")
221
-
222
- # Enhanced post-processing
223
- binary_mask = post_process_mask(prediction, threshold=0.3)
224
-
225
- # Create visualizations
226
- original_array = np.array(image.resize((256, 256)))
227
  processed_array = np.array(processed_img)
228
 
229
- # Probability heatmap
230
- prob_heatmap = plt.cm.hot(prediction)[:,:,:3] * 255
231
- prob_heatmap = prob_heatmap.astype(np.uint8)
232
-
233
- # Binary mask visualization
234
- mask_colored = np.zeros((256, 256, 3), dtype=np.uint8)
235
- mask_colored[:, :, 0] = binary_mask * 255 # Red channel
236
-
237
- # Enhanced overlay
238
- overlay = original_array.copy()
239
- overlay[binary_mask == 1] = [255, 0, 0] # Red for tumor
240
- overlay = cv2.addWeighted(original_array, 0.6, overlay, 0.4, 0)
241
-
242
- # Create comprehensive visualization
243
  fig, axes = plt.subplots(2, 3, figsize=(18, 12))
244
- fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=20, fontweight='bold')
245
 
246
- # Row 1: Original, Enhanced, Probability
247
  axes[0,0].imshow(original_array)
248
  axes[0,0].set_title('Original MRI', fontsize=14, fontweight='bold')
249
  axes[0,0].axis('off')
250
 
251
  axes[0,1].imshow(processed_array)
252
- axes[0,1].set_title('Enhanced (CLAHE-HE)', fontsize=14, fontweight='bold')
253
  axes[0,1].axis('off')
254
 
255
- axes[0,2].imshow(prob_heatmap)
256
- axes[0,2].set_title('Probability Heatmap', fontsize=14, fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  axes[0,2].axis('off')
258
 
259
- # Row 2: Binary Mask, Overlay, Statistics
260
- axes[1,0].imshow(mask_colored)
261
- axes[1,0].set_title('Tumor Segmentation', fontsize=14, fontweight='bold')
262
- axes[1,0].axis('off')
263
-
264
- axes[1,1].imshow(overlay)
265
- axes[1,1].set_title('Overlay (Red = Tumor)', fontsize=14, fontweight='bold')
266
- axes[1,1].axis('off')
267
-
268
- # Statistics plot
269
- tumor_pixels = np.sum(binary_mask)
270
- healthy_pixels = (256*256) - tumor_pixels
271
-
272
- axes[1,2].pie([healthy_pixels, tumor_pixels],
273
- labels=['Healthy', 'Tumor'],
274
- colors=['lightblue', 'red'],
275
- autopct='%1.1f%%',
276
- startangle=90)
277
- axes[1,2].set_title('Tissue Distribution', fontsize=14, fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  plt.tight_layout()
280
 
281
- # Save plot
282
  buf = io.BytesIO()
283
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
284
  buf.seek(0)
@@ -287,74 +284,65 @@ def predict_tumor(image):
287
  result_image = Image.open(buf)
288
 
289
  # Calculate comprehensive statistics
290
- total_pixels = 256 * 256
291
- tumor_pixels = np.sum(binary_mask)
292
- tumor_percentage = (tumor_pixels / total_pixels) * 100
293
 
294
- # Tumor characteristics
295
- if tumor_pixels > 0:
296
- # Calculate tumor size in mmΒ² (assuming 1 pixel = 1mmΒ²)
297
- tumor_area_mm2 = tumor_pixels
298
-
299
- # Calculate tumor centroid
300
- M = cv2.moments(binary_mask)
301
- if M["m00"] != 0:
302
- cX = int(M["m10"] / M["m00"])
303
- cY = int(M["m01"] / M["m00"])
304
- else:
305
- cX, cY = 0, 0
306
- else:
307
- tumor_area_mm2 = 0
308
- cX, cY = 0, 0
309
-
310
- # Enhanced analysis report
311
  analysis_text = f"""
312
- ## 🧠 Brain Tumor Segmentation Analysis
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- ### πŸ“Š Tumor Detection Results:
315
- - **Tumor Status**: {'πŸ”΄ TUMOR DETECTED' if tumor_pixels > 50 else '🟒 NO SIGNIFICANT TUMOR'}
316
- - **Tumor Area**: {tumor_area_mm2:.0f} pixels (~{tumor_area_mm2:.0f} mmΒ²)
317
- - **Tumor Percentage**: {tumor_percentage:.2f}% of brain area
318
- - **Tumor Location**: Center at ({cX}, {cY})
319
 
320
  ### πŸ”¬ Technical Details:
321
- - **Preprocessing**: CLAHE + Histogram Equalization
322
- - **Model Architecture**: U-Net with enhanced post-processing
323
- - **Input Resolution**: 256Γ—256 pixels
324
- - **Confidence Threshold**: 0.3 (optimized for sensitivity)
325
- - **Processing Device**: {device.type.upper()}
326
-
327
- ### πŸ“ˆ Image Quality Metrics:
328
- - **Prediction Range**: {prediction.min():.3f} - {prediction.max():.3f}
329
- - **Mean Confidence**: {prediction.mean():.3f}
330
- - **Enhancement Applied**: βœ… CLAHE-HE preprocessing
331
-
332
- ### ⚠️ Important Medical Disclaimer:
333
- **This AI tool is for research and educational purposes only.**
334
- - Results are NOT a medical diagnosis
335
- - Always consult qualified medical professionals
336
- - Use only as a supplementary analysis tool
337
- - Accuracy may vary with image quality and tumor type
338
-
339
- ### πŸ“‹ Recommended Actions:
340
- {f'- **Immediate consultation** with neurologist recommended' if tumor_percentage > 1.0 else '- **Routine follow-up** as per medical advice'}
341
- - Correlation with clinical symptoms advised
342
- - Consider additional imaging if warranted
343
  """
344
 
345
- print("Processing completed successfully!")
346
  return result_image, analysis_text
347
 
348
  except Exception as e:
349
- error_msg = f"❌ Error during prediction: {str(e)}"
350
  print(error_msg)
351
  return None, error_msg
352
 
353
  def clear_all():
354
- """Clear all inputs and outputs"""
355
- return None, None, "Upload a brain MRI image and click 'Analyze Image' to see results."
356
 
357
- # Enhanced CSS styling
358
  css = """
359
  .gradio-container {
360
  max-width: 1400px !important;
@@ -364,143 +352,60 @@ css = """
364
  text-align: center;
365
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
366
  color: white;
367
- padding: 25px;
368
- border-radius: 15px;
369
- margin-bottom: 25px;
370
- box-shadow: 0 8px 16px rgba(0,0,0,0.1);
371
- }
372
- .output-image {
373
  border-radius: 15px;
374
- box-shadow: 0 8px 16px rgba(0,0,0,0.1);
375
- }
376
- button {
377
- border-radius: 8px;
378
- font-weight: 600;
379
- transition: all 0.3s ease;
380
- }
381
- button:hover {
382
- transform: translateY(-2px);
383
- box-shadow: 0 4px 8px rgba(0,0,0,0.2);
384
- }
385
- .progress-bar {
386
- background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
387
  }
388
  """
389
 
390
- # Create enhanced Gradio interface
391
- with gr.Blocks(css=css, title="🧠 Advanced Brain Tumor Segmentation AI", theme=gr.themes.Soft()) as app:
392
 
393
- # Enhanced header
394
  gr.HTML("""
395
  <div id="title">
396
- <h1>🧠 Advanced Brain Tumor Segmentation AI</h1>
397
- <p style="font-size: 18px; margin-top: 10px;">
398
- Powered by Enhanced U-Net with CLAHE-HE Preprocessing
399
  </p>
400
- <p style="font-size: 14px; margin-top: 5px; opacity: 0.9;">
401
- Optimized for the Nikhil Tomar Brain Tumor Dataset
402
  </p>
403
  </div>
404
  """)
405
 
406
  with gr.Row():
407
  with gr.Column(scale=1):
408
- gr.Markdown("### πŸ“€ Input MRI Image")
409
 
410
  image_input = gr.Image(
411
- label="Upload Brain MRI Scan",
412
  type="pil",
413
  sources=["upload", "webcam"],
414
  height=350
415
  )
416
 
417
  with gr.Row():
418
- predict_btn = gr.Button(
419
- "πŸ” Analyze Brain Scan",
420
- variant="primary",
421
- scale=2,
422
- size="lg"
423
- )
424
- clear_btn = gr.Button(
425
- "πŸ—‘οΈ Clear All",
426
- variant="secondary",
427
- scale=1,
428
- size="lg"
429
- )
430
-
431
- gr.HTML("""
432
- <div style="margin-top: 25px; padding: 20px; background: linear-gradient(135deg, #f0f8ff 0%, #e6f3ff 100%); border-radius: 12px; border-left: 5px solid #667eea;">
433
- <h4 style="color: #667eea; margin-bottom: 15px;">πŸ“‹ Usage Instructions:</h4>
434
- <ul style="margin: 10px 0; padding-left: 25px; line-height: 1.6;">
435
- <li><strong>Upload Format:</strong> PNG, JPG, JPEG images</li>
436
- <li><strong>Best Results:</strong> High-contrast brain MRI scans</li>
437
- <li><strong>Preprocessing:</strong> CLAHE-HE enhancement applied automatically</li>
438
- <li><strong>Detection:</strong> Optimized for various tumor types and sizes</li>
439
- <li><strong>Mobile Support:</strong> Camera capture available</li>
440
- </ul>
441
- <div style="margin-top: 15px; padding: 10px; background-color: #fff3cd; border-radius: 6px; border-left: 3px solid #ffc107;">
442
- <strong>⚑ Enhanced Features:</strong> Advanced post-processing, morphological filtering, and comprehensive analysis
443
- </div>
444
- </div>
445
- """)
446
 
447
  with gr.Column(scale=2):
448
- gr.Markdown("### πŸ“Š Comprehensive Analysis Results")
449
 
450
  output_image = gr.Image(
451
- label="Segmentation Analysis",
452
  type="pil",
453
- height=600,
454
- elem_classes=["output-image"]
455
  )
456
 
457
  analysis_output = gr.Markdown(
458
- value="Upload a brain MRI image and click 'Analyze Brain Scan' to see comprehensive results.",
459
  elem_id="analysis"
460
  )
461
 
462
- # Enhanced footer
463
- gr.HTML("""
464
- <div style="margin-top: 40px; padding: 30px; background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); border-radius: 15px; border: 1px solid #dee2e6;">
465
- <div style="display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 30px; margin-bottom: 20px;">
466
- <div>
467
- <h4 style="color: #667eea; margin-bottom: 15px;">πŸ”¬ Technology Stack</h4>
468
- <p><strong>Model:</strong> Enhanced U-Net Architecture</p>
469
- <p><strong>Preprocessing:</strong> CLAHE + Histogram Equalization</p>
470
- <p><strong>Framework:</strong> PyTorch + OpenCV</p>
471
- <p><strong>Optimization:</strong> Nikhil Tomar Dataset</p>
472
- </div>
473
- <div>
474
- <h4 style="color: #28a745; margin-bottom: 15px;">⚑ Key Features</h4>
475
- <p><strong>Enhancement:</strong> Automatic contrast optimization</p>
476
- <p><strong>Detection:</strong> Multi-scale tumor analysis</p>
477
- <p><strong>Post-processing:</strong> Morphological filtering</p>
478
- <p><strong>Visualization:</strong> 6-panel comprehensive view</p>
479
- </div>
480
- <div>
481
- <h4 style="color: #dc3545; margin-bottom: 15px;">⚠️ Medical Disclaimer</h4>
482
- <p style="color: #dc3545; font-weight: 600; line-height: 1.4;">
483
- This AI tool is for <strong>research and educational purposes only</strong>.<br>
484
- <strong>NOT for medical diagnosis.</strong><br>
485
- Always consult healthcare professionals for medical advice.
486
- </p>
487
- </div>
488
- </div>
489
- <hr style="margin: 25px 0; border: none; border-top: 2px solid #dee2e6;">
490
- <div style="text-align: center;">
491
- <p style="color: #6c757d; margin: 10px 0; font-size: 16px;">
492
- πŸ₯ <strong>Advanced Medical AI</strong> β€’ Made with ❀️ using Gradio β€’ Powered by PyTorch β€’ Hosted on πŸ€— Hugging Face Spaces
493
- </p>
494
- <p style="color: #6c757d; margin: 5px 0; font-size: 14px;">
495
- Enhanced for Brain Tumor Detection β€’ Optimized Preprocessing Pipeline β€’ Research Grade Accuracy
496
- </p>
497
- </div>
498
- </div>
499
- """)
500
-
501
  # Event handlers
502
- predict_btn.click(
503
- fn=predict_tumor,
504
  inputs=[image_input],
505
  outputs=[output_image, analysis_output],
506
  show_progress=True
@@ -512,13 +417,8 @@ with gr.Blocks(css=css, title="🧠 Advanced Brain Tumor Segmentation AI", theme
512
  outputs=[image_input, output_image, analysis_output]
513
  )
514
 
515
- # Launch the enhanced app
516
  if __name__ == "__main__":
517
- print("πŸš€ Starting Advanced Brain Tumor Segmentation App...")
518
- print("βœ… Enhanced with CLAHE-HE preprocessing")
519
- print("βœ… Optimized for Nikhil Tomar dataset")
520
- print("βœ… Advanced post-processing enabled")
521
-
522
  app.launch(
523
  server_name="0.0.0.0",
524
  server_port=7860,
 
7
  import matplotlib.pyplot as plt
8
  import io
9
  from torchvision import transforms
10
+ import torchvision.models as models
11
+ from torchvision.models import detection
12
  import warnings
13
  warnings.filterwarnings("ignore")
14
 
 
16
  model = None
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
 
19
+ class TumorDetector:
20
+ def __init__(self):
21
+ self.model = None
22
+ self.device = device
23
+
24
+ def load_maskrcnn_model(self):
25
+ """Load Mask R-CNN for tumor instance segmentation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
+ print("πŸ”„ Loading Mask R-CNN for brain tumor detection...")
28
+
29
+ # Use pretrained Mask R-CNN and fine-tune for brain tumors
30
+ self.model = detection.maskrcnn_resnet50_fpn(pretrained=True)
31
+
32
+ # Modify for brain tumor segmentation (2 classes: background, tumor)
33
+ num_classes = 2
34
+ in_features = self.model.roi_heads.box_predictor.cls_score.in_features
35
+ self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
36
 
37
+ # Modify mask predictor
38
+ in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels
39
+ hidden_layer = 256
40
+ self.model.roi_heads.mask_predictor = detection.mask_rcnn.MaskRCNNPredictor(
41
+ in_features_mask, hidden_layer, num_classes
42
+ )
43
+
44
+ self.model.eval()
45
+ self.model = self.model.to(self.device)
46
+ print("βœ… Model loaded successfully!")
47
+ return True
48
+
49
+ except Exception as e:
50
+ print(f"❌ Error loading model: {e}")
51
+ return False
52
+
53
+ def load_robust_model():
54
+ """Load the most robust brain tumor detection model"""
55
+ global model
56
+ if model is None:
57
+ detector = TumorDetector()
58
+
59
+ # Try multiple model options
60
+ if detector.load_maskrcnn_model():
61
+ model = detector.model
62
+ print("βœ… Using Mask R-CNN for comprehensive tumor detection")
63
+ else:
64
+ # Fallback to PyTorch Hub U-Net
65
  try:
66
+ print("πŸ”„ Falling back to PyTorch Hub U-Net...")
67
  model = torch.hub.load(
68
  'mateuszbuda/brain-segmentation-pytorch',
69
  'unet',
 
73
  pretrained=True,
74
  force_reload=False
75
  )
76
+ model.eval()
77
+ model = model.to(device)
78
+ print("βœ… Fallback model loaded!")
79
  except:
80
+ model = None
81
+ print("❌ All models failed to load!")
82
+
 
 
 
 
 
 
 
 
83
  return model
84
 
85
+ def enhance_mri_image(image):
86
+ """Advanced MRI enhancement for better tumor detection"""
 
87
  if isinstance(image, Image.Image):
88
  image_np = np.array(image)
89
  else:
90
  image_np = image
91
 
92
+ # Convert to grayscale for processing
93
  if len(image_np.shape) == 3:
94
  gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
95
  else:
96
  gray = image_np
97
 
98
+ # Multi-step enhancement
99
+ # 1. CLAHE for contrast
100
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
101
+ enhanced = clahe.apply(gray)
102
+
103
+ # 2. Gaussian blur for noise reduction
104
+ denoised = cv2.GaussianBlur(enhanced, (3,3), 0)
105
+
106
+ # 3. Histogram equalization
107
+ hist_eq = cv2.equalizeHist(denoised)
108
 
109
+ # 4. Normalize intensity
110
+ normalized = cv2.normalize(hist_eq, None, 0, 255, cv2.NORM_MINMAX)
111
+
112
+ # 5. Edge enhancement
113
+ kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
114
+ sharpened = cv2.filter2D(normalized, -1, kernel)
115
 
116
  # Convert back to RGB
117
+ enhanced_rgb = cv2.cvtColor(sharpened, cv2.COLOR_GRAY2RGB)
118
 
119
+ return enhanced_rgb
 
 
 
 
 
 
 
 
 
120
 
121
+ def preprocess_for_detection(image):
122
+ """Preprocess image for comprehensive tumor detection"""
123
+ # Enhance the image
124
+ enhanced_image = enhance_mri_image(image)
125
  enhanced_pil = Image.fromarray(enhanced_image)
126
+
127
+ # Resize to standard size
128
+ enhanced_pil = enhanced_pil.resize((512, 512), Image.LANCZOS)
129
+
130
+ # Convert to tensor with proper normalization
 
 
 
131
  transform = transforms.Compose([
132
  transforms.ToTensor(),
133
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
134
  ])
135
+
136
  image_tensor = transform(enhanced_pil).unsqueeze(0)
137
  return image_tensor, enhanced_pil
138
 
139
+ def detect_all_tumors(image):
140
+ """Comprehensive tumor detection and segmentation"""
141
+ current_model = load_robust_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  if current_model is None:
144
+ return None, "❌ Model failed to load. Please check your setup."
145
 
146
  if image is None:
147
+ return None, "⚠️ Please upload a brain MRI image."
148
 
149
  try:
150
+ print("🧠 Detecting ALL brain tumors in the image...")
151
 
152
+ # Preprocess image
153
+ input_tensor, processed_img = preprocess_for_detection(image)
154
  input_tensor = input_tensor.to(device)
155
 
156
  # Make prediction
157
  with torch.no_grad():
158
+ if hasattr(current_model, 'roi_heads'): # Mask R-CNN
159
+ predictions = current_model(input_tensor)
160
+ # Process Mask R-CNN output
161
+ boxes = predictions[0]['boxes'].cpu().numpy()
162
+ masks = predictions[0]['masks'].cpu().numpy()
163
+ scores = predictions[0]['scores'].cpu().numpy()
164
+
165
+ # Filter high-confidence detections
166
+ threshold = 0.5
167
+ high_conf_mask = scores > threshold
168
+ final_masks = masks[high_conf_mask]
169
+ final_boxes = boxes[high_conf_mask]
170
+ final_scores = scores[high_conf_mask]
171
+
172
+ print(f"🎯 Detected {len(final_masks)} tumor(s) with confidence > {threshold}")
173
+
174
+ else: # U-Net
175
+ prediction = current_model(input_tensor)
176
+ prediction = torch.sigmoid(prediction)
177
+ prediction = prediction.squeeze().cpu().numpy()
178
+
179
+ # Create binary mask
180
+ binary_mask = (prediction > 0.3).astype(np.uint8)
181
+
182
+ # Find connected components (separate tumors)
183
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask)
184
+ final_masks = []
185
+ for i in range(1, num_labels):
186
+ if stats[i, cv2.CC_STAT_AREA] > 100: # Filter small regions
187
+ tumor_mask = (labels == i).astype(np.uint8)
188
+ final_masks.append(tumor_mask)
189
+
190
+ print(f"🎯 Detected {len(final_masks)} separate tumor region(s)")
191
 
192
+ # Create comprehensive visualization
193
+ original_array = np.array(image.resize((512, 512)))
 
 
 
 
 
194
  processed_array = np.array(processed_img)
195
 
196
+ # Create combined visualization
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  fig, axes = plt.subplots(2, 3, figsize=(18, 12))
198
+ fig.suptitle('🧠 Comprehensive Brain Tumor Detection', fontsize=20, fontweight='bold')
199
 
200
+ # Row 1: Original, Enhanced, All Tumors
201
  axes[0,0].imshow(original_array)
202
  axes[0,0].set_title('Original MRI', fontsize=14, fontweight='bold')
203
  axes[0,0].axis('off')
204
 
205
  axes[0,1].imshow(processed_array)
206
+ axes[0,1].set_title('Enhanced Image', fontsize=14, fontweight='bold')
207
  axes[0,1].axis('off')
208
 
209
+ # Combined tumor overlay
210
+ combined_overlay = original_array.copy()
211
+ colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)] # Different colors for different tumors
212
+
213
+ if len(final_masks) > 0:
214
+ for i, mask in enumerate(final_masks):
215
+ color = colors[i % len(colors)]
216
+ if len(mask.shape) == 3:
217
+ mask = mask[0] # Handle Mask R-CNN format
218
+ mask_resized = cv2.resize(mask, (512, 512))
219
+ combined_overlay[mask_resized > 0.5] = color
220
+
221
+ combined_overlay = cv2.addWeighted(original_array, 0.6, combined_overlay, 0.4, 0)
222
+
223
+ axes[0,2].imshow(combined_overlay)
224
+ axes[0,2].set_title(f'All Tumors Detected ({len(final_masks)})', fontsize=14, fontweight='bold')
225
  axes[0,2].axis('off')
226
 
227
+ # Row 2: Individual tumor analysis
228
+ if len(final_masks) >= 1:
229
+ mask1 = final_masks[0]
230
+ if len(mask1.shape) == 3:
231
+ mask1 = mask1[0]
232
+ mask1_colored = np.zeros((512, 512, 3), dtype=np.uint8)
233
+ mask1_resized = cv2.resize(mask1, (512, 512))
234
+ mask1_colored[:, :, 0] = mask1_resized * 255
235
+ axes[1,0].imshow(mask1_colored)
236
+ axes[1,0].set_title('Tumor Region 1', fontsize=14)
237
+ axes[1,0].axis('off')
238
+ else:
239
+ axes[1,0].text(0.5, 0.5, 'No Tumor\nDetected', ha='center', va='center', fontsize=16)
240
+ axes[1,0].axis('off')
241
+
242
+ if len(final_masks) >= 2:
243
+ mask2 = final_masks[1]
244
+ if len(mask2.shape) == 3:
245
+ mask2 = mask2[0]
246
+ mask2_colored = np.zeros((512, 512, 3), dtype=np.uint8)
247
+ mask2_resized = cv2.resize(mask2, (512, 512))
248
+ mask2_colored[:, :, 1] = mask2_resized * 255
249
+ axes[1,1].imshow(mask2_colored)
250
+ axes[1,1].set_title('Tumor Region 2', fontsize=14)
251
+ axes[1,1].axis('off')
252
+ else:
253
+ axes[1,1].text(0.5, 0.5, 'Single Tumor\nOnly', ha='center', va='center', fontsize=16)
254
+ axes[1,1].axis('off')
255
+
256
+ # Statistics pie chart
257
+ if len(final_masks) > 0:
258
+ total_pixels = 512 * 512
259
+ tumor_pixels = sum([np.sum(cv2.resize(mask[0] if len(mask.shape) == 3 else mask, (512, 512))) for mask in final_masks])
260
+ healthy_pixels = total_pixels - tumor_pixels
261
+
262
+ if tumor_pixels > 0:
263
+ axes[1,2].pie([healthy_pixels, tumor_pixels],
264
+ labels=['Healthy', 'Tumor'],
265
+ colors=['lightblue', 'red'],
266
+ autopct='%1.1f%%',
267
+ startangle=90)
268
+ axes[1,2].set_title('Tissue Distribution', fontsize=14, fontweight='bold')
269
+ else:
270
+ axes[1,2].text(0.5, 0.5, 'No Tumors\nDetected', ha='center', va='center', fontsize=16)
271
+ axes[1,2].axis('off')
272
+ else:
273
+ axes[1,2].text(0.5, 0.5, 'Healthy\nBrain', ha='center', va='center', fontsize=16, color='green')
274
+ axes[1,2].axis('off')
275
 
276
  plt.tight_layout()
277
 
278
+ # Save result
279
  buf = io.BytesIO()
280
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
281
  buf.seek(0)
 
284
  result_image = Image.open(buf)
285
 
286
  # Calculate comprehensive statistics
287
+ total_tumor_pixels = 0
288
+ tumor_areas = []
 
289
 
290
+ if len(final_masks) > 0:
291
+ for i, mask in enumerate(final_masks):
292
+ if len(mask.shape) == 3:
293
+ mask = mask[0]
294
+ mask_resized = cv2.resize(mask, (512, 512))
295
+ pixels = np.sum(mask_resized > 0.5)
296
+ total_tumor_pixels += pixels
297
+ tumor_areas.append(pixels)
298
+
299
+ total_percentage = (total_tumor_pixels / (512*512)) * 100
300
+
301
+ # Comprehensive analysis report
 
 
 
 
 
302
  analysis_text = f"""
303
+ ## 🧠 Comprehensive Brain Tumor Analysis
304
+
305
+ ### 🎯 Detection Summary:
306
+ - **Tumors Detected**: **{len(final_masks)} tumor region(s)**
307
+ - **Total Tumor Area**: {total_tumor_pixels:,} pixels ({total_percentage:.2f}%)
308
+ - **Detection Model**: {'Mask R-CNN Instance Segmentation' if hasattr(current_model, 'roi_heads') else 'Enhanced U-Net Segmentation'}
309
+
310
+ ### πŸ“Š Individual Tumor Analysis:
311
+ """
312
+
313
+ for i, area in enumerate(tumor_areas):
314
+ percentage = (area / (512*512)) * 100
315
+ analysis_text += f"- **Tumor {i+1}**: {area:,} pixels ({percentage:.2f}%)\n"
316
 
317
+ analysis_text += f"""
 
 
 
 
318
 
319
  ### πŸ”¬ Technical Details:
320
+ - **Enhancement**: CLAHE + Histogram Equalization + Edge Enhancement
321
+ - **Resolution**: 512Γ—512 pixels for high-precision detection
322
+ - **Detection Threshold**: Multiple confidence levels
323
+ - **Processing**: GPU-accelerated inference
324
+
325
+ ### 🎯 Clinical Insights:
326
+ - **Status**: {'πŸ”΄ MULTIPLE TUMORS DETECTED' if len(final_masks) > 1 else 'πŸ”΄ TUMOR DETECTED' if len(final_masks) == 1 else '🟒 NO TUMORS DETECTED'}
327
+ - **Complexity**: {'High (multiple lesions)' if len(final_masks) > 1 else 'Standard (single lesion)' if len(final_masks) == 1 else 'Normal brain'}
328
+ - **Recommendation**: {'Immediate specialist consultation' if total_percentage > 2.0 else 'Medical evaluation advised' if total_percentage > 0 else 'Regular monitoring'}
329
+
330
+ ### ⚠️ Medical Disclaimer:
331
+ This AI analysis is for **research purposes only**. Results should be verified by qualified radiologists. Not for diagnostic use.
 
 
 
 
 
 
 
 
 
 
332
  """
333
 
334
+ print("βœ… Comprehensive tumor analysis completed!")
335
  return result_image, analysis_text
336
 
337
  except Exception as e:
338
+ error_msg = f"❌ Error during tumor detection: {str(e)}"
339
  print(error_msg)
340
  return None, error_msg
341
 
342
  def clear_all():
343
+ return None, None, "Upload a brain MRI image for comprehensive tumor analysis."
 
344
 
345
+ # Enhanced CSS
346
  css = """
347
  .gradio-container {
348
  max-width: 1400px !important;
 
352
  text-align: center;
353
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
354
  color: white;
355
+ padding: 30px;
 
 
 
 
 
356
  border-radius: 15px;
357
+ margin-bottom: 30px;
358
+ box-shadow: 0 10px 20px rgba(0,0,0,0.1);
 
 
 
 
 
 
 
 
 
 
 
359
  }
360
  """
361
 
362
+ # Create comprehensive Gradio interface
363
+ with gr.Blocks(css=css, title="🧠 Comprehensive Brain Tumor Detection") as app:
364
 
 
365
  gr.HTML("""
366
  <div id="title">
367
+ <h1>🧠 Advanced Brain Tumor Detection AI</h1>
368
+ <p style="font-size: 18px; margin-top: 15px;">
369
+ Detects ALL Tumors β€’ Instance Segmentation β€’ Multi-Tumor Analysis
370
  </p>
371
+ <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
372
+ Powered by Mask R-CNN + Enhanced Image Processing
373
  </p>
374
  </div>
375
  """)
376
 
377
  with gr.Row():
378
  with gr.Column(scale=1):
379
+ gr.Markdown("### πŸ“€ Upload Brain MRI")
380
 
381
  image_input = gr.Image(
382
+ label="Brain MRI Scan",
383
  type="pil",
384
  sources=["upload", "webcam"],
385
  height=350
386
  )
387
 
388
  with gr.Row():
389
+ analyze_btn = gr.Button("πŸ” Detect All Tumors", variant="primary", scale=2, size="lg")
390
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
  with gr.Column(scale=2):
393
+ gr.Markdown("### πŸ“Š Comprehensive Analysis")
394
 
395
  output_image = gr.Image(
396
+ label="Complete Tumor Analysis",
397
  type="pil",
398
+ height=600
 
399
  )
400
 
401
  analysis_output = gr.Markdown(
402
+ value="Upload a brain MRI image to detect and analyze ALL tumors present.",
403
  elem_id="analysis"
404
  )
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  # Event handlers
407
+ analyze_btn.click(
408
+ fn=detect_all_tumors,
409
  inputs=[image_input],
410
  outputs=[output_image, analysis_output],
411
  show_progress=True
 
417
  outputs=[image_input, output_image, analysis_output]
418
  )
419
 
 
420
  if __name__ == "__main__":
421
+ print("πŸš€ Starting Comprehensive Brain Tumor Detection System...")
 
 
 
 
422
  app.launch(
423
  server_name="0.0.0.0",
424
  server_port=7860,