watchtowerss commited on
Commit
53a8438
·
1 Parent(s): 9b9cd68

requirement fix

Browse files
app.py CHANGED
@@ -13,7 +13,13 @@ import requests
13
  import json
14
  import torchvision
15
  import torch
 
 
16
  from tools.painter import mask_painter
 
 
 
 
17
 
18
  # download checkpoints
19
  def download_checkpoint(url, folder, filename):
@@ -200,6 +206,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
200
 
201
  # tracking vos
202
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
 
203
  model.xmem.clear_memory()
204
  if interactive_state["track_end_number"]:
205
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -219,6 +226,8 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
219
  template_mask = video_state["masks"][video_state["select_frame_number"]]
220
  fps = video_state["fps"]
221
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
 
 
222
 
223
  if interactive_state["track_end_number"]:
224
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
@@ -258,6 +267,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
258
 
259
  # inpaint
260
  def inpaint_video(video_state, interactive_state, mask_dropdown):
 
261
  frames = np.asarray(video_state["origin_images"])
262
  fps = video_state["fps"]
263
  inpaint_masks = np.asarray(video_state["masks"])
@@ -304,20 +314,33 @@ def generate_video_from_frames(frames, output_path, fps=30):
304
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
305
  return output_path
306
 
 
 
 
 
307
  # check and download checkpoints if needed
308
- SAM_checkpoint = "sam_vit_h_4b8939.pth"
309
- sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
 
 
 
 
 
 
 
 
 
 
310
  xmem_checkpoint = "XMem-s012.pth"
311
  xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
312
  e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
313
  e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
314
 
 
315
  folder ="./checkpoints"
316
- SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
317
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
318
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
319
- # args, defined in track_anything.py
320
- args = parse_augment()
321
  # args.port = 12315
322
  # args.device = "cuda:2"
323
  # args.mask_save = True
@@ -325,6 +348,12 @@ args = parse_augment()
325
  # initialize sam, xmem, e2fgvi models
326
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
327
 
 
 
 
 
 
 
328
  with gr.Blocks() as iface:
329
  """
330
  state for
@@ -356,7 +385,8 @@ with gr.Blocks() as iface:
356
  "fps": 30
357
  }
358
  )
359
-
 
360
  with gr.Row():
361
 
362
  # for user video input
@@ -365,7 +395,7 @@ with gr.Blocks() as iface:
365
  video_input = gr.Video(autosize=True)
366
  with gr.Column():
367
  video_info = gr.Textbox()
368
- video_info = gr.Textbox(value="Due to server restrictions, please upload a video that is no longer than 2 minutes. If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
369
  Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
370
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
371
 
@@ -534,7 +564,7 @@ with gr.Blocks() as iface:
534
  # cache_examples=True,
535
  )
536
  iface.queue(concurrency_count=1)
537
- iface.launch(debug=True, enable_queue=True)
538
 
539
 
540
 
 
13
  import json
14
  import torchvision
15
  import torch
16
+ from tools.interact_tools import SamControler
17
+ from tracker.base_tracker import BaseTracker
18
  from tools.painter import mask_painter
19
+ try:
20
+ from mmcv.cnn import ConvModule
21
+ except:
22
+ os.system("mim install mmcv")
23
 
24
  # download checkpoints
25
  def download_checkpoint(url, folder, filename):
 
206
 
207
  # tracking vos
208
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
209
+
210
  model.xmem.clear_memory()
211
  if interactive_state["track_end_number"]:
212
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
 
226
  template_mask = video_state["masks"][video_state["select_frame_number"]]
227
  fps = video_state["fps"]
228
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
229
+ # clear GPU memory
230
+ model.xmem.clear_memory()
231
 
232
  if interactive_state["track_end_number"]:
233
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
 
267
 
268
  # inpaint
269
  def inpaint_video(video_state, interactive_state, mask_dropdown):
270
+
271
  frames = np.asarray(video_state["origin_images"])
272
  fps = video_state["fps"]
273
  inpaint_masks = np.asarray(video_state["masks"])
 
314
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
315
  return output_path
316
 
317
+
318
+ # args, defined in track_anything.py
319
+ args = parse_augment()
320
+
321
  # check and download checkpoints if needed
322
+ SAM_checkpoint_dict = {
323
+ 'vit_h': "sam_vit_h_4b8939.pth",
324
+ 'vit_l': "sam_vit_l_0b3195.pth",
325
+ "vit_b": "sam_vit_b_01ec64.pth"
326
+ }
327
+ SAM_checkpoint_url_dict = {
328
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
329
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
330
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
331
+ }
332
+ sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
333
+ sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
334
  xmem_checkpoint = "XMem-s012.pth"
335
  xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
336
  e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
337
  e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
338
 
339
+
340
  folder ="./checkpoints"
341
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
342
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
343
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
 
 
344
  # args.port = 12315
345
  # args.device = "cuda:2"
346
  # args.mask_save = True
 
348
  # initialize sam, xmem, e2fgvi models
349
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
350
 
351
+
352
+ title = """<p><h1 align="center">Track-Anything</h1></p>
353
+ """
354
+ description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
355
+
356
+
357
  with gr.Blocks() as iface:
358
  """
