File size: 4,715 Bytes
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5b2d18
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import copy
import numpy as np
import torch

import sys
sys.path.append("./")
from models import sam_model_registry
from models.grasp_mods import modify_forward
from models.utils.transforms import ResizeLongestSide

from gradio_image_prompter import ImagePrompter
from structures.grasp_box import GraspCoder
img_resize = ResizeLongestSide(1024)
import cv2

import gradio as gr

from models.grasp_mods import add_inference_method

device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = "vit_b"

mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis]
std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis]

sam = sam_model_registry[model_type]()
sam.to(device=device)

sam.forward = modify_forward(sam)
sam.infer = add_inference_method(sam)

pretrained_model_path = "./epoch_9_step_535390.pth"

if pretrained_model_path != "":
    sd = torch.load(pretrained_model_path)
    # strip prefix "module." from keys
    new_sd = {}
    for k, v in sd.items():
        if k.startswith("module."):
            k = k[7:]
        new_sd[k] = v
    sam.load_state_dict(new_sd)

sam.eval()

def predict(input, topk):
    np_image = input["image"]
    points = input["points"]
    orig_size = np_image.shape[:2]
    # normalize image
    np_image = np_image.transpose(2, 0, 1)

    image = (np_image - mean) / std
    image = torch.tensor(image).float().to(device)
    image = image.unsqueeze(0)
    t_image = img_resize.apply_image_torch(image)
    t_orig_size = t_image.shape[-2:]
    # pad to 1024x1024
    t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))

    # get box prompt
    valid_boxes = []
    for point in points:
        x1, y1, type1, x2, y2, type2 = point
        if type1 == 2 and type2 == 3:
            valid_boxes.append([x1, y1, x2, y2])
    if len(valid_boxes) == 0:
        return np_image
    t_boxes = np.array(valid_boxes)
    t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
    box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
    batched_inputs = [{"image": t_image[0], "boxes": box_torch}]
    with torch.no_grad():
        outputs = sam.infer(batched_inputs, multimask_output=False)
    # visualize and post on tensorboard
    # recover image
    recovered_img = batched_inputs[0]['image'].cpu().numpy()
    recovered_img = recovered_img * std + mean
    recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255)

    for i in range(len(outputs.pred_masks)):
        # get predicted mask
        pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5
        pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2)

        # get predicted grasp
        pred_logits = outputs.logits[i].detach().cpu().numpy()
        top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
        pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
        coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp)
        _ = coded_grasp.decode()
        decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)

        # draw mask
        mask_color = np.array([0, 255, 0])[None, None, :]
        recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5

        # draw grasp
        recovered_img = np.ascontiguousarray(recovered_img)
        for grasp in decoded_grasp:
            grasp = grasp.astype(int)
            cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1)
            cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1)
            cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2)
            cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2)

    recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]]
    # resize to original size
    recovered_img = cv2.resize(recovered_img, (orig_size[0], orig_size[1]))
    return recovered_img

if __name__ == "__main__":
    app = gr.Blocks(title="GraspAnything")
    with app:
        gr.Markdown("""
        # GraspAnything <br>
        Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object.
        """)
        with gr.Column():
            prompter = ImagePrompter(show_label=False)
            top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps")
        with gr.Column():
            image_output = gr.Image()
        btn = gr.Button("Generate!")
        btn.click(predict,
                  inputs=[prompter, top_k],
                  outputs=[image_output])
    app.launch()