File size: 6,573 Bytes
bc4ed93 1953b4d 52f607b 1953b4d bc7ed47 403a673 a92cfc8 44dac73 e1b2102 f382771 a5853e1 1953b4d 52f607b 403a673 889a28b 20f980f e1b2102 20f980f e1b2102 889a28b 20f980f e4ddb7f a92cfc8 889a28b e1b2102 889a28b 403a673 bbe47aa 20f980f 535a218 20f980f 83608e4 e1b2102 20f980f a5d1c4e b735536 1953b4d 52f607b 1953b4d 83608e4 b2af907 52f607b 889a28b 83608e4 889a28b 83608e4 2950611 889a28b 83608e4 b2af907 52f607b 83608e4 44dac73 52f607b 83608e4 1953b4d 889a28b 1953b4d 83608e4 889a28b a92cfc8 889a28b 403a673 889a28b ab730f6 3af6d28 83608e4 889a28b e4ddb7f 20f980f 889a28b 1953b4d 2950611 1953b4d 2950611 1953b4d 889a28b 1953b4d 52f607b 20f980f 3af6d28 52f607b 83608e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from flask import Flask, request, jsonify, Response, stream_with_context
import requests
import json
import time
import random
import logging
import sys
import re
from logging.handlers import TimedRotatingFileHandler
from datetime import datetime
app = Flask(__name__)
# 配置日志
class RequestFormatter(logging.Formatter):
def format(self, record):
if request.method in ['POST', 'GET']: # 记录 POST 和 GET 请求
record.url = request.url
record.remote_addr = request.remote_addr
record.token = request.headers.get('Authorization', 'No Token')
return super().format(record)
return None
formatter = RequestFormatter(
'%(remote_addr)s - - [%(asctime)s] - Token: %(token)s - %(message)s',
datefmt='%d/%b/%Y %H:%M:%S'
)
handler = TimedRotatingFileHandler('app.log', when="midnight", interval=1, backupCount=30)
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)
app.logger.addHandler(handler)
app.logger.setLevel(logging.INFO)
# 模型映射
MODEL_MAPPING = {
"flux.1-schnell": {
"provider": "black-forest-labs",
"mapping": "black-forest-labs/FLUX.1-schnell"
},
"sd-turbo": {
"provider": "stabilityai",
"mapping": "stabilityai/sd-turbo"
},
"sdxl-turbo": {
"provider": "stabilityai",
"mapping": "stabilityai/sdxl-turbo"
},
"stable-diffusion-2-1": {
"provider": "stabilityai",
"mapping": "stabilityai/stable-diffusion-2-1"
},
"stable-diffusion-3-medium": {
"provider": "stabilityai",
"mapping": "stabilityai/stable-diffusion-3-medium"
},
"stable-diffusion-xl-base-1.0": {
"provider": "stabilityai",
"mapping": "stabilityai/stable-diffusion-xl-base-1.0"
}
}
# 模拟身份验证函数
def getAuthCookie(req):
auth_cookie = req.headers.get('Authorization')
if auth_cookie and auth_cookie.startswith('Bearer '):
return auth_cookie
return None
@app.route('/ai/v1/models', methods=['GET'])
def get_models():
try:
# 验证身份
auth_cookie = getAuthCookie(request)
if not auth_cookie:
app.logger.info(f'GET /ai/v1/models - 401 Unauthorized')
return jsonify({"error": "Unauthorized"}), 401
# 返回模型列表
models_list = [
{
"id": model_id,
"object": "model",
"created": int(time.time()),
"owned_by": info["provider"],
"permission": [],
"root": model_id,
"parent": None
}
for model_id, info in MODEL_MAPPING.items()
]
# 记录日志
app.logger.info(f'GET /ai/v1/models - 200 OK')
return jsonify({
"object": "list",
"data": models_list
})
except Exception as error:
app.logger.error(f"Error: {str(error)}")
return jsonify({"error": "Authentication failed", "details": str(error)}), 401
@app.route('/ai/v1/chat/completions', methods=['POST'])
def handle_request():
try:
body = request.json
model = body.get('model')
messages = body.get('messages')
stream = body.get('stream', False)
if not model or not messages or len(messages) == 0:
return jsonify({"error": "Bad Request: Missing required fields"}), 400
# 映射 model
if model in MODEL_MAPPING:
mapped_model = MODEL_MAPPING[model]['mapping']
else:
return jsonify({"error": f"Model '{model}' not found"}), 400
prompt = messages[-1]['content']
image_size, clean_prompt, use_original, size_param = extract_params_from_prompt(prompt)
auth_header = request.headers.get('Authorization')
random_token = get_random_token(auth_header)
if not random_token:
return jsonify({"error": "Unauthorized: Invalid or missing Authorization header"}), 401
if use_original:
enhanced_prompt = clean_prompt
else:
enhanced_prompt = translate_and_enhance_prompt(clean_prompt, random_token)
new_url = f'https://api.siliconflow.cn/v1/{mapped_model}/text-to-image'
new_request_body = {
"prompt": enhanced_prompt,
"image_size": image_size,
"batch_size": 1,
"num_inference_steps": 4,
"guidance_scale": 1
}
headers = {
'accept': 'application/json',
'content-type': 'application/json',
'Authorization': f'Bearer {random_token}'
}
response = requests.post(new_url, headers=headers, json=new_request_body, timeout=60)
response.raise_for_status()
response_body = response.json()
if 'images' in response_body and response_body['images'] and 'url' in response_body['images'][0]:
image_url = response_body['images'][0]['url']
else:
raise ValueError("Unexpected response structure from image generation API")
unique_id = str(int(time.time() * 1000))
current_timestamp = int(time.time())
system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9))
image_data = {'data': [{'url': image_url}]}
# Log the key information
params = []
if size_param != "16:9":
params.append(f"-s {size_param}")
if use_original:
params.append("-o")
params_str = " ".join(params) if params else "no params"
app.logger.info(f'POST /ai/v1/chat/completions - Status: 200 - Token: {random_token} - Model: {mapped_model} - Params: {params_str} - Image URL: {image_url}')
if stream:
return stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
else:
return non_stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
except Exception as e:
app.logger.error(f"Error: {str(e)}")
return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500
# 其余代码保持不变
# 例如 stream_response, generate_stream, non_stream_response 等函数
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8000)
|