Update README.md
Browse files
README.md
CHANGED
@@ -200,28 +200,24 @@ def run_model(messages, functions):
|
|
200 |
## Format messages in Rubra's format
|
201 |
formatted_msgs = preprocess_input(msgs=messages, tools=functions)
|
202 |
|
203 |
-
|
204 |
formatted_msgs,
|
205 |
-
|
206 |
-
|
207 |
-
)
|
|
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
212 |
]
|
213 |
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
eos_token_id=terminators,
|
218 |
-
do_sample=True,
|
219 |
-
temperature=0.1,
|
220 |
-
top_p=0.9,
|
221 |
-
)
|
222 |
-
response = outputs[0][input_ids.shape[-1]:]
|
223 |
-
raw_output = tokenizer.decode(response, skip_special_tokens=True)
|
224 |
-
return raw_output
|
225 |
|
226 |
raw_output = run_model(messages, functions)
|
227 |
# Check if there's a function call
|
@@ -245,9 +241,10 @@ if function_call:
|
|
245 |
messages.append({"role": "assistant", "tool_calls": function_call})
|
246 |
# append the result of the tool call in openai format, in this case, the value of add 6 to 4 is 10.
|
247 |
messages.append({'role': 'tool', 'tool_call_id': function_call[0]["id"], 'name': function_call[0]["function"]["name"], 'content': '10'})
|
248 |
-
|
249 |
# Check if there's a function call
|
250 |
-
|
|
|
251 |
if function_call:
|
252 |
print(function_call)
|
253 |
else:
|
|
|
200 |
## Format messages in Rubra's format
|
201 |
formatted_msgs = preprocess_input(msgs=messages, tools=functions)
|
202 |
|
203 |
+
text = tokenizer.apply_chat_template(
|
204 |
formatted_msgs,
|
205 |
+
tokenize=False,
|
206 |
+
add_generation_prompt=True
|
207 |
+
)
|
208 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
209 |
|
210 |
+
generated_ids = model.generate(
|
211 |
+
model_inputs.input_ids,
|
212 |
+
max_new_tokens=512
|
213 |
+
)
|
214 |
+
generated_ids = [
|
215 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
216 |
]
|
217 |
|
218 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
219 |
+
return response
|
220 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
raw_output = run_model(messages, functions)
|
223 |
# Check if there's a function call
|
|
|
241 |
messages.append({"role": "assistant", "tool_calls": function_call})
|
242 |
# append the result of the tool call in openai format, in this case, the value of add 6 to 4 is 10.
|
243 |
messages.append({'role': 'tool', 'tool_call_id': function_call[0]["id"], 'name': function_call[0]["function"]["name"], 'content': '10'})
|
244 |
+
raw_output1 = run_model(messages, functions)
|
245 |
# Check if there's a function call
|
246 |
+
|
247 |
+
function_call = postprocess_output(raw_output1)
|
248 |
if function_call:
|
249 |
print(function_call)
|
250 |
else:
|