rkihacker commited on
Commit
89d8cc9
·
verified ·
1 Parent(s): 4d2b16a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +49 -36
main.py CHANGED
@@ -4,17 +4,24 @@ from pydantic import BaseModel
4
  from typing import List, Optional, Union
5
  import requests
6
  import json
 
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
10
  API_URL = "https://api.typegpt.net/v1/chat/completions"
11
  API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
12
  BACKEND_MODEL = "pixtral-large-latest"
13
 
14
- # Load virtual model -> system prompt mappings
15
  with open("model_map.json", "r") as f:
16
  MODEL_PROMPTS = json.load(f)
17
 
 
18
  class Message(BaseModel):
19
  role: str
20
  content: str
@@ -30,62 +37,68 @@ class ChatRequest(BaseModel):
30
  presence_penalty: Optional[float] = 0.0
31
  frequency_penalty: Optional[float] = 0.0
32
 
 
33
  def build_payload(chat: ChatRequest):
34
  system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
35
-
36
- messages = [{"role": "system", "content": system_prompt}] + [
37
- {"role": msg.role, "content": msg.content} for msg in chat.messages
38
- ]
39
-
40
  return {
41
  "model": BACKEND_MODEL,
42
- "messages": messages,
43
  "stream": chat.stream,
44
  "temperature": chat.temperature,
45
  "top_p": chat.top_p,
46
  "n": chat.n,
47
  "stop": chat.stop,
48
  "presence_penalty": chat.presence_penalty,
49
- "frequency_penalty": chat.frequency_penalty
 
 
 
50
  }
51
 
 
52
  def stream_generator(requested_model: str, payload: dict, headers: dict):
53
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
54
  for line in r.iter_lines(decode_unicode=True):
55
- if line and line.startswith("data:"):
56
- # Remove "data: " prefix
 
57
  content = line[6:].strip()
 
 
 
58
  try:
59
- # Try to parse and replace model field
60
  json_obj = json.loads(content)
61
- if "model" in json_obj and json_obj["model"] == BACKEND_MODEL:
62
  json_obj["model"] = requested_model
63
- fixed_line = f"data: {json.dumps(json_obj)}\n\n"
64
  except json.JSONDecodeError:
65
- fixed_line = f"data: {content}\n\n"
66
- yield fixed_line
67
- elif line:
68
- yield f"data: {line}\n\n"
69
- yield "data: [DONE]\n\n"
70
 
 
71
  @app.post("/v1/chat/completions")
72
  async def proxy_chat(request: Request):
73
- body = await request.json()
74
- chat_request = ChatRequest(**body)
75
- payload = build_payload(chat_request)
76
- headers = {
77
- "Authorization": f"Bearer {API_KEY}",
78
- "Content-Type": "application/json"
79
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- if chat_request.stream:
82
- return StreamingResponse(
83
- stream_generator(chat_request.model, payload, headers),
84
- media_type="text/event-stream"
85
- )
86
- else:
87
- response = requests.post(API_URL, headers=headers, json=payload)
88
- data = response.json()
89
- if "model" in data and data["model"] == BACKEND_MODEL:
90
- data["model"] = chat_request.model
91
- return JSONResponse(content=data)
 
4
  from typing import List, Optional, Union
5
  import requests
6
  import json
7
+ import logging
8
 
9
  app = FastAPI()
10
 
11
+ # Setup logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger("proxy")
14
+
15
+ # TypeGPT API settings
16
  API_URL = "https://api.typegpt.net/v1/chat/completions"
17
  API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
18
  BACKEND_MODEL = "pixtral-large-latest"
19
 
20
+ # Load model-system-prompt mappings
21
  with open("model_map.json", "r") as f:
22
  MODEL_PROMPTS = json.load(f)
23
 
24
+ # Request schema
25
  class Message(BaseModel):
26
  role: str
27
  content: str
 
37
  presence_penalty: Optional[float] = 0.0
38
  frequency_penalty: Optional[float] = 0.0
39
 
40
+ # Build payload to send to actual backend API
41
  def build_payload(chat: ChatRequest):
42
  system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
 
 
 
 
 
43
  return {
44
  "model": BACKEND_MODEL,
 
45
  "stream": chat.stream,
46
  "temperature": chat.temperature,
47
  "top_p": chat.top_p,
48
  "n": chat.n,
49
  "stop": chat.stop,
50
  "presence_penalty": chat.presence_penalty,
51
+ "frequency_penalty": chat.frequency_penalty,
52
+ "messages": [{"role": "system", "content": system_prompt}] + [
53
+ {"role": msg.role, "content": msg.content} for msg in chat.messages
54
+ ]
55
  }
56
 
57
+ # Properly forward streaming data and replace model
58
  def stream_generator(requested_model: str, payload: dict, headers: dict):
59
  with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
60
  for line in r.iter_lines(decode_unicode=True):
61
+ if not line:
62
+ continue
63
+ if line.startswith("data:"):
64
  content = line[6:].strip()
65
+ if content == "[DONE]":
66
+ yield "data: [DONE]\n\n"
67
+ continue
68
  try:
 
69
  json_obj = json.loads(content)
70
+ if json_obj.get("model") == BACKEND_MODEL:
71
  json_obj["model"] = requested_model
72
+ yield f"data: {json.dumps(json_obj)}\n\n"
73
  except json.JSONDecodeError:
74
+ logger.warning("Invalid JSON in stream chunk: %s", content)
75
+ else:
76
+ logger.debug("Non-data stream line skipped: %s", line)
 
 
77
 
78
+ # Main endpoint
79
  @app.post("/v1/chat/completions")
80
  async def proxy_chat(request: Request):
81
+ try:
82
+ body = await request.json()
83
+ chat_request = ChatRequest(**body)
84
+ payload = build_payload(chat_request)
85
+ headers = {
86
+ "Authorization": f"Bearer {API_KEY}",
87
+ "Content-Type": "application/json"
88
+ }
89
+
90
+ if chat_request.stream:
91
+ return StreamingResponse(
92
+ stream_generator(chat_request.model, payload, headers),
93
+ media_type="text/event-stream"
94
+ )
95
+ else:
96
+ response = requests.post(API_URL, headers=headers, json=payload)
97
+ data = response.json()
98
+ if "model" in data and data["model"] == BACKEND_MODEL:
99
+ data["model"] = chat_request.model
100
+ return JSONResponse(content=data)
101
 
102
+ except Exception as e:
103
+ logger.error("Error in proxy_chat: %s", str(e))
104
+ return JSONResponse(content={"error": "Internal server error."}, status_code=500)