Spaces:
Runtime error
Runtime error
Commit
·
a7fa8fe
1
Parent(s):
b0c3916
Update model/openllama.py
Browse files- model/openllama.py +7 -1
model/openllama.py
CHANGED
@@ -10,6 +10,8 @@ import kornia as K
|
|
10 |
|
11 |
import torch
|
12 |
from torch.nn.utils import rnn
|
|
|
|
|
13 |
|
14 |
CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object',
|
15 |
'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum']
|
@@ -203,7 +205,11 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
203 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
204 |
)
|
205 |
|
206 |
-
|
|
|
|
|
|
|
|
|
207 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
208 |
self.llama_model.print_trainable_parameters()
|
209 |
|
|
|
10 |
|
11 |
import torch
|
12 |
from torch.nn.utils import rnn
|
13 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
14 |
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
15 |
|
16 |
CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', 'object',
|
17 |
'candle', 'cashew', 'chewinggum', 'fryum', 'macaroni', 'pcb', 'pipe fryum']
|
|
|
205 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
206 |
)
|
207 |
|
208 |
+
config = AutoConfig.from_pretrained(vicuna_ckpt_path)
|
209 |
+
with init_empty_weights():
|
210 |
+
self.llama_model = AutoModelForCausalLM.from_config(config)
|
211 |
+
self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map="auto", no_split_module_classes=["OPTDecoderLayer"])
|
212 |
+
# self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|
213 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
214 |
self.llama_model.print_trainable_parameters()
|
215 |
|