File size: 5,990 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import ast
from typing import List, Set, Dict, Optional
import sys


class ConfigChecker(ast.NodeVisitor):
    def __init__(self):
        self.errors: List[str] = []
        self.current_provider_block: Optional[str] = None
        self.param_assignments: Dict[str, Set[str]] = {}
        self.map_openai_calls: Set[str] = set()
        self.class_inheritance: Dict[str, List[str]] = {}

    def get_full_name(self, node):
        """Recursively extract the full name from a node."""
        if isinstance(node, ast.Name):
            return node.id
        elif isinstance(node, ast.Attribute):
            base = self.get_full_name(node.value)
            if base:
                return f"{base}.{node.attr}"
        return None

    def visit_ClassDef(self, node: ast.ClassDef):
        # Record class inheritance
        bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
        print(f"Found class {node.name} with bases {bases}")
        self.class_inheritance[node.name] = bases
        self.generic_visit(node)

    def visit_Call(self, node: ast.Call):
        # Check for map_openai_params calls
        if (
            isinstance(node.func, ast.Attribute)
            and node.func.attr == "map_openai_params"
        ):
            if isinstance(node.func.value, ast.Name):
                config_name = node.func.value.id
                self.map_openai_calls.add(config_name)
        self.generic_visit(node)

    def visit_If(self, node: ast.If):
        # Detect custom_llm_provider blocks
        provider = self._extract_provider_from_if(node)
        if provider:
            old_provider = self.current_provider_block
            self.current_provider_block = provider
            self.generic_visit(node)
            self.current_provider_block = old_provider
        else:
            self.generic_visit(node)

    def visit_Assign(self, node: ast.Assign):
        # Track assignments to optional_params
        if self.current_provider_block and len(node.targets) == 1:
            target = node.targets[0]
            if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
                if target.value.id == "optional_params":
                    if isinstance(target.slice, ast.Constant):
                        key = target.slice.value
                        if self.current_provider_block not in self.param_assignments:
                            self.param_assignments[self.current_provider_block] = set()
                        self.param_assignments[self.current_provider_block].add(key)
        self.generic_visit(node)

    def _extract_provider_from_if(self, node: ast.If) -> Optional[str]:
        """Extract the provider name from an if condition checking custom_llm_provider"""
        if isinstance(node.test, ast.Compare):
            if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq):
                if (
                    isinstance(node.test.left, ast.Name)
                    and node.test.left.id == "custom_llm_provider"
                ):
                    if isinstance(node.test.comparators[0], ast.Constant):
                        return node.test.comparators[0].value
        return None

    def check_patterns(self) -> List[str]:
        # Check if all configs using map_openai_params inherit from BaseConfig
        for config_name in self.map_openai_calls:
            print(f"Checking config: {config_name}")
            if (
                config_name not in self.class_inheritance
                or "BaseConfig" not in self.class_inheritance[config_name]
            ):
                # Retrieve the associated class name, if any
                class_name = next(
                    (
                        cls
                        for cls, bases in self.class_inheritance.items()
                        if config_name in bases
                    ),
                    "Unknown Class",
                )
                self.errors.append(
                    f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. "
                    f"It is used in the class: {class_name}"
                )

        # Check for parameter assignments in provider blocks
        for provider, params in self.param_assignments.items():
            # You can customize which parameters should raise warnings for each provider
            for param in params:
                if param not in self._get_allowed_params(provider):
                    self.errors.append(
                        f"Warning: Parameter '{param}' is directly assigned in {provider} block. "
                        f"Consider using a config class instead."
                    )

        return self.errors

    def _get_allowed_params(self, provider: str) -> Set[str]:
        """Define allowed direct parameter assignments for each provider"""
        # You can customize this based on your requirements
        common_allowed = {"stream", "api_key", "api_base"}
        provider_specific = {
            "anthropic": {"api_version"},
            "openai": {"organization"},
            # Add more providers and their allowed params here
        }
        return common_allowed.union(provider_specific.get(provider, set()))


def check_file(file_path: str) -> List[str]:
    with open(file_path, "r") as file:
        tree = ast.parse(file.read())

    checker = ConfigChecker()
    for node in tree.body:
        if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params":
            checker.visit(node)
            break  # No need to visit other functions
    return checker.check_patterns()


def main():
    file_path = "../../litellm/utils.py"
    errors = check_file(file_path)

    if errors:
        print("\nFound the following issues:")
        for error in errors:
            print(f"- {error}")
        sys.exit(1)
    else:
        print("No issues found!")
        sys.exit(0)


if __name__ == "__main__":
    main()