Spaces:
Running
Running
import os | |
import time | |
import json | |
import asyncio | |
import aiohttp | |
import logging | |
from aiohttp import web, ClientTimeout | |
from aiohttp.web import StreamResponse | |
# 设置日志 | |
log_level = os.getenv('LOG_LEVEL', 'INFO') | |
logging.basicConfig(level=getattr(logging, log_level), format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
PROJECT_ID = os.getenv('PROJECT_ID') | |
CLIENT_ID = os.getenv('CLIENT_ID') | |
CLIENT_SECRET = os.getenv('CLIENT_SECRET') | |
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN') | |
API_KEY = os.getenv('API_KEY') | |
MODEL = 'claude-3-5-sonnet@20240620' | |
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token' | |
token_cache = { | |
'access_token': '', | |
'expiry': 0, | |
'refresh_promise': None | |
} | |
async def get_access_token(): | |
now = time.time() | |
if token_cache['access_token'] and now < token_cache['expiry'] - 120: | |
logger.info("Using cached access token") | |
return token_cache['access_token'] | |
if token_cache['refresh_promise']: | |
logger.info("Waiting for ongoing token refresh") | |
await token_cache['refresh_promise'] | |
return token_cache['access_token'] | |
async def refresh_token(): | |
logger.info("Refreshing access token") | |
async with aiohttp.ClientSession() as session: | |
async with session.post(TOKEN_URL, json={ | |
'client_id': CLIENT_ID, | |
'client_secret': CLIENT_SECRET, | |
'refresh_token': REFRESH_TOKEN, | |
'grant_type': 'refresh_token' | |
}) as response: | |
data = await response.json() | |
token_cache['access_token'] = data['access_token'] | |
token_cache['expiry'] = now + data['expires_in'] | |
logger.info("Access token refreshed successfully") | |
token_cache['refresh_promise'] = refresh_token() | |
await token_cache['refresh_promise'] | |
token_cache['refresh_promise'] = None | |
return token_cache['access_token'] | |
def get_location(): | |
current_seconds = time.localtime().tm_sec | |
location = 'europe-west1' if current_seconds < 30 else 'us-east5' | |
logger.info(f"Selected location: {location}") | |
return location | |
def construct_api_url(location, model): | |
url = f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{model}:streamRawPredict' | |
logger.info(f"Constructed API URL: {url}") | |
return url | |
def format_model_name(model): | |
if model == 'claude-3-5-sonnet-20240620': | |
return 'claude-3-5-sonnet@20240620' | |
return model | |
async def handle_root(request): | |
logger.info("Received request to root path /") | |
return web.Response(text="GCP Vertex Claude API Proxy", status=200) | |
async def handle_request(request): | |
logger.info(f"Received {request.method} request to {request.path}") | |
if request.method == 'OPTIONS': | |
logger.info("Handling OPTIONS request") | |
return handle_options() | |
api_key = request.headers.get('x-api-key') | |
if api_key != API_KEY: | |
logger.warning("Invalid API key") | |
return create_error_response('Your API key does not have permission to use the specified resource.', 403) | |
try: | |
request_body = await request.json() | |
logger.info(f"Received request body: {json.dumps(request_body)}") | |
except json.JSONDecodeError: | |
logger.error("Invalid JSON in request body") | |
return create_error_response('Invalid JSON in request body', 400) | |
if not request_body: | |
logger.error("Empty request body") | |
return create_error_response('Empty request body', 400) | |
access_token = await get_access_token() | |
location = get_location() | |
model = format_model_name(request_body.get('model', MODEL)) | |
logger.info(f"Using model: {model}") | |
if 'anthropic_version' in request_body: | |
del request_body['anthropic_version'] | |
if 'model' in request_body: | |
del request_body['model'] | |
request_body['anthropic_version'] = 'vertex-2023-10-16' | |
api_url = construct_api_url(location, model) | |
headers = { | |
'Authorization': f'Bearer {access_token}', | |
'Content-Type': 'application/json; charset=utf-8' | |
} | |
logger.info(f"Sending request to Anthropic API: {api_url}") | |
# 创建一个 StreamResponse 对象 | |
response = StreamResponse(status=200) | |
response.headers['Content-Type'] = 'application/json' | |
response.headers['Access-Control-Allow-Origin'] = '*' | |
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS' | |
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model' | |
await response.prepare(request) | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post(api_url, json=request_body, headers=headers, timeout=ClientTimeout(total=30)) as api_response: | |
api_response.raise_for_status() | |
logger.info(f"Received response from Anthropic API. Status: {api_response.status}") | |
logger.info(f"Response headers: {dict(api_response.headers)}") | |
async for chunk in api_response.content.iter_any(): | |
await response.write(chunk) | |
logger.info("Finished streaming response from Anthropic API") | |
except aiohttp.ClientError as e: | |
logger.error(f"Error communicating with Anthropic API: {str(e)}") | |
await response.write(json.dumps({'error': 'Error communicating with AI service'}).encode()) | |
except asyncio.TimeoutError: | |
logger.error("Request to Anthropic API timed out") | |
await response.write(json.dumps({'error': 'Request timed out'}).encode()) | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
await response.write(json.dumps({'error': 'An unexpected error occurred'}).encode()) | |
finally: | |
await response.write_eof() | |
logger.info("Finished sending response back to client") | |
return response | |
def create_error_response(message, status_code): | |
logger.error(f"Creating error response: {message} (Status: {status_code})") | |
response = web.Response( | |
text=json.dumps({ | |
'type': 'error', | |
'error': { | |
'type': 'request_error', | |
'message': message | |
} | |
}), | |
status=status_code, | |
content_type='application/json' | |
) | |
response.headers['Access-Control-Allow-Origin'] = '*' | |
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD' | |
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model' | |
return response | |
def handle_options(): | |
logger.info("Handling OPTIONS request") | |
headers = { | |
'Access-Control-Allow-Origin': '*', | |
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS', | |
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model' | |
} | |
return web.Response(status=204, headers=headers) | |
app = web.Application() | |
app.router.add_get('/', handle_root) | |
app.router.add_route('*', '/ai/v1/messages', handle_request) | |
if __name__ == '__main__': | |
logger.info("Starting server on port 8080") | |
web.run_app(app, port=8080) | |