PQ3D-Demo / pq3d /inference.py
edward2021's picture
add text reponse
334df79
raw
history blame
8.07 kB
import json
import os
import torch
import numpy as np
from pq3d.model import Query3DUnified
from pq3d.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence
from torch.utils.data import default_collate
from transformers import AutoTokenizer
ASSET_DIR = os.path.join(os.getcwd(), 'assets')
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint')
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8"))
cat2int = {w: i for i, w in enumerate(int2cat)}
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv"))
def load_data(scan_id):
one_scan = {}
# load scan
pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth'))
inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth'))
points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1]
colors = colors / 127.5 - 1
pcds = np.concatenate([points, colors], 1)
one_scan['pcds'] = pcds
one_scan['instance_labels'] = instance_labels
one_scan['inst_to_label'] = inst_to_label
# convert to gt object
obj_pcds = []
inst_ids = []
inst_labels = []
bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_)
for inst_id in inst_to_label.keys():
if inst_to_label[inst_id] in cat2int.keys():
mask = instance_labels == inst_id
if np.sum(mask) == 0:
continue
obj_pcds.append(pcds[mask])
inst_ids.append(inst_id)
inst_labels.append(cat2int[inst_to_label[inst_id]])
if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']:
bg_indices[mask] = False
one_scan['obj_pcds'] = obj_pcds
one_scan['inst_labels'] = inst_labels
one_scan['inst_ids'] = inst_ids
one_scan['bg_pcds'] = pcds[bg_indices]
# calculate box for matching
obj_center = []
obj_box_size = []
for obj_pcd in obj_pcds:
_c, _b = convert_pc_to_box(obj_pcd)
obj_center.append(_c)
obj_box_size.append(_b)
one_scan['obj_loc'] = obj_center
one_scan['obj_box'] = obj_box_size
# load image feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth')
feat_dict = torch.load(feat_pth)
feat_dim = next(iter(feat_dict.values())).shape[0]
n_obj = len(one_scan['inst_ids']) # the last one is for missing objects.
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
for i, cid in enumerate(one_scan['inst_ids']):
if cid in feat_dict.keys():
feat[i] = feat_dict[cid]
one_scan['image_obj_feat_gt'] = feat
# load voxel feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_voxel_gt.pth')
feat_dict = torch.load(feat_pth)
feat_dim = next(iter(feat_dict.values())).shape[0]
n_obj = len(one_scan['inst_ids']) # the last one is for missing objects.
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
for i, cid in enumerate(one_scan['inst_ids']):
if cid in feat_dict.keys():
feat[i] = feat_dict[cid]
one_scan['voxel_obj_feat_gt'] = feat
# load point feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_point_gt.pth')
feat_dict = torch.load(feat_pth)
feat_dim = next(iter(feat_dict.values())).shape[0]
n_obj = len(one_scan['inst_ids']) # the last one is for missing objects.
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32)
for i, cid in enumerate(one_scan['inst_ids']):
if cid in feat_dict.keys():
feat[i] = feat_dict[cid]
one_scan['pc_obj_feat_gt'] = feat
# convert to pq3d input
obj_labels = one_scan['inst_labels'] # N
obj_pcds = one_scan['obj_pcds']
obj_ids = one_scan['inst_ids']
# object filter
excluded_labels = ['wall', 'floor', 'ceiling']
def keep_obj(i, obj_label):
category = int2cat[obj_label]
# filter out background
if category in excluded_labels:
return False
# filter out objects not mentioned in the sentence
return True
selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)]
obj_labels = [obj_labels[i] for i in selected_obj_idxs]
obj_pcds = [obj_pcds[i] for i in selected_obj_idxs]
# subsample points
obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024,
replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds])
obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False)
data_dict = {
"scan_id": scan_id,
"obj_fts": obj_fts.float(),
"obj_locs": obj_locs.float(),
"obj_labels": torch.LongTensor(obj_labels),
"obj_boxes": obj_boxes,
"obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate
"obj_ids": [obj_ids[i] for i in selected_obj_idxs]
}
# convert image feature
feats = one_scan['image_obj_feat_' + 'gt']
valid = selected_obj_idxs
data_dict['mv_seg_fts'] = feats[valid]
data_dict['mv_seg_pad_masks'] = torch.ones(len(data_dict['mv_seg_fts']), dtype=torch.bool)
# convert voxel feature
feats = one_scan['voxel_obj_feat_' + 'gt']
valid = selected_obj_idxs
data_dict['voxel_seg_fts'] = feats[valid]
data_dict['voxel_seg_pad_masks'] = torch.ones(len(data_dict['voxel_seg_fts']), dtype=torch.bool)
# convert point feature
feats = one_scan['pc_obj_feat_' + 'gt']
valid = selected_obj_idxs
data_dict['pc_seg_fts'] = feats[valid]
data_dict['pc_seg_pad_masks'] = torch.ones(len(data_dict['pc_seg_fts']), dtype=torch.bool)
# build other
data_dict['query_locs'] = data_dict['obj_locs'].clone()
data_dict['query_pad_masks'] = data_dict['obj_pad_masks'].clone()
data_dict['seg_center'] = obj_locs.float()
data_dict['seg_pad_masks'] = data_dict['obj_pad_masks']
return data_dict
def form_batch(data_dict):
batch = [data_dict]
new_batch = {}
# merge list keys
list_keys = [k for k, v in batch[0].items() if isinstance(v, list)]
for k in list_keys:
new_batch[k] = [sample.pop(k) for sample in batch]
# merge tensor
padding_keys = [k for k, v in batch[0].items() if isinstance(v, torch.Tensor) and v.ndim > 0]
for k in padding_keys:
tensors = [sample.pop(k) for sample in batch]
padding_value = -100 if k == 'obj_labels' else 0
padded_tensor = pad_sequence(tensors, pad=padding_value)
new_batch[k] = padded_tensor
# others
new_batch.update(default_collate(batch))
return new_batch
def tokenize_txt(text):
tokenizer_name = 'openai/clip-vit-large-patch14'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
encoded_input = tokenizer([text], add_special_tokens=True, truncation=True)
data_dict = {}
data_dict['prompt'] = torch.FloatTensor(encoded_input.input_ids[0])
data_dict['prompt_pad_masks'] = torch.ones((len(data_dict['prompt']))).bool()
data_dict['prompt_type'] = 1 # txt
return data_dict
def inference(scan_id, text):
data_dict = load_data(scan_id)
data_dict.update(tokenize_txt(text))
data_dict = form_batch(data_dict)
model = Query3DUnified()
load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
data_dict = model(data_dict)
# calculate result id
result_id = data_dict['obj_ids'][0][torch.argmax(data_dict['og3d_logits'][0]).item()]
print(f"finish infernece result id is {result_id}")
# calculate langauge
tokenizer = AutoTokenizer.from_pretrained("t5-small")
response_pred = tokenizer.batch_decode(data_dict['generation_logits'], skip_special_tokens=True)[0]
print(f"text response is {response_pred}")
return result_id, response_pred
if __name__ == '__main__':
inference("scene0050_00", "chair")