Adding tool call support in chat template

#13

Added tool calling feature in it. Kindly check for it.

For now I am testing for single tool call

Still testing this will update with some testing scripts.

This version uses a merge of the latest small model with TabbyAPI.
Not sure about tool_call, but the original one failed with an error.
ERROR: jinja2.exceptions.UndefinedError: 'strftime_now' is undefined
ERROR: Sent to request: TemplateError: 'strftime_now' is undefined

  "chat_template": "{%- set default_system_message = \"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\\nYour knowledge base was last updated on 2023-10-01.\\n\\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\\nIf the user's question is not clear, ambiguous, or does not provide enough context, you ask for clarification.\" %} {%- if not tools is defined %}{%- set tools = none %}{%- endif %} {%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set system_message = default_system_message %}{%- set loop_messages = messages %}{%- endif %} {{ bos_token }}[SYSTEM_PROMPT]{{ system_message }}[/SYSTEM_PROMPT] {%- set ns = namespace() %}{%- set ns.index = 0 %}{%- for message in loop_messages %}{%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}{%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}{{ raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}{%- endif %}{%- set ns.index = ns.index + 1 %}{%- endif %}{%- endfor %} {%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %} {%- for message in loop_messages %}{%- if message[\"role\"] == \"user\" %}{%- if tools is not none and (message == user_messages[-1]) %}[AVAILABLE_TOOLS] [{%- for tool_wrapper in tools %}{%- set tool = tool_wrapper.function %}{\"type\": \"function\", \"function\": { {%- for key, val in tool.items() if key != \"return\" %}{%- if val is string %}\"{{ key }}\": \"{{ val }}\"{%- else %}\"{{ key }}\": {{ val|tojson }}{%- endif %}{%- if not loop.last %}, {% endif %}{%- endfor %}}}{%- if not loop.last %}, {% endif %}{%- endfor %}] [/AVAILABLE_TOOLS]{%- endif %} [INST]{{ message[\"content\"] }}[/INST] {%- elif message.tool_calls is defined and message.tool_calls is not none %}[TOOL_CALLS] [{%- for tool_call in message.tool_calls %}{%- set out = tool_call.function | tojson %}{{ out[:-1] }}, \"id\": \"{{ tool_call.id }}\"}{%- if not loop.last %}, {% else %}]{{ eos_token }}{% endif %}{%- endfor %} {%- elif message[\"role\"] == \"assistant\" %}{{ message[\"content\"] }}{{ eos_token }} {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}{%- if message.content is defined and message.content.content is defined %}{%- set content = message.content.content %}{%- else %}{%- set content = message.content %}{%- endif %}[TOOL_RESULTS] {\"content\": {{ content|string }}, \"call_id\": \"{{ message.tool_call_id }}\"}[/TOOL_RESULTS] {%- elif message[\"role\"] == \"system\" %}[SYSTEM_PROMPT]{{ message[\"content\"] }}[/SYSTEM_PROMPT] {%- else %}{{ raise_exception(\"Only user, system, assistant, tool_calls, tool_results, and tool roles are supported!\") }}{%- endif %}{%- endfor %}",

I have updated it once again and its running fine by me

message_0 = [
    {"role": "system", "content": "You're a friendly and helpful assistant."},
    {"role": "user", "content": "Hello! How are you today?"}
]

message_1  = [
    {"role": "system", "content": "You're a weather bot that provides real-time weather updates."},
    {"role": "user", "content": "What's the temperature in New York?"},
    {"role": "assistant", "tool_calls": [
        {
            "function": {
                "name": "get_weather",
                "arguments": {"location": "New York"}
            }
        }
    ]},
    {"role": "user", "content": "What's the temperature in New delhi."}
]


message_2  = [
    {"role": "system", "content": "You're a weather bot that provides real-time weather updates."},
    {"role": "user", "content": "What's the temperature in New York?"},
    {"role": "assistant", "tool_calls": [
        {
            "function": {
                "name": "get_weather",
                "arguments": {"location": "New York"}
            }
        }
    ]},
    {"role": "tool", "content": "It's currently 5°C in New York."},
    {"role": "user", "content": "What's the temperature in New delhi."}
]


##parallel

              

message_3 = [
    {"role": "system", "content": "You're a helpful assistant."},
    {"role": "user", "content": "What is the weather in Paris and what time does the flight to Tokyo depart?"},
    {"role": "assistant", "tool_calls": [
        {"function": {"name": "get_weather", "arguments": {"location": "Paris"}}},
        {"function": {"name": "get_flight_schedule", "arguments": {"destination": "Tokyo"}}}
    ]},
    {"role": "user", "content": "What is the weather in california and what time does the flight to las vegas depart? "},
]

message_4= [
    {"role": "system", "content": "You're a helpful assistant."},
    {"role": "user", "content": "What is the weather in Paris and what time does the flight to Tokyo depart?"},
    {"role": "assistant", "tool_calls": [
        {"function": {"name": "get_weather", "arguments": {"location": "Paris"}}},
        {"function": {"name": "get_flight_schedule", "arguments": {"destination": "Tokyo"}}}
    ]},
    {"role": "tool", "content": "22°C"},
    {"role": "tool", "content": "Flight departs at 10:00 AM."},
    {"role": "user", "content": "What is the weather in california and what time does the flight to las vegas depart? "},
]

