Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -116,7 +116,7 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
116 |
return None # Return an empty string if no image is uploaded
|
117 |
|
118 |
@reactive.Calc
|
119 |
-
def generate_input_points():
|
120 |
"""
|
121 |
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
|
122 |
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
|
@@ -132,8 +132,8 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
132 |
embedding will be skipped for these points using the labels.
|
133 |
|
134 |
"""
|
135 |
-
# Define the size of your array
|
136 |
-
array_size = 256
|
137 |
|
138 |
# Define the size of your grid
|
139 |
grid_size = 10
|
@@ -177,7 +177,8 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
177 |
|
178 |
""" Prepare Inputs """
|
179 |
# get input points prompt (grid of points)
|
180 |
-
|
|
|
181 |
|
182 |
# prepare image and prompt for the model
|
183 |
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
|
|
116 |
return None # Return an empty string if no image is uploaded
|
117 |
|
118 |
@reactive.Calc
|
119 |
+
def generate_input_points(array_size = 256):
|
120 |
"""
|
121 |
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
|
122 |
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
|
|
|
132 |
embedding will be skipped for these points using the labels.
|
133 |
|
134 |
"""
|
135 |
+
# # Define the size of your array
|
136 |
+
# array_size = 256
|
137 |
|
138 |
# Define the size of your grid
|
139 |
grid_size = 10
|
|
|
177 |
|
178 |
""" Prepare Inputs """
|
179 |
# get input points prompt (grid of points)
|
180 |
+
array_size = max(image.shape[0], image.shape[1])
|
181 |
+
input_points = generate_input_points(array_size)
|
182 |
|
183 |
# prepare image and prompt for the model
|
184 |
inputs = processor(image, input_points=input_points, return_tensors="pt")
|