aagoluoglu commited on
Commit
65d4d46
·
verified ·
1 Parent(s): b4db21f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -116,7 +116,7 @@ def server(input: Inputs, output: Outputs, session: Session):
116
  return None # Return an empty string if no image is uploaded
117
 
118
  @reactive.Calc
119
- def generate_input_points(array_size = 256):
120
  """
121
  input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
122
  Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
@@ -132,11 +132,9 @@ def server(input: Inputs, output: Outputs, session: Session):
132
  embedding will be skipped for these points using the labels.
133
 
134
  """
135
- # # Define the size of your array
136
- # array_size = 256
137
-
138
- # Define the size of your grid
139
- grid_size = 10
140
 
141
  # Generate the grid points
142
  x = np.linspace(0, array_size-1, grid_size)
@@ -177,8 +175,7 @@ def server(input: Inputs, output: Outputs, session: Session):
177
 
178
  """ Prepare Inputs """
179
  # get input points prompt (grid of points)
180
- array_size = max(image.size)
181
- input_points = generate_input_points(array_size)
182
 
183
  # prepare image and prompt for the model
184
  inputs = processor(image, input_points=input_points, return_tensors="pt")
 
116
  return None # Return an empty string if no image is uploaded
117
 
118
  @reactive.Calc
119
+ def generate_input_points(image, grid_size=10):
120
  """
121
  input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
122
  Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
 
132
  embedding will be skipped for these points using the labels.
133
 
134
  """
135
+
136
+ # Get the dimensions of the image
137
+ array_size = max(image.width, image.height)
 
 
138
 
139
  # Generate the grid points
140
  x = np.linspace(0, array_size-1, grid_size)
 
175
 
176
  """ Prepare Inputs """
177
  # get input points prompt (grid of points)
178
+ input_points = generate_input_points(image)
 
179
 
180
  # prepare image and prompt for the model
181
  inputs = processor(image, input_points=input_points, return_tensors="pt")