Hansimov commited on
Commit
8d91150
1 Parent(s): 7796b5b

:boom: [Fix] Close aiohttp client session after websockets finished

Browse files
Files changed (1) hide show
  1. conversation_creater.py +13 -11
conversation_creater.py CHANGED
@@ -59,15 +59,15 @@ class ConversationConnector:
59
  + f"?sec_access_token={urllib.parse.quote(self.sec_access_token)}"
60
  )
61
 
62
- async def _init_handshake(self, wss):
63
- await wss.send_str(
64
  serialize_websocket_message({"protocol": "json", "version": 1})
65
  )
66
- await wss.receive_str()
67
- await wss.send_str(serialize_websocket_message({"type": 6}))
68
 
69
  async def stream_chat(self, prompt=""):
70
- self.aio_session = aiohttp.ClientSession(cookies=self.cookies)
71
  request_headers = {
72
  "Accept-Encoding": " gzip, deflate, br",
73
  "Accept-Language": "en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7",
@@ -81,13 +81,13 @@ class ConversationConnector:
81
  "Upgrade": "websocket",
82
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/118.0.0.0 Safari/537.36",
83
  }
84
- wss = await self.aio_session.ws_connect(
85
  self.ws_url,
86
  headers=request_headers,
87
  proxy=http_proxy,
88
  )
89
 
90
- await self._init_handshake(wss)
91
  chathub_request_constructor = ChathubRequestConstructor(
92
  prompt=prompt,
93
  conversation_style="precise",
@@ -97,13 +97,13 @@ class ConversationConnector:
97
  )
98
  chathub_request_constructor.construct()
99
 
100
- await wss.send_str(
101
  serialize_websocket_message(chathub_request_constructor.request_message)
102
  )
103
 
104
  delta_content_pointer = 0
105
- while not wss.closed:
106
- response_lines_str = await wss.receive_str()
107
  if isinstance(response_lines_str, str):
108
  response_lines = response_lines_str.split("\x1e")
109
  else:
@@ -160,7 +160,8 @@ class ConversationConnector:
160
  # message_text = message["text"]
161
  elif data.get("type") == 3:
162
  logger.success("[Finished]")
163
- await wss.close()
 
164
  break
165
  elif data.get("type") == 6:
166
  continue
@@ -181,6 +182,7 @@ if __name__ == "__main__":
181
  conversation_id=creator.response_content["conversationId"],
182
  )
183
  prompt = "Today's weather of California"
 
184
  logger.success(f"\n[User]: ", end="")
185
  logger.mesg(f"{prompt}")
186
  logger.success(f"\n[Bing]:")
 
59
  + f"?sec_access_token={urllib.parse.quote(self.sec_access_token)}"
60
  )
61
 
62
+ async def _init_handshake(self):
63
+ await self.wss.send_str(
64
  serialize_websocket_message({"protocol": "json", "version": 1})
65
  )
66
+ await self.wss.receive_str()
67
+ await self.wss.send_str(serialize_websocket_message({"type": 6}))
68
 
69
  async def stream_chat(self, prompt=""):
70
+ self.aiohttp_session = aiohttp.ClientSession(cookies=self.cookies)
71
  request_headers = {
72
  "Accept-Encoding": " gzip, deflate, br",
73
  "Accept-Language": "en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7",
 
81
  "Upgrade": "websocket",
82
  "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/118.0.0.0 Safari/537.36",
83
  }
84
+ self.wss = await self.aiohttp_session.ws_connect(
85
  self.ws_url,
86
  headers=request_headers,
87
  proxy=http_proxy,
88
  )
89
 
90
+ await self._init_handshake()
91
  chathub_request_constructor = ChathubRequestConstructor(
92
  prompt=prompt,
93
  conversation_style="precise",
 
97
  )
98
  chathub_request_constructor.construct()
99
 
100
+ await self.wss.send_str(
101
  serialize_websocket_message(chathub_request_constructor.request_message)
102
  )
103
 
104
  delta_content_pointer = 0
105
+ while not self.wss.closed:
106
+ response_lines_str = await self.wss.receive_str()
107
  if isinstance(response_lines_str, str):
108
  response_lines = response_lines_str.split("\x1e")
109
  else:
 
160
  # message_text = message["text"]
161
  elif data.get("type") == 3:
162
  logger.success("[Finished]")
163
+ await self.wss.close()
164
+ await self.aiohttp_session.close()
165
  break
166
  elif data.get("type") == 6:
167
  continue
 
182
  conversation_id=creator.response_content["conversationId"],
183
  )
184
  prompt = "Today's weather of California"
185
+ # prompt = "Tell me your name. Your output should be no more than 3 words."
186
  logger.success(f"\n[User]: ", end="")
187
  logger.mesg(f"{prompt}")
188
  logger.success(f"\n[Bing]:")