Spaces:
Build error
Build error
import re | |
import json | |
from typing import List, Dict, Any, Optional | |
class ToolCallExtractor: | |
def __init__(self): | |
# Existing regex patterns (retain if needed for other formats) | |
self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL) | |
self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL) | |
def _extract_function_args(self, args) -> Dict[str, Any]: | |
""" | |
Flatten the nested function args structure for Google AI protobuf types. | |
""" | |
flattened_args = {} | |
try: | |
# Explicitly check for fields | |
if hasattr(args, 'fields'): | |
# Iterate through fields using to_dict() to convert protobuf to dict | |
for field in args.fields: | |
key = field.key | |
value = field.value | |
# Additional debugging | |
print(f"Field key: {key}") | |
print(f"Field value type: {type(value)}") | |
print(f"Field value: {value}") | |
# Extract string value | |
if hasattr(value, 'string_value'): | |
flattened_args[key] = value.string_value | |
print(f"Extracted string value: {value.string_value}") | |
elif hasattr(value, 'number_value'): | |
flattened_args[key] = value.number_value | |
elif hasattr(value, 'bool_value') and value.bool_value is not None: | |
flattened_args[key] = value.bool_value | |
# Added additional debug information | |
print(f"Final flattened args: {flattened_args}") | |
except Exception as e: | |
print(f"Error extracting function args: {e}") | |
return flattened_args | |
def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]: | |
""" | |
Extract tool calls from input string, handling various inconsistent formats. | |
Args: | |
input_string (str): The input string containing tool calls. | |
Returns: | |
list: A list of dictionaries representing the parsed tool calls. | |
""" | |
tool_calls = [] | |
# Existing tag-based extraction (retain if needed) | |
complete_matches = self.complete_pattern.findall(input_string) | |
if complete_matches: | |
for match in complete_matches: | |
tool_calls.extend(self._extract_json_objects(match)) | |
return tool_calls | |
partial_matches = self.partial_pattern.findall(input_string) | |
if partial_matches: | |
for match in partial_matches: | |
tool_calls.extend(self._extract_json_objects(match)) | |
return tool_calls | |
# Fallback: Attempt to parse the entire string | |
tool_calls.extend(self._extract_json_objects(input_string)) | |
return tool_calls | |
def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]: | |
""" | |
Extract and parse multiple JSON objects from a string. | |
""" | |
json_objects = [] | |
potential_jsons = text.split(';') | |
for json_str in potential_jsons: | |
parsed_obj = self._clean_and_parse_json(json_str) | |
if parsed_obj: | |
json_objects.append(parsed_obj) | |
return json_objects | |
def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]: | |
""" | |
Clean and parse a JSON string, handling common formatting issues. | |
""" | |
try: | |
json_str = json_str.strip() | |
if json_str.startswith('{') or json_str.startswith('['): | |
return json.loads(json_str) | |
return None | |
except json.JSONDecodeError: | |
return None | |
def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool: | |
""" | |
Validate if a tool call has the required fields. | |
""" | |
return ( | |
isinstance(tool_call, dict) and | |
'name' in tool_call and | |
isinstance(tool_call['name'], str) | |
) | |
def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]: | |
""" | |
Extract function call details from the response parts. | |
Args: | |
response_parts (list): The list of response parts from the chat model. | |
Returns: | |
dict: A dictionary containing the function name and flattened arguments. | |
""" | |
for part in response_parts: | |
# Debug print | |
print(f"Examining part: {part}") | |
print(f"Part type: {type(part)}") | |
# Check for function_call attribute | |
if hasattr(part, 'function_call') and part.function_call: | |
function_call = part.function_call | |
# Debug print | |
print(f"Function call: {function_call}") | |
print(f"Function call type: {type(function_call)}") | |
print(f"Function args: {function_call.args}") | |
# Extract function name | |
function_name = getattr(function_call, 'name', None) | |
if not function_name: | |
continue # Skip if function name is missing | |
# Extract function arguments | |
function_args = self._extract_function_args(function_call.args) | |
return { | |
"name": function_name, | |
"args": function_args | |
} | |
return {} | |