import argparse import torch import json from config import config from typing import List, Dict from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig ) import functions from prompter import PromptManager from validator import validate_function_call_schema from utils import ( inference_logger, get_assistant_message, get_chat_template, validate_and_extract_tool_calls ) class ModelInference: def __init__(self, chat_template: str, load_in_4bit: bool = False): self.prompter = PromptManager() self.bnb_config = None if load_in_4bit == "True": # Never use this self.bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) self.model = AutoModelForCausalLM.from_pretrained( config.hf_model, trust_remote_code=True, return_dict=True, quantization_config=self.bnb_config, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto", ) self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "left" if self.tokenizer.chat_template is None: print("No chat template defined, getting chat_template...") self.tokenizer.chat_template = get_chat_template(chat_template) inference_logger.info(self.model.config) inference_logger.info(self.model.generation_config) inference_logger.info(self.tokenizer.special_tokens_map) def process_completion_and_validate(self, completion, chat_template): assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token) if assistant_message: validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message) if validation: inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}") return tool_calls, assistant_message, error_message else: tool_calls = None return tool_calls, assistant_message, error_message else: inference_logger.warning("Assistant message is None") raise ValueError("Assistant message is None") def execute_function_call(self, tool_call): function_name = tool_call.get("name") function_to_call = getattr(functions, function_name, None) function_args = tool_call.get("arguments", {}) inference_logger.info(f"Invoking function call {function_name} ...") function_response = function_to_call(*function_args.values()) results_dict = f'{{"name": "{function_name}", "content": {function_response}}}' return results_dict def run_inference(self, prompt: List[Dict[str, str]]): inputs = self.tokenizer.apply_chat_template( prompt, add_generation_prompt=True, return_tensors='pt' ) tokens = self.model.generate( inputs.to(self.model.device), max_new_tokens=1500, temperature=0.8, repetition_penalty=1.2, do_sample=True, eos_token_id=self.tokenizer.eos_token_id ) completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True) return completion def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5): try: depth = 0 user_message = f"{query}\nThis is the first turn and you don't have to analyze yet" chat = [{"role": "user", "content": user_message}] tools = functions.get_openai_tools() prompt = self.prompter.generate_prompt(chat, tools, num_fewshot) completion = self.run_inference(prompt) def recursive_loop(prompt, completion, depth): nonlocal max_depth tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template) prompt.append({"role": "assistant", "content": assistant_message}) tool_message = f"Agent iteration {depth} to assist with user query: {query}\n" if tool_calls: inference_logger.info(f"Assistant Message:\n{assistant_message}") for tool_call in tool_calls: validation, message = validate_function_call_schema(tool_call, tools) if validation: try: function_response = self.execute_function_call(tool_call) tool_message += f"\n{function_response}\n\n" inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}") except Exception as e: inference_logger.info(f"Could not execute function: {e}") tool_message += f"\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags \n\n" else: inference_logger.info(message) tool_message += f"\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags \n\n" prompt.append({"role": "tool", "content": tool_message}) depth += 1 if depth >= max_depth: print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") completion = self.run_inference(prompt) return completion completion = self.run_inference(prompt) return recursive_loop(prompt, completion, depth) elif error_message: inference_logger.info(f"Assistant Message:\n{assistant_message}") tool_message += f"\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax" prompt.append({"role": "tool", "content": tool_message}) depth += 1 if depth >= max_depth: print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") return completion completion = self.run_inference(prompt) return recursive_loop(prompt, completion, depth) else: inference_logger.info(f"Assistant Message:\n{assistant_message}") return assistant_message return recursive_loop(prompt, completion, depth) except Exception as e: inference_logger.error(f"Exception occurred: {e}") raise e