|
import json |
|
import os |
|
import torch |
|
|
|
import numpy as np |
|
|
|
|
|
from pq3d.model import Query3DUnified |
|
from pq3d.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence |
|
from torch.utils.data import default_collate |
|
from transformers import AutoTokenizer |
|
|
|
ASSET_DIR = os.path.join(os.getcwd(), 'assets') |
|
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint') |
|
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8")) |
|
cat2int = {w: i for i, w in enumerate(int2cat)} |
|
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv")) |
|
|
|
def load_data(scan_id): |
|
one_scan = {} |
|
|
|
pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth')) |
|
inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth')) |
|
points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1] |
|
colors = colors / 127.5 - 1 |
|
pcds = np.concatenate([points, colors], 1) |
|
one_scan['pcds'] = pcds |
|
one_scan['instance_labels'] = instance_labels |
|
one_scan['inst_to_label'] = inst_to_label |
|
|
|
obj_pcds = [] |
|
inst_ids = [] |
|
inst_labels = [] |
|
bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_) |
|
for inst_id in inst_to_label.keys(): |
|
if inst_to_label[inst_id] in cat2int.keys(): |
|
mask = instance_labels == inst_id |
|
if np.sum(mask) == 0: |
|
continue |
|
obj_pcds.append(pcds[mask]) |
|
inst_ids.append(inst_id) |
|
inst_labels.append(cat2int[inst_to_label[inst_id]]) |
|
if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']: |
|
bg_indices[mask] = False |
|
one_scan['obj_pcds'] = obj_pcds |
|
one_scan['inst_labels'] = inst_labels |
|
one_scan['inst_ids'] = inst_ids |
|
one_scan['bg_pcds'] = pcds[bg_indices] |
|
|
|
obj_center = [] |
|
obj_box_size = [] |
|
for obj_pcd in obj_pcds: |
|
_c, _b = convert_pc_to_box(obj_pcd) |
|
obj_center.append(_c) |
|
obj_box_size.append(_b) |
|
one_scan['obj_loc'] = obj_center |
|
one_scan['obj_box'] = obj_box_size |
|
|
|
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_img_gt.pth') |
|
feat_dict = torch.load(feat_pth) |
|
feat_dim = next(iter(feat_dict.values())).shape[0] |
|
n_obj = len(one_scan['inst_ids']) |
|
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32) |
|
for i, cid in enumerate(one_scan['inst_ids']): |
|
if cid in feat_dict.keys(): |
|
feat[i] = feat_dict[cid] |
|
one_scan['image_obj_feat_gt'] = feat |
|
|
|
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_voxel_gt.pth') |
|
feat_dict = torch.load(feat_pth) |
|
feat_dim = next(iter(feat_dict.values())).shape[0] |
|
n_obj = len(one_scan['inst_ids']) |
|
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32) |
|
for i, cid in enumerate(one_scan['inst_ids']): |
|
if cid in feat_dict.keys(): |
|
feat[i] = feat_dict[cid] |
|
one_scan['voxel_obj_feat_gt'] = feat |
|
|
|
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_point_gt.pth') |
|
feat_dict = torch.load(feat_pth) |
|
feat_dim = next(iter(feat_dict.values())).shape[0] |
|
n_obj = len(one_scan['inst_ids']) |
|
feat = torch.zeros((n_obj, feat_dim), dtype=torch.float32) |
|
for i, cid in enumerate(one_scan['inst_ids']): |
|
if cid in feat_dict.keys(): |
|
feat[i] = feat_dict[cid] |
|
one_scan['pc_obj_feat_gt'] = feat |
|
|
|
obj_labels = one_scan['inst_labels'] |
|
obj_pcds = one_scan['obj_pcds'] |
|
obj_ids = one_scan['inst_ids'] |
|
|
|
excluded_labels = ['wall', 'floor', 'ceiling'] |
|
def keep_obj(i, obj_label): |
|
category = int2cat[obj_label] |
|
|
|
if category in excluded_labels: |
|
return False |
|
|
|
return True |
|
selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)] |
|
obj_labels = [obj_labels[i] for i in selected_obj_idxs] |
|
obj_pcds = [obj_pcds[i] for i in selected_obj_idxs] |
|
|
|
obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024, |
|
replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds]) |
|
obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False) |
|
data_dict = { |
|
"scan_id": scan_id, |
|
"obj_fts": obj_fts.float(), |
|
"obj_locs": obj_locs.float(), |
|
"obj_labels": torch.LongTensor(obj_labels), |
|
"obj_boxes": obj_boxes, |
|
"obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), |
|
"obj_ids": [obj_ids[i] for i in selected_obj_idxs] |
|
} |
|
|
|
feats = one_scan['image_obj_feat_' + 'gt'] |
|
valid = selected_obj_idxs |
|
data_dict['mv_seg_fts'] = feats[valid] |
|
data_dict['mv_seg_pad_masks'] = torch.ones(len(data_dict['mv_seg_fts']), dtype=torch.bool) |
|
|
|
feats = one_scan['voxel_obj_feat_' + 'gt'] |
|
valid = selected_obj_idxs |
|
data_dict['voxel_seg_fts'] = feats[valid] |
|
data_dict['voxel_seg_pad_masks'] = torch.ones(len(data_dict['voxel_seg_fts']), dtype=torch.bool) |
|
|
|
feats = one_scan['pc_obj_feat_' + 'gt'] |
|
valid = selected_obj_idxs |
|
data_dict['pc_seg_fts'] = feats[valid] |
|
data_dict['pc_seg_pad_masks'] = torch.ones(len(data_dict['pc_seg_fts']), dtype=torch.bool) |
|
|
|
data_dict['query_locs'] = data_dict['obj_locs'].clone() |
|
data_dict['query_pad_masks'] = data_dict['obj_pad_masks'].clone() |
|
data_dict['seg_center'] = obj_locs.float() |
|
data_dict['seg_pad_masks'] = data_dict['obj_pad_masks'] |
|
return data_dict |
|
|
|
def form_batch(data_dict): |
|
batch = [data_dict] |
|
new_batch = {} |
|
|
|
list_keys = [k for k, v in batch[0].items() if isinstance(v, list)] |
|
for k in list_keys: |
|
new_batch[k] = [sample.pop(k) for sample in batch] |
|
|
|
padding_keys = [k for k, v in batch[0].items() if isinstance(v, torch.Tensor) and v.ndim > 0] |
|
for k in padding_keys: |
|
tensors = [sample.pop(k) for sample in batch] |
|
padding_value = -100 if k == 'obj_labels' else 0 |
|
padded_tensor = pad_sequence(tensors, pad=padding_value) |
|
new_batch[k] = padded_tensor |
|
|
|
new_batch.update(default_collate(batch)) |
|
return new_batch |
|
|
|
def tokenize_txt(text): |
|
tokenizer_name = 'openai/clip-vit-large-patch14' |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
encoded_input = tokenizer([text], add_special_tokens=True, truncation=True) |
|
data_dict = {} |
|
data_dict['prompt'] = torch.FloatTensor(encoded_input.input_ids[0]) |
|
data_dict['prompt_pad_masks'] = torch.ones((len(data_dict['prompt']))).bool() |
|
data_dict['prompt_type'] = 1 |
|
return data_dict |
|
|
|
def inference(scan_id, text): |
|
data_dict = load_data(scan_id) |
|
data_dict.update(tokenize_txt(text)) |
|
data_dict = form_batch(data_dict) |
|
model = Query3DUnified() |
|
load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False) |
|
data_dict = model(data_dict) |
|
|
|
result_id = data_dict['obj_ids'][0][torch.argmax(data_dict['og3d_logits'][0]).item()] |
|
print(f"finish infernece result id is {result_id}") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-small") |
|
response_pred = tokenizer.batch_decode(data_dict['generation_logits'], skip_special_tokens=True)[0] |
|
print(f"text response is {response_pred}") |
|
return result_id, response_pred |
|
|
|
if __name__ == '__main__': |
|
inference("scene0050_00", "chair") |
|
|