File size: 5,769 Bytes
59b2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os, shutil, sys
import argparse
import gdown
import cv2
import numpy as np
import os
import sys
import requests
import json
import torchvision
import torch 
import psutil
import time
try: 
    from mmcv.cnn import ConvModule
except:
    os.system("mim install mmcv")


# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from track_anything_code.model import TrackingAnything
from track_anything_code.track_anything_module import get_frames_from_video, download_checkpoint, parse_augment, sam_refine, vos_tracking_video
from scripts.compress_videos import compress_video




if __name__ == "__main__":
    dataset_path = "Bridge_v1_TT14"
    video_name = "combined.mp4"
    verbose = True      # If this is verbose, you will continue to write the code


    ################################################## Model setup ####################################################
    # check and download checkpoints if needed
    sam_checkpoint = "sam_vit_h_4b8939.pth"
    sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
    xmem_checkpoint = "XMem-s012.pth"
    xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"


    folder ="./pretrained"
    SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
    xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)

    # argument
    args = parse_augment()
    args.device = "cuda"      # Any GPU is ok

    # Initialize the Track model
    track_model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
    ###################################################################################################################


    # Iterate all files under the folder
    for sub_folder_name in sorted(os.listdir(dataset_path)):

        ################################################## Setting ####################################################
        sub_folder_path = os.path.join(dataset_path, sub_folder_name)

        click_state = [[],[]]
        interactive_state = {
                                "inference_times": 0,
                                "negative_click_times" : 0,
                                "positive_click_times": 0,
                                "mask_save": args.mask_save,
                                "multi_mask": {
                                    "mask_names": [],
                                    "masks": []
                                },
                                "track_end_number": None,
                                "resize_ratio": 1
                            }
        ###################################################################################################################
        

        video_path = os.path.join(sub_folder_path, video_name)
        if not os.path.exists(video_path):
            print("We cannot find the path of the ", video_path, " and we will compress one")
            status = compress_video(sub_folder_path, video_name)
            if not status:
                print("We still cannot generate a video")
                continue

        # Read video state
        video_state = { 
                        "user_name": "",
                        "video_name": "",
                        "origin_images": None,
                        "painted_images": None,
                        "masks": None,
                        "inpaint_masks": None,
                        "logits": None,
                        "select_frame_number": 0,
                        "fps": 30
                    }
        video_state, template_frame = get_frames_from_video(video_path, video_state, track_model)
        


        ########################################################## Get the sam point based on the data.txt ###########################################################
        data_txt_path = os.path.join(sub_folder_path, "data.txt")
        if not os.path.exists(data_txt_path):
            print("We cannot find data.txt in this folder")
            continue

        data_file = open(data_txt_path, 'r')
        lines = data_file.readlines()
        frame_idx, horizontal, vertical = lines[0][:-2].split(' ')   # Only read the first point
        point_cord = [int(float(horizontal)), int(float(vertical))]

        # Process by SAM
        track_model.samcontroler.sam_controler.reset_image() # Reset the image to clean history
        painted_image, video_state, interactive_state, operation_log = sam_refine(track_model, video_state, "Positive", click_state, interactive_state, point_cord)
        ################################################################################################################################################################



        ######################################################### Get the tracking output ########################################################################
        
        # Track the video for processing
        segment_output_path = os.path.join(sub_folder_path, "segment_output.gif")
        video_state = vos_tracking_video(track_model, segment_output_path, video_state, interactive_state, mask_dropdown=[])[0]   # mask_dropdown is empty now
        
        # Extract the mask needed by us for further point calculating
        masks = video_state["masks"]        # In the range [0, 1]
        
        if verbose:
            for idx, mask in enumerate(masks):
                cv2.imwrite(os.path.join(sub_folder_path, "mask"+str(idx)+".png"), mask*255)

        ##############################################################################################################################################################