vertex2api / app.py
smgc's picture
Update app.py
dd9ae3e verified
raw
history blame
3.98 kB
from flask import Flask, request, jsonify, Response, make_response
import requests
import threading
import time
import os
app = Flask(__name__)
MODEL = 'claude-3-5-sonnet@20240620'
PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
token_cache = {
'access_token': '',
'expiry': 0,
'refresh_promise': None
}
def get_access_token():
now = time.time()
if token_cache['access_token'] and now < token_cache['expiry'] - 120:
return token_cache['access_token']
if token_cache['refresh_promise']:
token_cache['refresh_promise'].join()
return token_cache['access_token']
def refresh_token():
try:
response = requests.post(TOKEN_URL, json={
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'refresh_token': REFRESH_TOKEN,
'grant_type': 'refresh_token'
})
data = response.json()
token_cache['access_token'] = data['access_token']
token_cache['expiry'] = now + data['expires_in']
finally:
token_cache['refresh_promise'] = None
token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
token_cache['refresh_promise'].start()
token_cache['refresh_promise'].join()
return token_cache['access_token']
def get_location():
current_seconds = time.localtime().tm_sec
return 'europe-west1' if current_seconds < 30 else 'us-east5'
def construct_api_url(location):
return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
@app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
def handle_request():
if request.method == 'OPTIONS':
return handle_options()
api_key = request.headers.get('x-api-key')
if api_key != API_KEY:
error_response = make_response(jsonify({
'type': 'error',
'error': {
'type': 'permission_error',
'message': 'Your API key does not have permission to use the specified resource.'
}
}), 403)
error_response.headers['Access-Control-Allow-Origin'] = '*'
error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
return error_response
access_token = get_access_token()
location = get_location()
api_url = construct_api_url(location)
request_body = request.json
if 'anthropic_version' in request_body:
del request_body['anthropic_version']
if 'model' in request_body:
del request_body['model']
request_body['anthropic_version'] = 'vertex-2023-10-16'
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json; charset=utf-8'
}
response = requests.post(api_url, headers=headers, json=request_body)
modified_response = make_response(response.content, response.status_code)
modified_response.headers['Access-Control-Allow-Origin'] = '*'
modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
return modified_response
def handle_options():
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
}
return '', 204, headers
if __name__ == '__main__':
app.run(port=8080)