File size: 3,979 Bytes
dd9ae3e
 
 
3fdf5e9
dd9ae3e
 
 
0b4f475
dd9ae3e
3fdf5e9
 
 
 
 
 
 
 
 
 
 
 
 
dd9ae3e
3fdf5e9
 
 
 
 
 
dd9ae3e
3fdf5e9
 
dd9ae3e
 
 
3fdf5e9
 
 
 
dd9ae3e
 
 
 
 
 
 
 
 
 
3fdf5e9
 
 
 
dd9ae3e
 
 
 
 
 
 
3fdf5e9
 
 
 
 
dd9ae3e
 
 
 
 
 
 
 
 
 
 
f21f2de
dd9ae3e
3fdf5e9
dd9ae3e
3fdf5e9
dd9ae3e
f21f2de
3fdf5e9
 
 
 
f21f2de
 
3fdf5e9
 
 
 
 
 
dd9ae3e
 
 
 
 
 
 
0b4f475
3fdf5e9
 
 
 
 
 
dd9ae3e
3fdf5e9
 
dd9ae3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from flask import Flask, request, jsonify, Response, make_response
import requests
import threading
import time
import os

app = Flask(__name__)

MODEL = 'claude-3-5-sonnet@20240620'
PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'

token_cache = {
    'access_token': '',
    'expiry': 0,
    'refresh_promise': None
}

def get_access_token():
    now = time.time()

    if token_cache['access_token'] and now < token_cache['expiry'] - 120:
        return token_cache['access_token']

    if token_cache['refresh_promise']:
        token_cache['refresh_promise'].join()
        return token_cache['access_token']

    def refresh_token():
        try:
            response = requests.post(TOKEN_URL, json={
                'client_id': CLIENT_ID,
                'client_secret': CLIENT_SECRET,
                'refresh_token': REFRESH_TOKEN,
                'grant_type': 'refresh_token'
            })
            data = response.json()
            token_cache['access_token'] = data['access_token']
            token_cache['expiry'] = now + data['expires_in']
        finally:
            token_cache['refresh_promise'] = None

    token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
    token_cache['refresh_promise'].start()
    token_cache['refresh_promise'].join()
    return token_cache['access_token']

def get_location():
    current_seconds = time.localtime().tm_sec
    return 'europe-west1' if current_seconds < 30 else 'us-east5'

def construct_api_url(location):
    return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'

@app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
def handle_request():
    if request.method == 'OPTIONS':
        return handle_options()

    api_key = request.headers.get('x-api-key')
    if api_key != API_KEY:
        error_response = make_response(jsonify({
            'type': 'error',
            'error': {
                'type': 'permission_error',
                'message': 'Your API key does not have permission to use the specified resource.'
            }
        }), 403)
        error_response.headers['Access-Control-Allow-Origin'] = '*'
        error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
        error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
        return error_response

    access_token = get_access_token()
    location = get_location()
    api_url = construct_api_url(location)

    request_body = request.json

    if 'anthropic_version' in request_body:
        del request_body['anthropic_version']
    if 'model' in request_body:
        del request_body['model']

    request_body['anthropic_version'] = 'vertex-2023-10-16'

    headers = {
        'Authorization': f'Bearer {access_token}',
        'Content-Type': 'application/json; charset=utf-8'
    }

    response = requests.post(api_url, headers=headers, json=request_body)
    modified_response = make_response(response.content, response.status_code)
    modified_response.headers['Access-Control-Allow-Origin'] = '*'
    modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
    modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'

    return modified_response

def handle_options():
    headers = {
        'Access-Control-Allow-Origin': '*',
        'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
        'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
    }
    return '', 204, headers

if __name__ == '__main__':
    app.run(port=8080)