File size: 3,983 Bytes
49f816b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  
    

def merge_bounding_boxes(bbox1, bbox2):
    xmin1, ymin1, xmax1, ymax1 = bbox1
    xmin2, ymin2, xmax2, ymax2 = bbox2
    
    xmin_merged = min(xmin1, xmin2)
    ymin_merged = min(ymin1, ymin2)
    xmax_merged = max(xmax1, xmax2)
    ymax_merged = max(ymax1, ymax2)
    
    return np.array([xmin_merged, ymin_merged, xmax_merged, ymax_merged])


def init_sam(
    device="cuda",
    ckpt_path='/users/kchen157/scratch/weights/SAM/sam_vit_h_4b8939.pth'
    ):
    sam = sam_model_registry['vit_h'](checkpoint=ckpt_path)
    sam.to(device=device)
    predictor = SamPredictor(sam)
    return predictor


def segment_hand_and_object(
    predictor,
    image, 
    hand_kpts, 
    hand_mask=None,
    box_shift_ratio = 0.3,
    box_size_factor = 2.,
    area_threshold = 0.2,
    overlap_threshold = 200):
    # Find bounding box for HOI
    input_box = {}
    for hand_type in ['right', 'left']:
        if hand_type not in hand_kpts:
            continue
        input_box[hand_type] = np.stack([hand_kpts[hand_type].min(axis=0), hand_kpts[hand_type].max(axis=0)])
        box_trans = input_box[hand_type][0] * box_shift_ratio + input_box[hand_type][1] * (1 - box_shift_ratio)
        input_box[hand_type] = ((input_box[hand_type] - box_trans) * box_size_factor + box_trans).reshape(-1)

    if len(input_box) == 2:
        input_box = merge_bounding_boxes(input_box['right'], input_box['left'])
        input_point = np.array([hand_kpts['right'][0], hand_kpts['left'][0]])
        input_label = np.array([1, 1])
    elif 'right' in input_box:
        input_box = input_box['right']
        input_point = np.array([hand_kpts['right'][0]])
        input_label = np.array([1])
    elif 'left' in input_box:
        input_box = input_box['left']
        input_point = np.array([hand_kpts['left'][0]])
        input_label = np.array([1])

    box_area = (input_box[2] - input_box[0]) * (input_box[3] - input_box[1])

    # segment hand using the wrist point
    predictor.set_image(image)
    if hand_mask is None:
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
        hand_mask = masks[0]

    # segment object in hand 
    input_label = np.zeros_like(input_label)
    masks, scores, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box[None, :],
        multimask_output=False,
    )
    object_mask = masks[0]

    if  (masks[0].astype(int) * hand_mask).sum() > overlap_threshold:
        # print('False positive: The mask overlaps the hand.')
        object_mask = np.zeros_like(object_mask)
    elif object_mask.astype(int).sum() / box_area > area_threshold:
        # print('False positive: The area is very big, probably the background')
        object_mask = np.zeros_like(object_mask)

    return object_mask, hand_mask