Spaces:
Configuration error
Configuration error
File size: 6,751 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import ast
from typing import List, Dict, Set, Optional
import os
from dataclasses import dataclass
import argparse
import re
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
@dataclass
class FunctionInfo:
"""Store function information."""
name: str
docstring: Optional[str]
parameters: Set[str]
file_path: str
line_number: int
class FastAPIDocVisitor(ast.NodeVisitor):
"""AST visitor to find FastAPI endpoint functions."""
def __init__(self, target_functions: Set[str]):
self.target_functions = target_functions
self.functions: Dict[str, FunctionInfo] = {}
self.current_file = ""
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
"""Visit function definitions (both async and sync) and collect info if they match target functions."""
if node.name in self.target_functions:
# Extract docstring
docstring = ast.get_docstring(node)
# Extract parameters
parameters = set()
for arg in node.args.args:
if arg.annotation is not None:
# Get the parameter type from annotation
if isinstance(arg.annotation, ast.Name):
parameters.add((arg.arg, arg.annotation.id))
elif isinstance(arg.annotation, ast.Subscript):
if isinstance(arg.annotation.value, ast.Name):
parameters.add((arg.arg, arg.annotation.value.id))
self.functions[node.name] = FunctionInfo(
name=node.name,
docstring=docstring,
parameters=parameters,
file_path=self.current_file,
line_number=node.lineno,
)
# Also need to add this to handle async functions
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
"""Handle async functions by delegating to the regular function visitor."""
return self.visit_FunctionDef(node)
def find_functions_in_file(
file_path: str, target_functions: Set[str]
) -> Dict[str, FunctionInfo]:
"""Find target functions in a Python file using AST."""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
visitor = FastAPIDocVisitor(target_functions)
visitor.current_file = file_path
tree = ast.parse(content)
visitor.visit(tree)
return visitor.functions
except Exception as e:
print(f"Error parsing {file_path}: {str(e)}")
return {}
def extract_docstring_params(docstring: Optional[str]) -> Set[str]:
"""Extract parameter names from docstring."""
if not docstring:
return set()
params = set()
# Match parameters in format:
# - parameter_name: description
# or
# parameter_name: description
param_pattern = r"-?\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\([^)]*\))?\s*:"
for match in re.finditer(param_pattern, docstring):
params.add(match.group(1))
return params
def analyze_function(func_info: FunctionInfo) -> Dict:
"""Analyze function documentation and return validation results."""
docstring_params = extract_docstring_params(func_info.docstring)
print(f"func_info.parameters: {func_info.parameters}")
pydantic_params = set()
for name, type_name in func_info.parameters:
if type_name.endswith("Request") or type_name.endswith("Response"):
pydantic_model = getattr(litellm.proxy._types, type_name, None)
if pydantic_model is not None:
for param in pydantic_model.model_fields.keys():
pydantic_params.add(param)
print(f"pydantic_params: {pydantic_params}")
missing_params = pydantic_params - docstring_params
return {
"function": func_info.name,
"file_path": func_info.file_path,
"line_number": func_info.line_number,
"has_docstring": bool(func_info.docstring),
"pydantic_params": list(pydantic_params),
"documented_params": list(docstring_params),
"missing_params": list(missing_params),
"is_valid": len(missing_params) == 0,
}
def print_validation_results(results: Dict) -> None:
"""Print validation results in a readable format."""
print(f"\nChecking function: {results['function']}")
print(f"File: {results['file_path']}:{results['line_number']}")
print("-" * 50)
if not results["has_docstring"]:
print("❌ No docstring found!")
return
if not results["pydantic_params"]:
print("ℹ️ No Pydantic input models found.")
return
if results["is_valid"]:
print("✅ All Pydantic parameters are documented!")
else:
print("❌ Missing documentation for parameters:")
for param in sorted(results["missing_params"]):
print(f" - {param}")
def main():
function_names = [
"new_end_user",
"end_user_info",
"update_end_user",
"delete_end_user",
"generate_key_fn",
"info_key_fn",
"update_key_fn",
"delete_key_fn",
"new_user",
"new_team",
"team_info",
"update_team",
"delete_team",
"new_organization",
"update_organization",
"delete_organization",
"list_organization",
"user_update",
"new_budget",
"info_budget",
"update_budget",
"delete_budget",
"list_budget",
]
# directory = "../../litellm/proxy/management_endpoints" # LOCAL
directory = "./litellm/proxy/management_endpoints"
# Convert function names to set for faster lookup
target_functions = set(function_names)
found_functions: Dict[str, FunctionInfo] = {}
# Walk through directory
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
found = find_functions_in_file(file_path, target_functions)
found_functions.update(found)
# Analyze and output results
for func_name in function_names:
if func_name in found_functions:
result = analyze_function(found_functions[func_name])
if not result["is_valid"]:
raise Exception(print_validation_results(result))
# results.append(result)
# print_validation_results(result)
# # Exit with error code if any validation failed
# if any(not r["is_valid"] for r in results):
# exit(1)
if __name__ == "__main__":
main()
|