import cv2 import numpy as np import pandas as pd import torch import gradio as gr import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # Ensure the SAM model checkpoint is downloaded SAM_CHECKPOINT = "sam_vit_h.pth" # Update path if needed MODEL_TYPE = "vit_h" # Check if CUDA is available for GPU processing DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load the SAM model try: sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(DEVICE) mask_generator = SamAutomaticMaskGenerator(sam) except FileNotFoundError: raise FileNotFoundError(f"Checkpoint file '{SAM_CHECKPOINT}' not found. Download it from: https://github.com/facebookresearch/segment-anything") def preprocess_image(image): """Convert image to grayscale and apply adaptive thresholding for better cell detection.""" if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # Apply adaptive thresholding for better contrast adaptive_thresh = cv2.adaptiveThreshold( gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2 ) # Morphological operations to remove noise kernel = np.ones((3, 3), np.uint8) clean_mask = cv2.morphologyEx(adaptive_thresh, cv2.MORPH_CLOSE, kernel, iterations=2) clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_OPEN, kernel, iterations=2) return clean_mask def detect_blood_cells(image): """Detect blood cells using SAM segmentation + contour analysis.""" # Generate masks using SAM masks = mask_generator.generate(image) features = [] processed_image = image.copy() for i, mask in enumerate(masks): mask_binary = mask["segmentation"].astype(np.uint8) * 255 # Convert to binary contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: area = cv2.contourArea(contour) perimeter = cv2.arcLength(contour, True) circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0 # Filter small or irregular shapes if 100 < area < 5000 and circularity > 0.7: M = cv2.moments(contour) if M["m00"] != 0: cx = int(M["m10"] / M["m00"]) cy = int(M["m01"] / M["m00"]) features.append( { "label": len(features) + 1, "area": area, "perimeter": perimeter, "circularity": circularity, "centroid_x": cx, "centroid_y": cy, } ) # Draw detected cell on image cv2.drawContours(processed_image, [contour], -1, (0, 255, 0), 2) cv2.putText( processed_image, str(len(features)), (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, ) return processed_image, features def process_image(image): if image is None: return None, None, None, None processed_img, features = detect_blood_cells(image) df = pd.DataFrame(features) return processed_img, df def analyze(image): processed_img, df = process_image(image) plt.style.use("dark_background") fig, axes = plt.subplots(1, 2, figsize=(12, 5)) if not df.empty: axes[0].hist(df["area"], bins=20, color="cyan", edgecolor="black") axes[0].set_title("Cell Size Distribution") axes[1].scatter(df["area"], df["circularity"], alpha=0.6, c="magenta") axes[1].set_title("Area vs Circularity") return processed_img, fig, df # Gradio Interface demo = gr.Interface( fn=analyze, inputs=gr.Image(type="numpy"), outputs=[gr.Image(), gr.Plot(), gr.Dataframe()], title="Blood Cell Detection", description="Detect and analyze blood cells using SAM segmentation & contour analysis.", ) demo.launch()