PranayChamala commited on
Commit
8dcd1f3
·
1 Parent(s): 5e4b22f

initialized the first deployment

Browse files
41598_2023_41576_Fig1_HTML.jpg ADDED
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python 3.9 image
2
+ FROM python:3.9
3
+
4
+ # Create a non-root user and switch to it
5
+ RUN useradd -m -u 1000 user
6
+ USER user
7
+ ENV PATH="/home/user/.local/bin:$PATH"
8
+
9
+ # Set working directory
10
+ WORKDIR /app
11
+
12
+ # Copy requirements and install
13
+ COPY --chown=user ./requirements.txt requirements.txt
14
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
15
+
16
+ # Copy the rest of the app
17
+ COPY --chown=user . /app
18
+
19
+ # Expose port 7860 and run the app with uvicorn
20
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
__pycache__/chatbot.cpython-313.pyc ADDED
Binary file (1.54 kB). View file
 
__pycache__/mediseg.cpython-313.pyc ADDED
Binary file (26.8 kB). View file
 
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os, io, base64, cv2, torch, numpy as np
3
+ from PIL import Image
4
+ from flask import Flask, request, render_template, jsonify
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision.models as models
8
+ import torchvision.transforms as transforms
9
+ from monai.transforms import EnsureChannelFirst, ScaleIntensity, Resize, ToTensor
10
+
11
+ # Enable debug logging
12
+ import logging
13
+ logging.basicConfig(level=logging.DEBUG)
14
+
15
+ # -------------------------------
16
+ # Global Setup
17
+ # -------------------------------
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ def pil_to_base64(pil_img):
20
+ buff = io.BytesIO()
21
+ pil_img.save(buff, format="JPEG")
22
+ return base64.b64encode(buff.getvalue()).decode("utf-8")
23
+
24
+ # -------------------------------
25
+ # 1. CLASSIFIER MODULE (DenseNet121 via MONAI)
26
+ # -------------------------------
27
+ CLASS_NAMES = ['AbdomenCT', 'BreastMRI', 'Chest Xray', 'ChestCT',
28
+ 'Endoscopy', 'Hand Xray', 'HeadCT', 'HeadMRI']
29
+ from monai.networks.nets import DenseNet121
30
+ def load_classifier_model(model_path):
31
+ model = DenseNet121(
32
+ spatial_dims=2,
33
+ in_channels=3,
34
+ out_channels=len(CLASS_NAMES)
35
+ ).to(device)
36
+ state_dict = torch.load(model_path, map_location=device)
37
+ if isinstance(state_dict, dict) and "state_dict" in state_dict:
38
+ state_dict = state_dict["state_dict"]
39
+ model.load_state_dict(state_dict, strict=False)
40
+ model.eval()
41
+ return model
42
+
43
+ def load_and_preprocess_image_classifier(image_path):
44
+ image_path = image_path.strip()
45
+ if image_path.lower().endswith((".jpg", ".jpeg", ".png")):
46
+ image = Image.open(image_path).convert("RGB")
47
+ image = np.array(image)
48
+ elif image_path.lower().endswith((".nii", ".nii.gz")):
49
+ import nibabel as nib
50
+ image = nib.load(image_path).get_fdata()
51
+ image = np.squeeze(image)
52
+ if len(image.shape) == 4:
53
+ image = image[..., 0]
54
+ if len(image.shape) == 3:
55
+ image = image[:, :, image.shape[2] // 2]
56
+ if len(image.shape) == 2:
57
+ image = np.stack([image]*3, axis=-1)
58
+ elif image_path.lower().endswith(".dcm"):
59
+ import pydicom
60
+ dicom_data = pydicom.dcmread(image_path)
61
+ image = dicom_data.pixel_array
62
+ if len(image.shape) == 2:
63
+ image = np.stack([image]*3, axis=-1)
64
+ else:
65
+ raise ValueError("Unsupported file format!")
66
+ if len(image.shape) == 3 and image.shape[-1] == 3:
67
+ image = np.transpose(image, (2, 0, 1))
68
+ else:
69
+ raise ValueError(f"Unexpected image shape: {image.shape}")
70
+ image = torch.tensor(image, dtype=torch.float32)
71
+ image = ScaleIntensity()(image)
72
+ image = Resize((224,224))(image)
73
+ image = image.unsqueeze(0)
74
+ return image.to(device)
75
+
76
+ def classify_medical_image(image_path, classifier_model):
77
+ image_tensor = load_and_preprocess_image_classifier(image_path)
78
+ with torch.no_grad():
79
+ output = classifier_model(image_tensor)
80
+ pred_class = torch.argmax(output, dim=1).item()
81
+ return CLASS_NAMES[pred_class]
82
+
83
+ # -------------------------------
84
+ # 2. BRAIN TUMOR SEGMENTATION MODULE (UNetMulti)
85
+ # -------------------------------
86
+ class DoubleConvUNet(nn.Module):
87
+ def __init__(self, in_channels, out_channels):
88
+ super(DoubleConvUNet, self).__init__()
89
+ self.conv = nn.Sequential(
90
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
91
+ nn.BatchNorm2d(out_channels),
92
+ nn.ReLU(inplace=True),
93
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
94
+ nn.BatchNorm2d(out_channels),
95
+ nn.ReLU(inplace=True)
96
+ )
97
+ def forward(self, x):
98
+ return self.conv(x)
99
+
100
+ class UNetMulti(nn.Module):
101
+ def __init__(self, in_channels=3, out_channels=4):
102
+ super(UNetMulti, self).__init__()
103
+ self.down1 = DoubleConvUNet(in_channels, 64)
104
+ self.pool1 = nn.MaxPool2d(2)
105
+ self.down2 = DoubleConvUNet(64, 128)
106
+ self.pool2 = nn.MaxPool2d(2)
107
+ self.down3 = DoubleConvUNet(128, 256)
108
+ self.pool3 = nn.MaxPool2d(2)
109
+ self.down4 = DoubleConvUNet(256, 512)
110
+ self.pool4 = nn.MaxPool2d(2)
111
+ self.bottleneck = DoubleConvUNet(512, 1024)
112
+ self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
113
+ self.conv4 = DoubleConvUNet(1024, 512)
114
+ self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
115
+ self.conv3 = DoubleConvUNet(512, 256)
116
+ self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
117
+ self.conv2 = DoubleConvUNet(256, 128)
118
+ self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
119
+ self.conv1 = DoubleConvUNet(128, 64)
120
+ self.final_conv = nn.Conv2d(64, out_channels, 1)
121
+ def forward(self, x):
122
+ c1 = self.down1(x)
123
+ p1 = self.pool1(c1)
124
+ c2 = self.down2(p1)
125
+ p2 = self.pool2(c2)
126
+ c3 = self.down3(p2)
127
+ p3 = self.pool3(c3)
128
+ c4 = self.down4(p3)
129
+ p4 = self.pool4(c4)
130
+ bn = self.bottleneck(p4)
131
+ u4 = self.up4(bn)
132
+ merge4 = torch.cat([u4, c4], dim=1)
133
+ c5 = self.conv4(merge4)
134
+ u3 = self.up3(c5)
135
+ merge3 = torch.cat([u3, c3], dim=1)
136
+ c6 = self.conv3(merge3)
137
+ u2 = self.up2(c6)
138
+ merge2 = torch.cat([u2, c2], dim=1)
139
+ c7 = self.conv2(merge2)
140
+ u1 = self.up1(c7)
141
+ merge1 = torch.cat([u1, c1], dim=1)
142
+ c8 = self.conv1(merge1)
143
+ return self.final_conv(c8)
144
+
145
+ def process_brain_tumor_return(image, model_path="models/brain_tumor_unet_multiclass.pth"):
146
+ logging.debug("Processing brain tumor segmentation")
147
+ model = UNetMulti(in_channels=3, out_channels=4).to(device)
148
+ model.load_state_dict(torch.load(model_path, map_location=device))
149
+ model.eval()
150
+ transform_img = transforms.Compose([
151
+ transforms.Resize((256,256)),
152
+ transforms.ToTensor()
153
+ ])
154
+ input_tensor = transform_img(image).unsqueeze(0).to(device)
155
+ with torch.no_grad():
156
+ output = model(input_tensor)
157
+ preds = torch.argmax(output, dim=1).squeeze().cpu().numpy()
158
+ image_np = transform_img(image).permute(1,2,0).cpu().numpy()
159
+ overlay = cv2.applyColorMap(np.uint8(255 * preds/np.max(preds+1e-8)), cv2.COLORMAP_JET)
160
+ overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
161
+ blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0)
162
+ orig_pil = Image.fromarray((image_np*255).astype(np.uint8))
163
+ mask_pil = Image.fromarray(overlay)
164
+ overlay_pil = Image.fromarray(blended)
165
+ return {
166
+ "original": pil_to_base64(orig_pil),
167
+ "mask": pil_to_base64(mask_pil),
168
+ "overlay": pil_to_base64(overlay_pil)
169
+ }
170
+
171
+ # -------------------------------
172
+ # 3. ENDOSCOPY POLYP DETECTION MODULE (Binary UNet)
173
+ # -------------------------------
174
+ class UNetBinary(nn.Module):
175
+ def __init__(self, in_channels=3, out_channels=1):
176
+ super(UNetBinary, self).__init__()
177
+ self.down1 = DoubleConvUNet(in_channels, 64)
178
+ self.pool1 = nn.MaxPool2d(2)
179
+ self.down2 = DoubleConvUNet(64, 128)
180
+ self.pool2 = nn.MaxPool2d(2)
181
+ self.down3 = DoubleConvUNet(128, 256)
182
+ self.pool3 = nn.MaxPool2d(2)
183
+ self.down4 = DoubleConvUNet(256, 512)
184
+ self.pool4 = nn.MaxPool2d(2)
185
+ self.bottleneck = DoubleConvUNet(512, 1024)
186
+ self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
187
+ self.conv4 = DoubleConvUNet(1024, 512)
188
+ self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
189
+ self.conv3 = DoubleConvUNet(512, 256)
190
+ self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
191
+ self.conv2 = DoubleConvUNet(256, 128)
192
+ self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
193
+ self.conv1 = DoubleConvUNet(128, 64)
194
+ self.final_conv = nn.Conv2d(64, out_channels, 1)
195
+ def forward(self, x):
196
+ c1 = self.down1(x)
197
+ p1 = self.pool1(c1)
198
+ c2 = self.down2(p1)
199
+ p2 = self.pool2(c2)
200
+ c3 = self.down3(p2)
201
+ p3 = self.pool3(c3)
202
+ c4 = self.down4(p3)
203
+ p4 = self.pool4(c4)
204
+ bn = self.bottleneck(p4)
205
+ u4 = self.up4(bn)
206
+ merge4 = torch.cat([u4, c4], dim=1)
207
+ c5 = self.conv4(merge4)
208
+ u3 = self.up3(c5)
209
+ merge3 = torch.cat([u3, c3], dim=1)
210
+ c6 = self.conv3(merge3)
211
+ u2 = self.up2(c6)
212
+ merge2 = torch.cat([u2, c2], dim=1)
213
+ c7 = self.conv2(merge2)
214
+ u1 = self.up1(c7)
215
+ merge1 = torch.cat([u1, c1], dim=1)
216
+ c8 = self.conv1(merge1)
217
+ return self.final_conv(c8)
218
+
219
+ def process_endoscopy_return(image, model_path="models/endoscopy_unet.pth"):
220
+ model = UNetBinary(in_channels=3, out_channels=1).to(device)
221
+ model.load_state_dict(torch.load(model_path, map_location=device))
222
+ model.eval()
223
+ transform_img = transforms.Compose([
224
+ transforms.Resize((256,256)),
225
+ transforms.ToTensor()
226
+ ])
227
+ input_tensor = transform_img(image).unsqueeze(0).to(device)
228
+ with torch.no_grad():
229
+ output = model(input_tensor)
230
+ prob = torch.sigmoid(output)
231
+ mask = (prob > 0.5).float().squeeze().cpu().numpy()
232
+ image_np = transform_img(image).permute(1,2,0).cpu().numpy()
233
+ overlay = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
234
+ overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
235
+ blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0)
236
+ orig_pil = Image.fromarray((image_np*255).astype(np.uint8))
237
+ mask_pil = Image.fromarray(overlay)
238
+ overlay_pil = Image.fromarray(blended)
239
+ return {
240
+ "original": pil_to_base64(orig_pil),
241
+ "mask": pil_to_base64(mask_pil),
242
+ "overlay": pil_to_base64(overlay_pil)
243
+ }
244
+
245
+ # -------------------------------
246
+ # 4. PNEUMONIA DETECTION MODULE (Grad-CAM on ResNet18)
247
+ # -------------------------------
248
+ class GradCAM_Pneumonia:
249
+ def __init__(self, model, target_layer):
250
+ self.model = model
251
+ self.target_layer = target_layer
252
+ self.gradients = None
253
+ self.activations = None
254
+ self.hook_handles = []
255
+ self._register_hooks()
256
+ def _register_hooks(self):
257
+ def forward_hook(module, input, output):
258
+ self.activations = output.detach()
259
+ def backward_hook(module, grad_in, grad_out):
260
+ self.gradients = grad_out[0].detach()
261
+ handle1 = self.target_layer.register_forward_hook(forward_hook)
262
+ handle2 = self.target_layer.register_backward_hook(backward_hook)
263
+ self.hook_handles.extend([handle1, handle2])
264
+ def remove_hooks(self):
265
+ for handle in self.hook_handles:
266
+ handle.remove()
267
+ def generate(self, input_image, target_class=None):
268
+ output = self.model(input_image)
269
+ if target_class is None:
270
+ target_class = output.argmax(dim=1).item()
271
+ self.model.zero_grad()
272
+ one_hot = torch.zeros_like(output)
273
+ one_hot[0, target_class] = 1
274
+ with torch.enable_grad():
275
+ output.backward(gradient=one_hot, retain_graph=True)
276
+ weights = self.gradients.mean(dim=(2,3), keepdim=True)
277
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
278
+ cam = F.relu(cam)
279
+ cam = cam.squeeze().cpu().numpy()
280
+ _, _, H, W = input_image.shape
281
+ cam = cv2.resize(cam, (W, H))
282
+ cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
283
+ return cam, output
284
+
285
+ def process_pneumonia_return(image, model_path="models/pneumonia_resnet18.pth"):
286
+ model = models.resnet18(pretrained=False)
287
+ num_ftrs = model.fc.in_features
288
+ model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia
289
+ model.load_state_dict(torch.load(model_path, map_location=device))
290
+ model.to(device)
291
+ model.eval()
292
+ grad_cam = GradCAM_Pneumonia(model, model.layer4)
293
+
294
+ transform_img = transforms.Compose([
295
+ transforms.Resize((224,224)),
296
+ transforms.ToTensor(),
297
+ transforms.Normalize(mean=[0.485,0.456,0.406],
298
+ std=[0.229,0.224,0.225])
299
+ ])
300
+ input_tensor = transform_img(image).unsqueeze(0).to(device)
301
+ # Enable gradient tracking for the input tensor
302
+ input_tensor.requires_grad_()
303
+ # Do NOT wrap the following call with torch.no_grad()
304
+ cam, output = grad_cam.generate(input_tensor)
305
+ predicted_class = output.argmax(dim=1).item()
306
+
307
+ label_text = "Pneumonia" if predicted_class == 1 else "Normal"
308
+
309
+ def get_bounding_box(heatmap, thresh=0.5, min_area=100):
310
+ heat_uint8 = np.uint8(255 * heatmap)
311
+ ret, binary = cv2.threshold(heat_uint8, int(thresh*255), 255, cv2.THRESH_BINARY)
312
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
313
+ if len(contours)==0:
314
+ return None
315
+ largest = max(contours, key=cv2.contourArea)
316
+ if cv2.contourArea(largest) < min_area:
317
+ return None
318
+ x, y, w, h = cv2.boundingRect(largest)
319
+ return (x, y, w, h)
320
+
321
+ bbox = None
322
+ if predicted_class == 1:
323
+ bbox = get_bounding_box(cam, thresh=0.5, min_area=100)
324
+
325
+ resized_image = image.resize((224,224))
326
+ image_np = np.array(resized_image)
327
+ overlay = image_np.copy()
328
+ if bbox is not None:
329
+ x, y, w, h = bbox
330
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), (255,0,0), 2)
331
+ cv2.putText(overlay, label_text, (10,25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,0),2)
332
+
333
+ heatmap_color = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
334
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
335
+
336
+ orig_pil = Image.fromarray(image_np)
337
+ heatmap_pil = Image.fromarray(heatmap_color)
338
+ overlay_pil = Image.fromarray(overlay)
339
+ grad_cam.remove_hooks()
340
+ return {
341
+ "original": pil_to_base64(orig_pil),
342
+ "mask": pil_to_base64(heatmap_pil),
343
+ "overlay": pil_to_base64(overlay_pil)
344
+ }
345
+
346
+ # -------------------------------
347
+ # 5. COMPLETE PIPELINE FUNCTION
348
+ # -------------------------------
349
+ def complete_pipeline(image_path):
350
+ classifier_model = load_classifier_model("models/best_metric_model (4).pth")
351
+ predicted_modality = classify_medical_image(image_path, classifier_model)
352
+ print(f"Detected modality: {predicted_modality}")
353
+ original_image = Image.open(image_path).convert("RGB")
354
+ results = {"predicted_modality": predicted_modality}
355
+ if predicted_modality in ["HeadCT", "HeadMRI"]:
356
+ results["specialized"] = process_brain_tumor_return(original_image, "models/brain_tumor_unet_multiclass.pth")
357
+ elif predicted_modality == "Endoscopy":
358
+ results["specialized"] = process_endoscopy_return(original_image, "models/endoscopy_unet.pth")
359
+ elif predicted_modality == "Chest Xray":
360
+ results["specialized"] = process_pneumonia_return(original_image, "models/pneumonia_resnet18.pth")
361
+ else:
362
+ results["message"] = f"No specialized processing for modality: {predicted_modality}"
363
+ return results
364
+
365
+ # -------------------------------
366
+ # 6. FLASK API SETUP
367
+ # -------------------------------
368
+ from flask import Flask, request, render_template, jsonify
369
+ app = Flask(__name__)
370
+
371
+ @app.route('/', methods=['GET'])
372
+ def index():
373
+ return render_template("index.html", result=None)
374
+
375
+ @app.route('/predict', methods=['POST'])
376
+ def predict():
377
+ if 'file' not in request.files:
378
+ return render_template("index.html", result={"error": "No file part in the request."})
379
+ file = request.files['file']
380
+ if file.filename == '':
381
+ return render_template("index.html", result={"error": "No file selected."})
382
+ temp_path = "temp_input.jpg"
383
+ file.save(temp_path)
384
+ try:
385
+ result = complete_pipeline(temp_path)
386
+ except Exception as e:
387
+ result = {"error": str(e)}
388
+ os.remove(temp_path)
389
+ return render_template("index.html", result=result)
390
+
391
+ if __name__ == '__main__':
392
+ app.run(host='0.0.0.0', port=5000, debug=True)
app/__init__.py ADDED
File without changes
app/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (151 Bytes). View file
 
