Spaces:
Configuration error
Configuration error
# What is this? | |
## Unit Tests for OpenAI Batches API | |
import asyncio | |
import json | |
import os | |
import sys | |
import traceback | |
import tempfile | |
from dotenv import load_dotenv | |
load_dotenv() | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system-path | |
import logging | |
import time | |
import pytest | |
from typing import Optional | |
import litellm | |
from litellm import create_batch, create_file | |
from litellm._logging import verbose_logger | |
verbose_logger.setLevel(logging.DEBUG) | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.types.utils import StandardLoggingPayload | |
import random | |
from unittest.mock import patch, MagicMock | |
def load_vertex_ai_credentials(): | |
# Define the path to the vertex_key.json file | |
print("loading vertex ai credentials") | |
os.environ["GCS_FLUSH_INTERVAL"] = "1" | |
filepath = os.path.dirname(os.path.abspath(__file__)) | |
vertex_key_path = filepath + "/pathrise-convert-1606954137718.json" | |
# Read the existing content of the file or create an empty dictionary | |
try: | |
with open(vertex_key_path, "r") as file: | |
# Read the file content | |
print("Read vertexai file path") | |
content = file.read() | |
# If the file is empty or not valid JSON, create an empty dictionary | |
if not content or not content.strip(): | |
service_account_key_data = {} | |
else: | |
# Attempt to load the existing JSON content | |
file.seek(0) | |
service_account_key_data = json.load(file) | |
except FileNotFoundError: | |
# If the file doesn't exist, create an empty dictionary | |
service_account_key_data = {} | |
# Update the service_account_key_data with environment variables | |
private_key_id = os.environ.get("GCS_PRIVATE_KEY_ID", "") | |
private_key = os.environ.get("GCS_PRIVATE_KEY", "") | |
private_key = private_key.replace("\\n", "\n") | |
service_account_key_data["private_key_id"] = private_key_id | |
service_account_key_data["private_key"] = private_key | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: | |
# Write the updated content to the temporary files | |
json.dump(service_account_key_data, temp_file, indent=2) | |
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS | |
os.environ["GCS_PATH_SERVICE_ACCOUNT"] = os.path.abspath(temp_file.name) | |
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) | |
print("created gcs path service account=", os.environ["GCS_PATH_SERVICE_ACCOUNT"]) | |
# , "azure" | |
async def test_create_batch(provider): | |
""" | |
1. Create File for Batch completion | |
2. Create Batch Request | |
3. Retrieve the specific batch | |
""" | |
if provider == "azure": | |
# Don't have anymore Azure Quota | |
return | |
file_name = "openai_batch_completions.jsonl" | |
_current_dir = os.path.dirname(os.path.abspath(__file__)) | |
file_path = os.path.join(_current_dir, file_name) | |
file_obj = await litellm.acreate_file( | |
file=open(file_path, "rb"), | |
purpose="batch", | |
custom_llm_provider=provider, | |
) | |
print("Response from creating file=", file_obj) | |
batch_input_file_id = file_obj.id | |
assert ( | |
batch_input_file_id is not None | |
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" | |
await asyncio.sleep(1) | |
create_batch_response = await litellm.acreate_batch( | |
completion_window="24h", | |
endpoint="/v1/chat/completions", | |
input_file_id=batch_input_file_id, | |
custom_llm_provider=provider, | |
metadata={"key1": "value1", "key2": "value2"}, | |
) | |
print("response from litellm.create_batch=", create_batch_response) | |
await asyncio.sleep(6) | |
assert ( | |
create_batch_response.id is not None | |
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" | |
assert ( | |
create_batch_response.endpoint == "/v1/chat/completions" | |
or create_batch_response.endpoint == "/chat/completions" | |
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" | |
assert ( | |
create_batch_response.input_file_id == batch_input_file_id | |
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" | |
retrieved_batch = await litellm.aretrieve_batch( | |
batch_id=create_batch_response.id, custom_llm_provider=provider | |
) | |
print("retrieved batch=", retrieved_batch) | |
# just assert that we retrieved a non None batch | |
assert retrieved_batch.id == create_batch_response.id | |
# list all batches | |
list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2) | |
print("list_batches=", list_batches) | |
file_content = await litellm.afile_content( | |
file_id=batch_input_file_id, custom_llm_provider=provider | |
) | |
result = file_content.content | |
result_file_name = "batch_job_results_furniture.jsonl" | |
with open(result_file_name, "wb") as file: | |
file.write(result) | |
# Cancel Batch | |
cancel_batch_response = await litellm.acancel_batch( | |
batch_id=create_batch_response.id, | |
custom_llm_provider=provider, | |
) | |
print("cancel_batch_response=", cancel_batch_response) | |
pass | |
class TestCustomLogger(CustomLogger): | |
def __init__(self): | |
super().__init__() | |
self.standard_logging_object: Optional[StandardLoggingPayload] = None | |
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
print( | |
"Success event logged with kwargs=", | |
kwargs, | |
"and response_obj=", | |
response_obj, | |
) | |
self.standard_logging_object = kwargs["standard_logging_object"] | |
def cleanup_azure_files(): | |
""" | |
Delete all files for Azure - helper for when we run out of Azure Files Quota | |
""" | |
azure_files = litellm.file_list( | |
custom_llm_provider="azure", | |
api_key=os.getenv("AZURE_FT_API_KEY"), | |
api_base=os.getenv("AZURE_FT_API_BASE"), | |
) | |
print("azure_files=", azure_files) | |
for _file in azure_files: | |
print("deleting file=", _file) | |
delete_file_response = litellm.file_delete( | |
file_id=_file.id, | |
custom_llm_provider="azure", | |
api_key=os.getenv("AZURE_FT_API_KEY"), | |
api_base=os.getenv("AZURE_FT_API_BASE"), | |
) | |
print("delete_file_response=", delete_file_response) | |
assert delete_file_response.id == _file.id | |
def cleanup_azure_ft_models(): | |
""" | |
Test CLEANUP: Delete all existing fine tuning jobs for Azure | |
""" | |
try: | |
from openai import AzureOpenAI | |
import requests | |
client = AzureOpenAI( | |
api_key=os.getenv("AZURE_FT_API_KEY"), | |
azure_endpoint=os.getenv("AZURE_FT_API_BASE"), | |
api_version=os.getenv("AZURE_API_VERSION"), | |
) | |
_list_ft_jobs = client.fine_tuning.jobs.list() | |
print("_list_ft_jobs=", _list_ft_jobs) | |
# delete all ft jobs make post request to this | |
# Delete all fine-tuning jobs | |
for job in _list_ft_jobs: | |
try: | |
endpoint = os.getenv("AZURE_FT_API_BASE").rstrip("/") | |
url = f"{endpoint}/openai/fine_tuning/jobs/{job.id}?api-version=2024-10-21" | |
print("url=", url) | |
headers = { | |
"api-key": os.getenv("AZURE_FT_API_KEY"), | |
"Content-Type": "application/json", | |
} | |
response = requests.delete(url, headers=headers) | |
print(f"Deleting job {job.id}: Status {response.status_code}") | |
if response.status_code != 204: | |
print(f"Error deleting job {job.id}: {response.text}") | |
except Exception as e: | |
print(f"Error deleting job {job.id}: {str(e)}") | |
except Exception as e: | |
print(f"Error on cleanup_azure_ft_models: {str(e)}") | |
async def test_async_create_batch(provider): | |
""" | |
1. Create File for Batch completion | |
2. Create Batch Request | |
3. Retrieve the specific batch | |
""" | |
litellm._turn_on_debug() | |
print("Testing async create batch") | |
litellm.logging_callback_manager._reset_all_callbacks() | |
custom_logger = TestCustomLogger() | |
litellm.callbacks = [custom_logger, "datadog"] | |
file_name = "openai_batch_completions.jsonl" | |
_current_dir = os.path.dirname(os.path.abspath(__file__)) | |
file_path = os.path.join(_current_dir, file_name) | |
file_obj = await litellm.acreate_file( | |
file=open(file_path, "rb"), | |
purpose="batch", | |
custom_llm_provider=provider, | |
) | |
print("Response from creating file=", file_obj) | |
await asyncio.sleep(10) | |
batch_input_file_id = file_obj.id | |
assert ( | |
batch_input_file_id is not None | |
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" | |
extra_metadata_field = { | |
"user_api_key_alias": "special_api_key_alias", | |
"user_api_key_team_alias": "special_team_alias", | |
} | |
create_batch_response = await litellm.acreate_batch( | |
completion_window="24h", | |
endpoint="/v1/chat/completions", | |
input_file_id=batch_input_file_id, | |
custom_llm_provider=provider, | |
metadata={"key1": "value1", "key2": "value2"}, | |
# litellm specific param - used for logging metadata on logging callback | |
litellm_metadata=extra_metadata_field, | |
) | |
print("response from litellm.create_batch=", create_batch_response) | |
assert ( | |
create_batch_response.id is not None | |
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" | |
assert ( | |
create_batch_response.endpoint == "/v1/chat/completions" | |
or create_batch_response.endpoint == "/chat/completions" | |
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" | |
assert ( | |
create_batch_response.input_file_id == batch_input_file_id | |
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" | |
await asyncio.sleep(6) | |
# Assert that the create batch event is logged on CustomLogger | |
assert custom_logger.standard_logging_object is not None | |
print( | |
"standard_logging_object=", | |
json.dumps(custom_logger.standard_logging_object, indent=4, default=str), | |
) | |
assert ( | |
custom_logger.standard_logging_object["metadata"]["user_api_key_alias"] | |
== extra_metadata_field["user_api_key_alias"] | |
) | |
assert ( | |
custom_logger.standard_logging_object["metadata"]["user_api_key_team_alias"] | |
== extra_metadata_field["user_api_key_team_alias"] | |
) | |
retrieved_batch = await litellm.aretrieve_batch( | |
batch_id=create_batch_response.id, custom_llm_provider=provider | |
) | |
print("retrieved batch=", retrieved_batch) | |
# just assert that we retrieved a non None batch | |
assert retrieved_batch.id == create_batch_response.id | |
# list all batches | |
list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2) | |
print("list_batches=", list_batches) | |
# try to get file content for our original file | |
file_content = await litellm.afile_content( | |
file_id=batch_input_file_id, custom_llm_provider=provider | |
) | |
print("file content = ", file_content) | |
# file obj | |
file_obj = await litellm.afile_retrieve( | |
file_id=batch_input_file_id, custom_llm_provider=provider | |
) | |
print("file obj = ", file_obj) | |
assert file_obj.id == batch_input_file_id | |
# delete file | |
delete_file_response = await litellm.afile_delete( | |
file_id=batch_input_file_id, custom_llm_provider=provider | |
) | |
print("delete file response = ", delete_file_response) | |
assert delete_file_response.id == batch_input_file_id | |
all_files_list = await litellm.afile_list( | |
custom_llm_provider=provider, | |
) | |
print("all_files_list = ", all_files_list) | |
result_file_name = "batch_job_results_furniture.jsonl" | |
with open(result_file_name, "wb") as file: | |
file.write(file_content.content) | |
# Cancel Batch | |
cancel_batch_response = await litellm.acancel_batch( | |
batch_id=create_batch_response.id, | |
custom_llm_provider=provider, | |
) | |
print("cancel_batch_response=", cancel_batch_response) | |
if random.randint(1, 3) == 1: | |
print("Running random cleanup of Azure files and models...") | |
cleanup_azure_files() | |
cleanup_azure_ft_models() | |
mock_file_response = { | |
"kind": "storage#object", | |
"id": "litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb/1739598666670574", | |
"selfLink": "https://www.googleapis.com/storage/v1/b/litellm-local/o/litellm-vertex-files%2Fpublishers%2Fgoogle%2Fmodels%2Fgemini-1.5-flash-001%2F5f7b99ad-9203-4430-98bf-3b45451af4cb", | |
"mediaLink": "https://storage.googleapis.com/download/storage/v1/b/litellm-local/o/litellm-vertex-files%2Fpublishers%2Fgoogle%2Fmodels%2Fgemini-1.5-flash-001%2F5f7b99ad-9203-4430-98bf-3b45451af4cb?generation=1739598666670574&alt=media", | |
"name": "litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb", | |
"bucket": "litellm-local", | |
"generation": "1739598666670574", | |
"metageneration": "1", | |
"contentType": "application/json", | |
"storageClass": "STANDARD", | |
"size": "416", | |
"md5Hash": "hbBNj7C8KJ7oVH+JmyRM6A==", | |
"crc32c": "oDmiUA==", | |
"etag": "CO7D0IT+xIsDEAE=", | |
"timeCreated": "2025-02-15T05:51:06.741Z", | |
"updated": "2025-02-15T05:51:06.741Z", | |
"timeStorageClassUpdated": "2025-02-15T05:51:06.741Z", | |
"timeFinalized": "2025-02-15T05:51:06.741Z", | |
} | |
mock_vertex_batch_response = { | |
"name": "projects/123456789/locations/us-central1/batchPredictionJobs/test-batch-id-456", | |
"displayName": "litellm_batch_job", | |
"model": "projects/123456789/locations/us-central1/models/gemini-1.5-flash-001", | |
"modelVersionId": "v1", | |
"inputConfig": { | |
"gcsSource": { | |
"uris": [ | |
"gs://litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb" | |
] | |
} | |
}, | |
"outputConfig": { | |
"gcsDestination": {"outputUriPrefix": "gs://litellm-local/batch-outputs/"} | |
}, | |
"dedicatedResources": { | |
"machineSpec": { | |
"machineType": "n1-standard-4", | |
"acceleratorType": "NVIDIA_TESLA_T4", | |
"acceleratorCount": 1, | |
}, | |
"startingReplicaCount": 1, | |
"maxReplicaCount": 1, | |
}, | |
"state": "JOB_STATE_RUNNING", | |
"createTime": "2025-02-15T05:51:06.741Z", | |
"startTime": "2025-02-15T05:51:07.741Z", | |
"updateTime": "2025-02-15T05:51:08.741Z", | |
"labels": {"key1": "value1", "key2": "value2"}, | |
"completionStats": {"successfulCount": 0, "failedCount": 0, "remainingCount": 100}, | |
} | |
async def test_avertex_batch_prediction(monkeypatch): | |
monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local") | |
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler | |
client = AsyncHTTPHandler() | |
async def mock_side_effect(*args, **kwargs): | |
print("args", args, "kwargs", kwargs) | |
url = kwargs.get("url", "") | |
if "files" in url: | |
mock_response.json.return_value = mock_file_response | |
elif "batch" in url: | |
mock_response.json.return_value = mock_vertex_batch_response | |
mock_response.status_code = 200 | |
return mock_response | |
with patch.object( | |
client, "post", side_effect=mock_side_effect | |
) as mock_post, patch( | |
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post" | |
) as mock_global_post: | |
# Configure mock responses | |
mock_response = MagicMock() | |
mock_response.raise_for_status.return_value = None | |
# Set up different responses for different API calls | |
mock_post.side_effect = mock_side_effect | |
mock_global_post.side_effect = mock_side_effect | |
# load_vertex_ai_credentials() | |
litellm.set_verbose = True | |
litellm._turn_on_debug() | |
file_name = "vertex_batch_completions.jsonl" | |
_current_dir = os.path.dirname(os.path.abspath(__file__)) | |
file_path = os.path.join(_current_dir, file_name) | |
# Create file | |
file_obj = await litellm.acreate_file( | |
file=open(file_path, "rb"), | |
purpose="batch", | |
custom_llm_provider="vertex_ai", | |
client=client | |
) | |
print("Response from creating file=", file_obj) | |
assert ( | |
file_obj.id | |
== "gs://litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb" | |
) | |
# Create batch | |
create_batch_response = await litellm.acreate_batch( | |
completion_window="24h", | |
endpoint="/v1/chat/completions", | |
input_file_id=file_obj.id, | |
custom_llm_provider="vertex_ai", | |
metadata={"key1": "value1", "key2": "value2"}, | |
) | |
print("create_batch_response=", create_batch_response) | |
assert create_batch_response.id == "test-batch-id-456" | |
assert ( | |
create_batch_response.input_file_id | |
== "gs://litellm-local/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/5f7b99ad-9203-4430-98bf-3b45451af4cb" | |
) | |
# Mock the retrieve batch response | |
with patch( | |
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.get" | |
) as mock_get: | |
mock_get_response = MagicMock() | |
mock_get_response.json.return_value = mock_vertex_batch_response | |
mock_get_response.status_code = 200 | |
mock_get_response.raise_for_status.return_value = None | |
mock_get.return_value = mock_get_response | |
retrieved_batch = await litellm.aretrieve_batch( | |
batch_id=create_batch_response.id, | |
custom_llm_provider="vertex_ai", | |
) | |
print("retrieved_batch=", retrieved_batch) | |
assert retrieved_batch.id == "test-batch-id-456" | |