move inference_states out of gr.State
Browse files
app.py
CHANGED
@@ -73,6 +73,7 @@ OBJ_ID = 0
|
|
73 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
74 |
model_cfg = "edgetam.yaml"
|
75 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
|
|
76 |
|
77 |
|
78 |
def get_video_fps(video_path):
|
@@ -89,15 +90,17 @@ def get_video_fps(video_path):
|
|
89 |
return fps
|
90 |
|
91 |
|
92 |
-
def reset(
|
93 |
predictor.to("cpu")
|
94 |
session_state["input_points"] = []
|
95 |
session_state["input_labels"] = []
|
96 |
-
|
97 |
-
|
|
|
|
|
98 |
session_state["first_frame"] = None
|
99 |
session_state["all_frames"] = None
|
100 |
-
|
101 |
return (
|
102 |
None,
|
103 |
gr.update(open=True),
|
@@ -112,8 +115,9 @@ def clear_points(session_state):
|
|
112 |
predictor.to("cpu")
|
113 |
session_state["input_points"] = []
|
114 |
session_state["input_labels"] = []
|
115 |
-
|
116 |
-
|
|
|
117 |
return (
|
118 |
session_state["first_frame"],
|
119 |
None,
|
@@ -168,7 +172,9 @@ def preprocess_video_in(video_path, session_state):
|
|
168 |
session_state["first_frame"] = copy.deepcopy(first_frame)
|
169 |
session_state["all_frames"] = all_frames
|
170 |
|
171 |
-
|
|
|
|
|
172 |
session_state["input_points"] = []
|
173 |
session_state["input_labels"] = []
|
174 |
|
@@ -230,8 +236,9 @@ def segment_with_points(
|
|
230 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
231 |
# for labels, `1` means positive click and `0` means negative click
|
232 |
labels = np.array(session_state["input_labels"], np.int32)
|
|
|
233 |
_, _, out_mask_logits = predictor.add_new_points(
|
234 |
-
inference_state=
|
235 |
frame_idx=0,
|
236 |
obj_id=OBJ_ID,
|
237 |
points=points,
|
@@ -270,10 +277,11 @@ def propagate_to_all(
|
|
270 |
torch.backends.cuda.matmul.allow_tf32 = True
|
271 |
torch.backends.cudnn.allow_tf32 = True
|
272 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
273 |
if (
|
274 |
len(session_state["input_points"]) == 0
|
275 |
or video_in is None
|
276 |
-
or
|
277 |
):
|
278 |
return (
|
279 |
None,
|
@@ -286,7 +294,7 @@ def propagate_to_all(
|
|
286 |
) # video_segments contains the per-frame segmentation results
|
287 |
print("starting propagate_in_video")
|
288 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
289 |
-
|
290 |
):
|
291 |
video_segments[out_frame_idx] = {
|
292 |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
@@ -340,7 +348,6 @@ with gr.Blocks() as demo:
|
|
340 |
"all_frames": None,
|
341 |
"input_points": [],
|
342 |
"input_labels": [],
|
343 |
-
"inference_state": None,
|
344 |
}
|
345 |
)
|
346 |
|
|
|
73 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
74 |
model_cfg = "edgetam.yaml"
|
75 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
76 |
+
global_inference_states = {}
|
77 |
|
78 |
|
79 |
def get_video_fps(video_path):
|
|
|
90 |
return fps
|
91 |
|
92 |
|
93 |
+
def reset():
|
94 |
predictor.to("cpu")
|
95 |
session_state["input_points"] = []
|
96 |
session_state["input_labels"] = []
|
97 |
+
|
98 |
+
session_id = id(session_state)
|
99 |
+
if global_inference_states[session_id] is not None:
|
100 |
+
predictor.reset_state(global_inference_states[session_id])
|
101 |
session_state["first_frame"] = None
|
102 |
session_state["all_frames"] = None
|
103 |
+
global_inference_states[session_id] = None
|
104 |
return (
|
105 |
None,
|
106 |
gr.update(open=True),
|
|
|
115 |
predictor.to("cpu")
|
116 |
session_state["input_points"] = []
|
117 |
session_state["input_labels"] = []
|
118 |
+
session_id = id(session_state)
|
119 |
+
if global_inference_states[session_id]["tracking_has_started"]:
|
120 |
+
predictor.reset_state(global_inference_states[session_id])
|
121 |
return (
|
122 |
session_state["first_frame"],
|
123 |
None,
|
|
|
172 |
session_state["first_frame"] = copy.deepcopy(first_frame)
|
173 |
session_state["all_frames"] = all_frames
|
174 |
|
175 |
+
session_id = id(session_state)
|
176 |
+
global_inference_states[session_id] = predictor.init_state(video_path=video_path)
|
177 |
+
|
178 |
session_state["input_points"] = []
|
179 |
session_state["input_labels"] = []
|
180 |
|
|
|
236 |
points = np.array(session_state["input_points"], dtype=np.float32)
|
237 |
# for labels, `1` means positive click and `0` means negative click
|
238 |
labels = np.array(session_state["input_labels"], np.int32)
|
239 |
+
session_id = id(session_state)
|
240 |
_, _, out_mask_logits = predictor.add_new_points(
|
241 |
+
inference_state=global_inference_states[session_id],
|
242 |
frame_idx=0,
|
243 |
obj_id=OBJ_ID,
|
244 |
points=points,
|
|
|
277 |
torch.backends.cuda.matmul.allow_tf32 = True
|
278 |
torch.backends.cudnn.allow_tf32 = True
|
279 |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
280 |
+
session_id = id(session_state)
|
281 |
if (
|
282 |
len(session_state["input_points"]) == 0
|
283 |
or video_in is None
|
284 |
+
or global_inference_states[session_id] is None
|
285 |
):
|
286 |
return (
|
287 |
None,
|
|
|
294 |
) # video_segments contains the per-frame segmentation results
|
295 |
print("starting propagate_in_video")
|
296 |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
297 |
+
global_inference_states[session_id]
|
298 |
):
|
299 |
video_segments[out_frame_idx] = {
|
300 |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
|
|
348 |
"all_frames": None,
|
349 |
"input_points": [],
|
350 |
"input_labels": [],
|
|
|
351 |
}
|
352 |
)
|
353 |
|