Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
·
53a8438
1
Parent(s):
9b9cd68
requirement fix
Browse files- app.py +38 -8
- assets/avengers.gif +2 -2
- inpainter/.DS_Store +0 -0
- inpainter/base_inpainter.py +6 -4
- inpainter/model/modules/tfocal_transformer_hq.py +2 -0
- requirements.txt +2 -5
- track_anything.py +10 -6
- tracker/.DS_Store +0 -0
- tracker/base_tracker.py +1 -0
app.py
CHANGED
@@ -13,7 +13,13 @@ import requests
|
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
|
|
|
|
16 |
from tools.painter import mask_painter
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# download checkpoints
|
19 |
def download_checkpoint(url, folder, filename):
|
@@ -200,6 +206,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
200 |
|
201 |
# tracking vos
|
202 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
|
203 |
model.xmem.clear_memory()
|
204 |
if interactive_state["track_end_number"]:
|
205 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
@@ -219,6 +226,8 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
219 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
220 |
fps = video_state["fps"]
|
221 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
|
|
|
|
222 |
|
223 |
if interactive_state["track_end_number"]:
|
224 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
@@ -258,6 +267,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
258 |
|
259 |
# inpaint
|
260 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
|
261 |
frames = np.asarray(video_state["origin_images"])
|
262 |
fps = video_state["fps"]
|
263 |
inpaint_masks = np.asarray(video_state["masks"])
|
@@ -304,20 +314,33 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
304 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
305 |
return output_path
|
306 |
|
|
|
|
|
|
|
|
|
307 |
# check and download checkpoints if needed
|
308 |
-
|
309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
xmem_checkpoint = "XMem-s012.pth"
|
311 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
312 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
313 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
314 |
|
|
|
315 |
folder ="./checkpoints"
|
316 |
-
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder,
|
317 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
318 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
319 |
-
# args, defined in track_anything.py
|
320 |
-
args = parse_augment()
|
321 |
# args.port = 12315
|
322 |
# args.device = "cuda:2"
|
323 |
# args.mask_save = True
|
@@ -325,6 +348,12 @@ args = parse_augment()
|
|
325 |
# initialize sam, xmem, e2fgvi models
|
326 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
with gr.Blocks() as iface:
|
329 |
"""
|
330 |
state for
|
@@ -356,7 +385,8 @@ with gr.Blocks() as iface:
|
|
356 |
"fps": 30
|
357 |
}
|
358 |
)
|
359 |
-
|
|
|
360 |
with gr.Row():
|
361 |
|
362 |
# for user video input
|
@@ -365,7 +395,7 @@ with gr.Blocks() as iface:
|
|
365 |
video_input = gr.Video(autosize=True)
|
366 |
with gr.Column():
|
367 |
video_info = gr.Textbox()
|
368 |
-
|
369 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
370 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
371 |
|
@@ -534,7 +564,7 @@ with gr.Blocks() as iface:
|
|
534 |
# cache_examples=True,
|
535 |
)
|
536 |
iface.queue(concurrency_count=1)
|
537 |
-
iface.launch(debug=True
|
538 |
|
539 |
|
540 |
|
|
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
16 |
+
from tools.interact_tools import SamControler
|
17 |
+
from tracker.base_tracker import BaseTracker
|
18 |
from tools.painter import mask_painter
|
19 |
+
try:
|
20 |
+
from mmcv.cnn import ConvModule
|
21 |
+
except:
|
22 |
+
os.system("mim install mmcv")
|
23 |
|
24 |
# download checkpoints
|
25 |
def download_checkpoint(url, folder, filename):
|
|
|
206 |
|
207 |
# tracking vos
|
208 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
209 |
+
|
210 |
model.xmem.clear_memory()
|
211 |
if interactive_state["track_end_number"]:
|
212 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
226 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
227 |
fps = video_state["fps"]
|
228 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
229 |
+
# clear GPU memory
|
230 |
+
model.xmem.clear_memory()
|
231 |
|
232 |
if interactive_state["track_end_number"]:
|
233 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
|
267 |
|
268 |
# inpaint
|
269 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
270 |
+
|
271 |
frames = np.asarray(video_state["origin_images"])
|
272 |
fps = video_state["fps"]
|
273 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
|
314 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
315 |
return output_path
|
316 |
|
317 |
+
|
318 |
+
# args, defined in track_anything.py
|
319 |
+
args = parse_augment()
|
320 |
+
|
321 |
# check and download checkpoints if needed
|
322 |
+
SAM_checkpoint_dict = {
|
323 |
+
'vit_h': "sam_vit_h_4b8939.pth",
|
324 |
+
'vit_l': "sam_vit_l_0b3195.pth",
|
325 |
+
"vit_b": "sam_vit_b_01ec64.pth"
|
326 |
+
}
|
327 |
+
SAM_checkpoint_url_dict = {
|
328 |
+
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
329 |
+
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
330 |
+
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
331 |
+
}
|
332 |
+
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
333 |
+
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
334 |
xmem_checkpoint = "XMem-s012.pth"
|
335 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
336 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
337 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
338 |
|
339 |
+
|
340 |
folder ="./checkpoints"
|
341 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
342 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
343 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
|
|
|
|
344 |
# args.port = 12315
|
345 |
# args.device = "cuda:2"
|
346 |
# args.mask_save = True
|
|
|
348 |
# initialize sam, xmem, e2fgvi models
|
349 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
350 |
|
351 |
+
|
352 |
+
title = """<p><h1 align="center">Track-Anything</h1></p>
|
353 |
+
"""
|
354 |
+
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
355 |
+
|
356 |
+
|
357 |
with gr.Blocks() as iface:
|
358 |
"""
|
359 |
state for
|
|
|
385 |
"fps": 30
|
386 |
}
|
387 |
)
|
388 |
+
gr.Markdown(title)
|
389 |
+
gr.Markdown(description)
|
390 |
with gr.Row():
|
391 |
|
392 |
# for user video input
|
|
|
395 |
video_input = gr.Video(autosize=True)
|
396 |
with gr.Column():
|
397 |
video_info = gr.Textbox()
|
398 |
+
resize_info = gr.Textbox(value="Due to server restrictions, please upload a video that is no longer than 2 minutes. If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
|
399 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
400 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
401 |
|
|
|
564 |
# cache_examples=True,
|
565 |
)
|
566 |
iface.queue(concurrency_count=1)
|
567 |
+
iface.launch(debug=True)
|
568 |
|
569 |
|
570 |
|
assets/avengers.gif
CHANGED
Git LFS Details
|
Git LFS Details
|
inpainter/.DS_Store
CHANGED
Binary files a/inpainter/.DS_Store and b/inpainter/.DS_Store differ
|
|
inpainter/base_inpainter.py
CHANGED
@@ -7,6 +7,8 @@ import yaml
|
|
7 |
import cv2
|
8 |
import importlib
|
9 |
import numpy as np
|
|
|
|
|
10 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
11 |
|
12 |
|
@@ -66,15 +68,15 @@ class BaseInpainter:
|
|
66 |
if ratio == 1:
|
67 |
size = None
|
68 |
else:
|
69 |
-
size =
|
70 |
if size[0] % 2 > 0:
|
71 |
size[0] += 1
|
72 |
if size[1] % 2 > 0:
|
73 |
size[1] += 1
|
74 |
|
75 |
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
76 |
-
binary_masks = resize_masks(masks, size)
|
77 |
-
frames = resize_frames(frames, size) # T, H, W, 3
|
78 |
# frames and binary_masks are numpy arrays
|
79 |
|
80 |
h, w = frames.shape[1:3]
|
@@ -87,7 +89,7 @@ class BaseInpainter:
|
|
87 |
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
88 |
comp_frames = [None] * video_length
|
89 |
|
90 |
-
for f in range(0, video_length, self.neighbor_stride):
|
91 |
neighbor_ids = [
|
92 |
i for i in range(max(0, f - self.neighbor_stride),
|
93 |
min(video_length, f + self.neighbor_stride + 1))
|
|
|
7 |
import cv2
|
8 |
import importlib
|
9 |
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
13 |
|
14 |
|
|
|
68 |
if ratio == 1:
|
69 |
size = None
|
70 |
else:
|
71 |
+
size = [int(W*ratio), int(H*ratio)]
|
72 |
if size[0] % 2 > 0:
|
73 |
size[0] += 1
|
74 |
if size[1] % 2 > 0:
|
75 |
size[1] += 1
|
76 |
|
77 |
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
78 |
+
binary_masks = resize_masks(masks, tuple(size))
|
79 |
+
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
80 |
# frames and binary_masks are numpy arrays
|
81 |
|
82 |
h, w = frames.shape[1:3]
|
|
|
89 |
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
90 |
comp_frames = [None] * video_length
|
91 |
|
92 |
+
for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
|
93 |
neighbor_ids = [
|
94 |
i for i in range(max(0, f - self.neighbor_stride),
|
95 |
min(video_length, f + self.neighbor_stride + 1))
|
inpainter/model/modules/tfocal_transformer_hq.py
CHANGED
@@ -128,8 +128,10 @@ def window_partition(x, window_size):
|
|
128 |
windows: (B*num_windows, T*window_size*window_size, C)
|
129 |
"""
|
130 |
B, T, H, W, C = x.shape
|
|
|
131 |
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
132 |
window_size[1], C)
|
|
|
133 |
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
134 |
-1, T * window_size[0] * window_size[1], C)
|
135 |
return windows
|
|
|
128 |
windows: (B*num_windows, T*window_size*window_size, C)
|
129 |
"""
|
130 |
B, T, H, W, C = x.shape
|
131 |
+
|
132 |
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
133 |
window_size[1], C)
|
134 |
+
|
135 |
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
136 |
-1, T * window_size[0] * window_size[1], C)
|
137 |
return windows
|
requirements.txt
CHANGED
@@ -10,10 +10,7 @@ gradio==3.25.0
|
|
10 |
opencv-python
|
11 |
pycocotools
|
12 |
matplotlib
|
13 |
-
onnxruntime
|
14 |
-
onnx
|
15 |
-
metaseg==0.6.1
|
16 |
pyyaml
|
17 |
av
|
18 |
-
|
19 |
-
|
|
|
10 |
opencv-python
|
11 |
pycocotools
|
12 |
matplotlib
|
|
|
|
|
|
|
13 |
pyyaml
|
14 |
av
|
15 |
+
openmim
|
16 |
+
tqdm
|
track_anything.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
import PIL
|
|
|
|
|
2 |
from tools.interact_tools import SamControler
|
3 |
from tracker.base_tracker import BaseTracker
|
4 |
from inpainter.base_inpainter import BaseInpainter
|
@@ -10,9 +12,12 @@ import argparse
|
|
10 |
class TrackingAnything():
|
11 |
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
12 |
self.args = args
|
13 |
-
self.
|
14 |
-
self.
|
15 |
-
self.
|
|
|
|
|
|
|
16 |
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
17 |
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
18 |
# if first_flag:
|
@@ -39,7 +44,7 @@ class TrackingAnything():
|
|
39 |
masks = []
|
40 |
logits = []
|
41 |
painted_images = []
|
42 |
-
for i in range(len(images)):
|
43 |
if i ==0:
|
44 |
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
45 |
masks.append(mask)
|
@@ -51,7 +56,6 @@ class TrackingAnything():
|
|
51 |
masks.append(mask)
|
52 |
logits.append(logit)
|
53 |
painted_images.append(painted_image)
|
54 |
-
print("tracking image {}".format(i))
|
55 |
return masks, logits, painted_images
|
56 |
|
57 |
|
|
|
1 |
+
import PIL
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
from tools.interact_tools import SamControler
|
5 |
from tracker.base_tracker import BaseTracker
|
6 |
from inpainter.base_inpainter import BaseInpainter
|
|
|
12 |
class TrackingAnything():
|
13 |
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
14 |
self.args = args
|
15 |
+
self.sam_checkpoint = sam_checkpoint
|
16 |
+
self.xmem_checkpoint = xmem_checkpoint
|
17 |
+
self.e2fgvi_checkpoint = e2fgvi_checkpoint
|
18 |
+
self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device)
|
19 |
+
self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
|
20 |
+
self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device)
|
21 |
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
22 |
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
23 |
# if first_flag:
|
|
|
44 |
masks = []
|
45 |
logits = []
|
46 |
painted_images = []
|
47 |
+
for i in tqdm(range(len(images)), desc="Tracking image"):
|
48 |
if i ==0:
|
49 |
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
50 |
masks.append(mask)
|
|
|
56 |
masks.append(mask)
|
57 |
logits.append(logit)
|
58 |
painted_images.append(painted_image)
|
|
|
59 |
return masks, logits, painted_images
|
60 |
|
61 |
|
tracker/.DS_Store
CHANGED
Binary files a/tracker/.DS_Store and b/tracker/.DS_Store differ
|
|
tracker/base_tracker.py
CHANGED
@@ -126,6 +126,7 @@ class BaseTracker:
|
|
126 |
def clear_memory(self):
|
127 |
self.tracker.clear_memory()
|
128 |
self.mapper.clear_labels()
|
|
|
129 |
|
130 |
|
131 |
## how to use:
|
|
|
126 |
def clear_memory(self):
|
127 |
self.tracker.clear_memory()
|
128 |
self.mapper.clear_labels()
|
129 |
+
torch.cuda.empty_cache()
|
130 |
|
131 |
|
132 |
## how to use:
|