crystal-technologies commited on
Commit
52925b5
·
1 Parent(s): 714d948

Update Perceptrix/engine.py

Browse files
Files changed (1) hide show
  1. Perceptrix/engine.py +3 -3
Perceptrix/engine.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
2
  from utils import setup_device
3
  import torch
4
  import tqdm
@@ -17,7 +17,7 @@ bnb_config = BitsAndBytesConfig(
17
  bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16
18
  )
19
 
20
- model = AutoModelForCausalLM.from_pretrained(
21
  model_name,
22
  torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
23
  device_map="auto",
@@ -27,7 +27,7 @@ model = AutoModelForCausalLM.from_pretrained(
27
  quantization_config=bnb_config if str(device) != "cpu" else None,
28
  )
29
 
30
- tokenizer = AutoTokenizer.from_pretrained(
31
  model_name,
32
  trust_remote_code=True,
33
  use_fast=True,
 
1
+ from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, GenerationConfig
2
  from utils import setup_device
3
  import torch
4
  import tqdm
 
17
  bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16
18
  )
19
 
20
+ model = LlamaForCausalLM.from_pretrained(
21
  model_name,
22
  torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
23
  device_map="auto",
 
27
  quantization_config=bnb_config if str(device) != "cpu" else None,
28
  )
29
 
30
+ tokenizer = LlamaTokenizer.from_pretrained(
31
  model_name,
32
  trust_remote_code=True,
33
  use_fast=True,