LLM-ADE-dev / src /functioncall.py
WilliamGazeley
Migrate to loguru
691fc98
raw
history blame
9.05 kB
import argparse
import torch
import json
from config import config
from typing import List, Dict
from logger import logger
from transformers import AutoTokenizer
import functions
from prompter import PromptManager
from validator import validate_function_call_schema
from langchain_community.chat_models import ChatOllama
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from utils import (
get_chat_template,
validate_and_extract_tool_calls
)
class ModelInference:
def __init__(self, chat_template: str):
self.prompter = PromptManager()
self.model = ChatOllama(model=config.ollama_model, temperature=0.0, format='json')
template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
chain = template | self.model | StrOutputParser()
self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
if self.tokenizer.chat_template is None:
print("No chat template defined, getting chat_template...")
self.tokenizer.chat_template = get_chat_template(chat_template)
logger.info(f"Model loaded: {self.model}")
def process_completion_and_validate(self, completion, chat_template):
if completion:
# completion = f"<tool_call>\n{completion}\n</tool_call>"
breakpoint()
validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
if validation:
logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
return tool_calls, completion, error_message
else:
tool_calls = None
return tool_calls, completion, error_message
else:
logger.warning("Assistant message is None")
raise ValueError("Assistant message is None")
def execute_function_call(self, tool_call):
config.status.update(label=":mag: Gathering information..")
function_name = tool_call.get("name")
function_to_call = getattr(functions, function_name, None)
function_args = tool_call.get("arguments", {})
logger.info(f"Invoking function call {function_name} ...")
function_response = function_to_call(*function_args.values())
results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
return results_dict
def run_inference(self, prompt: List[Dict[str, str]]):
inputs = self.tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
tokenize=False,
)
completion = self.model.invoke(inputs, format='json')
return completion.content
def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
try:
depth = 0
user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
chat = [{"role": "user", "content": user_message}]
tools = functions.get_openai_tools()
prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
config.status.update(label=":brain: Thinking..")
completion = self.run_inference(prompt)
def recursive_loop(prompt, completion, depth):
nonlocal max_depth
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
prompt.append({"role": "assistant", "content": assistant_message})
tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
logger.info(f"Found tool calls: {tool_calls}")
if tool_calls:
logger.info(f"Assistant Message:\n{assistant_message}")
for tool_call in tool_calls:
validation, message = validate_function_call_schema(tool_call, tools)
if validation:
try:
function_response = self.execute_function_call(tool_call)
tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
except Exception as e:
logger.info(f"Could not execute function: {e}")
tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
else:
logger.info(message)
tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
prompt.append({"role": "tool", "content": tool_message})
depth += 1
if depth >= max_depth:
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
completion = self.run_inference(prompt)
return completion
config.status.update(label=":brain: Analysing information..")
completion = self.run_inference(prompt)
return recursive_loop(prompt, completion, depth)
elif error_message:
logger.info(f"Assistant Message:\n{assistant_message}")
tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
prompt.append({"role": "tool", "content": tool_message})
depth += 1
if depth >= max_depth:
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
return completion
completion = self.run_inference(prompt)
return recursive_loop(prompt, completion, depth)
else:
logger.info(f"Assistant Message:\n{assistant_message}")
return assistant_message
return recursive_loop(prompt, completion, depth)
except Exception as e:
logger.error(f"Exception occurred: {e}")
raise e