chat / app /utils /tool_call_extractor.py
ariansyahdedy's picture
Add prompt edit and api key config
8d2f9d4
raw
history blame
5.72 kB
import re
import json
from typing import List, Dict, Any, Optional
class ToolCallExtractor:
def __init__(self):
# Existing regex patterns (retain if needed for other formats)
self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL)
self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL)
def _extract_function_args(self, args) -> Dict[str, Any]:
"""
Flatten the nested function args structure for Google AI protobuf types.
"""
flattened_args = {}
try:
# Explicitly check for fields
if hasattr(args, 'fields'):
# Iterate through fields using to_dict() to convert protobuf to dict
for field in args.fields:
key = field.key
value = field.value
# Additional debugging
print(f"Field key: {key}")
print(f"Field value type: {type(value)}")
print(f"Field value: {value}")
# Extract string value
if hasattr(value, 'string_value'):
flattened_args[key] = value.string_value
print(f"Extracted string value: {value.string_value}")
elif hasattr(value, 'number_value'):
flattened_args[key] = value.number_value
elif hasattr(value, 'bool_value') and value.bool_value is not None:
flattened_args[key] = value.bool_value
# Added additional debug information
print(f"Final flattened args: {flattened_args}")
except Exception as e:
print(f"Error extracting function args: {e}")
return flattened_args
def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]:
"""
Extract tool calls from input string, handling various inconsistent formats.
Args:
input_string (str): The input string containing tool calls.
Returns:
list: A list of dictionaries representing the parsed tool calls.
"""
tool_calls = []
# Existing tag-based extraction (retain if needed)
complete_matches = self.complete_pattern.findall(input_string)
if complete_matches:
for match in complete_matches:
tool_calls.extend(self._extract_json_objects(match))
return tool_calls
partial_matches = self.partial_pattern.findall(input_string)
if partial_matches:
for match in partial_matches:
tool_calls.extend(self._extract_json_objects(match))
return tool_calls
# Fallback: Attempt to parse the entire string
tool_calls.extend(self._extract_json_objects(input_string))
return tool_calls
def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]:
"""
Extract and parse multiple JSON objects from a string.
"""
json_objects = []
potential_jsons = text.split(';')
for json_str in potential_jsons:
parsed_obj = self._clean_and_parse_json(json_str)
if parsed_obj:
json_objects.append(parsed_obj)
return json_objects
def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]:
"""
Clean and parse a JSON string, handling common formatting issues.
"""
try:
json_str = json_str.strip()
if json_str.startswith('{') or json_str.startswith('['):
return json.loads(json_str)
return None
except json.JSONDecodeError:
return None
def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool:
"""
Validate if a tool call has the required fields.
"""
return (
isinstance(tool_call, dict) and
'name' in tool_call and
isinstance(tool_call['name'], str)
)
def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]:
"""
Extract function call details from the response parts.
Args:
response_parts (list): The list of response parts from the chat model.
Returns:
dict: A dictionary containing the function name and flattened arguments.
"""
for part in response_parts:
# Debug print
print(f"Examining part: {part}")
print(f"Part type: {type(part)}")
# Check for function_call attribute
if hasattr(part, 'function_call') and part.function_call:
function_call = part.function_call
# Debug print
print(f"Function call: {function_call}")
print(f"Function call type: {type(function_call)}")
print(f"Function args: {function_call.args}")
# Extract function name
function_name = getattr(function_call, 'name', None)
if not function_name:
continue # Skip if function name is missing
# Extract function arguments
function_args = self._extract_function_args(function_call.args)
return {
"name": function_name,
"args": function_args
}
return {}