File size: 2,723 Bytes
123489f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from tqdm import tqdm

from utils.interact_tools import SamControler
from tracker.base_tracker import BaseTracker
import numpy as np
import argparse
import cv2

from typing import Optional


class TrackingAnything:
    def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, xmem_checkpoint, args):
        self.args = args
        self.sam_pt_checkpoint = sam_pt_checkpoint
        self.sam_onnx_checkpoint = sam_onnx_checkpoint
        self.xmem_checkpoint = xmem_checkpoint
        self.samcontroler = SamControler(
            self.sam_pt_checkpoint, self.sam_onnx_checkpoint, args.sam_model_type, args.device
        )
        self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)

    def first_frame_click(
        self, image: np.ndarray, points: np.ndarray, labels: np.ndarray, multimask=True
    ):
        mask, logit, painted_image = self.samcontroler.first_frame_click(
            image, points, labels, multimask
        )
        return mask, logit, painted_image

    def generator(
        self,
        images: list,
        template_mask: np.ndarray,
        write: Optional[bool] = False,
        fps: Optional[int] = "30",
        output_path: Optional[str] = "tracking.mp4",
    ):
        masks = []
        logits = []
        painted_images = []

        if write:
            size = images[0].shape[:2][::-1]
            if not os.path.exists(os.path.dirname(output_path)):
                os.makedirs(os.path.dirname(output_path))
            writer = cv2.VideoWriter(
                output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size
            )

        for i in tqdm(range(len(images)), desc="Tracking image"):
            if i == 0:
                mask, logit, painted_image = self.xmem.track(images[i], template_mask)
            else:
                mask, logit, painted_image = self.xmem.track(images[i])

            masks.append(mask)
            logits.append(logit)

            if write:
                writer.write(painted_image[:,:,::-1])
            else:
                painted_images.append(painted_image)

        if write:
            writer.release()

        return masks, logits, painted_images


def parse_augment():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--sam_model_type", type=str, default="vit_t")
    parser.add_argument(
        "--port",
        type=int,
        default=6080,
        help="only useful when running gradio applications",
    )
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--mask_save", default=False)
    args = parser.parse_args()

    if args.debug:
        print(args)
    return args