Spaces:
Runtime error
Runtime error
from http.server import BaseHTTPRequestHandler, HTTPServer | |
from typing import Optional | |
from phi.cli.settings import phi_cli_settings | |
class CliAuthRequestHandler(BaseHTTPRequestHandler): | |
"""Request Handler to accept the CLI auth token after the web based auth flow. | |
References: | |
https://medium.com/@hasinthaindrajee/browser-sso-for-cli-applications-b0be743fa656 | |
https://gist.github.com/mdonkers/63e115cc0c79b4f6b8b3a6b797e485c7 | |
TODO: | |
* Fix the header and limit to only localhost or phidata.com | |
""" | |
def _set_response(self): | |
self.send_response(200) | |
self.send_header("Content-type", "application/json") | |
self.send_header("Access-Control-Allow-Origin", "*") | |
self.send_header("Access-Control-Allow-Headers", "*") | |
self.send_header("Access-Control-Allow-Methods", "POST") | |
self.end_headers() | |
# def do_GET(self): | |
# logger.info("GET request,\nPath: %s\nHeaders:\n%s\n", str(self.path), str(self.headers)) | |
# self._set_response() | |
# self.wfile.write("GET request for {}".format(self.path).encode('utf-8')) | |
def do_OPTIONS(self): | |
# logger.debug( | |
# "OPTIONS request,\nPath: %s\nHeaders:\n%s\n", | |
# str(self.path), | |
# str(self.headers), | |
# ) | |
self._set_response() | |
# self.wfile.write("OPTIONS request for {}".format(self.path).encode('utf-8')) | |
def do_POST(self): | |
content_length = int(self.headers["Content-Length"]) # <--- Gets the size of data | |
post_data = self.rfile.read(content_length) # <--- Gets the data itself | |
decoded_post_data = post_data.decode("utf-8") | |
# logger.debug( | |
# "POST request,\nPath: {}\nHeaders:\n{}\n\nBody:\n{}\n".format( | |
# str(self.path), str(self.headers), decoded_post_data | |
# ) | |
# ) | |
# logger.debug("Data: {}".format(decoded_post_data)) | |
# logger.info("type: {}".format(type(post_data))) | |
phi_cli_settings.tmp_token_path.touch(exist_ok=True) | |
phi_cli_settings.tmp_token_path.write_text(decoded_post_data) | |
# TODO: Add checks before shutting down the server | |
self.server.running = False # type: ignore | |
self._set_response() | |
def log_message(self, format, *args): | |
pass | |
class CliAuthServer: | |
""" | |
Source: https://stackoverflow.com/a/38196725/10953921 | |
""" | |
def __init__(self, port: int = 9191): | |
import threading | |
self._server = HTTPServer(("", port), CliAuthRequestHandler) | |
self._thread = threading.Thread(target=self.run) | |
self._thread.daemon = True | |
self._server.running = False # type: ignore | |
def run(self): | |
self._server.running = True # type: ignore | |
while self._server.running: # type: ignore | |
self._server.handle_request() | |
def start(self): | |
self._thread.start() | |
def shut_down(self): | |
self._thread.close() # type: ignore | |
def check_port(port: int): | |
import socket | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
try: | |
return s.connect_ex(("localhost", port)) == 0 | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
return False | |
def get_port_for_auth_server(): | |
starting_port = 9191 | |
for port in range(starting_port, starting_port + 100): | |
if not check_port(port): | |
return port | |
def get_auth_token_from_web_flow(port: int) -> Optional[str]: | |
""" | |
GET request: curl http://localhost:9191 | |
POST request: curl -d "foo=bar&bin=baz" http://localhost:9191 | |
""" | |
import json | |
server = CliAuthServer(port) | |
server.run() | |
if phi_cli_settings.tmp_token_path.exists() and phi_cli_settings.tmp_token_path.is_file(): | |
auth_token_str = phi_cli_settings.tmp_token_path.read_text() | |
auth_token_json = json.loads(auth_token_str) | |
phi_cli_settings.tmp_token_path.unlink() | |
return auth_token_json.get("AuthToken", None) | |
return None | |