Spaces:
Running
Running
File size: 3,979 Bytes
dd9ae3e 3fdf5e9 dd9ae3e 0b4f475 dd9ae3e 3fdf5e9 dd9ae3e 3fdf5e9 dd9ae3e 3fdf5e9 dd9ae3e 3fdf5e9 dd9ae3e 3fdf5e9 dd9ae3e 3fdf5e9 dd9ae3e f21f2de dd9ae3e 3fdf5e9 dd9ae3e 3fdf5e9 dd9ae3e f21f2de 3fdf5e9 f21f2de 3fdf5e9 dd9ae3e 0b4f475 3fdf5e9 dd9ae3e 3fdf5e9 dd9ae3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from flask import Flask, request, jsonify, Response, make_response
import requests
import threading
import time
import os
app = Flask(__name__)
MODEL = 'claude-3-5-sonnet@20240620'
PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
token_cache = {
'access_token': '',
'expiry': 0,
'refresh_promise': None
}
def get_access_token():
now = time.time()
if token_cache['access_token'] and now < token_cache['expiry'] - 120:
return token_cache['access_token']
if token_cache['refresh_promise']:
token_cache['refresh_promise'].join()
return token_cache['access_token']
def refresh_token():
try:
response = requests.post(TOKEN_URL, json={
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'refresh_token': REFRESH_TOKEN,
'grant_type': 'refresh_token'
})
data = response.json()
token_cache['access_token'] = data['access_token']
token_cache['expiry'] = now + data['expires_in']
finally:
token_cache['refresh_promise'] = None
token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
token_cache['refresh_promise'].start()
token_cache['refresh_promise'].join()
return token_cache['access_token']
def get_location():
current_seconds = time.localtime().tm_sec
return 'europe-west1' if current_seconds < 30 else 'us-east5'
def construct_api_url(location):
return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
@app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
def handle_request():
if request.method == 'OPTIONS':
return handle_options()
api_key = request.headers.get('x-api-key')
if api_key != API_KEY:
error_response = make_response(jsonify({
'type': 'error',
'error': {
'type': 'permission_error',
'message': 'Your API key does not have permission to use the specified resource.'
}
}), 403)
error_response.headers['Access-Control-Allow-Origin'] = '*'
error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
return error_response
access_token = get_access_token()
location = get_location()
api_url = construct_api_url(location)
request_body = request.json
if 'anthropic_version' in request_body:
del request_body['anthropic_version']
if 'model' in request_body:
del request_body['model']
request_body['anthropic_version'] = 'vertex-2023-10-16'
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json; charset=utf-8'
}
response = requests.post(api_url, headers=headers, json=request_body)
modified_response = make_response(response.content, response.status_code)
modified_response.headers['Access-Control-Allow-Origin'] = '*'
modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
return modified_response
def handle_options():
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
}
return '', 204, headers
if __name__ == '__main__':
app.run(port=8080) |