from fastapi import FastAPI,Request,HTTPException
from fastapi.responses import StreamingResponse
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import httpx
import json,os,random
from datetime import datetime
import pytz
def get_si_key():
    random.shuffle(keys)
    return keys[0]
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 允许的源列表
    allow_credentials=True,  # 允许发送 cookies
    allow_methods=["*"],  # 允许所有方法
    allow_headers=["*"],  # 允许所有头部
)
base_url = "https://api.siliconflow.cn/v1/chat/completions"
if os.environ.get("MODEL_MAP"): 
    model_map=json.loads(os.environ.get("MODEL_MAP"))
else:
    model_map={
        "qwen-72b":"Qwen/Qwen2.5-72B-Instruct",
        "qwen-32b":"Qwen/Qwen2.5-32B-Instruct",
        "qwen-14b":"Qwen/Qwen2.5-14B-Instruct",
        "qwen-7b":"Qwen/Qwen2.5-7B-Instruct",
        "qwen-vl":"Qwen/Qwen2-VL-72B-Instruct",
        "qwen-coder":"Qwen/Qwen2.5-Coder-32B-Instruct",
        "qwq":"Qwen/QwQ-32B-Preview",
        "o1":"AIDC-AI/Marco-o1",
        "deepseek":"deepseek-ai/DeepSeek-V2.5",
        "deepseek-vl":"deepseek-ai/deepseek-vl2",
        "glm-9b":"THUDM/glm-4-9b-chat",
        "bce":"netease-youdao/bce-embedding-base_v1",
        "bge-m3":"BAAI/bge-m3",
        "bge-zh":"BAAI/bge-large-zh-v1.5"
    }
keys=os.environ.get("SI_KEY").split(",")
key_balacnce={}
key_balacnce_notes=""
# 创建一个东八区的时区对象
tz = pytz.timezone('Asia/Shanghai')

now=""
@app.get("/",response_class=HTMLResponse)
async def root():
    # 获取当前时间,并设置时区为东八区
    models=""
    for key in model_map:
        models+=f"<h2>{key}————{model_map[key]}</h2>"
    global now
    return  f"""
    <html>
        <head>
            <title>富文本示例</title>
        </head>
        <body>
            <h1>有效key数量:{len(keys)}</h1>
            {models}
            <h1>最后更新时间:{now}</h1>
            {key_balacnce_notes}
        </body>
    </html>
    """
@app.get("/check")
async def check():
    global key_balacnce,key_balacnce_notes,now,keys
    key_balacnce_notes=""
    for i,key in enumerate(keys):
        url = "https://api.siliconflow.cn/v1/user/info"
        headers={
            "Authorization":f"Bearer {key.strip()}"
        }
        async with httpx.AsyncClient() as client:
            res=await client.get(url,headers=headers)
            if res.status_code==200:
                balance=res.json()['data']['balance']
                if float(balance)<0.1:
                    keys.pop(i)
                    continue
                va=f'''<h2>{key.strip()[0:4]}****{key.strip()[-4:]}————{balance}</h2>'''
                key_balacnce[key.strip()]=balance
                key_balacnce_notes+=va
    
    now = datetime.now(tz)
    
    return f"更新成功:{now}"            
@app.post("/hf/v1/chat/completions")
async def reforword(request:Request):
    body = await request.json()
    # 获取 API 密钥
    key = get_si_key()
    print(key)
    headers = {
            "Authorization": f"Bearer {key}"
        }
    body_map=model_map
    if "model" in body.keys() and body['model'] in body_map.keys():
        body['model']=body_map[body['model']]
    if "stream" in body.keys() and body['stream']:    
        # 使用 httpx 发送流式请求
        async def generate_response():
            async with httpx.AsyncClient() as client:
                async with client.stream("POST", base_url, headers=headers, json=body) as response:
                    response.raise_for_status()  # 检查响应状态码
                    async for chunk in response.aiter_bytes():
                        if chunk:
                            yield chunk

        return StreamingResponse(generate_response(), media_type="text/event-stream")
    else:
        # 发送 POST 请求
        async with httpx.AsyncClient() as client:
            response = await client.post(base_url, headers=headers, json=body)
            response.raise_for_status()
            return response.json()
@app.post("/hf/v1/embeddings")
async def embedding(request:Request):
    body=await request.json()
    body_map=model_map
    if "model" in body.keys() and body['model'] in body_map.keys():
        body['model']=body_map[body['model']]
    # 获取 API 密钥
    key = get_si_key()
    print(key)
    headers = {
        "Authorization": f"Bearer {key}"
    }
    async with httpx.AsyncClient() as client:
        response = await client.post(base_url, json=body, headers=headers)
        response.raise_for_status()  # 检查请求是否成功
        return response.json()