# app.py import os, io, base64, cv2, torch, numpy as np from PIL import Image from flask import Flask, request, render_template, jsonify import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import torchvision.transforms as transforms from monai.transforms import EnsureChannelFirst, ScaleIntensity, Resize, ToTensor # Enable debug logging import logging logging.basicConfig(level=logging.DEBUG) # ------------------------------- # Global Setup # ------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def pil_to_base64(pil_img): buff = io.BytesIO() pil_img.save(buff, format="JPEG") return base64.b64encode(buff.getvalue()).decode("utf-8") # ------------------------------- # 1. CLASSIFIER MODULE (DenseNet121 via MONAI) # ------------------------------- CLASS_NAMES = ['AbdomenCT', 'BreastMRI', 'Chest Xray', 'ChestCT', 'Endoscopy', 'Hand Xray', 'HeadCT', 'HeadMRI'] from monai.networks.nets import DenseNet121 def load_classifier_model(model_path): model = DenseNet121( spatial_dims=2, in_channels=3, out_channels=len(CLASS_NAMES) ).to(device) state_dict = torch.load(model_path, map_location=device) if isinstance(state_dict, dict) and "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(state_dict, strict=False) model.eval() return model def load_and_preprocess_image_classifier(image_path): image_path = image_path.strip() if image_path.lower().endswith((".jpg", ".jpeg", ".png")): image = Image.open(image_path).convert("RGB") image = np.array(image) elif image_path.lower().endswith((".nii", ".nii.gz")): import nibabel as nib image = nib.load(image_path).get_fdata() image = np.squeeze(image) if len(image.shape) == 4: image = image[..., 0] if len(image.shape) == 3: image = image[:, :, image.shape[2] // 2] if len(image.shape) == 2: image = np.stack([image]*3, axis=-1) elif image_path.lower().endswith(".dcm"): import pydicom dicom_data = pydicom.dcmread(image_path) image = dicom_data.pixel_array if len(image.shape) == 2: image = np.stack([image]*3, axis=-1) else: raise ValueError("Unsupported file format!") if len(image.shape) == 3 and image.shape[-1] == 3: image = np.transpose(image, (2, 0, 1)) else: raise ValueError(f"Unexpected image shape: {image.shape}") image = torch.tensor(image, dtype=torch.float32) image = ScaleIntensity()(image) image = Resize((224,224))(image) image = image.unsqueeze(0) return image.to(device) def classify_medical_image(image_path, classifier_model): image_tensor = load_and_preprocess_image_classifier(image_path) with torch.no_grad(): output = classifier_model(image_tensor) pred_class = torch.argmax(output, dim=1).item() return CLASS_NAMES[pred_class] # ------------------------------- # 2. BRAIN TUMOR SEGMENTATION MODULE (UNetMulti) # ------------------------------- class DoubleConvUNet(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConvUNet, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNetMulti(nn.Module): def __init__(self, in_channels=3, out_channels=4): super(UNetMulti, self).__init__() self.down1 = DoubleConvUNet(in_channels, 64) self.pool1 = nn.MaxPool2d(2) self.down2 = DoubleConvUNet(64, 128) self.pool2 = nn.MaxPool2d(2) self.down3 = DoubleConvUNet(128, 256) self.pool3 = nn.MaxPool2d(2) self.down4 = DoubleConvUNet(256, 512) self.pool4 = nn.MaxPool2d(2) self.bottleneck = DoubleConvUNet(512, 1024) self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv4 = DoubleConvUNet(1024, 512) self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.conv3 = DoubleConvUNet(512, 256) self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.conv2 = DoubleConvUNet(256, 128) self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv1 = DoubleConvUNet(128, 64) self.final_conv = nn.Conv2d(64, out_channels, 1) def forward(self, x): c1 = self.down1(x) p1 = self.pool1(c1) c2 = self.down2(p1) p2 = self.pool2(c2) c3 = self.down3(p2) p3 = self.pool3(c3) c4 = self.down4(p3) p4 = self.pool4(c4) bn = self.bottleneck(p4) u4 = self.up4(bn) merge4 = torch.cat([u4, c4], dim=1) c5 = self.conv4(merge4) u3 = self.up3(c5) merge3 = torch.cat([u3, c3], dim=1) c6 = self.conv3(merge3) u2 = self.up2(c6) merge2 = torch.cat([u2, c2], dim=1) c7 = self.conv2(merge2) u1 = self.up1(c7) merge1 = torch.cat([u1, c1], dim=1) c8 = self.conv1(merge1) return self.final_conv(c8) def process_brain_tumor_return(image, model_path="models/brain_tumor_unet_multiclass.pth"): logging.debug("Processing brain tumor segmentation") model = UNetMulti(in_channels=3, out_channels=4).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() transform_img = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) input_tensor = transform_img(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) preds = torch.argmax(output, dim=1).squeeze().cpu().numpy() image_np = transform_img(image).permute(1,2,0).cpu().numpy() overlay = cv2.applyColorMap(np.uint8(255 * preds/np.max(preds+1e-8)), cv2.COLORMAP_JET) overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0) orig_pil = Image.fromarray((image_np*255).astype(np.uint8)) mask_pil = Image.fromarray(overlay) overlay_pil = Image.fromarray(blended) return { "original": pil_to_base64(orig_pil), "mask": pil_to_base64(mask_pil), "overlay": pil_to_base64(overlay_pil) } # ------------------------------- # 3. ENDOSCOPY POLYP DETECTION MODULE (Binary UNet) # ------------------------------- class UNetBinary(nn.Module): def __init__(self, in_channels=3, out_channels=1): super(UNetBinary, self).__init__() self.down1 = DoubleConvUNet(in_channels, 64) self.pool1 = nn.MaxPool2d(2) self.down2 = DoubleConvUNet(64, 128) self.pool2 = nn.MaxPool2d(2) self.down3 = DoubleConvUNet(128, 256) self.pool3 = nn.MaxPool2d(2) self.down4 = DoubleConvUNet(256, 512) self.pool4 = nn.MaxPool2d(2) self.bottleneck = DoubleConvUNet(512, 1024) self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2) self.conv4 = DoubleConvUNet(1024, 512) self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2) self.conv3 = DoubleConvUNet(512, 256) self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2) self.conv2 = DoubleConvUNet(256, 128) self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv1 = DoubleConvUNet(128, 64) self.final_conv = nn.Conv2d(64, out_channels, 1) def forward(self, x): c1 = self.down1(x) p1 = self.pool1(c1) c2 = self.down2(p1) p2 = self.pool2(c2) c3 = self.down3(p2) p3 = self.pool3(c3) c4 = self.down4(p3) p4 = self.pool4(c4) bn = self.bottleneck(p4) u4 = self.up4(bn) merge4 = torch.cat([u4, c4], dim=1) c5 = self.conv4(merge4) u3 = self.up3(c5) merge3 = torch.cat([u3, c3], dim=1) c6 = self.conv3(merge3) u2 = self.up2(c6) merge2 = torch.cat([u2, c2], dim=1) c7 = self.conv2(merge2) u1 = self.up1(c7) merge1 = torch.cat([u1, c1], dim=1) c8 = self.conv1(merge1) return self.final_conv(c8) def process_endoscopy_return(image, model_path="models/endoscopy_unet.pth"): model = UNetBinary(in_channels=3, out_channels=1).to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() transform_img = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor() ]) input_tensor = transform_img(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) prob = torch.sigmoid(output) mask = (prob > 0.5).float().squeeze().cpu().numpy() image_np = transform_img(image).permute(1,2,0).cpu().numpy() overlay = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0) orig_pil = Image.fromarray((image_np*255).astype(np.uint8)) mask_pil = Image.fromarray(overlay) overlay_pil = Image.fromarray(blended) return { "original": pil_to_base64(orig_pil), "mask": pil_to_base64(mask_pil), "overlay": pil_to_base64(overlay_pil) } # ------------------------------- # 4. PNEUMONIA DETECTION MODULE (Grad-CAM on ResNet18) # ------------------------------- class GradCAM_Pneumonia: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None self.hook_handles = [] self._register_hooks() def _register_hooks(self): def forward_hook(module, input, output): self.activations = output.detach() def backward_hook(module, grad_in, grad_out): self.gradients = grad_out[0].detach() handle1 = self.target_layer.register_forward_hook(forward_hook) handle2 = self.target_layer.register_backward_hook(backward_hook) self.hook_handles.extend([handle1, handle2]) def remove_hooks(self): for handle in self.hook_handles: handle.remove() def generate(self, input_image, target_class=None): output = self.model(input_image) if target_class is None: target_class = output.argmax(dim=1).item() self.model.zero_grad() one_hot = torch.zeros_like(output) one_hot[0, target_class] = 1 with torch.enable_grad(): output.backward(gradient=one_hot, retain_graph=True) weights = self.gradients.mean(dim=(2,3), keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = cam.squeeze().cpu().numpy() _, _, H, W = input_image.shape cam = cv2.resize(cam, (W, H)) cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8) return cam, output def process_pneumonia_return(image, model_path="models/pneumonia_resnet18.pth"): model = models.resnet18(pretrained=False) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() grad_cam = GradCAM_Pneumonia(model, model.layer4) transform_img = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) input_tensor = transform_img(image).unsqueeze(0).to(device) # Enable gradient tracking for the input tensor input_tensor.requires_grad_() # Do NOT wrap the following call with torch.no_grad() cam, output = grad_cam.generate(input_tensor) predicted_class = output.argmax(dim=1).item() label_text = "Pneumonia" if predicted_class == 1 else "Normal" def get_bounding_box(heatmap, thresh=0.5, min_area=100): heat_uint8 = np.uint8(255 * heatmap) ret, binary = cv2.threshold(heat_uint8, int(thresh*255), 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours)==0: return None largest = max(contours, key=cv2.contourArea) if cv2.contourArea(largest) < min_area: return None x, y, w, h = cv2.boundingRect(largest) return (x, y, w, h) bbox = None if predicted_class == 1: bbox = get_bounding_box(cam, thresh=0.5, min_area=100) resized_image = image.resize((224,224)) image_np = np.array(resized_image) overlay = image_np.copy() if bbox is not None: x, y, w, h = bbox cv2.rectangle(overlay, (x, y), (x+w, y+h), (255,0,0), 2) cv2.putText(overlay, label_text, (10,25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,0),2) heatmap_color = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET) heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) orig_pil = Image.fromarray(image_np) heatmap_pil = Image.fromarray(heatmap_color) overlay_pil = Image.fromarray(overlay) grad_cam.remove_hooks() return { "original": pil_to_base64(orig_pil), "mask": pil_to_base64(heatmap_pil), "overlay": pil_to_base64(overlay_pil) } # ------------------------------- # 5. COMPLETE PIPELINE FUNCTION # ------------------------------- def complete_pipeline(image_path): classifier_model = load_classifier_model("models/best_metric_model (4).pth") predicted_modality = classify_medical_image(image_path, classifier_model) print(f"Detected modality: {predicted_modality}") original_image = Image.open(image_path).convert("RGB") results = {"predicted_modality": predicted_modality} if predicted_modality in ["HeadCT", "HeadMRI"]: results["specialized"] = process_brain_tumor_return(original_image, "models/brain_tumor_unet_multiclass.pth") elif predicted_modality == "Endoscopy": results["specialized"] = process_endoscopy_return(original_image, "models/endoscopy_unet.pth") elif predicted_modality == "Chest Xray": results["specialized"] = process_pneumonia_return(original_image, "models/pneumonia_resnet18.pth") else: results["message"] = f"No specialized processing for modality: {predicted_modality}" return results # ------------------------------- # 6. FLASK API SETUP # ------------------------------- from flask import Flask, request, render_template, jsonify app = Flask(__name__) @app.route('/', methods=['GET']) def index(): return render_template("index.html", result=None) @app.route('/predict', methods=['POST']) def predict(): if 'file' not in request.files: return render_template("index.html", result={"error": "No file part in the request."}) file = request.files['file'] if file.filename == '': return render_template("index.html", result={"error": "No file selected."}) temp_path = "temp_input.jpg" file.save(temp_path) try: result = complete_pipeline(temp_path) except Exception as e: result = {"error": str(e)} os.remove(temp_path) return render_template("index.html", result=result) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=True)