import gradio as gr from transformers import pipeline import numpy as np from diffusers import DiffusionPipeline prompt_writer = pipeline('text-generation', model='toloka/gpt2-large-rl-prompt-writing') prompt_reward_model = pipeline('text-classification', model='toloka/prompts_reward_model') pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") def write_prompt(img_desc): prompts = [p['generated_text'] for p in prompt_writer(img_desc + '', max_new_tokens=100, num_return_sequences=2)] scores = [p['score'] for p in prompt_reward_model(prompts, function_to_apply='none')] return prompts[np.argmax(scores)].split('')[1].strip() def generate(text): prompt = write_prompt(text) img = pipe(prompt=prompt, num_inference_steps=50).images[0] return img, prompt with gr.Blocks() as demo: with gr.Column(variant="panel"): with gr.Row(variant="compact"): text = gr.Textbox( label="Enter your image description, e.g., \"a cat\"", show_label=False, max_lines=1, placeholder="Enter your image description, e.g., \"a cat\"", ).style( container=False, ) btn = gr.Button("Generate image").style(full_width=False) written_prompt = gr.outputs.Textbox(label="AI-written prompt") gen_img = gr.outputs.Image(type="pil", label="Generated image", ).style(object_fit="contain", height=512) btn.click(generate, text, [gen_img, written_prompt]) if __name__ == "__main__": demo.launch()