import requests
import json
import time
from typing import Dict, Any

class APITester:
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url
        self.session = requests.Session()

    def test_health_check(self) -> None:
        """Test the health check endpoint."""
        print("\n=== Testing Health Check Endpoint ===")
        try:
            response = self.session.get(f"{self.base_url}/health_check")
            print(f"Status Code: {response.status_code}")
            print(f"Response: {response.json()}")
            assert response.status_code == 200
            print("✅ Health check test passed!")
        except Exception as e:
            print(f"❌ Health check test failed: {str(e)}")

    def test_models(self) -> None:
        """Test the models endpoint."""
        print("\n=== Testing Models Endpoint ===")
        try:
            response = self.session.get(f"{self.base_url}/models")
            print(f"Status Code: {response.status_code}")
            data = response.json()
            print(f"Number of models available: {len(data['data'])}")
            print("Sample models:")
            for model in data['data'][:5]:  # Show first 5 models
                print(f"- {model['id']}")
            assert response.status_code == 200
            print("✅ Models endpoint test passed!")
        except Exception as e:
            print(f"❌ Models endpoint test failed: {str(e)}")

    def test_chat_completions_non_streaming(self) -> None:
        """Test the chat completions endpoint without streaming."""
        print("\n=== Testing Chat Completions Endpoint (Non-Streaming) ===")
        payload = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Tell me a short joke about programming."}
            ],
            "temperature": 0.7,
            "max_tokens": 150,
            "stream": False
        }

        try:
            response = self.session.post(
                f"{self.base_url}/chat/completions",
                json=payload
            )
            print(f"Status Code: {response.status_code}")
            if response.status_code == 200:
                data = response.json()
                print("Response content:")
                print(data['choices'][0]['message']['content'])
            assert response.status_code == 200
            print("✅ Chat completions (non-streaming) test passed!")
        except Exception as e:
            print(f"❌ Chat completions (non-streaming) test failed: {str(e)}")

    def test_chat_completions_streaming(self) -> None:
        """Test the chat completions endpoint with streaming."""
        print("\n=== Testing Chat Completions Endpoint (Streaming) ===")
        payload = {
            "model": "gpt-3.5-turbo",
            "messages": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Write 5 lines about India"}
            ],
            "temperature": 0.7,
            "max_tokens": 150,
            "stream": True
        }

        try:
            with self.session.post(
                f"{self.base_url}/chat/completions",
                json=payload,
                stream=True,
                headers={"Accept": "text/event-stream"}
            ) as response:
                print(f"Status Code: {response.status_code}")
                print("Streaming response:")
                
                buffer = ""
                for chunk in response.iter_lines():
                    if chunk:
                        chunk = chunk.decode('utf-8')
                        if chunk.startswith('data: '):
                            chunk = chunk[6:]  # Remove 'data: ' prefix
                            if chunk.strip() == '[DONE]':
                                break
                            try:
                                data = json.loads(chunk)
                                if 'choices' in data and len(data['choices']) > 0:
                                    if 'delta' in data['choices'][0] and 'content' in data['choices'][0]['delta']:
                                        content = data['choices'][0]['delta']['content']
                                        print(content, end='', flush=True)
                                        time.sleep(0.1)  # Add a small delay to simulate real-time streaming
                            except json.JSONDecodeError:
                                continue
                
                print("\n✅ Chat completions (streaming) test passed!")
        except Exception as e:
            print(f"❌ Chat completions (streaming) test failed: {str(e)}")

    def test_developer_info(self) -> None:
        """Test the developer info endpoint."""
        print("\n=== Testing Developer Info Endpoint ===")
        try:
            response = self.session.get(f"{self.base_url}/developer_info")
            print(f"Status Code: {response.status_code}")
            print("Developer Info:")
            print(json.dumps(response.json(), indent=2))
            assert response.status_code == 200
            print("✅ Developer info test passed!")
        except Exception as e:
            print(f"❌ Developer info test failed: {str(e)}")

    def run_all_tests(self) -> None:
        """Run all tests sequentially."""
        tests = [
            self.test_health_check,
            self.test_models,
            self.test_chat_completions_non_streaming,
            self.test_chat_completions_streaming,
            self.test_developer_info
        ]

        print("🚀 Starting API Tests...")
        start_time = time.time()

        for test in tests:
            test()

        end_time = time.time()
        duration = end_time - start_time
        
        print(f"\n============================")
        print(f"🏁 All tests completed in {duration:.2f} seconds")
        print(f"============================")

def main():
    # Initialize tester with your API's base URL
    tester = APITester("http://localhost:8000")
    
    # Run all tests
    tester.run_all_tests()

if __name__ == "__main__":
    main()