Spaces:
Runtime error
Runtime error
supersolar
commited on
Create sam2-colab.py
Browse files- sam2-colab.py +154 -0
sam2-colab.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%cd /content/florence-sam
|
2 |
+
import os
|
3 |
+
from typing import Tuple, Optional
|
4 |
+
import shutil
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import spaces
|
9 |
+
import supervision as sv
|
10 |
+
import torch
|
11 |
+
from PIL import Image
|
12 |
+
from tqdm import tqdm
|
13 |
+
import sys
|
14 |
+
import json
|
15 |
+
import pickle
|
16 |
+
os.chdir("/content/florence-sam")
|
17 |
+
sys.path.append('/content/florence-sam')
|
18 |
+
from utils.video import generate_unique_name, create_directory, delete_directory
|
19 |
+
from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model
|
20 |
+
|
21 |
+
DEVICE = torch.device("cuda")
|
22 |
+
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
|
23 |
+
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
|
24 |
+
# DEVICE = torch.device("cpu")
|
25 |
+
|
26 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
27 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
28 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
29 |
+
torch.backends.cudnn.allow_tf32 = True
|
30 |
+
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
|
31 |
+
SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
|
32 |
+
|
33 |
+
VIDEO_SCALE_FACTOR = 1
|
34 |
+
VIDEO_TARGET_DIRECTORY = "/content/"
|
35 |
+
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
|
36 |
+
with open('/content/output_video.pkl', 'rb') as file:
|
37 |
+
output_video = pickle.load(file)
|
38 |
+
print(output_video)
|
39 |
+
video_input= output_video
|
40 |
+
|
41 |
+
|
42 |
+
frame_generator = sv.get_video_frames_generator(video_input)
|
43 |
+
frame = next(frame_generator)
|
44 |
+
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
45 |
+
|
46 |
+
|
47 |
+
with open('/content/detections_list.pkl', 'rb') as file:
|
48 |
+
detections_list = pickle.load(file)
|
49 |
+
print(detections_list)
|
50 |
+
detections = sv.Detections.merge(detections_list)
|
51 |
+
detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
|
52 |
+
|
53 |
+
if len(detections.mask) == 0:
|
54 |
+
print(
|
55 |
+
"No objects of class {text_input} found in the first frame of the video. "
|
56 |
+
"Trim the video to make the object appear in the first frame or try a "
|
57 |
+
"different text prompt."
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
name = generate_unique_name()
|
62 |
+
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
|
63 |
+
frames_sink = sv.ImageSink(
|
64 |
+
target_dir_path=frame_directory_path,
|
65 |
+
image_name_pattern="{:05d}.jpeg"
|
66 |
+
)
|
67 |
+
|
68 |
+
video_info = sv.VideoInfo.from_video_path(video_input)
|
69 |
+
video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
|
70 |
+
video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
|
71 |
+
|
72 |
+
frames_generator = sv.get_video_frames_generator(video_input)
|
73 |
+
with frames_sink:
|
74 |
+
for frame in tqdm(
|
75 |
+
frames_generator,
|
76 |
+
total=video_info.total_frames,
|
77 |
+
desc="splitting video into frames"
|
78 |
+
):
|
79 |
+
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
|
80 |
+
frames_sink.save_image(frame)
|
81 |
+
|
82 |
+
inference_state = SAM_VIDEO_MODEL.init_state(
|
83 |
+
video_path=frame_directory_path,
|
84 |
+
device=DEVICE
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
for mask_index, mask in enumerate(detections.mask):
|
91 |
+
_, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
|
92 |
+
inference_state=inference_state,
|
93 |
+
frame_idx=0,
|
94 |
+
obj_id=mask_index,
|
95 |
+
mask=mask
|
96 |
+
)
|
97 |
+
|
98 |
+
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
|
99 |
+
frames_generator = sv.get_video_frames_generator(video_input)
|
100 |
+
masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
|
101 |
+
n = 1
|
102 |
+
|
103 |
+
out_dir = "/content/output"
|
104 |
+
os.makedirs(out_dir, exist_ok=True)
|
105 |
+
import shutil
|
106 |
+
shutil.rmtree('/content/output', ignore_errors=True)
|
107 |
+
os.makedirs('/content/output', exist_ok=True)
|
108 |
+
with sv.VideoSink(video_path, video_info=video_info) as sink:
|
109 |
+
for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
|
110 |
+
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
|
111 |
+
#print(cv2.imwrite('/content/te111.jpeg', frame))
|
112 |
+
#print(mask_logits.dtype)
|
113 |
+
masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
|
114 |
+
# 将布尔掩码转换为 uint8 类型
|
115 |
+
# 将布尔掩码合并成一个遮罩 all_mask
|
116 |
+
all_mask_bool = np.logical_or.reduce(masks, axis=0)
|
117 |
+
# 将布尔掩码转换为 uint8 类型
|
118 |
+
all_mask_uint8 = (all_mask_bool * 255).astype(np.uint8)
|
119 |
+
#print(cv2.imwrite('/content/tem444.jpg', all_mask_uint8.squeeze()))
|
120 |
+
#all_mask = masks_uint8[0].squeeze()
|
121 |
+
all_mask = all_mask_uint8.squeeze()
|
122 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
123 |
+
# 对 all_mask 进行腐蚀操作
|
124 |
+
eroded_mask = cv2.erode(all_mask, kernel, iterations=1)
|
125 |
+
all_mask = eroded_mask
|
126 |
+
# 对 all_mask 进行膨胀操作
|
127 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
|
128 |
+
dilated_mask = cv2.dilate(all_mask, kernel, iterations=1)
|
129 |
+
all_mask = dilated_mask
|
130 |
+
#image = cv2.imread('/content/tem.jpg')
|
131 |
+
masked_image = cv2.bitwise_and(frame, frame, mask=all_mask)
|
132 |
+
# 创建一个全黑的背景图像
|
133 |
+
black_background = np.zeros_like(frame)
|
134 |
+
|
135 |
+
# 创建一个全白的背景图像
|
136 |
+
white_background = np.ones_like(frame) * 255
|
137 |
+
|
138 |
+
# 将 frame 的透明度设置为 50%
|
139 |
+
transparent_frame = cv2.addWeighted(frame, 0.1, black_background, 0.9, 0)
|
140 |
+
|
141 |
+
# 检查每个像素是否为白色,如果是白色,则保持其不透明
|
142 |
+
#white_mask = (frame >= 100).all(axis=-1)
|
143 |
+
white_mask = (frame >= 130).all(axis=-1)
|
144 |
+
frame = np.where(white_mask[:, :, None], frame, transparent_frame)
|
145 |
+
|
146 |
+
# 将提取的部分区域叠加到透明后的图片上
|
147 |
+
frame = np.where(all_mask[:, :, None] > 0, masked_image, frame)
|
148 |
+
# result 即为只保留 all_mask 遮罩内容的图像
|
149 |
+
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
150 |
+
cv2.imwrite("/content/output/" + str(n) + ".jpeg", frame)
|
151 |
+
n = n + 1
|
152 |
+
|
153 |
+
|
154 |
+
delete_directory(frame_directory_path)
|