359
  state for
 
385
  "fps": 30
386
  }
387
  )
388
+ gr.Markdown(title)
389
+ gr.Markdown(description)
390
  with gr.Row():
391
 
392
  # for user video input
 
395
  video_input = gr.Video(autosize=True)
396
  with gr.Column():
397
  video_info = gr.Textbox()
398
+ resize_info = gr.Textbox(value="Due to server restrictions, please upload a video that is no longer than 2 minutes. If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
399
  Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
400
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
401
 
 
564
  # cache_examples=True,
565
  )
566
  iface.queue(concurrency_count=1)
567
+ iface.launch(debug=True)
568
 
569
 
570
 
assets/avengers.gif CHANGED

Git LFS Details

  • SHA256: 9193a028c2e968ff7a7ee222ccc27166a5fbbe40a4d971cee13eba519134c5cf
  • Pointer size: 133 Bytes
  • Size of remote file: 99.2 MB

Git LFS Details

  • SHA256: 5e07b86ee4cf002b3481c71e2038c03f4420883c3be78220dafbc4b59abfb32d
  • Pointer size: 133 Bytes
  • Size of remote file: 30 MB
inpainter/.DS_Store CHANGED
Binary files a/inpainter/.DS_Store and b/inpainter/.DS_Store differ
 
inpainter/base_inpainter.py CHANGED
@@ -7,6 +7,8 @@ import yaml
7
  import cv2
8
  import importlib
9
  import numpy as np
 
 
10
  from inpainter.util.tensor_util import resize_frames, resize_masks
11
 
12
 
@@ -66,15 +68,15 @@ class BaseInpainter:
66
  if ratio == 1:
67
  size = None
68
  else:
69
- size = (int(W*ratio), int(H*ratio))
70
  if size[0] % 2 > 0:
71
  size[0] += 1
72
  if size[1] % 2 > 0:
73
  size[1] += 1
74
 
75
  masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
76
- binary_masks = resize_masks(masks, size)
77
- frames = resize_frames(frames, size) # T, H, W, 3
78
  # frames and binary_masks are numpy arrays
79
 
80
  h, w = frames.shape[1:3]
@@ -87,7 +89,7 @@ class BaseInpainter:
87
  imgs, masks = imgs.to(self.device), masks.to(self.device)
88
  comp_frames = [None] * video_length
89
 
90
- for f in range(0, video_length, self.neighbor_stride):
91
  neighbor_ids = [
92
  i for i in range(max(0, f - self.neighbor_stride),
93
  min(video_length, f + self.neighbor_stride + 1))
 
7
  import cv2
8
  import importlib
9
  import numpy as np
10
+ from tqdm import tqdm
11
+
12
  from inpainter.util.tensor_util import resize_frames, resize_masks
13
 
14
 
 
68
  if ratio == 1:
69
  size = None
70
  else:
71
+ size = [int(W*ratio), int(H*ratio)]
72
  if size[0] % 2 > 0:
73
  size[0] += 1
74
  if size[1] % 2 > 0:
75
  size[1] += 1
76
 
77
  masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
78
+ binary_masks = resize_masks(masks, tuple(size))
79
+ frames = resize_frames(frames, tuple(size)) # T, H, W, 3
80
  # frames and binary_masks are numpy arrays
81
 
82
  h, w = frames.shape[1:3]
 
89
  imgs, masks = imgs.to(self.device), masks.to(self.device)
90
  comp_frames = [None] * video_length
91
 
92
+ for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
93
  neighbor_ids = [
94
  i for i in range(max(0, f - self.neighbor_stride),
95
  min(video_length, f + self.neighbor_stride + 1))
inpainter/model/modules/tfocal_transformer_hq.py CHANGED
@@ -128,8 +128,10 @@ def window_partition(x, window_size):
128
  windows: (B*num_windows, T*window_size*window_size, C)
129
  """
130
  B, T, H, W, C = x.shape
 
131
  x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
132
  window_size[1], C)
 
133
  windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
134
  -1, T * window_size[0] * window_size[1], C)
135
  return windows
 
128
  windows: (B*num_windows, T*window_size*window_size, C)
129
  """
130
  B, T, H, W, C = x.shape
131
+
132
  x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
133
  window_size[1], C)
134
+
135
  windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
136
  -1, T * window_size[0] * window_size[1], C)
