File size: 4,322 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import ast
import os

IGNORE_FUNCTIONS = [
    "_format_type",
    "_remove_additional_properties",
    "_remove_strict_from_schema",
    "filter_schema_fields",
    "text_completion",
    "_check_for_os_environ_vars",
    "clean_message",
    "unpack_defs",
    "convert_anyof_null_to_nullable", # has a set max depth
    "add_object_type",
    "strip_field",
    "_transform_prompt",
    "mask_dict",
    "_serialize",  # we now set a max depth for this
    "_sanitize_request_body_for_spend_logs_payload", # testing added for circular reference
    "_sanitize_value", # testing added for circular reference
    "set_schema_property_ordering", # testing added for infinite recursion
    "process_items", # testing added for infinite recursion + max depth set.
    "_can_object_call_model", # max depth set.
    "encode_unserializable_types", # max depth set.
    "filter_value_from_dict", # max depth set.
]


class RecursiveFunctionFinder(ast.NodeVisitor):
    def __init__(self):
        self.recursive_functions = []
        self.ignored_recursive_functions = []

    def visit_FunctionDef(self, node):
        # Check if the function calls itself
        if any(self._is_recursive_call(node, call) for call in ast.walk(node)):
            if node.name in IGNORE_FUNCTIONS:
                self.ignored_recursive_functions.append(node.name)
            else:
                self.recursive_functions.append(node.name)
        self.generic_visit(node)

    def _is_recursive_call(self, func_node, call_node):
        # Check if the call node is a function call
        if not isinstance(call_node, ast.Call):
            return False

        # Case 1: Direct function call (e.g., my_func())
        if isinstance(call_node.func, ast.Name) and call_node.func.id == func_node.name:
            return True

        # Case 2: Method call with self (e.g., self.my_func())
        if isinstance(call_node.func, ast.Attribute) and isinstance(
            call_node.func.value, ast.Name
        ):
            return (
                call_node.func.value.id == "self"
                and call_node.func.attr == func_node.name
            )

        return False


def find_recursive_functions_in_file(file_path):
    with open(file_path, "r") as file:
        tree = ast.parse(file.read(), filename=file_path)
    finder = RecursiveFunctionFinder()
    finder.visit(tree)
    return finder.recursive_functions, finder.ignored_recursive_functions


def find_recursive_functions_in_directory(directory):
    recursive_functions = {}
    ignored_recursive_functions = {}
    for root, _, files in os.walk(directory):
        for file in files:
            print("file: ", file)
            if file.endswith(".py"):
                file_path = os.path.join(root, file)
                functions, ignored = find_recursive_functions_in_file(file_path)
                if functions:
                    recursive_functions[file_path] = functions
                if ignored:
                    ignored_recursive_functions[file_path] = ignored
    return recursive_functions, ignored_recursive_functions

if __name__ == "__main__":
    # Example usage
    # raise exception if any recursive functions are found, except for the ignored ones
    # this is used in the CI/CD pipeline to prevent recursive functions from being merged

    directory_path = "./litellm"
    recursive_functions, ignored_recursive_functions = (
        find_recursive_functions_in_directory(directory_path)
    )
    print("UNIGNORED RECURSIVE FUNCTIONS: ", recursive_functions)
    print("IGNORED RECURSIVE FUNCTIONS: ", ignored_recursive_functions)

    if len(recursive_functions) > 0:
        # raise exception if any recursive functions are found
        for file, functions in recursive_functions.items():
            print(
                    f"🚨 Unignored recursive functions found in {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
                )
        file, functions = list(recursive_functions.items())[0]
        raise Exception(
                f"🚨 Unignored recursive functions found include {file}: {functions}. THIS IS REALLY BAD, it has caused CPU Usage spikes in the past. Only keep this if it's ABSOLUTELY necessary."
            )