smgc commited on
Commit
dd9ae3e
·
verified ·
1 Parent(s): 3faecc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -133
app.py CHANGED
@@ -1,24 +1,17 @@
1
- import os
 
 
2
  import time
3
- import json
4
- import asyncio
5
- import aiohttp
6
- import logging
7
- from aiohttp import web, ClientTimeout
8
- from aiohttp.web import StreamResponse
9
-
10
- # 设置日志
11
- log_level = os.getenv('LOG_LEVEL', 'INFO')
12
- logging.basicConfig(level=getattr(logging, log_level), format='%(asctime)s - %(levelname)s - %(message)s')
13
- logger = logging.getLogger(__name__)
14
 
 
15
  PROJECT_ID = os.getenv('PROJECT_ID')
16
  CLIENT_ID = os.getenv('CLIENT_ID')
17
  CLIENT_SECRET = os.getenv('CLIENT_SECRET')
18
  REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
19
  API_KEY = os.getenv('API_KEY')
20
- MODEL = 'claude-3-5-sonnet@20240620'
21
-
22
  TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
23
 
24
  token_cache = {
@@ -27,85 +20,66 @@ token_cache = {
27
  'refresh_promise': None
28
  }
29
 
30
- async def get_access_token():
31
  now = time.time()
32
 
33
  if token_cache['access_token'] and now < token_cache['expiry'] - 120:
34
- logger.info("Using cached access token")
35
  return token_cache['access_token']
36
 
37
  if token_cache['refresh_promise']:
38
- logger.info("Waiting for ongoing token refresh")
39
- await token_cache['refresh_promise']
40
  return token_cache['access_token']
41
 
42
- async def refresh_token():
43
- logger.info("Refreshing access token")
44
- async with aiohttp.ClientSession() as session:
45
- async with session.post(TOKEN_URL, json={
46
  'client_id': CLIENT_ID,
47
  'client_secret': CLIENT_SECRET,
48
  'refresh_token': REFRESH_TOKEN,
49
  'grant_type': 'refresh_token'
50
- }) as response:
51
- data = await response.json()
52
- token_cache['access_token'] = data['access_token']
53
- token_cache['expiry'] = now + data['expires_in']
54
- logger.info("Access token refreshed successfully")
55
-
56
- token_cache['refresh_promise'] = refresh_token()
57
- await token_cache['refresh_promise']
58
- token_cache['refresh_promise'] = None
 
59
  return token_cache['access_token']
60
 
61
  def get_location():
62
  current_seconds = time.localtime().tm_sec
63
- location = 'europe-west1' if current_seconds < 30 else 'us-east5'
64
- logger.info(f"Selected location: {location}")
65
- return location
66
-
67
- def construct_api_url(location, model):
68
- url = f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{model}:streamRawPredict'
69
- logger.info(f"Constructed API URL: {url}")
70
- return url
71
-
72
- def format_model_name(model):
73
- if model == 'claude-3-5-sonnet-20240620':
74
- return 'claude-3-5-sonnet@20240620'
75
- return model
76
-
77
- async def handle_root(request):
78
- logger.info("Received request to root path /")
79
- return web.Response(text="GCP Vertex Claude API Proxy", status=200)
80
-
81
- async def handle_request(request):
82
- logger.info(f"Received {request.method} request to {request.path}")
83
-
84
  if request.method == 'OPTIONS':
85
- logger.info("Handling OPTIONS request")
86
  return handle_options()
87
 
88
  api_key = request.headers.get('x-api-key')
89
  if api_key != API_KEY:
90
- logger.warning("Invalid API key")
91
- return create_error_response('Your API key does not have permission to use the specified resource.', 403)
92
-
93
- try:
94
- request_body = await request.json()
95
- logger.info(f"Received request body: {json.dumps(request_body)}")
96
- except json.JSONDecodeError:
97
- logger.error("Invalid JSON in request body")
98
- return create_error_response('Invalid JSON in request body', 400)
99
-
100
- if not request_body:
101
- logger.error("Empty request body")
102
- return create_error_response('Empty request body', 400)
103
 
104
- access_token = await get_access_token()
105
  location = get_location()
 
106
 
107
- model = format_model_name(request_body.get('model', MODEL))
108
- logger.info(f"Using model: {model}")
109
 
110
  if 'anthropic_version' in request_body:
111
  del request_body['anthropic_version']
@@ -114,81 +88,26 @@ async def handle_request(request):
114
 
115
  request_body['anthropic_version'] = 'vertex-2023-10-16'
116
 
117
- api_url = construct_api_url(location, model)
118
-
119
  headers = {
120
  'Authorization': f'Bearer {access_token}',
121
  'Content-Type': 'application/json; charset=utf-8'
122
  }
123
 
124
- logger.info(f"Sending request to Anthropic API: {api_url}")
125
-
126
- # 创建一个 StreamResponse 对象
127
- response = StreamResponse(status=200)
128
- response.headers['Content-Type'] = 'application/json'
129
- response.headers['Access-Control-Allow-Origin'] = '*'
130
- response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
131
- response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
132
-
133
- await response.prepare(request)
134
-
135
- try:
136
- async with aiohttp.ClientSession() as session:
137
- async with session.post(api_url, json=request_body, headers=headers, timeout=ClientTimeout(total=30)) as api_response:
138
- api_response.raise_for_status()
139
- logger.info(f"Received response from Anthropic API. Status: {api_response.status}")
140
- logger.info(f"Response headers: {dict(api_response.headers)}")
141
-
142
- async for chunk in api_response.content.iter_any():
143
- await response.write(chunk)
144
-
145
- logger.info("Finished streaming response from Anthropic API")
146
- except aiohttp.ClientError as e:
147
- logger.error(f"Error communicating with Anthropic API: {str(e)}")
148
- await response.write(json.dumps({'error': 'Error communicating with AI service'}).encode())
149
- except asyncio.TimeoutError:
150
- logger.error("Request to Anthropic API timed out")
151
- await response.write(json.dumps({'error': 'Request timed out'}).encode())
152
- except Exception as e:
153
- logger.error(f"Unexpected error: {str(e)}")
154
- await response.write(json.dumps({'error': 'An unexpected error occurred'}).encode())
155
- finally:
156
- await response.write_eof()
157
- logger.info("Finished sending response back to client")
158
-
159
- return response
160
-
161
- def create_error_response(message, status_code):
162
- logger.error(f"Creating error response: {message} (Status: {status_code})")
163
- response = web.Response(
164
- text=json.dumps({
165
- 'type': 'error',
166
- 'error': {
167
- 'type': 'request_error',
168
- 'message': message
169
- }
170
- }),
171
- status=status_code,
172
- content_type='application/json'
173
- )
174
- response.headers['Access-Control-Allow-Origin'] = '*'
175
- response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
176
- response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
177
- return response
178
 
179
  def handle_options():
180
- logger.info("Handling OPTIONS request")
181
  headers = {
182
  'Access-Control-Allow-Origin': '*',
183
  'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
184
  'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
185
  }
186
- return web.Response(status=204, headers=headers)
187
-
188
- app = web.Application()
189
- app.router.add_get('/', handle_root)
190
- app.router.add_route('*', '/ai/v1/messages', handle_request)
191
 
192
  if __name__ == '__main__':
193
- logger.info("Starting server on port 8080")
194
- web.run_app(app, port=8080)
 
1
+ from flask import Flask, request, jsonify, Response, make_response
2
+ import requests
3
+ import threading
4
  import time
5
+ import os
6
+
7
+ app = Flask(__name__)
 
 
 
 
 
 
 
 
8
 
9
+ MODEL = 'claude-3-5-sonnet@20240620'
10
  PROJECT_ID = os.getenv('PROJECT_ID')
11
  CLIENT_ID = os.getenv('CLIENT_ID')
12
  CLIENT_SECRET = os.getenv('CLIENT_SECRET')
13
  REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
14
  API_KEY = os.getenv('API_KEY')
 
 
15
  TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
16
 
17
  token_cache = {
 
20
  'refresh_promise': None
21
  }
22
 
23
+ def get_access_token():
24
  now = time.time()
25
 
26
  if token_cache['access_token'] and now < token_cache['expiry'] - 120:
 
27
  return token_cache['access_token']
28
 
29
  if token_cache['refresh_promise']:
30
+ token_cache['refresh_promise'].join()
 
31
  return token_cache['access_token']
32
 
33
+ def refresh_token():
34
+ try:
35
+ response = requests.post(TOKEN_URL, json={
 
36
  'client_id': CLIENT_ID,
37
  'client_secret': CLIENT_SECRET,
38
  'refresh_token': REFRESH_TOKEN,
39
  'grant_type': 'refresh_token'
40
+ })
41
+ data = response.json()
42
+ token_cache['access_token'] = data['access_token']
43
+ token_cache['expiry'] = now + data['expires_in']
44
+ finally:
45
+ token_cache['refresh_promise'] = None
46
+
47
+ token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
48
+ token_cache['refresh_promise'].start()
49
+ token_cache['refresh_promise'].join()
50
  return token_cache['access_token']
51
 
52
  def get_location():
53
  current_seconds = time.localtime().tm_sec
54
+ return 'europe-west1' if current_seconds < 30 else 'us-east5'
55
+
56
+ def construct_api_url(location):
57
+ return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
58
+
59
+ @app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
60
+ def handle_request():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if request.method == 'OPTIONS':
 
62
  return handle_options()
63
 
64
  api_key = request.headers.get('x-api-key')
65
  if api_key != API_KEY:
66
+ error_response = make_response(jsonify({
67
+ 'type': 'error',
68
+ 'error': {
69
+ 'type': 'permission_error',
70
+ 'message': 'Your API key does not have permission to use the specified resource.'
71
+ }
72
+ }), 403)
73
+ error_response.headers['Access-Control-Allow-Origin'] = '*'
74
+ error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
75
+ error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
76
+ return error_response
 
 
77
 
78
+ access_token = get_access_token()
79
  location = get_location()
80
+ api_url = construct_api_url(location)
81
 
82
+ request_body = request.json
 
83
 
84
  if 'anthropic_version' in request_body:
85
  del request_body['anthropic_version']
 
88
 
89
  request_body['anthropic_version'] = 'vertex-2023-10-16'
90
 
 
 
91
  headers = {
92
  'Authorization': f'Bearer {access_token}',
93
  'Content-Type': 'application/json; charset=utf-8'
94
  }
95
 
96
+ response = requests.post(api_url, headers=headers, json=request_body)
97
+ modified_response = make_response(response.content, response.status_code)
98
+ modified_response.headers['Access-Control-Allow-Origin'] = '*'
99
+ modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
100
+ modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
101
+
102
+ return modified_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def handle_options():
 
105
  headers = {
106
  'Access-Control-Allow-Origin': '*',
107
  'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
108
  'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
109
  }
110
+ return '', 204, headers
 
 
 
 
111
 
112
  if __name__ == '__main__':
113
+ app.run(port=8080)