File size: 3,769 Bytes
039d611
 
 
 
 
 
 
 
f6fe4f2
0872eb7
039d611
 
 
 
 
 
 
 
 
 
 
 
 
 
0872eb7
f6fe4f2
 
0872eb7
 
 
 
 
 
 
f6fe4f2
0872eb7
 
 
 
 
 
 
f6fe4f2
0872eb7
039d611
d15e848
 
 
f6fe4f2
 
 
 
0872eb7
 
5727bd1
 
 
87fc391
8285b48
0872eb7
f6fe4f2
 
0872eb7
 
f6fe4f2
 
0872eb7
 
 
87fc391
 
f6fe4f2
87fc391
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the model and tokenizer
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

def run_tinystyler(source_text, target_example_texts, reranking, temperature, top_p):
    concatenated_text = source_text + " " + target_example_texts
    inputs = tokenizer(concatenated_text, return_tensors="pt")
    
    # Generate the output with specified temperature and top_p
    output = model.generate(
        inputs["input_ids"], 
        do_sample=True, 
        temperature=temperature, 
        top_p=top_p,
        max_length=100
    )
    
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

# Preset examples with cached generations
preset_examples = {
    "Example 1": {
        "source_text": "Once upon a time in a small village",
        "target_example_texts": "In a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness.",
        "reranking": 5,
        "temperature": 1.0,
        "top_p": 1.0,
        "output": "Once upon a time in a small village in a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness."
    },
    "Example 2": {
        "source_text": "The quick brown fox",
        "target_example_texts": "A nimble, chocolate-colored fox swiftly darted through the emerald forest, weaving between trees with grace and agility.",
        "reranking": 5,
        "temperature": 0.9,
        "top_p": 0.9,
        "output": "The quick brown fox, a nimble, chocolate-colored fox, swiftly darted through the emerald forest, weaving between trees with grace and agility."
    }
}

# Define Gradio interface
with gr.Blocks(theme="ParityError/[email protected]") as demo:
    gr.Markdown("# TinyStyler Demo")
    gr.Markdown("Style transfer the source text into the target style, given some example texts of the target style. You can adjust re-ranking and top_p to your desire to control the quality of style transfer. A higher re-ranking value will generally result in better results, at slower speed.")
    
    with gr.Row():
        example_dropdown = gr.Dropdown(label="Examples", choices=list(preset_examples.keys()))
    
    source_text = gr.Textbox(lines=3, placeholder="Enter the source text to transform into the target style...", label="Source Text")
    target_example_texts = gr.Textbox(lines=5, placeholder="Enter example texts of the target style (one per line)...", label="Example Texts of the Target Style")
    reranking = gr.Slider(1, 10, value=5, step=1, label="Re-ranking")
    temperature = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="Temperature")
    top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.1, label="Top-P")
    
    output = gr.Textbox(lines=5, placeholder="Click 'Generate' to transform the source text into the target style.", label="Output", interactive=False)

    def set_example(example_name):
        example = preset_examples[example_name]
        return example["source_text"], example["target_example_texts"], example["reranking"], example["temperature"], example["top_p"], example["output"]

    example_dropdown.change(
        set_example,
        inputs=[example_dropdown],
        outputs=[source_text, target_example_texts, reranking, temperature, top_p, output]
    )
    
    btn = gr.Button("Generate")
    btn.click(run_tinystyler, [source_text, target_example_texts, reranking, temperature, top_p], output)

demo.launch()