File size: 8,071 Bytes
744c933 1ad306d 744c933 334df79 744c933 334df79 744c933 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import json
import os
import torch
import numpy as np
from pq3d.model import Query3DUnified
from pq3d.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence
from torch.utils.data import default_collate
from transformers import AutoTokenizer
ASSET_DIR = os.path.join(os.getcwd(), 'assets')
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint')
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8"))
cat2int = {w: i for i, w in enumerate(int2cat)}
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv"))
def load_data(scan_id):
one_scan = {}
# load scan
pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth'))
inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth'))
points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1]
colors = colors / 127.5 - 1
pcds = np.concatenate([points, colors], 1)
one_scan['pcds'] = pcds
one_scan['instance_labels'] = instance_labels
one_scan['inst_to_label'] = inst_to_label
# convert to gt object
obj_pcds = []
inst_ids = []
inst_labels = []
bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_)
for inst_id in inst_to_label.keys():
if inst_to_label[inst_id] in cat2int.keys():
mask = instance_labels == inst_id
if np.sum(mask) == 0:
continue
obj_pcds.append(pcds[mask])
inst_ids.append(inst_id)
inst_labels.append(cat2int[inst_to_label[inst_id]])
if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']:
bg_indices[mask] = False
one_scan['obj_pcds'] = obj_pcds
one_scan['inst_labels'] = inst_labels
one_scan['inst_ids'] = inst_ids
one_scan['bg_pcds'] = pcds[bg_indices]
# calculate box for matching
obj_center = []
obj_box_size = []
for obj_pcd in obj_pcds:
_c, _b = convert_pc_to_box(obj_pcd)
obj_center.append(_c)
obj_box_size.append(_b)
one_scan['obj_loc'] = obj_center
one_scan['obj_box'] = obj_box_size
# load image feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth')
feat_dict = torch.load(feat_pth)
feat_dim = next(iter(feat_dict.values())).shape[0]
n_obj = len(one_scan['inst_ids']) # the last one is for missing objects.
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
for i, cid in enumerate(one_scan['inst_ids']):
if cid in feat_dict.keys():
feat[i] = feat_dict[cid]
one_scan['image_obj_feat_gt'] = feat
# load voxel feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_voxel_gt.pth')
feat_dict = torch.load(feat_pth)
feat_dim = next(iter(feat_dict.values())).shape[0]
n_obj = len(one_scan['inst_ids']) # the last one is for missing objects.
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
for i, cid in enumerate(one_scan['inst_ids']):
if cid in feat_dict.keys():
feat[i] = feat_dict[cid]
one_scan['voxel_obj_feat_gt'] = feat
# load point feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_point_gt.pth')
feat_dict = torch.load(feat_pth)
feat_dim = next(iter(feat_dict.values())).shape[0]
n_obj = len(one_scan['inst_ids']) # the last one is for missing objects.
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
for i, cid in enumerate(one_scan['inst_ids']):
if cid in feat_dict.keys():
feat[i] = feat_dict[cid]
one_scan['pc_obj_feat_gt'] = feat
# convert to pq3d input
obj_labels = one_scan['inst_labels'] # N
obj_pcds = one_scan['obj_pcds']
obj_ids = one_scan['inst_ids']
# object filter
excluded_labels = ['wall', 'floor', 'ceiling']
def keep_obj(i, obj_label):
category = int2cat[obj_label]
# filter out background
if category in excluded_labels:
return False
# filter out objects not mentioned in the sentence
return True
selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)]
obj_labels = [obj_labels[i] for i in selected_obj_idxs]
obj_pcds = [obj_pcds[i] for i in selected_obj_idxs]
# subsample points
obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024,
replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds])
obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False)
data_dict = {
"scan_id": scan_id,
"obj_fts": obj_fts.float(),
"obj_locs": obj_locs.float(),
"obj_labels": torch.LongTensor(obj_labels),
"obj_boxes": obj_boxes,
"obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate
"obj_ids": [obj_ids[i] for i in selected_obj_idxs]
}
# convert image feature
feats = one_scan['image_obj_feat_' + 'gt']
valid = selected_obj_idxs
data_dict['mv_seg_fts'] = feats[valid]
data_dict['mv_seg_pad_masks'] = torch.ones(len(data_dict['mv_seg_fts']), dtype=torch.bool)
# convert voxel feature
feats = one_scan['voxel_obj_feat_' + 'gt']
valid = selected_obj_idxs
data_dict['voxel_seg_fts'] = feats[valid]
data_dict['voxel_seg_pad_masks'] = torch.ones(len(data_dict['voxel_seg_fts']), dtype=torch.bool)
# convert point feature
feats = one_scan['pc_obj_feat_' + 'gt']
valid = selected_obj_idxs
data_dict['pc_seg_fts'] = feats[valid]
data_dict['pc_seg_pad_masks'] = torch.ones(len(data_dict['pc_seg_fts']), dtype=torch.bool)
# build other
data_dict['query_locs'] = data_dict['obj_locs'].clone()
data_dict['query_pad_masks'] = data_dict['obj_pad_masks'].clone()
data_dict['seg_center'] = obj_locs.float()
data_dict['seg_pad_masks'] = data_dict['obj_pad_masks']
return data_dict
def form_batch(data_dict):
batch = [data_dict]
new_batch = {}
# merge list keys
list_keys = [k for k, v in batch[0].items() if isinstance(v, list)]
for k in list_keys:
new_batch[k] = [sample.pop(k) for sample in batch]
# merge tensor
padding_keys = [k for k, v in batch[0].items() if isinstance(v, torch.Tensor) and v.ndim > 0]
for k in padding_keys:
tensors = [sample.pop(k) for sample in batch]
padding_value = -100 if k == 'obj_labels' else 0
padded_tensor = pad_sequence(tensors, pad=padding_value)
new_batch[k] = padded_tensor
# others
new_batch.update(default_collate(batch))
return new_batch
def tokenize_txt(text):
tokenizer_name = 'openai/clip-vit-large-patch14'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
encoded_input = tokenizer([text], add_special_tokens=True, truncation=True)
data_dict = {}
data_dict['prompt'] = torch.FloatTensor(encoded_input.input_ids[0])
data_dict['prompt_pad_masks'] = torch.ones((len(data_dict['prompt']))).bool()
data_dict['prompt_type'] = 1 # txt
return data_dict
def inference(scan_id, text):
data_dict = load_data(scan_id)
data_dict.update(tokenize_txt(text))
data_dict = form_batch(data_dict)
model = Query3DUnified()
load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
data_dict = model(data_dict)
# calculate result id
result_id = data_dict['obj_ids'][0][torch.argmax(data_dict['og3d_logits'][0]).item()]
print(f"finish infernece result id is {result_id}")
# calculate langauge
tokenizer = AutoTokenizer.from_pretrained("t5-small")
response_pred = tokenizer.batch_decode(data_dict['generation_logits'], skip_special_tokens=True)[0]
print(f"text response is {response_pred}")
return result_id, response_pred
if __name__ == '__main__':
inference("scene0050_00", "chair")
|