WilliamGazeley commited on
Commit
ce65c0f
2 Parent(s): 08238aa 5894c9b

Merge branch 'simple-rag'

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .env
2
+
3
+ # Python
4
+ __pycache__/
5
+
6
+ # vLLM
7
+ inference_logs/
app.py CHANGED
@@ -1,43 +1,43 @@
1
  import os
2
  import huggingface_hub
3
  import streamlit as st
4
- from vllm import LLM, SamplingParams
 
 
 
5
 
6
- sys_msg = """#Context:
7
- You are an expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
8
- #Objective:
9
- Please answer questions as best as possible given your current knowledge. You do not have access to up-to-date current market data. Try to demonstrate analytical depth and showcase ability to integrate complex data into practical advice, but answer the question directly.
10
- #Style and tone:
11
- Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
12
- #Audience:
13
- The questions will be asked by top executives and managers of successful startups. Assume the audience is composed of 40 year old males with high wealth and income, high risk appetite with high threshold for volatility.
14
- #Response:
15
- Direct answer to question, concise yet insightful."""
16
 
17
  @st.cache_resource(show_spinner="Loading model..")
18
  def init_llm():
19
- huggingface_hub.login(token=os.getenv("HF_TOKEN"))
20
- llm = LLM(model="InvestmentResearchAI/LLM-ADE-dev")
21
- tok = llm.get_tokenizer()
22
- tok.eos_token = '<|im_end|>' # Override to use turns
23
  return llm
24
 
25
  def get_response(prompt):
26
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  convo = [
28
- {"role": "system", "content": sys_msg},
29
- {"role": "user", "content": prompt},
30
  ]
31
- llm = init_llm()
32
- prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
33
- sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=500, stop_token_ids=[128009])
34
- outputs = llm.generate(prompts, sampling_params)
35
- for output in outputs:
36
- return output.outputs[0].text
37
  except Exception as e:
38
  return f"An error occurred: {str(e)}"
39
 
40
-
41
  def main():
42
  st.title("LLM-ADE 9B Demo")
43
 
@@ -46,13 +46,21 @@ def main():
46
  if st.button("Generate"):
47
  if input_text:
48
  with st.spinner('Generating response...'):
49
- response_text = get_response(input_text)
50
- st.write(response_text)
51
  else:
52
  st.warning("Please enter some text to generate a response.")
53
 
54
  llm = init_llm()
55
 
56
- if __name__ == "__main__":
57
- main()
 
 
 
58
 
 
 
 
 
 
 
1
  import os
2
  import huggingface_hub
3
  import streamlit as st
4
+ from config import config
5
+ from utils import get_assistant_message
6
+ from functioncall import ModelInference
7
+ from prompter import PromptManager
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @st.cache_resource(show_spinner="Loading model..")
11
  def init_llm():
12
+ huggingface_hub.login(token=config.hf_token, new_session=False)
13
+ llm = ModelInference(chat_template=config.chat_template)
 
 
14
  return llm
15
 
16
  def get_response(prompt):
17
  try:
18
+ return llm.generate_function_call(
19
+ prompt,
20
+ config.chat_template,
21
+ config.num_fewshot,
22
+ config.max_depth
23
+ )
24
+ except Exception as e:
25
+ return f"An error occurred: {str(e)}"
26
+
27
+ def get_output(context, user_input):
28
+ try:
29
+ prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml")
30
+ sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \
31
+ f"Information:\n{context}"
32
  convo = [
33
+ {"role": "system", "content": sys_prompt},
34
+ {"role": "user", "content": user_input},
35
  ]
36
+ response = llm.run_inference(convo)
37
+ return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token)
 
 
 
 
38
  except Exception as e:
39
  return f"An error occurred: {str(e)}"
40
 
 
41
  def main():
42
  st.title("LLM-ADE 9B Demo")
43
 
 
46
  if st.button("Generate"):
47
  if input_text:
48
  with st.spinner('Generating response...'):
49
+ agent_resp = get_response(input_text)
50
+ st.write(get_output(agent_resp, input_text))
51
  else:
52
  st.warning("Please enter some text to generate a response.")
53
 
54
  llm = init_llm()
55
 
56
+ def main_headless():
57
+ while True:
58
+ input_text = input("Enter your text here: ")
59
+ agent_resp = get_response(input_text)
60
+ print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m')
61
 
