lmt commited on
Commit
8aefd0f
·
1 Parent(s): 337828e

少许更新

Browse files
Files changed (1) hide show
  1. main.py +35 -19
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import asyncio
 
2
  import json
3
  from typing import Dict, List, Union
4
  import os
@@ -22,35 +23,30 @@ def read_root():
22
  return {"Hello": "World!"}
23
 
24
 
25
- class Item(BaseModel):
26
- _msgid: Union[str, None] = None
27
- input: str
28
- history: List[Dict] = []
29
-
30
-
31
- @app.post("/api/chat")
32
- def chat(item: Item):
33
- print(item)
34
- history = [construct_user(item.input)]
35
- res = get_response(initial_prompt, history)
36
- return res
37
-
38
-
39
  # 存储每个连接的对话历史
40
  connection_history: Dict[str, List[Dict[str, str]]] = {}
41
 
42
 
43
- def get_ai_response(messages):
 
 
 
 
 
 
 
44
  '''获取ChatGPT答复,使用流式返回'''
45
  headers = {
46
  "Content-Type": "application/json",
47
  "Authorization": f"Bearer {api_key}",
48
  }
49
 
 
 
50
  payload = json.dumps({
51
  "model": "gpt-3.5-turbo",
52
- "messages": messages,
53
- "stream": True,
54
  })
55
 
56
  response = requests.post(
@@ -73,10 +69,16 @@ class ResponseMessage(BaseModel):
73
 
74
  @app.websocket("/api/ws")
75
  async def websocket_endpoint(websocket: WebSocket):
 
 
 
 
 
 
 
76
  await websocket.accept()
77
  connection_id = str(id(websocket))
78
- connection_history[connection_id] = [
79
- {"role": "system", "content": "You are a helpful assistant."}]
80
 
81
  try:
82
  while True:
@@ -125,6 +127,20 @@ async def websocket_endpoint(websocket: WebSocket):
125
  print(f"WebSocket disconnected with code: {e.code}")
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def get_response(system_prompt, history):
129
 
130
  history = [construct_system(system_prompt), *history]
 
1
  import asyncio
2
+ from datetime import datetime
3
  import json
4
  from typing import Dict, List, Union
5
  import os
 
23
  return {"Hello": "World!"}
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # 存储每个连接的对话历史
27
  connection_history: Dict[str, List[Dict[str, str]]] = {}
28
 
29
 
30
+ def get_sys_prompt():
31
+ return [{
32
+ "role": "system",
33
+ "content": f"You are a helpful assistant. Current time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}. 默认用户在成都"
34
+ }]
35
+
36
+
37
+ def get_ai_response(messages, stream=True):
38
  '''获取ChatGPT答复,使用流式返回'''
39
  headers = {
40
  "Content-Type": "application/json",
41
  "Authorization": f"Bearer {api_key}",
42
  }
43
 
44
+ sys_prompt = get_sys_prompt()
45
+
46
  payload = json.dumps({
47
  "model": "gpt-3.5-turbo",
48
+ "messages": sys_prompt + messages,
49
+ "stream": stream,
50
  })
51
 
52
  response = requests.post(
 
69
 
70
  @app.websocket("/api/ws")
71
  async def websocket_endpoint(websocket: WebSocket):
72
+ # auth_header = websocket.headers.get("Authorization")
73
+ # ws_key = "Bearer YOUR_API_KEY" # 将 YOUR_API_KEY 替换为实际的 API 密钥
74
+
75
+ # if auth_header != ws_key:
76
+ # await websocket.close(code=1008) # 关闭连接,发送策略原因错误
77
+ # return
78
+
79
  await websocket.accept()
80
  connection_id = str(id(websocket))
81
+ connection_history[connection_id] = []
 
82
 
83
  try:
84
  while True:
 
127
  print(f"WebSocket disconnected with code: {e.code}")
128
 
129
 
130
+ class Item(BaseModel):
131
+ _msgid: Union[str, None] = None
132
+ input: str
133
+ history: List[Dict] = []
134
+
135
+
136
+ @app.post("/api/chat")
137
+ def chat(item: Item):
138
+ print(item)
139
+ history = [construct_user(item.input)]
140
+ res = get_response(initial_prompt, history)
141
+ return res
142
+
143
+
144
  def get_response(system_prompt, history):
145
 
146
  history = [construct_system(system_prompt), *history]