File size: 9,046 Bytes
c124df1
 
 
 
5894c9b
691fc98
c124df1
691fc98
c124df1
 
 
 
9e2a95f
691fc98
 
c124df1
 
 
 
 
 
 
9e2a95f
c124df1
9e2a95f
691fc98
 
 
9e2a95f
d5870e6
c124df1
 
 
 
 
 
 
691fc98
c124df1
9e2a95f
 
691fc98
 
9e2a95f
c124df1
 
691fc98
9e2a95f
c124df1
 
9e2a95f
c124df1
691fc98
c124df1
 
 
7aab4a8
c124df1
 
 
 
691fc98
c124df1
 
 
 
5894c9b
 
 
 
9e2a95f
5894c9b
9e2a95f
 
c124df1
 
 
 
 
 
 
 
47c54d0
c124df1
 
 
 
 
 
 
 
691fc98
c124df1
691fc98
c124df1
 
 
 
 
 
 
691fc98
c124df1
691fc98
c124df1
 
691fc98
c124df1
 
 
 
 
 
b401ec2
 
c124df1
47c54d0
c124df1
5894c9b
c124df1
691fc98
c124df1
 
 
 
 
 
b401ec2
c124df1
 
5894c9b
c124df1
691fc98
5894c9b
c124df1
5894c9b
c124df1
 
691fc98
c124df1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import torch
import json
from config import config
from typing import List, Dict
from logger import logger

from transformers import AutoTokenizer

import functions
from prompter import PromptManager
from validator import validate_function_call_schema
from langchain_community.chat_models import ChatOllama
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

from utils import (
    get_chat_template,
    validate_and_extract_tool_calls
)

class ModelInference:
    def __init__(self, chat_template: str):
        self.prompter = PromptManager()
        
        self.model = ChatOllama(model=config.ollama_model, temperature=0.0, format='json')
        template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n    Args:\\n        symbol (str): The stock symbol.\\n\\n    Returns:\\n        dict: A dictionary containing fundamental data.\\n            Keys:\\n                - \'symbol\': The stock symbol.\\n                - \'company_name\': The long name of the company.\\n                - \'sector\': The sector to which the company belongs.\\n                - \'industry\': The industry to which the company belongs.\\n                - \'market_cap\': The market capitalization of the company.\\n                - \'pe_ratio\': The forward price-to-earnings ratio.\\n                - \'pb_ratio\': The price-to-book ratio.\\n                - \'dividend_yield\': The dividend yield.\\n                - \'eps\': The trailing earnings per share.\\n                - \'beta\': The beta value of the stock.\\n                - \'52_week_high\': The 52-week high price of the stock.\\n                - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}  </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
        chain = template | self.model | StrOutputParser()
        
        self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"

        if self.tokenizer.chat_template is None:
            print("No chat template defined, getting chat_template...")
            self.tokenizer.chat_template = get_chat_template(chat_template)

        logger.info(f"Model loaded: {self.model}")

    def process_completion_and_validate(self, completion, chat_template):
        if completion:
            # completion = f"<tool_call>\n{completion}\n</tool_call>"
            breakpoint()
            validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)

            if validation:
                logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
                return tool_calls, completion, error_message
            else:
                tool_calls = None
                return tool_calls, completion, error_message
        else:
            logger.warning("Assistant message is None")
            raise ValueError("Assistant message is None")
        
    def execute_function_call(self, tool_call):
        config.status.update(label=":mag: Gathering information..")
        function_name = tool_call.get("name")
        function_to_call = getattr(functions, function_name, None)
        function_args = tool_call.get("arguments", {})

        logger.info(f"Invoking function call {function_name} ...")
        function_response = function_to_call(*function_args.values())
        results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
        return results_dict
    
    def run_inference(self, prompt: List[Dict[str, str]]):
        inputs = self.tokenizer.apply_chat_template(
            prompt,
            add_generation_prompt=True,
            tokenize=False,
        )
        completion = self.model.invoke(inputs, format='json')
        return completion.content

    def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
        try:
            depth = 0
            user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
            chat = [{"role": "user", "content": user_message}]
            tools = functions.get_openai_tools()
            prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
            config.status.update(label=":brain: Thinking..")
            completion = self.run_inference(prompt)

            def recursive_loop(prompt, completion, depth):
                nonlocal max_depth
                tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
                prompt.append({"role": "assistant", "content": assistant_message})

                tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
                logger.info(f"Found tool calls: {tool_calls}")
                if tool_calls:
                    logger.info(f"Assistant Message:\n{assistant_message}")

                    for tool_call in tool_calls:
                        validation, message = validate_function_call_schema(tool_call, tools)
                        if validation:
                            try:
                                function_response = self.execute_function_call(tool_call)
                                tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
                                logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
                            except Exception as e:
                                logger.info(f"Could not execute function: {e}")
                                tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
                        else:
                            logger.info(message)
                            tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
                    prompt.append({"role": "tool", "content": tool_message})

                    depth += 1
                    if depth >= max_depth:
                        print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
                        completion = self.run_inference(prompt)
                        return completion

                    config.status.update(label=":brain: Analysing information..")
                    completion = self.run_inference(prompt)
                    return recursive_loop(prompt, completion, depth)
                elif error_message:
                    logger.info(f"Assistant Message:\n{assistant_message}")
                    tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
                    prompt.append({"role": "tool", "content": tool_message})

                    depth += 1
                    if depth >= max_depth:
                        print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
                        return completion

                    completion = self.run_inference(prompt)
                    return recursive_loop(prompt, completion, depth)
                else:
                    logger.info(f"Assistant Message:\n{assistant_message}")
                    return assistant_message

            return recursive_loop(prompt, completion, depth)

        except Exception as e:
            logger.error(f"Exception occurred: {e}")
            raise e