Spaces:
Running
Running
File size: 3,773 Bytes
6d46915 dfdce18 6d46915 dfdce18 6d46915 eef5820 6d46915 92b542d dfdce18 b343c97 dfdce18 8e943a0 92b542d 8e943a0 dfdce18 8e943a0 dfdce18 8e943a0 dfdce18 6d46915 eef5820 6d46915 8e943a0 6d46915 8e943a0 6d46915 92b542d 6d46915 92b542d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import base64
import os
import random
import time
from datetime import datetime, timedelta
from io import BytesIO
import gradio as gr
import requests
from PIL import Image
from wandb_data import TIMEZONE
SERVER_API = os.environ["SERVER_API"]
SERVER_API_KEY = os.environ["SERVER_API_KEY"]
current_model: str | None = None
last_current_model_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)
def get_current_model() -> str | None:
global current_model
global last_current_model_sync
now = datetime.now(TIMEZONE)
if now - last_current_model_sync < timedelta(minutes=5):
return current_model
last_current_model_sync = now
try:
response = requests.get(f"{SERVER_API}/model")
response.raise_for_status()
model = response.json()
current_model = f"{model['uid']} - {model['url']}"
return current_model
except:
print("Unable to connect to API")
return None
def image_from_base64(image_data: str) -> Image:
image_buffer = BytesIO(base64.b64decode(image_data))
image = Image.open(image_buffer)
return image
def submit(prompt: str, seed: int | str | None, baseline: bool) -> tuple:
if isinstance(seed, str):
seed = hash(seed) % 2 ** 32
if not seed:
random.seed(int(time.time()))
seed = random.randint(0, 2 ** 32 - 1)
print(f"Making request with prompt: {prompt}, seed: {seed}, baseline: {baseline}")
response = requests.post(
f"{SERVER_API}/generate",
params={"prompt": prompt, "baseline": baseline, "seed": seed},
headers={"X-API-Key": SERVER_API_KEY},
)
response.raise_for_status()
result = response.json()
generation_time = float(result["generation_time"])
nsfw = result["nsfw"]
image = image_from_base64(result["image"])
print(f"Received image with generation time: {generation_time:.3f}s, NSFW: {nsfw}")
return None, gr.Image(
image,
label=f"{generation_time:.3f}s",
show_label=True
)
def create_demo():
offline_textbox = gr.Textbox("The server is offline! Please come back later", interactive=False, show_label=False)
with gr.Group(visible=get_current_model() is not None):
with gr.Group():
with gr.Row():
with gr.Column():
gr.Textbox("Baseline", interactive=False, show_label=False)
baseline_image_component = gr.Image(show_label=False)
with gr.Column():
textbox = gr.Textbox(interactive=False, show_label=False)
textbox.attach_load_event(lambda: get_current_model(), None)
optimized_image_component = gr.Image(show_label=False)
with gr.Row():
prompt = gr.Textbox(
placeholder="Enter prompt...",
interactive=True,
submit_btn=True,
show_label=False,
autofocus=True,
scale=10,
)
seed_input = gr.Textbox(
placeholder="Enter seed...",
interactive=True,
show_label=False,
)
offline_textbox.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is None), None)
prompt.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)
seed_input.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)
prompt.submit(lambda prompt, seed: submit(prompt, seed, True), inputs=[prompt, seed_input], outputs=[prompt, baseline_image_component])
prompt.submit(lambda prompt, seed: submit(prompt, seed, False), inputs=[prompt, seed_input], outputs=[prompt, optimized_image_component])
|