137
  return windows
requirements.txt CHANGED
@@ -10,10 +10,7 @@ gradio==3.25.0
10
  opencv-python
11
  pycocotools
12
  matplotlib
13
- onnxruntime
14
- onnx
15
- metaseg==0.6.1
16
  pyyaml
17
  av
18
- mmcv-full
19
- mmengine
 
10
  opencv-python
11
  pycocotools
12
  matplotlib
 
 
 
13
  pyyaml
14
  av
15
+ openmim
16
+ tqdm
track_anything.py CHANGED
@@ -1,4 +1,6 @@
1
- import PIL
 
 
2
  from tools.interact_tools import SamControler
3
  from tracker.base_tracker import BaseTracker
4
  from inpainter.base_inpainter import BaseInpainter
@@ -10,9 +12,12 @@ import argparse
10
  class TrackingAnything():
11
  def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
12
  self.args = args
13
- self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
14
- self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
15
- self.baseinpainter = BaseInpainter(e2fgvi_checkpoint, args.device)
 
 
 
16
  # def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
17
  # same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
18
  # if first_flag:
@@ -39,7 +44,7 @@ class TrackingAnything():
39
  masks = []
40
  logits = []
41
  painted_images = []
42
- for i in range(len(images)):
43
  if i ==0:
44
  mask, logit, painted_image = self.xmem.track(images[i], template_mask)
45
  masks.append(mask)
@@ -51,7 +56,6 @@ class TrackingAnything():
51
  masks.append(mask)
52
  logits.append(logit)
53
  painted_images.append(painted_image)
54
- print("tracking image {}".format(i))
55
  return masks, logits, painted_images
56
 
57
 
 
1
+ import PIL
2
+ from tqdm import tqdm
3
+
4
  from tools.interact_tools import SamControler
5
  from tracker.base_tracker import BaseTracker
6
  from inpainter.base_inpainter import BaseInpainter
 
12
  class TrackingAnything():
13
  def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
14
  self.args = args
15
+ self.sam_checkpoint = sam_checkpoint
16
+ self.xmem_checkpoint = xmem_checkpoint
17
+ self.e2fgvi_checkpoint = e2fgvi_checkpoint
18
+ self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device)
19
+ self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
20
+ self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device)
21
  # def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
22
  # same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
23
  # if first_flag:
 
44
  masks = []
45
  logits = []
46
  painted_images = []
47
+ for i in tqdm(range(len(images)), desc="Tracking image"):
48
  if i ==0:
49
  mask, logit, painted_image = self.xmem.track(images[i], template_mask)
50
  masks.append(mask)
 
56
  masks.append(mask)
57
  logits.append(logit)
58
  painted_images.append(painted_image)
 
59
  return masks, logits, painted_images
60
 
61
 
tracker/.DS_Store CHANGED
Binary files a/tracker/.DS_Store and b/tracker/.DS_Store differ
 
tracker/base_tracker.py CHANGED
@@ -126,6 +126,7 @@ class BaseTracker:
126
  def clear_memory(self):
127
  self.tracker.clear_memory()
128
  self.mapper.clear_labels()
 
129
 
130
 
131
  ## how to use:
 
126
  def clear_memory(self):
127
  self.tracker.clear_memory()
128
  self.mapper.clear_labels()
129
+ torch.cuda.empty_cache()
130
 
131
 
132
  ## how to use: