google-labs-jules[bot] commited on
Commit
6477883
Β·
1 Parent(s): 2f2ad4d

Create a complete brain tumor segmentation application using Gradio.

Browse files

This commit includes the following files as specified:
- `app.py`: The main Gradio application.
- `requirements.txt`: Project dependencies.
- `.gitignore`: Standard gitignore for a Python project.
- `README.md`: Documentation for the Hugging Face Space.

Files changed (4) hide show
  1. .gitignore +44 -0
  2. README.md +44 -0
  3. app.py +299 -0
  4. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .Python
6
+ build/
7
+ develop-eggs/
8
+ dist/
9
+ downloads/
10
+ eggs/
11
+ .eggs/
12
+ lib/
13
+ lib64/
14
+ parts/
15
+ sdist/
16
+ var/
17
+ wheels/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+ MANIFEST
22
+
23
+ # PyTorch
24
+ *.pth
25
+ *.pt
26
+
27
+ # Jupyter Notebook
28
+ .ipynb_checkpoints
29
+
30
+ # Environment
31
+ .env
32
+ .venv
33
+ env/
34
+ venv/
35
+
36
+ # IDE
37
+ .vscode/
38
+ .idea/
39
+ *.swp
40
+ *.swo
41
+
42
+ # OS
43
+ .DS_Store
44
+ Thumbs.db
README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Brain Tumor Segmentation AI
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # 🧠 Brain Tumor Segmentation AI
14
+
15
+ An advanced deep learning application for automatic brain tumor detection and segmentation in MRI images.
16
+
17
+ ## Features
18
+
19
+ - πŸ“€ **Easy Upload**: Support for image upload and camera capture
20
+ - 🎯 **Accurate Segmentation**: Uses pre-trained U-Net model for precise tumor detection
21
+ - πŸ“Š **Detailed Analysis**: Provides tumor statistics and visual overlays
22
+ - 🌐 **Web-based Interface**: No installation required, runs in browser
23
+ - πŸ“± **Mobile Friendly**: Responsive design works on all devices
24
+
25
+ ## How to Use
26
+
27
+ 1. Upload an MRI brain scan image or use your camera
28
+ 2. Click "Analyze Image" or wait for auto-processing
29
+ 3. View the segmentation results and analysis report
30
+
31
+ ## Technology
32
+
33
+ - **Model**: Pre-trained U-Net architecture
34
+ - **Framework**: PyTorch
35
+ - **Interface**: Gradio
36
+ - **Hosting**: Hugging Face Spaces
37
+
38
+ ## Medical Disclaimer
39
+
40
+ ⚠️ **Important**: This tool is for research and educational purposes only. Do not use for medical diagnosis. Always consult qualified healthcare professionals for medical advice.
41
+
42
+ ## License
43
+
44
+ MIT License - see LICENSE file for details.
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+ import base64
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+
12
+ # Load the pretrained model
13
+ @gr.utils.cache
14
+ def load_model():
15
+ """Load the pretrained brain segmentation model"""
16
+ try:
17
+ model = torch.hub.load(
18
+ 'mateuszbuda/brain-segmentation-pytorch',
19
+ 'unet',
20
+ in_channels=3,
21
+ out_channels=1,
22
+ init_features=32,
23
+ pretrained=True,
24
+ force_reload=False
25
+ )
26
+ model.eval()
27
+ return model
28
+ except Exception as e:
29
+ print(f"Error loading model: {e}")
30
+ return None
31
+
32
+ # Initialize model
33
+ model = load_model()
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ if model:
36
+ model = model.to(device)
37
+
38
+ def preprocess_image(image):
39
+ """Preprocess the input image for the model"""
40
+ if isinstance(image, np.ndarray):
41
+ image = Image.fromarray(image)
42
+
43
+ # Convert to RGB if not already
44
+ if image.mode != 'RGB':
45
+ image = image.convert('RGB')
46
+
47
+ # Resize to 256x256 (model's expected input size)
48
+ image = image.resize((256, 256), Image.Resampling.LANCZOS)
49
+
50
+ # Convert to tensor and normalize
51
+ transform = transforms.Compose([
52
+ transforms.ToTensor(),
53
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
54
+ std=[0.229, 0.224, 0.225])
55
+ ])
56
+
57
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
58
+ return image_tensor, image
59
+
60
+ def create_overlay_visualization(original_img, mask, alpha=0.6):
61
+ """Create an overlay visualization of the segmentation"""
62
+ # Convert original image to numpy array
63
+ original_np = np.array(original_img)
64
+
65
+ # Create colored mask (red for tumor regions)
66
+ colored_mask = np.zeros_like(original_np)
67
+ colored_mask[:, :, 0] = mask * 255 # Red channel for tumor
68
+
69
+ # Create overlay
70
+ overlay = cv2.addWeighted(original_np, 1-alpha, colored_mask, alpha, 0)
71
+
72
+ return overlay
73
+
74
+ def predict_tumor(image):
75
+ """Main prediction function"""
76
+ if model is None:
77
+ return None, "❌ Model failed to load. Please try again."
78
+
79
+ if image is None:
80
+ return None, "⚠️ Please upload an image first."
81
+
82
+ try:
83
+ # Preprocess the image
84
+ input_tensor, original_img = preprocess_image(image)
85
+ input_tensor = input_tensor.to(device)
86
+
87
+ # Make prediction
88
+ with torch.no_grad():
89
+ prediction = model(input_tensor)
90
+ # Apply sigmoid to get probability map
91
+ prediction = torch.sigmoid(prediction)
92
+ # Convert to numpy
93
+ prediction = prediction.squeeze().cpu().numpy()
94
+
95
+ # Threshold the prediction (you can adjust this threshold)
96
+ threshold = 0.5
97
+ binary_mask = (prediction > threshold).astype(np.uint8)
98
+
99
+ # Create visualizations
100
+ # 1. Original image
101
+ original_array = np.array(original_img)
102
+
103
+ # 2. Segmentation mask
104
+ mask_colored = np.zeros((256, 256, 3), dtype=np.uint8)
105
+ mask_colored[:, :, 0] = binary_mask * 255 # Red channel
106
+
107
+ # 3. Overlay
108
+ overlay = create_overlay_visualization(original_img, binary_mask, alpha=0.4)
109
+
110
+ # 4. Side-by-side comparison
111
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
112
+
113
+ axes[0].imshow(original_array)
114
+ axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
115
+ axes[0].axis('off')
116
+
117
+ axes[1].imshow(mask_colored)
118
+ axes[1].set_title('Tumor Segmentation', fontsize=14, fontweight='bold')
119
+ axes[1].axis('off')
120
+
121
+ axes[2].imshow(overlay)
122
+ axes[2].set_title('Overlay (Red = Tumor)', fontsize=14, fontweight='bold')
123
+ axes[2].axis('off')
124
+
125
+ plt.tight_layout()
126
+
127
+ # Save plot to bytes
128
+ buf = io.BytesIO()
129
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
130
+ buf.seek(0)
131
+ plt.close()
132
+
133
+ # Convert to PIL Image
134
+ result_image = Image.open(buf)
135
+
136
+ # Calculate tumor statistics
137
+ total_pixels = 256 * 256
138
+ tumor_pixels = np.sum(binary_mask)
139
+ tumor_percentage = (tumor_pixels / total_pixels) * 100
140
+
141
+ # Create analysis report
142
+ analysis_text = f"""
143
+ ## 🧠 Brain Tumor Segmentation Analysis
144
+
145
+ **πŸ“Š Tumor Statistics:**
146
+ - Total pixels analyzed: {total_pixels:,}
147
+ - Tumor pixels detected: {tumor_pixels:,}
148
+ - Tumor area percentage: {tumor_percentage:.2f}%
149
+
150
+ **🎯 Model Performance:**
151
+ - Model: U-Net with attention mechanism
152
+ - Input resolution: 256Γ—256 pixels
153
+ - Detection threshold: {threshold}
154
+
155
+ **⚠️ Medical Disclaimer:**
156
+ This is an AI tool for research purposes only.
157
+ Always consult qualified medical professionals for diagnosis.
158
+ """
159
+
160
+ return result_image, analysis_text
161
+
162
+ except Exception as e:
163
+ error_msg = f"❌ Error during prediction: {str(e)}"
164
+ return None, error_msg
165
+
166
+ def clear_all():
167
+ """Clear all inputs and outputs"""
168
+ return None, None, ""
169
+
170
+ # Custom CSS for better styling
171
+ css = """
172
+ #main-container {
173
+ max-width: 1200px;
174
+ margin: 0 auto;
175
+ }
176
+ #title {
177
+ text-align: center;
178
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
179
+ color: white;
180
+ padding: 20px;
181
+ border-radius: 10px;
182
+ margin-bottom: 20px;
183
+ }
184
+ #upload-box {
185
+ border: 2px dashed #ccc;
186
+ border-radius: 10px;
187
+ padding: 20px;
188
+ text-align: center;
189
+ margin: 10px 0;
190
+ }
191
+ .output-image {
192
+ border-radius: 10px;
193
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
194
+ }
195
+ """
196
+
197
+ # Create Gradio interface
198
+ with gr.Blocks(css=css, title="Brain Tumor Segmentation") as app:
199
+
200
+ # Header
201
+ gr.HTML("""
202
+ <div id="title">
203
+ <h1>🧠 Brain Tumor Segmentation AI</h1>
204
+ <p>Upload an MRI brain scan to detect and visualize tumor regions using deep learning</p>
205
+ </div>
206
+ """)
207
+
208
+ with gr.Row():
209
+ with gr.Column(scale=1):
210
+ gr.HTML("<h3>πŸ“€ Input Image</h3>")
211
+
212
+ # Image input with camera option
213
+ image_input = gr.Image(
214
+ label="Upload Brain MRI Scan",
215
+ type="pil",
216
+ sources=["upload", "webcam"], # Allow both upload and camera
217
+ height=300
218
+ )
219
+
220
+ with gr.Row():
221
+ predict_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
222
+ clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
223
+
224
+ gr.HTML("""
225
+ <div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 8px;">
226
+ <h4>πŸ“‹ Instructions:</h4>
227
+ <ul>
228
+ <li>Upload a brain MRI scan image</li>
229
+ <li>Supported formats: PNG, JPG, JPEG</li>
230
+ <li>For best results, use clear, high-contrast MRI images</li>
231
+ <li>You can also use the camera to capture an image from your device</li>
232
+ </ul>
233
+ </div>
234
+ """)
235
+
236
+ with gr.Column(scale=2):
237
+ gr.HTML("<h3>πŸ“Š Segmentation Results</h3>")
238
+
239
+ # Output image
240
+ output_image = gr.Image(
241
+ label="Segmentation Results",
242
+ type="pil",
243
+ height=400,
244
+ elem_classes=["output-image"]
245
+ )
246
+
247
+ # Analysis text
248
+ analysis_output = gr.Markdown(
249
+ label="Analysis Report",
250
+ value="Upload an image and click 'Analyze Image' to see results."
251
+ )
252
+
253
+ # Add footer with information
254
+ gr.HTML("""
255
+ <div style="margin-top: 30px; padding: 20px; background-color: #f9f9f9; border-radius: 10px;">
256
+ <h4>πŸ”¬ About This Tool</h4>
257
+ <p><strong>Model:</strong> Pre-trained U-Net architecture optimized for brain tumor segmentation</p>
258
+ <p><strong>Technology:</strong> PyTorch, Deep Learning, Computer Vision</p>
259
+ <p><strong>Dataset:</strong> Trained on medical MRI brain scans</p>
260
+
261
+ <h4>⚠️ Important Medical Disclaimer</h4>
262
+ <p style="color: #d73027; font-weight: bold;">
263
+ This AI tool is for research and educational purposes only. It should NOT be used for medical diagnosis.
264
+ Always consult qualified healthcare professionals for medical advice and diagnosis.
265
+ </p>
266
+
267
+ <p style="text-align: center; margin-top: 20px; color: #666;">
268
+ Made with ❀️ using Gradio β€’ Powered by PyTorch β€’ Hosted on πŸ€— Hugging Face Spaces
269
+ </p>
270
+ </div>
271
+ """)
272
+
273
+ # Event handlers
274
+ predict_btn.click(
275
+ fn=predict_tumor,
276
+ inputs=[image_input],
277
+ outputs=[output_image, analysis_output]
278
+ )
279
+
280
+ clear_btn.click(
281
+ fn=clear_all,
282
+ outputs=[image_input, output_image, analysis_output]
283
+ )
284
+
285
+ # Auto-predict when image is uploaded
286
+ image_input.change(
287
+ fn=predict_tumor,
288
+ inputs=[image_input],
289
+ outputs=[output_image, analysis_output]
290
+ )
291
+
292
+ # Launch the app
293
+ if __name__ == "__main__":
294
+ app.launch(
295
+ share=True,
296
+ server_name="0.0.0.0",
297
+ server_port=7860,
298
+ show_error=True
299
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision>=0.10.0
3
+ gradio>=4.0.0
4
+ opencv-python>=4.5.0
5
+ Pillow>=8.0.0
6
+ numpy>=1.21.0
7
+ matplotlib>=3.3.0
8
+ scikit-image>=0.18.0