Kandinsky-API / app.py
ehristoforu's picture
Update app.py
6b72f87 verified
import json
import time
import requests
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import os
api_key = os.getenv("api_key")
secret_key = os.getenv("secret_key")
class Text2ImageAPI:
def __init__(self, url, api_key, secret_key):
self.URL = url
self.AUTH_HEADERS = {
'X-Key': f'Key {api_key}',
'X-Secret': f'Secret {secret_key}',
}
def get_model(self):
response = requests.get(self.URL + 'key/api/v1/models', headers=self.AUTH_HEADERS)
data = response.json()
return data[0]['id']
def generate(self, prompt, width, height, model):
params = {
"type": "GENERATE",
"numImages": 1,
"width": width,
"height": height,
"censored": True,
"generateParams": {
"query": f"{prompt}"
}
}
data = {
'model_id': (None, model),
'params': (None, json.dumps(params), 'application/json')
}
response = requests.post(self.URL + 'key/api/v1/text2image/run', headers=self.AUTH_HEADERS, files=data)
data = response.json()
return data['uuid']
def check_generation(self, request_id, attempts=10, delay=10):
while attempts > 0:
response = requests.get(self.URL + 'key/api/v1/text2image/status/' + request_id, headers=self.AUTH_HEADERS)
data = response.json()
if data['status'] == 'DONE':
return data['images']
attempts -= 1
time.sleep(delay)
def api_gradio(prompt, width, height):
api = Text2ImageAPI('https://api-key.fusionbrain.ai/', api_key, secret_key)
model_id = api.get_model()
uuid = api.generate(prompt, width, height, model_id)
images = api.check_generation(uuid)
decoded_data = base64.b64decode(images[0])
image = Image.open(BytesIO(decoded_data))
return [image]
css = """
footer {
visibility: hidden
}
#generate_button {
color: white;
border-color: #007bff;
background: #2563eb;
}
#save_button {
color: white;
border-color: #028b40;
background: #01b97c;
width: 200px;
}
#settings_header {
background: rgb(245, 105, 105);
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Kandinsky ```API DEMO```")
with gr.Row():
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20)
button = gr.Button(value="Generate", scale=1)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
width = gr.Slider(label="Width", minimum=888, maximum=1024, step=1, value=888, interactive=True)
height = gr.Slider(label="Height", minimum=888, maximum=1024, step=1, value=888, interactive=True)
with gr.Row():
gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)
button.click(api_gradio, inputs=[prompt, width, height], outputs=gallery)
demo.queue().launch(show_api=False)