|
import gradio as gr |
|
|
|
from giskard.ml_worker.ml_worker import MLWorker |
|
from pydantic import AnyHttpUrl |
|
from giskard.settings import settings |
|
from urllib.parse import urlparse |
|
|
|
import asyncio |
|
import threading |
|
|
|
import sys |
|
|
|
LOG_FILE = "output.log" |
|
|
|
class Logger: |
|
def __init__(self, filename): |
|
self.terminal = sys.stdout |
|
self.log = open(filename, "w") |
|
|
|
def write(self, message): |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
|
|
def flush(self): |
|
self.terminal.flush() |
|
self.log.flush() |
|
|
|
def isatty(self): |
|
return False |
|
|
|
sys.stdout = Logger(LOG_FILE) |
|
|
|
|
|
def read_logs(): |
|
sys.stdout.flush() |
|
with open(LOG_FILE, "r") as f: |
|
return f.read() |
|
|
|
|
|
previous_url = "" |
|
ml_worker = None |
|
|
|
|
|
def run_ml_worker(ml_worker: MLWorker): |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
loop.run_until_complete(ml_worker.start()) |
|
loop.close() |
|
|
|
def stop_ml_worker(): |
|
global ml_worker, previous_url |
|
if ml_worker is not None: |
|
print(f"Stopping ML worker for {previous_url}") |
|
ml_worker.stop() |
|
print("ML worker stopped") |
|
return "ML worker stopped" |
|
return "ML worker not started" |
|
|
|
|
|
def start_ml_worker(url, api_key, hf_token): |
|
global ml_worker, previous_url |
|
|
|
stop_ml_worker() |
|
|
|
parsed_url = urlparse(url) |
|
backend_url = AnyHttpUrl( |
|
url=f"{parsed_url.scheme if parsed_url.scheme else 'http'}://{parsed_url.hostname}" |
|
f"/{parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path}", |
|
scheme=parsed_url.scheme, |
|
host=parsed_url.hostname, |
|
path=parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path, |
|
) |
|
print(f"Starting ML worker for {backend_url}") |
|
ml_worker = MLWorker(False, backend_url, api_key, hf_token) |
|
previous_url = backend_url |
|
thread = threading.Thread(target=run_ml_worker, args=(ml_worker,)) |
|
thread.start() |
|
return f"ML worker running for {backend_url}" |
|
|
|
|
|
with gr.Blocks() as iface: |
|
with gr.Row(): |
|
with gr.Column(): |
|
url_input = gr.Textbox(label="Giskard Hub URL") |
|
api_key_input = gr.Textbox(label="Giskard Hub API Key", placeholder="gsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx") |
|
hf_token_input = gr.Textbox(label="Hugging Face Spaces Token") |
|
|
|
output = gr.Textbox(label="Status") |
|
with gr.Row(): |
|
run_btn = gr.Button("Run") |
|
run_btn.click(start_ml_worker, [url_input, api_key_input, hf_token_input], output) |
|
|
|
stop_btn = gr.Button("Stop") |
|
stop_btn.click(stop_ml_worker, None, output) |
|
|
|
logs = gr.Textbox(label="Giskard ML worker log:") |
|
iface.load(read_logs, None, logs, every=0.5) |
|
|
|
iface.queue() |
|
iface.launch() |
|
|