app/__pycache__/chatbot.cpython-313.pyc ADDED
Binary file (8.13 kB). View file
 
app/__pycache__/database.cpython-313.pyc ADDED
Binary file (956 Bytes). View file
 
app/__pycache__/main.cpython-313.pyc ADDED
Binary file (3.36 kB). View file
 
app/__pycache__/mediseg.cpython-313.pyc ADDED
Binary file (24.6 kB). View file
 
app/chatbot.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import numpy as np
4
+ import faiss
5
+ from sklearn.feature_extraction.text import TfidfVectorizer
6
+ from transformers import pipeline
7
+
8
+ # -------------------------------
9
+ # Load disease data and preprocess
10
+ # -------------------------------
11
+ def load_disease_data(csv_path):
12
+ df = pd.read_csv(csv_path)
13
+ df.columns = df.columns.str.strip().str.lower()
14
+ df = df.fillna("")
15
+ disease_symptoms = {}
16
+ disease_precautions = {}
17
+ for _, row in df.iterrows():
18
+ disease = row["disease"].strip()
19
+ symptoms = [s.strip().lower() for s in row["symptoms"].split(",") if s.strip()]
20
+ precautions = [p.strip() for p in row["precautions"].split(",") if p.strip()]
21
+ disease_symptoms[disease] = symptoms
22
+ disease_precautions[disease] = precautions
23
+ return disease_symptoms, disease_precautions
24
+
25
+ # Load CSV data (ensure this CSV file is in the repository root)
26
+ disease_symptoms, disease_precautions = load_disease_data("disease_sympts_prec_full.csv")
27
+ known_symptoms = set()
28
+ for syms in disease_symptoms.values():
29
+ known_symptoms.update(syms)
30
+
31
+ # -------------------------------
32
+ # Build symptom vectorizer and FAISS index
33
+ # -------------------------------
34
+ vectorizer = TfidfVectorizer()
35
+ symptom_texts = [" ".join(symptoms) for symptoms in disease_symptoms.values()]
36
+ tfidf_matrix = vectorizer.fit_transform(symptom_texts).toarray()
37
+ index = faiss.IndexFlatL2(tfidf_matrix.shape[1])
38
+ index.add(np.array(tfidf_matrix, dtype=np.float32))
39
+ disease_list = list(disease_symptoms.keys())
40
+
41
+ def find_closest_disease(user_symptoms):
42
+ if not user_symptoms:
43
+ return None
44
+ user_vector = vectorizer.transform([" ".join(user_symptoms)]).toarray().astype("float32")
45
+ distances, indices = index.search(user_vector, k=1)
46
+ return disease_list[indices[0][0]]
47
+
48
+ # -------------------------------
49
+ # Load Medical NER model for symptom extraction
50
+ # -------------------------------
51
+ medical_ner = pipeline(
52
+ "ner",
53
+ model="blaze999/Medical-NER",
54
+ tokenizer="blaze999/Medical-NER",
55
+ aggregation_strategy="simple"
56
+ )
57
+
58
+ def extract_symptoms_ner(text):
59
+ results = medical_ner(text)
60
+ extracted = []
61
+ for r in results:
62
+ if "SIGN_SYMPTOM" in r["entity_group"]:
63
+ extracted.append(r["word"].lower())
64
+ return list(set(extracted))
65
+
66
+ def is_affirmative(answer):
67
+ answer_lower = answer.lower()
68
+ return any(word in answer_lower for word in ["yes", "yeah", "yep", "certainly", "sometimes", "a little"])
69
+
70
+ # -------------------------------
71
+ # Chatbot session class
72
+ # -------------------------------
73
+ class ChatbotSession:
74
+ def __init__(self):
75
+ self.conversation_history = []
76
+ self.reported_symptoms = set()
77
+ self.asked_missing = set()
78
+ self.awaiting_followup = None
79
+ self.state = "symptom_collection" # states: symptom_collection, pain, medications
80
+ # Initial greeting
81
+ greeting = "Doctor: Hello, I am your virtual doctor. What brought you in today?"
82
+ self.conversation_history.append(greeting)
83
+ self.finished = False
84
+
85
+ def process_message(self, message: str) -> str:
86
+ # State: collecting symptoms
87
+ if self.state == "symptom_collection":
88
+ if message.lower() in ["exit", "quit", "no"]:
89
+ self.state = "pain"
90
+ prompt = "Doctor: Do you experience any pain or aches? Please rate the pain on a scale of 1 to 10 (or type 'no' if none):"
91
+ self.conversation_history.append(prompt)
92
+ return prompt
93
+ # If we are waiting on a follow-up about a specific symptom
94
+ if self.awaiting_followup:
95
+ if is_affirmative(message):
96
+ self.reported_symptoms.add(self.awaiting_followup)
97
+ self.asked_missing.add(self.awaiting_followup)
98
+ self.awaiting_followup = None
99
+ else:
100
+ # Extract symptoms from message text
101
+ ner_results = extract_symptoms_ner(message)
102
+ for sym in ner_results:
103
+ if sym not in self.reported_symptoms:
104
+ self.reported_symptoms.add(sym)
105
+ # Update predicted disease
106
+ predicted_disease = find_closest_disease(list(self.reported_symptoms)) if self.reported_symptoms else None
107
+ # Check for missing symptoms if a disease is predicted
108
+ if predicted_disease:
109
+ expected = set(disease_symptoms.get(predicted_disease, []))
110
+ missing = expected - self.reported_symptoms
111
+ not_asked = missing - self.asked_missing
112
+ if not_asked:
113
+ symptom_to_ask = list(not_asked)[0]
114
+ followup = f"Are you also experiencing {symptom_to_ask}?"
115
+ self.conversation_history.append("Doctor: " + followup)
116
+ self.awaiting_followup = symptom_to_ask
117
+ return followup
118
+ prompt = "Doctor: Do you have any other symptoms you'd like to mention?"
119
+ self.conversation_history.append(prompt)
120
+ return prompt
121
+
122
+ # State: asking about pain
123
+ elif self.state == "pain":
124
+ try:
125
+ self.pain_level = int(message)
126
+ except ValueError:
127
+ self.pain_level = message
128
+ self.state = "medications"
129
+ prompt = "Doctor: Have you taken any medications recently? If yes, please specify (or type 'no' if none):"
130
+ self.conversation_history.append(prompt)
131
+ return prompt
132
+
133
+ # State: asking about medications
134
+ elif self.state == "medications":
135
+ self.medications = message if message.lower() not in ["no", "none"] else "None"
136
+ closing = "Doctor: Thank you for providing all the information."
137
+ self.conversation_history.append(closing)
138
+ self.finished = True
139
+ return closing
140
+
141
+ return "Doctor: I'm sorry, I didn't understand that."
142
+
143
+ def get_data(self):
144
+ return {
145
+ "conversation": self.conversation_history,
146
+ "symptoms": list(self.reported_symptoms),
147
+ "pain_level": getattr(self, "pain_level", None),
148
+ "medications": getattr(self, "medications", None)
149
+ }
app/database.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pymongo import MongoClient
3
+
4
+ # Replace with your actual username and password
5
+ MONGO_URI = "mongodb+srv://root:[email protected]/uspark_db?retryWrites=true&w=majority"
6
+
7
+ client = MongoClient(MONGO_URI)
8
+ db = client["uspark_db"]
9
+
10
+ def save_chat_session(session_id: str, conversation_data: dict):
11
+ db.chatbot.insert_one({"session_id": session_id, **conversation_data})
12
+
13
+ def save_medseg_result(result_data: dict):
14
+ db.medseg.insert_one(result_data)
app/main.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Depends
2
+ from uuid import uuid4
3
+ import io
4
+ from PIL import Image
5
+ from pydantic import BaseModel
6
+
7
+ # Import modules from the Uspark package
8
+ from app.chatbot import ChatbotSession
9
+ from app.mediseg import complete_pipeline_image
10
+ from app.database import save_chat_session, save_medseg_result
11
+
12
+ app = FastAPI(title="Uspark API")
13
+
14
+ # Ensure models are loaded from the 'models' directory within 'Uspark'
15
+ import sys
16
+ import os
17
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../models"))
18
+
19
+ class ChatMessage(BaseModel):
20
+ session_id: str
21
+ message: str
22
+
23
+ # In-memory session store (for demo purposes; consider persistent storage for production)
24
+ sessions = {}
25
+
26
+ @app.post("/chat/start")
27
+ def start_chat():
28
+ session_id = str(uuid4())
29
+ session = ChatbotSession()
30
+ sessions[session_id] = session
31
+ return {"session_id": session_id, "message": session.conversation_history[0]}
32
+
33
+ @app.post("/chat/message")
34
+ def chat_message(chat: ChatMessage):
35
+ if chat.session_id not in sessions:
36
+ raise HTTPException(status_code=404, detail="Invalid session_id")
37
+
38
+ session = sessions[chat.session_id]
39
+ response = session.process_message(chat.message)
40
+
41
+ # If the session has finished (after pain & medication), save to MongoDB and remove from memory.
42
+ if session.finished:
43
+ save_chat_session(chat.session_id, session.get_data())
44
+ del sessions[chat.session_id]
45
+
46
+ return {"response": response, "conversation": session.conversation_history}
47
+
48
+ @app.post("/medseg")
49
+ async def medseg_endpoint(file: UploadFile = File(...)):
50
+ try:
51
+ contents = await file.read()
52
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
53
+ except Exception:
54
+ raise HTTPException(status_code=400, detail="Invalid image file")
55
+
56
+ # Process image through the complete pipeline (classification + segmentation)
57
+ result = complete_pipeline_image(image)
58
+
59
+ # Save result to MongoDB
60
+ result_record = {
61
+ "filename": file.filename,
62
+ "result": result # Contains predicted modality and base64 image(s)
63
+ }
64
+ save_medseg_result(result_record)
65
+
66
+ return result
app/mediseg.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import nibabel as nib
5
+ import pydicom
6
+ import cv2
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.models as models
12
+ from torchvision import transforms
13
+ from monai.transforms import EnsureChannelFirst, ScaleIntensity, Resize, ToTensor
14
+ from io import BytesIO
15
+ import base64
16
+
17
+ # Set device
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # -------------------------------
21
+ # CLASSIFIER MODULE (Medical Classifier)
22
+ # -------------------------------
23
+ class_names = ['AbdomenCT', 'BreastMRI', 'Chest Xray', 'ChestCT',
24
+ 'Endoscopy', 'Hand Xray', 'HeadCT', 'HeadMRI']
25
+
26
+ # Update model path to load from models folder
27
+ model_path_classifier = os.path.join("models", "best_metric_model (4).pth")
28
+
29
+ from monai.networks.nets import DenseNet121
30
+ classifier_model = DenseNet121(
31
+ spatial_dims=2,
32
+ in_channels=3,
33
+ out_channels=len(class_names)
34
+ ).to(device)
35
+
36
+ state_dict = torch.load(model_path_classifier, map_location=device)
37
+ classifier_model.load_state_dict(state_dict, strict=False)
38
+ classifier_model.eval()
39
+
40
+ # A simple transform for classification from a PIL image
41
+ def classify_medical_image_pil(image: Image.Image) -> str:
42
+ transform = transforms.Compose([
43
+ transforms.ToTensor(),
44
+ transforms.Resize((224, 224))
45
+ ])
46
+ image_tensor = transform(image).unsqueeze(0).to(device)
47
+ with torch.no_grad():
48
+ output = classifier_model(image_tensor)
49
+ pred_class = torch.argmax(output, dim=1).item()
50
+ return class_names[pred_class]
51
+
52
+ # -------------------------------
53
+ # SPECIALIZED MODULES
54
+ # -------------------------------
55
+
56
+ # --- A. Brain Tumor Segmentation Module (for HeadCT/HeadMRI) ---
57
+ class DoubleConvUNet(nn.Module):
58
+ def __init__(self, in_channels, out_channels):
59
+ super(DoubleConvUNet, self).__init__()
60
+ self.conv = nn.Sequential(
61
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
62
+ nn.BatchNorm2d(out_channels),
63
+ nn.ReLU(inplace=True),
64
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
65
+ nn.BatchNorm2d(out_channels),
66
+ nn.ReLU(inplace=True)
67
+ )
68
+ def forward(self, x):
69
+ return self.conv(x)
70
+
71
+ class UNetMulti(nn.Module):
72
+ def __init__(self, in_channels=3, out_channels=4):
73
+ super(UNetMulti, self).__init__()
74
+ self.down1 = DoubleConvUNet(in_channels, 64)
75
+ self.pool1 = nn.MaxPool2d(2)
76
+ self.down2 = DoubleConvUNet(64, 128)
77
+ self.pool2 = nn.MaxPool2d(2)
78
+ self.down3 = DoubleConvUNet(128, 256)
79
+ self.pool3 = nn.MaxPool2d(2)
80
+ self.down4 = DoubleConvUNet(256, 512)
81
+ self.pool4 = nn.MaxPool2d(2)
82
+ self.bottleneck = DoubleConvUNet(512, 1024)
83
+ self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
84
+ self.conv4 = DoubleConvUNet(1024, 512)
85
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
86
+ self.conv3 = DoubleConvUNet(512, 256)
87
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
88
+ self.conv2 = DoubleConvUNet(256, 128)
89
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
90
+ self.conv1 = DoubleConvUNet(128, 64)
91
+ self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
92
+
93
+ def forward(self, x):
94
+ c1 = self.down1(x)
95
+ p1 = self.pool1(c1)
96
+ c2 = self.down2(p1)
97
+ p2 = self.pool2(c2)
98
+ c3 = self.down3(p2)
99
+ p3 = self.pool3(c3)
100
+ c4 = self.down4(p3)
101
+ p4 = self.pool4(c4)
102
+ bn = self.bottleneck(p4)
103
+ u4 = self.up4(bn)
104
+ merge4 = torch.cat([u4, c4], dim=1)
105
+ c5 = self.conv4(merge4)
106
+ u3 = self.up3(c5)
107
+ merge3 = torch.cat([u3, c3], dim=1)
108
+ c6 = self.conv3(merge3)
109
+ u2 = self.up2(c6)
110
+ merge2 = torch.cat([u2, c2], dim=1)
111
+ c7 = self.conv2(merge2)
112
+ u1 = self.up1(c7)
113
+ merge1 = torch.cat([u1, c1], dim=1)
114
+ c8 = self.conv1(merge1)
115
+ output = self.final_conv(c8)
116
+ return output
117
+
118
+ def process_brain_tumor(image: Image.Image, model_path=os.path.join("models", "brain_tumor_unet_multiclass.pth")) -> str:
119
+ model = UNetMulti(in_channels=3, out_channels=4).to(device)
120
+ model.load_state_dict(torch.load(model_path, map_location=device))
121
+ model.eval()
122
+
123
+ transform_img = transforms.Compose([
124
+ transforms.Resize((256,256)),
125
+ transforms.ToTensor()
126
+ ])
127
+ input_tensor = transform_img(image).unsqueeze(0).to(device)
128
+ with torch.no_grad():
129
+ output = model(input_tensor)
130
+ preds = torch.argmax(output, dim=1).squeeze().cpu().numpy()
131
+
132
+ image_np = np.array(image.resize((256,256)))
133
+ # Create overlay and blended image
134
+ overlay = cv2.applyColorMap(np.uint8(255 * preds/np.max(preds + 1e-8)), cv2.COLORMAP_JET)
135
+ overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
136
+ blended = cv2.addWeighted(np.uint8(image_np), 0.6, overlay, 0.4, 0)
137
+
138
+ # Create a figure with subplots
139
+ fig, ax = plt.subplots(1, 3, figsize=(18,6))
140
+ ax[0].imshow(image_np)
141
+ ax[0].set_title("Original Image")
142
+ ax[0].axis("off")
143
+ ax[1].imshow(preds, cmap='jet')
144
+ ax[1].set_title("Segmentation Mask")
145
+ ax[1].axis("off")
146
+ ax[2].imshow(blended)
147
+ ax[2].set_title("Overlay")
148
+ ax[2].axis("off")
149
+
150
+ buf = BytesIO()
151
+ fig.savefig(buf, format="png")
152
+ buf.seek(0)
153
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
154
+ plt.close(fig)
155
+ return img_base64
156
+
157
+ # --- B. Endoscopy Polyp Detection Module (Binary UNet) ---
158
+ class UNetBinary(nn.Module):
159
+ def __init__(self, in_channels=3, out_channels=1):
160
+ super(UNetBinary, self).__init__()
161
+ self.down1 = DoubleConvUNet(in_channels, 64)
162
+ self.pool1 = nn.MaxPool2d(2)
163
+ self.down2 = DoubleConvUNet(64, 128)
164
+ self.pool2 = nn.MaxPool2d(2)
165
+ self.down3 = DoubleConvUNet(128, 256)
166
+ self.pool3 = nn.MaxPool2d(2)
167
+ self.down4 = DoubleConvUNet(128, 512)
168
+ self.pool4 = nn.MaxPool2d(2)
169
+ self.bottleneck = DoubleConvUNet(512, 1024)
170
+ self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
171
+ self.conv4 = DoubleConvUNet(1024, 512)
172
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
173
+ self.conv3 = DoubleConvUNet(512, 256)
174
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
175
+ self.conv2 = DoubleConvUNet(256, 128)
176
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
177
+ self.conv1 = DoubleConvUNet(128, 64)
178
+ self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
179
+
180
+ def forward(self, x):
181
+ c1 = self.down1(x)
182
+ p1 = self.pool1(c1)
183
+ c2 = self.down2(p1)
184
+ p2 = self.pool2(c2)
185
+ c3 = self.down3(p2)
186
+ p3 = self.pool3(c3)
187
+ c4 = self.down4(p3)
188
+ p4 = self.pool4(c4)
189
+ bn = self.bottleneck(p4)
190
+ u4 = self.up4(bn)
191
+ merge4 = torch.cat([u4, c4], dim=1)
192
+ c5 = self.conv4(merge4)
193
+ u3 = self.up3(c5)
194
+ merge3 = torch.cat([u3, c3], dim=1)
195
+ c6 = self.conv3(merge3)
196
+ u2 = self.up2(c6)
197
+ merge2 = torch.cat([u2, c2], dim=1)
198
+ c7 = self.conv2(merge2)
199
+ u1 = self.up1(c7)
200
+ merge1 = torch.cat([u1, c1], dim=1)
201
+ c8 = self.conv1(merge1)
202
+ output = self.final_conv(c8)
203
+ return output
204
+
205
+ def process_endoscopy(image: Image.Image, model_path=os.path.join("models", "endoscopy_unet.pth")) -> str:
206
+ model = UNetBinary(in_channels=3, out_channels=1).to(device)
207
+ model.load_state_dict(torch.load(model_path, map_location=device))
208
+ model.eval()
209
+
210
+ transform_img = transforms.Compose([
211
+ transforms.Resize((256,256)),
212
+ transforms.ToTensor()
213
+ ])
214
+ input_tensor = transform_img(image).unsqueeze(0).to(device)
215
+ with torch.no_grad():
216
+ output = model(input_tensor)
217
+ prob = torch.sigmoid(output)
218
+ mask = (prob > 0.5).float().squeeze().cpu().numpy()
219
+
220
+ image_np = np.array(image.resize((256,256)))
221
+ overlay = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
222
+ overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
223
+ blended = cv2.addWeighted(np.uint8(image_np), 0.6, overlay, 0.4, 0)
224
+
225
+ fig, ax = plt.subplots(1, 3, figsize=(18,6))
226
+ ax[0].imshow(image_np)
227
+ ax[0].set_title("Actual Image")
228
+ ax[0].axis("off")
229
+ ax[1].imshow(mask, cmap='gray')
230
+ ax[1].set_title("Segmentation Mask")
231
+ ax[1].axis("off")
232
+ ax[2].imshow(blended)
233
+ ax[2].set_title("Overlay")
234
+ ax[2].axis("off")
235
+
236
+ buf = BytesIO()
237
+ fig.savefig(buf, format="png")
238
+ buf.seek(0)
239
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
240
+ plt.close(fig)
241
+ return img_base64
242
+
243
+ # --- C. Pneumonia Detection Module (Using Grad-CAM on ResNet18) ---
244
+ class GradCAM_Pneumonia:
245
+ def __init__(self, model, target_layer):
246
+ self.model = model
247
+ self.target_layer = target_layer
248
+ self.gradients = None
249
+ self.activations = None
250
+ self.hook_handles = []
251
+ self._register_hooks()
252
+
253
+ def _register_hooks(self):
254
+ def forward_hook(module, input, output):
255
+ self.activations = output.detach()
256
+ def backward_hook(module, grad_in, grad_out):
257
+ self.gradients = grad_out[0].detach()
258
+ handle1 = self.target_layer.register_forward_hook(forward_hook)
259
+ handle2 = self.target_layer.register_backward_hook(backward_hook)
260
+ self.hook_handles.extend([handle1, handle2])
261
+
262
+ def remove_hooks(self):
263
+ for handle in self.hook_handles:
264
+ handle.remove()
265
+
266
+ def generate(self, input_image, target_class=None):
267
+ output = self.model(input_image)
268
+ if target_class is None:
269
+ target_class = output.argmax(dim=1).item()
270
+ self.model.zero_grad()
271
+ one_hot = torch.zeros_like(output)
272
+ one_hot[0, target_class] = 1
273
+ output.backward(gradient=one_hot, retain_graph=True)
274
+ weights = self.gradients.mean(dim=(2,3), keepdim=True)
275
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
276
+ cam = F.relu(cam)
277
+ cam = cam.squeeze().cpu().numpy()
278
+ _, _, H, W = input_image.shape
279
+ cam = cv2.resize(cam, (W, H))
280
+ cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
281
+ return cam, output
282
+
283
+ def process_pneumonia(image: Image.Image, model_path=os.path.join("models", "pneumonia_resnet18.pth")) -> str:
284
+ model = models.resnet18(pretrained=False)
285
+ num_ftrs = model.fc.in_features
286
+ model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia
287
+ model.load_state_dict(torch.load(model_path, map_location=device))
288
+ model.to(device)
289
+ model.eval()
290
+
291
+ grad_cam = GradCAM_Pneumonia(model, model.layer4)
292
+
293
+ transform_img = transforms.Compose([
294
+ transforms.Resize((224,224)),
295
+ transforms.ToTensor(),
296
+ transforms.Normalize(mean=[0.485,0.456,0.406],
297
+ std=[0.229,0.224,0.225])
298
+ ])
299
+ input_tensor = transform_img(image).unsqueeze(0).to(device)
300
+ with torch.no_grad():
301
+ cam, output = grad_cam.generate(input_tensor)
302
+ predicted_class = output.argmax(dim=1).item()
303
+ label_text = "Pneumonia" if predicted_class == 1 else "Normal"
304
+
305
+ def get_bounding_box(heatmap, thresh=0.5, min_area=100):
306
+ heat_uint8 = np.uint8(255 * heatmap)
307
+ ret, binary = cv2.threshold(heat_uint8, int(thresh*255), 255, cv2.THRESH_BINARY)
308
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
309
+ if len(contours)==0:
310
+ return None
311
+ largest = max(contours, key=cv2.contourArea)
312
+ if cv2.contourArea(largest) < min_area:
313
+ return None
314
+ x, y, w, h = cv2.boundingRect(largest)
315
+ return (x, y, w, h)
316
+
317
+ bbox = None
318
+ if predicted_class == 1:
319
+ bbox = get_bounding_box(cam, thresh=0.5, min_area=100)
320
+
321
+ resized_image = image.resize((224,224))
322
+ image_np = np.array(resized_image)
323
+ overlay = image_np.copy()
324
+ if bbox is not None:
325
+ x, y, w, h = bbox
326
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), (255,0,0), 2)
327
+ cv2.putText(overlay, label_text, (10,25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,0),2)
328
+
329
+ heatmap_color = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
330
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
331
+
332
+ fig, ax = plt.subplots(1, 3, figsize=(18,6))
333
+ ax[0].imshow(image_np)
334
+ ax[0].set_title("Actual Image")
335
+ ax[0].axis("off")
336
+ ax[1].imshow(heatmap_color)
337
+ ax[1].set_title("Detected Output (Heatmap)")
338
+ ax[1].axis("off")
339
+ ax[2].imshow(overlay)
340
+ ax[2].set_title("Boxed Overlay")
341
+ ax[2].axis("off")
342
+
343
+ buf = BytesIO()
344
+ fig.savefig(buf, format="png")
345
+ buf.seek(0)
346
+ img_base64 = base64.b64encode(buf.read()).decode("utf-8")
347
+ plt.close(fig)
348
+ grad_cam.remove_hooks()
349
+ return img_base64
350
+
351
+ # -------------------------------
352
+ # COMPLETE PIPELINE FUNCTION
353
+ # -------------------------------
354
+ def complete_pipeline_image(image: Image.Image) -> dict:
355
+ predicted_modality = classify_medical_image_pil(image)
356
+ result = {"predicted_modality": predicted_modality}
357
+
358
+ if predicted_modality in ["HeadCT", "HeadMRI"]:
359
+ result_overlay = process_brain_tumor(image)
360
+ result["segmentation_result"] = result_overlay
361
+ elif predicted_modality == "Endoscopy":
362
+ result_overlay = process_endoscopy(image)
363
+ result["segmentation_result"] = result_overlay
364
+ elif predicted_modality == "Chest Xray":
365
+ result_overlay = process_pneumonia(image)
366
+ result["segmentation_result"] = result_overlay
367
+ else:
368
+ # For modalities without specialized processing, return the original image as base64
369
+ buf = BytesIO()
370
+ image.save(buf, format="PNG")
371
+ result["segmentation_result"] = base64.b64encode(buf.getvalue()).decode("utf-8")
372
+ return result
disease_sympts_prec_full.csv ADDED
The diff for this file is too large to render. See raw diff
 
