Spaces:
Running
Running
File size: 4,963 Bytes
fcdfd72 0cbc48e fcdfd72 9c466e3 fcdfd72 8e9f709 fcdfd72 8e9f709 fcdfd72 8e9f709 fcdfd72 8e9f709 fcdfd72 e7d17af fcdfd72 8e9f709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import copy
import numpy as np
import torch
import sys
sys.path.append("./")
from models import sam_model_registry
from models.grasp_mods import modify_forward
from models.utils.transforms import ResizeLongestSide
from gradio_image_prompter import ImagePrompter
from structures.grasp_box import GraspCoder
img_resize = ResizeLongestSide(1024)
import cv2
import gradio as gr
from models.grasp_mods import add_inference_method
device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "vit_b"
mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis]
std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis]
sam = sam_model_registry[model_type]()
sam.to(device=device)
sam.forward = modify_forward(sam)
sam.infer = add_inference_method(sam)
pretrained_model_path = "./epoch_39_step_415131.pth"
if pretrained_model_path != "":
sd = torch.load(pretrained_model_path, map_location='cpu')
# strip prefix "module." from keys
new_sd = {}
for k, v in sd.items():
if k.startswith("module."):
k = k[7:]
new_sd[k] = v
sam.load_state_dict(new_sd)
sam.eval()
def predict(input, topk):
np_image = input["image"]
points = input["points"]
orig_size = np_image.shape[:2]
# normalize image
np_image = np_image.transpose(2, 0, 1)
image = (np_image - mean) / std
image = torch.tensor(image).float().to(device)
image = image.unsqueeze(0)
t_image = img_resize.apply_image_torch(image)
t_orig_size = t_image.shape[-2:]
# pad to 1024x1024
pixel_mask = torch.ones(1, t_orig_size[0], t_orig_size[1], device=device)
t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
pixel_mask = torch.nn.functional.pad(pixel_mask, (0, 1024 - t_orig_size[1], 0, 1024 - t_orig_size[0]))
# get box prompt
valid_boxes = []
for point in points:
x1, y1, type1, x2, y2, type2 = point
if type1 == 2 and type2 == 3:
valid_boxes.append([x1, y1, x2, y2])
if len(valid_boxes) == 0:
return np_image
t_boxes = np.array(valid_boxes)
t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
batched_inputs = [{"image": t_image[0], "boxes": box_torch, "pixel_mask": pixel_mask}]
with torch.no_grad():
outputs = sam.infer(batched_inputs, multimask_output=False)
# visualize and post on tensorboard
# recover image
recovered_img = batched_inputs[0]['image'].cpu().numpy()
recovered_img = recovered_img * std + mean
recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255)
for i in range(len(outputs.pred_masks)):
# get predicted mask
pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5
pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2)
# get predicted grasp
pred_logits = outputs.logits[i].detach().cpu().numpy()
top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
coded_grasp = GraspCoder(t_orig_size[0], t_orig_size[1], None, grasp_annos_reformat=pred_grasp)
_ = coded_grasp.decode()
decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
# draw mask
mask_color = np.array([0, 255, 0])[None, None, :]
recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5
# draw grasp
recovered_img = np.ascontiguousarray(recovered_img)
for grasp in decoded_grasp:
grasp = grasp.astype(int)
cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1)
cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1)
cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2)
cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2)
recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]]
# resize to original size
recovered_img = cv2.resize(recovered_img, (orig_size[1], orig_size[0]))
return recovered_img
if __name__ == "__main__":
app = gr.Blocks(title="GraspAnything")
with app:
gr.Markdown("""
# GraspAnything <br>
Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object.
""")
with gr.Column():
prompter = ImagePrompter(show_label=False)
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps")
with gr.Column():
image_output = gr.Image()
btn = gr.Button("Generate!")
btn.click(predict,
inputs=[prompter, top_k],
outputs=[image_output])
app.launch() |