|
from http.server import HTTPServer, BaseHTTPRequestHandler |
|
import json |
|
from datetime import datetime |
|
import urllib.request |
|
import ssl |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ChatCompletionsHandler(BaseHTTPRequestHandler): |
|
def do_POST(self): |
|
if self.path != '/ai/v1/chat/completions': |
|
self.send_error(404, "Not Found") |
|
return |
|
|
|
content_length = int(self.headers['Content-Length']) |
|
post_data = self.rfile.read(content_length) |
|
body = json.loads(post_data.decode('utf-8')) |
|
|
|
model = body.get('model') |
|
messages = body.get('messages') |
|
stream = body.get('stream', False) |
|
|
|
if not model or not messages or len(messages) == 0: |
|
self.send_error(400, "Bad Request: Missing required fields") |
|
return |
|
|
|
prompt = messages[-1]['content'] |
|
new_url = f"https://api.siliconflow.cn/v1/{model}/text-to-image" |
|
new_request_body = { |
|
"prompt": prompt, |
|
"image_size": "1024x1024", |
|
"batch_size": 1, |
|
"num_inference_steps": 4, |
|
"guidance_scale": 1 |
|
} |
|
|
|
req = urllib.request.Request(new_url, |
|
data=json.dumps(new_request_body).encode('utf-8'), |
|
headers={ |
|
'accept': 'application/json', |
|
'content-type': 'application/json', |
|
'Authorization': self.headers.get('Authorization') |
|
}, |
|
method='POST') |
|
|
|
ctx = ssl.create_default_context() |
|
ctx.check_hostname = False |
|
ctx.verify_mode = ssl.CERT_NONE |
|
|
|
try: |
|
with urllib.request.urlopen(req, context=ctx) as response: |
|
response_body = json.loads(response.read().decode('utf-8')) |
|
|
|
logger.info(f"API Response: {response_body}") |
|
|
|
if 'images' not in response_body or not response_body['images']: |
|
raise ValueError("No images in the response") |
|
|
|
image_url = response_body['images'][0].get('url') |
|
if not image_url: |
|
raise ValueError("No URL in the image response") |
|
|
|
unique_id = int(datetime.now().timestamp() * 1000) |
|
current_timestamp = int(unique_id / 1000) |
|
|
|
if stream: |
|
response_payload = { |
|
"id": unique_id, |
|
"object": "chat.completion.chunk", |
|
"created": current_timestamp, |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"delta": { |
|
"content": f"" |
|
}, |
|
"finish_reason": "stop" |
|
} |
|
] |
|
} |
|
self.send_response(200) |
|
self.send_header('Content-Type', 'text/event-stream') |
|
self.end_headers() |
|
self.wfile.write(f"data: {json.dumps(response_payload)}\n\n".encode('utf-8')) |
|
else: |
|
response_payload = { |
|
"id": unique_id, |
|
"object": "chat.completion", |
|
"created": current_timestamp, |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": f"" |
|
}, |
|
"logprobs": None, |
|
"finish_reason": "length" |
|
} |
|
], |
|
"usage": { |
|
"prompt_tokens": len(prompt), |
|
"completion_tokens": len(image_url), |
|
"total_tokens": len(prompt) + len(image_url) |
|
} |
|
} |
|
self.send_response(200) |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
self.wfile.write(json.dumps(response_payload).encode('utf-8')) |
|
|
|
except Exception as e: |
|
logger.error(f"Error occurred: {str(e)}") |
|
self.send_error(500, f"Internal Server Error: {str(e)}") |
|
|
|
def run(server_class=HTTPServer, handler_class=ChatCompletionsHandler, port=8000): |
|
server_address = ('', port) |
|
httpd = server_class(server_address, handler_class) |
|
logger.info(f"Starting server on port {port}") |
|
httpd.serve_forever() |
|
|
|
if __name__ == '__main__': |
|
run() |