images.jpg ADDED
models/best_metric_model (4).pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a5e578d9e93b089eed0bdbdaa50237209fce830497136593664b16d4df720ee
3
+ size 28471314
models/brain_tumor_unet_multiclass.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b11bc3ebaa317154e9530f8151ce6bf7efa407abeedc62454502099889dbfe42
3
+ size 124269778
models/endoscopy_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cac633092591e1c515da87ed7bd54de35ca4c50fec41a646b8cfef5b1e15afde
3
+ size 124267126
models/pneumonia_resnet18.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a9bca18323c658623f0e62207e6c3331836ad1825a8f9c9d3aa118709d2614a
3
+ size 44790376
oligodendroglioma-banner.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pymongo
4
+ pandas
5
+ faiss-cpu
6
+ numpy
7
+ scikit-learn
8
+ transformers
9
+ torch
10
+ torchvision
11
+ monai
12
+ pydicom
13
+ nibabel
14
+ opencv-python
15
+ Pillow
16
+ matplotlib
symptom_assessment.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.feature_extraction.text import TfidfVectorizer
2
+
3
+ class SymptomAssessment:
4
+ def __init__(self):
5
+ # Example disease-symptom mapping dictionary.
6
+ # In practice, replace this with a robust dataset.
7
+ self.disease_symptoms = {
8
+ "Flu": ["fever", "cough", "sore throat", "fatigue"],
9
+ "Migraine": ["headache", "nausea", "sensitivity to light"],
10
+ "COVID-19": ["fever", "cough", "shortness of breath", "loss of taste"]
11
+ }
12
+ # Prepare vector space for diseases
13
+ self.vectorizer = TfidfVectorizer()
14
+ self.diseases = list(self.disease_symptoms.keys())
15
+ symptom_texts = [" ".join(self.disease_symptoms[d]) for d in self.diseases]
16
+ self.vectors = self.vectorizer.fit_transform(symptom_texts)
17
+
18
+ def assess(self, symptoms_list):
19
+ """
20
+ Given a list of reported symptoms, determine the best matching disease
21
+ and identify which expected symptoms are missing.
22
+ """
23
+ input_text = " ".join(symptoms_list)
24
+ input_vector = self.vectorizer.transform([input_text])
25
+ similarities = (self.vectors * input_vector.T).toarray().flatten()
26
+ best_match_index = similarities.argmax()
27
+ best_disease = self.diseases[best_match_index]
28
+ missing_symptoms = list(set(self.disease_symptoms[best_disease]) - set(symptoms_list))
29
+ assessment = (f"Based on the input symptoms, {best_disease} is suspected. "
30
+ f"Missing symptoms for improved diagnosis: {missing_symptoms}")
31
+ return missing_symptoms, assessment
templates/index.html ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <title>Medical Image Processing Pipeline</title>
6
+ <style>
7
+ body { font-family: Arial, sans-serif; margin: 20px; }
8
+ .container { max-width: 800px; margin: auto; }
9
+ .result img { max-width: 250px; margin: 10px; }
10
+ .result { display: flex; flex-wrap: wrap; }
11
+ </style>
12
+ </head>
13
+ <body>
14
+ <div class="container">
15
+ <h1>Medical Image Processing Pipeline</h1>
16
+ <form action="/predict" method="POST" enctype="multipart/form-data">
17
+ <input type="file" name="file" accept="image/*,.nii,.nii.gz,.dcm" required>
18
+ <button type="submit">Upload and Process</button>
19
+ </form>
20
+ {% if result %}
21
+ <hr>
22
+ {% if result.error %}
23
+ <h3>Error: {{ result.error }}</h3>
24
+ {% else %}
25
+ <h3>Predicted Modality: {{ result.predicted_modality }}</h3>
26
+ {% if result.specialized %}
27
+ <div class="result">
28
+ <div>
29
+ <h4>Original Image</h4>
30
+ <img src="data:image/jpeg;base64,{{ result.specialized.original }}" alt="Original">
31
+ </div>
32
+ <div>
33
+ <h4>Mask</h4>
34
+ <img src="data:image/jpeg;base64,{{ result.specialized.mask }}" alt="Mask">
35
+ </div>
36
+ <div>
37
+ <h4>Overlay</h4>
38
+ <img src="data:image/jpeg;base64,{{ result.specialized.overlay }}" alt="Overlay">
39
+ </div>
40
+ </div>
41
+ {% elif result.message %}
42
+ <h4>{{ result.message }}</h4>
43
+ <div class="result">
44
+ <img src="data:image/jpeg;base64,{{ result.original }}" alt="Original">
45
+ </div>
46
+ {% endif %}
47
+ {% endif %}
48
+ {% endif %}
49
+ </div>
50
+ </body>
51
+ </html>