File size: 3,088 Bytes
7bdc0ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c717871
7bdc0ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b72f87
 
7bdc0ee
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import time

import requests
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import os

api_key = os.getenv("api_key")
secret_key = os.getenv("secret_key")

class Text2ImageAPI:

    def __init__(self, url, api_key, secret_key):
        self.URL = url
        self.AUTH_HEADERS = {
            'X-Key': f'Key {api_key}',
            'X-Secret': f'Secret {secret_key}',
        }

    def get_model(self):
        response = requests.get(self.URL + 'key/api/v1/models', headers=self.AUTH_HEADERS)
        data = response.json()
        return data[0]['id']

    def generate(self, prompt, width, height, model):
        params = {
            "type": "GENERATE",
            "numImages": 1,
            "width": width,
            "height": height,
            "censored": True,
            "generateParams": {
                "query": f"{prompt}"
            }
        }

        data = {
            'model_id': (None, model),
            'params': (None, json.dumps(params), 'application/json')
        }
        response = requests.post(self.URL + 'key/api/v1/text2image/run', headers=self.AUTH_HEADERS, files=data)
        data = response.json()
        return data['uuid']

    def check_generation(self, request_id, attempts=10, delay=10):
        while attempts > 0:
            response = requests.get(self.URL + 'key/api/v1/text2image/status/' + request_id, headers=self.AUTH_HEADERS)
            data = response.json()
            if data['status'] == 'DONE':
                return data['images']

            attempts -= 1
            time.sleep(delay)


def api_gradio(prompt, width, height):
    api = Text2ImageAPI('https://api-key.fusionbrain.ai/', api_key, secret_key)
    model_id = api.get_model()
    uuid = api.generate(prompt, width, height, model_id)
    images = api.check_generation(uuid)

    decoded_data = base64.b64decode(images[0])
    image = Image.open(BytesIO(decoded_data))

    return [image]

css = """
footer {
    visibility: hidden
}
#generate_button {
    color: white;
    border-color: #007bff;
    background: #2563eb;
}
#save_button {
    color: white;
    border-color: #028b40;
    background: #01b97c;
    width: 200px;
}
#settings_header {
    background: rgb(245, 105, 105);
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Kandinsky ```API DEMO```")
    with gr.Row():
        prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20)
        button = gr.Button(value="Generate", scale=1)
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            width = gr.Slider(label="Width", minimum=888, maximum=1024, step=1, value=888, interactive=True)
            height = gr.Slider(label="Height", minimum=888, maximum=1024, step=1, value=888, interactive=True)
    with gr.Row():
        gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)

    button.click(api_gradio, inputs=[prompt, width, height], outputs=gallery)

demo.queue().launch(show_api=False)