File size: 2,820 Bytes
df1aa82 66e8b15 4708837 e7aeb95 df1aa82 66e8b15 6348ca6 66e8b15 6348ca6 66e8b15 6348ca6 66e8b15 e7aeb95 18f1fae e7aeb95 66e8b15 df1aa82 6348ca6 df1aa82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
from giskard.ml_worker.ml_worker import MLWorker
from pydantic import AnyHttpUrl
from giskard.settings import settings
from urllib.parse import urlparse
import asyncio
import threading
import sys
LOG_FILE = "output.log"
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
sys.stdout = Logger(LOG_FILE)
def read_logs():
sys.stdout.flush()
with open(LOG_FILE, "r") as f:
return f.read()
previous_url = ""
ml_worker = None
def run_ml_worker(ml_worker: MLWorker):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(ml_worker.start())
loop.close()
def stop_ml_worker():
global ml_worker, previous_url
if ml_worker is not None:
print(f"Stopping ML worker for {previous_url}")
ml_worker.stop()
print("ML worker stopped")
return "ML worker stopped"
return "ML worker not started"
def start_ml_worker(url, api_key, hf_token):
global ml_worker, previous_url
# Always run an external ML worker
stop_ml_worker()
parsed_url = urlparse(url)
backend_url = AnyHttpUrl(
url=f"{parsed_url.scheme if parsed_url.scheme else 'http'}://{parsed_url.hostname}"
f"/{parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path}",
scheme=parsed_url.scheme,
host=parsed_url.hostname,
path=parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path,
)
print(f"Starting ML worker for {backend_url}")
ml_worker = MLWorker(False, backend_url, api_key, hf_token)
previous_url = backend_url
thread = threading.Thread(target=run_ml_worker, args=(ml_worker,))
thread.start()
return f"ML worker running for {backend_url}"
with gr.Blocks() as iface:
with gr.Row():
with gr.Column():
url_input = gr.Textbox(label="Giskard Hub URL")
api_key_input = gr.Textbox(label="Giskard Hub API Key", placeholder="gsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx")
hf_token_input = gr.Textbox(label="Hugging Face Spaces Token")
output = gr.Textbox(label="Status")
with gr.Row():
run_btn = gr.Button("Run")
run_btn.click(start_ml_worker, [url_input, api_key_input, hf_token_input], output)
stop_btn = gr.Button("Stop")
stop_btn.click(stop_ml_worker, None, output)
logs = gr.Textbox(label="Giskard ML worker log:")
iface.load(read_logs, None, logs, every=0.5)
iface.queue()
iface.launch()
|