62
+ if __name__ == "__main__":
63
+ if config.headless:
64
+ main_headless()
65
+ else:
66
+ main()
config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import Field
2
+ from pydantic_settings import BaseSettings
3
+
4
+ class Config(BaseSettings):
5
+ hf_token: str = Field(...)
6
+ model_path: str = Field("InvestmentResearchAI/LLM-ADE-dev")
7
+ headless: bool = Field(False, description="Run in headless mode.")
8
+
9
+ chat_template: str = Field("chatml", description="Chat template for prompt formatting")
10
+ num_fewshot: int | None = Field(None, description="Option to use json mode examples")
11
+ load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
12
+ max_depth: int = Field(5, description="Maximum number of recursive iteration")
13
+
14
+ config = Config(_env_file=".env")
functioncall.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import json
4
+ from config import config
5
+ from typing import List, Dict
6
+ from vllm import LLM, SamplingParams
7
+
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ BitsAndBytesConfig
12
+ )
13
+
14
+ import functions
15
+ from prompter import PromptManager
16
+ from validator import validate_function_call_schema
17
+
18
+ from utils import (
19
+ inference_logger,
20
+ get_assistant_message,
21
+ get_chat_template,
22
+ validate_and_extract_tool_calls
23
+ )
24
+
25
+ class ModelInference:
26
+ def __init__(self, chat_template: str, load_in_4bit: bool = False):
27
+ self.prompter = PromptManager()
28
+ self.bnb_config = None
29
+
30
+ if load_in_4bit == "True": # Never use this
31
+ self.bnb_config = BitsAndBytesConfig(
32
+ load_in_4bit=True,
33
+ bnb_4bit_quant_type="nf4",
34
+ bnb_4bit_use_double_quant=True,
35
+ )
36
+ self.model = AutoModelForCausalLM.from_pretrained(
37
+ config.model_path,
38
+ trust_remote_code=True,
39
+ return_dict=True,
40
+ quantization_config=self.bnb_config,
41
+ torch_dtype=torch.float16,
42
+ attn_implementation="flash_attention_2",
43
+ device_map="auto",
44
+ )
45
+
46
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True)
47
+ self.tokenizer.pad_token = self.tokenizer.eos_token
48
+ self.tokenizer.padding_side = "left"
49
+
50
+ if self.tokenizer.chat_template is None:
51
+ print("No chat template defined, getting chat_template...")
52
+ self.tokenizer.chat_template = get_chat_template(chat_template)
53
+
54
+ inference_logger.info(self.model.config)
55
+ inference_logger.info(self.model.generation_config)
56
+ inference_logger.info(self.tokenizer.special_tokens_map)
57
+
58
+ def process_completion_and_validate(self, completion, chat_template):
59
+
60
+ assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
61
+
62
+ if assistant_message:
63
+ validation, tool_calls, error_message = validate_and_extract_tool_calls(assistant_message)
64
+
65
+ if validation:
66
+ inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
67
+ return tool_calls, assistant_message, error_message
68
+ else:
69
+ tool_calls = None
70
+ return tool_calls, assistant_message, error_message
71
+ else:
72
+ inference_logger.warning("Assistant message is None")
73
+ raise ValueError("Assistant message is None")
74
+
75
+ def execute_function_call(self, tool_call):
76
+ function_name = tool_call.get("name")
77
+ function_to_call = getattr(functions, function_name, None)
78
+ function_args = tool_call.get("arguments", {})
79
+
80
+ inference_logger.info(f"Invoking function call {function_name} ...")
81
+ function_response = function_to_call(*function_args.values())
82
+ results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
83
+ return results_dict
84
+
85
+ def run_inference(self, prompt: List[Dict[str, str]]):
86
+ inputs = self.tokenizer.apply_chat_template(
87
+ prompt,
88
+ add_generation_prompt=True,
89
+ return_tensors='pt'
90
+ )
91
+
92
+ tokens = self.model.generate(
93
+ inputs.to(self.model.device),
94
+ max_new_tokens=1500,
95
+ temperature=0.8,
96
+ repetition_penalty=1.1,
97
+ do_sample=True,
98
+ eos_token_id=self.tokenizer.eos_token_id
99
+ )
100
+ completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False, clean_up_tokenization_space=True)
101
+ return completion
102
+
103
+ def generate_function_call(self, query, chat_template, num_fewshot, max_depth=5):
104
+ try:
105
+ depth = 0
106
+ user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
107
+ chat = [{"role": "user", "content": user_message}]
108
+ tools = functions.get_openai_tools()
109
+ prompt = self.prompter.generate_prompt(chat, tools, num_fewshot)
110
+ completion = self.run_inference(prompt)
111
+
112
+ def recursive_loop(prompt, completion, depth):
113
+ nonlocal max_depth
114
+ tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
115
+ prompt.append({"role": "assistant", "content": assistant_message})
116
+
117
+ tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
118
+ if tool_calls:
119
+ inference_logger.info(f"Assistant Message:\n{assistant_message}")
120
+
121
+ for tool_call in tool_calls:
122
+ validation, message = validate_function_call_schema(tool_call, tools)
123
+ if validation:
124
+ try:
125
+ function_response = self.execute_function_call(tool_call)
126
+ tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
127
+ inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
128
+ except Exception as e:
129
+ inference_logger.info(f"Could not execute function: {e}")
130
+ tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
131
+ else:
132
+ inference_logger.info(message)
133
+ tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
134
+ prompt.append({"role": "tool", "content": tool_message})
135
+
136
+ depth += 1
137
+ if depth >= max_depth:
138
+ print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
139
+ return
140
+
141
+ completion = self.run_inference(prompt)
142
+ return recursive_loop(prompt, completion, depth)
143
+ elif error_message:
144
+ inference_logger.info(f"Assistant Message:\n{assistant_message}")
145
+ tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
146
+ prompt.append({"role": "tool", "content": tool_message})
147
+
148
+ depth += 1
149
+ if depth >= max_depth:
150
+ print(f"Maximum recursion depth reached ({max_depth}). Stopping recursion.")
151
+ return
152
+
153
+ completion = self.run_inference(prompt)
154
+ return recursive_loop(prompt, completion, depth)
155
+ else:
156
+ inference_logger.info(f"Assistant Message:\n{assistant_message}")
157
+ return assistant_message
158
+
159
+ return recursive_loop(prompt, completion, depth)
160
+
161
+ except Exception as e:
162
+ inference_logger.error(f"Exception occurred: {e}")
163
+ raise e
functions.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inspect
3
+ import requests
4
+ import pandas as pd
5
+ import yfinance as yf
6
+ import concurrent.futures
7
+
8
+ from typing import List
9
+ from bs4 import BeautifulSoup
10
+ from utils import inference_logger
11
+ from langchain.tools import tool
12
+ from langchain_core.utils.function_calling import convert_to_openai_tool
13
+
14
+ @tool
15
+ def google_search_and_scrape(query: str) -> dict:
16
+ """
17
+ Performs a Google search for the given query, retrieves the top search result URLs,
18
+ and scrapes the text content and table data from those pages in parallel.
19
+
20
+ Args:
21
+ query (str): The search query.
22
+ Returns:
23
+ list: A list of dictionaries containing the URL, text content, and table data for each scraped page.
24
+ """
25
+ num_results = 2
26
+ url = 'https://www.google.com/search'
27
+ params = {'q': query, 'num': num_results}
28
+ headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'}
29
+
30
+ inference_logger.info(f"Performing google search with query: {query}\nplease wait...")
31
+ response = requests.get(url, params=params, headers=headers)
32
+ soup = BeautifulSoup(response.text, 'html.parser')
33
+ urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')]
34
+
35
+ inference_logger.info(f"Scraping text from urls, please wait...")
36
+ [inference_logger.info(url) for url in urls]
37
+ with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
38
+ futures = [executor.submit(lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url in urls[:num_results] if isinstance(url, str)]
39
+ results = []
40
+ for future in concurrent.futures.as_completed(futures):
41
+ url, html = future.result()
42
+ soup = BeautifulSoup(html, 'html.parser')
43
+ paragraphs = [p.text.strip() for p in soup.find_all('p') if p.text.strip()]
44
+ text_content = ' '.join(paragraphs)
45
+ text_content = re.sub(r'\s+', ' ', text_content)
46
+ table_data = [[cell.get_text(strip=True) for cell in row.find_all('td')] for table in soup.find_all('table') for row in table.find_all('tr')]
47
+ if text_content or table_data:
48
+ results.append({'url': url, 'content': text_content, 'tables': table_data})
49
+ return results
50
+
51
+ @tool
52
+ def get_current_stock_price(symbol: str) -> float:
53
+ """
54
+ Get the current stock price for a given symbol.
55
+
56
+ Args:
57
+ symbol (str): The stock symbol.
58
+
59
+ Returns:
60
+ float: The current stock price, or None if an error occurs.
61
+ """
62
+ try:
63
+ stock = yf.Ticker(symbol)
64
+ # Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market
65
+ current_price = stock.info.get("regularMarketPrice", stock.info.get("currentPrice"))
66
+ return current_price if current_price else None
67
+ except Exception as e:
68
+ print(f"Error fetching current price for {symbol}: {e}")
69
+ return None
70
+
71
+ @tool
72
+ def get_stock_fundamentals(symbol: str) -> dict:
73
+ """
74
+ Get fundamental data for a given stock symbol using yfinance API.
75
+
76
+ Args:
77
+ symbol (str): The stock symbol.
78
+
79
+ Returns:
80
+ dict: A dictionary containing fundamental data.
81
+ Keys:
82
+ - 'symbol': The stock symbol.
83
+ - 'company_name': The long name of the company.
84
+ - 'sector': The sector to which the company belongs.
85
+ - 'industry': The industry to which the company belongs.
86
+ - 'market_cap': The market capitalization of the company.
87
+ - 'pe_ratio': The forward price-to-earnings ratio.
88
+ - 'pb_ratio': The price-to-book ratio.
89
+ - 'dividend_yield': The dividend yield.
90
+ - 'eps': The trailing earnings per share.
91
+ - 'beta': The beta value of the stock.
92
+ - '52_week_high': The 52-week high price of the stock.
93
+ - '52_week_low': The 52-week low price of the stock.
94
+ """
95
+ try:
96
+ stock = yf.Ticker(symbol)
97
+ info = stock.info
98
+ fundamentals = {
99
+ 'symbol': symbol,
100
+ 'company_name': info.get('longName', ''),
101
+ 'sector': info.get('sector', ''),
102
+ 'industry': info.get('industry', ''),
103
+ 'market_cap': info.get('marketCap', None),
104
+ 'pe_ratio': info.get('forwardPE', None),
105
+ 'pb_ratio': info.get('priceToBook', None),
106
+ 'dividend_yield': info.get('dividendYield', None),
107
+ 'eps': info.get('trailingEps', None),
108
+ 'beta': info.get('beta', None),
109
+ '52_week_high': info.get('fiftyTwoWeekHigh', None),
110
+ '52_week_low': info.get('fiftyTwoWeekLow', None)
111
+ }
112
+ return fundamentals
113
+ except Exception as e:
114
+ print(f"Error getting fundamentals for {symbol}: {e}")
115
+ return {}
116
+
117
+ @tool
118
+ def get_financial_statements(symbol: str) -> dict:
119
+ """
120
+ Get financial statements for a given stock symbol.
121
+
122
+ Args:
123
+ symbol (str): The stock symbol.
124
+
125
+ Returns:
126
+ dict: Dictionary containing financial statements (income statement, balance sheet, cash flow statement).
127
+ """
128
+ try:
129
+ stock = yf.Ticker(symbol)
130
+ financials = stock.financials
131
+ return financials
132
+ except Exception as e:
133
+ print(f"Error fetching financial statements for {symbol}: {e}")
134
+ return {}
135
+
136
+ @tool
137
+ def get_key_financial_ratios(symbol: str) -> dict:
138
+ """
139
+ Get key financial ratios for a given stock symbol.
140
+
141
+ Args:
142
+ symbol (str): The stock symbol.
143
+
144
+ Returns:
145
+ dict: Dictionary containing key financial ratios.
146
+ """
147
+ try:
148
+ stock = yf.Ticker(symbol)
149
+ key_ratios = stock.info
150
+ return key_ratios
151
+ except Exception as e:
152
+ print(f"Error fetching key financial ratios for {symbol}: {e}")
153
+ return {}
154
+
155
+ @tool
156
+ def get_analyst_recommendations(symbol: str) -> pd.DataFrame:
157
+ """
158
+ Get analyst recommendations for a given stock symbol.
159
+
160
+ Args:
161
+ symbol (str): The stock symbol.
162
+
163
+ Returns:
164
+ pd.DataFrame: DataFrame containing analyst recommendations.
165
+ """
166
+ try:
167
+ stock = yf.Ticker(symbol)
168
+ recommendations = stock.recommendations
169
+ return recommendations
170
+ except Exception as e:
171
+ print(f"Error fetching analyst recommendations for {symbol}: {e}")
172
+ return pd.DataFrame()
173
+
174
+ @tool
175
+ def get_dividend_data(symbol: str) -> pd.DataFrame:
176
+ """
177
+ Get dividend data for a given stock symbol.
178
+
179
+ Args:
180
+ symbol (str): The stock symbol.
181
+
182
+ Returns:
183
+ pd.DataFrame: DataFrame containing dividend data.
184
+ """
185
+ try:
186
+ stock = yf.Ticker(symbol)
187
+ dividends = stock.dividends
188
+ return dividends
189
+ except Exception as e:
190
+ print(f"Error fetching dividend data for {symbol}: {e}")
191
+ return pd.DataFrame()
192
+
193
+ @tool
194
+ def get_company_news(symbol: str) -> pd.DataFrame:
195
+ """
196
+ Get company news and press releases for a given stock symbol.
197
+
198
+ Args:
199
+ symbol (str): The stock symbol.
200
+
201
+ Returns:
202
+ pd.DataFrame: DataFrame containing company news and press releases.
203
+ """
204
+ try:
205
+ news = yf.Ticker(symbol).news
206
+ return news
207
+ except Exception as e:
208
+ print(f"Error fetching company news for {symbol}: {e}")
209
+ return pd.DataFrame()
210
+
211
+ @tool
212
+ def get_technical_indicators(symbol: str) -> pd.DataFrame:
213
+ """
214
+ Get technical indicators for a given stock symbol.
215
+
216
+ Args:
217
+ symbol (str): The stock symbol.
218
+
219
+ Returns:
220
+ pd.DataFrame: DataFrame containing technical indicators.
221
+ """
222
+ try:
223
+ indicators = yf.Ticker(symbol).history(period="max")
224
+ return indicators
225
+ except Exception as e:
226
+ print(f"Error fetching technical indicators for {symbol}: {e}")
227
+ return pd.DataFrame()
228
+
229
+ @tool
230
+ def get_company_profile(symbol: str) -> dict:
231
+ """
232
+ Get company profile and overview for a given stock symbol.
233
+
234
+ Args:
235
+ symbol (str): The stock symbol.
236
+
237
+ Returns:
238
+ dict: Dictionary containing company profile and overview.
239
+ """
240
+ try:
241
+ profile = yf.Ticker(symbol).info
242
+ return profile
243
+ except Exception as e:
244
+ print(f"Error fetching company profile for {symbol}: {e}")
245
+ return {}
246
+
247
+ def get_openai_tools() -> List[dict]:
248
+ functions = [
249
+ google_search_and_scrape,
250
+ get_current_stock_price,
251
+ get_company_news,
252
+ get_company_profile,
253
+ get_stock_fundamentals,
254
+ get_financial_statements,
255
+ get_key_financial_ratios,
256
+ get_analyst_recommendations,
257
+ get_dividend_data,
258
+ get_technical_indicators
259
+ ]
260
+
261
+ tools = [convert_to_openai_tool(f) for f in functions]
262
+ return tools
prompt_assets/few_shot.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n</tool_call>\n```\n"
4
+ },
5
+ {
6
+ "example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n<tools>[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n</tools>\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n<tool_call>\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n</tool_call>\n```"
7
+ }
8
+ ]
prompt_assets/output_sys_prompt.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Role: |
2
+ You are an expert financial advisor named IRAI.
3
+ You have a comprehensive understanding of finance and investing with experience and expertise in all areas of finance.
4
+ You can use information given to you, but do not mention function calls.
5
+ Objective: |
6
+ Answer questions accurately and truthfully given your current knowledge. Answer the question directly.
7
+ Instructions: |
8
+ The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
9
+ Answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
10
+ Give a direct answer to question, concise yet insightful.
prompt_assets/sys_prompt.yml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Role: |
2
+ You are a function calling AI agent with self-recursion.
3
+ You can call only one function at a time and analyse data you get from function response.
4
+ You are provided with function signatures within <tools></tools> XML tags.
5
+ The current date is: {date}.
6
+ Objective: |
7
+ You may use agentic frameworks for reasoning and planning to help with user query.
8
+ Please call a function and wait for function results to be provided to you in the next iteration.
9
+ Don't make assumptions about what values to plug into function arguments.
10
+ Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
11
+ Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
12
+ Analyze the data once you get the results and call another function.
13
+ At each iteration please continue adding the your analysis to previous summary.
14
+ Your final response should directly answer the user query with an anlysis or summary of the results of function calls.
15
+ Tools: |
16
+ Here are the available tools:
17
+ <tools> {tools} </tools>
18
+ If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
19
+ <tool_call>
20
+ {{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
21
+ </tool_call>
22
+ Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
23
+ Examples: |
24
+ Here are some example usage of functions:
25
+ {examples}
26
+ Schema: |
27
+ Use the following pydantic model json schema for each tool call you will make:
28
+ {schema}
29
+ Instructions: |
30
+ At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
31
+ Please keep a running summary with analysis of previous function results and summaries from previous iterations.
32
+ Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
33
+ Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
34
+ If you plan to continue with analysis, always call another function.
35
+ For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
36
+ <tool_call>
37
+ {{"arguments": <args-dict>, "name": <function-name>}}
38
+ </tool_call>
prompter.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pydantic import BaseModel
3
+ from typing import Dict
4
+ from schema import FunctionCall
5
+ from utils import (
6
+ get_fewshot_examples
7
+ )
8
+ import yaml
9
+ import json
10
+ import os
11
+
12
+ class PromptSchema(BaseModel):
13
+ Role: str
14
+ Objective: str
15
+ Tools: str
16
+ Examples: str
17
+ Schema: str
18
+ Instructions: str
19
+
20
+ class PromptManager:
21
+ def __init__(self):
22
+ self.script_dir = os.path.dirname(os.path.abspath(__file__))
23
+
24
+ def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str:
25
+ formatted_prompt = ""
26
+ for field, value in prompt_schema.dict().items():
27
+ if field == "Examples" and variables.get("examples") is None:
28
+ continue
29
+ formatted_value = value.format(**variables)
30
+ if field == "Instructions":
31
+ formatted_prompt += f"{formatted_value}"
32
+ else:
33
+ formatted_value = formatted_value.replace("\n", " ")
34
+ formatted_prompt += f"{formatted_value}"
35
+ return formatted_prompt
36
+
37
+ def read_yaml_file(self, file_path: str) -> PromptSchema:
38
+ with open(file_path, 'r') as file:
39
+ yaml_content = yaml.safe_load(file)
40
+
41
+ prompt_schema = PromptSchema(
42
+ Role=yaml_content.get('Role', ''),
43
+ Objective=yaml_content.get('Objective', ''),
44
+ Tools=yaml_content.get('Tools', ''),
45
+ Examples=yaml_content.get('Examples', ''),
46
+ Schema=yaml_content.get('Schema', ''),
47
+ Instructions=yaml_content.get('Instructions', ''),
48
+ )
49
+ return prompt_schema
50
+
51
+ def generate_prompt(self, user_prompt, tools, num_fewshot=None):
52
+ prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml')
53
+ prompt_schema = self.read_yaml_file(prompt_path)
54
+
55
+ if num_fewshot is not None:
56
+ examples = get_fewshot_examples(num_fewshot)
57
+ else:
58
+ examples = None
59
+
60
+ schema_json = json.loads(FunctionCall.schema_json())
61
+
62
+ variables = {
63
+ "date": datetime.date.today(),
64
+ "tools": tools,
65
+ "examples": examples,
66
+ "schema": schema_json
67
+ }
68
+ sys_prompt = self.format_yaml_prompt(prompt_schema, variables)
69
+
70
+ prompt = [
71
+ {'content': sys_prompt, 'role': 'system'}
72
+ ]
73
+ prompt.extend(user_prompt)
74
+ return prompt
75
+
76
+
requirements.txt CHANGED
@@ -1,6 +1,131 @@
1
- streamlit
2
- transformers
3
- torch
4
- vllm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  xformers==0.0.23
6
-
 
 
1
+ aiohttp==3.9.5
2
+ aioprometheus==23.12.0
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ appdirs==1.4.4
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ beautifulsoup4==4.12.3
11
+ blinker==1.8.2
12
+ cachetools==5.3.3
13
+ certifi==2024.2.2
14
+ charset-normalizer==3.3.2
15
+ click==8.1.7
16
+ dataclasses-json==0.6.5
17
+ dnspython==2.6.1
18
+ email_validator==2.1.1
19
+ exceptiongroup==1.2.1
20
+ fastapi==0.111.0
21
+ fastapi-cli==0.0.3
22
+ filelock==3.14.0
23
+ frozendict==2.4.4
24
+ frozenlist==1.4.1
25
+ fsspec==2024.3.1
26
+ gitdb==4.0.11
27
+ GitPython==3.1.43
28
+ greenlet==3.0.3
29
+ h11==0.14.0
30
+ html5lib==1.1
31
+ httpcore==1.0.5
32
+ httptools==0.6.1
33
+ httpx==0.27.0
34
+ huggingface-hub==0.23.0
35
+ idna==3.7
36
+ Jinja2==3.1.4
37
+ jsonpatch==1.33
38
+ jsonpointer==2.4
39
+ jsonschema==4.22.0
40
+ jsonschema-specifications==2023.12.1
41
+ langchain==0.1.17
42
+ langchain-community==0.0.37
43
+ langchain-core==0.1.52
44
+ langchain-text-splitters==0.0.1
45
+ langsmith==0.1.54
46
+ lxml==5.2.1
47
+ markdown-it-py==3.0.0
48
+ MarkupSafe==2.1.5
49
+ marshmallow==3.21.2
50
+ mdurl==0.1.2
51
+ mpmath==1.3.0
52
+ msgpack==1.0.8
53
+ multidict==6.0.5
54
+ multitasking==0.0.11
55
+ mypy-extensions==1.0.0
56
+ networkx==3.3
57
+ ninja==1.11.1.1
58
+ numpy==1.26.4
59
+ nvidia-cublas-cu12==12.1.3.1
60
+ nvidia-cuda-cupti-cu12==12.1.105
61
+ nvidia-cuda-nvrtc-cu12==12.1.105
62
+ nvidia-cuda-runtime-cu12==12.1.105
63
+ nvidia-cudnn-cu12==8.9.2.26
64
+ nvidia-cufft-cu12==11.0.2.54
65
+ nvidia-curand-cu12==10.3.2.106
66
+ nvidia-cusolver-cu12==11.4.5.107
67
+ nvidia-cusparse-cu12==12.1.0.106
68
+ nvidia-nccl-cu12==2.18.1
69
+ nvidia-nvjitlink-cu12==12.4.127
70
+ nvidia-nvtx-cu12==12.1.105
71
+ orjson==3.10.3
72
+ packaging==23.2
73
+ pandas==2.2.2
74
+ peewee==3.17.3
75
+ pillow==10.3.0
76
+ protobuf==4.25.3
77
+ psutil==5.9.8
78
+ pyarrow==16.0.0
79
+ pydantic==2.7.1
80
+ pydantic-settings==2.2.1
81
+ pydantic_core==2.18.2
82
+ pydeck==0.9.0
83
+ Pygments==2.18.0
84
+ python-dateutil==2.9.0.post0
85
+ python-dotenv==1.0.1
86
+ python-multipart==0.0.9
87
+ pytz==2024.1
88
+ PyYAML==6.0.1
89
+ quantile-python==1.1
90
+ ray==2.20.0
91
+ referencing==0.35.1
92
+ regex==2024.4.28
93
+ requests==2.31.0
94
+ rich==13.7.1
95
+ rpds-py==0.18.1
96
+ safetensors==0.4.3
97
+ sentencepiece==0.2.0
98
+ shellingham==1.5.4
99
+ six==1.16.0
100
+ smmap==5.0.1
101
+ sniffio==1.3.1
102
+ soupsieve==2.5
103
+ SQLAlchemy==2.0.30
104
+ starlette==0.37.2
105
+ streamlit==1.34.0
106
+ sympy==1.12
107
+ tenacity==8.3.0
108
+ tokenizers==0.19.1
109
+ toml==0.10.2
110
+ toolz==0.12.1
111
+ torch==2.1.1
112
+ tornado==6.4
113
+ tqdm==4.66.4
114
+ transformers==4.40.2
115
+ triton==2.1.0
116
+ typer==0.12.3
117
+ typing-inspect==0.9.0
118
+ typing_extensions==4.11.0
119
+ tzdata==2024.1
120
+ ujson==5.9.0
121
+ urllib3==2.2.1
122
+ uvicorn==0.29.0
123
+ uvloop==0.19.0
124
+ vllm==0.2.5
125
+ watchdog==4.0.0
126
+ watchfiles==0.21.0
127
+ webencodings==0.5.1
128
+ websockets==12.0
129
  xformers==0.0.23
130
+ yarl==1.9.4
131
+ yfinance==0.2.38
schema.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Dict, Literal, Optional
3
+
4
+ class FunctionCall(BaseModel):
5
+ arguments: dict
6
+ """
7
+ The arguments to call the function with, as generated by the model in JSON
8
+ format. Note that the model does not always generate valid JSON, and may
9
+ hallucinate parameters not defined by your function schema. Validate the
10
+ arguments in your code before calling your function.
11
+ """
12
+
13
+ name: str
14
+ """The name of the function to call."""
15
+
16
+ class FunctionDefinition(BaseModel):
17
+ name: str
18
+ description: Optional[str] = None
19
+ parameters: Optional[Dict[str, object]] = None
20
+
21
+ class FunctionSignature(BaseModel):
22
+ function: FunctionDefinition
23
+ type: Literal["function"]
utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ import re
4
+ import json
5
+ import logging
6
+ import datetime
7
+ import xml.etree.ElementTree as ET
8
+
9
+ from logging.handlers import RotatingFileHandler
10
+
11
+ logging.basicConfig(
12
+ format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
13
+ datefmt="%Y-%m-%d:%H:%M:%S",
14
+ level=logging.INFO,
15
+ )
16
+ script_dir = os.path.dirname(os.path.abspath(__file__))
17
+ now = datetime.datetime.now()
18
+ log_folder = os.path.join(script_dir, "inference_logs")
19
+ os.makedirs(log_folder, exist_ok=True)
20
+ log_file_path = os.path.join(
21
+ log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
22
+ )
23
+ # Use RotatingFileHandler from the logging.handlers module
24
+ file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
25
+ file_handler.setLevel(logging.INFO)
26
+
27
+ formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
28
+ file_handler.setFormatter(formatter)
29
+
30
+ inference_logger = logging.getLogger("function-calling-inference")
31
+ inference_logger.addHandler(file_handler)
32
+
33
+ def get_fewshot_examples(num_fewshot):
34
+ """return a list of few shot examples"""
35
+ example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
36
+ with open(example_path, 'r') as file:
37
+ examples = json.load(file) # Use json.load with the file object, not the file path
38
+ if num_fewshot > len(examples):
39
+ raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
40
+ return examples[:num_fewshot]
41
+
42
+ def get_chat_template(chat_template):
43
+ """read chat template from jinja file"""
44
+ template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
45
+
46
+ if not os.path.exists(template_path):
47
+ print
48
+ inference_logger.error(f"Template file not found: {chat_template}")
49
+ return None
50
+ try:
51
+ with open(template_path, 'r') as file:
52
+ template = file.read()
53
+ return template
54
+ except Exception as e:
55
+ print(f"Error loading template: {e}")
56
+ return None
57
+
58
+ def get_assistant_message(completion, chat_template, eos_token):
59
+ """define and match pattern to find the assistant message"""
60
+ completion = completion.strip()
61
+
62
+ if chat_template == "zephyr":
63
+ assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL)
64
+ elif chat_template == "chatml":
65
+ assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL)
66
+
67
+ elif chat_template == "vicuna":
68
+ assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL)
69
+ else:
70
+ raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
71
+
72
+ assistant_match = assistant_pattern.search(completion)
73
+ if assistant_match:
74
+ assistant_content = assistant_match.group(1).strip()
75
+ if chat_template == "vicuna":
76
+ eos_token = f"</s>{eos_token}"
77
+ return assistant_content.replace(eos_token, "")
78
+ else:
79
+ assistant_content = None
80
+ inference_logger.info("No match found for the assistant pattern")
81
+ return assistant_content
82
+
83
+ def validate_and_extract_tool_calls(assistant_content):
84
+ validation_result = False
85
+ tool_calls = []
86
+ error_message = None
87
+
88
+ try:
89
+ # wrap content in root element
90
+ xml_root_element = f"<root>{assistant_content}</root>"
91
+ root = ET.fromstring(xml_root_element)
92
+
93
+ # extract JSON data
94
+ for element in root.findall(".//tool_call"):
95
+ json_data = None
96
+ try:
97
+ json_text = element.text.strip()
98
+
99
+ try:
100
+ # Prioritize json.loads for better error handling
101
+ json_data = json.loads(json_text)
102
+ except json.JSONDecodeError as json_err:
103
+ try:
104
+ # Fallback to ast.literal_eval if json.loads fails
105
+ json_data = ast.literal_eval(json_text)
106
+ except (SyntaxError, ValueError) as eval_err:
107
+ error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\
108
+ f"- JSON Decode Error: {json_err}\n"\
109
+ f"- Fallback Syntax/Value Error: {eval_err}\n"\
110
+ f"- Problematic JSON text: {json_text}"
111
+ inference_logger.error(error_message)
112
+ continue
113
+ except Exception as e:
114
+ error_message = f"Cannot strip text: {e}"
115
+ inference_logger.error(error_message)
116
+
117
+ if json_data is not None:
118
+ tool_calls.append(json_data)
119
+ validation_result = True
120
+
121
+ except ET.ParseError as err:
122
+ error_message = f"XML Parse Error: {err}"
123
+ inference_logger.error(f"XML Parse Error: {err}")
124
+
125
+ # Return default values if no valid data is extracted
126
+ return validation_result, tool_calls, error_message
127
+
128
+ def extract_json_from_markdown(text):
129
+ """
130
+ Extracts the JSON string from the given text using a regular expression pattern.
131
+
132
+ Args:
133
+ text (str): The input text containing the JSON string.
134
+
135
+ Returns:
136
+ dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
137
+ """
138
+ json_pattern = r'```json\r?\n(.*?)\r?\n```'
139
+ match = re.search(json_pattern, text, re.DOTALL)
140
+ if match:
141
+ json_string = match.group(1)
142
+ try:
143
+ data = json.loads(json_string)
144
+ return data
145
+ except json.JSONDecodeError as e:
146
+ print(f"Error decoding JSON string: {e}")
147
+ else:
148
+ print("JSON string not found in the text.")
149
+ return None
validator.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ from jsonschema import validate
4
+ from pydantic import ValidationError
5
+ from utils import inference_logger, extract_json_from_markdown
6
+ from schema import FunctionCall, FunctionSignature
7
+
8
+ def validate_function_call_schema(call, signatures):
9
+ try:
10
+ call_data = FunctionCall(**call)
11
+ except ValidationError as e:
12
+ return False, str(e)
13
+
14
+ for signature in signatures:
15
+ try:
16
+ signature_data = FunctionSignature(**signature)
17
+ if signature_data.function.name == call_data.name:
18
+ # Validate types in function arguments
19
+ for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
20
+ if arg_name in call_data.arguments:
21
+ call_arg_value = call_data.arguments[arg_name]
22
+ if call_arg_value:
23
+ try:
24
+ validate_argument_type(arg_name, call_arg_value, arg_schema)
25
+ except Exception as arg_validation_error:
26
+ return False, str(arg_validation_error)
27
+
28
+ # Check if all required arguments are present
29
+ required_arguments = signature_data.function.parameters.get('required', [])
30
+ result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
31
+ if not result:
32
+ return False, f"Missing required arguments: {missing_arguments}"
33
+
34
+ return True, None
35
+ except Exception as e:
36
+ # Handle validation errors for the function signature
37
+ return False, str(e)
38
+
39
+ # No matching function signature found
40
+ return False, f"No matching function signature found for function: {call_data.name}"
41
+
42
+ def check_required_arguments(call_arguments, required_arguments):
43
+ missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
44
+ return not bool(missing_arguments), missing_arguments
45
+
46
+ def validate_enum_value(arg_name, arg_value, enum_values):
47
+ if arg_value not in enum_values:
48
+ raise Exception(
49
+ f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
50
+ )
51
+
52
+ def validate_argument_type(arg_name, arg_value, arg_schema):
53
+ arg_type = arg_schema.get('type', None)
54
+ if arg_type:
55
+ if arg_type == 'string' and 'enum' in arg_schema:
56
+ enum_values = arg_schema['enum']
57
+ if None not in enum_values and enum_values != []:
58
+ try:
59
+ validate_enum_value(arg_name, arg_value, enum_values)
60
+ except Exception as e:
61
+ # Propagate the validation error message
62
+ raise Exception(f"Error validating function call: {e}")
63
+
64
+ python_type = get_python_type(arg_type)
65
+ if not isinstance(arg_value, python_type):
66
+ raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
67
+
68
+ def get_python_type(json_type):
69
+ type_mapping = {
70
+ 'string': str,
71
+ 'number': (int, float),
72
+ 'integer': int,
73
+ 'boolean': bool,
74
+ 'array': list,
75
+ 'object': dict,
76
+ 'null': type(None),
77
+ }
78
+ return type_mapping[json_type]
79
+
80
+ def validate_json_data(json_object, json_schema):
81
+ valid = False
82
+ error_message = None
83
+ result_json = None
84
+
85
+ try:
86
+ # Attempt to load JSON using json.loads
87
+ try:
88
+ result_json = json.loads(json_object)
89
+ except json.decoder.JSONDecodeError:
90
+ # If json.loads fails, try ast.literal_eval
91
+ try:
92
+ result_json = ast.literal_eval(json_object)
93
+ except (SyntaxError, ValueError) as e:
94
+ try:
95
+ result_json = extract_json_from_markdown(json_object)
96
+ except Exception as e:
97
+ error_message = f"JSON decoding error: {e}"
98
+ inference_logger.info(f"Validation failed for JSON data: {error_message}")
99
+ return valid, result_json, error_message
100
+
101
+ # Return early if both json.loads and ast.literal_eval fail
102
+ if result_json is None:
103
+ error_message = "Failed to decode JSON data"
104
+ inference_logger.info(f"Validation failed for JSON data: {error_message}")
105
+ return valid, result_json, error_message
106
+
107
+ # Validate each item in the list against schema if it's a list
108
+ if isinstance(result_json, list):
109
+ for index, item in enumerate(result_json):
110
+ try:
111
+ validate(instance=item, schema=json_schema)
112
+ inference_logger.info(f"Item {index+1} is valid against the schema.")
113
+ except ValidationError as e:
114
+ error_message = f"Validation failed for item {index+1}: {e}"
115
+ break
116
+ else:
117
+ # Default to validation without list
118
+ try:
119
+ validate(instance=result_json, schema=json_schema)
120
+ except ValidationError as e:
121
+ error_message = f"Validation failed: {e}"
122
+
123
+ except Exception as e:
124
+ error_message = f"Error occurred: {e}"
125
+
126
+ if error_message is None:
127
+ valid = True
128
+ inference_logger.info("JSON data is valid against the schema.")
129
+ else:
130
+ inference_logger.info(f"Validation failed for JSON data: {error_message}")
131
+
132
+ return valid, result_json, error_message