File size: 6,573 Bytes
bc4ed93
1953b4d
52f607b
1953b4d
bc7ed47
403a673
a92cfc8
44dac73
e1b2102
f382771
a5853e1
1953b4d
52f607b
403a673
889a28b
 
20f980f
e1b2102
 
20f980f
e1b2102
 
889a28b
 
20f980f
e4ddb7f
a92cfc8
889a28b
e1b2102
889a28b
 
 
 
 
403a673
bbe47aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20f980f
 
 
 
 
 
535a218
20f980f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83608e4
e1b2102
20f980f
 
 
 
 
 
 
 
 
 
 
a5d1c4e
b735536
1953b4d
 
 
52f607b
 
 
 
1953b4d
83608e4
b2af907
 
 
 
 
 
52f607b
889a28b
83608e4
889a28b
 
83608e4
 
 
2950611
 
 
889a28b
83608e4
b2af907
52f607b
83608e4
44dac73
52f607b
 
 
 
83608e4
1953b4d
 
 
889a28b
1953b4d
83608e4
889a28b
 
 
a92cfc8
889a28b
 
 
 
403a673
889a28b
ab730f6
3af6d28
 
 
83608e4
889a28b
e4ddb7f
 
 
 
 
 
 
20f980f
889a28b
1953b4d
2950611
1953b4d
2950611
1953b4d
889a28b
1953b4d
52f607b
20f980f
 
3af6d28
52f607b
83608e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from flask import Flask, request, jsonify, Response, stream_with_context
import requests
import json
import time
import random
import logging
import sys
import re
from logging.handlers import TimedRotatingFileHandler
from datetime import datetime

app = Flask(__name__)

# 配置日志
class RequestFormatter(logging.Formatter):
    def format(self, record):
        if request.method in ['POST', 'GET']:  # 记录 POST 和 GET 请求
            record.url = request.url
            record.remote_addr = request.remote_addr
            record.token = request.headers.get('Authorization', 'No Token')
            return super().format(record)
        return None

formatter = RequestFormatter(
    '%(remote_addr)s - - [%(asctime)s] - Token: %(token)s - %(message)s',
    datefmt='%d/%b/%Y %H:%M:%S'
)

handler = TimedRotatingFileHandler('app.log', when="midnight", interval=1, backupCount=30)
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)

app.logger.addHandler(handler)
app.logger.setLevel(logging.INFO)

# 模型映射
MODEL_MAPPING = {
    "flux.1-schnell": {
        "provider": "black-forest-labs",
        "mapping": "black-forest-labs/FLUX.1-schnell"
    },
    "sd-turbo": {
        "provider": "stabilityai",
        "mapping": "stabilityai/sd-turbo"
    },
    "sdxl-turbo": {
        "provider": "stabilityai",
        "mapping": "stabilityai/sdxl-turbo"
    },
    "stable-diffusion-2-1": {
        "provider": "stabilityai",
        "mapping": "stabilityai/stable-diffusion-2-1"
    },
    "stable-diffusion-3-medium": {
        "provider": "stabilityai",
        "mapping": "stabilityai/stable-diffusion-3-medium"
    },
    "stable-diffusion-xl-base-1.0": {
        "provider": "stabilityai",
        "mapping": "stabilityai/stable-diffusion-xl-base-1.0"
    }
}

# 模拟身份验证函数
def getAuthCookie(req):
    auth_cookie = req.headers.get('Authorization')
    if auth_cookie and auth_cookie.startswith('Bearer '):
        return auth_cookie
    return None

@app.route('/ai/v1/models', methods=['GET'])
def get_models():
    try:
        # 验证身份
        auth_cookie = getAuthCookie(request)
        if not auth_cookie:
            app.logger.info(f'GET /ai/v1/models - 401 Unauthorized')
            return jsonify({"error": "Unauthorized"}), 401
        
        # 返回模型列表
        models_list = [
            {
                "id": model_id,
                "object": "model",
                "created": int(time.time()),
                "owned_by": info["provider"],
                "permission": [],
                "root": model_id,
                "parent": None
            }
            for model_id, info in MODEL_MAPPING.items()
        ]
        
        # 记录日志
        app.logger.info(f'GET /ai/v1/models - 200 OK')
        
        return jsonify({
            "object": "list",
            "data": models_list
        })
    
    except Exception as error:
        app.logger.error(f"Error: {str(error)}")
        return jsonify({"error": "Authentication failed", "details": str(error)}), 401

@app.route('/ai/v1/chat/completions', methods=['POST'])
def handle_request():
    try:
        body = request.json
        model = body.get('model')
        messages = body.get('messages')
        stream = body.get('stream', False)
        if not model or not messages or len(messages) == 0:
            return jsonify({"error": "Bad Request: Missing required fields"}), 400
        
        # 映射 model
        if model in MODEL_MAPPING:
            mapped_model = MODEL_MAPPING[model]['mapping']
        else:
            return jsonify({"error": f"Model '{model}' not found"}), 400
        
        prompt = messages[-1]['content']
        image_size, clean_prompt, use_original, size_param = extract_params_from_prompt(prompt)
        
        auth_header = request.headers.get('Authorization')
        random_token = get_random_token(auth_header)
        if not random_token:
            return jsonify({"error": "Unauthorized: Invalid or missing Authorization header"}), 401
        
        if use_original:
            enhanced_prompt = clean_prompt
        else:
            enhanced_prompt = translate_and_enhance_prompt(clean_prompt, random_token)
        
        new_url = f'https://api.siliconflow.cn/v1/{mapped_model}/text-to-image'
        new_request_body = {
            "prompt": enhanced_prompt,
            "image_size": image_size,
            "batch_size": 1,
            "num_inference_steps": 4,
            "guidance_scale": 1
        }
        
        headers = {
            'accept': 'application/json',
            'content-type': 'application/json',
            'Authorization': f'Bearer {random_token}'
        }
        
        response = requests.post(new_url, headers=headers, json=new_request_body, timeout=60)
        response.raise_for_status()
        response_body = response.json()
        
        if 'images' in response_body and response_body['images'] and 'url' in response_body['images'][0]:
            image_url = response_body['images'][0]['url']
        else:
            raise ValueError("Unexpected response structure from image generation API")
        
        unique_id = str(int(time.time() * 1000))
        current_timestamp = int(time.time())
        system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9))
        
        image_data = {'data': [{'url': image_url}]}
        
        # Log the key information
        params = []
        if size_param != "16:9":
            params.append(f"-s {size_param}")
        if use_original:
            params.append("-o")
        params_str = " ".join(params) if params else "no params"
        
        app.logger.info(f'POST /ai/v1/chat/completions - Status: 200 - Token: {random_token} - Model: {mapped_model} - Params: {params_str} - Image URL: {image_url}')
        
        if stream:
            return stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
        else:
            return non_stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
    except Exception as e:
        app.logger.error(f"Error: {str(e)}")
        return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500

# 其余代码保持不变
# 例如 stream_response, generate_stream, non_stream_response 等函数

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000)