Spaces:
Running
Running
# coding:utf-8 | |
import argparse | |
import requests | |
import json | |
import os | |
from cerebras import CerebrasUnofficial | |
from flask import Flask, request, Response, stream_with_context, jsonify | |
import sys | |
# -- Start of Config -- | |
# Replace with your cerebras.ai session token found in the `authjs.session-token` cookie. | |
# Or you can `set AUTHJS_SESSION_TOKEN=authjs.session-token` | |
# This token is valid for one month. | |
authjs_session_token = '12345678-abcd-abcd-abcd-12345678abcd' | |
# Replace with any string you wish like `my-api-key`. | |
# Or you can `set SERVER_API_KEY=my-api-key` | |
# You should set it to update the session token in the future. | |
server_api_key = 'my-api-key' | |
# -- End of Config -- | |
sys.tracebacklimit = 0 | |
authjs_session_token = os.environ.get('AUTHJS_SESSION_TOKEN', authjs_session_token) | |
server_api_key = os.environ.get('SERVER_API_KEY', server_api_key) | |
print(f'Using the cookie: authjs.session-token={authjs_session_token}') | |
print(f'Your api key: {server_api_key}') | |
cerebras_ai = CerebrasUnofficial(authjs_session_token) | |
app = Flask(__name__) | |
app.json.sort_keys = False | |
parser = argparse.ArgumentParser(description='Cerebras.AI API.') | |
parser.add_argument('--host', type=str, help='Set the ip address.(default: 0.0.0.0)', default='0.0.0.0') | |
parser.add_argument('--port', type=int, help='Set the port.(default: 7860)', default=7860) | |
args = parser.parse_args() | |
class Provider: | |
key = '' | |
max_tokens = None | |
api_url = '' | |
def __init__(self, request_key, model): | |
self.request_key = request_key | |
self.model = model | |
self.init_request_info() | |
def init_request_info(self): | |
if self.request_key == server_api_key: | |
self.api_url = cerebras_ai.api_url | |
self.key = cerebras_ai.get_api_key() | |
def renew_token(): | |
if server_api_key == request.args.get('key', ''): | |
request_token = request.args.get('token', '') | |
global cerebras_ai | |
cerebras_ai = CerebrasUnofficial(request_token) | |
return f'new authjs.session_token: {request_token}' | |
else: | |
raise Exception('invalid api key') | |
def model_list(): | |
model_list = { | |
'object': 'list', | |
'data': [{ | |
'id': 'llama3.1-8b', | |
'object': 'model', | |
'created': 1721692800, | |
'owned_by': 'Meta' | |
}, { | |
'id': 'llama-3.3-70b', | |
'object': 'model', | |
'created': 1733443200, | |
'owned_by': 'Meta' | |
}, { | |
'id': 'deepseek-r1-distill-llama-70b', | |
'object': 'model', | |
'created': 1733443200, | |
'owned_by': 'deepseek' | |
}] | |
} | |
return jsonify(model_list) | |
def proxy(): | |
request_key = request.headers['Authorization'].split(' ')[1] | |
if server_api_key != request_key: | |
raise Exception('invalid api key') | |
headers = dict(request.headers) | |
headers.pop('Host', None) | |
headers.pop('Content-Length', None) | |
headers['X-Use-Cache'] = 'false' | |
model = request.get_json()['model'] | |
provider = Provider(request_key, model) | |
headers['Authorization'] = f'Bearer {provider.key}' | |
chat_api = f'{provider.api_url}/v1/chat/completions' | |
def generate(): | |
with requests.post(chat_api, json=request.json, headers=headers, stream=True) as resp: | |
for chunk in resp.iter_content(chunk_size=1024): | |
if chunk: | |
chunk_str = chunk.decode('utf-8') | |
yield chunk_str | |
return Response(stream_with_context(generate()), content_type='text/event-stream') | |
if __name__ == '__main__': | |
app.run(host=args.host, port=args.port, debug=True) | |