File size: 3,773 Bytes
6d46915
 
 
 
dfdce18
6d46915
 
 
 
 
 
dfdce18
 
6d46915
eef5820
6d46915
92b542d
dfdce18
 
b343c97
dfdce18
8e943a0
 
 
92b542d
8e943a0
 
 
dfdce18
8e943a0
dfdce18
 
 
 
 
8e943a0
dfdce18
 
 
6d46915
 
 
 
 
 
 
 
 
 
 
 
 
 
eef5820
 
 
 
 
6d46915
 
 
8e943a0
6d46915
8e943a0
6d46915
 
 
 
 
 
 
 
 
92b542d
 
 
6d46915
92b542d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import base64
import os
import random
import time
from datetime import datetime, timedelta
from io import BytesIO

import gradio as gr
import requests
from PIL import Image

from wandb_data import TIMEZONE

SERVER_API = os.environ["SERVER_API"]
SERVER_API_KEY = os.environ["SERVER_API_KEY"]

current_model: str | None = None
last_current_model_sync: datetime = datetime.fromtimestamp(0, TIMEZONE)


def get_current_model() -> str | None:
    global current_model
    global last_current_model_sync
    now = datetime.now(TIMEZONE)
    if now - last_current_model_sync < timedelta(minutes=5):
        return current_model
    last_current_model_sync = now

    try:
        response = requests.get(f"{SERVER_API}/model")
        response.raise_for_status()
        model = response.json()
        current_model = f"{model['uid']} - {model['url']}"
        return current_model
    except:
        print("Unable to connect to API")
        return None


def image_from_base64(image_data: str) -> Image:
    image_buffer = BytesIO(base64.b64decode(image_data))
    image = Image.open(image_buffer)
    return image


def submit(prompt: str, seed: int | str | None, baseline: bool) -> tuple:
    if isinstance(seed, str):
        seed = hash(seed) % 2 ** 32
    if not seed:
        random.seed(int(time.time()))
        seed = random.randint(0, 2 ** 32 - 1)

    print(f"Making request with prompt: {prompt}, seed: {seed}, baseline: {baseline}")
    response = requests.post(
        f"{SERVER_API}/generate",
        params={"prompt": prompt, "baseline": baseline, "seed": seed},
        headers={"X-API-Key": SERVER_API_KEY},
    )
    response.raise_for_status()
    result = response.json()
    generation_time = float(result["generation_time"])
    nsfw = result["nsfw"]
    image = image_from_base64(result["image"])
    print(f"Received image with generation time: {generation_time:.3f}s, NSFW: {nsfw}")

    return None, gr.Image(
        image,
        label=f"{generation_time:.3f}s",
        show_label=True
    )


def create_demo():
    offline_textbox = gr.Textbox("The server is offline! Please come back later", interactive=False, show_label=False)
    with gr.Group(visible=get_current_model() is not None):
        with gr.Group():
            with gr.Row():
                with gr.Column():
                    gr.Textbox("Baseline", interactive=False, show_label=False)
                    baseline_image_component = gr.Image(show_label=False)

                with gr.Column():
                    textbox = gr.Textbox(interactive=False, show_label=False)
                    textbox.attach_load_event(lambda: get_current_model(), None)
                    optimized_image_component = gr.Image(show_label=False)
        with gr.Row():
            prompt = gr.Textbox(
                placeholder="Enter prompt...",
                interactive=True,
                submit_btn=True,
                show_label=False,
                autofocus=True,
                scale=10,
            )

            seed_input = gr.Textbox(
                placeholder="Enter seed...",
                interactive=True,
                show_label=False,
            )

        offline_textbox.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is None), None)
        prompt.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)
        seed_input.attach_load_event(lambda: gr.Textbox(visible=get_current_model() is not None), None)

        prompt.submit(lambda prompt, seed: submit(prompt, seed, True), inputs=[prompt, seed_input], outputs=[prompt, baseline_image_component])
        prompt.submit(lambda prompt, seed: submit(prompt, seed, False), inputs=[prompt, seed_input], outputs=[prompt, optimized_image_component])