Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
from diffusers import DiffusionPipeline | |
prompt_writer = pipeline('text-generation', model='toloka/gpt2-large-rl-prompt-writing') | |
prompt_reward_model = pipeline('text-classification', model='toloka/prompts_reward_model') | |
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") | |
def write_prompt(img_desc): | |
prompts = [p['generated_text'] for p in prompt_writer(img_desc + '</s>', max_new_tokens=100, num_return_sequences=2)] | |
scores = [p['score'] for p in prompt_reward_model(prompts, function_to_apply='none')] | |
return prompts[np.argmax(scores)].split('</s>')[1].strip() | |
def generate(text): | |
prompt = write_prompt(text) | |
img = pipe(prompt=prompt, num_inference_steps=50).images[0] | |
return img, prompt | |
with gr.Blocks() as demo: | |
with gr.Column(variant="panel"): | |
with gr.Row(variant="compact"): | |
text = gr.Textbox( | |
label="Enter your image description, e.g., \"a cat\"", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your image description, e.g., \"a cat\"", | |
).style( | |
container=False, | |
) | |
btn = gr.Button("Generate image").style(full_width=False) | |
written_prompt = gr.outputs.Textbox(label="AI-written prompt") | |
gen_img = gr.outputs.Image(type="pil", | |
label="Generated image", | |
).style(object_fit="contain", height=512) | |
btn.click(generate, text, [gen_img, written_prompt]) | |
if __name__ == "__main__": | |
demo.launch() | |