Nguyen Thai Thao Uyen commited on
Commit
7fee2e3
·
1 Parent(s): 26112df

Update file format

Browse files
Files changed (2) hide show
  1. app.py +3 -7
  2. predictor.py +53 -0
app.py CHANGED
@@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import seaborn as sns
7
  import shinyswatch
8
- import run
9
  import PIL
10
 
11
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
@@ -22,11 +22,7 @@ app_ui = ui.page_fillable(
22
  ui.sidebar(
23
  ui.input_file("image_input", "Upload image: ", multiple=True),
24
  ),
25
- ui.output_image("image"),
26
- # ui.output_image("image_output"),
27
- ui.output_plot("plot"),
28
- # ui.output_image("prediction"),
29
- # ui.output_image("prob")
30
  ),
31
  )
32
 
@@ -48,7 +44,7 @@ def server(input: Inputs, output: Outputs, session: Session):
48
  if input.image_input():
49
  new_image = input.image_input()[0]['datapath']
50
 
51
- pred_prob, pred_prediction = run.pred(new_image)
52
 
53
  print("plotting...")
54
  fig, axes = plt.subplots(1, 2, figsize=(15, 5))
 
5
  import pandas as pd
6
  import seaborn as sns
7
  import shinyswatch
8
+ import predictor
9
  import PIL
10
 
11
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
 
22
  ui.sidebar(
23
  ui.input_file("image_input", "Upload image: ", multiple=True),
24
  ),
25
+ ui.output_plot("plot")
 
 
 
 
26
  ),
27
  )
28
 
 
44
  if input.image_input():
45
  new_image = input.image_input()[0]['datapath']
46
 
47
+ pred_prob, pred_prediction = predictor.pred(new_image)
48
 
49
  print("plotting...")
50
  fig, axes = plt.subplots(1, 2, figsize=(15, 5))
predictor.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SamModel, SamConfig, SamProcessor
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import app
6
+ import os
7
+ import json
8
+ from PIL import Image
9
+
10
+ def pred(src):
11
+ # -- cache
12
+ cache_dir = "/code/cache"
13
+
14
+ # -- load model configuration
15
+ MODEL_FILE = "sam_model.pth"
16
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base", cache_dir=cache_dir)
17
+
18
+ model = SamModel(config=model_config)
19
+ model.load_state_dict(torch.load(MODEL_FILE, map_location=torch.device('cpu')))
20
+
21
+ with open("sam-config.json", "r") as f: # modified config json file
22
+ modified_config_dict = json.load(f)
23
+
24
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base",
25
+ **modified_config_dict,
26
+ cache_dir=cache_dir)
27
+
28
+ # -- process image
29
+ image = Image.open(src)
30
+ rgbim = image.convert("RGB")
31
+ new_image = np.array(rgbim)
32
+ print()
33
+ print("image shape:",new_image.shape)
34
+
35
+ inputs = processor(new_image, return_tensors="pt")
36
+ model.eval()
37
+
38
+ # forward pass
39
+ print("predicting...")
40
+ with torch.no_grad():
41
+ outputs = model(pixel_values=inputs["pixel_values"],
42
+ multimask_output=False)
43
+
44
+ # apply sigmoid
45
+ print("apply sigmoid...")
46
+ pred_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
47
+
48
+ # convert soft mask to hard mask
49
+ PROBABILITY_THRES = 0.30
50
+ pred_prob = pred_prob.cpu().numpy().squeeze()
51
+ pred_prediction = (pred_prob > PROBABILITY_THRES).astype(np.uint8)
52
+
53
+ return pred_prob, pred_prediction