FantasticGNU commited on
Commit
e2c6ac2
·
1 Parent(s): 4c92e71

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +2 -2
model/openllama.py CHANGED
@@ -225,14 +225,14 @@ class OpenLLAMAPEFTModel(nn.Module):
225
  # self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16)
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')
232
 
233
  self.llama_proj = nn.Linear(
234
  self.visual_hidden_size, self.llama_model.config.hidden_size
235
- ).to(self.device)
236
 
237
  self.max_tgt_len = max_tgt_len
238
 
 
225
  # self.llama_model.load_state_dict(delta_ckpt, strict=False)
226
  self.llama_model.print_trainable_parameters()
227
 
228
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16).to(self.device)
229
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
230
  self.llama_tokenizer.padding_side = "right"
231
  print ('Language decoder initialized.')
232
 
233
  self.llama_proj = nn.Linear(
234
  self.visual_hidden_size, self.llama_model.config.hidden_size
235
+ ).to(torch.float16).to(self.device)
236
 
237
  self.max_tgt_len = max_tgt_len
238