Spaces:
Runtime error
Runtime error
Commit
·
94dd0a9
1
Parent(s):
39c26fe
add duplicate space
Browse files
app.py
CHANGED
@@ -13,9 +13,13 @@ import requests
|
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
|
|
|
|
16 |
from tools.painter import mask_painter
|
17 |
-
|
18 |
-
|
|
|
|
|
19 |
|
20 |
# download checkpoints
|
21 |
def download_checkpoint(url, folder, filename):
|
@@ -202,6 +206,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
202 |
|
203 |
# tracking vos
|
204 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
|
205 |
model.xmem.clear_memory()
|
206 |
if interactive_state["track_end_number"]:
|
207 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
@@ -221,6 +226,8 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
221 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
222 |
fps = video_state["fps"]
|
223 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
|
|
|
|
224 |
|
225 |
if interactive_state["track_end_number"]:
|
226 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
@@ -260,6 +267,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
260 |
|
261 |
# inpaint
|
262 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
|
263 |
frames = np.asarray(video_state["origin_images"])
|
264 |
fps = video_state["fps"]
|
265 |
inpaint_masks = np.asarray(video_state["masks"])
|
@@ -306,27 +314,44 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
306 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
307 |
return output_path
|
308 |
|
|
|
|
|
|
|
|
|
309 |
# check and download checkpoints if needed
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
xmem_checkpoint = "XMem-s012.pth"
|
313 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
314 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
315 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
316 |
|
|
|
317 |
folder ="./checkpoints"
|
318 |
-
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder,
|
319 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
320 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
321 |
-
|
322 |
-
args = parse_augment()
|
323 |
-
# args.port = 12315
|
324 |
-
# args.device = "cuda:2"
|
325 |
-
# args.mask_save = True
|
326 |
|
327 |
# initialize sam, xmem, e2fgvi models
|
328 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
with gr.Blocks() as iface:
|
331 |
"""
|
332 |
state for
|
@@ -358,7 +383,8 @@ with gr.Blocks() as iface:
|
|
358 |
"fps": 30
|
359 |
}
|
360 |
)
|
361 |
-
|
|
|
362 |
with gr.Row():
|
363 |
|
364 |
# for user video input
|
@@ -367,7 +393,7 @@ with gr.Blocks() as iface:
|
|
367 |
video_input = gr.Video(autosize=True)
|
368 |
with gr.Column():
|
369 |
video_info = gr.Textbox()
|
370 |
-
|
371 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
372 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
373 |
|
|
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
16 |
+
from tools.interact_tools import SamControler
|
17 |
+
from tracker.base_tracker import BaseTracker
|
18 |
from tools.painter import mask_painter
|
19 |
+
try:
|
20 |
+
from mmcv.cnn import ConvModule
|
21 |
+
except:
|
22 |
+
os.system("mim install mmcv")
|
23 |
|
24 |
# download checkpoints
|
25 |
def download_checkpoint(url, folder, filename):
|
|
|
206 |
|
207 |
# tracking vos
|
208 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
209 |
+
|
210 |
model.xmem.clear_memory()
|
211 |
if interactive_state["track_end_number"]:
|
212 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
226 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
227 |
fps = video_state["fps"]
|
228 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
229 |
+
# clear GPU memory
|
230 |
+
model.xmem.clear_memory()
|
231 |
|
232 |
if interactive_state["track_end_number"]:
|
233 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
|
267 |
|
268 |
# inpaint
|
269 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
270 |
+
|
271 |
frames = np.asarray(video_state["origin_images"])
|
272 |
fps = video_state["fps"]
|
273 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
|
314 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
315 |
return output_path
|
316 |
|
317 |
+
|
318 |
+
# args, defined in track_anything.py
|
319 |
+
args = parse_augment()
|
320 |
+
|
321 |
# check and download checkpoints if needed
|
322 |
+
SAM_checkpoint_dict = {
|
323 |
+
'vit_h': "sam_vit_h_4b8939.pth",
|
324 |
+
'vit_l': "sam_vit_l_0b3195.pth",
|
325 |
+
"vit_b": "sam_vit_b_01ec64.pth"
|
326 |
+
}
|
327 |
+
SAM_checkpoint_url_dict = {
|
328 |
+
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
329 |
+
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
330 |
+
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
331 |
+
}
|
332 |
+
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
333 |
+
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
334 |
xmem_checkpoint = "XMem-s012.pth"
|
335 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
336 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
337 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
338 |
|
339 |
+
|
340 |
folder ="./checkpoints"
|
341 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
342 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
343 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
344 |
+
|
|
|
|
|
|
|
|
|
345 |
|
346 |
# initialize sam, xmem, e2fgvi models
|
347 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
348 |
|
349 |
+
|
350 |
+
title = """<p><h1 align="center">Track-Anything</h1></p>
|
351 |
+
"""
|
352 |
+
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
353 |
+
|
354 |
+
|
355 |
with gr.Blocks() as iface:
|
356 |
"""
|
357 |
state for
|
|
|
383 |
"fps": 30
|
384 |
}
|
385 |
)
|
386 |
+
gr.Markdown(title)
|
387 |
+
gr.Markdown(description)
|
388 |
with gr.Row():
|
389 |
|
390 |
# for user video input
|
|
|
393 |
video_input = gr.Video(autosize=True)
|
394 |
with gr.Column():
|
395 |
video_info = gr.Textbox()
|
396 |
+
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
|
397 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
398 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
399 |
|