watchtowerss commited on
Commit
94dd0a9
·
1 Parent(s): 39c26fe

add duplicate space

Browse files
Files changed (1) hide show
  1. app.py +38 -12
app.py CHANGED
@@ -13,9 +13,13 @@ import requests
13
  import json
14
  import torchvision
15
  import torch
 
 
16
  from tools.painter import mask_painter
17
-
18
- os.system("mim install mmcv")
 
 
19
 
20
  # download checkpoints
21
  def download_checkpoint(url, folder, filename):
@@ -202,6 +206,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
202
 
203
  # tracking vos
204
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
 
205
  model.xmem.clear_memory()
206
  if interactive_state["track_end_number"]:
207
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -221,6 +226,8 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
221
  template_mask = video_state["masks"][video_state["select_frame_number"]]
222
  fps = video_state["fps"]
223
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
 
 
224
 
225
  if interactive_state["track_end_number"]:
226
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
@@ -260,6 +267,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
260
 
261
  # inpaint
262
  def inpaint_video(video_state, interactive_state, mask_dropdown):
 
263
  frames = np.asarray(video_state["origin_images"])
264
  fps = video_state["fps"]
265
  inpaint_masks = np.asarray(video_state["masks"])
@@ -306,27 +314,44 @@ def generate_video_from_frames(frames, output_path, fps=30):
306
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
307
  return output_path
308
 
 
 
 
 
309
  # check and download checkpoints if needed
310
- SAM_checkpoint = "sam_vit_h_4b8939.pth"
311
- sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
 
 
 
 
 
 
 
 
 
 
312
  xmem_checkpoint = "XMem-s012.pth"
313
  xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
314
  e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
315
  e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
316
 
 
317
  folder ="./checkpoints"
318
- SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
319
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
320
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
321
- # args, defined in track_anything.py
322
- args = parse_augment()
323
- # args.port = 12315
324
- # args.device = "cuda:2"
325
- # args.mask_save = True
326
 
327
  # initialize sam, xmem, e2fgvi models
328
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
329
 
 
 
 
 
 
 
330
  with gr.Blocks() as iface:
331
  """
332
  state for
@@ -358,7 +383,8 @@ with gr.Blocks() as iface:
358
  "fps": 30
359
  }
360
  )
361
-
 
362
  with gr.Row():
363
 
364
  # for user video input
@@ -367,7 +393,7 @@ with gr.Blocks() as iface:
367
  video_input = gr.Video(autosize=True)
368
  with gr.Column():
369
  video_info = gr.Textbox()
370
- inpaint_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
371
  Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
372
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
373
 
 
13
  import json
14
  import torchvision
15
  import torch
16
+ from tools.interact_tools import SamControler
17
+ from tracker.base_tracker import BaseTracker
18
  from tools.painter import mask_painter
19
+ try:
20
+ from mmcv.cnn import ConvModule
21
+ except:
22
+ os.system("mim install mmcv")
23
 
24
  # download checkpoints
25
  def download_checkpoint(url, folder, filename):
 
206
 
207
  # tracking vos
208
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
209
+
210
  model.xmem.clear_memory()
211
  if interactive_state["track_end_number"]:
212
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
 
226
  template_mask = video_state["masks"][video_state["select_frame_number"]]
227
  fps = video_state["fps"]
228
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
229
+ # clear GPU memory
230
+ model.xmem.clear_memory()
231
 
232
  if interactive_state["track_end_number"]:
233
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
 
267
 
268
  # inpaint
269
  def inpaint_video(video_state, interactive_state, mask_dropdown):
270
+
271
  frames = np.asarray(video_state["origin_images"])
272
  fps = video_state["fps"]
273
  inpaint_masks = np.asarray(video_state["masks"])
 
314
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
315
  return output_path
316
 
317
+
318
+ # args, defined in track_anything.py
319
+ args = parse_augment()
320
+
321
  # check and download checkpoints if needed
322
+ SAM_checkpoint_dict = {
323
+ 'vit_h': "sam_vit_h_4b8939.pth",
324
+ 'vit_l': "sam_vit_l_0b3195.pth",
325
+ "vit_b": "sam_vit_b_01ec64.pth"
326
+ }
327
+ SAM_checkpoint_url_dict = {
328
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
329
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
330
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
331
+ }
332
+ sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
333
+ sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
334
  xmem_checkpoint = "XMem-s012.pth"
335
  xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
336
  e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
337
  e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
338
 
339
+
340
  folder ="./checkpoints"
341
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
342
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
343
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
344
+
 
 
 
 
345
 
346
  # initialize sam, xmem, e2fgvi models
347
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
348
 
349
+
350
+ title = """<p><h1 align="center">Track-Anything</h1></p>
351
+ """
352
+ description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
353
+
354
+
355
  with gr.Blocks() as iface:
356
  """
357
  state for
 
383
  "fps": 30
384
  }
385
  )
386
+ gr.Markdown(title)
387
+ gr.Markdown(description)
388
  with gr.Row():
389
 
390
  # for user video input
 
393
  video_input = gr.Video(autosize=True)
394
  with gr.Column():
395
  video_info = gr.Textbox()
396
+ resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
397
  Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
398
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
399