Bono93 commited on
Commit
146111c
·
1 Parent(s): 131f195

feat: add medsam

Browse files
app.py CHANGED
@@ -5,12 +5,21 @@ import torch
5
  import cv2
6
  from segment_anything import SamPredictor, sam_model_registry
7
 
8
- CHECKPOINT = "./models/sam_vit_h_4b8939.pth"
9
- MODEL_TYPE = "vit_h"
 
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
- SAM = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT)
 
 
 
12
  SAM.to(device=DEVICE)
13
  SAM_PREDICTOR = SamPredictor(SAM)
 
 
 
 
14
 
15
 
16
  def draw_contour(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
@@ -23,17 +32,20 @@ def draw_contour(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
23
  return contour_image, contours
24
 
25
 
26
- def inference(image: np.ndarray, coord_y: int, coord_x: int) -> np.ndarray:
 
 
27
  """Inference."""
28
- SAM_PREDICTOR.set_image(image)
29
 
30
  input_point = np.array([[coord_y, coord_x]])
31
  input_label = np.array([1])
32
- mask, _, _ = SAM_PREDICTOR.predict(
33
  point_coords=input_point,
34
  point_labels=input_label,
35
  multimask_output=False,
36
  )
 
37
  h, w = mask.shape[-2:]
38
  mask = mask.reshape(h, w, 1)
39
  mask = (mask * 255).astype(np.uint8)
@@ -63,7 +75,7 @@ with gr.Blocks() as demo:
63
  )
64
 
65
  # Segment image
66
- with gr.Tab(label="Image"):
67
  with gr.Row().style(equal_height=True):
68
  with gr.Column(label="Input Image"):
69
  # input image
@@ -80,7 +92,12 @@ with gr.Blocks() as demo:
80
  input_image.select(get_coords, None, [coord_h, coord_w])
81
  gr.Examples(
82
  examples=[
83
- [os.path.join(os.path.dirname(__file__), "samples/bears.jpg"), 1300, 950]
 
 
 
 
 
84
  ],
85
  inputs=[input_image, coord_h, coord_w],
86
  outputs=output,
 
5
  import cv2
6
  from segment_anything import SamPredictor, sam_model_registry
7
 
8
+ # Global variables
9
+ OFFICIAL_CHECKPOINT = "./models/sam_vit_b_01ec64.pth"
10
+ MEDSAM_CHECKPOINT = "./models/medsam_vitb.pth"
11
+ MODEL_TYPE = "vit_b"
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Model
15
+ ## OFFICIAL SAM
16
+ SAM = sam_model_registry[MODEL_TYPE](checkpoint=OFFICIAL_CHECKPOINT)
17
  SAM.to(device=DEVICE)
18
  SAM_PREDICTOR = SamPredictor(SAM)
19
+ ## MEDSAM
20
+ MEDSAM = sam_model_registry[MODEL_TYPE](checkpoint=MEDSAM_CHECKPOINT)
21
+ MEDSAM.to(device=DEVICE)
22
+ MEDSAM_PREDICTOR = SamPredictor(MEDSAM)
23
 
24
 
25
  def draw_contour(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
 
32
  return contour_image, contours
33
 
34
 
35
+ def inference(
36
+ predictor: SamPredictor, image: np.ndarray, coord_y: int, coord_x: int
37
+ ) -> np.ndarray:
38
  """Inference."""
39
+ predictor.set_image(image)
40
 
41
  input_point = np.array([[coord_y, coord_x]])
42
  input_label = np.array([1])
43
+ mask, _, _ = predictor.predict(
44
  point_coords=input_point,
45
  point_labels=input_label,
46
  multimask_output=False,
47
  )
48
+
49
  h, w = mask.shape[-2:]
50
  mask = mask.reshape(h, w, 1)
51
  mask = (mask * 255).astype(np.uint8)
 
75
  )
76
 
77
  # Segment image
78
+ with gr.Tab(label="SAM Inference"):
79
  with gr.Row().style(equal_height=True):
80
  with gr.Column(label="Input Image"):
81
  # input image
 
92
  input_image.select(get_coords, None, [coord_h, coord_w])
93
  gr.Examples(
94
  examples=[
95
+ [os.path.join(os.path.dirname(__file__), "samples/bears.jpg"), 1300, 950],
96
+ [
97
+ os.path.join(os.path.dirname(__file__), "samples/breast_cancer.png"),
98
+ 125,
99
+ 60,
100
+ ],
101
  ],
102
  inputs=[input_image, coord_h, coord_w],
103
  outputs=output,
models/{sam_vit_h_4b8939.pth → medsam_vitb.pth} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
- size 2564550879
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
models/sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  opencv-python
2
  matplotlib
3
  gradio
 
4
  torch
5
  torchvision
 
1
  opencv-python
2
  matplotlib
3
  gradio
4
+ transformers
5
  torch
6
  torchvision
samples/breast_cancer.png ADDED
scripts/example.py CHANGED
@@ -16,6 +16,7 @@ device = "cpu"
16
  # Load model
17
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
18
  sam.to(device=device)
 
19
  predictor = SamPredictor(sam)
20
 
21
  # Preprocessing the image
 
16
  # Load model
17
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
18
  sam.to(device=device)
19
+
20
  predictor = SamPredictor(sam)
21
 
22
  # Preprocessing the image