File size: 3,670 Bytes
daa1bb4
f7f30cc
4d2b16a
f7f30cc
bcdcebb
daa1bb4
89d8cc9
daa1bb4
 
bcdcebb
89d8cc9
 
 
 
 
bcdcebb
 
 
 
89d8cc9
daa1bb4
 
bcdcebb
89d8cc9
daa1bb4
 
 
bcdcebb
daa1bb4
 
f7f30cc
 
 
 
 
 
 
 
bcdcebb
89d8cc9
f7f30cc
 
 
bcdcebb
f7f30cc
 
 
 
 
 
89d8cc9
 
 
 
bcdcebb
 
89d8cc9
4d2b16a
f7f30cc
4d2b16a
89d8cc9
 
 
4d2b16a
89d8cc9
 
 
4d2b16a
 
89d8cc9
4d2b16a
89d8cc9
4d2b16a
89d8cc9
 
 
f7f30cc
89d8cc9
f7f30cc
 
89d8cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcdcebb
89d8cc9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging

app = FastAPI()

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")

# TypeGPT API settings
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"

# Load model-system-prompt mappings
with open("model_map.json", "r") as f:
    MODEL_PROMPTS = json.load(f)

# Request schema
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0

# Build payload to send to actual backend API
def build_payload(chat: ChatRequest):
    system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
    return {
        "model": BACKEND_MODEL,
        "stream": chat.stream,
        "temperature": chat.temperature,
        "top_p": chat.top_p,
        "n": chat.n,
        "stop": chat.stop,
        "presence_penalty": chat.presence_penalty,
        "frequency_penalty": chat.frequency_penalty,
        "messages": [{"role": "system", "content": system_prompt}] + [
            {"role": msg.role, "content": msg.content} for msg in chat.messages
        ]
    }

# Properly forward streaming data and replace model
def stream_generator(requested_model: str, payload: dict, headers: dict):
    with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
        for line in r.iter_lines(decode_unicode=True):
            if not line:
                continue
            if line.startswith("data:"):
                content = line[6:].strip()
                if content == "[DONE]":
                    yield "data: [DONE]\n\n"
                    continue
                try:
                    json_obj = json.loads(content)
                    if json_obj.get("model") == BACKEND_MODEL:
                        json_obj["model"] = requested_model
                    yield f"data: {json.dumps(json_obj)}\n\n"
                except json.JSONDecodeError:
                    logger.warning("Invalid JSON in stream chunk: %s", content)
            else:
                logger.debug("Non-data stream line skipped: %s", line)

# Main endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
    try:
        body = await request.json()
        chat_request = ChatRequest(**body)
        payload = build_payload(chat_request)
        headers = {
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json"
        }

        if chat_request.stream:
            return StreamingResponse(
                stream_generator(chat_request.model, payload, headers),
                media_type="text/event-stream"
            )
        else:
            response = requests.post(API_URL, headers=headers, json=payload)
            data = response.json()
            if "model" in data and data["model"] == BACKEND_MODEL:
                data["model"] = chat_request.model
            return JSONResponse(content=data)

    except Exception as e:
        logger.error("Error in proxy_chat: %s", str(e))
        return JSONResponse(content={"error": "Internal server error."}, status_code=500)