File size: 2,804 Bytes
daa1bb4
f7f30cc
 
 
bcdcebb
daa1bb4
 
 
bcdcebb
 
 
 
 
f7f30cc
daa1bb4
 
bcdcebb
daa1bb4
 
 
bcdcebb
daa1bb4
 
f7f30cc
 
 
 
 
 
 
 
bcdcebb
f7f30cc
 
 
 
 
 
bcdcebb
f7f30cc
bcdcebb
f7f30cc
 
 
 
 
 
 
 
bcdcebb
 
f7f30cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daa1bb4
 
 
 
bcdcebb
f7f30cc
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Union
import requests
import json

app = FastAPI()

API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "pixtral-large-latest"

# Load virtual model -> system prompt mappings
with open("model_map.json", "r") as f:
    MODEL_PROMPTS = json.load(f)

class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    stream: Optional[bool] = False
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 1.0
    n: Optional[int] = 1
    stop: Optional[Union[str, List[str]]] = None
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0

def build_payload(chat: ChatRequest):
    system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")

    messages = [{"role": "system", "content": system_prompt}] + [
        {"role": msg.role, "content": msg.content} for msg in chat.messages
    ]

    return {
        "model": BACKEND_MODEL,
        "messages": messages,
        "stream": chat.stream,
        "temperature": chat.temperature,
        "top_p": chat.top_p,
        "n": chat.n,
        "stop": chat.stop,
        "presence_penalty": chat.presence_penalty,
        "frequency_penalty": chat.frequency_penalty
    }

def stream_generator(requested_model, payload, headers):
    with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
        for line in r.iter_lines():
            if line:
                decoded = line.decode('utf-8')
                # Rewrite the model name in streaming output
                if BACKEND_MODEL in decoded:
                    decoded = decoded.replace(BACKEND_MODEL, requested_model)
                yield f"data: {decoded}\n\n"
        yield "data: [DONE]\n\n"

@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
    body = await request.json()
    chat_request = ChatRequest(**body)
    payload = build_payload(chat_request)
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }

    if chat_request.stream:
        return StreamingResponse(
            stream_generator(chat_request.model, payload, headers),
            media_type="text/event-stream"
        )
    else:
        response = requests.post(API_URL, headers=headers, json=payload)
        data = response.json()
        # Replace model in final result
        if "model" in data and data["model"] == BACKEND_MODEL:
            data["model"] = chat_request.model
        return JSONResponse(content=data)