Harry
commited on
Better implement for using Enum (#106)
Browse files- src/EdgeGPT.py +43 -19
src/EdgeGPT.py
CHANGED
@@ -9,13 +9,14 @@ import random
|
|
9 |
import asyncio
|
10 |
import argparse
|
11 |
from enum import Enum
|
12 |
-
from typing import Generator, Optional
|
13 |
|
14 |
import requests
|
15 |
import websockets.client as websockets
|
16 |
|
17 |
DELIMITER = "\x1e"
|
18 |
|
|
|
19 |
# Generate random IP between range 13.104.0.0/14
|
20 |
FORWARDED_IP = (
|
21 |
f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
@@ -55,6 +56,11 @@ class ConversationStyle(Enum):
|
|
55 |
precise = "h3precise"
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
58 |
def append_identifier(msg: dict) -> str:
|
59 |
"""
|
60 |
Appends special character to end of message to identify end of message
|
@@ -82,7 +88,12 @@ class ChatHubRequest:
|
|
82 |
self.conversation_signature: str = conversation_signature
|
83 |
self.invocation_id: int = invocation_id
|
84 |
|
85 |
-
def update(
|
|
|
|
|
|
|
|
|
|
|
86 |
"""
|
87 |
Updates request object
|
88 |
"""
|
@@ -94,6 +105,8 @@ class ChatHubRequest:
|
|
94 |
"enablemm",
|
95 |
]
|
96 |
if conversation_style:
|
|
|
|
|
97 |
options = [
|
98 |
"deepleo",
|
99 |
"enable_debug_commands",
|
@@ -141,7 +154,10 @@ class Conversation:
|
|
141 |
}
|
142 |
self.session = requests.Session()
|
143 |
self.session.headers.update(
|
144 |
-
{
|
|
|
|
|
|
|
145 |
if cookies is not None:
|
146 |
cookie_file = cookies
|
147 |
else:
|
@@ -192,7 +208,7 @@ class ChatHub:
|
|
192 |
)
|
193 |
|
194 |
async def ask_stream(
|
195 |
-
self, prompt: str, conversation_style:
|
196 |
) -> Generator[str, None, None]:
|
197 |
"""
|
198 |
Ask a question to the bot
|
@@ -206,8 +222,7 @@ class ChatHub:
|
|
206 |
)
|
207 |
await self.__initial_handshake()
|
208 |
# Construct a ChatHub request
|
209 |
-
self.request.update(
|
210 |
-
prompt=prompt, conversation_style=conversation_style)
|
211 |
# Send request
|
212 |
await self.wss.send(append_identifier(self.request.struct))
|
213 |
final = False
|
@@ -245,23 +260,30 @@ class Chatbot:
|
|
245 |
def __init__(self, cookiePath: str = "", cookies: Optional[dict] = None) -> None:
|
246 |
self.cookiePath: str = cookiePath
|
247 |
self.cookies: dict | None = cookies
|
248 |
-
self.chat_hub: ChatHub = ChatHub(
|
249 |
-
Conversation(self.cookiePath, self.cookies))
|
250 |
|
251 |
-
async def ask(
|
|
|
|
|
252 |
"""
|
253 |
Ask a question to the bot
|
254 |
"""
|
255 |
-
async for final, response in self.chat_hub.ask_stream(
|
|
|
|
|
256 |
if final:
|
257 |
return response
|
258 |
self.chat_hub.wss.close()
|
259 |
|
260 |
-
async def ask_stream(
|
|
|
|
|
261 |
"""
|
262 |
Ask a question to the bot
|
263 |
"""
|
264 |
-
async for response in self.chat_hub.ask_stream(
|
|
|
|
|
265 |
yield response
|
266 |
|
267 |
async def close(self):
|
@@ -332,13 +354,15 @@ async def main():
|
|
332 |
print("Bot:")
|
333 |
if args.no_stream:
|
334 |
print(
|
335 |
-
(await bot.ask(prompt=prompt, conversation_style=args.style))["item"][
|
336 |
-
|
337 |
-
]["body"][0]["text"],
|
338 |
)
|
339 |
else:
|
340 |
wrote = 0
|
341 |
-
async for final, response in bot.ask_stream(
|
|
|
|
|
342 |
if not final:
|
343 |
print(response[wrote:], end="")
|
344 |
wrote = len(response)
|
@@ -364,8 +388,9 @@ if __name__ == "__main__":
|
|
364 |
parser = argparse.ArgumentParser()
|
365 |
parser.add_argument("--enter-once", action="store_true")
|
366 |
parser.add_argument("--no-stream", action="store_true")
|
367 |
-
parser.add_argument(
|
368 |
-
|
|
|
369 |
parser.add_argument(
|
370 |
"--cookie-file",
|
371 |
type=str,
|
@@ -375,5 +400,4 @@ if __name__ == "__main__":
|
|
375 |
args = parser.parse_args()
|
376 |
os.environ["COOKIE_FILE"] = args.cookie_file
|
377 |
args = parser.parse_args()
|
378 |
-
args.style = ConversationStyle[args.style]
|
379 |
asyncio.run(main())
|
|
|
9 |
import asyncio
|
10 |
import argparse
|
11 |
from enum import Enum
|
12 |
+
from typing import Generator, Optional, Union, Literal
|
13 |
|
14 |
import requests
|
15 |
import websockets.client as websockets
|
16 |
|
17 |
DELIMITER = "\x1e"
|
18 |
|
19 |
+
|
20 |
# Generate random IP between range 13.104.0.0/14
|
21 |
FORWARDED_IP = (
|
22 |
f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
|
|
56 |
precise = "h3precise"
|
57 |
|
58 |
|
59 |
+
CONVERSATION_STYLE_TYPE = Optional[
|
60 |
+
Union[ConversationStyle, Literal["creative", "balanced", "precise"]]
|
61 |
+
]
|
62 |
+
|
63 |
+
|
64 |
def append_identifier(msg: dict) -> str:
|
65 |
"""
|
66 |
Appends special character to end of message to identify end of message
|
|
|
88 |
self.conversation_signature: str = conversation_signature
|
89 |
self.invocation_id: int = invocation_id
|
90 |
|
91 |
+
def update(
|
92 |
+
self,
|
93 |
+
prompt: str,
|
94 |
+
conversation_style: CONVERSATION_STYLE_TYPE,
|
95 |
+
options: Optional[list] = None,
|
96 |
+
) -> None:
|
97 |
"""
|
98 |
Updates request object
|
99 |
"""
|
|
|
105 |
"enablemm",
|
106 |
]
|
107 |
if conversation_style:
|
108 |
+
if not isinstance(conversation_style, ConversationStyle):
|
109 |
+
conversation_style = getattr(ConversationStyle, conversation_style)
|
110 |
options = [
|
111 |
"deepleo",
|
112 |
"enable_debug_commands",
|
|
|
154 |
}
|
155 |
self.session = requests.Session()
|
156 |
self.session.headers.update(
|
157 |
+
{
|
158 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36"
|
159 |
+
}
|
160 |
+
)
|
161 |
if cookies is not None:
|
162 |
cookie_file = cookies
|
163 |
else:
|
|
|
208 |
)
|
209 |
|
210 |
async def ask_stream(
|
211 |
+
self, prompt: str, conversation_style: CONVERSATION_STYLE_TYPE = None
|
212 |
) -> Generator[str, None, None]:
|
213 |
"""
|
214 |
Ask a question to the bot
|
|
|
222 |
)
|
223 |
await self.__initial_handshake()
|
224 |
# Construct a ChatHub request
|
225 |
+
self.request.update(prompt=prompt, conversation_style=conversation_style)
|
|
|
226 |
# Send request
|
227 |
await self.wss.send(append_identifier(self.request.struct))
|
228 |
final = False
|
|
|
260 |
def __init__(self, cookiePath: str = "", cookies: Optional[dict] = None) -> None:
|
261 |
self.cookiePath: str = cookiePath
|
262 |
self.cookies: dict | None = cookies
|
263 |
+
self.chat_hub: ChatHub = ChatHub(Conversation(self.cookiePath, self.cookies))
|
|
|
264 |
|
265 |
+
async def ask(
|
266 |
+
self, prompt: str, conversation_style: CONVERSATION_STYLE_TYPE = None
|
267 |
+
) -> dict:
|
268 |
"""
|
269 |
Ask a question to the bot
|
270 |
"""
|
271 |
+
async for final, response in self.chat_hub.ask_stream(
|
272 |
+
prompt=prompt, conversation_style=conversation_style
|
273 |
+
):
|
274 |
if final:
|
275 |
return response
|
276 |
self.chat_hub.wss.close()
|
277 |
|
278 |
+
async def ask_stream(
|
279 |
+
self, prompt: str, conversation_style: CONVERSATION_STYLE_TYPE = None
|
280 |
+
) -> Generator[str, None, None]:
|
281 |
"""
|
282 |
Ask a question to the bot
|
283 |
"""
|
284 |
+
async for response in self.chat_hub.ask_stream(
|
285 |
+
prompt=prompt, conversation_style=conversation_style
|
286 |
+
):
|
287 |
yield response
|
288 |
|
289 |
async def close(self):
|
|
|
354 |
print("Bot:")
|
355 |
if args.no_stream:
|
356 |
print(
|
357 |
+
(await bot.ask(prompt=prompt, conversation_style=args.style))["item"][
|
358 |
+
"messages"
|
359 |
+
][1]["adaptiveCards"][0]["body"][0]["text"],
|
360 |
)
|
361 |
else:
|
362 |
wrote = 0
|
363 |
+
async for final, response in bot.ask_stream(
|
364 |
+
prompt=prompt, conversation_style=args.style
|
365 |
+
):
|
366 |
if not final:
|
367 |
print(response[wrote:], end="")
|
368 |
wrote = len(response)
|
|
|
388 |
parser = argparse.ArgumentParser()
|
389 |
parser.add_argument("--enter-once", action="store_true")
|
390 |
parser.add_argument("--no-stream", action="store_true")
|
391 |
+
parser.add_argument(
|
392 |
+
"--style", choices=["creative", "balanced", "precise"], default="balanced"
|
393 |
+
)
|
394 |
parser.add_argument(
|
395 |
"--cookie-file",
|
396 |
type=str,
|
|
|
400 |
args = parser.parse_args()
|
401 |
os.environ["COOKIE_FILE"] = args.cookie_file
|
402 |
args = parser.parse_args()
|
|
|
403 |
asyncio.run(main())
|