yingbei commited on
Commit
53537bd
·
verified ·
1 Parent(s): 4ddae95

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +17 -20
README.md CHANGED
@@ -200,28 +200,24 @@ def run_model(messages, functions):
200
  ## Format messages in Rubra's format
201
  formatted_msgs = preprocess_input(msgs=messages, tools=functions)
202
 
203
- input_ids = tokenizer.apply_chat_template(
204
  formatted_msgs,
205
- add_generation_prompt=True,
206
- return_tensors="pt"
207
- ).to(model.device)
 
208
 
209
- terminators = [
210
- tokenizer.eos_token_id,
211
- tokenizer.convert_tokens_to_ids("")
 
 
 
212
  ]
213
 
214
- outputs = model.generate(
215
- input_ids,
216
- max_new_tokens=1000,
217
- eos_token_id=terminators,
218
- do_sample=True,
219
- temperature=0.1,
220
- top_p=0.9,
221
- )
222
- response = outputs[0][input_ids.shape[-1]:]
223
- raw_output = tokenizer.decode(response, skip_special_tokens=True)
224
- return raw_output
225
 
226
  raw_output = run_model(messages, functions)
227
  # Check if there's a function call
@@ -245,9 +241,10 @@ if function_call:
245
  messages.append({"role": "assistant", "tool_calls": function_call})
246
  # append the result of the tool call in openai format, in this case, the value of add 6 to 4 is 10.
247
  messages.append({'role': 'tool', 'tool_call_id': function_call[0]["id"], 'name': function_call[0]["function"]["name"], 'content': '10'})
248
- raw_output = run_model(messages, functions)
249
  # Check if there's a function call
250
- function_call = postprocess_output(raw_output)
 
251
  if function_call:
252
  print(function_call)
253
  else:
 
200
  ## Format messages in Rubra's format
201
  formatted_msgs = preprocess_input(msgs=messages, tools=functions)
202
 
203
+ text = tokenizer.apply_chat_template(
204
  formatted_msgs,
205
+ tokenize=False,
206
+ add_generation_prompt=True
207
+ )
208
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
209
 
210
+ generated_ids = model.generate(
211
+ model_inputs.input_ids,
212
+ max_new_tokens=512
213
+ )
214
+ generated_ids = [
215
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
216
  ]
217
 
218
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
219
+ return response
220
+
 
 
 
 
 
 
 
 
221
 
222
  raw_output = run_model(messages, functions)
223
  # Check if there's a function call
 
241
  messages.append({"role": "assistant", "tool_calls": function_call})
242
  # append the result of the tool call in openai format, in this case, the value of add 6 to 4 is 10.
243
  messages.append({'role': 'tool', 'tool_call_id': function_call[0]["id"], 'name': function_call[0]["function"]["name"], 'content': '10'})
244
+ raw_output1 = run_model(messages, functions)
245
  # Check if there's a function call
246
+
247
+ function_call = postprocess_output(raw_output1)
248
  if function_call:
249
  print(function_call)
250
  else: