DevsDoCode commited on
Commit
e628215
·
verified ·
1 Parent(s): edc4bc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -159
app.py CHANGED
@@ -1,160 +1,169 @@
1
- from flask import Flask, request, jsonify, Response
2
- from functools import wraps
3
- import uuid
4
- import json
5
- from typing import List, Optional
6
- from pydantic import BaseModel, ValidationError
7
- from API_provider import API_Inference
8
- from core_logic import (
9
- check_api_key_validity,
10
- update_request_count,
11
- get_rate_limit_status,
12
- get_subscription_status,
13
- get_available_models,
14
- get_model_info,
15
- )
16
-
17
- app = Flask(__name__)
18
-
19
- class Message(BaseModel):
20
- role: str
21
- content: str
22
-
23
- class ChatCompletionRequest(BaseModel):
24
- model: str
25
- messages: List[Message]
26
- stream: Optional[bool] = False
27
- max_tokens: Optional[int] = 4000
28
- temperature: Optional[float] = 0.5
29
- top_p: Optional[float] = 0.95
30
-
31
- def get_api_key():
32
- auth_header = request.headers.get('Authorization')
33
- if not auth_header or not auth_header.startswith('Bearer '):
34
- return None
35
- return auth_header.split(' ')[1]
36
-
37
- def requires_api_key(func):
38
- @wraps(func)
39
- def decorated(*args, **kwargs):
40
- api_key = get_api_key()
41
- if not api_key:
42
- return jsonify({'detail': 'Not authenticated'}), 401
43
- kwargs['api_key'] = api_key
44
- return func(*args, **kwargs)
45
- return decorated
46
-
47
- @app.route('/v1/chat/completions', methods=['POST'])
48
- @requires_api_key
49
- def chat_completions(api_key):
50
- try:
51
- # Parse and validate request data
52
- try:
53
- data = request.get_json()
54
- chat_request = ChatCompletionRequest(**data)
55
- except ValidationError as e:
56
- return jsonify({'detail': e.errors()}), 400
57
-
58
- # Check API key validity and rate limit
59
- is_valid, error_message = check_api_key_validity(api_key)
60
- if not is_valid:
61
- return jsonify({'detail': error_message}), 401
62
-
63
- messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
64
-
65
- # Get model info
66
- model_info = get_model_info(chat_request.model)
67
- if not model_info:
68
- return jsonify({'detail': 'Invalid model specified'}), 400
69
-
70
- # Model mapping
71
- model_mapping = {
72
- "meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
73
- "claude-3.5-sonnet": "claude-3-sonnet-20240229",
74
- }
75
- model_name = model_mapping.get(chat_request.model, chat_request.model)
76
- credits_reduction = {
77
- "gpt-4o": 1,
78
- "claude-3-sonnet-20240229": 1,
79
- "gemini-1.5-pro": 1,
80
- "gemini-1-5-flash": 1,
81
- "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1,
82
- "o1-mini": 2,
83
- "o1-preview": 3,
84
- }.get(model_name, 0)
85
-
86
- if chat_request.stream:
87
- def generate():
88
- try:
89
- for chunk in API_Inference(messages, model=model_name, stream=True,
90
- max_tokens=chat_request.max_tokens,
91
- temperature=chat_request.temperature,
92
- top_p=chat_request.top_p):
93
- data = json.dumps({'choices': [{'delta': {'content': chunk}}]})
94
- yield f"data: {data}\n\n"
95
- yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n"
96
- update_request_count(api_key, credits_reduction)
97
- except Exception as e:
98
- yield f"data: [ERROR] {str(e)}\n\n"
99
-
100
- return Response(generate(), mimetype='text/event-stream')
101
- else:
102
- response = API_Inference(messages, model=model_name, stream=False,
103
- max_tokens=chat_request.max_tokens,
104
- temperature=chat_request.temperature,
105
- top_p=chat_request.top_p)
106
- update_request_count(api_key, credits_reduction)
107
- prompt_tokens = sum(len(msg['content'].split()) for msg in messages)
108
- completion_tokens = len(response.split())
109
- total_tokens = prompt_tokens + completion_tokens
110
- return jsonify({
111
- "id": f"chatcmpl-{str(uuid.uuid4())}",
112
- "object": "chat.completion",
113
- "created": int(uuid.uuid1().time // 1e7),
114
- "model": model_name,
115
- "choices": [
116
- {
117
- "index": 0,
118
- "message": {
119
- "role": "assistant",
120
- "content": response
121
- },
122
- "finish_reason": "stop"
123
- }
124
- ],
125
- "usage": {
126
- "prompt_tokens": prompt_tokens,
127
- "completion_tokens": completion_tokens,
128
- "total_tokens": total_tokens
129
- },
130
- "credits_used": credits_reduction
131
- })
132
- except Exception as e:
133
- return jsonify({'detail': str(e)}), 500
134
-
135
- @app.route('/rate_limit/status', methods=['GET'])
136
- @requires_api_key
137
- def get_rate_limit_status_endpoint(api_key):
138
- is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
139
- if not is_valid:
140
- return jsonify({'detail': error_message}), 401
141
- return jsonify(get_rate_limit_status(api_key))
142
-
143
- @app.route('/subscription/status', methods=['GET'])
144
- @requires_api_key
145
- def get_subscription_status_endpoint(api_key):
146
- is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
147
- if not is_valid:
148
- return jsonify({'detail': error_message}), 401
149
- return jsonify(get_subscription_status(api_key))
150
-
151
- @app.route('/models', methods=['GET'])
152
- @requires_api_key
153
- def get_available_models_endpoint(api_key):
154
- is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
155
- if not is_valid:
156
- return jsonify({'detail': error_message}), 401
157
- return jsonify({"data": [{"id": model} for model in get_available_models().values()]})
158
-
159
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
160
  app.run(host="0.0.0.0", port=8000)
 
1
+ from flask import Flask, request, jsonify, Response
2
+ from functools import wraps
3
+ import uuid
4
+ import json
5
+ from typing import List, Optional
6
+ from pydantic import BaseModel, ValidationError
7
+ import logging
8
+ from API_provider import API_Inference
9
+ from core_logic import (
10
+ check_api_key_validity,
11
+ update_request_count,
12
+ get_rate_limit_status,
13
+ get_subscription_status,
14
+ get_available_models,
15
+ get_model_info,
16
+ )
17
+
18
+ app = Flask(__name__)
19
+ logging.basicConfig(level=logging.DEBUG)
20
+
21
+ class Message(BaseModel):
22
+ role: str
23
+ content: str
24
+
25
+ class ChatCompletionRequest(BaseModel):
26
+ model: str
27
+ messages: List[Message]
28
+ stream: Optional[bool] = False
29
+ max_tokens: Optional[int] = 4000
30
+ temperature: Optional[float] = 0.5
31
+ top_p: Optional[float] = 0.95
32
+
33
+ def get_api_key():
34
+ auth_header = request.headers.get('Authorization')
35
+ if not auth_header or not auth_header.startswith('Bearer '):
36
+ return None
37
+ return auth_header.split(' ')[1]
38
+
39
+ def requires_api_key(func):
40
+ @wraps(func)
41
+ def decorated(*args, **kwargs):
42
+ api_key = get_api_key()
43
+ if not api_key:
44
+ return jsonify({'detail': 'Not authenticated'}), 401
45
+ kwargs['api_key'] = api_key
46
+ return func(*args, **kwargs)
47
+ return decorated
48
+
49
+ @app.route('/')
50
+ def index():
51
+ return 'Hello, World!'
52
+
53
+ @app.route('/v1/chat/completions', methods=['POST'])
54
+ @requires_api_key
55
+ def chat_completions(api_key):
56
+ logging.info("Received request for chat completions")
57
+ print("requess received")
58
+ try:
59
+ logging.info("Received request for chat completions")
60
+ # Parse and validate request data
61
+ try:
62
+ data = request.get_json()
63
+ chat_request = ChatCompletionRequest(**data)
64
+ except ValidationError as e:
65
+ return jsonify({'detail': e.errors()}), 400
66
+
67
+ # Check API key validity and rate limit
68
+ is_valid, error_message = check_api_key_validity(api_key)
69
+ if not is_valid:
70
+ return jsonify({'detail': error_message}), 401
71
+
72
+ messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
73
+
74
+ # Get model info
75
+ model_info = get_model_info(chat_request.model)
76
+ if not model_info:
77
+ return jsonify({'detail': 'Invalid model specified'}), 400
78
+
79
+ # Model mapping
80
+ model_mapping = {
81
+ "meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
82
+ "claude-3.5-sonnet": "claude-3-sonnet-20240229",
83
+ }
84
+ model_name = model_mapping.get(chat_request.model, chat_request.model)
85
+ credits_reduction = {
86
+ "gpt-4o": 1,
87
+ "claude-3-sonnet-20240229": 1,
88
+ "gemini-1.5-pro": 1,
89
+ "gemini-1-5-flash": 1,
90
+ "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1,
91
+ "o1-mini": 2,
92
+ "o1-preview": 3,
93
+ }.get(model_name, 0)
94
+
95
+ if chat_request.stream:
96
+ def generate():
97
+ try:
98
+ for chunk in API_Inference(messages, model=model_name, stream=True,
99
+ max_tokens=chat_request.max_tokens,
100
+ temperature=chat_request.temperature,
101
+ top_p=chat_request.top_p):
102
+ data = json.dumps({'choices': [{'delta': {'content': chunk}}]})
103
+ yield f"data: {data}\n\n"
104
+ yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n"
105
+ update_request_count(api_key, credits_reduction)
106
+ except Exception as e:
107
+ yield f"data: [ERROR] {str(e)}\n\n"
108
+
109
+ return Response(generate(), mimetype='text/event-stream')
110
+ else:
111
+ response = API_Inference(messages, model=model_name, stream=False,
112
+ max_tokens=chat_request.max_tokens,
113
+ temperature=chat_request.temperature,
114
+ top_p=chat_request.top_p)
115
+ update_request_count(api_key, credits_reduction)
116
+ prompt_tokens = sum(len(msg['content'].split()) for msg in messages)
117
+ completion_tokens = len(response.split())
118
+ total_tokens = prompt_tokens + completion_tokens
119
+ return jsonify({
120
+ "id": f"chatcmpl-{str(uuid.uuid4())}",
121
+ "object": "chat.completion",
122
+ "created": int(uuid.uuid1().time // 1e7),
123
+ "model": model_name,
124
+ "choices": [
125
+ {
126
+ "index": 0,
127
+ "message": {
128
+ "role": "assistant",
129
+ "content": response
130
+ },
131
+ "finish_reason": "stop"
132
+ }
133
+ ],
134
+ "usage": {
135
+ "prompt_tokens": prompt_tokens,
136
+ "completion_tokens": completion_tokens,
137
+ "total_tokens": total_tokens
138
+ },
139
+ "credits_used": credits_reduction
140
+ })
141
+ except Exception as e:
142
+ return jsonify({'detail': str(e)}), 500
143
+
144
+ @app.route('/rate_limit/status', methods=['GET'])
145
+ @requires_api_key
146
+ def get_rate_limit_status_endpoint(api_key):
147
+ is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
148
+ if not is_valid:
149
+ return jsonify({'detail': error_message}), 401
150
+ return jsonify(get_rate_limit_status(api_key))
151
+
152
+ @app.route('/subscription/status', methods=['GET'])
153
+ @requires_api_key
154
+ def get_subscription_status_endpoint(api_key):
155
+ is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
156
+ if not is_valid:
157
+ return jsonify({'detail': error_message}), 401
158
+ return jsonify(get_subscription_status(api_key))
159
+
160
+ @app.route('/models', methods=['GET'])
161
+ @requires_api_key
162
+ def get_available_models_endpoint(api_key):
163
+ is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
164
+ if not is_valid:
165
+ return jsonify({'detail': error_message}), 401
166
+ return jsonify({"data": [{"id": model} for model in get_available_models().values()]})
167
+
168
+ if __name__ == "__main__":
169
  app.run(host="0.0.0.0", port=8000)