AlexNijjar commited on
Commit
6d46915
·
1 Parent(s): c656c20

Add model comparison dashboard

Browse files
Files changed (1) hide show
  1. src/model_demo.py +78 -0
src/model_demo.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import random
4
+ import time
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import requests
9
+ from PIL import Image
10
+
11
+ SERVER_API = os.environ["SERVER_API"]
12
+
13
+
14
+ def image_from_base64(image_data: str) -> Image:
15
+ image_buffer = BytesIO(base64.b64decode(image_data))
16
+ image = Image.open(image_buffer)
17
+ return image
18
+
19
+
20
+ def submit(prompt: str, seed: int | str | None, baseline: bool) -> tuple:
21
+ if isinstance(seed, str):
22
+ seed = hash(seed) % 2 ** 32
23
+ if not seed:
24
+ random.seed(int(time.time()))
25
+ seed = random.randint(0, 2 ** 32 - 1)
26
+
27
+ print(f"Making request with prompt: {prompt}, seed: {seed}, baseline: {baseline}")
28
+ response = requests.post(f"{SERVER_API}/generate", params={"prompt": prompt, "baseline": baseline, "seed": seed}, verify=False)
29
+ response.raise_for_status()
30
+ result = response.json()
31
+ generation_time = float(result["generation_time"])
32
+ image = image_from_base64(result["image"])
33
+ print(f"Received image with generation time: {generation_time:.3f}s")
34
+
35
+ return None, gr.Image(
36
+ image,
37
+ label=f"{generation_time:.3f}s",
38
+ show_label=True
39
+ )
40
+
41
+
42
+ def create_textbox() -> gr.Textbox:
43
+ response = requests.get(f"{SERVER_API}/model", verify=False)
44
+ response.raise_for_status()
45
+ model = response.json()
46
+ return gr.Textbox(f"{model['uid']} - {model['url']}", interactive=False, show_label=False)
47
+
48
+
49
+ def create_demo():
50
+ with gr.Accordion(f"EdgeMaxxing Model Comparison"):
51
+ with gr.Group():
52
+ with gr.Row():
53
+ with gr.Column():
54
+ gr.Textbox("Baseline", interactive=False, show_label=False)
55
+ baseline_image_component = gr.Image(show_label=False)
56
+
57
+ with gr.Column():
58
+ textbox = gr.Textbox()
59
+ textbox.attach_load_event(lambda: create_textbox(), None)
60
+ optimized_image_component = gr.Image(show_label=False)
61
+ with gr.Row():
62
+ prompt = gr.Textbox(
63
+ placeholder="Enter prompt...",
64
+ interactive=True,
65
+ submit_btn=True,
66
+ show_label=False,
67
+ autofocus=True,
68
+ scale=10,
69
+ )
70
+
71
+ seed_input = gr.Textbox(
72
+ placeholder="Enter seed...",
73
+ interactive=True,
74
+ show_label=False,
75
+ )
76
+
77
+ prompt.submit(lambda prompt, seed: submit(prompt, seed, True), inputs=[prompt, seed_input], outputs=[prompt, baseline_image_component])
78
+ prompt.submit(lambda prompt, seed: submit(prompt, seed, False), inputs=[prompt, seed_input], outputs=[prompt, optimized_image_component])