import contextlib import clip import torch import torch.nn as nn from einops import rearrange from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from leo.img_encoder import GridFeatureExtractor2D from leo.pcd_encoder import OSE3D from leo.grounding_head import SequentialGroundHead from leo.utils import get_mlp_head def maybe_autocast(model, dtype='bf16', enabled=True): # if on cpu, don't use autocast # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 enable_autocast = model.device != torch.device('cpu') if dtype == 'bf16': dtype = torch.bfloat16 elif dtype == 'fp16': dtype == torch.float16 else: dtype = torch.float32 if enable_autocast: return torch.cuda.amp.autocast(dtype=dtype, enabled=enabled) else: return contextlib.nullcontext() def disabled_train(self, mode=True): """ Overwrite model.train with this function to make sure train/eval mode does not change anymore """ return self class SequentialGrounder(torch.nn.Module): def __init__(self,predict_mode=False): super().__init__() cfg = { "launch_mode": "hf", "model": { "llm": { "name": "Vicuna7B", "cfg_path": "/scratch/generalvision/vicuna-7b", "hf_cfg_path": "huangjy-pku/vicuna-7b", "truncation_side": "right", "max_context_len": 256, "max_out_len": 256, "lora": { "flag": True, "rank": 16, "alpha": 16, "dropout": 0.0, "target_modules": ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], }, }, "clip_txt_guidance": { "flag": False, "clip_out_dim": 1024, }, }, } self.predict_mode = predict_mode # LLM llm_name = cfg['model']['llm']['name'] if cfg['launch_mode'] == 'hf': llm_cfg_path = cfg['model']['llm']['hf_cfg_path'] else: llm_cfg_path = cfg['model']['llm']['cfg_path'] llm_truncation_side = 'right' if 'vicuna' in llm_name.lower(): self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side) self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) else: self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side) self.llm_model = AutoModelForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) for param in self.llm_model.parameters(): param.requires_grad = False self.llm_model.eval() self.llm_model.train = disabled_train # 2D vision self.img_encoder = GridFeatureExtractor2D() self.img_proj = nn.Linear( self.img_encoder.out_channels, self.llm_model.config.hidden_size ) # 3D vision self.pcd_encoder = OSE3D() self.pcd_proj = nn.Linear(256, self.llm_model.config.hidden_size) # type embedding # self.img_type_embed = nn.Parameter(torch.zeros(self.llm_model.config.hidden_size), requires_grad=True) # self.pcd_type_embed = nn.Parameter(torch.zeros(self.llm_model.config.hidden_size), requires_grad=True) # LoRA if cfg['model']['llm']['lora']['flag']: lora_config = LoraConfig( r=cfg['model']['llm']['lora']['rank'], lora_alpha=cfg['model']['llm']['lora']['alpha'], target_modules=cfg['model']['llm']['lora']['target_modules'], lora_dropout=cfg['model']['llm']['lora']['dropout'], bias='none', modules_to_save=[], ) self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config) self.max_context_len = 256 self.max_out_len = 256 # additional text x multi-modal tokens fusion self.clip_txt_guidance = cfg['model']['clip_txt_guidance']['flag'] if self.clip_txt_guidance: self.clip_model = clip.load('RN50')[0] for param in self.clip_model.parameters(): param.requires_grad = False self.clip_model.eval() self.clip_model.train = disabled_train self.clip_proj = nn.Linear(cfg['clip_txt_guidance']['clip_out_dim'], self.llm_model.config.hidden_size) # grounding head self.ground_head = SequentialGroundHead() self.obj_cls_head = get_mlp_head(4096, 768, 607, 0.3) self.pre_grounding = True @property def device(self): return list(self.parameters())[0].device def build_right_justified_sequence(self, data_dict): """ Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`. Return right justified sequence for causal LM: , , , , . """ device = self.device bs = len(data_dict['prompt_before_obj']) self.llm_tokenizer.padding_side = 'left' text_input_tokens_pre = self.llm_tokenizer( data_dict['prompt_before_obj'], return_tensors='pt', padding='longest' ).to(device) # [PAD, BOS, tokens], (B, T1) text_input_tokens_mid1 = self.llm_tokenizer( data_dict['prompt_middle_1'], return_tensors='pt', padding='longest' ).to(device) img_tokens = data_dict['img_tokens'].to(device) img_masks = data_dict['img_masks'].to(device) img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1)) text_input_tokens_mid2 = self.llm_tokenizer( data_dict['prompt_middle_2'], return_tensors='pt', padding='longest' ).to(device) obj_tokens = data_dict['obj_tokens'].to(device) obj_masks = data_dict['obj_masks'].to(device) # additional clip fusion if self.clip_txt_guidance: with torch.no_grad(): clip_fts = self.clip_model.encode_text( clip.tokenize(data_dict['prompt_after_obj'], truncate=True).to(device) ) clip_fts = self.clip_proj(clip_fts) # B, N, C img_tokens = torch.einsum('bnc,bc->bnc', img_tokens, clip_fts) obj_tokens = torch.einsum('bnc,bc->bnc', obj_tokens, clip_fts) self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted self.llm_tokenizer.truncation_side = 'left' # truncate history text_input_tokens_post = self.llm_tokenizer( data_dict['prompt_after_obj'], return_tensors='pt', padding='longest', truncation=True, max_length=self.max_context_len, ).to(device) # [BOS, tokens, PAD], (B, T3) assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \ "prompt_middle should be the same and thus no padding" # remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq" text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:] text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:] text_input_tokens_mid2.input_ids = text_input_tokens_mid2.input_ids[:, 1:] text_input_tokens_mid2.attention_mask = text_input_tokens_mid2.attention_mask[:, 1:] text_input_tokens_post.input_ids = text_input_tokens_post.input_ids[:, 1:] text_input_tokens_post.attention_mask = text_input_tokens_post.attention_mask[:, 1:] for i in range(bs): if not img_masks[i].any(): # no image input, also mask the text prompt for image tokens text_input_tokens_mid1.attention_mask[i].fill_(0) inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids) inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids) inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids) inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids) # since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1) attn_mask_mid = torch.cat( [text_input_tokens_mid1.attention_mask, img_masks, text_input_tokens_mid2.attention_mask, obj_masks], dim=1, ) post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1) bs, l1, hidden_dim = inputs_embeds_pre.shape _, l2, _ = inputs_embeds_mid.shape _, l3, _ = inputs_embeds_post.shape inputs_embeds = torch.zeros(bs, l1+l2+l3, hidden_dim).type(inputs_embeds_pre.dtype).to(device) attention_mask = torch.zeros(bs, l1+l2+l3).type(obj_masks.dtype).to(device) # assign by chunks for i in range(bs): post_pad_len = post_pad_length[i] if post_pad_len > 0: inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:] attention_mask[i, :post_pad_len] = 0 inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len] attention_mask[i, post_pad_len+l1+l2:] = 1 else: # no padding inputs_embeds[i, -l3:] = inputs_embeds_post[i] attention_mask[i, -l3:] = 1 inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i] attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i] inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i] attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i] return inputs_embeds, attention_mask, (l1, l2, l3) def forward(self, data_dict): if self.predict_mode: return self.generate(data_dict=data_dict) """ data_dict requires keys: # input prompt_before_obj: list of str, (B,) prompt_middle_1: list of str, (B,) prompt_middle_2: list of str, (B,) prompt_after_obj: list of str, (B,) obj_fts: (B, N, P, 6), xyz + rgb obj_masks: (B, N), 1 valid and 0 masked obj_locs: (B, N, 6), xyz + whd anchor_locs: (B, 3) anchor_orientation: (B, C) img_fts: (B, 3, H, W), rgb img_masks: (B, 1), 1 valid and 0 masked # output output_gt: list of str, (B,) """ device = self.device bs = len(data_dict['prompt_after_obj']) data_dict['bs'] = bs if 'obj_tokens' not in data_dict: # obtain obj tokens data_dict = self.pcd_encoder(data_dict) # TO CHANGE FOR DEBUG #self.llm_model.float() #data_dict['obj_tokens'] = torch.zeros((data_dict['obj_locs'].shape[0], data_dict['obj_locs'].shape[1], 256)).to(device=device) data_dict['obj_tokens'] = self.pcd_proj(data_dict['obj_tokens'].to(device)) # data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed data_dict['img_tokens'] = self.img_proj(self.img_encoder(data_dict['img_fts'])) # data_dict['img_tokens'] = data_dict['img_tokens'] + self.img_type_embed # build input embdes and record prompt position inputs_embeds, attention_mask, input_length = self.build_right_justified_sequence(data_dict=data_dict) obj_token_length = data_dict['obj_masks'].shape[1] # (B, T1+O+T2, D), (B, T1+O+T2) self.llm_tokenizer.padding_side = 'right' self.llm_tokenizer.truncation_side = 'right' text_output_tokens = self.llm_tokenizer( [t + self.llm_tokenizer.eos_token for t in data_dict['output_gt']], return_tensors='pt', padding='longest', truncation=True, max_length=self.max_out_len, ).to(device) # record position for special token [SOS] grd_token_id = self.llm_tokenizer.convert_tokens_to_ids([''])[0] out_input_ids_remove_first_sos = text_output_tokens.input_ids.clone() out_input_ids_remove_first_sos[:, 0] = -100 grd_ind_0, grd_ind_1 = (out_input_ids_remove_first_sos == grd_token_id).nonzero(as_tuple=True) text_output_embeds = self.llm_model.get_input_embeddings()(text_output_tokens.input_ids) # (B, T3, D) inputs_embeds = torch.cat([inputs_embeds, text_output_embeds], dim=1) # (B, T1+O+T2+T3, D) attention_mask = torch.cat([attention_mask, text_output_tokens.attention_mask], dim=1) # (B, T1+O+T2+T3) # construct targets targets = torch.zeros_like(attention_mask).long().fill_(-100) # (B, T1+O+T2+T3) # only apply loss to answer tokens targets_idx = text_output_tokens.attention_mask.bool() targets[:, -targets_idx.shape[1]:][targets_idx] = text_output_tokens.input_ids[targets_idx] # do not predict bos token, regard it as condition instead targets[:, -targets_idx.shape[1]] = -100 with maybe_autocast(self): outputs = self.llm_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, output_hidden_states=True, ) logits = outputs.logits.float() last_hidden_state = outputs.hidden_states[-1] # different from the loss inside `llm_model.forward`, here we take mean of each sequence instead of sum shift_logits = logits[..., :-1, :].contiguous() shift_labels = targets[..., 1:].contiguous() num_tokens_for_loss = (shift_labels >= 0).int().sum(1) # (B,) shift_logits = rearrange(shift_logits, 'b t v -> (b t) v') shift_labels = rearrange(shift_labels, 'b t -> (b t)') shift_labels = shift_labels.to(shift_logits.device) # record for llm loss data_dict['llm_logits'] = shift_logits data_dict['llm_labels'] = shift_labels data_dict['num_tokens_for_loss'] = num_tokens_for_loss # record for grounding loss grd_list = [] obj_list = [] mask_list = [] for step in range(len(grd_ind_0)): batch_ind = grd_ind_0[step] grd_token_ind = grd_ind_1[step] if self.pre_grounding: output_obj_tokens = data_dict['obj_tokens'][batch_ind] else: output_obj_tokens = last_hidden_state[batch_ind, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :] output_grd_tokens = last_hidden_state[batch_ind, sum(input_length) + grd_token_ind:sum(input_length) + grd_token_ind + 1, :] grd_list.append(output_grd_tokens) obj_list.append(output_obj_tokens) mask_list.append(data_dict['obj_masks'][batch_ind]) output_obj = torch.stack(obj_list).float() output_grd = torch.stack(grd_list).float() data_dict['ground_logits'] = self.ground_head(output_obj, output_grd, torch.stack(mask_list)) # data_dict['ground_label'] = torch.concat(data_dict['tgt_object_id'], dim=0) # record for cls loss #obj_cls_post_embeds = last_hidden_state[:, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :].float() obj_cls_post_embeds = data_dict['obj_tokens'].float() data_dict['obj_cls_post_logits'] = self.obj_cls_head(obj_cls_post_embeds) return data_dict @torch.no_grad() def generate( self, data_dict, use_nucleus_sampling=False, num_beams=5, max_length=256, min_length=1, top_p=0.9, repetition_penalty=6.0, length_penalty=1, num_captions=1, temperature=1, ): """ data_dict requires the same keys as forward() except output_gt """ device = self.device bs = len(data_dict['prompt_after_obj']) data_dict['bs'] = bs if 'obj_tokens' not in data_dict: # obtain obj tokens data_dict = self.pcd_encoder(data_dict) # TO CHANGE FOR DEBUG #self.llm_model.float() #data_dict['obj_tokens'] = torch.zeros((data_dict['obj_locs'].shape[0], data_dict['obj_locs'].shape[1], 256)).to(device=device) data_dict['obj_tokens'] = self.pcd_proj(data_dict['obj_tokens'].to(device)) # data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed data_dict['img_tokens'] = self.img_proj(self.img_encoder(data_dict['img_fts'])) # data_dict['img_tokens'] = data_dict['img_tokens'] + self.img_type_embed inputs_embeds, attention_mask, input_length = self.build_right_justified_sequence(data_dict=data_dict) obj_token_length = data_dict['obj_masks'].shape[1] # give bos token as condition bos_tokens = self.llm_tokenizer( [self.llm_tokenizer.bos_token] * bs, return_tensors='pt', ).to(device) bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1) bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1) # prepare a `bos_token` bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D) inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D) attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1) with maybe_autocast(self): outputs = self.llm_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=use_nucleus_sampling, top_p=top_p, temperature=temperature, num_beams=num_beams, max_length=max_length, min_length=min_length, repetition_penalty=repetition_penalty, length_penalty=length_penalty, num_return_sequences=num_captions, return_dict_in_generate=True, output_hidden_states=True, output_scores=True ) # note output_ids_idx - 1 = step idx, because we do not preduct [BOS] beam_indices = outputs.beam_indices # bs x step, beam indices range (bsxbeam) scores = outputs.scores # step x (bs x beam) x vocab hidden_states = outputs.hidden_states # step x layer x (bs x beam) x token_num x hidden_dim outputs = outputs.sequences # bs x output_ids outputs[outputs == self.llm_tokenizer.unk_token_id] = self.llm_tokenizer.eos_token_id # data_dict['output_tokens'] = outputs # unable to gather variable-length tensors # record for grounding grd_token_id = self.llm_tokenizer.convert_tokens_to_ids([''])[0] out_input_ids_remove_first_sos = outputs.clone() out_input_ids_remove_first_sos[:, 0] = -100 grd_ind_0, grd_ind_1 = (out_input_ids_remove_first_sos == grd_token_id).nonzero(as_tuple=True) grd_list = [] grd_batch_ind_list = [] obj_list = [] mask_list = [] if len(grd_ind_0) > 0: for step in range(len(grd_ind_0)): batch_ind = grd_ind_0[step] grd_token_ind = grd_ind_1[step] #output_obj_tokens = last_hidden_state[batch_ind, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :] output_obj_tokens = data_dict['obj_tokens'][batch_ind] output_grd_tokens = hidden_states[grd_token_ind-1][-1][beam_indices[batch_ind, grd_token_ind-1]][-1].unsqueeze(0) # grd_token_ind - 1 because first token is sos grd_list.append(output_grd_tokens) grd_batch_ind_list.append(batch_ind) obj_list.append(output_obj_tokens) mask_list.append(data_dict['obj_masks'][batch_ind]) output_obj = torch.stack(obj_list).float() output_grd = torch.stack(grd_list).float() data_dict['ground_logits'] = self.ground_head(output_obj, output_grd, torch.stack(mask_list)) else: data_dict['ground_logits'] = None # data_dict['ground_label'] = torch.concat(data_dict['tgt_object_id'], dim=0) data_dict['grd_batch_ind_list'] = grd_batch_ind_list output_txt = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) output_txt = [txt.strip() for txt in output_txt] data_dict['output_txt'] = output_txt return data_dict