umairahmad1789's picture
Update app.py
dc83630 verified
raw
history blame
5.17 kB
import gradio as gr
import torch
from unet import EnhancedUNet
import numpy as np
from PIL import Image
import io
import math
def initialize_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedUNet(n_channels=1, n_classes=4).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model, device
def extract_patches(image, patch_size=256):
"""Extract patches from the input image."""
width, height = image.size
patches = []
positions = []
# Calculate number of patches in each dimension
n_cols = math.ceil(width / patch_size)
n_rows = math.ceil(height / patch_size)
# Pad image if necessary
padded_width = n_cols * patch_size
padded_height = n_rows * patch_size
padded_image = Image.new('L', (padded_width, padded_height), 0)
padded_image.paste(image, (0, 0))
# Extract patches
for i in range(n_rows):
for j in range(n_cols):
left = j * patch_size
top = i * patch_size
right = left + patch_size
bottom = top + patch_size
patch = padded_image.crop((left, top, right, bottom))
patches.append(patch)
positions.append((left, top, right, bottom))
return patches, positions, (padded_width, padded_height), (width, height)
def stitch_patches(patches, positions, padded_size, original_size, n_classes=4):
"""Stitch patches back together into a complete mask."""
full_mask = np.zeros((padded_size[1], padded_size[0]), dtype=np.uint8)
for patch, (left, top, right, bottom) in zip(patches, positions):
full_mask[top:bottom, left:right] = patch
# Crop back to original size
full_mask = full_mask[:original_size[1], :original_size[0]]
return full_mask
def process_patch(patch, device):
# Convert to grayscale if it's not already
patch_gray = patch.convert("L")
# Convert to numpy array and normalize
patch_np = np.array(patch_gray, dtype=np.float32) / 255.0
# Add batch and channel dimensions
patch_tensor = torch.from_numpy(patch_np).float().unsqueeze(0).unsqueeze(0)
return patch_tensor.to(device)
def create_overlay(original_image, mask, alpha=0.5):
colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255)] # Define colors for each class
mask_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
for i, color in enumerate(colors):
mask_rgb[mask == i] = color
# Resize original image to match mask size
original_image = original_image.resize((mask.shape[1], mask.shape[0]))
original_array = np.array(original_image.convert("RGB"))
# Create overlay
overlay = (alpha * mask_rgb + (1 - alpha) * original_array).astype(np.uint8)
return Image.fromarray(overlay)
def predict(input_image, model_choice):
if input_image is None:
return None, None
model = models[model_choice]
patch_size = 256
# Extract patches
patches, positions, padded_size, original_size = extract_patches(input_image, patch_size)
# Process each patch
predicted_patches = []
for patch in patches:
# Process patch
patch_tensor = process_patch(patch, device)
# Perform inference
with torch.no_grad():
output = model(patch_tensor)
# Get prediction mask for patch
patch_mask = torch.argmax(output, dim=1).cpu().numpy()[0]
predicted_patches.append(patch_mask)
# Stitch patches back together
full_mask = stitch_patches(predicted_patches, positions, padded_size, original_size)
# Create mask image
mask_image = Image.fromarray((full_mask * 63).astype(np.uint8)) # Scale for better visibility
# Create overlay image
overlay_image = create_overlay(input_image, full_mask)
return mask_image, overlay_image
# Initialize model (do this outside the inference function for better performance)
w_noise_model_path = "./models/best_model_w_noise.pth"
wo_noise_model_path = "./models/best_model_wo_noise.pth"
w_noise_model_v2_path = "./models/best_model_w_noise_v2.pth"
w_noise_model, device = initialize_model(w_noise_model_path)
wo_noise_model, device = initialize_model(wo_noise_model_path)
w_noise_model_v2, device = initialize_model(w_noise_model_v2_path)
models = {
"Without Noise": wo_noise_model,
"With Noise": w_noise_model,
"With Noise V2": w_noise_model_v2
}
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(choices=["Without Noise", "With Noise", "With Noise V2"], value="With Noise V2"),
],
outputs=[
gr.Image(type="pil", label="Segmentation Mask"),
gr.Image(type="pil", label="Overlay"),
],
title="MoS2 Image Segmentation",
description="Upload an image to get the segmentation mask and overlay visualization.",
examples=[["./examples/image_000003.png", "With Noise"], ["./examples/image_000005.png", "Without Noise"]],
)
# Launch the interface
iface.launch(share=True)