Tried on this messages using hugging face code

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
# from bitsandbytes import load_model
import torch

model_name = 'mistral_testing/Mistral-Small-24B-Instruct-2501'
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
         model_name,
          quantization_config=quant_config,
          device_map="auto",
          
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
def chat_with_bot(message):
    # Apply chat template to format conversation history
    prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False)
    
    # Convert prompt into input_ids
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
    
    # Get the length of input for later trimming
    input_length = input_ids.shape[-1]
    
    # Generate model output
    output_encode = model.generate(
        input_ids=input_ids, 
        pad_token_id=tokenizer.pad_token_id, 
        max_length=1024
    )
    
    # Decode output to text, starting from input_length to exclude the prompt
    full_output = tokenizer.decode(output_encode[0], skip_special_tokens=False)
    
    # Find the actual response by getting everything after the last input
    response = full_output[len(prompt):]
    
    return full_output ,response.strip() 



    
# Get the model's response

full_output ,response = chat_with_bot(message_4)

# Output the assistant's response
print(f"Bot: {response}")
print("\nAll output", full_output)

@patrickvonplaten kindly have a look I am working on it for past 4-5 hrs. If you want some other changes happy to do that.

Earlier version response for message_1
image.png

Current version response for message_1

image.png

Done one more testing

import requests
import json
from huggingface_hub import hf_hub_download
from datetime import datetime, timedelta

url = "http://0.0.0.0:8000/v1/chat/completions"
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}

model = "/home/jupyter-navanit/mistral_testing/Mistral-Small-24B-Instruct-2501"
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city to find the weather for, e.g. 'San Francisco'",
                    },
                    "state": {
                        "type": "string",
                        "description": "The state abbreviation, e.g. 'CA' for California",
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit for temperature",
                        "enum": ["celsius", "fahrenheit"],
                    },
                },
                "required": ["city", "state", "unit"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "rewrite",
            "description": "Rewrite a given text for improved clarity",
            "parameters": {
                "type": "object",
                "properties": {
                    "text": {
                        "type": "string",
                        "description": "The input text to rewrite",
                    }
                },
            },
        },
    },
]

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {
        "role": "user",
        "content": "Could you please make the below article more concise?\n\nOpenAI is an artificial intelligence research laboratory consisting of the non-profit OpenAI Incorporated and its for-profit subsidiary corporation OpenAI Limited Partnership.",
    },
    {
        "role": "assistant",
        "content": "",
        "tool_calls": [
            {
                "id": "bbc5b7ede",
                "type": "function",
                "function": {
                    "name": "rewrite",
                    "arguments": '{"text": "OpenAI is an artificial intelligence research laboratory consisting of the non-profit OpenAI Incorporated and its for-profit subsidiary corporation OpenAI Limited Partnership."}',
                },
            }
        ],
    },
    {
        "role": "tool",
        "content": '{"action":"rewrite","outcome":"OpenAI is a FOR-profit company."}',
        "tool_call_id": "bbc5b7ede",
        "name": "rewrite",
    },
    {
        "role": "assistant",
        "content": "---\n\nOpenAI is a FOR-profit company.",
    },
    {
        "role": "user",
        "content": "Can you tell me what the temperature will be in Dallas, in Fahrenheit?",
    },
]

data = {"model": model, "messages": messages, "tools": tools}

response = requests.post(url, headers=headers, data=json.dumps(data))
print(response.json()["choices"][0]["message"]["tool_calls"])

Output reply

image.png

@patrickvonplaten Kindly review if you can

if this PR works, would be awesome to merge it!

@LHC88 Yes I have tested it on some queries.
@patrickvonplaten kindly review it

Mistral AI_ org

Thanks a lot for working on it! cc @cyrilvallez any chance this could be tested on your side?

Mistral AI_ org

cc @Navanit-AI any chance you could run the official reference code with vLLM: https://huggingface.co/mistralai/Mistral-Small-24B-Instruct-2501#function-calling and make sure things match 1-to-1?

Sure give me a moment .
Just FYI I have A6000 and using some quantization to run this code on vLLM.
So just the LLM part has to be changed and will give the output.

image.png

@patrickvonplaten ran on the official one looks good to me

strftime_now breaks text-generation-inference setup.

https://huggingface.co/mistralai/Mistral-Small-24B-Instruct-2501/discussions/17#67a07da06934f9aa1c937967

as mentioned here it might be TGI error since its working in Vllm or transformers

Exllama2 breaks too.

Yes, TGI is being patched to support the template!

@Navanit-AI what template type has to be set? llama3_json or mistral?

@LHC88 Sorry I didn't get you, template types as in I think we are using mistral so the template type will be mistral only as I have used the hf transformer library where the apply chat template was for the given template.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment