Uspark / app.py
PranayChamala's picture
initialized the first deployment
8dcd1f3
raw
history blame
15.7 kB
# app.py
import os, io, base64, cv2, torch, numpy as np
from PIL import Image
from flask import Flask, request, render_template, jsonify
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from monai.transforms import EnsureChannelFirst, ScaleIntensity, Resize, ToTensor
# Enable debug logging
import logging
logging.basicConfig(level=logging.DEBUG)
# -------------------------------
# Global Setup
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def pil_to_base64(pil_img):
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
return base64.b64encode(buff.getvalue()).decode("utf-8")
# -------------------------------
# 1. CLASSIFIER MODULE (DenseNet121 via MONAI)
# -------------------------------
CLASS_NAMES = ['AbdomenCT', 'BreastMRI', 'Chest Xray', 'ChestCT',
'Endoscopy', 'Hand Xray', 'HeadCT', 'HeadMRI']
from monai.networks.nets import DenseNet121
def load_classifier_model(model_path):
model = DenseNet121(
spatial_dims=2,
in_channels=3,
out_channels=len(CLASS_NAMES)
).to(device)
state_dict = torch.load(model_path, map_location=device)
if isinstance(state_dict, dict) and "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
def load_and_preprocess_image_classifier(image_path):
image_path = image_path.strip()
if image_path.lower().endswith((".jpg", ".jpeg", ".png")):
image = Image.open(image_path).convert("RGB")
image = np.array(image)
elif image_path.lower().endswith((".nii", ".nii.gz")):
import nibabel as nib
image = nib.load(image_path).get_fdata()
image = np.squeeze(image)
if len(image.shape) == 4:
image = image[..., 0]
if len(image.shape) == 3:
image = image[:, :, image.shape[2] // 2]
if len(image.shape) == 2:
image = np.stack([image]*3, axis=-1)
elif image_path.lower().endswith(".dcm"):
import pydicom
dicom_data = pydicom.dcmread(image_path)
image = dicom_data.pixel_array
if len(image.shape) == 2:
image = np.stack([image]*3, axis=-1)
else:
raise ValueError("Unsupported file format!")
if len(image.shape) == 3 and image.shape[-1] == 3:
image = np.transpose(image, (2, 0, 1))
else:
raise ValueError(f"Unexpected image shape: {image.shape}")
image = torch.tensor(image, dtype=torch.float32)
image = ScaleIntensity()(image)
image = Resize((224,224))(image)
image = image.unsqueeze(0)
return image.to(device)
def classify_medical_image(image_path, classifier_model):
image_tensor = load_and_preprocess_image_classifier(image_path)
with torch.no_grad():
output = classifier_model(image_tensor)
pred_class = torch.argmax(output, dim=1).item()
return CLASS_NAMES[pred_class]
# -------------------------------
# 2. BRAIN TUMOR SEGMENTATION MODULE (UNetMulti)
# -------------------------------
class DoubleConvUNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConvUNet, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNetMulti(nn.Module):
def __init__(self, in_channels=3, out_channels=4):
super(UNetMulti, self).__init__()
self.down1 = DoubleConvUNet(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.down2 = DoubleConvUNet(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.down3 = DoubleConvUNet(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.down4 = DoubleConvUNet(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.bottleneck = DoubleConvUNet(512, 1024)
self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv4 = DoubleConvUNet(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv3 = DoubleConvUNet(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv2 = DoubleConvUNet(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv1 = DoubleConvUNet(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, 1)
def forward(self, x):
c1 = self.down1(x)
p1 = self.pool1(c1)
c2 = self.down2(p1)
p2 = self.pool2(c2)
c3 = self.down3(p2)
p3 = self.pool3(c3)
c4 = self.down4(p3)
p4 = self.pool4(c4)
bn = self.bottleneck(p4)
u4 = self.up4(bn)
merge4 = torch.cat([u4, c4], dim=1)
c5 = self.conv4(merge4)
u3 = self.up3(c5)
merge3 = torch.cat([u3, c3], dim=1)
c6 = self.conv3(merge3)
u2 = self.up2(c6)
merge2 = torch.cat([u2, c2], dim=1)
c7 = self.conv2(merge2)
u1 = self.up1(c7)
merge1 = torch.cat([u1, c1], dim=1)
c8 = self.conv1(merge1)
return self.final_conv(c8)
def process_brain_tumor_return(image, model_path="models/brain_tumor_unet_multiclass.pth"):
logging.debug("Processing brain tumor segmentation")
model = UNetMulti(in_channels=3, out_channels=4).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
transform_img = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
input_tensor = transform_img(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
preds = torch.argmax(output, dim=1).squeeze().cpu().numpy()
image_np = transform_img(image).permute(1,2,0).cpu().numpy()
overlay = cv2.applyColorMap(np.uint8(255 * preds/np.max(preds+1e-8)), cv2.COLORMAP_JET)
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0)
orig_pil = Image.fromarray((image_np*255).astype(np.uint8))
mask_pil = Image.fromarray(overlay)
overlay_pil = Image.fromarray(blended)
return {
"original": pil_to_base64(orig_pil),
"mask": pil_to_base64(mask_pil),
"overlay": pil_to_base64(overlay_pil)
}
# -------------------------------
# 3. ENDOSCOPY POLYP DETECTION MODULE (Binary UNet)
# -------------------------------
class UNetBinary(nn.Module):
def __init__(self, in_channels=3, out_channels=1):
super(UNetBinary, self).__init__()
self.down1 = DoubleConvUNet(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.down2 = DoubleConvUNet(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.down3 = DoubleConvUNet(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.down4 = DoubleConvUNet(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.bottleneck = DoubleConvUNet(512, 1024)
self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv4 = DoubleConvUNet(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv3 = DoubleConvUNet(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv2 = DoubleConvUNet(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv1 = DoubleConvUNet(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, 1)
def forward(self, x):
c1 = self.down1(x)
p1 = self.pool1(c1)
c2 = self.down2(p1)
p2 = self.pool2(c2)
c3 = self.down3(p2)
p3 = self.pool3(c3)
c4 = self.down4(p3)
p4 = self.pool4(c4)
bn = self.bottleneck(p4)
u4 = self.up4(bn)
merge4 = torch.cat([u4, c4], dim=1)
c5 = self.conv4(merge4)
u3 = self.up3(c5)
merge3 = torch.cat([u3, c3], dim=1)
c6 = self.conv3(merge3)
u2 = self.up2(c6)
merge2 = torch.cat([u2, c2], dim=1)
c7 = self.conv2(merge2)
u1 = self.up1(c7)
merge1 = torch.cat([u1, c1], dim=1)
c8 = self.conv1(merge1)
return self.final_conv(c8)
def process_endoscopy_return(image, model_path="models/endoscopy_unet.pth"):
model = UNetBinary(in_channels=3, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
transform_img = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
input_tensor = transform_img(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
prob = torch.sigmoid(output)
mask = (prob > 0.5).float().squeeze().cpu().numpy()
image_np = transform_img(image).permute(1,2,0).cpu().numpy()
overlay = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
blended = cv2.addWeighted(np.uint8(image_np*255), 0.6, overlay, 0.4, 0)
orig_pil = Image.fromarray((image_np*255).astype(np.uint8))
mask_pil = Image.fromarray(overlay)
overlay_pil = Image.fromarray(blended)
return {
"original": pil_to_base64(orig_pil),
"mask": pil_to_base64(mask_pil),
"overlay": pil_to_base64(overlay_pil)
}
# -------------------------------
# 4. PNEUMONIA DETECTION MODULE (Grad-CAM on ResNet18)
# -------------------------------
class GradCAM_Pneumonia:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.hook_handles = []
self._register_hooks()
def _register_hooks(self):
def forward_hook(module, input, output):
self.activations = output.detach()
def backward_hook(module, grad_in, grad_out):
self.gradients = grad_out[0].detach()
handle1 = self.target_layer.register_forward_hook(forward_hook)
handle2 = self.target_layer.register_backward_hook(backward_hook)
self.hook_handles.extend([handle1, handle2])
def remove_hooks(self):
for handle in self.hook_handles:
handle.remove()
def generate(self, input_image, target_class=None):
output = self.model(input_image)
if target_class is None:
target_class = output.argmax(dim=1).item()
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
with torch.enable_grad():
output.backward(gradient=one_hot, retain_graph=True)
weights = self.gradients.mean(dim=(2,3), keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam.squeeze().cpu().numpy()
_, _, H, W = input_image.shape
cam = cv2.resize(cam, (W, H))
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
return cam, output
def process_pneumonia_return(image, model_path="models/pneumonia_resnet18.pth"):
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # 2 classes: normal and pneumonia
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
grad_cam = GradCAM_Pneumonia(model, model.layer4)
transform_img = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
input_tensor = transform_img(image).unsqueeze(0).to(device)
# Enable gradient tracking for the input tensor
input_tensor.requires_grad_()
# Do NOT wrap the following call with torch.no_grad()
cam, output = grad_cam.generate(input_tensor)
predicted_class = output.argmax(dim=1).item()
label_text = "Pneumonia" if predicted_class == 1 else "Normal"
def get_bounding_box(heatmap, thresh=0.5, min_area=100):
heat_uint8 = np.uint8(255 * heatmap)
ret, binary = cv2.threshold(heat_uint8, int(thresh*255), 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours)==0:
return None
largest = max(contours, key=cv2.contourArea)
if cv2.contourArea(largest) < min_area:
return None
x, y, w, h = cv2.boundingRect(largest)
return (x, y, w, h)
bbox = None
if predicted_class == 1:
bbox = get_bounding_box(cam, thresh=0.5, min_area=100)
resized_image = image.resize((224,224))
image_np = np.array(resized_image)
overlay = image_np.copy()
if bbox is not None:
x, y, w, h = bbox
cv2.rectangle(overlay, (x, y), (x+w, y+h), (255,0,0), 2)
cv2.putText(overlay, label_text, (10,25), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,0),2)
heatmap_color = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
orig_pil = Image.fromarray(image_np)
heatmap_pil = Image.fromarray(heatmap_color)
overlay_pil = Image.fromarray(overlay)
grad_cam.remove_hooks()
return {
"original": pil_to_base64(orig_pil),
"mask": pil_to_base64(heatmap_pil),
"overlay": pil_to_base64(overlay_pil)
}
# -------------------------------
# 5. COMPLETE PIPELINE FUNCTION
# -------------------------------
def complete_pipeline(image_path):
classifier_model = load_classifier_model("models/best_metric_model (4).pth")
predicted_modality = classify_medical_image(image_path, classifier_model)
print(f"Detected modality: {predicted_modality}")
original_image = Image.open(image_path).convert("RGB")
results = {"predicted_modality": predicted_modality}
if predicted_modality in ["HeadCT", "HeadMRI"]:
results["specialized"] = process_brain_tumor_return(original_image, "models/brain_tumor_unet_multiclass.pth")
elif predicted_modality == "Endoscopy":
results["specialized"] = process_endoscopy_return(original_image, "models/endoscopy_unet.pth")
elif predicted_modality == "Chest Xray":
results["specialized"] = process_pneumonia_return(original_image, "models/pneumonia_resnet18.pth")
else:
results["message"] = f"No specialized processing for modality: {predicted_modality}"
return results
# -------------------------------
# 6. FLASK API SETUP
# -------------------------------
from flask import Flask, request, render_template, jsonify
app = Flask(__name__)
@app.route('/', methods=['GET'])
def index():
return render_template("index.html", result=None)
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return render_template("index.html", result={"error": "No file part in the request."})
file = request.files['file']
if file.filename == '':
return render_template("index.html", result={"error": "No file selected."})
temp_path = "temp_input.jpg"
file.save(temp_path)
try:
result = complete_pipeline(temp_path)
except Exception as e:
result = {"error": str(e)}
os.remove(temp_path)
return render_template("index.html", result=result)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)