|
from flask import Flask, request, jsonify, Response, stream_with_context |
|
import requests |
|
import json |
|
import time |
|
import random |
|
import logging |
|
import sys |
|
import re |
|
from logging.handlers import TimedRotatingFileHandler |
|
from datetime import datetime |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
class RequestFormatter(logging.Formatter): |
|
def format(self, record): |
|
if request.method in ['POST', 'GET']: |
|
record.url = request.url |
|
record.remote_addr = request.remote_addr |
|
record.token = request.headers.get('Authorization', 'No Token') |
|
return super().format(record) |
|
return None |
|
|
|
formatter = RequestFormatter( |
|
'%(remote_addr)s - - [%(asctime)s] - Token: %(token)s - %(message)s', |
|
datefmt='%d/%b/%Y %H:%M:%S' |
|
) |
|
|
|
handler = TimedRotatingFileHandler('app.log', when="midnight", interval=1, backupCount=30) |
|
handler.setFormatter(formatter) |
|
handler.setLevel(logging.INFO) |
|
|
|
app.logger.addHandler(handler) |
|
app.logger.setLevel(logging.INFO) |
|
|
|
|
|
MODEL_MAPPING = { |
|
"flux.1-schnell": { |
|
"provider": "black-forest-labs", |
|
"mapping": "black-forest-labs/FLUX.1-schnell" |
|
}, |
|
"sd-turbo": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/sd-turbo" |
|
}, |
|
"sdxl-turbo": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/sdxl-turbo" |
|
}, |
|
"stable-diffusion-2-1": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/stable-diffusion-2-1" |
|
}, |
|
"stable-diffusion-3-medium": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/stable-diffusion-3-medium" |
|
}, |
|
"stable-diffusion-xl-base-1.0": { |
|
"provider": "stabilityai", |
|
"mapping": "stabilityai/stable-diffusion-xl-base-1.0" |
|
} |
|
} |
|
|
|
|
|
def getAuthCookie(req): |
|
auth_cookie = req.headers.get('Authorization') |
|
if auth_cookie and auth_cookie.startswith('Bearer '): |
|
return auth_cookie |
|
return None |
|
|
|
@app.route('/ai/v1/models', methods=['GET']) |
|
def get_models(): |
|
try: |
|
|
|
auth_cookie = getAuthCookie(request) |
|
if not auth_cookie: |
|
app.logger.info(f'GET /ai/v1/models - 401 Unauthorized') |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
|
|
models_list = [ |
|
{ |
|
"id": model_id, |
|
"object": "model", |
|
"created": int(time.time()), |
|
"owned_by": info["provider"], |
|
"permission": [], |
|
"root": model_id, |
|
"parent": None |
|
} |
|
for model_id, info in MODEL_MAPPING.items() |
|
] |
|
|
|
|
|
app.logger.info(f'GET /ai/v1/models - 200 OK') |
|
|
|
return jsonify({ |
|
"object": "list", |
|
"data": models_list |
|
}) |
|
|
|
except Exception as error: |
|
app.logger.error(f"Error: {str(error)}") |
|
return jsonify({"error": "Authentication failed", "details": str(error)}), 401 |
|
|
|
@app.route('/ai/v1/chat/completions', methods=['POST']) |
|
def handle_request(): |
|
try: |
|
body = request.json |
|
model = body.get('model') |
|
messages = body.get('messages') |
|
stream = body.get('stream', False) |
|
if not model or not messages or len(messages) == 0: |
|
return jsonify({"error": "Bad Request: Missing required fields"}), 400 |
|
|
|
|
|
if model in MODEL_MAPPING: |
|
mapped_model = MODEL_MAPPING[model]['mapping'] |
|
else: |
|
return jsonify({"error": f"Model '{model}' not found"}), 400 |
|
|
|
prompt = messages[-1]['content'] |
|
image_size, clean_prompt, use_original, size_param = extract_params_from_prompt(prompt) |
|
|
|
auth_header = request.headers.get('Authorization') |
|
random_token = get_random_token(auth_header) |
|
if not random_token: |
|
return jsonify({"error": "Unauthorized: Invalid or missing Authorization header"}), 401 |
|
|
|
if use_original: |
|
enhanced_prompt = clean_prompt |
|
else: |
|
enhanced_prompt = translate_and_enhance_prompt(clean_prompt, random_token) |
|
|
|
new_url = f'https://api.siliconflow.cn/v1/{mapped_model}/text-to-image' |
|
new_request_body = { |
|
"prompt": enhanced_prompt, |
|
"image_size": image_size, |
|
"batch_size": 1, |
|
"num_inference_steps": 4, |
|
"guidance_scale": 1 |
|
} |
|
|
|
headers = { |
|
'accept': 'application/json', |
|
'content-type': 'application/json', |
|
'Authorization': f'Bearer {random_token}' |
|
} |
|
|
|
response = requests.post(new_url, headers=headers, json=new_request_body, timeout=60) |
|
response.raise_for_status() |
|
response_body = response.json() |
|
|
|
if 'images' in response_body and response_body['images'] and 'url' in response_body['images'][0]: |
|
image_url = response_body['images'][0]['url'] |
|
else: |
|
raise ValueError("Unexpected response structure from image generation API") |
|
|
|
unique_id = str(int(time.time() * 1000)) |
|
current_timestamp = int(time.time()) |
|
system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9)) |
|
|
|
image_data = {'data': [{'url': image_url}]} |
|
|
|
|
|
params = [] |
|
if size_param != "16:9": |
|
params.append(f"-s {size_param}") |
|
if use_original: |
|
params.append("-o") |
|
params_str = " ".join(params) if params else "no params" |
|
|
|
app.logger.info(f'POST /ai/v1/chat/completions - Status: 200 - Token: {random_token} - Model: {mapped_model} - Params: {params_str} - Image URL: {image_url}') |
|
|
|
if stream: |
|
return stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original) |
|
else: |
|
return non_stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original) |
|
except Exception as e: |
|
app.logger.error(f"Error: {str(e)}") |
|
return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500 |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=8000) |
|
|