Spaces:
Runtime error
Runtime error
#%cd /content/florence-sam | |
import os | |
from typing import Tuple, Optional | |
import shutil | |
import os | |
import cv2 | |
import numpy as np | |
import spaces | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
import sys | |
import json | |
import pickle | |
os.chdir("/content/florence-sam") | |
sys.path.append('/content/florence-sam') | |
from utils.video import generate_unique_name, create_directory, delete_directory | |
from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model | |
DEVICE = torch.device("cuda") | |
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1] | |
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0] | |
# DEVICE = torch.device("cpu") | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE) | |
VIDEO_SCALE_FACTOR = 1 | |
VIDEO_TARGET_DIRECTORY = "/content/" | |
create_directory(directory_path=VIDEO_TARGET_DIRECTORY) | |
with open('/content/output_video.pkl', 'rb') as file: | |
output_video = pickle.load(file) | |
print(output_video) | |
video_input= output_video | |
frame_generator = sv.get_video_frames_generator(video_input) | |
frame = next(frame_generator) | |
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
with open('/content/detections_list.pkl', 'rb') as file: | |
detections_list = pickle.load(file) | |
print(detections_list) | |
detections = sv.Detections.merge(detections_list) | |
detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections) | |
if len(detections.mask) == 0: | |
print( | |
"No objects of class {text_input} found in the first frame of the video. " | |
"Trim the video to make the object appear in the first frame or try a " | |
"different text prompt." | |
) | |
name = generate_unique_name() | |
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name) | |
frames_sink = sv.ImageSink( | |
target_dir_path=frame_directory_path, | |
image_name_pattern="{:05d}.jpeg" | |
) | |
video_info = sv.VideoInfo.from_video_path(video_input) | |
video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR) | |
video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR) | |
frames_generator = sv.get_video_frames_generator(video_input) | |
with frames_sink: | |
for frame in tqdm( | |
frames_generator, | |
total=video_info.total_frames, | |
desc="splitting video into frames" | |
): | |
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR) | |
frames_sink.save_image(frame) | |
inference_state = SAM_VIDEO_MODEL.init_state( | |
video_path=frame_directory_path, | |
device=DEVICE | |
) | |
for mask_index, mask in enumerate(detections.mask): | |
_, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask( | |
inference_state=inference_state, | |
frame_idx=0, | |
obj_id=mask_index, | |
mask=mask | |
) | |
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") | |
frames_generator = sv.get_video_frames_generator(video_input) | |
masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state) | |
n = 1 | |
out_dir = "/content/output" | |
os.makedirs(out_dir, exist_ok=True) | |
import shutil | |
shutil.rmtree('/content/output', ignore_errors=True) | |
os.makedirs('/content/output', exist_ok=True) | |
with sv.VideoSink(video_path, video_info=video_info) as sink: | |
for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator): | |
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR) | |
#print(cv2.imwrite('/content/te111.jpeg', frame)) | |
#print(mask_logits.dtype) | |
masks = (mask_logits > 0.0).cpu().numpy().astype(bool) | |
# 将布尔掩码转换为 uint8 类型 | |
# 将布尔掩码合并成一个遮罩 all_mask | |
all_mask_bool = np.logical_or.reduce(masks, axis=0) | |
# 将布尔掩码转换为 uint8 类型 | |
all_mask_uint8 = (all_mask_bool * 255).astype(np.uint8) | |
#print(cv2.imwrite('/content/tem444.jpg', all_mask_uint8.squeeze())) | |
#all_mask = masks_uint8[0].squeeze() | |
all_mask = all_mask_uint8.squeeze() | |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) | |
# 对 all_mask 进行腐蚀操作 | |
eroded_mask = cv2.erode(all_mask, kernel, iterations=1) | |
all_mask = eroded_mask | |
# 对 all_mask 进行膨胀操作 | |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10)) | |
dilated_mask = cv2.dilate(all_mask, kernel, iterations=1) | |
all_mask = dilated_mask | |
#image = cv2.imread('/content/tem.jpg') | |
masked_image = cv2.bitwise_and(frame, frame, mask=all_mask) | |
# 创建一个全黑的背景图像 | |
black_background = np.zeros_like(frame) | |
# 创建一个全白的背景图像 | |
white_background = np.ones_like(frame) * 255 | |
# 将 frame 的透明度设置为 50% | |
transparent_frame = cv2.addWeighted(frame, 0.1, black_background, 0.9, 0) | |
# 检查每个像素是否为白色,如果是白色,则保持其不透明 | |
#white_mask = (frame >= 100).all(axis=-1) | |
white_mask = (frame >= 130).all(axis=-1) | |
frame = np.where(white_mask[:, :, None], frame, transparent_frame) | |
# 将提取的部分区域叠加到透明后的图片上 | |
frame = np.where(all_mask[:, :, None] > 0, masked_image, frame) | |
# result 即为只保留 all_mask 遮罩内容的图像 | |
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
cv2.imwrite("/content/output/" + str(n) + ".jpeg", frame) | |
n = n + 1 | |
delete_directory(frame_directory_path) |