File size: 3,634 Bytes
1953b4d
 
52f607b
1953b4d
bc7ed47
a5853e1
1953b4d
52f607b
a5d1c4e
 
fab1a10
a5d1c4e
b735536
1953b4d
 
 
52f607b
 
 
f8d271e
52f607b
1953b4d
f8d271e
bc7ed47
 
 
 
 
 
 
 
 
 
 
52f607b
1953b4d
f8d271e
52f607b
 
 
 
 
 
 
f8d271e
1953b4d
 
 
bc7ed47
1953b4d
f8d271e
1953b4d
 
f8d271e
1953b4d
70ac8d0
 
f8d271e
1953b4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52f607b
1953b4d
 
 
 
 
52f607b
1953b4d
 
 
f8d271e
1953b4d
 
52f607b
 
f8d271e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from flask import Flask, request, jsonify, Response
import requests
import json
import time
import random

app = Flask(__name__)

@app.route('/')
def index():
    return "text-to-image with siliconflow", 200

@app.route('/ai/v1/chat/completions', methods=['POST'])
def handle_request():
    try:
        body = request.json
        model = body.get('model')
        messages = body.get('messages')
        stream = body.get('stream', False)

        if not model or not messages or len(messages) == 0:
            return jsonify({"error": "Bad Request: Missing required fields"}), 400

        authorization_header = request.headers.get('Authorization')
        if not authorization_header:
            return jsonify({"error": "Unauthorized: Missing Authorization header"}), 401

        # Extract tokens from Authorization header
        tokens = authorization_header.split(' ')[1].split(',')
        if len(tokens) == 1:
            selected_token = tokens[0]
        else:
            selected_token = random.choice(tokens)

        prompt = messages[-1]['content']
        new_url = f'https://api.siliconflow.cn/v1/{model}/text-to-image'

        new_request_body = {
            "prompt": prompt,
            "image_size": "1024x1024",
            "batch_size": 1,
            "num_inference_steps": 4,
            "guidance_scale": 1
        }

        headers = {
            'accept': 'application/json',
            'content-type': 'application/json',
            'Authorization': f'Bearer {selected_token}'
        }

        response = requests.post(new_url, headers=headers, json=new_request_body)
        response_body = response.json()

        image_url = response_body['images'][0]['url']
        unique_id = str(int(time.time() * 1000))  # Convert id to string
        current_timestamp = int(unique_id) // 1000

        if stream:
            response_payload = {
                "id": unique_id,
                "object": "chat.completion.chunk",
                "created": current_timestamp,
                "model": model,
                "choices": [
                    {
                        "index": 0,
                        "delta": {
                            "content": f"![]({image_url})"
                        },
                        "finish_reason": "stop"
                    }
                ]
            }
            data_string = json.dumps(response_payload)
            return Response(f"data: {data_string}\n\n", content_type='text/event-stream')
        else:
            response_payload = {
                "id": unique_id,
                "object": "chat.completion",
                "created": current_timestamp,
                "model": model,
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": f"![]({image_url})"
                        },
                        "logprobs": None,
                        "finish_reason": "length"
                    }
                ],
                "usage": {
                    "prompt_tokens": len(prompt),
                    "completion_tokens": len(image_url),
                    "total_tokens": len(prompt) + len(image_url)
                }
            }
            data_string = json.dumps(response_payload)
            return Response(f"{data_string}\n\n", content_type='text/event-stream')

    except Exception as e:
        return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000)