LLM-ADE-dev / src /functioncall.py
WilliamGazeley
Fix recursion bug
b401ec2
raw
history blame
7.61 kB
import argparse
import torch
import json
from config import config
from typing import List, Dict
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
)
import functions
from prompter import PromptManager
from validator import validate_function_call_schema
from utils import (
inference_logger,
get_assistant_message,
get_chat_template,
validate_and_extract_tool_calls
)
class ModelInference:
def __init__(self, chat_template: str, load_in_4bit: bool = False):
self.prompter = PromptManager()
self.bnb_config = None
if load_in_4bit == "True": # Never use this
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
config.hf_model,
trust_remote_code=True,
return_dict=True,
quantization_config=self.bnb_config,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
device_map="auto",
)
self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
if self.tokenizer.chat_template is None:
print("No chat template defined, getting chat_template...")
self.tokenizer.chat_template = get_chat_template(chat_template)
inference_logger.info(self.model.config)
inference_logger.info(self.model.generation_config)
inference_logger.info(self.tokenizer.special_tokens_map)
def process_completion_and_validate(self, completion, chat_template):
assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
if assistant_message:
validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message)
if validation:
inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
return tool_calls, assistant_message, error_message
else:
tool_calls = None
return tool_calls, assistant_message, error_message
else:
inference_logger.warning("Assistant message is None")
raise ValueError("Assistant message is None")
def execute_function_call(self, tool_call):
function_name = tool_call.get("name")
function_to_call = getattr(functions, function_name, None)
function_args = tool_call.get("arguments", {})
inference_logger.info(f"Invoking function call {function_name} ...")
function_response = function_to_call(*function_args.values())
results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
return results_dict
def run_inference(self, prompt: List[Dict[str, str]]):
inputs = self.tokenizer.apply_chat_template(
prompt,
add_generation_prompt=True,
return_tensors='pt'
)
tokens = self.model.generate(
inputs.to(self.model.device),
max_new_tokens=1500,
temperature=0.8,
repetition_penalty=1.2,
do_sample=True,
eos_token_id=self.tokenizer.eos_token_id
)
completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
return completion
def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
try:
depth = 0
user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
chat = [{"role": "user", "content": user_message}]
tools = functions.get_openai_tools()
prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
completion = self.run_inference(prompt)
def recursive_loop(prompt, completion, depth):
nonlocal max_depth
tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
prompt.append({"role": "assistant", "content": assistant_message})
tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
if tool_calls:
inference_logger.info(f"Assistant Message:\n{assistant_message}")
for tool_call in tool_calls:
validation, message = validate_function_call_schema(tool_call, tools)
if validation:
try:
function_response = self.execute_function_call(tool_call)
tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
except Exception as e:
inference_logger.info(f"Could not execute function: {e}")
tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
else:
inference_logger.info(message)
tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
prompt.append({"role": "tool", "content": tool_message})
depth += 1
if depth >= max_depth:
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
completion = self.run_inference(prompt)
return completion
completion = self.run_inference(prompt)
return recursive_loop(prompt, completion, depth)
elif error_message:
inference_logger.info(f"Assistant Message:\n{assistant_message}")
tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
prompt.append({"role": "tool", "content": tool_message})
depth += 1
if depth >= max_depth:
print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
return completion
completion = self.run_inference(prompt)
return recursive_loop(prompt, completion, depth)
else:
inference_logger.info(f"Assistant Message:\n{assistant_message}")
return assistant_message
return recursive_loop(prompt, completion, depth)
except Exception as e:
inference_logger.error(f"Exception occurred: {e}")
raise e