Spaces:
Sleeping
Sleeping
File size: 6,107 Bytes
6c0d568 a4f8a15 6c0d568 1eecd17 eee8ee3 6c0d568 eee8ee3 6c0d568 eee8ee3 ce91e93 eee8ee3 aafe80a 6c0d568 eee8ee3 6c0d568 6eee3ae 5010485 6c0d568 eee8ee3 a37e840 6c0d568 01bbb6c 6c0d568 eee8ee3 6c0d568 eee8ee3 2556971 eee8ee3 2556971 eee8ee3 01bbb6c eee8ee3 6c0d568 eee8ee3 6c0d568 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
import os
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
def preprocess_image(image):
return image, gr.State([]), gr.State([])
def get_point(tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
tracking_points.value.append(evt.index)
print(f"TRACKING POINT: {tracking_points.value}")
trackings_input_label.value.append(1)
print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
# for SAM2
# input_point = np.array(tracking_points.value)
# print(f"SAM2 INPUT POINT: {input_point}")
# input_label = np.array([1])
transparent_background = Image.open(first_frame_path).convert('RGBA')
w, h = transparent_background.size
transparent_layer = np.zeros((h, w, 4))
for track in tracking_points.value:
cv2.circle(transparent_layer, track, 5, (255, 0, 0, 255), -1)
transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
return tracking_points, trackings_input_label, selected_point_map
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def show_mask(mask, ax, random_color=False, borders = True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
masks_store = []
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
# plt.show()
# Save the figure as a JPG file
filename = f"masked_image_{i+1}.jpg"
plt.savefig(filename, format='jpg', bbox_inches='tight')
masks_store.append(filename)
# Close the figure to free up memory
plt.close()
return masks_store
def sam_process(input_image, tracking_points, trackings_input_label):
image = Image.open(input_image)
image = np.array(image.convert("RGB"))
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
model_cfg = "sam2_hiera_t.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.set_image(image)
input_point = np.array(tracking_points.value)
input_label = np.array(trackings_input_label.value)
print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
print(masks.shape)
results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
print(results)
return results
with gr.Blocks() as demo:
first_frame_path = gr.State()
tracking_points = gr.State([])
trackings_input_label = gr.State([])
with gr.Column():
gr.Markdown("# SAM2 Image Predictor")
with gr.Row():
input_image = gr.Image(label="input image", interactive=True, type="filepath")
with gr.Column():
points_map = gr.Image(label="points map", interactive=False)
submit_btn = gr.Button("Submit")
output_result = gr.Gallery()
input_image.upload(preprocess_image, input_image, [first_frame_path, tracking_points, trackings_input_label, points_map])
points_map.select(get_point, [tracking_points, trackings_input_label, first_frame_path], [tracking_points, trackings_input_label, points_map])
submit_btn.click(
fn = sam_process,
inputs = [input_image, tracking_points, trackings_input_label],
outputs = [output_result]
)
demo.launch() |