florence-sam-tencent / sam2-colab.py
supersolar's picture
Create sam2-colab.py
dba3ac4 verified
raw
history blame
5.67 kB
#%cd /content/florence-sam
import os
from typing import Tuple, Optional
import shutil
import os
import cv2
import numpy as np
import spaces
import supervision as sv
import torch
from PIL import Image
from tqdm import tqdm
import sys
import json
import pickle
os.chdir("/content/florence-sam")
sys.path.append('/content/florence-sam')
from utils.video import generate_unique_name, create_directory, delete_directory
from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model
DEVICE = torch.device("cuda")
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
VIDEO_SCALE_FACTOR = 1
VIDEO_TARGET_DIRECTORY = "/content/"
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
with open('/content/output_video.pkl', 'rb') as file:
output_video = pickle.load(file)
print(output_video)
video_input= output_video
frame_generator = sv.get_video_frames_generator(video_input)
frame = next(frame_generator)
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
with open('/content/detections_list.pkl', 'rb') as file:
detections_list = pickle.load(file)
print(detections_list)
detections = sv.Detections.merge(detections_list)
detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
if len(detections.mask) == 0:
print(
"No objects of class {text_input} found in the first frame of the video. "
"Trim the video to make the object appear in the first frame or try a "
"different text prompt."
)
name = generate_unique_name()
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
frames_sink = sv.ImageSink(
target_dir_path=frame_directory_path,
image_name_pattern="{:05d}.jpeg"
)
video_info = sv.VideoInfo.from_video_path(video_input)
video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
frames_generator = sv.get_video_frames_generator(video_input)
with frames_sink:
for frame in tqdm(
frames_generator,
total=video_info.total_frames,
desc="splitting video into frames"
):
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
frames_sink.save_image(frame)
inference_state = SAM_VIDEO_MODEL.init_state(
video_path=frame_directory_path,
device=DEVICE
)
for mask_index, mask in enumerate(detections.mask):
_, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
inference_state=inference_state,
frame_idx=0,
obj_id=mask_index,
mask=mask
)
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
frames_generator = sv.get_video_frames_generator(video_input)
masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
n = 1
out_dir = "/content/output"
os.makedirs(out_dir, exist_ok=True)
import shutil
shutil.rmtree('/content/output', ignore_errors=True)
os.makedirs('/content/output', exist_ok=True)
with sv.VideoSink(video_path, video_info=video_info) as sink:
for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
#print(cv2.imwrite('/content/te111.jpeg', frame))
#print(mask_logits.dtype)
masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
# 将布尔掩码转换为 uint8 类型
# 将布尔掩码合并成一个遮罩 all_mask
all_mask_bool = np.logical_or.reduce(masks, axis=0)
# 将布尔掩码转换为 uint8 类型
all_mask_uint8 = (all_mask_bool * 255).astype(np.uint8)
#print(cv2.imwrite('/content/tem444.jpg', all_mask_uint8.squeeze()))
#all_mask = masks_uint8[0].squeeze()
all_mask = all_mask_uint8.squeeze()
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# 对 all_mask 进行腐蚀操作
eroded_mask = cv2.erode(all_mask, kernel, iterations=1)
all_mask = eroded_mask
# 对 all_mask 进行膨胀操作
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
dilated_mask = cv2.dilate(all_mask, kernel, iterations=1)
all_mask = dilated_mask
#image = cv2.imread('/content/tem.jpg')
masked_image = cv2.bitwise_and(frame, frame, mask=all_mask)
# 创建一个全黑的背景图像
black_background = np.zeros_like(frame)
# 创建一个全白的背景图像
white_background = np.ones_like(frame) * 255
# 将 frame 的透明度设置为 50%
transparent_frame = cv2.addWeighted(frame, 0.1, black_background, 0.9, 0)
# 检查每个像素是否为白色,如果是白色,则保持其不透明
#white_mask = (frame >= 100).all(axis=-1)
white_mask = (frame >= 130).all(axis=-1)
frame = np.where(white_mask[:, :, None], frame, transparent_frame)
# 将提取的部分区域叠加到透明后的图片上
frame = np.where(all_mask[:, :, None] > 0, masked_image, frame)
# result 即为只保留 all_mask 遮罩内容的图像
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imwrite("/content/output/" + str(n) + ".jpeg", frame)
n = n + 1
delete_directory(frame_directory_path)