supersolar commited on
Commit
dba3ac4
·
verified ·
1 Parent(s): 369648b

Create sam2-colab.py

Browse files
Files changed (1) hide show
  1. sam2-colab.py +154 -0
sam2-colab.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%cd /content/florence-sam
2
+ import os
3
+ from typing import Tuple, Optional
4
+ import shutil
5
+ import os
6
+ import cv2
7
+ import numpy as np
8
+ import spaces
9
+ import supervision as sv
10
+ import torch
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ import sys
14
+ import json
15
+ import pickle
16
+ os.chdir("/content/florence-sam")
17
+ sys.path.append('/content/florence-sam')
18
+ from utils.video import generate_unique_name, create_directory, delete_directory
19
+ from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model
20
+
21
+ DEVICE = torch.device("cuda")
22
+ DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
23
+ DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
24
+ # DEVICE = torch.device("cpu")
25
+
26
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
27
+ if torch.cuda.get_device_properties(0).major >= 8:
28
+ torch.backends.cuda.matmul.allow_tf32 = True
29
+ torch.backends.cudnn.allow_tf32 = True
30
+ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
31
+ SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
32
+
33
+ VIDEO_SCALE_FACTOR = 1
34
+ VIDEO_TARGET_DIRECTORY = "/content/"
35
+ create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
36
+ with open('/content/output_video.pkl', 'rb') as file:
37
+ output_video = pickle.load(file)
38
+ print(output_video)
39
+ video_input= output_video
40
+
41
+
42
+ frame_generator = sv.get_video_frames_generator(video_input)
43
+ frame = next(frame_generator)
44
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
45
+
46
+
47
+ with open('/content/detections_list.pkl', 'rb') as file:
48
+ detections_list = pickle.load(file)
49
+ print(detections_list)
50
+ detections = sv.Detections.merge(detections_list)
51
+ detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
52
+
53
+ if len(detections.mask) == 0:
54
+ print(
55
+ "No objects of class {text_input} found in the first frame of the video. "
56
+ "Trim the video to make the object appear in the first frame or try a "
57
+ "different text prompt."
58
+ )
59
+
60
+
61
+ name = generate_unique_name()
62
+ frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
63
+ frames_sink = sv.ImageSink(
64
+ target_dir_path=frame_directory_path,
65
+ image_name_pattern="{:05d}.jpeg"
66
+ )
67
+
68
+ video_info = sv.VideoInfo.from_video_path(video_input)
69
+ video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
70
+ video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
71
+
72
+ frames_generator = sv.get_video_frames_generator(video_input)
73
+ with frames_sink:
74
+ for frame in tqdm(
75
+ frames_generator,
76
+ total=video_info.total_frames,
77
+ desc="splitting video into frames"
78
+ ):
79
+ frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
80
+ frames_sink.save_image(frame)
81
+
82
+ inference_state = SAM_VIDEO_MODEL.init_state(
83
+ video_path=frame_directory_path,
84
+ device=DEVICE
85
+ )
86
+
87
+
88
+
89
+
90
+ for mask_index, mask in enumerate(detections.mask):
91
+ _, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
92
+ inference_state=inference_state,
93
+ frame_idx=0,
94
+ obj_id=mask_index,
95
+ mask=mask
96
+ )
97
+
98
+ video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
99
+ frames_generator = sv.get_video_frames_generator(video_input)
100
+ masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
101
+ n = 1
102
+
103
+ out_dir = "/content/output"
104
+ os.makedirs(out_dir, exist_ok=True)
105
+ import shutil
106
+ shutil.rmtree('/content/output', ignore_errors=True)
107
+ os.makedirs('/content/output', exist_ok=True)
108
+ with sv.VideoSink(video_path, video_info=video_info) as sink:
109
+ for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
110
+ frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
111
+ #print(cv2.imwrite('/content/te111.jpeg', frame))
112
+ #print(mask_logits.dtype)
113
+ masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
114
+ # 将布尔掩码转换为 uint8 类型
115
+ # 将布尔掩码合并成一个遮罩 all_mask
116
+ all_mask_bool = np.logical_or.reduce(masks, axis=0)
117
+ # 将布尔掩码转换为 uint8 类型
118
+ all_mask_uint8 = (all_mask_bool * 255).astype(np.uint8)
119
+ #print(cv2.imwrite('/content/tem444.jpg', all_mask_uint8.squeeze()))
120
+ #all_mask = masks_uint8[0].squeeze()
121
+ all_mask = all_mask_uint8.squeeze()
122
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
123
+ # 对 all_mask 进行腐蚀操作
124
+ eroded_mask = cv2.erode(all_mask, kernel, iterations=1)
125
+ all_mask = eroded_mask
126
+ # 对 all_mask 进行膨胀操作
127
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
128
+ dilated_mask = cv2.dilate(all_mask, kernel, iterations=1)
129
+ all_mask = dilated_mask
130
+ #image = cv2.imread('/content/tem.jpg')
131
+ masked_image = cv2.bitwise_and(frame, frame, mask=all_mask)
132
+ # 创建一个全黑的背景图像
133
+ black_background = np.zeros_like(frame)
134
+
135
+ # 创建一个全白的背景图像
136
+ white_background = np.ones_like(frame) * 255
137
+
138
+ # 将 frame 的透明度设置为 50%
139
+ transparent_frame = cv2.addWeighted(frame, 0.1, black_background, 0.9, 0)
140
+
141
+ # 检查每个像素是否为白色,如果是白色,则保持其不透明
142
+ #white_mask = (frame >= 100).all(axis=-1)
143
+ white_mask = (frame >= 130).all(axis=-1)
144
+ frame = np.where(white_mask[:, :, None], frame, transparent_frame)
145
+
146
+ # 将提取的部分区域叠加到透明后的图片上
147
+ frame = np.where(all_mask[:, :, None] > 0, masked_image, frame)
148
+ # result 即为只保留 all_mask 遮罩内容的图像
149
+ #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
150
+ cv2.imwrite("/content/output/" + str(n) + ".jpeg", frame)
151
+ n = n + 1
152
+
153
+
154
+ delete_directory(frame_directory_path)