Spaces:
Running
Running
File size: 8,707 Bytes
fbfca4a 8655f82 fbfca4a 8655f82 039d611 6308662 4c8dbff fbfca4a 9687f11 039d611 8655f82 039d611 8655f82 0174717 a37cee2 b6f4412 a37cee2 fbfca4a 039d611 8655f82 a5b34f7 8655f82 0174717 8655f82 0174717 58979a1 0174717 7f2c0d1 0174717 8655f82 fbfca4a fc5a09a 0174717 8ea3dab 0174717 9687f11 0174717 9687f11 fbfca4a 0174717 039d611 b66cc53 b6f8b65 039d611 9687f11 1ab6216 37aa083 b66cc53 b6f8b65 fbfca4a 9687f11 b66cc53 9687f11 b66cc53 8655f82 8ea3dab 039d611 0872eb7 f6fe4f2 0872eb7 8ea3dab 0872eb7 f6fe4f2 0872eb7 8ea3dab 0872eb7 f6fe4f2 0872eb7 039d611 d15e848 e8d705f f6fe4f2 0872eb7 8ea3dab 5727bd1 88976aa 5727bd1 87fc391 8285b48 0872eb7 f6fe4f2 6d06d55 0872eb7 f6fe4f2 0872eb7 6d06d55 0872eb7 87fc391 8ea3dab 87fc391 1118a37 8ea3dab 1118a37 87fc391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import itertools
import torch
from statistics import mean
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import gradio as gr
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from mutual_implication_score import MIS
from time import time
# Load the model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.to(device)
embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu').half()
luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True).half()
luar_model.to(device)
luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True)
mis_model = MIS(device=device)
def get_target_style_embeddings(target_texts_batch):
all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
embeddings = embedding_model.encode(all_target_texts, batch_size=len(all_target_texts), convert_to_tensor=True, show_progress_bar=False)
lengths = [len(target_texts) for target_texts in target_texts_batch]
split_embeddings = torch.split(embeddings, lengths)
padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)
mask = (torch.arange(padded_embeddings.size(1))[None, :] < torch.tensor(lengths)[:, None]).to(embeddings.dtype).unsqueeze(-1)
mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
return mean_embeddings.float().cpu().numpy()
@torch.no_grad()
def get_luar_embeddings(texts_batch):
assert len(set([len(texts) for texts in texts_batch])) == 1
episodes = texts_batch
tokenized_episodes = [luar_tokenizer(episode, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device) for episode in episodes]
episode_lengths = [t["attention_mask"].shape[0] for t in tokenized_episodes]
max_episode_length = max(episode_lengths)
sequence_lengths = [t["attention_mask"].shape[1] for t in tokenized_episodes]
max_sequence_length = max(sequence_lengths)
padded_input_ids = [torch.nn.functional.pad(t["input_ids"], (0, 0, 0, max_episode_length - t["input_ids"].shape[0])) for t in tokenized_episodes]
padded_attention_mask = [torch.nn.functional.pad(t["attention_mask"], (0, 0, 0, max_episode_length - t["attention_mask"].shape[0])) for t in tokenized_episodes]
input_ids = torch.stack(padded_input_ids)
attention_mask = torch.stack(padded_attention_mask)
return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
def compute_mis(texts, target_texts_batch):
a_texts = list(itertools.chain.from_iterable([[st] * len(target_texts) for st, target_texts in zip(source_texts, target_texts_batch)]))
b_texts = list(itertools.chain.from_iterable(target_texts_batch))
scores = mis.compute(a_texts, b_texts, batch_size=len(a_texts))
for idx, (score, a_text, b_text) in enumerate(zip(scores, a_texts, b_texts)):
if a_text == b_text:
scores[idx] = 1.0
final_scores = []
current_idx = 0
for target_texts in target_texts_batch:
final_scores.append(mean(scores[idx:idx+len(target_texts)]))
return final_scores
def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
inputs = tokenizer(source_texts, return_tensors="pt").to(device)
target_style_embeddings = get_target_style_embeddings(target_texts_batch)
source_style_luar_embeddings = get_luar_embeddings([[st] for st in source_texts])
print("Log 0", time(), source_style_luar_embeddings.shape)
target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
print("Log 1", time(), target_style_luar_embeddings.shape)
baseline_sim = compute_mis(source_texts, target_texts_batch)
print("Log 1.5", time(), len(baseline_sim))
# Generate the output with specified temperature and top_p
output = model.generate(
inputs["input_ids"],
do_sample=True,
temperature=temperature,
top_p=top_p,
max_length=1024,
num_return_sequences=reranking,
)
print("Log 2", time(), output.shape)
generated_texts = tokenizer.batch_decode(output, skip_special_tokens=True)
generated_texts = [generated_texts[i * reranking:(i + 1) * reranking] for i in range(inputs["input_ids"].shape[0])] # Unflatten
# Evaluate candidates
candidates_luar_embeddings = [get_luar_embeddings([[candidates[i]] for candidates in generated_texts]) for i in range(reranking)]
candidates_sim = [compute_mis([candidates[i] for candidates in generated_texts], target_texts_batch) for i in range(reranking)]
print("Log 3", time(), len(candidates_luar_embeddings), len(candidates_luar_embeddings[0]))
# Get best based on re-ranking
generated_texts = [texts[0] for texts in generated_texts]
print("Final Log", time(), len(generated_texts))
return generated_texts
def run_tinystyler(source_text, target_texts, reranking, temperature, top_p):
target_texts = [target_text.strip() for target_text in target_texts.split("\n")]
return run_tinystyler_batch([source_text], [target_texts], reranking, temperature, top_p)[0]
# Preset examples with cached generations
preset_examples = {
"Example 1": {
"source_text": "Once upon a time in a small village",
"target_texts": "In a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness.",
"reranking": 5,
"temperature": 1.0,
"top_p": 1.0,
"output": "Once upon a time in a small village in a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness."
},
"Example 2": {
"source_text": "The quick brown fox",
"target_texts": "A nimble, chocolate-colored fox swiftly darted through the emerald forest, weaving between trees with grace and agility.",
"reranking": 5,
"temperature": 0.9,
"top_p": 0.9,
"output": "The quick brown fox, a nimble, chocolate-colored fox, swiftly darted through the emerald forest, weaving between trees with grace and agility."
}
}
# Define Gradio interface
with gr.Blocks(theme="ParityError/[email protected]") as demo:
gr.Markdown("# TinyStyler Demo")
gr.Markdown("Style transfer the source text into the target style, given some example texts of the target style. You can adjust re-ranking and top_p to your desire to control the quality of style transfer. A higher re-ranking value will generally result in better generations, at slower speed.")
with gr.Row():
example_dropdown = gr.Dropdown(label="Examples", choices=list(preset_examples.keys()))
source_text = gr.Textbox(lines=3, placeholder="Enter the source text to transform into the target style...", label="Source Text")
target_texts = gr.Textbox(lines=5, placeholder="Enter example texts of the target style (one per line)...", label="Example Texts of the Target Style")
reranking = gr.Slider(1, 10, value=5, step=1, label="Re-ranking")
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.1, label="Top-P")
output = gr.Textbox(lines=5, placeholder="Click 'Generate' to transform the source text into the target style.", label="Output", interactive=False)
def set_example(example_name):
example = preset_examples[example_name]
return example["source_text"], example["target_texts"], example["reranking"], example["temperature"], example["top_p"], example["output"]
example_dropdown.change(
set_example,
inputs=[example_dropdown],
outputs=[source_text, target_texts, reranking, temperature, top_p, output]
)
btn = gr.Button("Generate")
btn.click(run_tinystyler, [source_text, target_texts, reranking, temperature, top_p], output)
# Initialize the fields with the first example
example_dropdown.value, (source_text.value, target_texts.value, reranking.value, temperature.value, top_p.value, output.value) = list(preset_examples.keys())[0], set_example(list(preset_examples.keys())[0])
demo.launch() |