File size: 8,071 Bytes
744c933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ad306d
744c933
334df79
744c933
 
334df79
 
 
 
 
744c933
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import json
import os
import torch

import numpy as np


from pq3d.model import Query3DUnified
from pq3d.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence
from torch.utils.data import default_collate
from transformers import AutoTokenizer

ASSET_DIR = os.path.join(os.getcwd(), 'assets')
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint')
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8"))
cat2int = {w: i for i, w in enumerate(int2cat)}
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv"))

def load_data(scan_id):
    one_scan = {}
    # load scan
    pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth'))
    inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth')) 
    points, colors, instance_labels  = pcd_data[0], pcd_data[1], pcd_data[-1]
    colors = colors / 127.5 - 1
    pcds = np.concatenate([points, colors], 1)
    one_scan['pcds'] = pcds
    one_scan['instance_labels'] = instance_labels
    one_scan['inst_to_label'] = inst_to_label
    # convert to gt object
    obj_pcds = []
    inst_ids = []
    inst_labels = []
    bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_)
    for inst_id in inst_to_label.keys():
        if inst_to_label[inst_id] in cat2int.keys():
            mask = instance_labels == inst_id
            if np.sum(mask) == 0:
                continue
            obj_pcds.append(pcds[mask])
            inst_ids.append(inst_id)
            inst_labels.append(cat2int[inst_to_label[inst_id]])
            if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']:
                bg_indices[mask] = False
    one_scan['obj_pcds'] = obj_pcds
    one_scan['inst_labels'] = inst_labels
    one_scan['inst_ids'] = inst_ids
    one_scan['bg_pcds'] = pcds[bg_indices]
    # calculate box for matching
    obj_center = []
    obj_box_size = []
    for obj_pcd in obj_pcds:
        _c, _b = convert_pc_to_box(obj_pcd)
        obj_center.append(_c)
        obj_box_size.append(_b)
    one_scan['obj_loc'] = obj_center
    one_scan['obj_box'] = obj_box_size
    # load image feat
    feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth')
    feat_dict = torch.load(feat_pth)
    feat_dim = next(iter(feat_dict.values())).shape[0]
    n_obj = len(one_scan['inst_ids'])  # the last one is for missing objects.
    feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
    for i, cid in enumerate(one_scan['inst_ids']):
        if cid in feat_dict.keys():
            feat[i] =  feat_dict[cid]
    one_scan['image_obj_feat_gt'] = feat
    # load voxel feat
    feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_voxel_gt.pth')
    feat_dict = torch.load(feat_pth)
    feat_dim = next(iter(feat_dict.values())).shape[0]
    n_obj = len(one_scan['inst_ids'])  # the last one is for missing objects.
    feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
    for i, cid in enumerate(one_scan['inst_ids']):
        if cid in feat_dict.keys():
            feat[i] =  feat_dict[cid]
    one_scan['voxel_obj_feat_gt'] = feat
    # load point feat
    feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_point_gt.pth')
    feat_dict = torch.load(feat_pth)
    feat_dim = next(iter(feat_dict.values())).shape[0]
    n_obj = len(one_scan['inst_ids'])  # the last one is for missing objects.
    feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
    for i, cid in enumerate(one_scan['inst_ids']):
        if cid in feat_dict.keys():
            feat[i] =  feat_dict[cid]
    one_scan['pc_obj_feat_gt'] = feat
    #  convert to pq3d input
    obj_labels = one_scan['inst_labels'] # N
    obj_pcds = one_scan['obj_pcds']
    obj_ids = one_scan['inst_ids'] 
    # object filter
    excluded_labels = ['wall', 'floor', 'ceiling']
    def keep_obj(i, obj_label):
        category = int2cat[obj_label]
        # filter out background
        if category in excluded_labels:
            return False
        # filter out objects not mentioned in the sentence
        return True
    selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)]
    obj_labels = [obj_labels[i] for i in selected_obj_idxs]
    obj_pcds = [obj_pcds[i] for i in selected_obj_idxs]
    # subsample points
    obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024,
                                        replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds])
    obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False)
    data_dict = {
        "scan_id": scan_id,
        "obj_fts": obj_fts.float(),
        "obj_locs": obj_locs.float(),
        "obj_labels": torch.LongTensor(obj_labels),
        "obj_boxes": obj_boxes, 
        "obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate
        "obj_ids": [obj_ids[i] for i in selected_obj_idxs]
    }
    # convert image feature
    feats = one_scan['image_obj_feat_' + 'gt']
    valid = selected_obj_idxs
    data_dict['mv_seg_fts'] = feats[valid]
    data_dict['mv_seg_pad_masks'] = torch.ones(len(data_dict['mv_seg_fts']), dtype=torch.bool)
    # convert voxel feature
    feats = one_scan['voxel_obj_feat_' + 'gt']
    valid = selected_obj_idxs
    data_dict['voxel_seg_fts'] = feats[valid]
    data_dict['voxel_seg_pad_masks'] = torch.ones(len(data_dict['voxel_seg_fts']), dtype=torch.bool)
    # convert point feature
    feats = one_scan['pc_obj_feat_' + 'gt']
    valid = selected_obj_idxs
    data_dict['pc_seg_fts'] = feats[valid]
    data_dict['pc_seg_pad_masks'] = torch.ones(len(data_dict['pc_seg_fts']), dtype=torch.bool)
    # build other
    data_dict['query_locs'] = data_dict['obj_locs'].clone()
    data_dict['query_pad_masks'] = data_dict['obj_pad_masks'].clone()
    data_dict['seg_center'] = obj_locs.float()
    data_dict['seg_pad_masks'] = data_dict['obj_pad_masks']
    return data_dict

def form_batch(data_dict):
    batch = [data_dict]
    new_batch = {}
    # merge list keys
    list_keys = [k for k, v in batch[0].items() if isinstance(v, list)]
    for k in list_keys:
        new_batch[k] = [sample.pop(k) for sample in batch] 
    # merge tensor
    padding_keys = [k for k, v in batch[0].items() if isinstance(v, torch.Tensor) and v.ndim > 0]
    for k in padding_keys:
        tensors = [sample.pop(k) for sample in batch]
        padding_value = -100 if k == 'obj_labels' else 0
        padded_tensor = pad_sequence(tensors, pad=padding_value)
        new_batch[k] = padded_tensor
    # others
    new_batch.update(default_collate(batch))
    return new_batch
    
def tokenize_txt(text):
    tokenizer_name = 'openai/clip-vit-large-patch14'
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    encoded_input = tokenizer([text], add_special_tokens=True, truncation=True)
    data_dict = {}
    data_dict['prompt'] = torch.FloatTensor(encoded_input.input_ids[0])
    data_dict['prompt_pad_masks'] = torch.ones((len(data_dict['prompt']))).bool()
    data_dict['prompt_type'] = 1 # txt
    return data_dict
    
def inference(scan_id, text):
    data_dict = load_data(scan_id)
    data_dict.update(tokenize_txt(text))
    data_dict = form_batch(data_dict)
    model = Query3DUnified()
    load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
    data_dict = model(data_dict)
    # calculate result id
    result_id = data_dict['obj_ids'][0][torch.argmax(data_dict['og3d_logits'][0]).item()]
    print(f"finish infernece result id is {result_id}")
    # calculate langauge
    tokenizer = AutoTokenizer.from_pretrained("t5-small")
    response_pred = tokenizer.batch_decode(data_dict['generation_logits'], skip_special_tokens=True)[0]
    print(f"text response is {response_pred}")
    return result_id, response_pred
    
if __name__ == '__main__':
    inference("scene0050_00", "chair")