Spaces:
Runtime error
Runtime error
File size: 5,670 Bytes
dba3ac4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#%cd /content/florence-sam
import os
from typing import Tuple, Optional
import shutil
import os
import cv2
import numpy as np
import spaces
import supervision as sv
import torch
from PIL import Image
from tqdm import tqdm
import sys
import json
import pickle
os.chdir("/content/florence-sam")
sys.path.append('/content/florence-sam')
from utils.video import generate_unique_name, create_directory, delete_directory
from utils.sam import load_sam_image_model, run_sam_inference, load_sam_video_model
DEVICE = torch.device("cuda")
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][-1]
DEVICE = [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())][0]
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
VIDEO_SCALE_FACTOR = 1
VIDEO_TARGET_DIRECTORY = "/content/"
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
with open('/content/output_video.pkl', 'rb') as file:
output_video = pickle.load(file)
print(output_video)
video_input= output_video
frame_generator = sv.get_video_frames_generator(video_input)
frame = next(frame_generator)
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
with open('/content/detections_list.pkl', 'rb') as file:
detections_list = pickle.load(file)
print(detections_list)
detections = sv.Detections.merge(detections_list)
detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
if len(detections.mask) == 0:
print(
"No objects of class {text_input} found in the first frame of the video. "
"Trim the video to make the object appear in the first frame or try a "
"different text prompt."
)
name = generate_unique_name()
frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
frames_sink = sv.ImageSink(
target_dir_path=frame_directory_path,
image_name_pattern="{:05d}.jpeg"
)
video_info = sv.VideoInfo.from_video_path(video_input)
video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
frames_generator = sv.get_video_frames_generator(video_input)
with frames_sink:
for frame in tqdm(
frames_generator,
total=video_info.total_frames,
desc="splitting video into frames"
):
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
frames_sink.save_image(frame)
inference_state = SAM_VIDEO_MODEL.init_state(
video_path=frame_directory_path,
device=DEVICE
)
for mask_index, mask in enumerate(detections.mask):
_, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
inference_state=inference_state,
frame_idx=0,
obj_id=mask_index,
mask=mask
)
video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
frames_generator = sv.get_video_frames_generator(video_input)
masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
n = 1
out_dir = "/content/output"
os.makedirs(out_dir, exist_ok=True)
import shutil
shutil.rmtree('/content/output', ignore_errors=True)
os.makedirs('/content/output', exist_ok=True)
with sv.VideoSink(video_path, video_info=video_info) as sink:
for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
#print(cv2.imwrite('/content/te111.jpeg', frame))
#print(mask_logits.dtype)
masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
# 将布尔掩码转换为 uint8 类型
# 将布尔掩码合并成一个遮罩 all_mask
all_mask_bool = np.logical_or.reduce(masks, axis=0)
# 将布尔掩码转换为 uint8 类型
all_mask_uint8 = (all_mask_bool * 255).astype(np.uint8)
#print(cv2.imwrite('/content/tem444.jpg', all_mask_uint8.squeeze()))
#all_mask = masks_uint8[0].squeeze()
all_mask = all_mask_uint8.squeeze()
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
# 对 all_mask 进行腐蚀操作
eroded_mask = cv2.erode(all_mask, kernel, iterations=1)
all_mask = eroded_mask
# 对 all_mask 进行膨胀操作
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
dilated_mask = cv2.dilate(all_mask, kernel, iterations=1)
all_mask = dilated_mask
#image = cv2.imread('/content/tem.jpg')
masked_image = cv2.bitwise_and(frame, frame, mask=all_mask)
# 创建一个全黑的背景图像
black_background = np.zeros_like(frame)
# 创建一个全白的背景图像
white_background = np.ones_like(frame) * 255
# 将 frame 的透明度设置为 50%
transparent_frame = cv2.addWeighted(frame, 0.1, black_background, 0.9, 0)
# 检查每个像素是否为白色,如果是白色,则保持其不透明
#white_mask = (frame >= 100).all(axis=-1)
white_mask = (frame >= 130).all(axis=-1)
frame = np.where(white_mask[:, :, None], frame, transparent_frame)
# 将提取的部分区域叠加到透明后的图片上
frame = np.where(all_mask[:, :, None] > 0, masked_image, frame)
# result 即为只保留 all_mask 遮罩内容的图像
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imwrite("/content/output/" + str(n) + ".jpeg", frame)
n = n + 1
delete_directory(frame_directory_path) |