File size: 4,551 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Test that all cache calls in async functions in router_strategy/ are async

"""

import os
import sys
from typing import Dict, List, Tuple
import ast

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import os


class AsyncCacheCallVisitor(ast.NodeVisitor):
    def __init__(self):
        self.async_functions: Dict[str, List[Tuple[str, int]]] = {}
        self.current_function = None

    def visit_AsyncFunctionDef(self, node):
        """Visit async function definitions and store their cache calls"""
        self.current_function = node.name
        self.async_functions[node.name] = []
        self.generic_visit(node)
        self.current_function = None

    def visit_Call(self, node):
        """Visit function calls and check for cache operations"""
        if self.current_function is not None:
            # Check if it's a cache-related call
            if isinstance(node.func, ast.Attribute):
                method_name = node.func.attr
                if any(keyword in method_name.lower() for keyword in ["cache"]):
                    # Get the full method call path
                    if isinstance(node.func.value, ast.Name):
                        full_call = f"{node.func.value.id}.{method_name}"
                    elif isinstance(node.func.value, ast.Attribute):
                        # Handle nested attributes like self.router_cache.get
                        parts = []
                        current = node.func.value
                        while isinstance(current, ast.Attribute):
                            parts.append(current.attr)
                            current = current.value
                        if isinstance(current, ast.Name):
                            parts.append(current.id)
                        parts.reverse()
                        parts.append(method_name)
                        full_call = ".".join(parts)
                    else:
                        full_call = method_name
                    # Store both the call and its line number
                    self.async_functions[self.current_function].append(
                        (full_call, node.lineno)
                    )
        self.generic_visit(node)


def get_python_files(directory: str) -> List[str]:
    """Get all Python files in the router_strategy directory"""
    python_files = []
    for file in os.listdir(directory):
        if file.endswith(".py") and not file.startswith("__"):
            python_files.append(os.path.join(directory, file))
    return python_files


def analyze_file(file_path: str) -> Dict[str, List[Tuple[str, int]]]:
    """Analyze a Python file for async functions and their cache calls"""
    with open(file_path, "r") as file:
        tree = ast.parse(file.read())

    visitor = AsyncCacheCallVisitor()
    visitor.visit(tree)
    return visitor.async_functions


def test_router_strategy_async_cache_calls():
    """Test that all cache calls in async functions are properly async"""
    strategy_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
        "litellm",
        "router_strategy",
    )

    # Get all Python files in the router_strategy directory
    python_files = get_python_files(strategy_dir)

    print("python files:", python_files)

    all_async_functions: Dict[str, Dict[str, List[Tuple[str, int]]]] = {}

    for file_path in python_files:
        file_name = os.path.basename(file_path)
        async_functions = analyze_file(file_path)

        if async_functions:
            all_async_functions[file_name] = async_functions
            print(f"\nAnalyzing {file_name}:")

            for func_name, cache_calls in async_functions.items():
                print(f"\nAsync function: {func_name}")
                print(f"Cache calls found: {cache_calls}")

                # Assert that cache calls in async functions use async methods
                for call, line_number in cache_calls:
                    if any(keyword in call.lower() for keyword in ["cache"]):
                        assert (
                            "async" in call.lower()
                        ), f"VIOLATION: Cache call '{call}' in async function '{func_name}' should be async. file path: {file_path}, line number: {line_number}"

    # Assert we found async functions to analyze
    assert (
        len(all_async_functions) > 0
    ), "No async functions found in router_strategy directory"


if __name__ == "__main__":
    test_router_strategy_async_cache_calls()