Spaces:
Configuration error
Configuration error
import pytest | |
import os | |
from datetime import datetime, timezone | |
from unittest.mock import MagicMock, patch | |
from botocore.credentials import Credentials | |
from typing import Dict, Any | |
from litellm.llms.bedrock.base_aws_llm import ( | |
BaseAWSLLM, | |
AwsAuthError, | |
Boto3CredentialsInfo, | |
) | |
# Test fixtures | |
def base_aws_llm(): | |
return BaseAWSLLM() | |
def mock_credentials(): | |
return Credentials( | |
access_key="test_access", secret_key="test_secret", token="test_token" | |
) | |
# Test cache key generation | |
def test_get_cache_key(base_aws_llm): | |
test_args = { | |
"aws_access_key_id": "test_key", | |
"aws_secret_access_key": "test_secret", | |
} | |
cache_key = base_aws_llm.get_cache_key(test_args) | |
assert isinstance(cache_key, str) | |
assert len(cache_key) == 64 # SHA-256 produces 64 character hex string | |
# Test web identity token authentication | |
# Add this patch | |
def test_auth_with_web_identity_token(mock_get_secret, mock_boto3_client, base_aws_llm): | |
# Mock get_secret to return a token | |
mock_get_secret.return_value = "mocked_oidc_token" | |
# Mock the STS client and response | |
mock_sts = MagicMock() | |
mock_sts.assume_role_with_web_identity.return_value = { | |
"Credentials": { | |
"AccessKeyId": "test_access", | |
"SecretAccessKey": "test_secret", | |
"SessionToken": "test_token", | |
}, | |
"PackedPolicySize": 10, | |
} | |
mock_boto3_client.return_value = mock_sts | |
credentials, ttl = base_aws_llm._auth_with_web_identity_token( | |
aws_web_identity_token="test_token", | |
aws_role_name="test_role", | |
aws_session_name="test_session", | |
aws_region_name="us-west-2", | |
aws_sts_endpoint=None, | |
) | |
# Verify get_secret was called with the correct argument | |
mock_get_secret.assert_called_once_with("test_token") | |
assert isinstance(credentials, Credentials) | |
assert ttl == 3540 # default TTL (3600 - 60) | |
# Test AWS role authentication | |
def test_auth_with_aws_role(mock_boto3_client, base_aws_llm): | |
# Mock the STS client and response | |
mock_sts = MagicMock() | |
expiry_time = datetime.now(timezone.utc) | |
mock_sts.assume_role.return_value = { | |
"Credentials": { | |
"AccessKeyId": "test_access", | |
"SecretAccessKey": "test_secret", | |
"SessionToken": "test_token", | |
"Expiration": expiry_time, | |
} | |
} | |
mock_boto3_client.return_value = mock_sts | |
credentials, ttl = base_aws_llm._auth_with_aws_role( | |
aws_access_key_id="test_access", | |
aws_secret_access_key="test_secret", | |
aws_role_name="test_role", | |
aws_session_name="test_session", | |
) | |
assert isinstance(credentials, Credentials) | |
assert isinstance(ttl, float) | |
# Test AWS profile authentication | |
def test_auth_with_aws_profile(mock_session, base_aws_llm, mock_credentials): | |
# Mock the session | |
mock_session_instance = MagicMock() | |
mock_session_instance.get_credentials.return_value = mock_credentials | |
mock_session.return_value = mock_session_instance | |
credentials, ttl = base_aws_llm._auth_with_aws_profile("test_profile") | |
assert credentials == mock_credentials | |
assert ttl is None | |
# Test session token authentication | |
def test_auth_with_aws_session_token(base_aws_llm): | |
credentials, ttl = base_aws_llm._auth_with_aws_session_token( | |
aws_access_key_id="test_access", | |
aws_secret_access_key="test_secret", | |
aws_session_token="test_token", | |
) | |
assert isinstance(credentials, Credentials) | |
assert credentials.access_key == "test_access" | |
assert credentials.secret_key == "test_secret" | |
assert credentials.token == "test_token" | |
assert ttl is None | |
# Test access key and secret key authentication | |
def test_auth_with_access_key_and_secret_key( | |
mock_session, base_aws_llm, mock_credentials | |
): | |
# Mock the session | |
mock_session_instance = MagicMock() | |
mock_session_instance.get_credentials.return_value = mock_credentials | |
mock_session.return_value = mock_session_instance | |
credentials, ttl = base_aws_llm._auth_with_access_key_and_secret_key( | |
aws_access_key_id="test_access", | |
aws_secret_access_key="test_secret", | |
aws_region_name="us-west-2", | |
) | |
assert credentials == mock_credentials | |
assert ttl == 3540 # default TTL (3600 - 60) | |
# Test environment variables authentication | |
def test_auth_with_env_vars(mock_session, base_aws_llm, mock_credentials): | |
# Mock the session | |
mock_session_instance = MagicMock() | |
mock_session_instance.get_credentials.return_value = mock_credentials | |
mock_session.return_value = mock_session_instance | |
credentials, ttl = base_aws_llm._auth_with_env_vars() | |
assert credentials == mock_credentials | |
assert ttl is None | |
# Test runtime endpoint resolution | |
def test_get_runtime_endpoint(base_aws_llm): | |
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint( | |
api_base=None, aws_bedrock_runtime_endpoint=None, aws_region_name="us-west-2" | |
) | |
assert endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com" | |
assert proxy_endpoint_url == "https://bedrock-runtime.us-west-2.amazonaws.com" | |
endpoint_url, proxy_endpoint_url = base_aws_llm.get_runtime_endpoint( | |
aws_bedrock_runtime_endpoint=None, aws_region_name="us-east-1", api_base=None | |
) | |
assert endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com" | |
assert proxy_endpoint_url == "https://bedrock-runtime.us-east-1.amazonaws.com" | |
def clear_cache(base_aws_llm): | |
"""Clear the cache before each test""" | |
base_aws_llm.iam_cache.in_memory_cache.cache_dict = {} | |
yield | |