File size: 5,716 Bytes
8d2f9d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import re
import json
from typing import List, Dict, Any, Optional

class ToolCallExtractor:
    def __init__(self):
        # Existing regex patterns (retain if needed for other formats)
        self.complete_pattern = re.compile(r'<\|python_tag\|>(.*?)<\|eom_id\|>', re.DOTALL)
        self.partial_pattern = re.compile(r'(.*?)<\|eom_id\|>', re.DOTALL)
        
    def _extract_function_args(self, args) -> Dict[str, Any]:
        """
        Flatten the nested function args structure for Google AI protobuf types.
        """
        flattened_args = {}
        
        try:
            # Explicitly check for fields
            if hasattr(args, 'fields'):
                # Iterate through fields using to_dict() to convert protobuf to dict
                for field in args.fields:
                    key = field.key
                    value = field.value
                    
                    # Additional debugging
                    print(f"Field key: {key}")
                    print(f"Field value type: {type(value)}")
                    print(f"Field value: {value}")
                    
                    # Extract string value
                    if hasattr(value, 'string_value'):
                        flattened_args[key] = value.string_value
                        print(f"Extracted string value: {value.string_value}")
                    elif hasattr(value, 'number_value'):
                        flattened_args[key] = value.number_value
                    elif hasattr(value, 'bool_value') and value.bool_value is not None:
                        flattened_args[key] = value.bool_value
            
            # Added additional debug information
            print(f"Final flattened args: {flattened_args}")
        
        except Exception as e:
            print(f"Error extracting function args: {e}")
        
        return flattened_args

    def extract_tool_calls(self, input_string: str) -> List[Dict[str, Any]]:
        """
        Extract tool calls from input string, handling various inconsistent formats.
        
        Args:
            input_string (str): The input string containing tool calls.
        
        Returns:
            list: A list of dictionaries representing the parsed tool calls.
        """
        tool_calls = []
        
        # Existing tag-based extraction (retain if needed)
        complete_matches = self.complete_pattern.findall(input_string)
        if complete_matches:
            for match in complete_matches:
                tool_calls.extend(self._extract_json_objects(match))
            return tool_calls
            
        partial_matches = self.partial_pattern.findall(input_string)
        if partial_matches:
            for match in partial_matches:
                tool_calls.extend(self._extract_json_objects(match))
            return tool_calls
            
        # Fallback: Attempt to parse the entire string
        tool_calls.extend(self._extract_json_objects(input_string))
        
        return tool_calls

    def _extract_json_objects(self, text: str) -> List[Dict[str, Any]]:
        """
        Extract and parse multiple JSON objects from a string.
        """
        json_objects = []
        potential_jsons = text.split(';')
        
        for json_str in potential_jsons:
            parsed_obj = self._clean_and_parse_json(json_str)
            if parsed_obj:
                json_objects.append(parsed_obj)
                
        return json_objects

    def _clean_and_parse_json(self, json_str: str) -> Optional[Dict[str, Any]]:
        """
        Clean and parse a JSON string, handling common formatting issues.
        """
        try:
            json_str = json_str.strip()
            if json_str.startswith('{') or json_str.startswith('['):
                return json.loads(json_str)
            return None
        except json.JSONDecodeError:
            return None

    def validate_tool_call(self, tool_call: Dict[str, Any]) -> bool:
        """
        Validate if a tool call has the required fields.
        """
        return (
            isinstance(tool_call, dict) and
            'name' in tool_call and
            isinstance(tool_call['name'], str)
        )
    
    def extract_function_call(self, response_parts: List[Any]) -> Dict[str, Any]:
        """
        Extract function call details from the response parts.
        
        Args:
            response_parts (list): The list of response parts from the chat model.
        
        Returns:
            dict: A dictionary containing the function name and flattened arguments.
        """
        for part in response_parts:
            # Debug print
            print(f"Examining part: {part}")
            print(f"Part type: {type(part)}")
            
            # Check for function_call attribute
            if hasattr(part, 'function_call') and part.function_call:
                function_call = part.function_call
                
                # Debug print
                print(f"Function call: {function_call}")
                print(f"Function call type: {type(function_call)}")
                print(f"Function args: {function_call.args}")
                
                # Extract function name
                function_name = getattr(function_call, 'name', None)
                if not function_name:
                    continue  # Skip if function name is missing
                
                # Extract function arguments
                function_args = self._extract_function_args(function_call.args)
                
                return {
                    "name": function_name,
                    "args": function_args
                }
        return {}