flux2api / app.py
smgc's picture
Update app.py
20f980f verified
raw
history blame
6.57 kB
from flask import Flask, request, jsonify, Response, stream_with_context
import requests
import json
import time
import random
import logging
import sys
import re
from logging.handlers import TimedRotatingFileHandler
from datetime import datetime
app = Flask(__name__)
# 配置日志
class RequestFormatter(logging.Formatter):
def format(self, record):
if request.method in ['POST', 'GET']: # 记录 POST 和 GET 请求
record.url = request.url
record.remote_addr = request.remote_addr
record.token = request.headers.get('Authorization', 'No Token')
return super().format(record)
return None
formatter = RequestFormatter(
'%(remote_addr)s - - [%(asctime)s] - Token: %(token)s - %(message)s',
datefmt='%d/%b/%Y %H:%M:%S'
)
handler = TimedRotatingFileHandler('app.log', when="midnight", interval=1, backupCount=30)
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)
app.logger.addHandler(handler)
app.logger.setLevel(logging.INFO)
# 模型映射
MODEL_MAPPING = {
"flux.1-schnell": {
"provider": "black-forest-labs",
"mapping": "black-forest-labs/FLUX.1-schnell"
},
"sd-turbo": {
"provider": "stabilityai",
"mapping": "stabilityai/sd-turbo"
},
"sdxl-turbo": {
"provider": "stabilityai",
"mapping": "stabilityai/sdxl-turbo"
},
"stable-diffusion-2-1": {
"provider": "stabilityai",
"mapping": "stabilityai/stable-diffusion-2-1"
},
"stable-diffusion-3-medium": {
"provider": "stabilityai",
"mapping": "stabilityai/stable-diffusion-3-medium"
},
"stable-diffusion-xl-base-1.0": {
"provider": "stabilityai",
"mapping": "stabilityai/stable-diffusion-xl-base-1.0"
}
}
# 模拟身份验证函数
def getAuthCookie(req):
auth_cookie = req.headers.get('Authorization')
if auth_cookie and auth_cookie.startswith('Bearer '):
return auth_cookie
return None
@app.route('/ai/v1/models', methods=['GET'])
def get_models():
try:
# 验证身份
auth_cookie = getAuthCookie(request)
if not auth_cookie:
app.logger.info(f'GET /ai/v1/models - 401 Unauthorized')
return jsonify({"error": "Unauthorized"}), 401
# 返回模型列表
models_list = [
{
"id": model_id,
"object": "model",
"created": int(time.time()),
"owned_by": info["provider"],
"permission": [],
"root": model_id,
"parent": None
}
for model_id, info in MODEL_MAPPING.items()
]
# 记录日志
app.logger.info(f'GET /ai/v1/models - 200 OK')
return jsonify({
"object": "list",
"data": models_list
})
except Exception as error:
app.logger.error(f"Error: {str(error)}")
return jsonify({"error": "Authentication failed", "details": str(error)}), 401
@app.route('/ai/v1/chat/completions', methods=['POST'])
def handle_request():
try:
body = request.json
model = body.get('model')
messages = body.get('messages')
stream = body.get('stream', False)
if not model or not messages or len(messages) == 0:
return jsonify({"error": "Bad Request: Missing required fields"}), 400
# 映射 model
if model in MODEL_MAPPING:
mapped_model = MODEL_MAPPING[model]['mapping']
else:
return jsonify({"error": f"Model '{model}' not found"}), 400
prompt = messages[-1]['content']
image_size, clean_prompt, use_original, size_param = extract_params_from_prompt(prompt)
auth_header = request.headers.get('Authorization')
random_token = get_random_token(auth_header)
if not random_token:
return jsonify({"error": "Unauthorized: Invalid or missing Authorization header"}), 401
if use_original:
enhanced_prompt = clean_prompt
else:
enhanced_prompt = translate_and_enhance_prompt(clean_prompt, random_token)
new_url = f'https://api.siliconflow.cn/v1/{mapped_model}/text-to-image'
new_request_body = {
"prompt": enhanced_prompt,
"image_size": image_size,
"batch_size": 1,
"num_inference_steps": 4,
"guidance_scale": 1
}
headers = {
'accept': 'application/json',
'content-type': 'application/json',
'Authorization': f'Bearer {random_token}'
}
response = requests.post(new_url, headers=headers, json=new_request_body, timeout=60)
response.raise_for_status()
response_body = response.json()
if 'images' in response_body and response_body['images'] and 'url' in response_body['images'][0]:
image_url = response_body['images'][0]['url']
else:
raise ValueError("Unexpected response structure from image generation API")
unique_id = str(int(time.time() * 1000))
current_timestamp = int(time.time())
system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9))
image_data = {'data': [{'url': image_url}]}
# Log the key information
params = []
if size_param != "16:9":
params.append(f"-s {size_param}")
if use_original:
params.append("-o")
params_str = " ".join(params) if params else "no params"
app.logger.info(f'POST /ai/v1/chat/completions - Status: 200 - Token: {random_token} - Model: {mapped_model} - Params: {params_str} - Image URL: {image_url}')
if stream:
return stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
else:
return non_stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
except Exception as e:
app.logger.error(f"Error: {str(e)}")
return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500
# 其余代码保持不变
# 例如 stream_response, generate_stream, non_stream_response 等函数
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8000)