lmt commited on
Commit
566ac72
·
1 Parent(s): 8aefd0f

增加重置对话功能

Browse files
Files changed (1) hide show
  1. main.py +17 -11
main.py CHANGED
@@ -10,13 +10,10 @@ import requests
10
  import uvicorn
11
 
12
  api_key = os.environ.get('api_key')
13
- initial_prompt = "You are a helpful assistant."
14
  API_URL = "https://api.openai.com/v1/chat/completions"
15
 
16
  app = FastAPI()
17
 
18
- openai.api_key = api_key
19
-
20
 
21
  @app.get("/")
22
  def read_root():
@@ -30,10 +27,14 @@ connection_history: Dict[str, List[Dict[str, str]]] = {}
30
  def get_sys_prompt():
31
  return [{
32
  "role": "system",
33
- "content": f"You are a helpful assistant. Current time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}. 默认用户在成都"
34
  }]
35
 
36
 
 
 
 
 
37
  def get_ai_response(messages, stream=True):
38
  '''获取ChatGPT答复,使用流式返回'''
39
  headers = {
@@ -69,13 +70,6 @@ class ResponseMessage(BaseModel):
69
 
70
  @app.websocket("/api/ws")
71
  async def websocket_endpoint(websocket: WebSocket):
72
- # auth_header = websocket.headers.get("Authorization")
73
- # ws_key = "Bearer YOUR_API_KEY" # 将 YOUR_API_KEY 替换为实际的 API 密钥
74
-
75
- # if auth_header != ws_key:
76
- # await websocket.close(code=1008) # 关闭连接,发送策略原因错误
77
- # return
78
-
79
  await websocket.accept()
80
  connection_id = str(id(websocket))
81
  connection_history[connection_id] = []
@@ -86,6 +80,12 @@ async def websocket_endpoint(websocket: WebSocket):
86
  message = Message(**json.loads(data))
87
  print(message)
88
 
 
 
 
 
 
 
89
  user_message = {"role": "user", "content": message.msg}
90
  connection_history[connection_id].append(user_message)
91
 
@@ -125,6 +125,8 @@ async def websocket_endpoint(websocket: WebSocket):
125
  except WebSocketDisconnect as e:
126
  # 在这里处理断开连接的情况,例如记录日志、清理资源等
127
  print(f"WebSocket disconnected with code: {e.code}")
 
 
128
 
129
 
130
  class Item(BaseModel):
@@ -141,6 +143,10 @@ def chat(item: Item):
141
  return res
142
 
143
 
 
 
 
 
144
  def get_response(system_prompt, history):
145
 
146
  history = [construct_system(system_prompt), *history]
 
10
  import uvicorn
11
 
12
  api_key = os.environ.get('api_key')
 
13
  API_URL = "https://api.openai.com/v1/chat/completions"
14
 
15
  app = FastAPI()
16
 
 
 
17
 
18
  @app.get("/")
19
  def read_root():
 
27
  def get_sys_prompt():
28
  return [{
29
  "role": "system",
30
+ "content": f"You are a helpful assistant. Current time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}."
31
  }]
32
 
33
 
34
+ def test_reset(msg):
35
+ return msg == "重置对话" or msg == "开启新对话"
36
+
37
+
38
  def get_ai_response(messages, stream=True):
39
  '''获取ChatGPT答复,使用流式返回'''
40
  headers = {
 
70
 
71
  @app.websocket("/api/ws")
72
  async def websocket_endpoint(websocket: WebSocket):
 
 
 
 
 
 
 
73
  await websocket.accept()
74
  connection_id = str(id(websocket))
75
  connection_history[connection_id] = []
 
80
  message = Message(**json.loads(data))
81
  print(message)
82
 
83
+ # 判断是否重置对话
84
+ if test_reset(message.msg):
85
+ connection_history[connection_id] = []
86
+ print("OK,对话已重置")
87
+ continue
88
+
89
  user_message = {"role": "user", "content": message.msg}
90
  connection_history[connection_id].append(user_message)
91
 
 
125
  except WebSocketDisconnect as e:
126
  # 在这里处理断开连接的情况,例如记录日志、清理资源等
127
  print(f"WebSocket disconnected with code: {e.code}")
128
+ except:
129
+ print(f"WebSocket disconnected with unknown reason")
130
 
131
 
132
  class Item(BaseModel):
 
143
  return res
144
 
145
 
146
+ openai.api_key = api_key
147
+ initial_prompt = "You are a helpful assistant."
148
+
149
+
150
  def get_response(system_prompt, history):
151
 
152
  history = [construct_system(system_prompt), *history]