Add documentation for updated chat template
Browse files
README.md
CHANGED
@@ -21,13 +21,37 @@ should probably proofread and complete it, then remove this comment. -->
|
|
21 |
|
22 |
This model is a fine-tuned version of [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) on [devanshamin/gem-viggo-function-calling](https://huggingface.co/datasets/devanshamin/gem-viggo-function-calling) dataset.
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
## Basic Usage
|
25 |
|
26 |
```python
|
27 |
import torch
|
28 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
29 |
|
30 |
-
# Load the model and the tokenizer
|
31 |
model_id = "Qwen2-1.5B-Instruct-Function-Calling-v1"
|
32 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto")
|
33 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
@@ -39,11 +63,7 @@ def inference(prompt: str) -> str:
|
|
39 |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
40 |
return response
|
41 |
|
42 |
-
|
43 |
-
messages = [
|
44 |
-
{"role": "system", "content": "You are a helpful assistant."},
|
45 |
-
{"role": "user", "content": prompt}
|
46 |
-
]
|
47 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
48 |
response = inference(prompt)
|
49 |
print(response)
|
@@ -55,14 +75,17 @@ print(response)
|
|
55 |
|
56 |
```python
|
57 |
import json
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
return prompt
|
67 |
|
68 |
tool = {
|
@@ -87,7 +110,7 @@ tool = {
|
|
87 |
}
|
88 |
}
|
89 |
input_text = "Founded in 2021, Pluto raised $4 million across multiple seed funding rounds, valuing the company at $12 million (pre-money), according to PitchBook. The startup was backed by investors including Switch Ventures, Caffeinated Capital and Maxime Seguineau."
|
90 |
-
prompt = get_prompt(
|
91 |
response = inference(prompt)
|
92 |
print(response)
|
93 |
# ```json
|
@@ -100,7 +123,7 @@ print(response)
|
|
100 |
# "Caffeinated Capital",
|
101 |
# "Maxime Seguineau"
|
102 |
# ],
|
103 |
-
# "valuation": "
|
104 |
# "source": "PitchBook"
|
105 |
# }
|
106 |
# }
|
@@ -127,7 +150,7 @@ class Classification(BaseModel):
|
|
127 |
function_definition = openai_schema(Classification).openai_schema
|
128 |
tool = dict(type='function', function=function_definition)
|
129 |
input_text = "1,25-dihydroxyvitamin D(3) (1,25(OH)(2)D(3)), the biologically active form of vitamin D, is widely recognized as a modulator of the immune system as well as a regulator of mineral metabolism. The objective of this study was to determine the effects of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice, a murine model of human type I diabetes. We have found that vitamin D-deficiency increases the incidence of diabetes in female mice from 46% (n=13) to 88% (n=8) and from 0% (n=10) to 44% (n=9) in male mice as of 200 days of age when compared to vitamin D-sufficient animals. Addition of 50 ng of 1,25(OH)(2)D(3)/day to the diet prevented disease onset as of 200 days and caused a significant rise in serum calcium levels, regardless of gender or vitamin D status. Our results indicate that vitamin D status is a determining factor of disease susceptibility and oral administration of 1,25(OH)(2)D(3) prevents diabetes onset in NOD mice through 200 days of age."
|
130 |
-
prompt = get_prompt(
|
131 |
output = inference(prompt)
|
132 |
print(output)
|
133 |
# ```json
|
|
|
21 |
|
22 |
This model is a fine-tuned version of [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) on [devanshamin/gem-viggo-function-calling](https://huggingface.co/datasets/devanshamin/gem-viggo-function-calling) dataset.
|
23 |
|
24 |
+
## Updated Chat Template
|
25 |
+
> Note: The template supports multiple tools but the model is fine-tuned on a dataset consisting of a single tool.
|
26 |
+
|
27 |
+
- The chat template has been added to the [tokenizer_config.json](https://huggingface.co/devanshamin/Qwen2-1.5B-Instruct-Function-Calling-v1/blob/7ee7c020cefdb0101939469de608acc2afa7809e/tokenizer_config.json#L34).
|
28 |
+
- Supports prompts with and without tools.
|
29 |
+
|
30 |
+
```python
|
31 |
+
chat_template = (
|
32 |
+
"{% for message in messages %}"
|
33 |
+
"{% if loop.first and messages[0]['role'] != 'system' %}"
|
34 |
+
"{% if tools %}"
|
35 |
+
"<|im_start|>system\nYou are a helpful assistant with access to the following tools. Use them if required - \n"
|
36 |
+
"```json\n{{ tools | tojson }}\n```<|im_end|>\n"
|
37 |
+
"{% else %}"
|
38 |
+
"<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n"
|
39 |
+
"{% endif %}"
|
40 |
+
"{% endif %}"
|
41 |
+
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
42 |
+
"{% endfor %}"
|
43 |
+
"{% if add_generation_prompt %}"
|
44 |
+
"{{ '<|im_start|>assistant\n' }}"
|
45 |
+
"{% endif %}"
|
46 |
+
)
|
47 |
+
```
|
48 |
+
|
49 |
## Basic Usage
|
50 |
|
51 |
```python
|
52 |
import torch
|
53 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
54 |
|
|
|
55 |
model_id = "Qwen2-1.5B-Instruct-Function-Calling-v1"
|
56 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto")
|
57 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
63 |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
64 |
return response
|
65 |
|
66 |
+
messages = [{"role": "user", "content": "What is the speed of light?"}]
|
|
|
|
|
|
|
|
|
67 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
68 |
response = inference(prompt)
|
69 |
print(response)
|
|
|
75 |
|
76 |
```python
|
77 |
import json
|
78 |
+
from typing import List, Dict
|
79 |
+
|
80 |
+
def get_prompt(user_input: str, tools: List[Dict] | None = None):
|
81 |
+
prompt = 'Extract the information from the following - \n{}'.format(user_input)
|
82 |
+
messages = [{"role": "user", "content": prompt}]
|
83 |
+
prompt = tokenizer.apply_chat_template(
|
84 |
+
messages,
|
85 |
+
tokenize=False,
|
86 |
+
add_generation_prompt=True,
|
87 |
+
tools=tools
|
88 |
+
)
|
89 |
return prompt
|
90 |
|
91 |
tool = {
|
|
|
110 |
}
|
111 |
}
|
112 |
input_text = "Founded in 2021, Pluto raised $4 million across multiple seed funding rounds, valuing the company at $12 million (pre-money), according to PitchBook. The startup was backed by investors including Switch Ventures, Caffeinated Capital and Maxime Seguineau."
|
113 |
+
prompt = get_prompt(input_text, tools=[tool])
|
114 |
response = inference(prompt)
|
115 |
print(response)
|
116 |
# ```json
|
|
|
123 |
# "Caffeinated Capital",
|
124 |
# "Maxime Seguineau"
|
125 |
# ],
|
126 |
+
# "valuation": "$12 million",
|
127 |
# "source": "PitchBook"
|
128 |
# }
|
129 |
# }
|
|
|
150 |
function_definition = openai_schema(Classification).openai_schema
|
151 |
tool = dict(type='function', function=function_definition)
|
152 |
input_text = "1,25-dihydroxyvitamin D(3) (1,25(OH)(2)D(3)), the biologically active form of vitamin D, is widely recognized as a modulator of the immune system as well as a regulator of mineral metabolism. The objective of this study was to determine the effects of vitamin D status and treatment with 1,25(OH)(2)D(3) on diabetes onset in non-obese diabetic (NOD) mice, a murine model of human type I diabetes. We have found that vitamin D-deficiency increases the incidence of diabetes in female mice from 46% (n=13) to 88% (n=8) and from 0% (n=10) to 44% (n=9) in male mice as of 200 days of age when compared to vitamin D-sufficient animals. Addition of 50 ng of 1,25(OH)(2)D(3)/day to the diet prevented disease onset as of 200 days and caused a significant rise in serum calcium levels, regardless of gender or vitamin D status. Our results indicate that vitamin D status is a determining factor of disease susceptibility and oral administration of 1,25(OH)(2)D(3) prevents diabetes onset in NOD mice through 200 days of age."
|
153 |
+
prompt = get_prompt(input_text, tools=[tool])
|
154 |
output = inference(prompt)
|
155 |
print(output)
|
156 |
# ```json
|