File size: 8,707 Bytes
fbfca4a
8655f82
fbfca4a
8655f82
 
039d611
6308662
4c8dbff
fbfca4a
9687f11
039d611
 
8655f82
039d611
 
 
8655f82
0174717
a37cee2
b6f4412
a37cee2
fbfca4a
039d611
8655f82
 
a5b34f7
8655f82
 
 
0174717
8655f82
0174717
 
58979a1
0174717
7f2c0d1
0174717
 
 
 
 
 
 
 
 
 
 
8655f82
fbfca4a
 
 
 
 
 
 
 
 
 
 
 
 
fc5a09a
0174717
8ea3dab
0174717
9687f11
0174717
9687f11
fbfca4a
 
0174717
039d611
 
 
 
 
 
 
b66cc53
b6f8b65
039d611
9687f11
1ab6216
37aa083
b66cc53
b6f8b65
 
fbfca4a
9687f11
b66cc53
 
 
9687f11
b66cc53
8655f82
 
8ea3dab
 
 
039d611
0872eb7
f6fe4f2
 
0872eb7
8ea3dab
0872eb7
 
 
 
 
f6fe4f2
0872eb7
8ea3dab
0872eb7
 
 
 
 
f6fe4f2
0872eb7
039d611
d15e848
 
e8d705f
f6fe4f2
 
 
 
0872eb7
8ea3dab
5727bd1
88976aa
5727bd1
87fc391
8285b48
0872eb7
f6fe4f2
 
6d06d55
0872eb7
f6fe4f2
 
0872eb7
6d06d55
0872eb7
87fc391
 
8ea3dab
87fc391
1118a37
8ea3dab
1118a37
87fc391
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import itertools
import torch
from statistics import mean
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import gradio as gr
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from mutual_implication_score import MIS
from time import time

# Load the model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.to(device)
embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu').half()
luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True).half()
luar_model.to(device)
luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True)
mis_model = MIS(device=device)

def get_target_style_embeddings(target_texts_batch):
    all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
    embeddings = embedding_model.encode(all_target_texts, batch_size=len(all_target_texts), convert_to_tensor=True, show_progress_bar=False)
    lengths = [len(target_texts) for target_texts in target_texts_batch]
    split_embeddings = torch.split(embeddings, lengths)
    padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)
    mask = (torch.arange(padded_embeddings.size(1))[None, :] < torch.tensor(lengths)[:, None]).to(embeddings.dtype).unsqueeze(-1)
    mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
    return mean_embeddings.float().cpu().numpy()

@torch.no_grad()
def get_luar_embeddings(texts_batch):
    assert len(set([len(texts) for texts in texts_batch])) == 1
    episodes = texts_batch
    tokenized_episodes = [luar_tokenizer(episode, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device) for episode in episodes]
    episode_lengths = [t["attention_mask"].shape[0] for t in tokenized_episodes]
    max_episode_length = max(episode_lengths)
    sequence_lengths = [t["attention_mask"].shape[1] for t in tokenized_episodes]
    max_sequence_length = max(sequence_lengths)
    padded_input_ids = [torch.nn.functional.pad(t["input_ids"], (0, 0, 0, max_episode_length - t["input_ids"].shape[0])) for t in tokenized_episodes]
    padded_attention_mask = [torch.nn.functional.pad(t["attention_mask"], (0, 0, 0, max_episode_length - t["attention_mask"].shape[0])) for t in tokenized_episodes]
    input_ids = torch.stack(padded_input_ids)
    attention_mask = torch.stack(padded_attention_mask)
    return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()

def compute_mis(texts, target_texts_batch):
    a_texts = list(itertools.chain.from_iterable([[st] * len(target_texts) for st, target_texts in zip(source_texts, target_texts_batch)]))
    b_texts = list(itertools.chain.from_iterable(target_texts_batch))
    scores = mis.compute(a_texts, b_texts, batch_size=len(a_texts))
    for idx, (score, a_text, b_text) in enumerate(zip(scores, a_texts, b_texts)):
        if a_text == b_text:
            scores[idx] = 1.0
    final_scores = []
    current_idx = 0
    for target_texts in target_texts_batch:
        final_scores.append(mean(scores[idx:idx+len(target_texts)]))
    return final_scores

