rudall-e / app.py
speech-test's picture
temporary API rate limit fix
f0dfc26
raw
history blame
2.45 kB
import random
import torch
import gradio as gr
from gradio.mix import Series
from transformers import pipeline
from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
translation_pipe = pipeline("translation", model="facebook/wmt19-en-ru", device=0)
dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)
def translation_wrapper(text: str):
return translation_pipe(text)[0]["translation_text"]
def dalle_wrapper(prompt: str):
top_k, top_p = random.choice([
(1024, 0.98),
(512, 0.97),
(384, 0.96),
])
images , _ = generate_images(
prompt,
tokenizer,
dalle,
vae,
top_k=top_k,
images_num=1,
top_p=top_p
)
title = f"<b>{prompt}</b>"
return title, images[0]
translator = gr.Interface(fn=translation_wrapper,
inputs=[gr.inputs.Textbox(label='What would you like to see?')],
outputs="text")
outputs = [
gr.outputs.HTML(label=""),
gr.outputs.Image(label=""),
]
generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs)
description = (
"ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). "
"This demo uses an English-Russian translation model to adapt the prompts. "
"Try pressing [Submit] multiple times to generate new images!"
)
article = (
"<p style='text-align: center'>"
"<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | "
"<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>"
"</p>"
)
examples = [["A still life of grapes and a bottle of wine"],
["Город в стиле киберпанк"],
["A colorful photo of a coral reef"],
["A white cat sitting in a cardboard box"]]
series = Series(translator, generator,
title='Kinda-English ruDALL-E',
description=description,
article=article,
layout='horizontal',
theme='huggingface',
examples=examples,
allow_flagging=False,
live=False,
enable_queue=True,
)
series.launch()