import base64 import os import random import time from datetime import datetime, timedelta from io import BytesIO import gradio as gr import requests from PIL import Image from wandb_data import TIMEZONE SERVER_API = os.environ["SERVER_API"] SERVER_API_KEY = os.environ["SERVER_API_KEY"] current_model: str | None = None last_current_model_sync: datetime = datetime.fromtimestamp(0, TIMEZONE) def get_current_model() -> str | None: global current_model global last_current_model_sync now = datetime.now(TIMEZONE) if now - last_current_model_sync < timedelta(minutes=5): return current_model last_current_model_sync = now try: response = requests.get(f"{SERVER_API}/model") response.raise_for_status() model = response.json() current_model = f"{model['uid']} - {model['url']}" return current_model except: print("Unable to connect to API") return None def image_from_base64(image_data: str) -> Image: image_buffer = BytesIO(base64.b64decode(image_data)) image = Image.open(image_buffer) return image def submit(prompt: str, seed: int | str | None, baseline: bool) -> tuple: if isinstance(seed, str): seed = hash(seed) % 2 ** 32 if not seed: random.seed(int(time.time())) seed = random.randint(0, 2 ** 32 - 1) print(f"Making request with prompt: {prompt}, seed: {seed}, baseline: {baseline}") response = requests.post( f"{SERVER_API}/generate", params={"prompt": prompt, "baseline": baseline, "seed": seed}, headers={"X-API-Key": SERVER_API_KEY}, ) response.raise_for_status() result = response.json() generation_time = float(result["generation_time"]) nsfw = result["nsfw"] image = image_from_base64(result["image"]) print(f"Received image with generation time: {generation_time:.3f}s, NSFW: {nsfw}") return None, gr.Image( image, label=f"{generation_time:.3f}s", show_label=True ) def create_demo(): offline_textbox = gr.Textbox("The server is offline! Please come back later", interactive=False, show_label=False) with gr.Group(visible=get_current_model() is not None): with gr.Group(): with gr.Row(): with gr.Column(): gr.Textbox("Baseline", interactive=False, show_label=False) baseline_image_component = gr.Image(show_label=False) with gr.Column(): textbox = gr.Textbox(interactive=False, show_label=False) textbox.attach_load_event(lambda: get_current_model(), None) optimized_image_component = gr.Image(show_label=False) with gr.Row(): prompt = gr.Textbox( placeholder="Enter prompt...", interactive=True, submit_btn=True, show_label=False, autofocus=True, scale=10, ) seed_input = gr.Textbox( placeholder="Enter seed...", interactive=True, show_label=False, ) offline_textbox.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is None), None) prompt.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None) seed_input.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None) prompt.submit(lambda prompt, seed: submit(prompt, seed, True), inputs=[prompt, seed_input], outputs=[prompt, baseline_image_component]) prompt.submit(lambda prompt, seed: submit(prompt, seed, False), inputs=[prompt, seed_input], outputs=[prompt, optimized_image_component])