def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
    inputs = tokenizer(source_texts, return_tensors="pt").to(device)
    target_style_embeddings = get_target_style_embeddings(target_texts_batch)
    source_style_luar_embeddings = get_luar_embeddings([[st] for st in source_texts])
    print("Log 0", time(), source_style_luar_embeddings.shape)
    target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
    print("Log 1", time(), target_style_luar_embeddings.shape)
    baseline_sim = compute_mis(source_texts, target_texts_batch)
    print("Log 1.5", time(), len(baseline_sim))
    
    
    # Generate the output with specified temperature and top_p
    output = model.generate(
        inputs["input_ids"], 
        do_sample=True, 
        temperature=temperature, 
        top_p=top_p,
        max_length=1024,
        num_return_sequences=reranking,
    )
    print("Log 2", time(), output.shape)
    generated_texts = tokenizer.batch_decode(output, skip_special_tokens=True)
    generated_texts = [generated_texts[i * reranking:(i + 1) * reranking] for i in range(inputs["input_ids"].shape[0])] # Unflatten

    # Evaluate candidates
    candidates_luar_embeddings = [get_luar_embeddings([[candidates[i]] for candidates in generated_texts]) for i in range(reranking)]
    candidates_sim = [compute_mis([candidates[i] for candidates in generated_texts], target_texts_batch) for i in range(reranking)]
    print("Log 3", time(), len(candidates_luar_embeddings), len(candidates_luar_embeddings[0]))

    # Get best based on re-ranking
    generated_texts = [texts[0] for texts in generated_texts]
    print("Final Log", time(), len(generated_texts))
    
    return generated_texts
    
def run_tinystyler(source_text, target_texts, reranking, temperature, top_p):
    target_texts = [target_text.strip() for target_text in target_texts.split("\n")]
    return run_tinystyler_batch([source_text], [target_texts], reranking, temperature, top_p)[0]

# Preset examples with cached generations
preset_examples = {
    "Example 1": {
        "source_text": "Once upon a time in a small village",
        "target_texts": "In a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness.",
        "reranking": 5,
        "temperature": 1.0,
        "top_p": 1.0,
        "output": "Once upon a time in a small village in a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness."
    },
    "Example 2": {
        "source_text": "The quick brown fox",
        "target_texts": "A nimble, chocolate-colored fox swiftly darted through the emerald forest, weaving between trees with grace and agility.",
        "reranking": 5,
        "temperature": 0.9,
        "top_p": 0.9,
        "output": "The quick brown fox, a nimble, chocolate-colored fox, swiftly darted through the emerald forest, weaving between trees with grace and agility."
    }
}

# Define Gradio interface
with gr.Blocks(theme="ParityError/[email protected]") as demo:
    gr.Markdown("# TinyStyler Demo")
    gr.Markdown("Style transfer the source text into the target style, given some example texts of the target style. You can adjust re-ranking and top_p to your desire to control the quality of style transfer. A higher re-ranking value will generally result in better generations, at slower speed.")
    
    with gr.Row():
        example_dropdown = gr.Dropdown(label="Examples", choices=list(preset_examples.keys()))
    
    source_text = gr.Textbox(lines=3, placeholder="Enter the source text to transform into the target style...", label="Source Text")
    target_texts = gr.Textbox(lines=5, placeholder="Enter example texts of the target style (one per line)...", label="Example Texts of the Target Style")
    reranking = gr.Slider(1, 10, value=5, step=1, label="Re-ranking")
    temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
    top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.1, label="Top-P")
    
    output = gr.Textbox(lines=5, placeholder="Click 'Generate' to transform the source text into the target style.", label="Output", interactive=False)

    def set_example(example_name):
        example = preset_examples[example_name]
        return example["source_text"], example["target_texts"], example["reranking"], example["temperature"], example["top_p"], example["output"]

    example_dropdown.change(
        set_example,
        inputs=[example_dropdown],
        outputs=[source_text, target_texts, reranking, temperature, top_p, output]
    )
    
    btn = gr.Button("Generate")
    btn.click(run_tinystyler, [source_text, target_texts, reranking, temperature, top_p], output)

    # Initialize the fields with the first example
    example_dropdown.value, (source_text.value, target_texts.value, reranking.value, temperature.value, top_p.value, output.value) = list(preset_examples.keys())[0], set_example(list(preset_examples.keys())[0])

demo.launch()