GraspAnything / app.py
Plachta's picture
Upload 2 files
a5b2d18 verified
raw
history blame
4.72 kB
import copy
import numpy as np
import torch
import sys
sys.path.append("./")
from models import sam_model_registry
from models.grasp_mods import modify_forward
from models.utils.transforms import ResizeLongestSide
from gradio_image_prompter import ImagePrompter
from structures.grasp_box import GraspCoder
img_resize = ResizeLongestSide(1024)
import cv2
import gradio as gr
from models.grasp_mods import add_inference_method
device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "vit_b"
mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis]
std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis]
sam = sam_model_registry[model_type]()
sam.to(device=device)
sam.forward = modify_forward(sam)
sam.infer = add_inference_method(sam)
pretrained_model_path = "./epoch_9_step_535390.pth"
if pretrained_model_path != "":
sd = torch.load(pretrained_model_path)
# strip prefix "module." from keys
new_sd = {}
for k, v in sd.items():
if k.startswith("module."):
k = k[7:]
new_sd[k] = v
sam.load_state_dict(new_sd)
sam.eval()
def predict(input, topk):
np_image = input["image"]
points = input["points"]
orig_size = np_image.shape[:2]
# normalize image
np_image = np_image.transpose(2, 0, 1)
image = (np_image - mean) / std
image = torch.tensor(image).float().to(device)
image = image.unsqueeze(0)
t_image = img_resize.apply_image_torch(image)
t_orig_size = t_image.shape[-2:]
# pad to 1024x1024
t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
# get box prompt
valid_boxes = []
for point in points:
x1, y1, type1, x2, y2, type2 = point
if type1 == 2 and type2 == 3:
valid_boxes.append([x1, y1, x2, y2])
if len(valid_boxes) == 0:
return np_image
t_boxes = np.array(valid_boxes)
t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
batched_inputs = [{"image": t_image[0], "boxes": box_torch}]
with torch.no_grad():
outputs = sam.infer(batched_inputs, multimask_output=False)
# visualize and post on tensorboard
# recover image
recovered_img = batched_inputs[0]['image'].cpu().numpy()
recovered_img = recovered_img * std + mean
recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255)
for i in range(len(outputs.pred_masks)):
# get predicted mask
pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5
pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2)
# get predicted grasp
pred_logits = outputs.logits[i].detach().cpu().numpy()
top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp)
_ = coded_grasp.decode()
decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
# draw mask
mask_color = np.array([0, 255, 0])[None, None, :]
recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5
# draw grasp
recovered_img = np.ascontiguousarray(recovered_img)
for grasp in decoded_grasp:
grasp = grasp.astype(int)
cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1)
cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1)
cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2)
cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2)
recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]]
# resize to original size
recovered_img = cv2.resize(recovered_img, (orig_size[0], orig_size[1]))
return recovered_img
if __name__ == "__main__":
app = gr.Blocks(title="GraspAnything")
with app:
gr.Markdown("""
# GraspAnything <br>
Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object.
""")
with gr.Column():
prompter = ImagePrompter(show_label=False)
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps")
with gr.Column():
image_output = gr.Image()
btn = gr.Button("Generate!")
btn.click(predict,
inputs=[prompter, top_k],
outputs=[image_output])
app.launch()