KITT / kitt /core /model.py
sasan's picture
chore: A new more advanced method
60ee11d
raw
history blame
8.6 kB
import json
import re
import uuid
from langchain.memory import ChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.utils.function_calling import convert_to_openai_function
import ollama
from pydantic import BaseModel
from loguru import logger
from kitt.skills import vehicle_status
class FunctionCall(BaseModel):
arguments: dict
"""
The arguments to call the function with, as generated by the model in JSON
format. Note that the model does not always generate valid JSON, and may
hallucinate parameters not defined by your function schema. Validate the
arguments in your code before calling your function.
"""
name: str
"""The name of the function to call."""
schema_json = json.loads(FunctionCall.schema_json())
HRMS_SYSTEM_PROMPT = """<|begin_of_text|>
<|im_start|>system
You are a function calling AI agent with self-recursion.
You can call only one function at a time and analyse data you get from function response.
You are provided with function signatures within <tools></tools> XML tags.
{car_status}
You may use agentic frameworks for reasoning and planning to help with user query.
Please call a function and wait for function results to be provided to you in the next iteration.
Don't make assumptions about what values to plug into function arguments.
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
Analyze the data once you get the results and call another function.
At each iteration please continue adding the your analysis to previous summary.
Your final response should directly answer the user query.
Here are the available tools:
<tools> {tools} </tools>
If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
<tool_call>
{{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
</tool_call>
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
When asked for the weather, lookup the weather for the current location of the car. Unless the user provides a location, then use that location.
If asked about points of interest, use the tools available to you. Do not make up points of interest.
Use the following pydantic model json schema for each tool call you will make:
{schema}
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
If you plan to continue with analysis, always call another function.
For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{{"arguments": <args-dict>, "name": <function-name>}}
</tool_call>
<|im_end|>"""
AI_PREAMBLE = """
<|im_start|>assistant
"""
HRMS_TEMPLATE_USER = """
<|im_start|>user
{user_input}<|im_end|>"""
HRMS_TEMPLATE_ASSISTANT = """
<|im_start|>assistant
{assistant_response}<|im_end|>"""
HRMS_TEMPLATE_TOOL_RESULT = """
<|im_start|>tool
{result}
<|im_end|>"""
def append_message(prompt, h):
if h.type == "human":
prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
elif h.type == "ai":
prompt += HRMS_TEMPLATE_ASSISTANT.format(assistant_response=h.content)
elif h.type == "tool":
prompt += HRMS_TEMPLATE_TOOL_RESULT.format(result=h.content)
return prompt
def get_prompt(template, history, tools, schema, car_status=None):
if not car_status:
# car_status = vehicle.dict()
car_status = vehicle_status()[0]
# "vehicle_status": vehicle_status_fn()[0]
kwargs = {"history": history, "schema": schema, "tools": tools, "car_status": car_status}
prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
if history:
for h in history.messages:
prompt = append_message(prompt, h)
# if input:
# prompt += USER_QUERY_TEMPLATE.format(user_input=input)
return prompt
def use_tool(tool_call, tools):
func_name = tool_call["name"]
kwargs = tool_call["arguments"]
for tool in tools:
if tool.name == func_name:
return tool.invoke(input=kwargs)
return None
def parse_tool_calls(text):
logger.debug(f"Start parsing tool_calls: {text}")
pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
if not text.startswith("<tool_call>"):
return [], []
matches = re.findall(pattern, text, re.DOTALL)
tool_calls = []
errors = []
for match in matches:
try:
tool_call = json.loads(match)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
errors.append(f"Invalid JSON in tool call: {e}")
logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
return tool_calls, errors
def process_response(user_query, res, history, tools, depth):
"""Returns True if the response contains tool calls, False otherwise."""
logger.debug(f"Processing response: {res}")
tool_calls, errors = parse_tool_calls(res)
# TODO: Handle errors
if not tool_calls:
return False
# tool_results = ""
tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
for tool_call in tool_calls:
# TODO: Extra Validation
# Call the function
try:
result = use_tool(tool_call, tools)
if type(result) == tuple:
result = result[1]
tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
except Exception as e:
print(e)
# Currently only to mimic OpneAI's behavior
# But it could be used for tracking function calls
tool_results = tool_results.strip()
print(f"Tool results: {tool_results}")
tool_call_id = uuid.uuid4().hex
history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
return True
def run_inference_step(history, tools, schema_json, dry_run=False):
# If we decide to call a function, we need to generate the prompt for the model
# based on the history of the conversation so far.
# not break the loop
openai_tools = [convert_to_openai_function(tool) for tool in tools]
prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
data = {
"prompt": prompt + AI_PREAMBLE,
# "streaming": False,
# "model": "smangrul/llama-3-8b-instruct-function-calling",
# "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
"model": "interstellarninja/hermes-2-pro-llama-3-8b",
"raw": True,
"options": {"temperature": 0.8,
# "max_tokens": 1500,
"num_predict": 1500,
# "num_predict": 1500,
# "max_tokens": 1500,
}
}
if dry_run:
print(prompt + AI_PREAMBLE)
return "Didn't really run it."
out = ollama.generate(**data)
res = out["response"]
return res
def process_query(user_query: str, history: ChatMessageHistory, tools):
history.add_message(HumanMessage(content=user_query))
for depth in range(10):
out = run_inference_step(history, tools, schema_json)
print(f"Inference step result:\n{out}\n------------------\n")
history.add_message(AIMessage(content=out))
if not process_response(user_query, out, history, tools, depth):
print(f"This is the answer, no more iterations: {out}")
return out
# Otherwise, tools result is already added to history, we just need to continue the loop.
# If we get here something went wrong.
history.add_message(
AIMessage(content="Sorry, I am not sure how to help you with that.")
)
return "Sorry, I am not sure how to help you with that."