Spaces:
Configuration error
Configuration error
File size: 3,529 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import ast
import os
class EnterpriseImportFinder(ast.NodeVisitor):
def __init__(self):
self.unsafe_imports = []
self.current_file = None
self.in_try_block = False
self.try_blocks = []
def visit_Try(self, node):
# Track that we're entering a try block
self.in_try_block = True
self.try_blocks.append(node)
# Visit all nodes in the try block
for item in node.body:
self.visit(item)
# Visit except blocks
for handler in node.handlers:
for item in handler.body:
self.visit(item)
# Visit else block if it exists
for item in node.orelse:
self.visit(item)
# Visit finally block if it exists
for item in node.finalbody:
self.visit(item)
# We're leaving the try block
self.try_blocks.pop()
self.in_try_block = len(self.try_blocks) > 0
def visit_Import(self, node):
# Check for direct imports of litellm_enterprise
for name in node.names:
if "litellm_enterprise" in name.name or "enterprise" in name.name:
if not self.in_try_block:
self.unsafe_imports.append({
"file": self.current_file,
"line": node.lineno,
"import": name.name,
"context": "direct import"
})
self.generic_visit(node)
def visit_ImportFrom(self, node):
# Check for from litellm_enterprise imports
if node.module and ("litellm_enterprise" in node.module or "enterprise" in node.module):
if not self.in_try_block:
self.unsafe_imports.append({
"file": self.current_file,
"line": node.lineno,
"import": f"from {node.module}",
"context": "from import"
})
self.generic_visit(node)
def find_unsafe_enterprise_imports_in_file(file_path):
with open(file_path, "r") as file:
tree = ast.parse(file.read(), filename=file_path)
finder = EnterpriseImportFinder()
finder.current_file = file_path
finder.visit(tree)
return finder.unsafe_imports
def find_unsafe_enterprise_imports_in_directory(directory):
unsafe_imports = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
imports = find_unsafe_enterprise_imports_in_file(file_path)
if imports:
unsafe_imports.extend(imports)
return unsafe_imports
if __name__ == "__main__":
# Check for unsafe enterprise imports in the litellm directory
directory_path = "./litellm"
unsafe_imports = find_unsafe_enterprise_imports_in_directory(directory_path)
if unsafe_imports:
print("🚨 UNSAFE ENTERPRISE IMPORTS FOUND (not in try-except blocks):")
for imp in unsafe_imports:
print(f"File: {imp['file']}")
print(f"Line: {imp['line']}")
print(f"Import: {imp['import']}")
print(f"Context: {imp['context']}")
print("---")
# Raise exception to fail CI/CD
raise Exception(
"🚨 Unsafe enterprise imports found. All enterprise imports must be wrapped in try-except blocks."
)
else:
print("✅ No unsafe enterprise imports found.")
|