ibrahim313's picture
Update app.py
1cd6596 verified
raw
history blame
4.31 kB
import cv2
import numpy as np
import pandas as pd
import torch
import gradio as gr
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
# Ensure the SAM model checkpoint is downloaded
SAM_CHECKPOINT = "sam_vit_h.pth" # Update path if needed
MODEL_TYPE = "vit_h"
# Check if CUDA is available for GPU processing
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load the SAM model
try:
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(DEVICE)
mask_generator = SamAutomaticMaskGenerator(sam)
except FileNotFoundError:
raise FileNotFoundError(f"Checkpoint file '{SAM_CHECKPOINT}' not found. Download it from: https://github.com/facebookresearch/segment-anything")
def preprocess_image(image):
"""Convert image to grayscale and apply adaptive thresholding for better cell detection."""
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply adaptive thresholding for better contrast
adaptive_thresh = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
)
# Morphological operations to remove noise
kernel = np.ones((3, 3), np.uint8)
clean_mask = cv2.morphologyEx(adaptive_thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_OPEN, kernel, iterations=2)
return clean_mask
def detect_blood_cells(image):
"""Detect blood cells using SAM segmentation + contour analysis."""
# Generate masks using SAM
masks = mask_generator.generate(image)
features = []
processed_image = image.copy()
for i, mask in enumerate(masks):
mask_binary = mask["segmentation"].astype(np.uint8) * 255 # Convert to binary
contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
area = cv2.contourArea(contour)
perimeter = cv2.arcLength(contour, True)
circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
# Filter small or irregular shapes
if 100 < area < 5000 and circularity > 0.7:
M = cv2.moments(contour)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
features.append(
{
"label": len(features) + 1,
"area": area,
"perimeter": perimeter,
"circularity": circularity,
"centroid_x": cx,
"centroid_y": cy,
}
)
# Draw detected cell on image
cv2.drawContours(processed_image, [contour], -1, (0, 255, 0), 2)
cv2.putText(
processed_image,
str(len(features)),
(cx, cy),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 255),
1,
)
return processed_image, features
def process_image(image):
if image is None:
return None, None, None, None
processed_img, features = detect_blood_cells(image)
df = pd.DataFrame(features)
return processed_img, df
def analyze(image):
processed_img, df = process_image(image)
plt.style.use("dark_background")
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
if not df.empty:
axes[0].hist(df["area"], bins=20, color="cyan", edgecolor="black")
axes[0].set_title("Cell Size Distribution")
axes[1].scatter(df["area"], df["circularity"], alpha=0.6, c="magenta")
axes[1].set_title("Area vs Circularity")
return processed_img, fig, df
# Gradio Interface
demo = gr.Interface(
fn=analyze,
inputs=gr.Image(type="numpy"),
outputs=[gr.Image(), gr.Plot(), gr.Dataframe()],
title="Blood Cell Detection",
description="Detect and analyze blood cells using SAM segmentation & contour analysis.",
)
demo.launch()