Spaces:
Running
Running
File size: 4,044 Bytes
039d611 f6fe4f2 0872eb7 039d611 0872eb7 f6fe4f2 0872eb7 f6fe4f2 0872eb7 f6fe4f2 0872eb7 039d611 d15e848 e8d705f f6fe4f2 0872eb7 5727bd1 87fc391 8285b48 0872eb7 f6fe4f2 0872eb7 f6fe4f2 0872eb7 87fc391 f6fe4f2 87fc391 1118a37 1b1ed88 1118a37 87fc391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the model and tokenizer
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def run_tinystyler(source_text, target_example_texts, reranking, temperature, top_p):
concatenated_text = source_text + " " + target_example_texts
inputs = tokenizer(concatenated_text, return_tensors="pt")
# Generate the output with specified temperature and top_p
output = model.generate(
inputs["input_ids"],
do_sample=True,
temperature=temperature,
top_p=top_p,
max_length=100
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Preset examples with cached generations
preset_examples = {
"Example 1": {
"source_text": "Once upon a time in a small village",
"target_example_texts": "In a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness.",
"reranking": 5,
"temperature": 1.0,
"top_p": 1.0,
"output": "Once upon a time in a small village in a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness."
},
"Example 2": {
"source_text": "The quick brown fox",
"target_example_texts": "A nimble, chocolate-colored fox swiftly darted through the emerald forest, weaving between trees with grace and agility.",
"reranking": 5,
"temperature": 0.9,
"top_p": 0.9,
"output": "The quick brown fox, a nimble, chocolate-colored fox, swiftly darted through the emerald forest, weaving between trees with grace and agility."
}
}
# Define Gradio interface
with gr.Blocks(theme="ParityError/[email protected]") as demo:
gr.Markdown("# TinyStyler Demo")
gr.Markdown("Style transfer the source text into the target style, given some example texts of the target style. You can adjust re-ranking and top_p to your desire to control the quality of style transfer. A higher re-ranking value will generally result in better generations, at slower speed.")
with gr.Row():
example_dropdown = gr.Dropdown(label="Examples", choices=list(preset_examples.keys()))
source_text = gr.Textbox(lines=3, placeholder="Enter the source text to transform into the target style...", label="Source Text")
target_example_texts = gr.Textbox(lines=5, placeholder="Enter example texts of the target style (one per line)...", label="Example Texts of the Target Style")
reranking = gr.Slider(1, 10, value=5, step=1, label="Re-ranking")
temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.1, label="Top-P")
output = gr.Textbox(lines=5, placeholder="Click 'Generate' to transform the source text into the target style.", label="Output", interactive=False)
def set_example(example_name):
example = preset_examples[example_name]
return example["source_text"], example["target_example_texts"], example["reranking"], example["temperature"], example["top_p"], example["output"]
example_dropdown.change(
set_example,
inputs=[example_dropdown],
outputs=[source_text, target_example_texts, reranking, temperature, top_p, output]
)
btn = gr.Button("Generate")
btn.click(run_tinystyler, [source_text, target_example_texts, reranking, temperature, top_p], output)
# Initialize the fields with the first example
example_dropdown.value, (source_text.value, target_example_texts.value, reranking.value, temperature.value, top_p.value, output.value) = list(preset_examples.keys())[0], set_example(list(preset_examples.keys())[0])
demo.launch() |