Spaces:
Running
Running
import base64 | |
import os | |
import random | |
import time | |
from datetime import datetime, timedelta | |
from io import BytesIO | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from wandb_data import TIMEZONE | |
SERVER_API = os.environ["SERVER_API"] | |
SERVER_API_KEY = os.environ["SERVER_API_KEY"] | |
current_model: str | None = None | |
last_current_model_sync: datetime = datetime.fromtimestamp(0, TIMEZONE) | |
def get_current_model() -> str | None: | |
global current_model | |
global last_current_model_sync | |
now = datetime.now(TIMEZONE) | |
if now - last_current_model_sync < timedelta(minutes=5): | |
return current_model | |
last_current_model_sync = now | |
try: | |
response = requests.get(f"{SERVER_API}/model") | |
response.raise_for_status() | |
model = response.json() | |
current_model = f"{model['uid']} - {model['url']}" | |
return current_model | |
except: | |
print("Unable to connect to API") | |
return None | |
def image_from_base64(image_data: str) -> Image: | |
image_buffer = BytesIO(base64.b64decode(image_data)) | |
image = Image.open(image_buffer) | |
return image | |
def submit(prompt: str, seed: int | str | None, baseline: bool) -> tuple: | |
if isinstance(seed, str): | |
seed = hash(seed) % 2 ** 32 | |
if not seed: | |
random.seed(int(time.time())) | |
seed = random.randint(0, 2 ** 32 - 1) | |
print(f"Making request with prompt: {prompt}, seed: {seed}, baseline: {baseline}") | |
response = requests.post( | |
f"{SERVER_API}/generate", | |
params={"prompt": prompt, "baseline": baseline, "seed": seed}, | |
headers={"X-API-Key": SERVER_API_KEY}, | |
) | |
response.raise_for_status() | |
result = response.json() | |
generation_time = float(result["generation_time"]) | |
nsfw = result["nsfw"] | |
image = image_from_base64(result["image"]) | |
print(f"Received image with generation time: {generation_time:.3f}s, NSFW: {nsfw}") | |
return None, gr.Image( | |
image, | |
label=f"{generation_time:.3f}s", | |
show_label=True | |
) | |
def create_demo(): | |
offline_textbox = gr.Textbox("The server is offline! Please come back later", interactive=False, show_label=False) | |
with gr.Group(visible=get_current_model() is not None): | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Textbox("Baseline", interactive=False, show_label=False) | |
baseline_image_component = gr.Image(show_label=False) | |
with gr.Column(): | |
textbox = gr.Textbox(interactive=False, show_label=False) | |
textbox.attach_load_event(lambda: get_current_model(), None) | |
optimized_image_component = gr.Image(show_label=False) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
placeholder="Enter prompt...", | |
interactive=True, | |
submit_btn=True, | |
show_label=False, | |
autofocus=True, | |
scale=10, | |
) | |
seed_input = gr.Textbox( | |
placeholder="Enter seed...", | |
interactive=True, | |
show_label=False, | |
) | |
offline_textbox.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is None), None) | |
prompt.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None) | |
seed_input.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None) | |
prompt.submit(lambda prompt, seed: submit(prompt, seed, True), inputs=[prompt, seed_input], outputs=[prompt, baseline_image_component]) | |
prompt.submit(lambda prompt, seed: submit(prompt, seed, False), inputs=[prompt, seed_input], outputs=[prompt, optimized_image_component]) | |