File size: 6,917 Bytes
0ad74ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from __future__ import annotations
import os
import socket
import threading
import time
from functools import partial
from typing import TYPE_CHECKING
import uvicorn
from uvicorn.config import Config
from gradio.exceptions import ServerFailedToStartError
from gradio.routes import App
from gradio.utils import SourceFileReloader, watchfn
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
pass
# By default, the local server will try to open on localhost, port 7860.
# If that is not available, then it will try 7861, 7862, ... 7959.
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
GRADIO_WATCH_DIRS = (
os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else []
)
GRADIO_WATCH_MODULE_NAME = os.getenv("GRADIO_WATCH_MODULE_NAME", "app")
GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo")
GRADIO_WATCH_DEMO_PATH = os.getenv("GRADIO_WATCH_DEMO_PATH", "")
class Server(uvicorn.Server):
def __init__(
self, config: Config, reloader: SourceFileReloader | None = None
) -> None:
self.running_app = config.app
super().__init__(config)
self.reloader = reloader
if self.reloader:
self.event = threading.Event()
self.watch = partial(watchfn, self.reloader)
def install_signal_handlers(self):
pass
def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
if self.reloader:
self.watch_thread = threading.Thread(target=self.watch, daemon=True)
self.watch_thread.start()
self.thread.start()
start = time.time()
while not self.started:
time.sleep(1e-3)
if time.time() - start > 5:
raise ServerFailedToStartError(
"Server failed to start. Please check that the port is available."
)
def close(self):
self.should_exit = True
if self.reloader:
self.reloader.stop()
self.watch_thread.join()
self.thread.join(timeout=5)
def start_server(
app: App,
server_name: str | None = None,
server_port: int | None = None,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
) -> tuple[str, int, str, Server]:
"""Launches a local server running the provided Interface
Parameters:
app: the FastAPI app object to run
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login.
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
Returns:
server_name: the name of the server (default is "localhost")
port: the port number the server is running on
path_to_local_server: the complete address that the local server can be accessed at
server: the server object that is a subclass of uvicorn.Server (used to close the server)
"""
if ssl_keyfile is not None and ssl_certfile is None:
raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.")
server_name = server_name or LOCALHOST_NAME
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
# Strip IPv6 brackets from the address if they exist.
# This is needed as http://[::1]:port/ is a valid browser address,
# but not a valid IPv6 address, so asyncio will throw an exception.
if server_name.startswith("[") and server_name.endswith("]"):
host = server_name[1:-1]
else:
host = server_name
server_ports = (
[server_port]
if server_port is not None
else range(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
)
for port in server_ports:
try:
# The fastest way to check if a port is available is to try to bind to it with socket.
# If the port is not available, socket will throw an OSError.
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Really, we should be checking if (server_name, server_port) is available, but
# socket.bind() doesn't seem to throw an OSError with ipv6 addresses, based on my testing.
# Instead, we just check if the port is available on localhost.
s.bind((LOCALHOST_NAME, port))
s.close()
# To avoid race conditions, so we also check if the port by trying to start the uvicorn server.
# If the port is not available, this will throw a ServerFailedToStartError.
config = uvicorn.Config(
app=app,
port=port,
host=host,
log_level="warning",
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
)
reloader = None
if GRADIO_WATCH_DIRS:
reloader = SourceFileReloader(
app=app,
watch_dirs=GRADIO_WATCH_DIRS,
watch_module_name=GRADIO_WATCH_MODULE_NAME,
demo_name=GRADIO_WATCH_DEMO_NAME,
stop_event=threading.Event(),
demo_file=GRADIO_WATCH_DEMO_PATH,
)
server = Server(config=config, reloader=reloader)
server.run_in_thread()
break
except (OSError, ServerFailedToStartError):
pass
else:
raise OSError(
f"Cannot find empty port in range: {min(server_ports)}-{max(server_ports)}. You can specify a different port by setting the GRADIO_SERVER_PORT environment variable or passing the `server_port` parameter to `launch()`."
)
if ssl_keyfile is not None:
path_to_local_server = f"https://{url_host_name}:{port}/"
else:
path_to_local_server = f"http://{url_host_name}:{port}/"
return server_name, port, path_to_local_server, server
|