edge-maxxing-dashboard / src /model_demo.py
AlexNijjar's picture
Update to new format
b343c97
import base64
import os
import random
import time
from datetime import datetime, timedelta
from io import BytesIO
import gradio as gr
import requests
from PIL import Image
from wandb_data import TIMEZONE
SERVER_API = os.environ["SERVER_API"]
SERVER_API_KEY = os.environ["SERVER_API_KEY"]
current_model: str | None = None
last_current_model_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)
def get_current_model() -> str | None:
global current_model
global last_current_model_sync
now = datetime.now(TIMEZONE)
if now - last_current_model_sync < timedelta(minutes=5):
return current_model
last_current_model_sync = now
try:
response = requests.get(f"{SERVER_API}/model")
response.raise_for_status()
model = response.json()
current_model = f"{model['uid']} - {model['url']}"
return current_model
except:
print("Unable to connect to API")
return None
def image_from_base64(image_data: str) -> Image:
image_buffer = BytesIO(base64.b64decode(image_data))
image = Image.open(image_buffer)
return image
def submit(prompt: str, seed: int | str | None, baseline: bool) -> tuple:
if isinstance(seed, str):
seed = hash(seed) % 2 ** 32
if not seed:
random.seed(int(time.time()))
seed = random.randint(0, 2 ** 32 - 1)
print(f"Making request with prompt: {prompt}, seed: {seed}, baseline: {baseline}")
response = requests.post(
f"{SERVER_API}/generate",
params={"prompt": prompt, "baseline": baseline, "seed": seed},
headers={"X-API-Key": SERVER_API_KEY},
)
response.raise_for_status()
result = response.json()
generation_time = float(result["generation_time"])
nsfw = result["nsfw"]
image = image_from_base64(result["image"])
print(f"Received image with generation time: {generation_time:.3f}s, NSFW: {nsfw}")
return None, gr.Image(
image,
label=f"{generation_time:.3f}s",
show_label=True
)
def create_demo():
offline_textbox = gr.Textbox("The server is offline! Please come back later", interactive=False, show_label=False)
with gr.Group(visible=get_current_model() is not None):
with gr.Group():
with gr.Row():
with gr.Column():
gr.Textbox("Baseline", interactive=False, show_label=False)
baseline_image_component = gr.Image(show_label=False)
with gr.Column():
textbox = gr.Textbox(interactive=False, show_label=False)
textbox.attach_load_event(lambda: get_current_model(), None)
optimized_image_component = gr.Image(show_label=False)
with gr.Row():
prompt = gr.Textbox(
placeholder="Enter prompt...",
interactive=True,
submit_btn=True,
show_label=False,
autofocus=True,
scale=10,
)
seed_input = gr.Textbox(
placeholder="Enter seed...",
interactive=True,
show_label=False,
)
offline_textbox.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is None), None)
prompt.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)
seed_input.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)
prompt.submit(lambda prompt, seed: submit(prompt, seed, True), inputs=[prompt, seed_input], outputs=[prompt, baseline_image_component])
prompt.submit(lambda prompt, seed: submit(prompt, seed, False), inputs=[prompt, seed_input], outputs=[prompt, optimized_image_component])