from flask import Flask, request, jsonify, Response, stream_with_context
import requests
import json
import time
import random
import logging
import sys
import re
from logging.handlers import TimedRotatingFileHandler

app = Flask(__name__)

# 配置日志
class RequestFormatter(logging.Formatter):
    def format(self, record):
        if request.method in ['POST', 'GET']:  # 记录 POST 和 GET 请求
            record.url = request.url
            record.remote_addr = request.remote_addr
            record.token = request.headers.get('Authorization', 'No Token')
            return super().format(record)
        return None

formatter = RequestFormatter(
    '%(remote_addr)s - - [%(asctime)s] - Token: %(token)s - %(message)s',
    datefmt='%d/%b/%Y %H:%M:%S'
)

handler = TimedRotatingFileHandler('app.log', when="midnight", interval=1, backupCount=30)
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)

app.logger.addHandler(handler)
app.logger.setLevel(logging.INFO)

# 模型映射
MODEL_MAPPING = {
    "flux.1-schnell": {
        "provider": "black-forest-labs",
        "mapping": "black-forest-labs/FLUX.1-schnell"
    },
    "sd-turbo": {
        "provider": "stabilityai",
        "mapping": "stabilityai/sd-turbo"
    },
    "sdxl-turbo": {
        "provider": "stabilityai",
        "mapping": "stabilityai/sdxl-turbo"
    },
    "stable-diffusion-2-1": {
        "provider": "stabilityai",
        "mapping": "stabilityai/stable-diffusion-2-1"
    },
    "stable-diffusion-3-medium": {
        "provider": "stabilityai",
        "mapping": "stabilityai/stable-diffusion-3-medium"
    },
    "stable-diffusion-xl-base-1.0": {
        "provider": "stabilityai",
        "mapping": "stabilityai/stable-diffusion-xl-base-1.0"
    }
}

SYSTEM_ASSISTANT = """作为 Stable Diffusion Prompt 提示词专家,您将从关键词中创建提示,通常来自 Danbooru 等数据库。
提示通常描述图像,使用常见词汇,按重要性排列,并用逗号分隔。避免使用"-"或".",但可以接受空格和自然语言。避免词汇重复。

为了强调关键词,请将其放在括号中以增加其权重。例如,"(flowers)"将'flowers'的权重增加1.1倍,而"(((flowers)))"将其增加1.331倍。使用"(flowers:1.5)"将'flowers'的权重增加1.5倍。只为重要的标签增加权重。

提示包括三个部分:**前缀**(质量标签+风格词+效果器)+ **主题**(图像的主要焦点)+ **场景**(背景、环境)。

*   前缀影响图像质量。像"masterpiece"、"best quality"、"4k"这样的标签可以提高图像的细节。像"illustration"、"lensflare"这样的风格词定义图像的风格。像"bestlighting"、"lensflare"、"depthoffield"这样的效果器会影响光照和深度。

*   主题是图像的主要焦点,如角色或场景。对主题进行详细描述可以确保图像丰富而详细。增加主题的权重以增强其清晰度。对于角色,描述面部、头发、身体、服装、姿势等特征。

*   场景描述环境。没有场景,图像的背景是平淡的,主题显得过大。某些主题本身包含场景(例如建筑物、风景)。像"花草草地"、"阳光"、"河流"这样的环境词可以丰富场景。你的任务是设计图像生成的提示。请按照以下步骤进行操作:

1.  我会发送给您一个图像场景。需要你生成详细的图像描述
2.  图像描述必须是英文,输出为Positive Prompt。

示例:

我发送:二战时期的护士。
您回复只回复:
A WWII-era nurse in a German uniform, holding a wine bottle and stethoscope, sitting at a table in white attire, with a table in the background, masterpiece, best quality, 4k, illustration style, best lighting, depth of field, detailed character, detailed environment.
"""

RATIO_MAP = {
    "1:1": "1024x1024",
    "1:2": "1024x2048",
    "3:2": "1536x1024",
    "4:3": "1536x2048",
    "16:9": "2048x1152",
    "9:16": "1152x2048"
}

# 模拟身份验证函数
def getAuthCookie(req):
    auth_cookie = req.headers.get('Authorization')
    if auth_cookie and auth_cookie.startswith('Bearer '):
        return auth_cookie
    return None

@app.route('/ai/v1/models', methods=['GET'])
def get_models():
    try:
        # 验证身份
        auth_cookie = getAuthCookie(request)
        if not auth_cookie:
            app.logger.info(f'GET /ai/v1/models - 401 Unauthorized')
            return jsonify({"error": "Unauthorized"}), 401
        
        # 返回模型列表
        models_list = [
            {
                "id": model_id,
                "object": "model",
                "created": int(time.time()),
                "owned_by": info["provider"],
                "permission": [],
                "root": model_id,
                "parent": None
            }
            for model_id, info in MODEL_MAPPING.items()
        ]
        
        # 记录日志
        app.logger.info(f'GET /ai/v1/models - 200 OK')
        
        return jsonify({
            "object": "list",
            "data": models_list
        })
    
    except Exception as error:
        app.logger.error(f"Error: {str(error)}")
        return jsonify({"error": "Authentication failed", "details": str(error)}), 401

