Gladiator commited on
Commit
51f706a
·
1 Parent(s): 271c748

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -5,20 +5,16 @@ from PIL import Image
5
  from cellpose import models
6
 
7
 
8
- @st.cache()
9
- def load_model(model_path):
10
  inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)
11
- return inf_model
 
12
 
13
 
14
  if __name__ == "__main__":
15
 
16
  st.title("Sartorius Neuronal Cell Segmentation")
17
 
18
- inf_model = load_model(
19
- model_path="./cellpose_residual_on_style_on_concatenation_off_fold1_ep_649_cv_0.2834"
20
- )
21
-
22
  uploaded_img = st.file_uploader(label="Upload neuronal cell image")
23
 
24
  with st.expander("View input image"):
@@ -40,14 +36,18 @@ if __name__ == "__main__":
40
  "resample": True,
41
  }
42
  with st.spinner("Performing segmentation. This might take a while..."):
43
- preds, flows, _ = inf_model.eval([img], **model_params)
 
 
 
 
44
 
45
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
46
  ax1.axis("off")
47
  ax2.axis("off")
48
  ax3.axis("off")
49
  ax1.set_title("Original Image")
50
- ax1.imshow(img)
51
  ax2.set_title("Segmented image")
52
  ax2.imshow(preds[0])
53
  ax3.set_title("Image flows")
 
5
  from cellpose import models
6
 
7
 
8
+ def inference(img, model_path, **model_params):
 
9
  inf_model = models.CellposeModel(gpu=False, pretrained_model=model_path)
10
+ preds, flows, _ = inf_model.eval([img], **model_params)
11
+ return preds, flows
12
 
13
 
14
  if __name__ == "__main__":
15
 
16
  st.title("Sartorius Neuronal Cell Segmentation")
17
 
 
 
 
 
18
  uploaded_img = st.file_uploader(label="Upload neuronal cell image")
19
 
20
  with st.expander("View input image"):
 
36
  "resample": True,
37
  }
38
  with st.spinner("Performing segmentation. This might take a while..."):
39
+ preds, flows = inference(
40
+ img=img,
41
+ model_path="./cellpose_residual_on_style_on_concatenation_off_fold1_ep_649_cv_0.2834",
42
+ **model_params
43
+ )
44
 
45
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
46
  ax1.axis("off")
47
  ax2.axis("off")
48
  ax3.axis("off")
49
  ax1.set_title("Original Image")
50
+ ax1.imshow(img, cmap="gray")
51
  ax2.set_title("Segmented image")
52
  ax2.imshow(preds[0])
53
  ax3.set_title("Image flows")