File size: 3,728 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import ast
import os
import re


def find_azure_files(base_dir):
    """
    Find all Python files in the Azure directory.
    """
    azure_files = []
    for root, _, files in os.walk(base_dir):
        for file in files:
            if file.endswith(".py"):
                azure_files.append(os.path.join(root, file))
    return azure_files


def check_direct_instantiation(file_path):
    """
    Check if a file directly instantiates AzureOpenAI or AsyncAzureOpenAI
    outside of the BaseAzureLLM class methods.
    """
    with open(file_path, "r") as file:
        content = file.read()

    # Parse the file
    tree = ast.parse(content)

    # Track issues found
    issues = []

    # Find all class definitions
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef):
            class_name = node.name

            # Skip BaseAzureLLM class since it's allowed to define the client creation methods
            if class_name == "BaseAzureLLM":
                continue

            # Check method bodies for direct instantiation
            for method in node.body:
                if isinstance(method, ast.FunctionDef) or isinstance(
                    method, ast.AsyncFunctionDef
                ):
                    method_name = method.name

                    # Skip methods that are specifically for client creation
                    if method_name in [
                        "get_azure_openai_client",
                        "initialize_azure_sdk_client",
                    ]:
                        continue

                    # Look for direct instantiation in the method body
                    for subnode in ast.walk(method):
                        if isinstance(subnode, ast.Call):
                            if hasattr(subnode, "func") and hasattr(subnode.func, "id"):
                                if subnode.func.id in [
                                    "AzureOpenAI",
                                    "AsyncAzureOpenAI",
                                ]:
                                    issues.append(
                                        f"Direct instantiation of {subnode.func.id} in {class_name}.{method_name}"
                                    )
                            elif hasattr(subnode, "func") and hasattr(
                                subnode.func, "attr"
                            ):
                                if subnode.func.attr in [
                                    "AzureOpenAI",
                                    "AsyncAzureOpenAI",
                                ]:
                                    issues.append(
                                        f"Direct instantiation of {subnode.func.attr} in {class_name}.{method_name}"
                                    )

    return issues


def main():
    """
    Main function to run the test.
    """
    # local
    base_dir = "../../litellm/llms/azure"
    azure_files = find_azure_files(base_dir)
    print(f"Found {len(azure_files)} Azure Python files to check")

    all_issues = []

    for file_path in azure_files:
        issues = check_direct_instantiation(file_path)
        if issues:
            all_issues.extend([f"{file_path}: {issue}" for issue in issues])

    if all_issues:
        print("Found direct instantiations of AzureOpenAI or AsyncAzureOpenAI:")
        for issue in all_issues:
            print(f"  - {issue}")
        raise Exception(
            f"Found {len(all_issues)} direct instantiations of AzureOpenAI or AsyncAzureOpenAI classes. Use get_azure_openai_client instead."
        )
    else:
        print("All Azure modules are correctly using get_azure_openai_client!")


if __name__ == "__main__":
    main()