import re import json from typing import List, Dict, Any, Optional class ToolCallExtractor: def __init__(self): # Existing regex patterns (retain if needed for other formats) self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL) self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL) def _extract_function_args(self, args) -> Dict[str, Any]: """ Flatten the nested function args structure for Google AI protobuf types. """ flattened_args = {} try: # Explicitly check for fields if hasattr(args, 'fields'): # Iterate through fields using to_dict() to convert protobuf to dict for field in args.fields: key = field.key value = field.value # Additional debugging print(f"Field key: {key}") print(f"Field value type: {type(value)}") print(f"Field value: {value}") # Extract string value if hasattr(value, 'string_value'): flattened_args[key] = value.string_value print(f"Extracted string value: {value.string_value}") elif hasattr(value, 'number_value'): flattened_args[key] = value.number_value elif hasattr(value, 'bool_value') and value.bool_value is not None: flattened_args[key] = value.bool_value # Added additional debug information print(f"Final flattened args: {flattened_args}") except Exception as e: print(f"Error extracting function args: {e}") return flattened_args def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]: """ Extract tool calls from input string, handling various inconsistent formats. Args: input_string (str): The input string containing tool calls. Returns: list: A list of dictionaries representing the parsed tool calls. """ tool_calls = [] # Existing tag-based extraction (retain if needed) complete_matches = self.complete_pattern.findall(input_string) if complete_matches: for match in complete_matches: tool_calls.extend(self._extract_json_objects(match)) return tool_calls partial_matches = self.partial_pattern.findall(input_string) if partial_matches: for match in partial_matches: tool_calls.extend(self._extract_json_objects(match)) return tool_calls # Fallback: Attempt to parse the entire string tool_calls.extend(self._extract_json_objects(input_string)) return tool_calls def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]: """ Extract and parse multiple JSON objects from a string. """ json_objects = [] potential_jsons = text.split(';') for json_str in potential_jsons: parsed_obj = self._clean_and_parse_json(json_str) if parsed_obj: json_objects.append(parsed_obj) return json_objects def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]: """ Clean and parse a JSON string, handling common formatting issues. """ try: json_str = json_str.strip() if json_str.startswith('{') or json_str.startswith('['): return json.loads(json_str) return None except json.JSONDecodeError: return None def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool: """ Validate if a tool call has the required fields. """ return ( isinstance(tool_call, dict) and 'name' in tool_call and isinstance(tool_call['name'], str) ) def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]: """ Extract function call details from the response parts. Args: response_parts (list): The list of response parts from the chat model. Returns: dict: A dictionary containing the function name and flattened arguments. """ for part in response_parts: # Debug print print(f"Examining part: {part}") print(f"Part type: {type(part)}") # Check for function_call attribute if hasattr(part, 'function_call') and part.function_call: function_call = part.function_call # Debug print print(f"Function call: {function_call}") print(f"Function call type: {type(function_call)}") print(f"Function args: {function_call.args}") # Extract function name function_name = getattr(function_call, 'name', None) if not function_name: continue # Skip if function name is missing # Extract function arguments function_args = self._extract_function_args(function_call.args) return { "name": function_name, "args": function_args } return {}