aagoluoglu commited on
Commit
6fc6444
·
verified ·
1 Parent(s): 8f831ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -116,7 +116,7 @@ def server(input: Inputs, output: Outputs, session: Session):
116
  return None # Return an empty string if no image is uploaded
117
 
118
  @reactive.Calc
119
- def generate_input_points():
120
  """
121
  input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
122
  Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
@@ -132,8 +132,8 @@ def server(input: Inputs, output: Outputs, session: Session):
132
  embedding will be skipped for these points using the labels.
133
 
134
  """
135
- # Define the size of your array
136
- array_size = 256
137
 
138
  # Define the size of your grid
139
  grid_size = 10
@@ -177,7 +177,8 @@ def server(input: Inputs, output: Outputs, session: Session):
177
 
178
  """ Prepare Inputs """
179
  # get input points prompt (grid of points)
180
- input_points = generate_input_points(image)
 
181
 
182
  # prepare image and prompt for the model
183
  inputs = processor(image, input_points=input_points, return_tensors="pt")
 
116
  return None # Return an empty string if no image is uploaded
117
 
118
  @reactive.Calc
119
+ def generate_input_points(array_size = 256):
120
  """
121
  input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
122
  Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
 
132
  embedding will be skipped for these points using the labels.
133
 
134
  """
135
+ # # Define the size of your array
136
+ # array_size = 256
137
 
138
  # Define the size of your grid
139
  grid_size = 10
 
177
 
178
  """ Prepare Inputs """
179
  # get input points prompt (grid of points)
180
+ array_size = max(image.shape[0], image.shape[1])
181
+ input_points = generate_input_points(array_size)
182
 
183
  # prepare image and prompt for the model
184
  inputs = processor(image, input_points=input_points, return_tensors="pt")