import datetime import json import os import sys import unittest from unittest.mock import ANY, MagicMock, patch sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system-path from litellm.integrations.athina import AthinaLogger class TestAthinaLogger(unittest.TestCase): def setUp(self): # Set up environment variables for testing self.env_patcher = patch.dict( "os.environ", { "ATHINA_API_KEY": "test-api-key", "ATHINA_BASE_URL": "https://test.athina.ai", }, ) self.env_patcher.start() self.logger = AthinaLogger() # Setup common test variables self.start_time = datetime.datetime(2023, 1, 1, 12, 0, 0) self.end_time = datetime.datetime(2023, 1, 1, 12, 0, 1) self.print_verbose = MagicMock() def tearDown(self): self.env_patcher.stop() def test_init(self): """Test the initialization of AthinaLogger""" self.assertEqual(self.logger.athina_api_key, "test-api-key") self.assertEqual( self.logger.athina_logging_url, "https://test.athina.ai/api/v1/log/inference", ) self.assertEqual( self.logger.headers, {"athina-api-key": "test-api-key", "Content-Type": "application/json"}, ) @patch("litellm.module_level_client.post") def test_log_event_success(self, mock_post): """Test successful logging of an event""" # Setup mock response mock_response = MagicMock() mock_response.status_code = 200 mock_response.text = "Success" mock_post.return_value = mock_response # Create test data kwargs = { "model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}], "stream": False, "litellm_params": { "metadata": { "environment": "test-environment", "prompt_slug": "test-prompt", "customer_id": "test-customer", "customer_user_id": "test-user", "session_id": "test-session", "external_reference_id": "test-ext-ref", "context": "test-context", "expected_response": "test-expected", "user_query": "test-query", "tags": ["test-tag"], "user_feedback": "test-feedback", "model_options": {"test-opt": "test-val"}, "custom_attributes": {"test-attr": "test-val"}, } }, } response_obj = MagicMock() response_obj.model_dump.return_value = { "id": "resp-123", "choices": [{"message": {"content": "Hi there"}}], "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } # Call the method self.logger.log_event( kwargs, response_obj, self.start_time, self.end_time, self.print_verbose ) # Verify the results mock_post.assert_called_once() call_args = mock_post.call_args self.assertEqual(call_args[0][0], "https://test.athina.ai/api/v1/log/inference") self.assertEqual(call_args[1]["headers"], self.logger.headers) # Parse and verify the sent data sent_data = json.loads(call_args[1]["data"]) self.assertEqual(sent_data["language_model_id"], "gpt-4") self.assertEqual(sent_data["prompt"], kwargs["messages"]) self.assertEqual(sent_data["prompt_tokens"], 10) self.assertEqual(sent_data["completion_tokens"], 5) self.assertEqual(sent_data["total_tokens"], 15) self.assertEqual(sent_data["response_time"], 1000) # 1 second = 1000ms self.assertEqual(sent_data["customer_id"], "test-customer") self.assertEqual(sent_data["session_id"], "test-session") self.assertEqual(sent_data["environment"], "test-environment") self.assertEqual(sent_data["prompt_slug"], "test-prompt") self.assertEqual(sent_data["external_reference_id"], "test-ext-ref") self.assertEqual(sent_data["context"], "test-context") self.assertEqual(sent_data["expected_response"], "test-expected") self.assertEqual(sent_data["user_query"], "test-query") self.assertEqual(sent_data["tags"], ["test-tag"]) self.assertEqual(sent_data["user_feedback"], "test-feedback") self.assertEqual(sent_data["model_options"], {"test-opt": "test-val"}) self.assertEqual(sent_data["custom_attributes"], {"test-attr": "test-val"}) # Verify the print_verbose was called self.print_verbose.assert_called_once_with("Athina Logger Succeeded - Success") @patch("litellm.module_level_client.post") def test_log_event_error_response(self, mock_post): """Test handling of error response from the API""" # Setup mock error response mock_response = MagicMock() mock_response.status_code = 400 mock_response.text = "Bad Request" mock_post.return_value = mock_response # Create test data kwargs = { "model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}], "stream": False, } response_obj = MagicMock() response_obj.model_dump.return_value = { "id": "resp-123", "choices": [{"message": {"content": "Hi there"}}], "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } # Call the method self.logger.log_event( kwargs, response_obj, self.start_time, self.end_time, self.print_verbose ) # Verify print_verbose was called with error message self.print_verbose.assert_called_once_with( "Athina Logger Error - Bad Request, 400" ) @patch("litellm.module_level_client.post") def test_log_event_exception(self, mock_post): """Test handling of exceptions during logging""" # Setup mock to raise exception mock_post.side_effect = Exception("Test exception") # Create test data kwargs = { "model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}], "stream": False, } response_obj = MagicMock() response_obj.model_dump.return_value = {} # Call the method self.logger.log_event( kwargs, response_obj, self.start_time, self.end_time, self.print_verbose ) # Verify print_verbose was called with exception info self.print_verbose.assert_called_once() self.assertIn( "Athina Logger Error - Test exception", self.print_verbose.call_args[0][0] ) @patch("litellm.module_level_client.post") def test_log_event_with_tools(self, mock_post): """Test logging with tools/functions data""" # Setup mock response mock_response = MagicMock() mock_response.status_code = 200 mock_post.return_value = mock_response # Create test data with tools kwargs = { "model": "gpt-4", "messages": [{"role": "user", "content": "What's the weather?"}], "stream": False, "optional_params": { "tools": [{"type": "function", "function": {"name": "get_weather"}}] }, } response_obj = MagicMock() response_obj.model_dump.return_value = { "id": "resp-123", "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } # Call the method self.logger.log_event( kwargs, response_obj, self.start_time, self.end_time, self.print_verbose ) # Verify the results sent_data = json.loads(mock_post.call_args[1]["data"]) self.assertEqual( sent_data["tools"], [{"type": "function", "function": {"name": "get_weather"}}], ) if __name__ == "__main__": unittest.main()