File size: 5,670 Bytes
dba3ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#%cd /content/florence-sam
import os
from typing import Tuple, Optional
import shutil
import os
import cv2
import numpy as np
import spaces
import supervision as sv
import torch
from PIL import Image
from tqdm import tqdm
import sys
import json
import pickle
os.chdir("/content/florence-sam")
sys.path.append('/content/florence-sam')
from utils.video import generate_unique_name, create_directory, delete_directory
from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model

DEVICE = torch.device("cuda")
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
# DEVICE = torch.device("cpu")

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)

VIDEO_SCALE_FACTOR = 1
VIDEO_TARGET_DIRECTORY = "/content/"
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
with open('/content/output_video.pkl', 'rb') as file:
    output_video = pickle.load(file)
print(output_video)
video_input= output_video


frame_generator = sv.get_video_frames_generator(video_input)
frame = next(frame_generator)
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))


with open('/content/detections_list.pkl', 'rb') as file:
    detections_list = pickle.load(file)
print(detections_list)
detections = sv.Detections.merge(detections_list)
detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)

if len(detections.mask) == 0:
    print(
        "No objects of class {text_input} found in the first frame of the video. "
        "Trim the video to make the object appear in the first frame or try a "
        "different text prompt."
    )


name = generate_unique_name()
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
frames_sink = sv.ImageSink(
    target_dir_path=frame_directory_path,
    image_name_pattern="{:05d}.jpeg"
)

video_info = sv.VideoInfo.from_video_path(video_input)
video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)

frames_generator = sv.get_video_frames_generator(video_input)
with frames_sink:
    for frame in tqdm(
            frames_generator,
            total=video_info.total_frames,
            desc="splitting video into frames"
    ):
        frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
        frames_sink.save_image(frame)

inference_state = SAM_VIDEO_MODEL.init_state(
    video_path=frame_directory_path,
    device=DEVICE
)




for mask_index, mask in enumerate(detections.mask):
    _, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
        inference_state=inference_state,
        frame_idx=0,
        obj_id=mask_index,
        mask=mask
    )

video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
frames_generator = sv.get_video_frames_generator(video_input)
masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
n = 1

out_dir = "/content/output"
os.makedirs(out_dir, exist_ok=True)
import shutil
shutil.rmtree('/content/output', ignore_errors=True)
os.makedirs('/content/output', exist_ok=True)
with sv.VideoSink(video_path, video_info=video_info) as sink:
    for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
        frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
        #print(cv2.imwrite('/content/te111.jpeg', frame))
        #print(mask_logits.dtype)
        masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
       # 将布尔掩码转换为 uint8 类型
        # 将布尔掩码合并成一个遮罩 all_mask
        all_mask_bool = np.logical_or.reduce(masks, axis=0)
        # 将布尔掩码转换为 uint8 类型
        all_mask_uint8 = (all_mask_bool * 255).astype(np.uint8)
        #print(cv2.imwrite('/content/tem444.jpg', all_mask_uint8.squeeze()))
        #all_mask = masks_uint8[0].squeeze()
        all_mask = all_mask_uint8.squeeze()
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
        # 对 all_mask 进行腐蚀操作
        eroded_mask = cv2.erode(all_mask, kernel, iterations=1)
        all_mask = eroded_mask
        # 对 all_mask 进行膨胀操作
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
        dilated_mask = cv2.dilate(all_mask, kernel, iterations=1)
        all_mask = dilated_mask
        #image = cv2.imread('/content/tem.jpg')
        masked_image = cv2.bitwise_and(frame, frame, mask=all_mask)
        # 创建一个全黑的背景图像
        black_background = np.zeros_like(frame)

        # 创建一个全白的背景图像
        white_background = np.ones_like(frame) * 255

        # 将 frame 的透明度设置为 50%
        transparent_frame = cv2.addWeighted(frame, 0.1, black_background, 0.9, 0)

        # 检查每个像素是否为白色,如果是白色,则保持其不透明
        #white_mask = (frame >= 100).all(axis=-1)
        white_mask = (frame >= 130).all(axis=-1)
        frame = np.where(white_mask[:, :, None], frame, transparent_frame)

        # 将提取的部分区域叠加到透明后的图片上
        frame = np.where(all_mask[:, :, None] > 0, masked_image, frame)
        # result 即为只保留 all_mask 遮罩内容的图像
        #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cv2.imwrite("/content/output/" + str(n) + ".jpeg", frame)
        n = n + 1


delete_directory(frame_directory_path)