Spaces:
Configuration error
Configuration error
File size: 3,387 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock, Mock, patch
import pytest
import base64
import httpx
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
titan_embedding_response = {
"embedding": [0.1, 0.2, 0.3],
"inputTextTokenCount": 10
}
cohere_embedding_response = {
"embeddings": [[0.1, 0.2, 0.3]],
"inputTextTokenCount": 10
}
img_base_64 = ""
@pytest.mark.parametrize(
"model,input_type,embed_response",
[
("bedrock/amazon.titan-embed-text-v1", "text", titan_embedding_response), # V1 text model
("bedrock/amazon.titan-embed-text-v2:0", "text", titan_embedding_response), # V2 text model
("bedrock/amazon.titan-embed-image-v1", "image", titan_embedding_response), # Image model
("bedrock/cohere.embed-english-v3", "text", cohere_embedding_response), # Cohere English
("bedrock/cohere.embed-multilingual-v3", "text", cohere_embedding_response), # Cohere Multilingual
],
)
def test_bedrock_embedding_models(model, input_type, embed_response):
"""Test embedding functionality for all Bedrock models with different input types"""
litellm.set_verbose = True
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = json.dumps(embed_response)
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response
# Prepare input based on type
input_data = img_base_64 if input_type == "image" else "Hello world from litellm"
try:
response = litellm.embedding(
model=model,
input=input_data,
client=client,
aws_region_name="us-west-2",
aws_bedrock_runtime_endpoint="https://bedrock-runtime.us-west-2.amazonaws.com",
)
# Verify response structure
assert isinstance(response, litellm.EmbeddingResponse)
print(response.data)
assert isinstance(response.data[0]['embedding'], list)
assert len(response.data[0]['embedding']) == 3 # Based on mock response
# Fetch request body
request_data = json.loads(mock_post.call_args.kwargs["data"])
# Verify AWS params are not in request body
aws_params = ["aws_region_name", "aws_bedrock_runtime_endpoint"]
for param in aws_params:
assert param not in request_data, f"AWS param {param} should not be in request body"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
|