File size: 6,751 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import ast
from typing import List, Dict, Set, Optional
import os
from dataclasses import dataclass
import argparse
import re
import sys

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import litellm


@dataclass
class FunctionInfo:
    """Store function information."""

    name: str
    docstring: Optional[str]
    parameters: Set[str]
    file_path: str
    line_number: int


class FastAPIDocVisitor(ast.NodeVisitor):
    """AST visitor to find FastAPI endpoint functions."""

    def __init__(self, target_functions: Set[str]):
        self.target_functions = target_functions
        self.functions: Dict[str, FunctionInfo] = {}
        self.current_file = ""

    def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
        """Visit function definitions (both async and sync) and collect info if they match target functions."""
        if node.name in self.target_functions:
            # Extract docstring
            docstring = ast.get_docstring(node)

            # Extract parameters
            parameters = set()
            for arg in node.args.args:
                if arg.annotation is not None:
                    # Get the parameter type from annotation
                    if isinstance(arg.annotation, ast.Name):
                        parameters.add((arg.arg, arg.annotation.id))
                    elif isinstance(arg.annotation, ast.Subscript):
                        if isinstance(arg.annotation.value, ast.Name):
                            parameters.add((arg.arg, arg.annotation.value.id))

            self.functions[node.name] = FunctionInfo(
                name=node.name,
                docstring=docstring,
                parameters=parameters,
                file_path=self.current_file,
                line_number=node.lineno,
            )

    # Also need to add this to handle async functions
    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
        """Handle async functions by delegating to the regular function visitor."""
        return self.visit_FunctionDef(node)


def find_functions_in_file(
    file_path: str, target_functions: Set[str]
) -> Dict[str, FunctionInfo]:
    """Find target functions in a Python file using AST."""
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            content = f.read()

        visitor = FastAPIDocVisitor(target_functions)
        visitor.current_file = file_path
        tree = ast.parse(content)
        visitor.visit(tree)
        return visitor.functions

    except Exception as e:
        print(f"Error parsing {file_path}: {str(e)}")
        return {}


def extract_docstring_params(docstring: Optional[str]) -> Set[str]:
    """Extract parameter names from docstring."""
    if not docstring:
        return set()

    params = set()
    # Match parameters in format:
    # - parameter_name: description
    # or
    # parameter_name: description
    param_pattern = r"-?\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\([^)]*\))?\s*:"

    for match in re.finditer(param_pattern, docstring):
        params.add(match.group(1))

    return params


def analyze_function(func_info: FunctionInfo) -> Dict:
    """Analyze function documentation and return validation results."""

    docstring_params = extract_docstring_params(func_info.docstring)

    print(f"func_info.parameters: {func_info.parameters}")
    pydantic_params = set()

    for name, type_name in func_info.parameters:
        if type_name.endswith("Request") or type_name.endswith("Response"):
            pydantic_model = getattr(litellm.proxy._types, type_name, None)
            if pydantic_model is not None:
                for param in pydantic_model.model_fields.keys():
                    pydantic_params.add(param)

    print(f"pydantic_params: {pydantic_params}")

    missing_params = pydantic_params - docstring_params

    return {
        "function": func_info.name,
        "file_path": func_info.file_path,
        "line_number": func_info.line_number,
        "has_docstring": bool(func_info.docstring),
        "pydantic_params": list(pydantic_params),
        "documented_params": list(docstring_params),
        "missing_params": list(missing_params),
        "is_valid": len(missing_params) == 0,
    }


def print_validation_results(results: Dict) -> None:
    """Print validation results in a readable format."""
    print(f"\nChecking function: {results['function']}")
    print(f"File: {results['file_path']}:{results['line_number']}")
    print("-" * 50)

    if not results["has_docstring"]:
        print("❌ No docstring found!")
        return

    if not results["pydantic_params"]:
        print("ℹ️  No Pydantic input models found.")
        return

    if results["is_valid"]:
        print("✅ All Pydantic parameters are documented!")
    else:
        print("❌ Missing documentation for parameters:")
        for param in sorted(results["missing_params"]):
            print(f"  - {param}")


def main():
    function_names = [
        "new_end_user",
        "end_user_info",
        "update_end_user",
        "delete_end_user",
        "generate_key_fn",
        "info_key_fn",
        "update_key_fn",
        "delete_key_fn",
        "new_user",
        "new_team",
        "team_info",
        "update_team",
        "delete_team",
        "new_organization",
        "update_organization",
        "delete_organization",
        "list_organization",
        "user_update",
        "new_budget",
        "info_budget",
        "update_budget",
        "delete_budget",
        "list_budget",
    ]
    # directory = "../../litellm/proxy/management_endpoints"  # LOCAL
    directory = "./litellm/proxy/management_endpoints"

    # Convert function names to set for faster lookup
    target_functions = set(function_names)
    found_functions: Dict[str, FunctionInfo] = {}

    # Walk through directory
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".py"):
                file_path = os.path.join(root, file)
                found = find_functions_in_file(file_path, target_functions)
                found_functions.update(found)

    # Analyze and output results
    for func_name in function_names:
        if func_name in found_functions:
            result = analyze_function(found_functions[func_name])
            if not result["is_valid"]:
                raise Exception(print_validation_results(result))
    #         results.append(result)
    #         print_validation_results(result)

    # # Exit with error code if any validation failed
    # if any(not r["is_valid"] for r in results):
    #     exit(1)


if __name__ == "__main__":
    main()