Spaces:
Configuration error
Configuration error
File size: 4,067 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import ast
import os
ALLOWED_FILES = [
# local files
"../../litellm/litellm_core_utils/litellm.logging_callback_manager.py",
"../../litellm/proxy/common_utils/callback_utils.py",
# when running on ci/cd
"./litellm/litellm_core_utils/litellm.logging_callback_manager.py",
"./litellm/proxy/common_utils/callback_utils.py",
]
warning_msg = "this is a serious violation. Callbacks must only be modified through LoggingCallbackManager"
def check_for_callback_modifications(file_path):
"""
Checks if any direct modifications to specific litellm callback lists are made in the given file.
Also prints the violating line of code.
"""
print("..checking file=", file_path)
if file_path in ALLOWED_FILES:
return []
violations = []
with open(file_path, "r") as file:
try:
lines = file.readlines()
tree = ast.parse("".join(lines))
except SyntaxError:
print(f"Warning: Syntax error in file {file_path}")
return violations
protected_lists = [
"callbacks",
"success_callback",
"failure_callback",
"_async_success_callback",
"_async_failure_callback",
]
forbidden_operations = ["append", "extend", "insert"]
for node in ast.walk(tree):
# Check for attribute calls like litellm.callbacks.append()
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
# Get the full attribute chain
attr_chain = []
current = node.func
while isinstance(current, ast.Attribute):
attr_chain.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
attr_chain.append(current.id)
# Reverse to get the chain from root to leaf
attr_chain = attr_chain[::-1]
# Check if the attribute chain starts with 'litellm' and modifies a protected list
if (
len(attr_chain) >= 3
and attr_chain[0] == "litellm"
and attr_chain[2] in forbidden_operations
):
protected_list = attr_chain[1]
operation = attr_chain[2]
if (
protected_list in protected_lists
and operation in forbidden_operations
):
violating_line = lines[node.lineno - 1].strip()
violations.append(
f"Found violation in file {file_path} line {node.lineno}: '{violating_line}'. "
f"Direct modification of 'litellm.{protected_list}' using '{operation}' is not allowed. "
f"Please use LoggingCallbackManager instead. {warning_msg}"
)
return violations
def scan_directory_for_callback_modifications(base_dir):
"""
Scans all Python files in the directory tree for unauthorized callback list modifications.
"""
all_violations = []
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
violations = check_for_callback_modifications(file_path)
all_violations.extend(violations)
return all_violations
def test_no_unauthorized_callback_modifications():
"""
Test to ensure callback lists are not modified directly anywhere in the codebase.
"""
base_dir = "./litellm" # Adjust this path as needed
# base_dir = "../../litellm" # LOCAL TESTING
violations = scan_directory_for_callback_modifications(base_dir)
if violations:
print(f"\nFound {len(violations)} callback modification violations:")
for violation in violations:
print("\n" + violation)
raise AssertionError(
"Found unauthorized callback modifications. See above for details."
)
if __name__ == "__main__":
test_no_unauthorized_callback_modifications()
|