Harry commited on
Commit
da7eaeb
·
unverified ·
1 Parent(s): a2c6dba

Better implement for using Enum (#106)

Browse files
Files changed (1) hide show
  1. src/EdgeGPT.py +43 -19
src/EdgeGPT.py CHANGED
@@ -9,13 +9,14 @@ import random
9
  import asyncio
10
  import argparse
11
  from enum import Enum
12
- from typing import Generator, Optional
13
 
14
  import requests
15
  import websockets.client as websockets
16
 
17
  DELIMITER = "\x1e"
18
 
 
19
  # Generate random IP between range 13.104.0.0/14
20
  FORWARDED_IP = (
21
  f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
@@ -55,6 +56,11 @@ class ConversationStyle(Enum):
55
  precise = "h3precise"
56
 
57
 
 
 
 
 
 
58
  def append_identifier(msg: dict) -> str:
59
  """
60
  Appends special character to end of message to identify end of message
@@ -82,7 +88,12 @@ class ChatHubRequest:
82
  self.conversation_signature: str = conversation_signature
83
  self.invocation_id: int = invocation_id
84
 
85
- def update(self, prompt: str, conversation_style: Optional[ConversationStyle], options: Optional[list] = None) -> None:
 
 
 
 
 
86
  """
87
  Updates request object
88
  """
@@ -94,6 +105,8 @@ class ChatHubRequest:
94
  "enablemm",
95
  ]
96
  if conversation_style:
 
 
97
  options = [
98
  "deepleo",
99
  "enable_debug_commands",
@@ -141,7 +154,10 @@ class Conversation:
141
  }
142
  self.session = requests.Session()
143
  self.session.headers.update(
144
- {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36"})
 
 
 
145
  if cookies is not None:
146
  cookie_file = cookies
147
  else:
@@ -192,7 +208,7 @@ class ChatHub:
192
  )
193
 
194
  async def ask_stream(
195
- self, prompt: str, conversation_style: Optional[ConversationStyle] = None
196
  ) -> Generator[str, None, None]:
197
  """
198
  Ask a question to the bot
@@ -206,8 +222,7 @@ class ChatHub:
206
  )
207
  await self.__initial_handshake()
208
  # Construct a ChatHub request
209
- self.request.update(
210
- prompt=prompt, conversation_style=conversation_style)
211
  # Send request
212
  await self.wss.send(append_identifier(self.request.struct))
213
  final = False
@@ -245,23 +260,30 @@ class Chatbot:
245
  def __init__(self, cookiePath: str = "", cookies: Optional[dict] = None) -> None:
246
  self.cookiePath: str = cookiePath
247
  self.cookies: dict | None = cookies
248
- self.chat_hub: ChatHub = ChatHub(
249
- Conversation(self.cookiePath, self.cookies))
250
 
251
- async def ask(self, prompt: str, conversation_style: ConversationStyle = None) -> dict:
 
 
252
  """
253
  Ask a question to the bot
254
  """
255
- async for final, response in self.chat_hub.ask_stream(prompt=prompt, conversation_style=conversation_style):
 
 
256
  if final:
257
  return response
258
  self.chat_hub.wss.close()
259
 
260
- async def ask_stream(self, prompt: str, conversation_style: ConversationStyle = None) -> Generator[str, None, None]:
 
 
261
  """
262
  Ask a question to the bot
263
  """
264
- async for response in self.chat_hub.ask_stream(prompt=prompt, conversation_style=conversation_style):
 
 
265
  yield response
266
 
267
  async def close(self):
@@ -332,13 +354,15 @@ async def main():
332
  print("Bot:")
333
  if args.no_stream:
334
  print(
335
- (await bot.ask(prompt=prompt, conversation_style=args.style))["item"]["messages"][1]["adaptiveCards"][
336
- 0
337
- ]["body"][0]["text"],
338
  )
339
  else:
340
  wrote = 0
341
- async for final, response in bot.ask_stream(prompt=prompt, conversation_style=args.style):
 
 
342
  if not final:
343
  print(response[wrote:], end="")
344
  wrote = len(response)
@@ -364,8 +388,9 @@ if __name__ == "__main__":
364
  parser = argparse.ArgumentParser()
365
  parser.add_argument("--enter-once", action="store_true")
366
  parser.add_argument("--no-stream", action="store_true")
367
- parser.add_argument("--style",
368
- choices=["creative", "balanced", "precise"], default="balanced")
 
369
  parser.add_argument(
370
  "--cookie-file",
371
  type=str,
@@ -375,5 +400,4 @@ if __name__ == "__main__":
375
  args = parser.parse_args()
376
  os.environ["COOKIE_FILE"] = args.cookie_file
377
  args = parser.parse_args()
378
- args.style = ConversationStyle[args.style]
379
  asyncio.run(main())
 
9
  import asyncio
10
  import argparse
11
  from enum import Enum
12
+ from typing import Generator, Optional, Union, Literal
13
 
14
  import requests
15
  import websockets.client as websockets
16
 
17
  DELIMITER = "\x1e"
18
 
19
+
20
  # Generate random IP between range 13.104.0.0/14
21
  FORWARDED_IP = (
22
  f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
 
56
  precise = "h3precise"
57
 
58
 
59
+ CONVERSATION_STYLE_TYPE = Optional[
60
+ Union[ConversationStyle, Literal["creative", "balanced", "precise"]]
61
+ ]
62
+
63
+
64
  def append_identifier(msg: dict) -> str:
65
  """
66
  Appends special character to end of message to identify end of message
 
88
  self.conversation_signature: str = conversation_signature
89
  self.invocation_id: int = invocation_id
90
 
91
+ def update(
92
+ self,
93
+ prompt: str,
94
+ conversation_style: CONVERSATION_STYLE_TYPE,
95
+ options: Optional[list] = None,
96
+ ) -> None:
97
  """
98
  Updates request object
99
  """
 
105
  "enablemm",
106
  ]
107
  if conversation_style:
108
+ if not isinstance(conversation_style, ConversationStyle):
109
+ conversation_style = getattr(ConversationStyle, conversation_style)
110
  options = [
111
  "deepleo",
112
  "enable_debug_commands",
 
154
  }
155
  self.session = requests.Session()
156
  self.session.headers.update(
157
+ {
158
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36"
159
+ }
160
+ )
161
  if cookies is not None:
162
  cookie_file = cookies
163
  else:
 
208
  )
209
 
210
  async def ask_stream(
211
+ self, prompt: str, conversation_style: CONVERSATION_STYLE_TYPE = None
212
  ) -> Generator[str, None, None]:
213
  """
214
  Ask a question to the bot
 
222
  )
223
  await self.__initial_handshake()
224
  # Construct a ChatHub request
225
+ self.request.update(prompt=prompt, conversation_style=conversation_style)
 
226
  # Send request
227
  await self.wss.send(append_identifier(self.request.struct))
228
  final = False
 
260
  def __init__(self, cookiePath: str = "", cookies: Optional[dict] = None) -> None:
261
  self.cookiePath: str = cookiePath
262
  self.cookies: dict | None = cookies
263
+ self.chat_hub: ChatHub = ChatHub(Conversation(self.cookiePath, self.cookies))
 
264
 
265
+ async def ask(
266
+ self, prompt: str, conversation_style: CONVERSATION_STYLE_TYPE = None
267
+ ) -> dict:
268
  """
269
  Ask a question to the bot
270
  """
271
+ async for final, response in self.chat_hub.ask_stream(
272
+ prompt=prompt, conversation_style=conversation_style
273
+ ):
274
  if final:
275
  return response
276
  self.chat_hub.wss.close()
277
 
278
+ async def ask_stream(
279
+ self, prompt: str, conversation_style: CONVERSATION_STYLE_TYPE = None
280
+ ) -> Generator[str, None, None]:
281
  """
282
  Ask a question to the bot
283
  """
284
+ async for response in self.chat_hub.ask_stream(
285
+ prompt=prompt, conversation_style=conversation_style
286
+ ):
287
  yield response
288
 
289
  async def close(self):
 
354
  print("Bot:")
355
  if args.no_stream:
356
  print(
357
+ (await bot.ask(prompt=prompt, conversation_style=args.style))["item"][
358
+ "messages"
359
+ ][1]["adaptiveCards"][0]["body"][0]["text"],
360
  )
361
  else:
362
  wrote = 0
363
+ async for final, response in bot.ask_stream(
364
+ prompt=prompt, conversation_style=args.style
365
+ ):
366
  if not final:
367
  print(response[wrote:], end="")
368
  wrote = len(response)
 
388
  parser = argparse.ArgumentParser()
389
  parser.add_argument("--enter-once", action="store_true")
390
  parser.add_argument("--no-stream", action="store_true")
391
+ parser.add_argument(
392
+ "--style", choices=["creative", "balanced", "precise"], default="balanced"
393
+ )
394
  parser.add_argument(
395
  "--cookie-file",
396
  type=str,
 
400
  args = parser.parse_args()
401
  os.environ["COOKIE_FILE"] = args.cookie_file
402
  args = parser.parse_args()
 
403
  asyncio.run(main())