ibrahim313 commited on
Commit
1cd6596
·
verified ·
1 Parent(s): b4e915f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -50
app.py CHANGED
@@ -1,91 +1,123 @@
1
- from segment_anything import sam_model_registry, SamPredictor
2
- import torch
3
  import cv2
4
  import numpy as np
5
- import gradio as gr
6
  import pandas as pd
 
 
7
  import matplotlib.pyplot as plt
 
 
 
 
 
8
 
9
- # Load SAM model
10
- sam_checkpoint = "sam_vit_h.pth" # Checkpoint file (download it from Meta AI)
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model_type = "vit_h"
13
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
14
- predictor = SamPredictor(sam)
 
 
 
15
 
16
  def preprocess_image(image):
17
- """Convert image to RGB format for SAM."""
18
  if len(image.shape) == 2:
19
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
20
- return image
21
 
22
- def detect_blood_cells(image):
23
- """Detect blood cells using SAM."""
24
- image = preprocess_image(image)
25
- predictor.set_image(image)
26
-
27
- # Generate automatic masks (SAM can also take prompts for guided segmentation)
28
- masks, _, _ = predictor.predict(
29
- point_coords=None,
30
- point_labels=None,
31
- multimask_output=True
32
  )
33
 
34
- contours_list = []
 
 
 
 
 
 
 
 
 
 
 
35
  features = []
 
 
36
  for i, mask in enumerate(masks):
37
- mask = mask.astype(np.uint8) * 255 # Convert boolean mask to uint8
38
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
39
 
40
- for j, contour in enumerate(contours, 1):
41
  area = cv2.contourArea(contour)
42
  perimeter = cv2.arcLength(contour, True)
43
  circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
44
 
 
45
  if 100 < area < 5000 and circularity > 0.7:
46
  M = cv2.moments(contour)
47
  if M["m00"] != 0:
48
  cx = int(M["m10"] / M["m00"])
49
  cy = int(M["m01"] / M["m00"])
50
- features.append({
51
- 'label': f"{i}-{j}", 'area': area, 'perimeter': perimeter,
52
- 'circularity': circularity, 'centroid_x': cx, 'centroid_y': cy
53
- })
54
- contours_list.append(contour)
55
-
56
- return contours_list, features, masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def process_image(image):
59
  if image is None:
60
  return None, None, None, None
61
 
