aagoluoglu commited on
Commit
ace1b10
·
verified ·
1 Parent(s): 9170044

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -15
app.py CHANGED
@@ -9,10 +9,12 @@ import shinyswatch
9
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
 
11
  import os
12
- # os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
13
  from transformers import SamModel, SamConfig, SamProcessor
14
  import torch
15
 
 
 
 
16
  sns.set_theme()
17
 
18
  dir = Path(__file__).resolve().parent
@@ -59,22 +61,16 @@ app_ui = ui.page_fillable(
59
  ),
60
  )
61
 
 
 
 
 
 
 
 
 
62
 
63
  def server(input: Inputs, output: Outputs, session: Session):
64
- # Load the model configuration
65
- model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
66
- processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
67
-
68
- # Create an instance of the model architecture with the loaded configuration
69
- model = SamModel(config=model_config)
70
- # Update the model by loading the weights from saved file
71
- model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
72
- model.load_state_dict(model_state_dict)
73
-
74
- # set the device to cuda if available, otherwise use cpu
75
- device = "cuda" if torch.cuda.is_available() else "cpu"
76
- model.to(device)
77
-
78
  @reactive.Calc
79
  def uploaded_image_path() -> str:
80
  """Returns the path to the uploaded image"""
@@ -92,6 +88,58 @@ def server(input: Inputs, output: Outputs, session: Session):
92
  return img
93
  else:
94
  return None # Return an empty string if no image is uploaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  @reactive.Calc
97
  def filtered_df() -> pd.DataFrame:
 
9
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
 
11
  import os
 
12
  from transformers import SamModel, SamConfig, SamProcessor
13
  import torch
14
 
15
+ from PIL import Image
16
+ import io
17
+
18
  sns.set_theme()
19
 
20
  dir = Path(__file__).resolve().parent
 
61
  ),
62
  )
63
 
64
+ def tif_bytes_to_pil_image(tif_bytes):
65
+ # Create a BytesIO object from the TIFF bytes
66
+ bytes_io = io.BytesIO(tif_bytes)
67
+
68
+ # Open the BytesIO object as an Image
69
+ image = Image.open(bytes_io)
70
+
71
+ return image
72
 
73
  def server(input: Inputs, output: Outputs, session: Session):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  @reactive.Calc
75
  def uploaded_image_path() -> str:
76
  """Returns the path to the uploaded image"""
 
88
  return img
89
  else:
90
  return None # Return an empty string if no image is uploaded
91
+
92
+ def process_image():
93
+ """Processes the uploaded image, loads the model, and evaluates to get predictions"""
94
+ # Load the uploaded image
95
+ uploaded_image_bytes = input.tile_image()[0].read()
96
+
97
+ # Convert the uploaded TIFF bytes to a PIL Image object
98
+ uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes)
99
+
100
+ # Perform any preprocessing steps on the image as needed
101
+
102
+ # Example: Convert the image to the required input format for the model
103
+ # image_array = preprocess_image(uploaded_image)
104
+
105
+ # Load the model configuration
106
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
107
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
108
+
109
+ # Create an instance of the model architecture with the loaded configuration
110
+ model = SamModel(config=model_config)
111
+ # Update the model by loading the weights from saved file
112
+ model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
113
+ model.load_state_dict(model_state_dict)
114
+
115
+ # set the device to cuda if available, otherwise use cpu
116
+ device = "cuda" if torch.cuda.is_available() else "cpu"
117
+ model.to(device)
118
+
119
+ # Evaluate the image with the model
120
+ # Example: predictions = model.predict(image_array)
121
+
122
+ # Return the processed result (replace 'result' with the actual processed result)
123
+ return "Processed result"
124
+
125
+ @reactive.Calc
126
+ def processed_result():
127
+ """Processes the image when uploaded"""
128
+ if input.tile_image() is not None:
129
+ return process_image()
130
+ else:
131
+ return None
132
+
133
+ @output
134
+ @render.text
135
+ def processed_output():
136
+ """Displays the predictions of the uploaded image"""
137
+ return processed_result()
138
+
139
+
140
+
141
+
142
+
143
 
144
  @reactive.Calc
145
  def filtered_df() -> pd.DataFrame: