Spaces:
Build error
Build error
File size: 5,716 Bytes
8d2f9d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import re
import json
from typing import List, Dict, Any, Optional
class ToolCallExtractor:
def __init__(self):
# Existing regex patterns (retain if needed for other formats)
self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL)
self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL)
def _extract_function_args(self, args) -> Dict[str, Any]:
"""
Flatten the nested function args structure for Google AI protobuf types.
"""
flattened_args = {}
try:
# Explicitly check for fields
if hasattr(args, 'fields'):
# Iterate through fields using to_dict() to convert protobuf to dict
for field in args.fields:
key = field.key
value = field.value
# Additional debugging
print(f"Field key: {key}")
print(f"Field value type: {type(value)}")
print(f"Field value: {value}")
# Extract string value
if hasattr(value, 'string_value'):
flattened_args[key] = value.string_value
print(f"Extracted string value: {value.string_value}")
elif hasattr(value, 'number_value'):
flattened_args[key] = value.number_value
elif hasattr(value, 'bool_value') and value.bool_value is not None:
flattened_args[key] = value.bool_value
# Added additional debug information
print(f"Final flattened args: {flattened_args}")
except Exception as e:
print(f"Error extracting function args: {e}")
return flattened_args
def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]:
"""
Extract tool calls from input string, handling various inconsistent formats.
Args:
input_string (str): The input string containing tool calls.
Returns:
list: A list of dictionaries representing the parsed tool calls.
"""
tool_calls = []
# Existing tag-based extraction (retain if needed)
complete_matches = self.complete_pattern.findall(input_string)
if complete_matches:
for match in complete_matches:
tool_calls.extend(self._extract_json_objects(match))
return tool_calls
partial_matches = self.partial_pattern.findall(input_string)
if partial_matches:
for match in partial_matches:
tool_calls.extend(self._extract_json_objects(match))
return tool_calls
# Fallback: Attempt to parse the entire string
tool_calls.extend(self._extract_json_objects(input_string))
return tool_calls
def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]:
"""
Extract and parse multiple JSON objects from a string.
"""
json_objects = []
potential_jsons = text.split(';')
for json_str in potential_jsons:
parsed_obj = self._clean_and_parse_json(json_str)
if parsed_obj:
json_objects.append(parsed_obj)
return json_objects
def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]:
"""
Clean and parse a JSON string, handling common formatting issues.
"""
try:
json_str = json_str.strip()
if json_str.startswith('{') or json_str.startswith('['):
return json.loads(json_str)
return None
except json.JSONDecodeError:
return None
def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool:
"""
Validate if a tool call has the required fields.
"""
return (
isinstance(tool_call, dict) and
'name' in tool_call and
isinstance(tool_call['name'], str)
)
def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]:
"""
Extract function call details from the response parts.
Args:
response_parts (list): The list of response parts from the chat model.
Returns:
dict: A dictionary containing the function name and flattened arguments.
"""
for part in response_parts:
# Debug print
print(f"Examining part: {part}")
print(f"Part type: {type(part)}")
# Check for function_call attribute
if hasattr(part, 'function_call') and part.function_call:
function_call = part.function_call
# Debug print
print(f"Function call: {function_call}")
print(f"Function call type: {type(function_call)}")
print(f"Function args: {function_call.args}")
# Extract function name
function_name = getattr(function_call, 'name', None)
if not function_name:
continue # Skip if function name is missing
# Extract function arguments
function_args = self._extract_function_args(function_call.args)
return {
"name": function_name,
"args": function_args
}
return {}
|