SuperPrompt-v1 / app.py
Nick088's picture
Fixed expected all tensors to be on the same device
31328c7 verified
raw
history blame
2.37 kB
import gradio as gr
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
model.to(device)
def generate(
prompt, history, max_new_tokens=512, repetition_penalty=1.2, temperature=0.5, top_p=1, top_k=1, seed=42
):
input_text = f"{prompt}, {history}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
outputs = model.generate(input_ids, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k)
better_prompt = tokenizer.decode(outputs[0])
return better_prompt
additional_inputs=[
gr.Slider(value=512, minimum=250, maximum=512, step=1, interactive=True, label="Max New Tokens", info="The maximum numbers of new tokens, controls how long is the output"),
gr.Slider(value=1.2, minimum=0, maximum=2, step=0.05, interactive=True, label="Repetition Penalty", info="Penalize repeated tokens, making the AI repeat less itself"),
gr.Slider(value=0.5, minimum=0, maximum=1, step=0.05, interactive=True, label="Temperature", info="Higher values produce more diverse outputs"),
gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, label="Top P", info="Higher values sample more low-probability tokens"),
gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens"),
gr.Number(value=42, interactive=True, label="Seed", info="A starting point to initiate the generation process"),
]
examples=[["Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it.", None, None ]]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="SuperPrompt-v1",
description="Make your prompts more detailed! Especially for AI Art!!!",
examples=examples,
concurrency_limit=20,
).launch(show_api=False)