WilliamGazeley commited on
Commit
691fc98
1 Parent(s): 9e2a95f

Migrate to loguru

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. requirements.txt +1 -1
  4. src/app.py +0 -1
  5. src/functioncall.py +21 -19
  6. src/functions.py +6 -24
  7. src/logger.py +13 -0
  8. src/utils.py +5 -32
  9. src/validator.py +38 -21
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.log
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -5,3 +5,5 @@ __pycache__/
5
 
6
  # vLLM
7
  inference_logs/
 
 
 
5
 
6
  # vLLM
7
  inference_logs/
8
+
9
+ logs/*
requirements.txt CHANGED
@@ -25,4 +25,4 @@ langchain==0.1.9
25
  accelerate==0.27.2
26
  azure-search-documents==11.6.0b1
27
  azure-identity==1.16.0
28
-
 
25
  accelerate==0.27.2
26
  azure-search-documents==11.6.0b1
27
  azure-identity==1.16.0
28
+ loguru==0.7.2
src/app.py CHANGED
@@ -3,7 +3,6 @@ from time import time
3
  import huggingface_hub
4
  import streamlit as st
5
  from config import config
6
- from utils import get_assistant_message
7
  from functioncall import ModelInference
8
 
9
 
 
3
  import huggingface_hub
4
  import streamlit as st
5
  from config import config
 
6
  from functioncall import ModelInference
7
 
8
 
src/functioncall.py CHANGED
@@ -3,21 +3,18 @@ import torch
3
  import json
4
  from config import config
5
  from typing import List, Dict
 
6
 
7
- from transformers import (
8
- AutoModelForCausalLM,
9
- AutoTokenizer,
10
- BitsAndBytesConfig
11
- )
12
 
13
  import functions
14
  from prompter import PromptManager
15
  from validator import validate_function_call_schema
16
  from langchain_community.chat_models import ChatOllama
 
 
17
 
18
  from utils import (
19
- inference_logger,
20
- get_assistant_message,
21
  get_chat_template,
22
  validate_and_extract_tool_calls
23
  )
@@ -26,8 +23,9 @@ class ModelInference:
26
  def __init__(self, chat_template: str):
27
  self.prompter = PromptManager()
28
 
29
- self.model = ChatOllama(model=config.ollama_model,
30
- temperature=0.0, format='json')
 
31
 
32
  self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
33
  self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -37,19 +35,22 @@ class ModelInference:
37
  print("No chat template defined, getting chat_template...")
38
  self.tokenizer.chat_template = get_chat_template(chat_template)
39
 
 
40
 
41
  def process_completion_and_validate(self, completion, chat_template):
42
  if completion:
 
 
43
  validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
44
 
45
  if validation:
46
- inference_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
47
  return tool_calls, completion, error_message
48
  else:
49
  tool_calls = None
50
  return tool_calls, completion, error_message
51
  else:
52
- inference_logger.warning("Assistant message is None")
53
  raise ValueError("Assistant message is None")
54
 
55
  def execute_function_call(self, tool_call):
@@ -58,7 +59,7 @@ class ModelInference:
58
  function_to_call = getattr(functions, function_name, None)
59
  function_args = tool_call.get("arguments", {})
60
 
61
- inference_logger.info(f"Invoking function call {function_name} ...")
62
  function_response = function_to_call(*function_args.values())
63
  results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
64
  return results_dict
@@ -88,8 +89,9 @@ class ModelInference:
88
  prompt.append({"role": "assistant", "content": assistant_message})
89
 
90
  tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
 
91
  if tool_calls:
92
- inference_logger.info(f"Assistant Message:\n{assistant_message}")
93
 
94
  for tool_call in tool_calls:
95
  validation, message = validate_function_call_schema(tool_call, tools)
@@ -97,12 +99,12 @@ class ModelInference:
97
  try:
98
  function_response = self.execute_function_call(tool_call)
99
  tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
100
- inference_logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
101
  except Exception as e:
102
- inference_logger.info(f"Could not execute function: {e}")
103
  tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
104
  else:
105
- inference_logger.info(message)
106
  tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
107
  prompt.append({"role": "tool", "content": tool_message})
108
 
@@ -116,7 +118,7 @@ class ModelInference:
116
  completion = self.run_inference(prompt)
117
  return recursive_loop(prompt, completion, depth)
118
  elif error_message:
119
- inference_logger.info(f"Assistant Message:\n{assistant_message}")
120
  tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
121
  prompt.append({"role": "tool", "content": tool_message})
122
 
@@ -128,11 +130,11 @@ class ModelInference:
128
  completion = self.run_inference(prompt)
129
  return recursive_loop(prompt, completion, depth)
130
  else:
131
- inference_logger.info(f"Assistant Message:\n{assistant_message}")
132
  return assistant_message
133
 
134
  return recursive_loop(prompt, completion, depth)
135
 
136
  except Exception as e:
137
- inference_logger.error(f"Exception occurred: {e}")
138
  raise e
 
3
  import json
4
  from config import config
5
  from typing import List, Dict
6
+ from logger import logger
7
 
8
+ from transformers import AutoTokenizer
 
 
 
 
9
 
10
  import functions
11
  from prompter import PromptManager
12
  from validator import validate_function_call_schema
13
  from langchain_community.chat_models import ChatOllama
14
+ from langchain.prompts import PromptTemplate
15
+ from langchain_core.output_parsers import StrOutputParser
16
 
17
  from utils import (
 
 
18
  get_chat_template,
19
  validate_and_extract_tool_calls
20
  )
 
23
  def __init__(self, chat_template: str):
24
  self.prompter = PromptManager()
25
 
26
+ self.model = ChatOllama(model=config.ollama_model, temperature=0.0, format='json')
27
+ template = PromptTemplate(template="""<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {"type": "function", "function": {"name": "get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\\n\\n Args:\\n symbol (str): The stock symbol.\\n\\n Returns:\\n dict: A dictionary containing fundamental data.\\n Keys:\\n - \'symbol\': The stock symbol.\\n - \'company_name\': The long name of the company.\\n - \'sector\': The sector to which the company belongs.\\n - \'industry\': The industry to which the company belongs.\\n - \'market_cap\': The market capitalization of the company.\\n - \'pe_ratio\': The forward price-to-earnings ratio.\\n - \'pb_ratio\': The price-to-book ratio.\\n - \'dividend_yield\': The dividend yield.\\n - \'eps\': The trailing earnings per share.\\n - \'beta\': The beta value of the stock.\\n - \'52_week_high\': The 52-week high price of the stock.\\n - \'52_week_low\': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}} </tools> Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n<tool_call>\n{"arguments": <args-dict>, "name": <function-name>}\n</tool_call><|im_end|>\n""", input_variables=["question"])
28
+ chain = template | self.model | StrOutputParser()
29
 
30
  self.tokenizer = AutoTokenizer.from_pretrained(config.hf_model, trust_remote_code=True)
31
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
35
  print("No chat template defined, getting chat_template...")
36
  self.tokenizer.chat_template = get_chat_template(chat_template)
37
 
38
+ logger.info(f"Model loaded: {self.model}")
39
 
40
  def process_completion_and_validate(self, completion, chat_template):
41
  if completion:
42
+ # completion = f"<tool_call>\n{completion}\n</tool_call>"
43
+ breakpoint()
44
  validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
45
 
46
  if validation:
47
+ logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
48
  return tool_calls, completion, error_message
49
  else:
50
  tool_calls = None
51
  return tool_calls, completion, error_message
52
  else:
53
+ logger.warning("Assistant message is None")
54
  raise ValueError("Assistant message is None")
55
 
56
  def execute_function_call(self, tool_call):
 
59
  function_to_call = getattr(functions, function_name, None)
60
  function_args = tool_call.get("arguments", {})
61
 
62
+ logger.info(f"Invoking function call {function_name} ...")
63
  function_response = function_to_call(*function_args.values())
64
  results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
65
  return results_dict
 
89
  prompt.append({"role": "assistant", "content": assistant_message})
90
 
91
  tool_message = f"Agent iteration {depth} to assist with user query: {query}\n"
92
+ logger.info(f"Found tool calls: {tool_calls}")
93
  if tool_calls:
94
+ logger.info(f"Assistant Message:\n{assistant_message}")
95
 
96
  for tool_call in tool_calls:
97
  validation, message = validate_function_call_schema(tool_call, tools)
 
99
  try:
100
  function_response = self.execute_function_call(tool_call)
101
  tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
102
+ logger.info(f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}")
103
  except Exception as e:
104
+ logger.info(f"Could not execute function: {e}")
105
  tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
106
  else:
107
+ logger.info(message)
108
  tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
109
  prompt.append({"role": "tool", "content": tool_message})
110
 
 
118
  completion = self.run_inference(prompt)
119
  return recursive_loop(prompt, completion, depth)
120
  elif error_message:
121
+ logger.info(f"Assistant Message:\n{assistant_message}")
122
  tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
123
  prompt.append({"role": "tool", "content": tool_message})
124
 
 
130
  completion = self.run_inference(prompt)
131
  return recursive_loop(prompt, completion, depth)
132
  else:
133
+ logger.info(f"Assistant Message:\n{assistant_message}")
134
  return assistant_message
135
 
136
  return recursive_loop(prompt, completion, depth)
137
 
138
  except Exception as e:
139
+ logger.error(f"Exception occurred: {e}")
140
  raise e
src/functions.py CHANGED
@@ -8,7 +8,7 @@ from datetime import datetime
8
 
9
  from typing import List
10
  from bs4 import BeautifulSoup
11
- from utils import inference_logger
12
  from langchain.tools import tool
13
  from langchain_core.utils.function_calling import convert_to_openai_tool
14
  from config import config
@@ -69,13 +69,13 @@ def google_search_and_scrape(query: str) -> dict:
69
  params = {'q': query, 'num': num_results}
70
  headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'}
71
 
72
- inference_logger.info(f"Performing google search with query: {query}\nplease wait...")
73
  response = requests.get(url, params=params, headers=headers)
74
  soup = BeautifulSoup(response.text, 'html.parser')
75
  urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')]
76
 
77
- inference_logger.info(f"Scraping text from urls, please wait...")
78
- [inference_logger.info(url) for url in urls]
79
  with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
80
  futures = [executor.submit(lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url in urls[:num_results] if isinstance(url, str)]
81
  results = []
@@ -196,25 +196,6 @@ def get_key_financial_ratios(symbol: str) -> dict:
196
  print(f"Error fetching key financial ratios for {symbol}: {e}")
197
  return {}
198
 
199
- @tool
200
- def get_analyst_recommendations(symbol: str) -> pd.DataFrame:
201
- """
202
- Get analyst recommendations for a given stock symbol.
203
-
204
- Args:
205
- symbol (str): The stock symbol.
206
-
207
- Returns:
208
- pd.DataFrame: DataFrame containing analyst recommendations.
209
- """
210
- try:
211
- stock = yf.Ticker(symbol)
212
- recommendations = stock.recommendations
213
- return recommendations
214
- except Exception as e:
215
- print(f"Error fetching analyst recommendations for {symbol}: {e}")
216
- return pd.DataFrame()
217
-
218
  @tool
219
  def get_dividend_data(symbol: str) -> pd.DataFrame:
220
  """
@@ -245,6 +226,7 @@ def get_company_news(symbol: str) -> pd.DataFrame:
245
  Returns:
246
  pd.DataFrame: DataFrame containing company news and press releases.
247
  """
 
248
  try:
249
  news = yf.Ticker(symbol).news
250
  return news
@@ -293,7 +275,7 @@ def get_openai_tools() -> List[dict]:
293
  get_analysis,
294
  # google_search_and_scrape,
295
  get_current_stock_price,
296
- # get_company_news,
297
  # get_company_profile,
298
  # get_stock_fundamentals,
299
  # get_financial_statements,
 
8
 
9
  from typing import List
10
  from bs4 import BeautifulSoup
11
+ from logger import logger
12
  from langchain.tools import tool
13
  from langchain_core.utils.function_calling import convert_to_openai_tool
14
  from config import config
 
69
  params = {'q': query, 'num': num_results}
70
  headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.3'}
71
 
72
+ logger.info(f"Performing google search with query: {query}\nplease wait...")
73
  response = requests.get(url, params=params, headers=headers)
74
  soup = BeautifulSoup(response.text, 'html.parser')
75
  urls = [result.find('a')['href'] for result in soup.find_all('div', class_='tF2Cxc')]
76
 
77
+ logger.info(f"Scraping text from urls, please wait...")
78
+ [logger.info(url) for url in urls]
79
  with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
80
  futures = [executor.submit(lambda url: (url, requests.get(url, headers=headers).text if isinstance(url, str) else None), url) for url in urls[:num_results] if isinstance(url, str)]
81
  results = []
 
196
  print(f"Error fetching key financial ratios for {symbol}: {e}")
197
  return {}
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  @tool
200
  def get_dividend_data(symbol: str) -> pd.DataFrame:
201
  """
 
226
  Returns:
227
  pd.DataFrame: DataFrame containing company news and press releases.
228
  """
229
+ config.status.update(label=":newspaper: Getting news")
230
  try:
231
  news = yf.Ticker(symbol).news
232
  return news
 
275
  get_analysis,
276
  # google_search_and_scrape,
277
  get_current_stock_price,
278
+ get_company_news,
279
  # get_company_profile,
280
  # get_stock_fundamentals,
281
  # get_financial_statements,
src/logger.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from time import time
3
+ from loguru import logger
4
+ from pathlib import Path
5
+
6
+ log_dir = Path("logs")
7
+ log_dir.mkdir(exist_ok=True)
8
+
9
+ logger.remove() # Remove the default logger configuration
10
+
11
+ # Configure the logger to write logs to both files and stdout
12
+ logger.add(sys.stdout, format="{time} - {file} - {line} - {message}", backtrace=True)
13
+ logger.add(log_dir / f"{time()}.log", format="{time} - {file} - {line} - {message}", backtrace=True)
src/utils.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import logging
6
  import datetime
7
  import xml.etree.ElementTree as ET
 
8
 
9
  from logging.handlers import RotatingFileHandler
10
 
@@ -27,9 +28,6 @@ file_handler.setLevel(logging.INFO)
27
  formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
28
  file_handler.setFormatter(formatter)
29
 
30
- inference_logger = logging.getLogger("function-calling-inference")
31
- inference_logger.addHandler(file_handler)
32
-
33
  def get_fewshot_examples(num_fewshot):
34
  """return a list of few shot examples"""
35
  example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
@@ -45,7 +43,7 @@ def get_chat_template(chat_template):
45
 
46
  if not os.path.exists(template_path):
47
  print
48
- inference_logger.error(f"Template file not found: {chat_template}")
49
  return None
50
  try:
51
  with open(template_path, 'r') as file:
@@ -55,31 +53,6 @@ def get_chat_template(chat_template):
55
  print(f"Error loading template: {e}")
56
  return None
57
 
58
- def get_assistant_message(completion, chat_template, eos_token):
59
- """define and match pattern to find the assistant message"""
60
- completion = completion.strip()
61
-
62
- if chat_template == "zephyr":
63
- assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL)
64
- elif chat_template == "chatml":
65
- assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL)
66
-
67
- elif chat_template == "vicuna":
68
- assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL)
69
- else:
70
- raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
71
-
72
- assistant_match = assistant_pattern.search(completion)
73
- if assistant_match:
74
- assistant_content = assistant_match.group(1).strip()
75
- if chat_template == "vicuna":
76
- eos_token = f"</s>{eos_token}"
77
- return assistant_content.replace(eos_token, "")
78
- else:
79
- assistant_content = None
80
- inference_logger.info("No match found for the assistant pattern")
81
- return assistant_content
82
-
83
  def validate_and_extract_tool_calls(assistant_content):
84
  validation_result = False
85
  tool_calls = []
@@ -108,11 +81,11 @@ def validate_and_extract_tool_calls(assistant_content):
108
  f"- JSON Decode Error: {json_err}\n"\
109
  f"- Fallback Syntax/Value Error: {eval_err}\n"\
110
  f"- Problematic JSON text: {json_text}"
111
- inference_logger.error(error_message)
112
  continue
113
  except Exception as e:
114
  error_message = f"Cannot strip text: {e}"
115
- inference_logger.error(error_message)
116
 
117
  if json_data is not None:
118
  tool_calls.append(json_data)
@@ -120,7 +93,7 @@ def validate_and_extract_tool_calls(assistant_content):
120
 
121
  except ET.ParseError as err:
122
  error_message = f"XML Parse Error: {err}"
123
- inference_logger.error(f"XML Parse Error: {err}")
124
 
125
  # Return default values if no valid data is extracted
126
  return validation_result, tool_calls, error_message
 
5
  import logging
6
  import datetime
7
  import xml.etree.ElementTree as ET
8
+ from logger import logger
9
 
10
  from logging.handlers import RotatingFileHandler
11
 
 
28
  formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
29
  file_handler.setFormatter(formatter)
30
 
 
 
 
31
  def get_fewshot_examples(num_fewshot):
32
  """return a list of few shot examples"""
33
  example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
 
43
 
44
  if not os.path.exists(template_path):
45
  print
46
+ logger.error(f"Template file not found: {chat_template}")
47
  return None
48
  try:
49
  with open(template_path, 'r') as file:
 
53
  print(f"Error loading template: {e}")
54
  return None
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def validate_and_extract_tool_calls(assistant_content):
57
  validation_result = False
58
  tool_calls = []
 
81
  f"- JSON Decode Error: {json_err}\n"\
82
  f"- Fallback Syntax/Value Error: {eval_err}\n"\
83
  f"- Problematic JSON text: {json_text}"
84
+ logger.error(error_message)
85
  continue
86
  except Exception as e:
87
  error_message = f"Cannot strip text: {e}"
88
+ logger.error(error_message)
89
 
90
  if json_data is not None:
91
  tool_calls.append(json_data)
 
93
 
94
  except ET.ParseError as err:
95
  error_message = f"XML Parse Error: {err}"
96
+ logger.error(f"XML Parse Error: {err}")
97
 
98
  # Return default values if no valid data is extracted
99
  return validation_result, tool_calls, error_message
src/validator.py CHANGED
@@ -2,9 +2,11 @@ import ast
2
  import json
3
  from jsonschema import validate
4
  from pydantic import ValidationError
5
- from utils import inference_logger, extract_json_from_markdown
 
6
  from schema import FunctionCall, FunctionSignature
7
 
 
8
  def validate_function_call_schema(call, signatures):
9
  try:
10
  call_data = FunctionCall(**call)
@@ -16,18 +18,26 @@ def validate_function_call_schema(call, signatures):
16
  signature_data = FunctionSignature(**signature)
17
  if signature_data.function.name == call_data.name:
18
  # Validate types in function arguments
19
- for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
 
 
20
  if arg_name in call_data.arguments:
21
  call_arg_value = call_data.arguments[arg_name]
22
  if call_arg_value:
23
  try:
24
- validate_argument_type(arg_name, call_arg_value, arg_schema)
 
 
25
  except Exception as arg_validation_error:
26
  return False, str(arg_validation_error)
27
 
28
  # Check if all required arguments are present
29
- required_arguments = signature_data.function.parameters.get('required', [])
30
- result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
 
 
 
 
31
  if not result:
32
  return False, f"Missing required arguments: {missing_arguments}"
33
 
@@ -39,21 +49,24 @@ def validate_function_call_schema(call, signatures):
39
  # No matching function signature found
40
  return False, f"No matching function signature found for function: {call_data.name}"
41
 
 
42
  def check_required_arguments(call_arguments, required_arguments):
43
  missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
44
  return not bool(missing_arguments), missing_arguments
45
 
 
46
  def validate_enum_value(arg_name, arg_value, enum_values):
47
  if arg_value not in enum_values:
48
  raise Exception(
49
  f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
50
  )
51
 
 
52
  def validate_argument_type(arg_name, arg_value, arg_schema):
53
- arg_type = arg_schema.get('type', None)
54
  if arg_type:
55
- if arg_type == 'string' and 'enum' in arg_schema:
56
- enum_values = arg_schema['enum']
57
  if None not in enum_values and enum_values != []:
58
  try:
59
  validate_enum_value(arg_name, arg_value, enum_values)
@@ -63,20 +76,24 @@ def validate_argument_type(arg_name, arg_value, arg_schema):
63
 
64
  python_type = get_python_type(arg_type)
65
  if not isinstance(arg_value, python_type):
66
- raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
 
 
 
67
 
68
  def get_python_type(json_type):
69
  type_mapping = {
70
- 'string': str,
71
- 'number': (int, float),
72
- 'integer': int,
73
- 'boolean': bool,
74
- 'array': list,
75
- 'object': dict,
76
- 'null': type(None),
77
  }
78
  return type_mapping[json_type]
79
 
 
80
  def validate_json_data(json_object, json_schema):
81
  valid = False
82
  error_message = None
@@ -95,13 +112,13 @@ def validate_json_data(json_object, json_schema):
95
  result_json = extract_json_from_markdown(json_object)
96
  except Exception as e:
97
  error_message = f"JSON decoding error: {e}"
98
- inference_logger.info(f"Validation failed for JSON data: {error_message}")
99
  return valid, result_json, error_message
100
 
101
  # Return early if both json.loads and ast.literal_eval fail
102
  if result_json is None:
103
  error_message = "Failed to decode JSON data"
104
- inference_logger.info(f"Validation failed for JSON data: {error_message}")
105
  return valid, result_json, error_message
106
 
107
  # Validate each item in the list against schema if it's a list
@@ -109,7 +126,7 @@ def validate_json_data(json_object, json_schema):
109
  for index, item in enumerate(result_json):
110
  try:
111
  validate(instance=item, schema=json_schema)
112
- inference_logger.info(f"Item {index+1} is valid against the schema.")
113
  except ValidationError as e:
114
  error_message = f"Validation failed for item {index+1}: {e}"
115
  break
@@ -125,8 +142,8 @@ def validate_json_data(json_object, json_schema):
125
 
126
  if error_message is None:
127
  valid = True
128
- inference_logger.info("JSON data is valid against the schema.")
129
  else:
130
- inference_logger.info(f"Validation failed for JSON data: {error_message}")
131
 
132
  return valid, result_json, error_message
 
2
  import json
3
  from jsonschema import validate
4
  from pydantic import ValidationError
5
+ from logger import logger
6
+ from utils import extract_json_from_markdown
7
  from schema import FunctionCall, FunctionSignature
8
 
9
+
10
  def validate_function_call_schema(call, signatures):
11
  try:
12
  call_data = FunctionCall(**call)
 
18
  signature_data = FunctionSignature(**signature)
19
  if signature_data.function.name == call_data.name:
20
  # Validate types in function arguments
21
+ for arg_name, arg_schema in signature_data.function.parameters.get(
22
+ "properties", {}
23
+ ).items():
24
  if arg_name in call_data.arguments:
25
  call_arg_value = call_data.arguments[arg_name]
26
  if call_arg_value:
27
  try:
28
+ validate_argument_type(
29
+ arg_name, call_arg_value, arg_schema
30
+ )
31
  except Exception as arg_validation_error:
32
  return False, str(arg_validation_error)
33
 
34
  # Check if all required arguments are present
35
+ required_arguments = signature_data.function.parameters.get(
36
+ "required", []
37
+ )
38
+ result, missing_arguments = check_required_arguments(
39
+ call_data.arguments, required_arguments
40
+ )
41
  if not result:
42
  return False, f"Missing required arguments: {missing_arguments}"
43
 
 
49
  # No matching function signature found
50
  return False, f"No matching function signature found for function: {call_data.name}"
51
 
52
+
53
  def check_required_arguments(call_arguments, required_arguments):
54
  missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
55
  return not bool(missing_arguments), missing_arguments
56
 
57
+
58
  def validate_enum_value(arg_name, arg_value, enum_values):
59
  if arg_value not in enum_values:
60
  raise Exception(
61
  f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
62
  )
63
 
64
+
65
  def validate_argument_type(arg_name, arg_value, arg_schema):
66
+ arg_type = arg_schema.get("type", None)
67
  if arg_type:
68
+ if arg_type == "string" and "enum" in arg_schema:
69
+ enum_values = arg_schema["enum"]
70
  if None not in enum_values and enum_values != []:
71
  try:
72
  validate_enum_value(arg_name, arg_value, enum_values)
 
76
 
77
  python_type = get_python_type(arg_type)
78
  if not isinstance(arg_value, python_type):
79
+ raise Exception(
80
+ f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}"
81
+ )
82
+
83
 
84
  def get_python_type(json_type):
85
  type_mapping = {
86
+ "string": str,
87
+ "number": (int, float),
88
+ "integer": int,
89
+ "boolean": bool,
90
+ "array": list,
91
+ "object": dict,
92
+ "null": type(None),
93
  }
94
  return type_mapping[json_type]
95
 
96
+
97
  def validate_json_data(json_object, json_schema):
98
  valid = False
99
  error_message = None
 
112
  result_json = extract_json_from_markdown(json_object)
113
  except Exception as e:
114
  error_message = f"JSON decoding error: {e}"
115
+ logger.info(f"Validation failed for JSON data: {error_message}")
116
  return valid, result_json, error_message
117
 
118
  # Return early if both json.loads and ast.literal_eval fail
119
  if result_json is None:
120
  error_message = "Failed to decode JSON data"
121
+ logger.info(f"Validation failed for JSON data: {error_message}")
122
  return valid, result_json, error_message
123
 
124
  # Validate each item in the list against schema if it's a list
 
126
  for index, item in enumerate(result_json):
127
  try:
128
  validate(instance=item, schema=json_schema)
129
+ logger.info(f"Item {index+1} is valid against the schema.")
130
  except ValidationError as e:
131
  error_message = f"Validation failed for item {index+1}: {e}"
132
  break
 
142
 
143
  if error_message is None:
144
  valid = True
145
+ logger.info("JSON data is valid against the schema.")
146
  else:
147
+ logger.info(f"Validation failed for JSON data: {error_message}")
148
 
149
  return valid, result_json, error_message