62
- contours, features, masks = detect_blood_cells(image)
63
- vis_img = image.copy()
64
-
65
- for feature in features:
66
- contour = contours[int(feature['label'].split('-')[1]) - 1]
67
- cv2.drawContours(vis_img, [contour], -1, (0, 255, 0), 2)
68
- cv2.putText(vis_img, str(feature['label']), (feature['centroid_x'], feature['centroid_y']),
69
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
70
-
71
  df = pd.DataFrame(features)
72
- return vis_img, masks[0], df
 
73
 
74
  def analyze(image):
75
- vis_img, mask, df = process_image(image)
76
 
77
- plt.style.use('dark_background')
78
  fig, axes = plt.subplots(1, 2, figsize=(12, 5))
79
 
80
  if not df.empty:
81
- axes[0].hist(df['area'], bins=20, color='cyan', edgecolor='black')
82
- axes[0].set_title('Cell Size Distribution')
83
 
84
- axes[1].scatter(df['area'], df['circularity'], alpha=0.6, c='magenta')
85
- axes[1].set_title('Area vs Circularity')
86
 
87
- return vis_img, mask, fig, df
88
 
89
  # Gradio Interface
90
- demo = gr.Interface(fn=analyze, inputs=gr.Image(type="numpy"), outputs=[gr.Image(), gr.Image(), gr.Plot(), gr.Dataframe()])
 
 
 
 
 
 
 
91
  demo.launch()
 
 
 
1
  import cv2
2
  import numpy as np
 
3
  import pandas as pd
4
+ import torch
5
+ import gradio as gr
6
  import matplotlib.pyplot as plt
7
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
8
+
9
+ # Ensure the SAM model checkpoint is downloaded
10
+ SAM_CHECKPOINT = "sam_vit_h.pth" # Update path if needed
11
+ MODEL_TYPE = "vit_h"
12
 
13
+ # Check if CUDA is available for GPU processing
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Load the SAM model
17
+ try:
18
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(DEVICE)
19
+ mask_generator = SamAutomaticMaskGenerator(sam)
20
+ except FileNotFoundError:
21
+ raise FileNotFoundError(f"Checkpoint file '{SAM_CHECKPOINT}' not found. Download it from: https://github.com/facebookresearch/segment-anything")
22
 
23
  def preprocess_image(image):
24
+ """Convert image to grayscale and apply adaptive thresholding for better cell detection."""
25
  if len(image.shape) == 2:
26
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
 
27
 
28
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
29
+
30
+ # Apply adaptive thresholding for better contrast
31
+ adaptive_thresh = cv2.adaptiveThreshold(
32
+ gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
 
 
 
 
 
33
  )
34
 
35
+ # Morphological operations to remove noise
36
+ kernel = np.ones((3, 3), np.uint8)
37
+ clean_mask = cv2.morphologyEx(adaptive_thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
38
+ clean_mask = cv2.morphologyEx(clean_mask, cv2.MORPH_OPEN, kernel, iterations=2)
39
+
40
+ return clean_mask
41
+
42
+ def detect_blood_cells(image):
43
+ """Detect blood cells using SAM segmentation + contour analysis."""
44
+ # Generate masks using SAM
45
+ masks = mask_generator.generate(image)
46
+
47
  features = []
48
+ processed_image = image.copy()
49
+
50
  for i, mask in enumerate(masks):
51
+ mask_binary = mask["segmentation"].astype(np.uint8) * 255 # Convert to binary
52
+ contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
 
54
+ for contour in contours:
55
  area = cv2.contourArea(contour)
56
  perimeter = cv2.arcLength(contour, True)
57
  circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
58
 
59
+ # Filter small or irregular shapes
60
  if 100 < area < 5000 and circularity > 0.7:
61
  M = cv2.moments(contour)
62
  if M["m00"] != 0:
63
  cx = int(M["m10"] / M["m00"])
64
  cy = int(M["m01"] / M["m00"])
65
+ features.append(
66
+ {
67
+ "label": len(features) + 1,
68
+ "area": area,
69
+ "perimeter": perimeter,
70
+ "circularity": circularity,
71
+ "centroid_x": cx,
72
+ "centroid_y": cy,
73
+ }
74
+ )
75
+
76
+ # Draw detected cell on image
77
+ cv2.drawContours(processed_image, [contour], -1, (0, 255, 0), 2)
78
+ cv2.putText(
79
+ processed_image,
80
+ str(len(features)),
81
+ (cx, cy),
82
+ cv2.FONT_HERSHEY_SIMPLEX,
83
+ 0.5,
84
+ (0, 0, 255),
85
+ 1,
86
+ )
87
+
88
+ return processed_image, features
89
 
90
  def process_image(image):
91
  if image is None:
92
  return None, None, None, None
93
 
94
+ processed_img, features = detect_blood_cells(image)
 
 
 
 
 
 
 
 
95
  df = pd.DataFrame(features)
96
+
97
+ return processed_img, df
98
 
99
  def analyze(image):
100
+ processed_img, df = process_image(image)
101
 
102
+ plt.style.use("dark_background")
103
  fig, axes = plt.subplots(1, 2, figsize=(12, 5))
104
 
105
  if not df.empty:
106
+ axes[0].hist(df["area"], bins=20, color="cyan", edgecolor="black")
107
+ axes[0].set_title("Cell Size Distribution")
108
 
109
+ axes[1].scatter(df["area"], df["circularity"], alpha=0.6, c="magenta")
110
+ axes[1].set_title("Area vs Circularity")
111
 
112
+ return processed_img, fig, df
113
 
114
  # Gradio Interface
115
+ demo = gr.Interface(
116
+ fn=analyze,
117
+ inputs=gr.Image(type="numpy"),
118
+ outputs=[gr.Image(), gr.Plot(), gr.Dataframe()],
119
+ title="Blood Cell Detection",
120
+ description="Detect and analyze blood cells using SAM segmentation & contour analysis.",
121
+ )
122
+
123
  demo.launch()