Spaces:
Runtime error
Runtime error
File size: 2,993 Bytes
650ec6e 11b632d 7084126 650ec6e 27cb35e 650ec6e 7084126 650ec6e e8720a0 5e05aea 650ec6e f0dfc26 976f94a 0fef3d0 976f94a 91602b2 f0dfc26 650ec6e ea2efae 650ec6e f0dfc26 650ec6e 3368fa1 650ec6e 4556b47 650ec6e edfe9df 4556b47 650ec6e 3368fa1 650ec6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import random
import torch
import numpy as np
from tqdm import tqdm
from functools import partialmethod
import gradio as gr
from gradio.mix import Series
from transformers import pipeline, FSMTForConditionalGeneration, FSMTTokenizer
from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae
# disable tqdm logging from the rudalle pipeline
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
translation_model = FSMTForConditionalGeneration.from_pretrained("facebook/wmt19-en-ru", torch_dtype=torch.float16).half().to(device)
translation_tokenizer = FSMTTokenizer.from_pretrained("facebook/wmt19-en-ru")
dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)
def translation_wrapper(text: str):
input_ids = translation_tokenizer.encode(text, return_tensors="pt")
outputs = translation_model.generate(input_ids.to(device))
decoded = translation_tokenizer.decode(outputs[0].float(), skip_special_tokens=True)
return decoded
def dalle_wrapper(prompt: str):
top_k, top_p = random.choice([
(1024, 0.98),
(512, 0.97),
(384, 0.96),
])
images , _ = generate_images(
prompt,
tokenizer,
dalle,
vae,
top_k=top_k,
images_num=1,
top_p=top_p
)
title = f"<b>{prompt}</b>"
return title, images[0]
translator = gr.Interface(fn=translation_wrapper,
inputs=[gr.inputs.Textbox(label='What would you like to see?')],
outputs="text")
outputs = [
gr.outputs.HTML(label=""),
gr.outputs.Image(label=""),
]
generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs)
description = (
"ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). "
"This demo uses an English-Russian translation model to adapt the prompts. "
"Try pressing [Submit] multiple times to generate new images!"
)
article = (
"<p style='text-align: center'>"
"<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | "
"<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>"
"</p>"
)
examples = [["A still life of grapes and a bottle of wine"],
["Город в стиле киберпанк"],
["A colorful photo of a coral reef"],
["A white cat sitting in a cardboard box"]]
series = Series(translator, generator,
title='Kinda-English ruDALL-E',
description=description,
article=article,
layout='horizontal',
theme='huggingface',
examples=examples,
allow_flagging=False,
live=False,
enable_queue=True,
)
series.launch()
|