haritsahm commited on
Commit
84be203
·
1 Parent(s): 563d5aa

replace preprocess specific to medsam

Browse files
Files changed (1) hide show
  1. utils/utils.py +25 -2
utils/utils.py CHANGED
@@ -1,23 +1,46 @@
 
 
1
  import numpy as np
2
  import streamlit as st
3
  import torch
4
  from distinctipy import distinctipy
5
  from segment_anything import (SamAutomaticMaskGenerator, SamPredictor,
6
  sam_model_registry)
 
7
 
8
 
9
  def get_color():
10
  return distinctipy.get_colors(200)
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @st.cache_resource
14
  def get_model(checkpoint='checkpoint/sam_vit_b_01ec64.pth'):
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- build_sam = sam_model_registry['vit_b']
17
- model = build_sam(checkpoint=checkpoint)
 
 
 
 
 
18
  model = model.to(device)
19
  if torch.cuda.is_available():
20
  torch.cuda.empty_cache()
 
21
  predictor = SamPredictor(model)
22
  mask_generator = SamAutomaticMaskGenerator(model)
23
  return predictor, mask_generator
 
1
+ import types
2
+
3
  import numpy as np
4
  import streamlit as st
5
  import torch
6
  from distinctipy import distinctipy
7
  from segment_anything import (SamAutomaticMaskGenerator, SamPredictor,
8
  sam_model_registry)
9
+ from torch.nn import functional as F
10
 
11
 
12
  def get_color():
13
  return distinctipy.get_colors(200)
14
 
15
 
16
+ def medsam_preprocess(self, x: torch.Tensor) -> torch.Tensor:
17
+ """Normalize pixel values and pad to a square input."""
18
+ # Normalize colors
19
+ x = (x - x.min()) / torch.clip(
20
+ x.max() - x.min(), min=1e-8, max=None) # normalize to [0, 1], (H, W, 3)
21
+
22
+ # Pad
23
+ h, w = x.shape[-2:]
24
+ padh = self.image_encoder.img_size - h
25
+ padw = self.image_encoder.img_size - w
26
+ x = F.pad(x, (0, padw, 0, padh))
27
+ return x
28
+
29
+
30
  @st.cache_resource
31
  def get_model(checkpoint='checkpoint/sam_vit_b_01ec64.pth'):
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model = sam_model_registry['vit_b'](checkpoint=checkpoint)
34
+
35
+ # Replace preprocess function
36
+ funcType = types.MethodType
37
+ model.preprocess = funcType(medsam_preprocess, model)
38
+ model.mask_threshold = 0.5
39
+
40
  model = model.to(device)
41
  if torch.cuda.is_available():
42
  torch.cuda.empty_cache()
43
+
44
  predictor = SamPredictor(model)
45
  mask_generator = SamAutomaticMaskGenerator(model)
46
  return predictor, mask_generator