ibrahim313's picture
Update app.py
6e20c51 verified
raw
history blame
3.22 kB
from segment_anything import sam_model_registry, SamPredictor
import torch
import cv2
import numpy as np
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
# Load SAM model
sam_checkpoint = "sam_vit_h.pth" # Checkpoint file (download it from Meta AI)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)
def preprocess_image(image):
"""Convert image to RGB format for SAM."""
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
return image
def detect_blood_cells(image):
"""Detect blood cells using SAM."""
image = preprocess_image(image)
predictor.set_image(image)
# Generate automatic masks (SAM can also take prompts for guided segmentation)
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
multimask_output=True
)
contours_list = []
features = []
for i, mask in enumerate(masks):
mask = mask.astype(np.uint8) * 255 # Convert boolean mask to uint8
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for j, contour in enumerate(contours, 1):
area = cv2.contourArea(contour)
perimeter = cv2.arcLength(contour, True)
circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
if 100 < area < 5000 and circularity > 0.7:
M = cv2.moments(contour)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
features.append({
'label': f"{i}-{j}", 'area': area, 'perimeter': perimeter,
'circularity': circularity, 'centroid_x': cx, 'centroid_y': cy
})
contours_list.append(contour)
return contours_list, features, masks
def process_image(image):
if image is None:
return None, None, None, None
contours, features, masks = detect_blood_cells(image)
vis_img = image.copy()
for feature in features:
contour = contours[int(feature['label'].split('-')[1]) - 1]
cv2.drawContours(vis_img, [contour], -1, (0, 255, 0), 2)
cv2.putText(vis_img, str(feature['label']), (feature['centroid_x'], feature['centroid_y']),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
df = pd.DataFrame(features)
return vis_img, masks[0], df
def analyze(image):
vis_img, mask, df = process_image(image)
plt.style.use('dark_background')
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
if not df.empty:
axes[0].hist(df['area'], bins=20, color='cyan', edgecolor='black')
axes[0].set_title('Cell Size Distribution')
axes[1].scatter(df['area'], df['circularity'], alpha=0.6, c='magenta')
axes[1].set_title('Area vs Circularity')
return vis_img, mask, fig, df
# Gradio Interface
demo = gr.Interface(fn=analyze, inputs=gr.Image(type="numpy"), outputs=[gr.Image(), gr.Image(), gr.Plot(), gr.Dataframe()])
demo.launch()