YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Promptist: reinforcement learning for automatic prompt optimization
News
- [Demo Release] Dec, 2022: Demo at HuggingFace Space
- [Model Release] Dec, 2022: link
- [Paper Release] Dec, 2022: Optimizing Prompts for Text-to-Image Generation
- Language models serve as a prompt interface that optimizes user input into model-preferred prompts.
- Learn a language model for automatic prompt optimization via reinforcement learning.
Load Pretrained Model for Stable Diffusion v1.4
You can try the online demo at https://huggingface.co/spaces/microsoft/Promptist.
[Note]
the online demo at HuggingFace Space is using CPU, so slow generation speed would be expected. Please load the model locally with GPUs for faster generation.
import gradio as grad
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def load_prompter():
prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
return prompter_model, tokenizer
prompter_model, prompter_tokenizer = load_prompter()
def generate(plain_text):
input_ids = prompter_tokenizer(plain_text.strip()+" Rephrase:", return_tensors="pt").input_ids
eos_id = prompter_tokenizer.eos_token_id
outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=8, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
res = output_texts[0].replace(plain_text+" Rephrase:", "").strip()
return res
txt = grad.Textbox(lines=1, label="Initial Text", placeholder="Input Prompt")
out = grad.Textbox(lines=1, label="Optimized Prompt")
examples = ["A rabbit is wearing a space suit", "Several railroad tracks with one train passing by", "The roof is wet from the rain", "Cats dancing in a space club"]
grad.Interface(fn=generate,
inputs=txt,
outputs=out,
title="Promptist Demo",
description="Promptist is a prompt interface for Stable Diffusion v1-4 (https://huggingface.co/CompVis/stable-diffusion-v1-4) that optimizes user input into model-preferred prompts.",
examples=examples,
allow_flagging='never',
cache_examples=False,
theme="default").launch(enable_queue=True, debug=True)
- Downloads last month
- 1,334
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.