Spaces:
Runtime error
Runtime error
gmftbyGMFTBY
commited on
Commit
·
2073756
1
Parent(s):
0c797c5
update codes
Browse files- model/openllama.py +2 -1
model/openllama.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from header import *
|
|
|
2 |
import torch.nn.functional as F
|
3 |
from .ImageBind import *
|
4 |
from .ImageBind import data
|
@@ -101,7 +102,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
101 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
102 |
)
|
103 |
|
104 |
-
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path)
|
105 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
106 |
self.llama_model.print_trainable_parameters()
|
107 |
|
|
|
1 |
from header import *
|
2 |
+
import os
|
3 |
import torch.nn.functional as F
|
4 |
from .ImageBind import *
|
5 |
from .ImageBind import data
|
|
|
102 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
103 |
)
|
104 |
|
105 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN'])
|
106 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
107 |
self.llama_model.print_trainable_parameters()
|
108 |
|