ml-worker / app.py
inoki-giskard's picture
Allow to start and stop worker
6348ca6
raw
history blame
2.82 kB
import gradio as gr
from giskard.ml_worker.ml_worker import MLWorker
from pydantic import AnyHttpUrl
from giskard.settings import settings
from urllib.parse import urlparse
import asyncio
import threading
import sys
LOG_FILE = "output.log"
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
sys.stdout = Logger(LOG_FILE)
def read_logs():
sys.stdout.flush()
with open(LOG_FILE, "r") as f:
return f.read()
previous_url = ""
ml_worker = None
def run_ml_worker(ml_worker: MLWorker):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(ml_worker.start())
loop.close()
def stop_ml_worker():
global ml_worker, previous_url
if ml_worker is not None:
print(f"Stopping ML worker for {previous_url}")
ml_worker.stop()
print("ML worker stopped")
return "ML worker stopped"
return "ML worker not started"
def start_ml_worker(url, api_key, hf_token):
global ml_worker, previous_url
# Always run an external ML worker
stop_ml_worker()
parsed_url = urlparse(url)
backend_url = AnyHttpUrl(
url=f"{parsed_url.scheme if parsed_url.scheme else 'http'}://{parsed_url.hostname}"
f"/{parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path}",
scheme=parsed_url.scheme,
host=parsed_url.hostname,
path=parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path,
)
print(f"Starting ML worker for {backend_url}")
ml_worker = MLWorker(False, backend_url, api_key, hf_token)
previous_url = backend_url
thread = threading.Thread(target=run_ml_worker, args=(ml_worker,))
thread.start()
return f"ML worker running for {backend_url}"
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
url_input = gr.Textbox(label="Giskard Hub URL")
api_key_input = gr.Textbox(label="Giskard Hub API Key", placeholder="gsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx")
hf_token_input = gr.Textbox(label="Hugging Face Spaces Token")
output = gr.Textbox(label="Status")
with gr.Row():
run_btn = gr.Button("Run")
run_btn.click(start_ml_worker, [url_input, api_key_input, hf_token_input], output)
stop_btn = gr.Button("Stop")
stop_btn.click(stop_ml_worker, None, output)
logs = gr.Textbox(label="Giskard ML worker log:")
iface.load(read_logs, None, logs, every=0.5)
iface.queue()
iface.launch()