FantasticGNU commited on
Commit
a7fa8fe
·
1 Parent(s): b0c3916

Update model/openllama.py

Browse files
Files changed (1) hide show
  1. model/openllama.py +7 -1
model/openllama.py CHANGED
@@ -10,6 +10,8 @@ import kornia as K
10
 
11
  import torch
12
  from torch.nn.utils import rnn
 
 
13
 
14
  CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object',
15
  'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum']
@@ -203,7 +205,11 @@ class OpenLLAMAPEFTModel(nn.Module):
203
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
204
  )
205
 
206
- self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
 
 
 
 
207
  self.llama_model = get_peft_model(self.llama_model, peft_config)
208
  self.llama_model.print_trainable_parameters()
209
 
 
10
 
11
  import torch
12
  from torch.nn.utils import rnn
13
+ from transformers import AutoConfig, AutoModelForCausalLM
14
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
15
 
16
  CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object',
17
  'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum']
 
205
  target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
206
  )
207
 
208
+ config = AutoConfig.from_pretrained(vicuna_ckpt_path)
209
+ with init_empty_weights():
210
+ self.llama_model = AutoModelForCausalLM.from_config(config)
211
+ self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map="auto", no_split_module_classes=["OPTDecoderLayer"])
212
+ # self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
213
  self.llama_model = get_peft_model(self.llama_model, peft_config)
214
  self.llama_model.print_trainable_parameters()
215