Commit
·
52925b5
1
Parent(s):
714d948
Update Perceptrix/engine.py
Browse files- Perceptrix/engine.py +3 -3
Perceptrix/engine.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from transformers import
|
2 |
from utils import setup_device
|
3 |
import torch
|
4 |
import tqdm
|
@@ -17,7 +17,7 @@ bnb_config = BitsAndBytesConfig(
|
|
17 |
bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16
|
18 |
)
|
19 |
|
20 |
-
model =
|
21 |
model_name,
|
22 |
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
|
23 |
device_map="auto",
|
@@ -27,7 +27,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
27 |
quantization_config=bnb_config if str(device) != "cpu" else None,
|
28 |
)
|
29 |
|
30 |
-
tokenizer =
|
31 |
model_name,
|
32 |
trust_remote_code=True,
|
33 |
use_fast=True,
|
|
|
1 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, GenerationConfig
|
2 |
from utils import setup_device
|
3 |
import torch
|
4 |
import tqdm
|
|
|
17 |
bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16
|
18 |
)
|
19 |
|
20 |
+
model = LlamaForCausalLM.from_pretrained(
|
21 |
model_name,
|
22 |
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
|
23 |
device_map="auto",
|
|
|
27 |
quantization_config=bnb_config if str(device) != "cpu" else None,
|
28 |
)
|
29 |
|
30 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
31 |
model_name,
|
32 |
trust_remote_code=True,
|
33 |
use_fast=True,
|