File size: 3,950 Bytes
daa1bb4
f7f30cc
4d2b16a
f7f30cc
bcdcebb
daa1bb4
89d8cc9
daa1bb4
 
bcdcebb
ad95a9a
89d8cc9
 
 
ad95a9a
bcdcebb
 
0c75fa7
bcdcebb
ad95a9a
36f72ba
daa1bb4
bcdcebb
89d8cc9
daa1bb4
 
 
bcdcebb
daa1bb4
 
f7f30cc
 
 
 
 
 
 
 
bcdcebb
ad95a9a
f7f30cc
 
a7b9a59
 
 
 
f7f30cc
bcdcebb
a7b9a59
f7f30cc
 
 
 
 
 
a7b9a59
bcdcebb
 
06ea63c
4d2b16a
f7f30cc
06ea63c
89d8cc9
 
06ea63c
4d2b16a
06ea63c
 
89d8cc9
4d2b16a
06ea63c
89d8cc9
4d2b16a
06ea63c
4d2b16a
89d8cc9
 
 
f7f30cc
06ea63c
f7f30cc
 
89d8cc9
 
 
 
a7b9a59
89d8cc9
 
 
 
 
 
 
 
06ea63c
89d8cc9
 
 
06ea63c
89d8cc9
 
 
06ea63c
bcdcebb
89d8cc9
a7b9a59
ad95a9a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging

app = FastAPI()

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")

# Configuration
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "gpt-4o-mini"

# Load system prompt mappings
with open("model_map.json", "r", encoding="utf-8") as f:
    MODEL_PROMPTS = json.load(f)

# Request schema
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0

# Build request to backend with injected system prompt
def build_payload(chat: ChatRequest):
    system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
    filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
    payload_messages = [{"role": "system", "content": system_prompt}] + [
        {"role": msg.role, "content": msg.content} for msg in filtered_messages
    ]
    return {
        "model": BACKEND_MODEL,
        "messages": payload_messages,
        "stream": chat.stream,
        "temperature": chat.temperature,
        "top_p": chat.top_p,
        "n": chat.n,
        "stop": chat.stop,
        "presence_penalty": chat.presence_penalty,
        "frequency_penalty": chat.frequency_penalty
    }

# Stream generator without forcing UTF-8
def stream_generator(requested_model: str, payload: dict, headers: dict):
    with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
        for line in r.iter_lines(decode_unicode=False):  # Keep as bytes
            if not line:
                continue
            if line.startswith(b"data:"):
                content = line[6:].strip()
                if content == b"[DONE]":
                    yield b"data: [DONE]\n\n"
                    continue
                try:
                    json_obj = json.loads(content.decode("utf-8"))
                    if json_obj.get("model") == BACKEND_MODEL:
                        json_obj["model"] = requested_model
                    yield f"data: {json.dumps(json_obj)}\n\n".encode("utf-8")
                except json.JSONDecodeError:
                    logger.warning("Invalid JSON in stream chunk: %s", content)
            else:
                logger.debug("Non-data stream line skipped: %s", line)

# Main endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
    try:
        body = await request.json()
        chat_request = ChatRequest(**body)
        payload = build_payload(chat_request)

        headers = {
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json"
        }

        if chat_request.stream:
            return StreamingResponse(
                stream_generator(chat_request.model, payload, headers),
                media_type="text/event-stream"
            )
        else:
            response = requests.post(API_URL, headers=headers, json=payload)
            response.raise_for_status()  # Raise error for bad responses
            data = response.json()
            if "model" in data and data["model"] == BACKEND_MODEL:
                data["model"] = chat_request.model
            return JSONResponse(content=data)

    except Exception as e:
        logger.error("Error in /v1/chat/completions: %s", str(e))
        return JSONResponse(
            content={"error": "Internal server error."},
            status_code=500
        )