Spaces:
Runtime error
Runtime error
Commit
·
2f64416
1
Parent(s):
cdd6d8b
Update model/openllama.py
Browse files- model/openllama.py +3 -1
model/openllama.py
CHANGED
@@ -208,7 +208,9 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
208 |
config = AutoConfig.from_pretrained(vicuna_ckpt_path)
|
209 |
with init_empty_weights():
|
210 |
self.llama_model = AutoModelForCausalLM.from_config(config)
|
211 |
-
|
|
|
|
|
212 |
self.llama_model.to(torch.float16)
|
213 |
# try:
|
214 |
# self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|
|
|
208 |
config = AutoConfig.from_pretrained(vicuna_ckpt_path)
|
209 |
with init_empty_weights():
|
210 |
self.llama_model = AutoModelForCausalLM.from_config(config)
|
211 |
+
|
212 |
+
device_map = infer_auto_device_map(self.llama_model, no_split_module_classes=["OPTDecoderLayer"], dtype="float16")
|
213 |
+
self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
|
214 |
self.llama_model.to(torch.float16)
|
215 |
# try:
|
216 |
# self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|