import argparse import asyncio import json import os import traceback import urllib.request from EdgeGPT import Chatbot from aiohttp import web public_dir = '/public' async def process_message(user_message, context, _U, locale): chatbot = None try: if _U: cookies = loaded_cookies + [{"name": "_U", "value": _U}] else: cookies = loaded_cookies chatbot = await Chatbot.create(cookies=cookies, proxy=args.proxy) async for _, response in chatbot.ask_stream(prompt=user_message, conversation_style="creative", raw=True, webpage_context=context, search_result=True, locale=locale): yield response except: yield {"type": "error", "error": traceback.format_exc()} finally: if chatbot: await chatbot.close() async def http_handler(request): file_path = request.path if file_path == "/": file_path = "/index.html" full_path = os.path.realpath('.' + public_dir + file_path) if not full_path.startswith(os.path.realpath('.' + public_dir)): raise web.HTTPForbidden() response = web.FileResponse(full_path) response.headers['Cache-Control'] = 'no-store' return response async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: if msg.type == web.WSMsgType.TEXT: request = json.loads(msg.data) user_message = request['message'] context = request['context'] locale = request['locale'] _U = request.get('_U') async for response in process_message(user_message, context, _U, locale=locale): await ws.send_json(response) return ws async def main(host, port): app = web.Application() app.router.add_get('/ws/', websocket_handler) app.router.add_get('/{tail:.*}', http_handler) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, host, port) await site.start() print(f"Go to http://{host}:{port} to start chatting!") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--host", "-H", help="host:port for the server", default="localhost:65432") parser.add_argument("--proxy", "-p", help='proxy address like "http://localhost:7890"', default=urllib.request.getproxies().get('https')) args = parser.parse_args() print(f"Proxy used: {args.proxy}") host, port = args.host.split(":") port = int(port) if os.path.isfile("cookies.json"): with open("cookies.json", 'r') as f: loaded_cookies = json.load(f) print("Loaded cookies.json") else: loaded_cookies = [] print("cookies.json not found") loop = asyncio.get_event_loop() try: loop.run_until_complete(main(host, port)) loop.run_forever() except KeyboardInterrupt: pass finally: loop.close()