vertex2api / app.py
smgc's picture
Update app.py
3faecc7 verified
raw
history blame
7.33 kB
import os
import time
import json
import asyncio
import aiohttp
import logging
from aiohttp import web, ClientTimeout
from aiohttp.web import StreamResponse
# 设置日志
log_level = os.getenv('LOG_LEVEL', 'INFO')
logging.basicConfig(level=getattr(logging, log_level), format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')
MODEL = 'claude-3-5-sonnet@20240620'
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
token_cache = {
'access_token': '',
'expiry': 0,
'refresh_promise': None
}
async def get_access_token():
now = time.time()
if token_cache['access_token'] and now < token_cache['expiry'] - 120:
logger.info("Using cached access token")
return token_cache['access_token']
if token_cache['refresh_promise']:
logger.info("Waiting for ongoing token refresh")
await token_cache['refresh_promise']
return token_cache['access_token']
async def refresh_token():
logger.info("Refreshing access token")
async with aiohttp.ClientSession() as session:
async with session.post(TOKEN_URL, json={
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'refresh_token': REFRESH_TOKEN,
'grant_type': 'refresh_token'
}) as response:
data = await response.json()
token_cache['access_token'] = data['access_token']
token_cache['expiry'] = now + data['expires_in']
logger.info("Access token refreshed successfully")
token_cache['refresh_promise'] = refresh_token()
await token_cache['refresh_promise']
token_cache['refresh_promise'] = None
return token_cache['access_token']
def get_location():
current_seconds = time.localtime().tm_sec
location = 'europe-west1' if current_seconds < 30 else 'us-east5'
logger.info(f"Selected location: {location}")
return location
def construct_api_url(location, model):
url = f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{model}:streamRawPredict'
logger.info(f"Constructed API URL: {url}")
return url
def format_model_name(model):
if model == 'claude-3-5-sonnet-20240620':
return 'claude-3-5-sonnet@20240620'
return model
async def handle_root(request):
logger.info("Received request to root path /")
return web.Response(text="GCP Vertex Claude API Proxy", status=200)
async def handle_request(request):
logger.info(f"Received {request.method} request to {request.path}")
if request.method == 'OPTIONS':
logger.info("Handling OPTIONS request")
return handle_options()
api_key = request.headers.get('x-api-key')
if api_key != API_KEY:
logger.warning("Invalid API key")
return create_error_response('Your API key does not have permission to use the specified resource.', 403)
try:
request_body = await request.json()
logger.info(f"Received request body: {json.dumps(request_body)}")
except json.JSONDecodeError:
logger.error("Invalid JSON in request body")
return create_error_response('Invalid JSON in request body', 400)
if not request_body:
logger.error("Empty request body")
return create_error_response('Empty request body', 400)
access_token = await get_access_token()
location = get_location()
model = format_model_name(request_body.get('model', MODEL))
logger.info(f"Using model: {model}")
if 'anthropic_version' in request_body:
del request_body['anthropic_version']
if 'model' in request_body:
del request_body['model']
request_body['anthropic_version'] = 'vertex-2023-10-16'
api_url = construct_api_url(location, model)
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json; charset=utf-8'
}
logger.info(f"Sending request to Anthropic API: {api_url}")
# 创建一个 StreamResponse 对象
response = StreamResponse(status=200)
response.headers['Content-Type'] = 'application/json'
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
await response.prepare(request)
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, json=request_body, headers=headers, timeout=ClientTimeout(total=30)) as api_response:
api_response.raise_for_status()
logger.info(f"Received response from Anthropic API. Status: {api_response.status}")
logger.info(f"Response headers: {dict(api_response.headers)}")
async for chunk in api_response.content.iter_any():
await response.write(chunk)
logger.info("Finished streaming response from Anthropic API")
except aiohttp.ClientError as e:
logger.error(f"Error communicating with Anthropic API: {str(e)}")
await response.write(json.dumps({'error': 'Error communicating with AI service'}).encode())
except asyncio.TimeoutError:
logger.error("Request to Anthropic API timed out")
await response.write(json.dumps({'error': 'Request timed out'}).encode())
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
await response.write(json.dumps({'error': 'An unexpected error occurred'}).encode())
finally:
await response.write_eof()
logger.info("Finished sending response back to client")
return response
def create_error_response(message, status_code):
logger.error(f"Creating error response: {message} (Status: {status_code})")
response = web.Response(
text=json.dumps({
'type': 'error',
'error': {
'type': 'request_error',
'message': message
}
}),
status=status_code,
content_type='application/json'
)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
return response
def handle_options():
logger.info("Handling OPTIONS request")
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
}
return web.Response(status=204, headers=headers)
app = web.Application()
app.router.add_get('/', handle_root)
app.router.add_route('*', '/ai/v1/messages', handle_request)
if __name__ == '__main__':
logger.info("Starting server on port 8080")
web.run_app(app, port=8080)