Spaces:
Runtime error
Runtime error
import contextlib | |
import clip | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from peft import LoraConfig, get_peft_model | |
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer | |
from leo.img_encoder import GridFeatureExtractor2D | |
from leo.pcd_encoder import OSE3D | |
from leo.grounding_head import SequentialGroundHead | |
from leo.utils import get_mlp_head | |
def maybe_autocast(model, dtype='bf16', enabled=True): | |
# if on cpu, don't use autocast | |
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16 | |
enable_autocast = model.device != torch.device('cpu') | |
if dtype == 'bf16': | |
dtype = torch.bfloat16 | |
elif dtype == 'fp16': | |
dtype == torch.float16 | |
else: | |
dtype = torch.float32 | |
if enable_autocast: | |
return torch.cuda.amp.autocast(dtype=dtype, enabled=enabled) | |
else: | |
return contextlib.nullcontext() | |
def disabled_train(self, mode=True): | |
""" | |
Overwrite model.train with this function to make sure train/eval mode does not change anymore | |
""" | |
return self | |
class SequentialGrounder(torch.nn.Module): | |
def __init__(self,predict_mode=False): | |
super().__init__() | |
cfg = { | |
"launch_mode": "hf", | |
"model": { | |
"llm": { | |
"name": "Vicuna7B", | |
"cfg_path": "/scratch/generalvision/vicuna-7b", | |
"hf_cfg_path": "huangjy-pku/vicuna-7b", | |
"truncation_side": "right", | |
"max_context_len": 256, | |
"max_out_len": 256, | |
"lora": { | |
"flag": True, | |
"rank": 16, | |
"alpha": 16, | |
"dropout": 0.0, | |
"target_modules": ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], | |
}, | |
}, | |
"clip_txt_guidance": { | |
"flag": False, | |
"clip_out_dim": 1024, | |
}, | |
}, | |
} | |
self.predict_mode = predict_mode | |
# LLM | |
llm_name = cfg['model']['llm']['name'] | |
if cfg['launch_mode'] == 'hf': | |
llm_cfg_path = cfg['model']['llm']['hf_cfg_path'] | |
else: | |
llm_cfg_path = cfg['model']['llm']['cfg_path'] | |
llm_truncation_side = 'right' | |
if 'vicuna' in llm_name.lower(): | |
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side) | |
self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) | |
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) | |
else: | |
self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_cfg_path, truncation_side=llm_truncation_side) | |
self.llm_model = AutoModelForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) | |
for param in self.llm_model.parameters(): | |
param.requires_grad = False | |
self.llm_model.eval() | |
self.llm_model.train = disabled_train | |
# 2D vision | |
self.img_encoder = GridFeatureExtractor2D() | |
self.img_proj = nn.Linear( | |
self.img_encoder.out_channels, self.llm_model.config.hidden_size | |
) | |
# 3D vision | |
self.pcd_encoder = OSE3D() | |
self.pcd_proj = nn.Linear(256, self.llm_model.config.hidden_size) | |
# type embedding | |
# self.img_type_embed = nn.Parameter(torch.zeros(self.llm_model.config.hidden_size), requires_grad=True) | |
# self.pcd_type_embed = nn.Parameter(torch.zeros(self.llm_model.config.hidden_size), requires_grad=True) | |
# LoRA | |
if cfg['model']['llm']['lora']['flag']: | |
lora_config = LoraConfig( | |
r=cfg['model']['llm']['lora']['rank'], | |
lora_alpha=cfg['model']['llm']['lora']['alpha'], | |
target_modules=cfg['model']['llm']['lora']['target_modules'], | |
lora_dropout=cfg['model']['llm']['lora']['dropout'], | |
bias='none', | |
modules_to_save=[], | |
) | |
self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config) | |
self.max_context_len = 256 | |
self.max_out_len = 256 | |
# additional text x multi-modal tokens fusion | |
self.clip_txt_guidance = cfg['model']['clip_txt_guidance']['flag'] | |
if self.clip_txt_guidance: | |
self.clip_model = clip.load('RN50')[0] | |
for param in self.clip_model.parameters(): | |
param.requires_grad = False | |
self.clip_model.eval() | |
self.clip_model.train = disabled_train | |
self.clip_proj = nn.Linear(cfg['clip_txt_guidance']['clip_out_dim'], self.llm_model.config.hidden_size) | |
# grounding head | |
self.ground_head = SequentialGroundHead() | |
self.obj_cls_head = get_mlp_head(4096, 768, 607, 0.3) | |
self.pre_grounding = True | |
def device(self): | |
return list(self.parameters())[0].device | |
def build_right_justified_sequence(self, data_dict): | |
""" | |
Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`. | |
Return right justified sequence for causal LM: <pad>, <role/situation>, <img>, <objs>, <instruction>. | |
""" | |
device = self.device | |
bs = len(data_dict['prompt_before_obj']) | |
self.llm_tokenizer.padding_side = 'left' | |
text_input_tokens_pre = self.llm_tokenizer( | |
data_dict['prompt_before_obj'], | |
return_tensors='pt', | |
padding='longest' | |
).to(device) # [PAD, BOS, tokens], (B, T1) | |
text_input_tokens_mid1 = self.llm_tokenizer( | |
data_dict['prompt_middle_1'], | |
return_tensors='pt', | |
padding='longest' | |
).to(device) | |
img_tokens = data_dict['img_tokens'].to(device) | |
img_masks = data_dict['img_masks'].to(device) | |
img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1)) | |
text_input_tokens_mid2 = self.llm_tokenizer( | |
data_dict['prompt_middle_2'], | |
return_tensors='pt', | |
padding='longest' | |
).to(device) | |
obj_tokens = data_dict['obj_tokens'].to(device) | |
obj_masks = data_dict['obj_masks'].to(device) | |
# additional clip fusion | |
if self.clip_txt_guidance: | |
with torch.no_grad(): | |
clip_fts = self.clip_model.encode_text( | |
clip.tokenize(data_dict['prompt_after_obj'], truncate=True).to(device) | |
) | |
clip_fts = self.clip_proj(clip_fts) | |
# B, N, C | |
img_tokens = torch.einsum('bnc,bc->bnc', img_tokens, clip_fts) | |
obj_tokens = torch.einsum('bnc,bc->bnc', obj_tokens, clip_fts) | |
self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted | |
self.llm_tokenizer.truncation_side = 'left' # truncate history | |
text_input_tokens_post = self.llm_tokenizer( | |
data_dict['prompt_after_obj'], | |
return_tensors='pt', | |
padding='longest', | |
truncation=True, | |
max_length=self.max_context_len, | |
).to(device) # [BOS, tokens, PAD], (B, T3) | |
assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \ | |
"prompt_middle should be the same and thus no padding" | |
# remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq" | |
text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:] | |
text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:] | |
text_input_tokens_mid2.input_ids = text_input_tokens_mid2.input_ids[:, 1:] | |
text_input_tokens_mid2.attention_mask = text_input_tokens_mid2.attention_mask[:, 1:] | |
text_input_tokens_post.input_ids = text_input_tokens_post.input_ids[:, 1:] | |
text_input_tokens_post.attention_mask = text_input_tokens_post.attention_mask[:, 1:] | |
for i in range(bs): | |
if not img_masks[i].any(): | |
# no image input, also mask the text prompt for image tokens | |
text_input_tokens_mid1.attention_mask[i].fill_(0) | |
inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids) | |
inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids) | |
inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids) | |
inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids) | |
# since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first | |
inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1) | |
attn_mask_mid = torch.cat( | |
[text_input_tokens_mid1.attention_mask, img_masks, text_input_tokens_mid2.attention_mask, obj_masks], | |
dim=1, | |
) | |
post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1) | |
bs, l1, hidden_dim = inputs_embeds_pre.shape | |
_, l2, _ = inputs_embeds_mid.shape | |
_, l3, _ = inputs_embeds_post.shape | |
inputs_embeds = torch.zeros(bs, l1+l2+l3, hidden_dim).type(inputs_embeds_pre.dtype).to(device) | |
attention_mask = torch.zeros(bs, l1+l2+l3).type(obj_masks.dtype).to(device) | |
# assign by chunks | |
for i in range(bs): | |
post_pad_len = post_pad_length[i] | |
if post_pad_len > 0: | |
inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:] | |
attention_mask[i, :post_pad_len] = 0 | |
inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len] | |
attention_mask[i, post_pad_len+l1+l2:] = 1 | |
else: | |
# no padding | |
inputs_embeds[i, -l3:] = inputs_embeds_post[i] | |
attention_mask[i, -l3:] = 1 | |
inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i] | |
attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i] | |
inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i] | |
attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i] | |
return inputs_embeds, attention_mask, (l1, l2, l3) | |
def forward(self, data_dict): | |
if self.predict_mode: | |
return self.generate(data_dict=data_dict) | |
""" | |
data_dict requires keys: | |
# input | |
prompt_before_obj: list of str, (B,) | |
prompt_middle_1: list of str, (B,) | |
prompt_middle_2: list of str, (B,) | |
prompt_after_obj: list of str, (B,) | |
obj_fts: (B, N, P, 6), xyz + rgb | |
obj_masks: (B, N), 1 valid and 0 masked | |
obj_locs: (B, N, 6), xyz + whd | |
anchor_locs: (B, 3) | |
anchor_orientation: (B, C) | |
img_fts: (B, 3, H, W), rgb | |
img_masks: (B, 1), 1 valid and 0 masked | |
# output | |
output_gt: list of str, (B,) | |
""" | |
device = self.device | |
bs = len(data_dict['prompt_after_obj']) | |
data_dict['bs'] = bs | |
if 'obj_tokens' not in data_dict: | |
# obtain obj tokens | |
data_dict = self.pcd_encoder(data_dict) | |
# TO CHANGE FOR DEBUG | |
#self.llm_model.float() | |
#data_dict['obj_tokens'] = torch.zeros((data_dict['obj_locs'].shape[0], data_dict['obj_locs'].shape[1], 256)).to(device=device) | |
data_dict['obj_tokens'] = self.pcd_proj(data_dict['obj_tokens'].to(device)) | |
# data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed | |
data_dict['img_tokens'] = self.img_proj(self.img_encoder(data_dict['img_fts'])) | |
# data_dict['img_tokens'] = data_dict['img_tokens'] + self.img_type_embed | |
# build input embdes and record prompt position | |
inputs_embeds, attention_mask, input_length = self.build_right_justified_sequence(data_dict=data_dict) | |
obj_token_length = data_dict['obj_masks'].shape[1] | |
# (B, T1+O+T2, D), (B, T1+O+T2) | |
self.llm_tokenizer.padding_side = 'right' | |
self.llm_tokenizer.truncation_side = 'right' | |
text_output_tokens = self.llm_tokenizer( | |
[t + self.llm_tokenizer.eos_token for t in data_dict['output_gt']], | |
return_tensors='pt', | |
padding='longest', | |
truncation=True, | |
max_length=self.max_out_len, | |
).to(device) | |
# record position for special token [SOS] | |
grd_token_id = self.llm_tokenizer.convert_tokens_to_ids(['<s>'])[0] | |
out_input_ids_remove_first_sos = text_output_tokens.input_ids.clone() | |
out_input_ids_remove_first_sos[:, 0] = -100 | |
grd_ind_0, grd_ind_1 = (out_input_ids_remove_first_sos == grd_token_id).nonzero(as_tuple=True) | |
text_output_embeds = self.llm_model.get_input_embeddings()(text_output_tokens.input_ids) # (B, T3, D) | |
inputs_embeds = torch.cat([inputs_embeds, text_output_embeds], dim=1) # (B, T1+O+T2+T3, D) | |
attention_mask = torch.cat([attention_mask, text_output_tokens.attention_mask], dim=1) # (B, T1+O+T2+T3) | |
# construct targets | |
targets = torch.zeros_like(attention_mask).long().fill_(-100) # (B, T1+O+T2+T3) | |
# only apply loss to answer tokens | |
targets_idx = text_output_tokens.attention_mask.bool() | |
targets[:, -targets_idx.shape[1]:][targets_idx] = text_output_tokens.input_ids[targets_idx] | |
# do not predict bos token, regard it as condition instead | |
targets[:, -targets_idx.shape[1]] = -100 | |
with maybe_autocast(self): | |
outputs = self.llm_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
return_dict=True, | |
output_hidden_states=True, | |
) | |
logits = outputs.logits.float() | |
last_hidden_state = outputs.hidden_states[-1] | |
# different from the loss inside `llm_model.forward`, here we take mean of each sequence instead of sum | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = targets[..., 1:].contiguous() | |
num_tokens_for_loss = (shift_labels >= 0).int().sum(1) # (B,) | |
shift_logits = rearrange(shift_logits, 'b t v -> (b t) v') | |
shift_labels = rearrange(shift_labels, 'b t -> (b t)') | |
shift_labels = shift_labels.to(shift_logits.device) | |
# record for llm loss | |
data_dict['llm_logits'] = shift_logits | |
data_dict['llm_labels'] = shift_labels | |
data_dict['num_tokens_for_loss'] = num_tokens_for_loss | |
# record for grounding loss | |
grd_list = [] | |
obj_list = [] | |
mask_list = [] | |
for step in range(len(grd_ind_0)): | |
batch_ind = grd_ind_0[step] | |
grd_token_ind = grd_ind_1[step] | |
if self.pre_grounding: | |
output_obj_tokens = data_dict['obj_tokens'][batch_ind] | |
else: | |
output_obj_tokens = last_hidden_state[batch_ind, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :] | |
output_grd_tokens = last_hidden_state[batch_ind, sum(input_length) + grd_token_ind:sum(input_length) + grd_token_ind + 1, :] | |
grd_list.append(output_grd_tokens) | |
obj_list.append(output_obj_tokens) | |
mask_list.append(data_dict['obj_masks'][batch_ind]) | |
output_obj = torch.stack(obj_list).float() | |
output_grd = torch.stack(grd_list).float() | |
data_dict['ground_logits'] = self.ground_head(output_obj, output_grd, torch.stack(mask_list)) | |
# data_dict['ground_label'] = torch.concat(data_dict['tgt_object_id'], dim=0) | |
# record for cls loss | |
#obj_cls_post_embeds = last_hidden_state[:, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :].float() | |
obj_cls_post_embeds = data_dict['obj_tokens'].float() | |
data_dict['obj_cls_post_logits'] = self.obj_cls_head(obj_cls_post_embeds) | |
return data_dict | |
def generate( | |
self, | |
data_dict, | |
use_nucleus_sampling=False, | |
num_beams=5, | |
max_length=256, | |
min_length=1, | |
top_p=0.9, | |
repetition_penalty=6.0, | |
length_penalty=1, | |
num_captions=1, | |
temperature=1, | |
): | |
""" | |
data_dict requires the same keys as forward() except output_gt | |
""" | |
device = self.device | |
bs = len(data_dict['prompt_after_obj']) | |
data_dict['bs'] = bs | |
if 'obj_tokens' not in data_dict: | |
# obtain obj tokens | |
data_dict = self.pcd_encoder(data_dict) | |
# TO CHANGE FOR DEBUG | |
#self.llm_model.float() | |
#data_dict['obj_tokens'] = torch.zeros((data_dict['obj_locs'].shape[0], data_dict['obj_locs'].shape[1], 256)).to(device=device) | |
data_dict['obj_tokens'] = self.pcd_proj(data_dict['obj_tokens'].to(device)) | |
# data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed | |
data_dict['img_tokens'] = self.img_proj(self.img_encoder(data_dict['img_fts'])) | |
# data_dict['img_tokens'] = data_dict['img_tokens'] + self.img_type_embed | |
inputs_embeds, attention_mask, input_length = self.build_right_justified_sequence(data_dict=data_dict) | |
obj_token_length = data_dict['obj_masks'].shape[1] | |
# give bos token as condition | |
bos_tokens = self.llm_tokenizer( | |
[self.llm_tokenizer.bos_token] * bs, | |
return_tensors='pt', | |
).to(device) | |
bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1) | |
bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1) | |
# prepare a `bos_token` | |
bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D) | |
inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D) | |
attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1) | |
with maybe_autocast(self): | |
outputs = self.llm_model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
do_sample=use_nucleus_sampling, | |
top_p=top_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
num_return_sequences=num_captions, | |
return_dict_in_generate=True, | |
output_hidden_states=True, | |
output_scores=True | |
) | |
# note output_ids_idx - 1 = step idx, because we do not preduct [BOS] | |
beam_indices = outputs.beam_indices # bs x step, beam indices range (bsxbeam) | |
scores = outputs.scores # step x (bs x beam) x vocab | |
hidden_states = outputs.hidden_states # step x layer x (bs x beam) x token_num x hidden_dim | |
outputs = outputs.sequences # bs x output_ids | |
outputs[outputs == self.llm_tokenizer.unk_token_id] = self.llm_tokenizer.eos_token_id | |
# data_dict['output_tokens'] = outputs # unable to gather variable-length tensors | |
# record for grounding | |
grd_token_id = self.llm_tokenizer.convert_tokens_to_ids(['<s>'])[0] | |
out_input_ids_remove_first_sos = outputs.clone() | |
out_input_ids_remove_first_sos[:, 0] = -100 | |
grd_ind_0, grd_ind_1 = (out_input_ids_remove_first_sos == grd_token_id).nonzero(as_tuple=True) | |
grd_list = [] | |
grd_batch_ind_list = [] | |
obj_list = [] | |
mask_list = [] | |
if len(grd_ind_0) > 0: | |
for step in range(len(grd_ind_0)): | |
batch_ind = grd_ind_0[step] | |
grd_token_ind = grd_ind_1[step] | |
#output_obj_tokens = last_hidden_state[batch_ind, input_length[0] + input_length[1] - obj_token_length : input_length[0] + input_length[1], :] | |
output_obj_tokens = data_dict['obj_tokens'][batch_ind] | |
output_grd_tokens = hidden_states[grd_token_ind-1][-1][beam_indices[batch_ind, grd_token_ind-1]][-1].unsqueeze(0) # grd_token_ind - 1 because first token is sos | |
grd_list.append(output_grd_tokens) | |
grd_batch_ind_list.append(batch_ind) | |
obj_list.append(output_obj_tokens) | |
mask_list.append(data_dict['obj_masks'][batch_ind]) | |
output_obj = torch.stack(obj_list).float() | |
output_grd = torch.stack(grd_list).float() | |
data_dict['ground_logits'] = self.ground_head(output_obj, output_grd, torch.stack(mask_list)) | |
else: | |
data_dict['ground_logits'] = None | |
# data_dict['ground_label'] = torch.concat(data_dict['tgt_object_id'], dim=0) | |
data_dict['grd_batch_ind_list'] = grd_batch_ind_list | |
output_txt = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
output_txt = [txt.strip() for txt in output_txt] | |
data_dict['output_txt'] = output_txt | |
return data_dict | |