import os import torch import numpy as np import nibabel as nib import pydicom import cv2 from PIL import Image import matplotlib.pyplot as plt import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from torchvision import transforms from monai.transforms import EnsureChannelFirst, ScaleIntensity, Resize, ToTensor from io import BytesIO import base64 # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------------- # CLASSIFIER MODULE (Medical Classifier) # ------------------------------- class_names = ['AbdomenCT', 'BreastMRI', 'Chest Xray', 'ChestCT', 'Endoscopy', 'Hand Xray', 'HeadCT', 'HeadMRI'] # Update model path to load from models folder model_path_classifier = os.path.join("models", "best_metric_model (4).pth") from monai.networks.nets import DenseNet121 classifier_model = DenseNet121( spatial_dims=2, in_channels=3, out_channels=len(class_names) ).to(device) state_dict = torch.load(model_path_classifier, map_location=device) classifier_model.load_state_dict(state_dict, strict=False) classifier_model.eval() # A simple transform for classification from a PIL image def classify_medical_image_pil(image: Image.Image) -> str: transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)) ]) image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = classifier_model(image_tensor) pred_class = torch.argmax(output, dim=1).item() return class_names[pred_class] # ------------------------------- # SPECIALIZED MODULES # ------------------------------- # --- A. Brain Tumor Segmentation Module (for HeadCT/HeadMRI) --- class DoubleConvUNet(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConvUNet, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNetMulti(nn.Module): def __init__(self, in_channels=3, out_channels=4): super(UNetMulti, self).__init__() self.down1 = DoubleConvUNet(in_channels, 64) self.pool1 = nn.MaxPool2d(2) self.down2 = DoubleConvUNet(64, 128) self.pool2 = nn.MaxPool2d(2) self.down3 = DoubleConvUNet(128, 256) self.pool3 = nn.MaxPool2d(2) self.down4 = DoubleConvUNet(256, 512) self.pool4 = nn.MaxPool2d(2) self.bottleneck = DoubleConvUNet(512, 1024) self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.conv4 = DoubleConvUNet(1024, 512) self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.conv3 = DoubleConvUNet(512, 256) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.conv2 = DoubleConvUNet(256, 128) self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.conv1 = DoubleConvUNet(128, 64) self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) def forward(self, x): c1 = self.down1(x) p1 = self.pool1(c1) c2 = self.down2(p1) p2 = self.pool2(c2) c3 = self.down3(p2) p3 = self.pool3(c3) c4 = self.down4(p3) p4 = self.pool4(c4) bn = self.bottleneck(p4) u4 = self.up4(bn) merge4 = torch.cat([u4, c4], dim=1) c5 = self.conv4(merge4) u3 = self.up3(c5) merge3 = torch.cat([u3, c3], dim=1) c6 = self.conv3(merge3) u2 = self.up2(c6) merge2 = torch.cat([u2, c2], dim=1) c7 = self.conv2(merge2) u1 = self.up1(c7) merge1 = torch.cat([u1, c1], dim=1) c8 = self.conv1(merge1) output = self.final_conv(c8) return output def process_brain_tumor(image: Image.Image, model_path=os.path.join("models", "brain_tumor_unet_multiclass.pth")) -> str: model = UNetMulti(in_channels=3, out_channels=4).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() transform_img = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) input_tensor = transform_img(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) preds = torch.argmax(output, dim=1).squeeze().cpu().numpy() image_np = np.array(image.resize((256,256))) # Create overlay and blended image overlay = cv2.applyColorMap(np.uint8(255 * preds/np.max(preds + 1e-8)), cv2.COLORMAP_JET) overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) blended = cv2.addWeighted(np.uint8(image_np), 0.6, overlay, 0.4, 0) # Create a figure with subplots fig, ax = plt.subplots(1, 3, figsize=(18,6)) ax[0].imshow(image_np) ax[0].set_title("Original Image") ax[0].axis("off") ax[1].imshow(preds, cmap='jet') ax[1].set_title("Segmentation Mask") ax[1].axis("off") ax[2].imshow(blended) ax[2].set_title("Overlay") ax[2].axis("off") buf = BytesIO() fig.savefig(buf, format="png") buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode("utf-8") plt.close(fig) return img_base64 # --- B. Endoscopy Polyp Detection Module (Binary UNet) --- class UNetBinary(nn.Module): def __init__(self, in_channels=3, out_channels=1): super(UNetBinary, self).__init__() self.down1 = DoubleConvUNet(in_channels, 64) self.pool1 = nn.MaxPool2d(2) self.down2 = DoubleConvUNet(64, 128) self.pool2 = nn.MaxPool2d(2) self.down3 = DoubleConvUNet(128, 256) self.pool3 = nn.MaxPool2d(2) self.down4 = DoubleConvUNet(128, 512) self.pool4 = nn.MaxPool2d(2) self.bottleneck = DoubleConvUNet(512, 1024) self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.conv4 = DoubleConvUNet(1024, 512) self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.conv3 = DoubleConvUNet(512, 256) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.conv2 = DoubleConvUNet(256, 128) self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.conv1 = DoubleConvUNet(128, 64) self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) def forward(self, x): c1 = self.down1(x) p1 = self.pool1(c1) c2 = self.down2(p1) p2 = self.pool2(c2) c3 = self.down3(p2) p3 = self.pool3(c3) c4 = self.down4(p3) p4 = self.pool4(c4) bn = self.bottleneck(p4) u4 = self.up4(bn) merge4 = torch.cat([u4, c4], dim=1) c5 = self.conv4(merge4) u3 = self.up3(c5) merge3 = torch.cat([u3, c3], dim=1) c6 = self.conv3(merge3) u2 = self.up2(c6) merge2 = torch.cat([u2, c2], dim=1) c7 = self.conv2(merge2) u1 = self.up1(c7) merge1 = torch.cat([u1, c1], dim=1) c8 = self.conv1(merge1) output = self.final_conv(c8) return output def process_endoscopy(image: Image.Image, model_path=os.path.join("models", "endoscopy_unet.pth")) -> str: model = UNetBinary(in_channels=3, out_channels=1).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() transform_img = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) input_tensor = transform_img(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) prob = torch.sigmoid(output) mask = (prob > 0.5).float().squeeze().cpu().numpy() image_np = np.array(image.resize((256,256))) overlay = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) blended = cv2.addWeighted(np.uint8(image_np), 0.6, overlay, 0.4, 0) fig, ax = plt.subplots(1, 3, figsize=(18,6)) ax[0].imshow(image_np) ax[0].set_title("Actual Image") ax[0].axis("off") ax[1].imshow(mask, cmap='gray') ax[1].set_title("Segmentation Mask") ax[1].axis("off") ax[2].imshow(blended) ax[2].set_title("Overlay") ax[2].axis("off") buf = BytesIO() fig.savefig(buf, format="png") buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode("utf-8") plt.close(fig) return img_base64 # --- C. Pneumonia Detection Module (Using Grad-CAM on ResNet18) --- class GradCAM_Pneumonia: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None self.hook_handles = [] self._register_hooks() def _register_hooks(self): def forward_hook(module, input, output): self.activations = output.detach() def backward_hook(module, grad_in, grad_out): self.gradients = grad_out[0].detach() handle1 = self.target_layer.register_forward_hook(forward_hook) handle2 = self.target_layer.register_backward_hook(backward_hook) self.hook_handles.extend([handle1, handle2]) def remove_hooks(self): for handle in self.hook_handles: handle.remove() def generate(self, input_image, target_class=None): output = self.model(input_image) if target_class is None: target_class = output.argmax(dim=1).item() self.model.zero_grad() one_hot = torch.zeros_like(output) one_hot[0, target_class] = 1 output.backward(gradient=one_hot, retain_graph=True) weights = self.gradients.mean(dim=(2,3), keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = cam.squeeze().cpu().numpy() _, _, H, W = input_image.shape cam = cv2.resize(cam, (W, H)) cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8) return cam, output def process_pneumonia(image: Image.Image, model_path=os.path.join("models", "pneumonia_resnet18.pth")) -> str: model = models.resnet18(pretrained=False) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() grad_cam = GradCAM_Pneumonia(model, model.layer4) transform_img = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) input_tensor = transform_img(image).unsqueeze(0).to(device) cam, output = grad_cam.generate(input_tensor) predicted_class = output.argmax(dim=1).item() label_text = "Pneumonia" if predicted_class == 1 else "Normal" def get_bounding_box(heatmap, thresh=0.5, min_area=100): heat_uint8 = np.uint8(255 * heatmap) ret, binary = cv2.threshold(heat_uint8, int(thresh*255), 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours)==0: return None largest = max(contours, key=cv2.contourArea) if cv2.contourArea(largest) < min_area: return None x, y, w, h = cv2.boundingRect(largest) return (x, y, w, h) bbox = None if predicted_class == 1: bbox = get_bounding_box(cam, thresh=0.5, min_area=100) resized_image = image.resize((224,224)) image_np = np.array(resized_image) overlay = image_np.copy() if bbox is not None: x, y, w, h = bbox cv2.rectangle(overlay, (x, y), (x+w, y+h), (255,0,0), 2) cv2.putText(overlay, label_text, (10,25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,0),2) heatmap_color = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET) heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) fig, ax = plt.subplots(1, 3, figsize=(18,6)) ax[0].imshow(image_np) ax[0].set_title("Actual Image") ax[0].axis("off") ax[1].imshow(heatmap_color) ax[1].set_title("Detected Output (Heatmap)") ax[1].axis("off") ax[2].imshow(overlay) ax[2].set_title("Boxed Overlay") ax[2].axis("off") buf = BytesIO() fig.savefig(buf, format="png") buf.seek(0) img_base64 = base64.b64encode(buf.read()).decode("utf-8") plt.close(fig) grad_cam.remove_hooks() return img_base64 # ------------------------------- # COMPLETE PIPELINE FUNCTION # ------------------------------- def complete_pipeline_image(image: Image.Image) -> dict: predicted_modality = classify_medical_image_pil(image) result = {"predicted_modality": predicted_modality} if predicted_modality in ["HeadCT", "HeadMRI"]: result_overlay = process_brain_tumor(image) result["segmentation_result"] = result_overlay elif predicted_modality == "Endoscopy": result_overlay = process_endoscopy(image) result["segmentation_result"] = result_overlay elif predicted_modality == "Chest Xray": result_overlay = process_pneumonia(image) result["segmentation_result"] = result_overlay else: # For modalities without specialized processing, return the original image as base64 buf = BytesIO() image.save(buf, format="PNG") result["segmentation_result"] = base64.b64encode(buf.getvalue()).decode("utf-8") return result