File size: 4,310 Bytes
8312ddd
 
6e20c51
1cd6596
 
8312ddd
1cd6596
 
 
 
 
6e20c51
1cd6596
 
 
 
 
 
 
 
 
3cf3c0d
7702060
1cd6596
7702060
 
 
1cd6596
 
 
 
 
6e20c51
 
1cd6596
 
 
 
 
 
 
 
 
 
 
 
7702060
1cd6596
 
6e20c51
1cd6596
 
6e20c51
1cd6596
6e20c51
 
 
 
1cd6596
6e20c51
 
 
 
 
1cd6596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7702060
 
 
6e20c51
7702060
1cd6596
7702060
1cd6596
 
7702060
 
1cd6596
7702060
1cd6596
7702060
 
 
1cd6596
 
7702060
1cd6596
 
7702060
1cd6596
7702060
 
1cd6596
 
 
 
 
 
 
 
7702060
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import cv2
import numpy as np
import pandas as pd
import torch
import gradio as gr
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# Ensure the SAM model checkpoint is downloaded
SAM_CHECKPOINT = "sam_vit_h.pth"  # Update path if needed
MODEL_TYPE = "vit_h"

# Check if CUDA is available for GPU processing
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the SAM model
try:
    sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(DEVICE)
    mask_generator = SamAutomaticMaskGenerator(sam)
except FileNotFoundError:
    raise FileNotFoundError(f"Checkpoint file '{SAM_CHECKPOINT}' not found. Download it from: https://github.com/facebookresearch/segment-anything")

def preprocess_image(image):
    """Convert image to grayscale and apply adaptive thresholding for better cell detection."""
    if len(image.shape) == 2:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

    # Apply adaptive thresholding for better contrast
    adaptive_thresh = cv2.adaptiveThreshold(
        gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
    )

    # Morphological operations to remove noise
    kernel = np.ones((3, 3), np.uint8)
    clean_mask = cv2.morphologyEx(adaptive_thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
    clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_OPEN, kernel, iterations=2)

    return clean_mask

def detect_blood_cells(image):
    """Detect blood cells using SAM segmentation + contour analysis."""
    # Generate masks using SAM
    masks = mask_generator.generate(image)

    features = []
    processed_image = image.copy()

    for i, mask in enumerate(masks):
        mask_binary = mask["segmentation"].astype(np.uint8) * 255  # Convert to binary
        contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        for contour in contours:
            area = cv2.contourArea(contour)
            perimeter = cv2.arcLength(contour, True)
            circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0

            # Filter small or irregular shapes
            if 100 < area < 5000 and circularity > 0.7:
                M = cv2.moments(contour)
                if M["m00"] != 0:
                    cx = int(M["m10"] / M["m00"])
                    cy = int(M["m01"] / M["m00"])
                    features.append(
                        {
                            "label": len(features) + 1,
                            "area": area,
                            "perimeter": perimeter,
                            "circularity": circularity,
                            "centroid_x": cx,
                            "centroid_y": cy,
                        }
                    )

                    # Draw detected cell on image
                    cv2.drawContours(processed_image, [contour], -1, (0, 255, 0), 2)
                    cv2.putText(
                        processed_image,
                        str(len(features)),
                        (cx, cy),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        0.5,
                        (0, 0, 255),
                        1,
                    )

    return processed_image, features

def process_image(image):
    if image is None:
        return None, None, None, None

    processed_img, features = detect_blood_cells(image)
    df = pd.DataFrame(features)

    return processed_img, df

def analyze(image):
    processed_img, df = process_image(image)

    plt.style.use("dark_background")
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    if not df.empty:
        axes[0].hist(df["area"], bins=20, color="cyan", edgecolor="black")
        axes[0].set_title("Cell Size Distribution")

        axes[1].scatter(df["area"], df["circularity"], alpha=0.6, c="magenta")
        axes[1].set_title("Area vs Circularity")

    return processed_img, fig, df

# Gradio Interface
demo = gr.Interface(
    fn=analyze,
    inputs=gr.Image(type="numpy"),
    outputs=[gr.Image(), gr.Plot(), gr.Dataframe()],
    title="Blood Cell Detection",
    description="Detect and analyze blood cells using SAM segmentation & contour analysis.",
)

demo.launch()