import json import os import torch import numpy as np from leo.model import SequentialGrounder from leo.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence from torch.utils.data import default_collate ASSET_DIR = os.path.join(os.getcwd(), 'assets') CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint/leo') int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8")) cat2int = {w: i for i, w in enumerate(int2cat)} label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv")) role_prompt = "You are an AI visual assistant situated in a 3D scene. "\ "You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\ "You should properly respond to the USER's instruction according to the given visual information. " #role_prompt = " " egoview_prompt = "Ego-view image:" objects_prompt = "Objects (including you) in the scene:" task_prompt = "USER: {instruction} ASSISTANT:" def get_prompt(instruction): return { 'prompt_before_obj': role_prompt, 'prompt_middle_1': egoview_prompt, 'prompt_middle_2': objects_prompt, 'prompt_after_obj': task_prompt.format(instruction=instruction), } def get_lang(task_item): task_description = task_item['task_description'] sentence = task_description data_dict = get_prompt(task_description) # scan_id = task_item['scan_id'] if 'action_steps' in task_item: action_steps = task_item['action_steps'] # tgt_object_id = [int(action['target_id']) for action in action_steps] # tgt_object_name = [action['label'] for action in action_steps] for action in action_steps: sentence += ' ' + action['action'] data_dict['output_gt'] = ' '.join([action['action'] + ' ' for action in action_steps]) # return scan_id, tgt_object_id, tgt_object_name, sentence, data_dict return data_dict def load_data(scan_id): one_scan = {} # load scan pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth')) inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth')) points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1] colors = colors / 127.5 - 1 pcds = np.concatenate([points, colors], 1) one_scan['pcds'] = pcds one_scan['instance_labels'] = instance_labels one_scan['inst_to_label'] = inst_to_label # convert to gt object obj_pcds = [] inst_ids = [] inst_labels = [] bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_) for inst_id in inst_to_label.keys(): if inst_to_label[inst_id] in cat2int.keys(): mask = instance_labels == inst_id if np.sum(mask) == 0: continue obj_pcds.append(pcds[mask]) inst_ids.append(inst_id) inst_labels.append(cat2int[inst_to_label[inst_id]]) if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']: bg_indices[mask] = False one_scan['obj_pcds'] = obj_pcds one_scan['inst_labels'] = inst_labels one_scan['inst_ids'] = inst_ids one_scan['bg_pcds'] = pcds[bg_indices] # calculate box for matching obj_center = [] obj_box_size = [] for obj_pcd in obj_pcds: _c, _b = convert_pc_to_box(obj_pcd) obj_center.append(_c) obj_box_size.append(_b) one_scan['obj_loc'] = obj_center one_scan['obj_box'] = obj_box_size # load point feat feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', 'obj_feats.pth') one_scan['obj_feats'] = torch.load(feat_pth, map_location='cpu') # convert to pq3d input obj_labels = one_scan['inst_labels'] # N obj_pcds = one_scan['obj_pcds'] obj_ids = one_scan['inst_ids'] # object filter excluded_labels = ['wall', 'floor', 'ceiling'] def keep_obj(i, obj_label): category = int2cat[obj_label] # filter out background if category in excluded_labels: return False # filter out objects not mentioned in the sentence return True selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)] # crop objects to max_obj_len and reorganize ids ? # TODO obj_labels = [obj_labels[i] for i in selected_obj_idxs] obj_pcds = [obj_pcds[i] for i in selected_obj_idxs] # subsample points obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024, replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds]) obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False) data_dict = { "scan_id": scan_id, "obj_fts": obj_fts.float(), "obj_locs": obj_locs.float(), "obj_labels": torch.LongTensor(obj_labels), "obj_boxes": obj_boxes, "obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate "obj_ids": torch.LongTensor([obj_ids[i] for i in selected_obj_idxs]) } # convert point feature data_dict['obj_feats'] = one_scan['obj_feats'].squeeze(0) useful_keys = ['tgt_object_id', 'scan_id', 'obj_labels', 'data_idx', 'obj_fts', 'obj_locs', 'obj_pad_masks', 'obj_ids', 'source', 'prompt_before_obj', 'prompt_middle_1', 'prompt_middle_2', 'prompt_after_obj', 'output_gt', 'obj_feats'] for k in list(data_dict.keys()): if k not in useful_keys: del data_dict[k] # add new keys because of leo data_dict['img_fts'] = torch.zeros(3, 224, 224) data_dict['img_masks'] = torch.LongTensor([0]).bool() data_dict['anchor_locs'] = torch.zeros(3) data_dict['anchor_orientation'] = torch.zeros(4) data_dict['anchor_orientation'][-1] = 1 # xyzw # convert to leo format data_dict['obj_masks'] = data_dict['obj_pad_masks'] del data_dict['obj_pad_masks'] return data_dict def form_batch(data_dict): batch = [data_dict] new_batch = {} # pad padding_keys = ['obj_fts', 'obj_locs', 'obj_masks', 'obj_labels', 'obj_ids'] for k in padding_keys: tensors = [sample.pop(k) for sample in batch] padded_tensor = pad_sequence(tensors, pad=0) new_batch[k] = padded_tensor # # list # list_keys = ['tgt_object_id'] # for k in list_keys: # new_batch[k] = [sample.pop(k) for sample in batch] # default collate new_batch.update(default_collate(batch)) return new_batch def inference(scan_id, task, predict_mode=False): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device = 'cpu' # ok for predict_mode=False, and both for Gradio demo local preview data_dict = load_data(scan_id) data_dict.update(get_lang(task)) data_dict = form_batch(data_dict) for key, value in data_dict.items(): if isinstance(value, torch.Tensor): data_dict[key] = value.to(device) model = SequentialGrounder(predict_mode) load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False) model.to(device) data_dict = model(data_dict) if predict_mode == False: # calculate result id result_id_list = [data_dict['obj_ids'][0][torch.argmax(data_dict['ground_logits'][i]).item()] for i in range(len(data_dict['ground_logits']))] else: # calculate langauge # tgt_object_id = data_dict['tgt_object_id'] if data_dict['ground_logits'] == None: og_pred = [] else: og_pred = torch.argmax(data_dict['ground_logits'], dim=1) grd_batch_ind_list = data_dict['grd_batch_ind_list'] response_pred = [] for i in range(1): # len(tgt_object_id) # target_sequence = list(tgt_object_id[i].cpu().numpy()) predict_sequence = [] if og_pred != None: for j in range(len(og_pred)): if grd_batch_ind_list[j] == i: predict_sequence.append(og_pred[j].item()) obj_ids = data_dict['obj_ids'] response_pred.append({ 'predict_object_id' : [obj_ids[i][o].item() for o in predict_sequence], 'predict_object_id': [obj_ids[i][o].item() for o in predict_sequence], 'pred_plan_text': data_dict['output_txt'][i] }) return result_id_list if predict_mode == False else response_pred if __name__ == '__main__': inference("scene0050_00", { "task_description": "Find the chair and move it to the table.", "action_steps": [ { "target_id": "1", "label": "chair", "action": "Find the chair." }, { "target_id": "2", "label": "table", "action": "Move the chair to the table." } ], "scan_id": "scene0050_00" }, predict_mode=True)