Spaces:
Sleeping
Sleeping
# app.py | |
import os, io, base64, cv2, torch, numpy as np | |
from PIL import Image | |
from flask import Flask, request, render_template, jsonify | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from monai.transforms import EnsureChannelFirst, ScaleIntensity, Resize, ToTensor | |
# Enable debug logging | |
import logging | |
logging.basicConfig(level=logging.DEBUG) | |
# ------------------------------- | |
# Global Setup | |
# ------------------------------- | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def pil_to_base64(pil_img): | |
buff = io.BytesIO() | |
pil_img.save(buff, format="JPEG") | |
return base64.b64encode(buff.getvalue()).decode("utf-8") | |
# ------------------------------- | |
# 1. CLASSIFIER MODULE (DenseNet121 via MONAI) | |
# ------------------------------- | |
CLASS_NAMES = ['AbdomenCT', 'BreastMRI', 'Chest Xray', 'ChestCT', | |
'Endoscopy', 'Hand Xray', 'HeadCT', 'HeadMRI'] | |
from monai.networks.nets import DenseNet121 | |
def load_classifier_model(model_path): | |
model = DenseNet121( | |
spatial_dims=2, | |
in_channels=3, | |
out_channels=len(CLASS_NAMES) | |
).to(device) | |
state_dict = torch.load(model_path, map_location=device) | |
if isinstance(state_dict, dict) and "state_dict" in state_dict: | |
state_dict = state_dict["state_dict"] | |
model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
return model | |
def load_and_preprocess_image_classifier(image_path): | |
image_path = image_path.strip() | |
if image_path.lower().endswith((".jpg", ".jpeg", ".png")): | |
image = Image.open(image_path).convert("RGB") | |
image = np.array(image) | |
elif image_path.lower().endswith((".nii", ".nii.gz")): | |
import nibabel as nib | |
image = nib.load(image_path).get_fdata() | |
image = np.squeeze(image) | |
if len(image.shape) == 4: | |
image = image[..., 0] | |
if len(image.shape) == 3: | |
image = image[:, :, image.shape[2] // 2] | |
if len(image.shape) == 2: | |
image = np.stack([image]*3, axis=-1) | |
elif image_path.lower().endswith(".dcm"): | |
import pydicom | |
dicom_data = pydicom.dcmread(image_path) | |
image = dicom_data.pixel_array | |
if len(image.shape) == 2: | |
image = np.stack([image]*3, axis=-1) | |
else: | |
raise ValueError("Unsupported file format!") | |
if len(image.shape) == 3 and image.shape[-1] == 3: | |
image = np.transpose(image, (2, 0, 1)) | |
else: | |
raise ValueError(f"Unexpected image shape: {image.shape}") | |
image = torch.tensor(image, dtype=torch.float32) | |
image = ScaleIntensity()(image) | |
image = Resize((224,224))(image) | |
image = image.unsqueeze(0) | |
return image.to(device) | |
def classify_medical_image(image_path, classifier_model): | |
image_tensor = load_and_preprocess_image_classifier(image_path) | |
with torch.no_grad(): | |
output = classifier_model(image_tensor) | |
pred_class = torch.argmax(output, dim=1).item() | |
return CLASS_NAMES[pred_class] | |
# ------------------------------- | |
# 2. BRAIN TUMOR SEGMENTATION MODULE (UNetMulti) | |
# ------------------------------- | |
class DoubleConvUNet(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(DoubleConvUNet, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, 3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x): | |
return self.conv(x) | |
class UNetMulti(nn.Module): | |
def __init__(self, in_channels=3, out_channels=4): | |
super(UNetMulti, self).__init__() | |
self.down1 = DoubleConvUNet(in_channels, 64) | |
self.pool1 = nn.MaxPool2d(2) | |
self.down2 = DoubleConvUNet(64, 128) | |
self.pool2 = nn.MaxPool2d(2) | |
self.down3 = DoubleConvUNet(128, 256) | |
self.pool3 = nn.MaxPool2d(2) | |
self.down4 = DoubleConvUNet(256, 512) | |
self.pool4 = nn.MaxPool2d(2) | |
self.bottleneck = DoubleConvUNet(512, 1024) | |
self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2) | |
self.conv4 = DoubleConvUNet(1024, 512) | |
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2) | |
self.conv3 = DoubleConvUNet(512, 256) | |
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2) | |
self.conv2 = DoubleConvUNet(256, 128) | |
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2) | |
self.conv1 = DoubleConvUNet(128, 64) | |
self.final_conv = nn.Conv2d(64, out_channels, 1) | |
def forward(self, x): | |
c1 = self.down1(x) | |
p1 = self.pool1(c1) | |
c2 = self.down2(p1) | |
p2 = self.pool2(c2) | |
c3 = self.down3(p2) | |
p3 = self.pool3(c3) | |
c4 = self.down4(p3) | |
p4 = self.pool4(c4) | |
bn = self.bottleneck(p4) | |
u4 = self.up4(bn) | |
merge4 = torch.cat([u4, c4], dim=1) | |
c5 = self.conv4(merge4) | |
u3 = self.up3(c5) | |
merge3 = torch.cat([u3, c3], dim=1) | |
c6 = self.conv3(merge3) | |
u2 = self.up2(c6) | |
merge2 = torch.cat([u2, c2], dim=1) | |
c7 = self.conv2(merge2) | |
u1 = self.up1(c7) | |
merge1 = torch.cat([u1, c1], dim=1) | |
c8 = self.conv1(merge1) | |
return self.final_conv(c8) | |
def process_brain_tumor_return(image, model_path="models/brain_tumor_unet_multiclass.pth"): | |
logging.debug("Processing brain tumor segmentation") | |
model = UNetMulti(in_channels=3, out_channels=4).to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
transform_img = transforms.Compose([ | |
transforms.Resize((256,256)), | |
transforms.ToTensor() | |
]) | |
input_tensor = transform_img(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
preds = torch.argmax(output, dim=1).squeeze().cpu().numpy() | |
image_np = transform_img(image).permute(1,2,0).cpu().numpy() | |
overlay = cv2.applyColorMap(np.uint8(255 * preds/np.max(preds+1e-8)), cv2.COLORMAP_JET) | |
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) | |
blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0) | |
orig_pil = Image.fromarray((image_np*255).astype(np.uint8)) | |
mask_pil = Image.fromarray(overlay) | |
overlay_pil = Image.fromarray(blended) | |
return { | |
"original": pil_to_base64(orig_pil), | |
"mask": pil_to_base64(mask_pil), | |
"overlay": pil_to_base64(overlay_pil) | |
} | |
# ------------------------------- | |
# 3. ENDOSCOPY POLYP DETECTION MODULE (Binary UNet) | |
# ------------------------------- | |
class UNetBinary(nn.Module): | |
def __init__(self, in_channels=3, out_channels=1): | |
super(UNetBinary, self).__init__() | |
self.down1 = DoubleConvUNet(in_channels, 64) | |
self.pool1 = nn.MaxPool2d(2) | |
self.down2 = DoubleConvUNet(64, 128) | |
self.pool2 = nn.MaxPool2d(2) | |
self.down3 = DoubleConvUNet(128, 256) | |
self.pool3 = nn.MaxPool2d(2) | |
self.down4 = DoubleConvUNet(256, 512) | |
self.pool4 = nn.MaxPool2d(2) | |
self.bottleneck = DoubleConvUNet(512, 1024) | |
self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2) | |
self.conv4 = DoubleConvUNet(1024, 512) | |
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2) | |
self.conv3 = DoubleConvUNet(512, 256) | |
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2) | |
self.conv2 = DoubleConvUNet(256, 128) | |
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2) | |
self.conv1 = DoubleConvUNet(128, 64) | |
self.final_conv = nn.Conv2d(64, out_channels, 1) | |
def forward(self, x): | |
c1 = self.down1(x) | |
p1 = self.pool1(c1) | |
c2 = self.down2(p1) | |
p2 = self.pool2(c2) | |
c3 = self.down3(p2) | |
p3 = self.pool3(c3) | |
c4 = self.down4(p3) | |
p4 = self.pool4(c4) | |
bn = self.bottleneck(p4) | |
u4 = self.up4(bn) | |
merge4 = torch.cat([u4, c4], dim=1) | |
c5 = self.conv4(merge4) | |
u3 = self.up3(c5) | |
merge3 = torch.cat([u3, c3], dim=1) | |
c6 = self.conv3(merge3) | |
u2 = self.up2(c6) | |
merge2 = torch.cat([u2, c2], dim=1) | |
c7 = self.conv2(merge2) | |
u1 = self.up1(c7) | |
merge1 = torch.cat([u1, c1], dim=1) | |
c8 = self.conv1(merge1) | |
return self.final_conv(c8) | |
def process_endoscopy_return(image, model_path="models/endoscopy_unet.pth"): | |
model = UNetBinary(in_channels=3, out_channels=1).to(device) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
transform_img = transforms.Compose([ | |
transforms.Resize((256,256)), | |
transforms.ToTensor() | |
]) | |
input_tensor = transform_img(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
prob = torch.sigmoid(output) | |
mask = (prob > 0.5).float().squeeze().cpu().numpy() | |
image_np = transform_img(image).permute(1,2,0).cpu().numpy() | |
overlay = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) | |
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB) | |
blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0) | |
orig_pil = Image.fromarray((image_np*255).astype(np.uint8)) | |
mask_pil = Image.fromarray(overlay) | |
overlay_pil = Image.fromarray(blended) | |
return { | |
"original": pil_to_base64(orig_pil), | |
"mask": pil_to_base64(mask_pil), | |
"overlay": pil_to_base64(overlay_pil) | |
} | |
# ------------------------------- | |
# 4. PNEUMONIA DETECTION MODULE (Grad-CAM on ResNet18) | |
# ------------------------------- | |
class GradCAM_Pneumonia: | |
def __init__(self, model, target_layer): | |
self.model = model | |
self.target_layer = target_layer | |
self.gradients = None | |
self.activations = None | |
self.hook_handles = [] | |
self._register_hooks() | |
def _register_hooks(self): | |
def forward_hook(module, input, output): | |
self.activations = output.detach() | |
def backward_hook(module, grad_in, grad_out): | |
self.gradients = grad_out[0].detach() | |
handle1 = self.target_layer.register_forward_hook(forward_hook) | |
handle2 = self.target_layer.register_backward_hook(backward_hook) | |
self.hook_handles.extend([handle1, handle2]) | |
def remove_hooks(self): | |
for handle in self.hook_handles: | |
handle.remove() | |
def generate(self, input_image, target_class=None): | |
output = self.model(input_image) | |
if target_class is None: | |
target_class = output.argmax(dim=1).item() | |
self.model.zero_grad() | |
one_hot = torch.zeros_like(output) | |
one_hot[0, target_class] = 1 | |
with torch.enable_grad(): | |
output.backward(gradient=one_hot, retain_graph=True) | |
weights = self.gradients.mean(dim=(2,3), keepdim=True) | |
cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
cam = F.relu(cam) | |
cam = cam.squeeze().cpu().numpy() | |
_, _, H, W = input_image.shape | |
cam = cv2.resize(cam, (W, H)) | |
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8) | |
return cam, output | |
def process_pneumonia_return(image, model_path="models/pneumonia_resnet18.pth"): | |
model = models.resnet18(pretrained=False) | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.to(device) | |
model.eval() | |
grad_cam = GradCAM_Pneumonia(model, model.layer4) | |
transform_img = transforms.Compose([ | |
transforms.Resize((224,224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485,0.456,0.406], | |
std=[0.229,0.224,0.225]) | |
]) | |
input_tensor = transform_img(image).unsqueeze(0).to(device) | |
# Enable gradient tracking for the input tensor | |
input_tensor.requires_grad_() | |
# Do NOT wrap the following call with torch.no_grad() | |
cam, output = grad_cam.generate(input_tensor) | |
predicted_class = output.argmax(dim=1).item() | |
label_text = "Pneumonia" if predicted_class == 1 else "Normal" | |
def get_bounding_box(heatmap, thresh=0.5, min_area=100): | |
heat_uint8 = np.uint8(255 * heatmap) | |
ret, binary = cv2.threshold(heat_uint8, int(thresh*255), 255, cv2.THRESH_BINARY) | |
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if len(contours)==0: | |
return None | |
largest = max(contours, key=cv2.contourArea) | |
if cv2.contourArea(largest) < min_area: | |
return None | |
x, y, w, h = cv2.boundingRect(largest) | |
return (x, y, w, h) | |
bbox = None | |
if predicted_class == 1: | |
bbox = get_bounding_box(cam, thresh=0.5, min_area=100) | |
resized_image = image.resize((224,224)) | |
image_np = np.array(resized_image) | |
overlay = image_np.copy() | |
if bbox is not None: | |
x, y, w, h = bbox | |
cv2.rectangle(overlay, (x, y), (x+w, y+h), (255,0,0), 2) | |
cv2.putText(overlay, label_text, (10,25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,0),2) | |
heatmap_color = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET) | |
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) | |
orig_pil = Image.fromarray(image_np) | |
heatmap_pil = Image.fromarray(heatmap_color) | |
overlay_pil = Image.fromarray(overlay) | |
grad_cam.remove_hooks() | |
return { | |
"original": pil_to_base64(orig_pil), | |
"mask": pil_to_base64(heatmap_pil), | |
"overlay": pil_to_base64(overlay_pil) | |
} | |
# ------------------------------- | |
# 5. COMPLETE PIPELINE FUNCTION | |
# ------------------------------- | |
def complete_pipeline(image_path): | |
classifier_model = load_classifier_model("models/best_metric_model (4).pth") | |
predicted_modality = classify_medical_image(image_path, classifier_model) | |
print(f"Detected modality: {predicted_modality}") | |
original_image = Image.open(image_path).convert("RGB") | |
results = {"predicted_modality": predicted_modality} | |
if predicted_modality in ["HeadCT", "HeadMRI"]: | |
results["specialized"] = process_brain_tumor_return(original_image, "models/brain_tumor_unet_multiclass.pth") | |
elif predicted_modality == "Endoscopy": | |
results["specialized"] = process_endoscopy_return(original_image, "models/endoscopy_unet.pth") | |
elif predicted_modality == "Chest Xray": | |
results["specialized"] = process_pneumonia_return(original_image, "models/pneumonia_resnet18.pth") | |
else: | |
results["message"] = f"No specialized processing for modality: {predicted_modality}" | |
return results | |
# ------------------------------- | |
# 6. FLASK API SETUP | |
# ------------------------------- | |
from flask import Flask, request, render_template, jsonify | |
app = Flask(__name__) | |
def index(): | |
return render_template("index.html", result=None) | |
def predict(): | |
if 'file' not in request.files: | |
return render_template("index.html", result={"error": "No file part in the request."}) | |
file = request.files['file'] | |
if file.filename == '': | |
return render_template("index.html", result={"error": "No file selected."}) | |
temp_path = "temp_input.jpg" | |
file.save(temp_path) | |
try: | |
result = complete_pipeline(temp_path) | |
except Exception as e: | |
result = {"error": str(e)} | |
os.remove(temp_path) | |
return render_template("index.html", result=result) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=5000, debug=True) | |