ethiotech4848 commited on
Commit
b06ae13
Β·
verified Β·
1 Parent(s): 0ecbccb

Create usage_inference.py

Browse files
Files changed (1) hide show
  1. usage_inference.py +158 -0
usage_inference.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import time
4
+ from typing import Dict, Any
5
+
6
+ class APITester:
7
+ def __init__(self, base_url: str = "http://localhost:8000"):
8
+ self.base_url = base_url
9
+ self.session = requests.Session()
10
+
11
+ def test_health_check(self) -> None:
12
+ """Test the health check endpoint."""
13
+ print("\n=== Testing Health Check Endpoint ===")
14
+ try:
15
+ response = self.session.get(f"{self.base_url}/health_check")
16
+ print(f"Status Code: {response.status_code}")
17
+ print(f"Response: {response.json()}")
18
+ assert response.status_code == 200
19
+ print("βœ… Health check test passed!")
20
+ except Exception as e:
21
+ print(f"❌ Health check test failed: {str(e)}")
22
+
23
+ def test_models(self) -> None:
24
+ """Test the models endpoint."""
25
+ print("\n=== Testing Models Endpoint ===")
26
+ try:
27
+ response = self.session.get(f"{self.base_url}/models")
28
+ print(f"Status Code: {response.status_code}")
29
+ data = response.json()
30
+ print(f"Number of models available: {len(data['data'])}")
31
+ print("Sample models:")
32
+ for model in data['data'][:5]: # Show first 5 models
33
+ print(f"- {model['id']}")
34
+ assert response.status_code == 200
35
+ print("βœ… Models endpoint test passed!")
36
+ except Exception as e:
37
+ print(f"❌ Models endpoint test failed: {str(e)}")
38
+
39
+ def test_chat_completions_non_streaming(self) -> None:
40
+ """Test the chat completions endpoint without streaming."""
41
+ print("\n=== Testing Chat Completions Endpoint (Non-Streaming) ===")
42
+ payload = {
43
+ "model": "gpt-3.5-turbo",
44
+ "messages": [
45
+ {"role": "system", "content": "You are a helpful assistant."},
46
+ {"role": "user", "content": "Tell me a short joke about programming."}
47
+ ],
48
+ "temperature": 0.7,
49
+ "max_tokens": 150,
50
+ "stream": False
51
+ }
52
+
53
+ try:
54
+ response = self.session.post(
55
+ f"{self.base_url}/chat/completions",
56
+ json=payload
57
+ )
58
+ print(f"Status Code: {response.status_code}")
59
+ if response.status_code == 200:
60
+ data = response.json()
61
+ print("Response content:")
62
+ print(data['choices'][0]['message']['content'])
63
+ assert response.status_code == 200
64
+ print("βœ… Chat completions (non-streaming) test passed!")
65
+ except Exception as e:
66
+ print(f"❌ Chat completions (non-streaming) test failed: {str(e)}")
67
+
68
+ def test_chat_completions_streaming(self) -> None:
69
+ """Test the chat completions endpoint with streaming."""
70
+ print("\n=== Testing Chat Completions Endpoint (Streaming) ===")
71
+ payload = {
72
+ "model": "gpt-3.5-turbo",
73
+ "messages": [
74
+ {"role": "system", "content": "You are a helpful assistant."},
75
+ {"role": "user", "content": "Write 5 lines about India"}
76
+ ],
77
+ "temperature": 0.7,
78
+ "max_tokens": 150,
79
+ "stream": True
80
+ }
81
+
82
+ try:
83
+ with self.session.post(
84
+ f"{self.base_url}/chat/completions",
85
+ json=payload,
86
+ stream=True,
87
+ headers={"Accept": "text/event-stream"}
88
+ ) as response:
89
+ print(f"Status Code: {response.status_code}")
90
+ print("Streaming response:")
91
+
92
+ buffer = ""
93
+ for chunk in response.iter_lines():
94
+ if chunk:
95
+ chunk = chunk.decode('utf-8')
96
+ if chunk.startswith('data: '):
97
+ chunk = chunk[6:] # Remove 'data: ' prefix
98
+ if chunk.strip() == '[DONE]':
99
+ break
100
+ try:
101
+ data = json.loads(chunk)
102
+ if 'choices' in data and len(data['choices']) > 0:
103
+ if 'delta' in data['choices'][0] and 'content' in data['choices'][0]['delta']:
104
+ content = data['choices'][0]['delta']['content']
105
+ print(content, end='', flush=True)
106
+ time.sleep(0.1) # Add a small delay to simulate real-time streaming
107
+ except json.JSONDecodeError:
108
+ continue
109
+
110
+ print("\nβœ… Chat completions (streaming) test passed!")
111
+ except Exception as e:
112
+ print(f"❌ Chat completions (streaming) test failed: {str(e)}")
113
+
114
+ def test_developer_info(self) -> None:
115
+ """Test the developer info endpoint."""
116
+ print("\n=== Testing Developer Info Endpoint ===")
117
+ try:
118
+ response = self.session.get(f"{self.base_url}/developer_info")
119
+ print(f"Status Code: {response.status_code}")
120
+ print("Developer Info:")
121
+ print(json.dumps(response.json(), indent=2))
122
+ assert response.status_code == 200
123
+ print("βœ… Developer info test passed!")
124
+ except Exception as e:
125
+ print(f"❌ Developer info test failed: {str(e)}")
126
+
127
+ def run_all_tests(self) -> None:
128
+ """Run all tests sequentially."""
129
+ tests = [
130
+ self.test_health_check,
131
+ self.test_models,
132
+ self.test_chat_completions_non_streaming,
133
+ self.test_chat_completions_streaming,
134
+ self.test_developer_info
135
+ ]
136
+
137
+ print("πŸš€ Starting API Tests...")
138
+ start_time = time.time()
139
+
140
+ for test in tests:
141
+ test()
142
+
143
+ end_time = time.time()
144
+ duration = end_time - start_time
145
+
146
+ print(f"\n============================")
147
+ print(f"🏁 All tests completed in {duration:.2f} seconds")
148
+ print(f"============================")
149
+
150
+ def main():
151
+ # Initialize tester with your API's base URL
152
+ tester = APITester("http://localhost:8000")
153
+
154
+ # Run all tests
155
+ tester.run_all_tests()
156
+
157
+ if __name__ == "__main__":
158
+ main()