supersolar commited on
Commit
5cc538f
·
verified ·
1 Parent(s): 9953f50

Create kaggle_sam2_gpu_2.py

Browse files
Files changed (1) hide show
  1. kaggle_sam2_gpu_2.py +153 -0
kaggle_sam2_gpu_2.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple, Optional
3
+ import shutil
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+ import spaces
8
+ import supervision as sv
9
+ import torch
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+ import sys
13
+ import json
14
+ import pickle
15
+ os.chdir("/kaggle/florence-sam-kaggle")
16
+ sys.path.append("/kaggle/florence-sam-kaggle")
17
+ from utils.video import generate_unique_name, create_directory, delete_directory
18
+ from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model
19
+
20
+ #DEVICE = torch.device("cuda")
21
+ DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
22
+ #DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
23
+ # DEVICE = torch.device("cpu")
24
+
25
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
26
+ if torch.cuda.get_device_properties(0).major >= 8:
27
+ torch.backends.cuda.matmul.allow_tf32 = True
28
+ torch.backends.cudnn.allow_tf32 = True
29
+ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
30
+ SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
31
+
32
+ VIDEO_SCALE_FACTOR = 1
33
+ VIDEO_TARGET_DIRECTORY = "/kaggle/"
34
+ create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
35
+ with open('/kaggle/output_video2.pkl', 'rb') as file:
36
+ output_video = pickle.load(file)
37
+ print(output_video)
38
+ video_input= output_video
39
+
40
+
41
+ frame_generator = sv.get_video_frames_generator(video_input)
42
+ frame = next(frame_generator)
43
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
44
+
45
+
46
+ with open('/kaggle/detections_list2.pkl', 'rb') as file:
47
+ detections_list = pickle.load(file)
48
+ print(detections_list)
49
+ detections = sv.Detections.merge(detections_list)
50
+ detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
51
+
52
+ if len(detections.mask) == 0:
53
+ print(
54
+ "No objects of class {text_input} found in the first frame of the video. "
55
+ "Trim the video to make the object appear in the first frame or try a "
56
+ "different text prompt."
57
+ )
58
+
59
+
60
+ name = generate_unique_name()
61
+ frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
62
+ frames_sink = sv.ImageSink(
63
+ target_dir_path=frame_directory_path,
64
+ image_name_pattern="{:05d}.jpeg"
65
+ )
66
+
67
+ video_info = sv.VideoInfo.from_video_path(video_input)
68
+ video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
69
+ video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
70
+
71
+ frames_generator = sv.get_video_frames_generator(video_input)
72
+ with frames_sink:
73
+ for frame in tqdm(
74
+ frames_generator,
75
+ total=video_info.total_frames,
76
+ desc="splitting video into frames"
77
+ ):
78
+ frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
79
+ frames_sink.save_image(frame)
80
+
81
+ inference_state = SAM_VIDEO_MODEL.init_state(
82
+ video_path=frame_directory_path,
83
+ device=DEVICE
84
+ )
85
+
86
+
87
+
88
+
89
+ for mask_index, mask in enumerate(detections.mask):
90
+ _, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
91
+ inference_state=inference_state,
92
+ frame_idx=0,
93
+ obj_id=mask_index,
94
+ mask=mask
95
+ )
96
+
97
+ video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
98
+ frames_generator = sv.get_video_frames_generator(video_input)
99
+ masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
100
+ n = 1
101
+
102
+ out_dir = "/kaggle/output2"
103
+ os.makedirs(out_dir, exist_ok=True)
104
+ import shutil
105
+ shutil.rmtree('/kaggle/output2', ignore_errors=True)
106
+ os.makedirs('/kaggle/output2', exist_ok=True)
107
+ with sv.VideoSink(video_path, video_info=video_info) as sink:
108
+ for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
109
+ frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
110
+ #print(cv2.imwrite('/kaggle/te111.jpeg', frame))
111
+ #print(mask_logits.dtype)
112
+ masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
113
+ # 将布尔掩码转换为 uint8 类型
114
+ # 将布尔掩码合并成一个遮罩 all_mask
115
+ all_mask_bool = np.logical_or.reduce(masks, axis=0)
116
+ # 将布尔掩码转换为 uint8 类型
117
+ all_mask_uint8 = (all_mask_bool * 255).astype(np.uint8)
118
+ #print(cv2.imwrite('/kaggle/tem444.jpg', all_mask_uint8.squeeze()))
119
+ #all_mask = masks_uint8[0].squeeze()
120
+ all_mask = all_mask_uint8.squeeze()
121
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
122
+ # 对 all_mask 进行腐蚀操作
123
+ eroded_mask = cv2.erode(all_mask, kernel, iterations=1)
124
+ all_mask = eroded_mask
125
+ # 对 all_mask 进行膨胀操作
126
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
127
+ dilated_mask = cv2.dilate(all_mask, kernel, iterations=1)
128
+ all_mask = dilated_mask
129
+ #image = cv2.imread('/kaggle/tem.jpg')
130
+ masked_image = cv2.bitwise_and(frame, frame, mask=all_mask)
131
+ # 创建一个全黑的背景图像
132
+ black_background = np.zeros_like(frame)
133
+
134
+ # 创建一个全白的背景图像
135
+ white_background = np.ones_like(frame) * 255
136
+
137
+ # 将 frame 的透明度设置为 50%
138
+ transparent_frame = cv2.addWeighted(frame, 0.1, black_background, 0.9, 0)
139
+
140
+ # 检查每个像素是否为白色,如果是白色,则保持其不透明
141
+ #white_mask = (frame >= 100).all(axis=-1)
142
+ white_mask = (frame >= 130).all(axis=-1)
143
+ frame = np.where(white_mask[:, :, None], frame, transparent_frame)
144
+
145
+ # 将提取的部分区域叠加到透明后的图片上
146
+ frame = np.where(all_mask[:, :, None] > 0, masked_image, frame)
147
+ # result 即为只保留 all_mask 遮罩内容的图像
148
+ #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
149
+ cv2.imwrite("/kaggle/output2/" + str(n) + ".jpeg", frame)
150
+ n = n + 1
151
+
152
+
153
+ delete_directory(frame_directory_path)