Spaces:
Configuration error
Configuration error
import ast | |
import os | |
class EnterpriseImportFinder(ast.NodeVisitor): | |
def __init__(self): | |
self.unsafe_imports = [] | |
self.current_file = None | |
self.in_try_block = False | |
self.try_blocks = [] | |
def visit_Try(self, node): | |
# Track that we're entering a try block | |
self.in_try_block = True | |
self.try_blocks.append(node) | |
# Visit all nodes in the try block | |
for item in node.body: | |
self.visit(item) | |
# Visit except blocks | |
for handler in node.handlers: | |
for item in handler.body: | |
self.visit(item) | |
# Visit else block if it exists | |
for item in node.orelse: | |
self.visit(item) | |
# Visit finally block if it exists | |
for item in node.finalbody: | |
self.visit(item) | |
# We're leaving the try block | |
self.try_blocks.pop() | |
self.in_try_block = len(self.try_blocks) > 0 | |
def visit_Import(self, node): | |
# Check for direct imports of litellm_enterprise | |
for name in node.names: | |
if "litellm_enterprise" in name.name or "enterprise" in name.name: | |
if not self.in_try_block: | |
self.unsafe_imports.append({ | |
"file": self.current_file, | |
"line": node.lineno, | |
"import": name.name, | |
"context": "direct import" | |
}) | |
self.generic_visit(node) | |
def visit_ImportFrom(self, node): | |
# Check for from litellm_enterprise imports | |
if node.module and ("litellm_enterprise" in node.module or "enterprise" in node.module): | |
if not self.in_try_block: | |
self.unsafe_imports.append({ | |
"file": self.current_file, | |
"line": node.lineno, | |
"import": f"from {node.module}", | |
"context": "from import" | |
}) | |
self.generic_visit(node) | |
def find_unsafe_enterprise_imports_in_file(file_path): | |
with open(file_path, "r") as file: | |
tree = ast.parse(file.read(), filename=file_path) | |
finder = EnterpriseImportFinder() | |
finder.current_file = file_path | |
finder.visit(tree) | |
return finder.unsafe_imports | |
def find_unsafe_enterprise_imports_in_directory(directory): | |
unsafe_imports = [] | |
for root, _, files in os.walk(directory): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
imports = find_unsafe_enterprise_imports_in_file(file_path) | |
if imports: | |
unsafe_imports.extend(imports) | |
return unsafe_imports | |
if __name__ == "__main__": | |
# Check for unsafe enterprise imports in the litellm directory | |
directory_path = "./litellm" | |
unsafe_imports = find_unsafe_enterprise_imports_in_directory(directory_path) | |
if unsafe_imports: | |
print("π¨ UNSAFE ENTERPRISE IMPORTS FOUND (not in try-except blocks):") | |
for imp in unsafe_imports: | |
print(f"File: {imp['file']}") | |
print(f"Line: {imp['line']}") | |
print(f"Import: {imp['import']}") | |
print(f"Context: {imp['context']}") | |
print("---") | |
# Raise exception to fail CI/CD | |
raise Exception( | |
"π¨ Unsafe enterprise imports found. All enterprise imports must be wrapped in try-except blocks." | |
) | |
else: | |
print("β No unsafe enterprise imports found.") | |