@app.route('/ai/v1/chat/completions', methods=['POST'])
def handle_request():
    try:
        body = request.json
        model = body.get('model')
        messages = body.get('messages')
        stream = body.get('stream', False)
        if not model or not messages or len(messages) == 0:
            return jsonify({"error": "Bad Request: Missing required fields"}), 400
        
        # 映射 model
        if model in MODEL_MAPPING:
            mapped_model = MODEL_MAPPING[model]['mapping']
        else:
            return jsonify({"error": f"Model '{model}' not found"}), 400
        
        prompt = messages[-1]['content']
        image_size, clean_prompt, use_original, size_param = extract_params_from_prompt(prompt)
        
        auth_header = request.headers.get('Authorization')
        random_token = get_random_token(auth_header)
        if not random_token:
            return jsonify({"error": "Unauthorized: Invalid or missing Authorization header"}), 401
        
        if use_original:
            enhanced_prompt = clean_prompt
        else:
            enhanced_prompt = translate_and_enhance_prompt(clean_prompt, random_token)
        
        new_url = f'https://api.siliconflow.cn/v1/{mapped_model}/text-to-image'
        new_request_body = {
            "prompt": enhanced_prompt,
            "image_size": image_size,
            "batch_size": 1,
            "num_inference_steps": 4,
            "guidance_scale": 1
        }
        
        headers = {
            'accept': 'application/json',
            'content-type': 'application/json',
            'Authorization': f'Bearer {random_token}'
        }
        
        response = requests.post(new_url, headers=headers, json=new_request_body, timeout=60)
        response.raise_for_status()
        response_body = response.json()
        
        if 'images' in response_body and response_body['images'] and 'url' in response_body['images'][0]:
            image_url = response_body['images'][0]['url']
        else:
            raise ValueError("Unexpected response structure from image generation API")
        
        unique_id = str(int(time.time() * 1000))
        current_timestamp = int(time.time())
        system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9))
        
        image_data = {'data': [{'url': image_url}]}
        
        # Log the key information
        params = []
        if size_param != "16:9":
            params.append(f"-s {size_param}")
        if use_original:
            params.append("-o")
        params_str = " ".join(params) if params else "no params"
        
        app.logger.info(f'POST /ai/v1/chat/completions - Status: 200 - Token: {random_token} - Model: {mapped_model} - Params: {params_str} - Image URL: {image_url}')
        
        if stream:
            return stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
        else:
            return non_stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint, use_original)
    except Exception as e:
        app.logger.error(f"Error: {str(e)}")
        return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500

def extract_params_from_prompt(prompt):
    size_match = re.search(r'-s\s+(\S+)', prompt)
    original_match = re.search(r'-o', prompt)
    
    if size_match:
        size = size_match.group(1)
        clean_prompt = re.sub(r'-s\s+\S+', '', prompt).strip()
    else:
        size = "16:9"
        clean_prompt = prompt
    
    use_original = bool(original_match)
    if use_original:
        clean_prompt = re.sub(r'-o', '', clean_prompt).strip()
    
    image_size = RATIO_MAP.get(size, RATIO_MAP["16:9"])
    return image_size, clean_prompt, use_original, size

def get_random_token(auth_header):
    if not auth_header:
        return None
    if auth_header.startswith('Bearer '):
        auth_header = auth_header[7:]
    tokens = [token.strip() for token in auth_header.split(',') if token.strip()]
    if not tokens:
        return None
    return random.choice(tokens)

def translate_and_enhance_prompt(prompt, auth_token):
    translate_url = 'https://api.siliconflow.cn/v1/chat/completions'
    translate_body = {
        'model': 'Qwen/Qwen2-72B-Instruct',
        'messages': [
            {'role': 'system', 'content': SYSTEM_ASSISTANT},
            {'role': 'user', 'content': prompt}
        ]
    }
    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {auth_token}'
    }
    
    response = requests.post(translate_url, headers=headers, json=translate_body, timeout=30)
    response.raise_for_status()
    result = response.json()
    return result['choices'][0]['message']['content']

def stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original):
    return Response(stream_with_context(generate_stream(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original)), content_type='text/event-stream')

def generate_stream(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original):
    chunks = [
        f"原始提示词:\n{original_prompt}\n",
    ]
    
    if not use_original:
        chunks.append(f"翻译后的提示词:\n{translated_prompt}\n")
    
    chunks.extend([
        f"图像规格:{size}\n",
        "正在根据提示词生成图像...\n",
        "图像正在处理中...\n",
        "即将完成...\n",
        f"生成成功!\n图像生成完毕,以下是结果:\n\n![生成的图像]({image_data['data'][0]['url']})"
    ])

    for i, chunk in enumerate(chunks):
        json_chunk = json.dumps({
            "id": unique_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model,
            "system_fingerprint": system_fingerprint,
            "choices": [{
                "index": 0,
                "delta": {"content": chunk},
                "logprobs": None,
                "finish_reason": None
            }]
        })
        yield f"data: {json_chunk}\n\n"
        time.sleep(0.5)  # 模拟生成时间

    final_chunk = json.dumps({
        "id": unique_id,
        "object": "chat.completion.chunk",
        "created": created,
        "model": model,
        "system_fingerprint": system_fingerprint,
        "choices": [{
            "index": 0,
            "delta": {},
            "logprobs": None,
            "finish_reason": "stop"
        }]
    })
    yield f"data: {final_chunk}\n\n"

def non_stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint, use_original):
    content = f"原始提示词:{original_prompt}\n"
    
    if not use_original:
        content += f"翻译后的提示词:{translated_prompt}\n"
    
    content += (
        f"图像规格:{size}\n"
        f"图像生成成功!\n"
        f"以下是结果:\n\n"
        f"![生成的图像]({image_data['data'][0]['url']})"
    )

    response = {
        'id': unique_id,
        'object': "chat.completion",
        'created': created,
        'model': model,
        'system_fingerprint': system_fingerprint,
        'choices': [{
            'index': 0,
            'message': {
                'role': "assistant",
                'content': content
            },
            'finish_reason': "stop"
        }],
        'usage': {
            'prompt_tokens': len(original_prompt),
            'completion_tokens': len(content),
            'total_tokens': len(original_prompt) + len(content)
        }
    }

    return jsonify(response)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000)