File size: 3,708 Bytes
d1a7225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, HTTPException, Depends, Header, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List
from g4f import ChatCompletion
from slowapi import Limiter
from slowapi.util import get_remote_address

app = FastAPI()

# Initialize the rate limiter
limiter = Limiter(key_func=get_remote_address)

# List of available models
models = [
    "gpt-4o", "gpt-4o-mini", "gpt-4",
    "gpt-4-turbo", "gpt-3.5-turbo",
    "claude-3.7-sonnet", "o3-mini", "o1", "claude-3.5", "llama-3.1-405b", "gemini-flash", "blackboxai-pro", "openchat-3.5", "glm-4-9B", "blackboxai"
]

# Request model
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    streaming: bool = True  # Add streaming support

class ChatResponse(BaseModel):
    role: str
    content: str

# Dependency to check API key
async def verify_api_key(x_api_key: str = Header(...)):
    if x_api_key != "vs-5wEvIw6vfLKIypGm7uiNoWuXrJcg4vAL":  # Replace with your actual API key
        raise HTTPException(status_code=403, detail="Invalid API key")

@app.get("/v1/models", tags=["Models"])
async def get_models():
    """Endpoint to get the list of available models."""
    return {"models": models}

@app.post("/v1/chat/completions", tags=["Chat Completion"])
@limiter.limit("10/minute")  # Rate limit to 10 requests per minute
async def chat_completion(
    request: Request,
    chat_request: ChatRequest,
    api_key: str = Depends(verify_api_key)
):
    # Validate model
    if chat_request.model not in models:
        raise HTTPException(status_code=400, detail="Invalid model selected.")

    # Check if messages are provided
    if not chat_request.messages:
        raise HTTPException(status_code=400, detail="Messages cannot be empty.")

    # Convert messages to the format expected by ChatCompletion
    formatted_messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]

    try:
        if chat_request.streaming:
            # Stream the response
            def event_stream():
                response = ChatCompletion.create(
                    model=chat_request.model,
                    messages=formatted_messages,
                    stream=True  # Enable streaming
                )
                
                for chunk in response:
                    if isinstance(chunk, dict) and 'choices' in chunk:
                        for choice in chunk['choices']:
                            if 'message' in choice:
                                yield f"data: {choice['message']['content']}\n\n"
                    else:
                        yield f"data: {chunk}\n\n"  # Fallback if chunk is not as expected

            return StreamingResponse(event_stream(), media_type="text/event-stream")
        else:
            # Non-streaming response
            response = ChatCompletion.create(
                model=chat_request.model,
                messages=formatted_messages
            )

            if isinstance(response, str):
                response_content = response  # Directly use if it's a string
            else:
                try:
                    response_content = response['choices'][0]['message']['content']
                except (IndexError, KeyError):
                    raise HTTPException(status_code=500, detail="Unexpected response structure.")

            return ChatResponse(role="assistant", content=response_content)

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)