import json import os import torch import numpy as np from pq3d.model import Query3DUnified from pq3d.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence from torch.utils.data import default_collate from transformers import AutoTokenizer ASSET_DIR = os.path.join(os.getcwd(), 'assets') CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint') int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8")) cat2int = {w: i for i, w in enumerate(int2cat)} label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv")) def load_data(scan_id): one_scan = {} # load scan pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth')) inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth')) points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1] colors = colors / 127.5 - 1 pcds = np.concatenate([points, colors], 1) one_scan['pcds'] = pcds one_scan['instance_labels'] = instance_labels one_scan['inst_to_label'] = inst_to_label # convert to gt object obj_pcds = [] inst_ids = [] inst_labels = [] bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_) for inst_id in inst_to_label.keys(): if inst_to_label[inst_id] in cat2int.keys(): mask = instance_labels == inst_id if np.sum(mask) == 0: continue obj_pcds.append(pcds[mask]) inst_ids.append(inst_id) inst_labels.append(cat2int[inst_to_label[inst_id]]) if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']: bg_indices[mask] = False one_scan['obj_pcds'] = obj_pcds one_scan['inst_labels'] = inst_labels one_scan['inst_ids'] = inst_ids one_scan['bg_pcds'] = pcds[bg_indices] # calculate box for matching obj_center = [] obj_box_size = [] for obj_pcd in obj_pcds: _c, _b = convert_pc_to_box(obj_pcd) obj_center.append(_c) obj_box_size.append(_b) one_scan['obj_loc'] = obj_center one_scan['obj_box'] = obj_box_size # load image feat feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth') feat_dict = torch.load(feat_pth) feat_dim = next(iter(feat_dict.values())).shape[0] n_obj = len(one_scan['inst_ids']) # the last one is for missing objects. feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32) for i, cid in enumerate(one_scan['inst_ids']): if cid in feat_dict.keys(): feat[i] = feat_dict[cid] one_scan['image_obj_feat_gt'] = feat # load voxel feat feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_voxel_gt.pth') feat_dict = torch.load(feat_pth) feat_dim = next(iter(feat_dict.values())).shape[0] n_obj = len(one_scan['inst_ids']) # the last one is for missing objects. feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32) for i, cid in enumerate(one_scan['inst_ids']): if cid in feat_dict.keys(): feat[i] = feat_dict[cid] one_scan['voxel_obj_feat_gt'] = feat # load point feat feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_point_gt.pth') feat_dict = torch.load(feat_pth) feat_dim = next(iter(feat_dict.values())).shape[0] n_obj = len(one_scan['inst_ids']) # the last one is for missing objects. feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32) for i, cid in enumerate(one_scan['inst_ids']): if cid in feat_dict.keys(): feat[i] = feat_dict[cid] one_scan['pc_obj_feat_gt'] = feat # convert to pq3d input obj_labels = one_scan['inst_labels'] # N obj_pcds = one_scan['obj_pcds'] obj_ids = one_scan['inst_ids'] # object filter excluded_labels = ['wall', 'floor', 'ceiling'] def keep_obj(i, obj_label): category = int2cat[obj_label] # filter out background if category in excluded_labels: return False # filter out objects not mentioned in the sentence return True selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)] obj_labels = [obj_labels[i] for i in selected_obj_idxs] obj_pcds = [obj_pcds[i] for i in selected_obj_idxs] # subsample points obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024, replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds]) obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False) data_dict = { "scan_id": scan_id, "obj_fts": obj_fts.float(), "obj_locs": obj_locs.float(), "obj_labels": torch.LongTensor(obj_labels), "obj_boxes": obj_boxes, "obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate "obj_ids": [obj_ids[i] for i in selected_obj_idxs] } # convert image feature feats = one_scan['image_obj_feat_' + 'gt'] valid = selected_obj_idxs data_dict['mv_seg_fts'] = feats[valid] data_dict['mv_seg_pad_masks'] = torch.ones(len(data_dict['mv_seg_fts']), dtype=torch.bool) # convert voxel feature feats = one_scan['voxel_obj_feat_' + 'gt'] valid = selected_obj_idxs data_dict['voxel_seg_fts'] = feats[valid] data_dict['voxel_seg_pad_masks'] = torch.ones(len(data_dict['voxel_seg_fts']), dtype=torch.bool) # convert point feature feats = one_scan['pc_obj_feat_' + 'gt'] valid = selected_obj_idxs data_dict['pc_seg_fts'] = feats[valid] data_dict['pc_seg_pad_masks'] = torch.ones(len(data_dict['pc_seg_fts']), dtype=torch.bool) # build other data_dict['query_locs'] = data_dict['obj_locs'].clone() data_dict['query_pad_masks'] = data_dict['obj_pad_masks'].clone() data_dict['seg_center'] = obj_locs.float() data_dict['seg_pad_masks'] = data_dict['obj_pad_masks'] return data_dict def form_batch(data_dict): batch = [data_dict] new_batch = {} # merge list keys list_keys = [k for k, v in batch[0].items() if isinstance(v, list)] for k in list_keys: new_batch[k] = [sample.pop(k) for sample in batch] # merge tensor padding_keys = [k for k, v in batch[0].items() if isinstance(v, torch.Tensor) and v.ndim > 0] for k in padding_keys: tensors = [sample.pop(k) for sample in batch] padding_value = -100 if k == 'obj_labels' else 0 padded_tensor = pad_sequence(tensors, pad=padding_value) new_batch[k] = padded_tensor # others new_batch.update(default_collate(batch)) return new_batch def tokenize_txt(text): tokenizer_name = 'openai/clip-vit-large-patch14' tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) encoded_input = tokenizer([text], add_special_tokens=True, truncation=True) data_dict = {} data_dict['prompt'] = torch.FloatTensor(encoded_input.input_ids[0]) data_dict['prompt_pad_masks'] = torch.ones((len(data_dict['prompt']))).bool() data_dict['prompt_type'] = 1 # txt return data_dict def inference(scan_id, text): data_dict = load_data(scan_id) data_dict.update(tokenize_txt(text)) data_dict = form_batch(data_dict) model = Query3DUnified() load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False) data_dict = model(data_dict) # calculate result id result_id = data_dict['obj_ids'][0][torch.argmax(data_dict['og3d_logits'][0]).item()] print(f"finish infernece result id is {result_id}") # calculate langauge tokenizer = AutoTokenizer.from_pretrained("t5-small") response_pred = tokenizer.batch_decode(data_dict['generation_logits'], skip_special_tokens=True)[0] print(f"text response is {response_pred}") return result_id, response_pred if __name__ == '__main__': inference("scene0050_00", "chair")