aagoluoglu commited on
Commit
d5f906d
·
verified ·
1 Parent(s): ace1b10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -18
app.py CHANGED
@@ -51,6 +51,7 @@ app_ui = ui.page_fillable(
51
  ui.input_switch("show_margins", "Show marginal plots", value=True),
52
  ),
53
  ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image
 
54
  ui.output_ui("value_boxes"),
55
  ui.output_plot("scatter", fill=True),
56
  ui.help_text(
@@ -70,7 +71,29 @@ def tif_bytes_to_pil_image(tif_bytes):
70
 
71
  return image
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def server(input: Inputs, output: Outputs, session: Session):
 
 
 
 
74
  @reactive.Calc
75
  def uploaded_image_path() -> str:
76
  """Returns the path to the uploaded image"""
@@ -88,34 +111,75 @@ def server(input: Inputs, output: Outputs, session: Session):
88
  return img
89
  else:
90
  return None # Return an empty string if no image is uploaded
91
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def process_image():
93
  """Processes the uploaded image, loads the model, and evaluates to get predictions"""
 
 
94
  # Load the uploaded image
95
  uploaded_image_bytes = input.tile_image()[0].read()
96
 
97
  # Convert the uploaded TIFF bytes to a PIL Image object
98
  uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes)
99
 
100
- # Perform any preprocessing steps on the image as needed
101
-
102
- # Example: Convert the image to the required input format for the model
103
- # image_array = preprocess_image(uploaded_image)
104
-
105
- # Load the model configuration
106
- model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
107
- processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
108
 
109
- # Create an instance of the model architecture with the loaded configuration
110
- model = SamModel(config=model_config)
111
- # Update the model by loading the weights from saved file
112
- model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
113
- model.load_state_dict(model_state_dict)
114
-
115
- # set the device to cuda if available, otherwise use cpu
116
- device = "cuda" if torch.cuda.is_available() else "cpu"
117
- model.to(device)
118
 
 
 
 
 
119
  # Evaluate the image with the model
120
  # Example: predictions = model.predict(image_array)
121
 
 
51
  ui.input_switch("show_margins", "Show marginal plots", value=True),
52
  ),
53
  ui.output_image("uploaded_image"), # display the uploaded TIFF sidewalk tile image
54
+ ui.output_text("processed_output")
55
  ui.output_ui("value_boxes"),
56
  ui.output_plot("scatter", fill=True),
57
  ui.help_text(
 
71
 
72
  return image
73
 
74
+ def load_model():
75
+ """ Get Model """
76
+ # Load the model configuration
77
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
78
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
79
+
80
+ # Create an instance of the model architecture with the loaded configuration
81
+ model = SamModel(config=model_config)
82
+ # Update the model by loading the weights from saved file
83
+ model_state_dict = torch.load(str(dir / "checkpoint.pth"), map_location=torch.device('cpu'))
84
+ model.load_state_dict(model_state_dict)
85
+
86
+ # set the device to cuda if available, otherwise use cpu
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
+ model.to(device)
89
+
90
+ return model, processor
91
+
92
  def server(input: Inputs, output: Outputs, session: Session):
93
+
94
+ # load model and processor once
95
+ model, processor = load_model()
96
+
97
  @reactive.Calc
98
  def uploaded_image_path() -> str:
99
  """Returns the path to the uploaded image"""
 
111
  return img
112
  else:
113
  return None # Return an empty string if no image is uploaded
114
+
115
+ @reactive.Calc
116
+ def generate_input_points():
117
+ """
118
+ input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
119
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
120
+ Generally yields to much better results. The points can be obtained by passing a
121
+ list of list of list to the processor that will create corresponding torch tensors
122
+ of dimension 4. The first dimension is the image batch size, the second dimension
123
+ is the point batch size (i.e. how many segmentation masks do we want the model to
124
+ predict per input point), the third dimension is the number of points per segmentation
125
+ mask (it is possible to pass multiple points for a single mask), and the last dimension
126
+ is the x (vertical) and y (horizontal) coordinates of the point. If a different number
127
+ of points is passed either for each image, or for each mask, the processor will create
128
+ “PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
129
+ embedding will be skipped for these points using the labels.
130
+
131
+ """
132
+ # Define the size of your array
133
+ array_size = 256
134
+
135
+ # Define the size of your grid
136
+ grid_size = 10
137
+
138
+ # Generate the grid points
139
+ x = np.linspace(0, array_size-1, grid_size)
140
+ y = np.linspace(0, array_size-1, grid_size)
141
+
142
+ # Generate a grid of coordinates
143
+ xv, yv = np.meshgrid(x, y)
144
+
145
+ # Convert the numpy arrays to lists
146
+ xv_list = xv.tolist()
147
+ yv_list = yv.tolist()
148
+
149
+ # Combine the x and y coordinates into a list of list of lists
150
+ input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
151
+
152
+ #We need to reshape our nxn grid to the expected shape of the input_points tensor
153
+ # (batch_size, point_batch_size, num_points_per_image, 2),
154
+ # where the last dimension of 2 represents the x and y coordinates of each point.
155
+ #batch_size: The number of images you're processing at once.
156
+ #point_batch_size: The number of point sets you have for each image.
157
+ #num_points_per_image: The number of points in each set.
158
+ input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
159
+
160
+ return input_points
161
+
162
  def process_image():
163
  """Processes the uploaded image, loads the model, and evaluates to get predictions"""
164
+
165
+ """ Get Image """
166
  # Load the uploaded image
167
  uploaded_image_bytes = input.tile_image()[0].read()
168
 
169
  # Convert the uploaded TIFF bytes to a PIL Image object
170
  uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes)
171
 
172
+ """ Prepare Inputs """
173
+ # get input points prompt (grid of points)
174
+ input_points = generate_input_points(image)
 
 
 
 
 
175
 
176
+ # prepare image and prompt for the model
177
+ inputs = processor(image, input_points=input_points, return_tensors="pt")
 
 
 
 
 
 
 
178
 
179
+ # remove batch dimension which the processor adds by default
180
+ inputs = {k:v.squeeze(0) for k,v in inputs.items()}
181
+
182
+ """ Get Predictions """
183
  # Evaluate the image with the model
184
  # Example: predictions = model.predict(image_array)
185