Spaces:
Configuration error
Configuration error
File size: 3,685 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
from openai.types.image import Image
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload
class TestCustomLogger(CustomLogger):
def __init__(self):
super().__init__()
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self.standard_logging_payload = kwargs.get("standard_logging_object")
pass
# test_example.py
from abc import ABC, abstractmethod
class BaseImageGenTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_image_generation_call_args(self) -> dict:
"""Must return the base image generation call args"""
pass
@pytest.mark.asyncio(scope="module")
async def test_basic_image_generation(self):
"""Test basic image generation"""
try:
custom_logger = TestCustomLogger()
litellm.logging_callback_manager._reset_all_callbacks()
litellm.callbacks = [custom_logger]
base_image_generation_call_args = self.get_base_image_generation_call_args()
litellm.set_verbose = True
response = await litellm.aimage_generation(
**base_image_generation_call_args, prompt="A image of a otter"
)
print(response)
await asyncio.sleep(1)
# assert response._hidden_params["response_cost"] is not None
# assert response._hidden_params["response_cost"] > 0
# print("response_cost", response._hidden_params["response_cost"])
logged_standard_logging_payload = custom_logger.standard_logging_payload
print("logged_standard_logging_payload", logged_standard_logging_payload)
assert logged_standard_logging_payload is not None
assert logged_standard_logging_payload["response_cost"] is not None
assert logged_standard_logging_payload["response_cost"] > 0
import openai
from openai.types.images_response import ImagesResponse
# print openai version
print("openai version=", openai.__version__)
response_dict = dict(response)
if "usage" in response_dict:
response_dict["usage"] = dict(response_dict["usage"])
print("response usage=", response_dict.get("usage"))
ImagesResponse.model_validate(response_dict)
for d in response.data:
assert isinstance(d, Image)
print("data in response.data", d)
assert d.b64_json is not None or d.url is not None
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}") |