import argparse import torch import json from config import config from typing import List, Dict from logger import logger from transformers import AutoTokenizer import functions from prompter import PromptManager from validator import validate_function_call_schema from langchain_community.chat_models import ChatOllama from langchain_community.llms import Ollama from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from utils import ( get_chat_template, validate_and_extract_tool_calls ) class ModelInference: def __init__(self, chat_template: str): self.prompter = PromptManager() self.model = Ollama(model=config.ollama_model, temperature=0.0, format='json') template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within XML tags as follows:\n\n{"arguments": , "name": }\n<|im_end|>\n""", input_variables=["question"]) chain = template | self.model | StrOutputParser() self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "left" if self.tokenizer.chat_template is None: print("No chat template defined, getting chat_template...") self.tokenizer.chat_template = get_chat_template(chat_template) logger.info(f"Model loaded: {self.model}") def process_completion_and_validate(self, completion, chat_template): if completion: # completion = f"\n{completion}\n"] validation, tool_calls, error_message = validate_and_extract_tool_calls(completion) if validation: logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}") return tool_calls, completion, error_message else: tool_calls = None return tool_calls, completion, error_message else: logger.warning("Assistant message is None") raise ValueError("Assistant message is None") def execute_function_call(self, tool_call): # config.status.update(label=":mag: Gathering information..") function_name = tool_call.get("name") function_to_call = getattr(functions, function_name, None) function_args = tool_call.get("arguments", {}) logger.info(f"Invoking function call {function_name} ...") function_response = function_to_call(*function_args.values()) results_dict = f'{{"name": "{function_name}", "content": {function_response}}}' return results_dict def run_inference(self, prompt: List[Dict[str, str]]): inputs = self.tokenizer.apply_chat_template( prompt, add_generation_prompt=True, tokenize=False, ) inputs = inputs.replace("<|begin_of_text|>", "") # Something wrong with the chat template, hotfix completion = self.model.invoke(inputs, format='json') return completion.content def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5): try: depth = 0 user_message = f"{query}\nThis is the first turn and you don't have to analyze yet" chat = [{"role": "user", "content": user_message}] tools = functions.get_openai_tools() prompt = self.prompter.generate_prompt(chat, tools, num_fewshot) # config.status.update(label=":brain: Thinking..") completion = self.run_inference(prompt) def recursive_loop(prompt, completion, depth): nonlocal max_depth tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template) prompt.append({"role": "assistant", "content": assistant_message}) tool_message = f"Agent iteration {depth} to assist with user query: {query}\n" logger.info(f"Found tool calls: {tool_calls}") if tool_calls: logger.info(f"Assistant Message:\n{assistant_message}") for tool_call in tool_calls: validation, message = validate_function_call_schema(tool_call, tools) if validation: try: function_response = self.execute_function_call(tool_call) tool_message += f"\n{function_response}\n\n" logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}") except Exception as e: logger.info(f"Could not execute function: {e}") tool_message += f"\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags \n\n" else: logger.info(message) tool_message += f"\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags \n\n" prompt.append({"role": "tool", "content": tool_message}) depth += 1 if depth >= max_depth: print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") completion = self.run_inference(prompt) return completion # config.status.update(label=":brain: Analysing information..") completion = self.run_inference(prompt) return recursive_loop(prompt, completion, depth) elif error_message: logger.info(f"Assistant Message:\n{assistant_message}") tool_message += f"\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax" prompt.append({"role": "tool", "content": tool_message}) depth += 1 if depth >= max_depth: print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.") return completion completion = self.run_inference(prompt) return recursive_loop(prompt, completion, depth) else: logger.info(f"Assistant Message:\n{assistant_message}") return assistant_message return recursive_loop(prompt, completion, depth) except Exception as e: logger.error(f"Exception occurred: {e}") raise e