Commit
·
6bac629
1
Parent(s):
e3efad5
Update README.md
Browse files
README.md
CHANGED
@@ -26,34 +26,8 @@ base_model: meta-llama/llama-2-7b-hf
|
|
26 |
## Instruction format
|
27 |
|
28 |
```python
|
|
|
29 |
import torch
|
30 |
-
from transformers import (
|
31 |
-
AutoModelForCausalLM,
|
32 |
-
AutoTokenizer,
|
33 |
-
TextStreamer,
|
34 |
-
StoppingCriteria,
|
35 |
-
StoppingCriteriaList,
|
36 |
-
BitsAndBytesConfig,
|
37 |
-
)
|
38 |
-
|
39 |
-
device = "cuda" # the device to load the model onto
|
40 |
-
model_name = "willnguyen/lacda-2-7B-chat-v0.1"
|
41 |
-
|
42 |
-
bnb_config = BitsAndBytesConfig(
|
43 |
-
load_in_4bit=True,
|
44 |
-
bnb_4bit_quant_type="nf4",
|
45 |
-
bnb_4bit_compute_dtype=torch.float16,
|
46 |
-
bnb_4bit_use_double_quant=True,
|
47 |
-
)
|
48 |
-
model = AutoModelForCausalLM.from_pretrained(
|
49 |
-
model_name, load_in_4bit=True, torch_dtype=torch.float16, quantization_config=bnb_config, device_map="auto"
|
50 |
-
)
|
51 |
-
|
52 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
53 |
-
model_name, cache_dir=None, use_fast=False, padding_side="right", tokenizer_type="llama"
|
54 |
-
)
|
55 |
-
tokenizer.pad_token_id = 0
|
56 |
-
|
57 |
|
58 |
class StopTokenCriteria(StoppingCriteria):
|
59 |
def __init__(self, stop_tokens, tokenizer, prompt_length):
|
@@ -69,20 +43,33 @@ class StopTokenCriteria(StoppingCriteria):
|
|
69 |
|
70 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
71 |
is_done = False
|
72 |
-
tokens = tokenizer.decode(input_ids[0])[self.prompt_length
|
73 |
for st in self.stop_tokens:
|
74 |
if st in tokens:
|
75 |
is_done = True
|
76 |
break
|
77 |
return is_done
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
stopping_criteria = StoppingCriteriaList([StopTokenCriteria(["[INST]", "[/INST]"], tokenizer, len(prompt))])
|
84 |
with torch.inference_mode():
|
85 |
-
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(
|
86 |
streamer = TextStreamer(tokenizer)
|
87 |
_ = model.generate(
|
88 |
input_ids=input_ids,
|
@@ -94,7 +81,6 @@ with torch.inference_mode():
|
|
94 |
repetition_penalty=1.0,
|
95 |
use_cache=True,
|
96 |
streamer=streamer,
|
97 |
-
stopping_criteria=stopping_criteria
|
98 |
)
|
99 |
-
|
100 |
```
|
|
|
26 |
## Instruction format
|
27 |
|
28 |
```python
|
29 |
+
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, TextStreamer, StoppingCriteria, StoppingCriteriaList
|
30 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
class StopTokenCriteria(StoppingCriteria):
|
33 |
def __init__(self, stop_tokens, tokenizer, prompt_length):
|
|
|
43 |
|
44 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
45 |
is_done = False
|
46 |
+
tokens = tokenizer.decode(input_ids[0])[self.prompt_length:]
|
47 |
for st in self.stop_tokens:
|
48 |
if st in tokens:
|
49 |
is_done = True
|
50 |
break
|
51 |
return is_done
|
52 |
|
53 |
+
model_name = "willnguyen/lacda-2-7B-chat-v0.1"
|
54 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
55 |
+
model_name,
|
56 |
+
use_fast=False,
|
57 |
+
padding_side="right",
|
58 |
+
tokenizer_type='llama',
|
59 |
+
)
|
60 |
+
tokenizer.pad_token_id = 0
|
61 |
|
62 |
+
model = AutoModelForCausalLM.from_pretrained(
|
63 |
+
model_name,
|
64 |
+
device_map="auto",
|
65 |
+
torch_dtype=torch.float16,
|
66 |
+
)
|
67 |
+
|
68 |
+
prompt = "<s> [INST] who is Hồ Chí Minh [/INST]"
|
69 |
|
70 |
stopping_criteria = StoppingCriteriaList([StopTokenCriteria(["[INST]", "[/INST]"], tokenizer, len(prompt))])
|
71 |
with torch.inference_mode():
|
72 |
+
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to('cuda')
|
73 |
streamer = TextStreamer(tokenizer)
|
74 |
_ = model.generate(
|
75 |
input_ids=input_ids,
|
|
|
81 |
repetition_penalty=1.0,
|
82 |
use_cache=True,
|
83 |
streamer=streamer,
|
84 |
+
stopping_criteria=stopping_criteria
|
85 |
)
|
|
|
86 |
```
|