import ast import os import re import json import logging import datetime import xml.etree.ElementTree as ET from logging.handlers import RotatingFileHandler logging.basicConfig( format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S", level=logging.INFO, ) script_dir = os.path.dirname(os.path.abspath(__file__)) now = datetime.datetime.now() log_folder = os.path.join(script_dir, "inference_logs") os.makedirs(log_folder, exist_ok=True) log_file_path = os.path.join( log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log" ) # Use RotatingFileHandler from the logging.handlers module file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0) file_handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S") file_handler.setFormatter(formatter) inference_logger = logging.getLogger("function-calling-inference") inference_logger.addHandler(file_handler) def get_fewshot_examples(num_fewshot): """return a list of few shot examples""" example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json') with open(example_path, 'r') as file: examples = json.load(file) # Use json.load with the file object, not the file path if num_fewshot > len(examples): raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).") return examples[:num_fewshot] def get_chat_template(chat_template): """read chat template from jinja file""" template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2") if not os.path.exists(template_path): print inference_logger.error(f"Template file not found: {chat_template}") return None try: with open(template_path, 'r') as file: template = file.read() return template except Exception as e: print(f"Error loading template: {e}") return None def get_assistant_message(completion, chat_template, eos_token): """define and match pattern to find the assistant message""" completion = completion.strip() if chat_template == "zephyr": assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL) elif chat_template == "chatml": assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL) elif chat_template == "vicuna": assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL) else: raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.") assistant_match = assistant_pattern.search(completion) if assistant_match: assistant_content = assistant_match.group(1).strip() if chat_template == "vicuna": eos_token = f"{eos_token}" return assistant_content.replace(eos_token, "") else: assistant_content = None inference_logger.info("No match found for the assistant pattern") return assistant_content def validate_and_extract_tool_calls(assistant_content): validation_result = False tool_calls = [] error_message = None try: # wrap content in root element xml_root_element = f"{assistant_content}" root = ET.fromstring(xml_root_element) # extract JSON data for element in root.findall(".//tool_call"): json_data = None try: json_text = element.text.strip() try: # Prioritize json.loads for better error handling json_data = json.loads(json_text) except json.JSONDecodeError as json_err: try: # Fallback to ast.literal_eval if json.loads fails json_data = ast.literal_eval(json_text) except (SyntaxError, ValueError) as eval_err: error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\ f"- JSON Decode Error: {json_err}\n"\ f"- Fallback Syntax/Value Error: {eval_err}\n"\ f"- Problematic JSON text: {json_text}" inference_logger.error(error_message) continue except Exception as e: error_message = f"Cannot strip text: {e}" inference_logger.error(error_message) if json_data is not None: tool_calls.append(json_data) validation_result = True except ET.ParseError as err: error_message = f"XML Parse Error: {err}" inference_logger.error(f"XML Parse Error: {err}") # Return default values if no valid data is extracted return validation_result, tool_calls, error_message def extract_json_from_markdown(text): """ Extracts the JSON string from the given text using a regular expression pattern. Args: text (str): The input text containing the JSON string. Returns: dict: The JSON data loaded from the extracted string, or None if the JSON string is not found. """ json_pattern = r'```json\r?\n(.*?)\r?\n```' match = re.search(json_pattern, text, re.DOTALL) if match: json_string = match.group(1) try: data = json.loads(json_string) return data except json.JSONDecodeError as e: print(f"Error decoding JSON string: {e}") else: print("JSON string not found in the text.") return None