rahulvenkk
commited on
Commit
·
4d601e2
1
Parent(s):
ee3d5e7
annot
Browse files- app.py +39 -5
- assets/intervention_test_images/annot.json +1 -0
app.py
CHANGED
@@ -17,6 +17,7 @@ dot_radius = 7 # Radius for the dots
|
|
17 |
dot_thickness = -1 # Thickness for solid circle (-1 fills the circle)
|
18 |
from PIL import Image
|
19 |
import torch
|
|
|
20 |
#load model
|
21 |
from cwm.model.model_factory import model_factory
|
22 |
|
@@ -141,9 +142,42 @@ with gr.Blocks() as demo:
|
|
141 |
def load_img(evt: gr.SelectData):
|
142 |
img_path = evt.value['image']['path']
|
143 |
img = np.array(Image.open(img_path))
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# print(f"Image uploaded with shape: {input.shape}")
|
145 |
resized_img = resize_to_square(img)
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
|
149 |
def store_img(img):
|
@@ -154,7 +188,7 @@ with gr.Blocks() as demo:
|
|
154 |
|
155 |
with gr.Row():
|
156 |
with gr.Column():
|
157 |
-
gallery = gr.Gallery( ["./assets/
|
158 |
# examples = gr.Examples(
|
159 |
# examples=[
|
160 |
# ["./assets/desk_1.jpg", "./assets/desk_1.jpg"],
|
@@ -228,13 +262,13 @@ with gr.Blocks() as demo:
|
|
228 |
# Draw arrow
|
229 |
|
230 |
# Draw dots at start and end points
|
231 |
-
cv2.circle(temp, start_point, dot_radius, color, dot_thickness)
|
232 |
-
cv2.circle(temp, end_point, dot_radius, color, dot_thickness)
|
233 |
|
234 |
# If there is an odd number of points (e.g., only a start point), draw a dot for it
|
235 |
if len(sel_pix) == 1:
|
236 |
start_point = sel_pix[0]
|
237 |
-
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness)
|
238 |
|
239 |
return temp if isinstance(temp, np.ndarray) else np.array(temp)
|
240 |
|
|
|
17 |
dot_thickness = -1 # Thickness for solid circle (-1 fills the circle)
|
18 |
from PIL import Image
|
19 |
import torch
|
20 |
+
import json
|
21 |
#load model
|
22 |
from cwm.model.model_factory import model_factory
|
23 |
|
|
|
142 |
def load_img(evt: gr.SelectData):
|
143 |
img_path = evt.value['image']['path']
|
144 |
img = np.array(Image.open(img_path))
|
145 |
+
# print(f"Image uploaded with shape: {input.shape}")
|
146 |
+
with open('./assets/intervention_test_images/annot.json', 'r') as f:
|
147 |
+
points_json = json.load(f)
|
148 |
+
|
149 |
+
points_json = points_json[os.path.basename(img_path)]
|
150 |
+
|
151 |
# print(f"Image uploaded with shape: {input.shape}")
|
152 |
resized_img = resize_to_square(img)
|
153 |
+
|
154 |
+
temp = resized_img.copy()
|
155 |
+
|
156 |
+
# Redraw all remaining arrows and dots
|
157 |
+
for i in range(0, len(points_json), 2):
|
158 |
+
start_point = points_json[i]
|
159 |
+
end_point = points_json[i + 1]
|
160 |
+
if start_point == end_point:
|
161 |
+
# Zero-length vector: Draw a dot
|
162 |
+
color = dot_color_fixed
|
163 |
+
else:
|
164 |
+
cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length,
|
165 |
+
line_type=cv2.LINE_AA)
|
166 |
+
color = arrow_color
|
167 |
+
# Draw arrow
|
168 |
+
|
169 |
+
# Draw dots at start and end points
|
170 |
+
cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
|
171 |
+
cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
|
172 |
+
|
173 |
+
# If there is an odd number of points (e.g., only a start point), draw a dot for it
|
174 |
+
if len(points_json) == 1:
|
175 |
+
start_point = points_json[0]
|
176 |
+
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
return temp, resized_img, img, points_json
|
181 |
|
182 |
|
183 |
def store_img(img):
|
|
|
188 |
|
189 |
with gr.Row():
|
190 |
with gr.Column():
|
191 |
+
gallery = gr.Gallery( ["./assets/ducks.jpg", "./assets/robot_arm.jpg", "./assets/bread.jpg", "./assets/bird.jpg", "./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/watering_pot.jpg"], columns=5, allow_preview=False, label="Select an example image to test")
|
192 |
# examples = gr.Examples(
|
193 |
# examples=[
|
194 |
# ["./assets/desk_1.jpg", "./assets/desk_1.jpg"],
|
|
|
262 |
# Draw arrow
|
263 |
|
264 |
# Draw dots at start and end points
|
265 |
+
cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
|
266 |
+
cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
|
267 |
|
268 |
# If there is an odd number of points (e.g., only a start point), draw a dot for it
|
269 |
if len(sel_pix) == 1:
|
270 |
start_point = sel_pix[0]
|
271 |
+
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
|
272 |
|
273 |
return temp if isinstance(temp, np.ndarray) else np.array(temp)
|
274 |
|
assets/intervention_test_images/annot.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bread.jpg": [[120, 257], [175, 269], [328, 375], [266, 353], [410, 217], [341, 248], [228, 149], [248, 211], [152, 129], [152, 129], [108, 51], [108, 51], [342, 39], [342, 39], [479, 93], [479, 93], [477, 390], [477, 390], [229, 486], [229, 486], [58, 442], [58, 442]]}
|