Spaces:
Running
Running
File size: 3,773 Bytes
6d46915 dfdce18 6d46915 dfdce18 6d46915 eef5820 6d46915 92b542d dfdce18 8e943a0 92b542d 8e943a0 dfdce18 8e943a0 dfdce18 8e943a0 dfdce18 6d46915 eef5820 6d46915 8e943a0 6d46915 8e943a0 6d46915 92b542d 6d46915 92b542d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import base64
import os
import random
import time
from datetime import datetime, timedelta
from io import BytesIO
import gradio as gr
import requests
from PIL import Image
from wandb_data import TIMEZONE
SERVER_API = os.environ["SERVER_API"]
SERVER_API_KEY = os.environ["SERVER_API_KEY"]
current_model: str | None = None
last_current_model_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)
def get_current_model() -> str | None:
global current_model
global last_current_model_sync
now = datetime.now(TIMEZONE)
if now - last_current_model_sync < timedelta(minutes=5):
return current_model
last_current_model_sync = now
try:
response = requests.get(f"{SERVER_API}/model")
response.raise_for_status()
model = response.json()
current_model = f"{model['uid']} - {model['url']}"
return current_model
except:
print("Unable to connect to API")
return None
def image_from_base64(image_data: str) -> Image:
image_buffer = BytesIO(base64.b64decode(image_data))
image = Image.open(image_buffer)
return image
def submit(prompt: str, seed: int | str | None, baseline: bool) -> tuple:
if isinstance(seed, str):
seed = hash(seed) % 2 ** 32
if not seed:
random.seed(int(time.time()))
seed = random.randint(0, 2 ** 32 - 1)
print(f"Making request with prompt: {prompt}, seed: {seed}, baseline: {baseline}")
response = requests.post(
f"{SERVER_API}/generate",
params={"prompt": prompt, "baseline": baseline, "seed": seed},
headers={"X-API-Key": SERVER_API_KEY},
)
response.raise_for_status()
result = response.json()
generation_time = float(result["generation_time"])
nsfw = result["nsfw"]
image = image_from_base64(result["image"])
print(f"Received image with generation time: {generation_time:.3f}s, NSFW: {nsfw}")
return None, gr.Image(
image,
label=f"{generation_time:.3f}s",
show_label=True
)
def create_demo():
offline_textbox = gr.Textbox("The server is offline! Please come back later", interactive=False, show_label=False)
with gr.Group(visible=get_current_model() is not None):
with gr.Group():
with gr.Row():
with gr.Column():
gr.Textbox("Baseline", interactive=False, show_label=False)
baseline_image_component = gr.Image(show_label=False)
with gr.Column():
textbox = gr.Textbox(interactive=False, show_label=False)
textbox.attach_load_event(lambda: get_current_model(), None)
optimized_image_component = gr.Image(show_label=False)
with gr.Row():
prompt = gr.Textbox(
placeholder="Enter prompt...",
interactive=True,
submit_btn=True,
show_label=False,
autofocus=True,
scale=10,
)
seed_input = gr.Textbox(
placeholder="Enter seed...",
interactive=True,
show_label=False,
)
offline_textbox.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is None), None)
prompt.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)
seed_input.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)
prompt.submit(lambda prompt, seed: submit(prompt, seed, True), inputs=[prompt, seed_input], outputs=[prompt, baseline_image_component])
prompt.submit(lambda prompt, seed: submit(prompt, seed, False), inputs=[prompt, seed_input], outputs=[prompt, optimized_image_component])
|