Spaces:
Configuration error
Configuration error
import ast | |
from typing import List, Set, Dict, Optional | |
import sys | |
class ConfigChecker(ast.NodeVisitor): | |
def __init__(self): | |
self.errors: List[str] = [] | |
self.current_provider_block: Optional[str] = None | |
self.param_assignments: Dict[str, Set[str]] = {} | |
self.map_openai_calls: Set[str] = set() | |
self.class_inheritance: Dict[str, List[str]] = {} | |
def get_full_name(self, node): | |
"""Recursively extract the full name from a node.""" | |
if isinstance(node, ast.Name): | |
return node.id | |
elif isinstance(node, ast.Attribute): | |
base = self.get_full_name(node.value) | |
if base: | |
return f"{base}.{node.attr}" | |
return None | |
def visit_ClassDef(self, node: ast.ClassDef): | |
# Record class inheritance | |
bases = [base.id for base in node.bases if isinstance(base, ast.Name)] | |
print(f"Found class {node.name} with bases {bases}") | |
self.class_inheritance[node.name] = bases | |
self.generic_visit(node) | |
def visit_Call(self, node: ast.Call): | |
# Check for map_openai_params calls | |
if ( | |
isinstance(node.func, ast.Attribute) | |
and node.func.attr == "map_openai_params" | |
): | |
if isinstance(node.func.value, ast.Name): | |
config_name = node.func.value.id | |
self.map_openai_calls.add(config_name) | |
self.generic_visit(node) | |
def visit_If(self, node: ast.If): | |
# Detect custom_llm_provider blocks | |
provider = self._extract_provider_from_if(node) | |
if provider: | |
old_provider = self.current_provider_block | |
self.current_provider_block = provider | |
self.generic_visit(node) | |
self.current_provider_block = old_provider | |
else: | |
self.generic_visit(node) | |
def visit_Assign(self, node: ast.Assign): | |
# Track assignments to optional_params | |
if self.current_provider_block and len(node.targets) == 1: | |
target = node.targets[0] | |
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name): | |
if target.value.id == "optional_params": | |
if isinstance(target.slice, ast.Constant): | |
key = target.slice.value | |
if self.current_provider_block not in self.param_assignments: | |
self.param_assignments[self.current_provider_block] = set() | |
self.param_assignments[self.current_provider_block].add(key) | |
self.generic_visit(node) | |
def _extract_provider_from_if(self, node: ast.If) -> Optional[str]: | |
"""Extract the provider name from an if condition checking custom_llm_provider""" | |
if isinstance(node.test, ast.Compare): | |
if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq): | |
if ( | |
isinstance(node.test.left, ast.Name) | |
and node.test.left.id == "custom_llm_provider" | |
): | |
if isinstance(node.test.comparators[0], ast.Constant): | |
return node.test.comparators[0].value | |
return None | |
def check_patterns(self) -> List[str]: | |
# Check if all configs using map_openai_params inherit from BaseConfig | |
for config_name in self.map_openai_calls: | |
print(f"Checking config: {config_name}") | |
if ( | |
config_name not in self.class_inheritance | |
or "BaseConfig" not in self.class_inheritance[config_name] | |
): | |
# Retrieve the associated class name, if any | |
class_name = next( | |
( | |
cls | |
for cls, bases in self.class_inheritance.items() | |
if config_name in bases | |
), | |
"Unknown Class", | |
) | |
self.errors.append( | |
f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. " | |
f"It is used in the class: {class_name}" | |
) | |
# Check for parameter assignments in provider blocks | |
for provider, params in self.param_assignments.items(): | |
# You can customize which parameters should raise warnings for each provider | |
for param in params: | |
if param not in self._get_allowed_params(provider): | |
self.errors.append( | |
f"Warning: Parameter '{param}' is directly assigned in {provider} block. " | |
f"Consider using a config class instead." | |
) | |
return self.errors | |
def _get_allowed_params(self, provider: str) -> Set[str]: | |
"""Define allowed direct parameter assignments for each provider""" | |
# You can customize this based on your requirements | |
common_allowed = {"stream", "api_key", "api_base"} | |
provider_specific = { | |
"anthropic": {"api_version"}, | |
"openai": {"organization"}, | |
# Add more providers and their allowed params here | |
} | |
return common_allowed.union(provider_specific.get(provider, set())) | |
def check_file(file_path: str) -> List[str]: | |
with open(file_path, "r") as file: | |
tree = ast.parse(file.read()) | |
checker = ConfigChecker() | |
for node in tree.body: | |
if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params": | |
checker.visit(node) | |
break # No need to visit other functions | |
return checker.check_patterns() | |
def main(): | |
file_path = "../../litellm/utils.py" | |
errors = check_file(file_path) | |
if errors: | |
print("\nFound the following issues:") | |
for error in errors: | |
print(f"- {error}") | |
sys.exit(1) | |
else: | |
print("No issues found!") | |
sys.exit(0) | |
if __name__ == "__main__": | |
main() | |