AjayP13 commited on
Commit
8655f82
·
verified ·
1 Parent(s): 88976aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -1,14 +1,29 @@
 
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
  # Load the model and tokenizer
 
5
  model_name = "google/flan-t5-large"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
8
 
9
- def run_tinystyler(source_text, target_example_texts, reranking, temperature, top_p):
10
- concatenated_text = source_text + " " + target_example_texts
11
- inputs = tokenizer(concatenated_text, return_tensors="pt")
 
 
 
 
 
 
 
 
 
12
 
13
  # Generate the output with specified temperature and top_p
14
  output = model.generate(
@@ -19,8 +34,11 @@ def run_tinystyler(source_text, target_example_texts, reranking, temperature, to
19
  max_length=1024
20
  )
21
 
22
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
23
- return generated_text
 
 
 
24
 
25
  # Preset examples with cached generations
26
  preset_examples = {
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.nn.utils.rnn import pad_sequence
4
  import gradio as gr
5
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
 
7
  # Load the model and tokenizer
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  model_name = "google/flan-t5-large"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu')
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+ model.to(device)
14
 
15
+ def get_target_style_embeddings(target_texts_batch):
16
+ all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
17
+ embeddings = model.encode(all_target_texts, batch_size=len(all_target_texts), convert_to_tensor=True, show_progress_bar=False)
18
+ lengths = [len(target_texts) for target_texts in target_texts_batch]
19
+ split_embeddings = torch.split(embeddings, lengths)
20
+ padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)
21
+ mask = (torch.arange(padded_embeddings.size(1))[None, :] < torch.tensor(lengths)[:, None]).to(torch.float32).unsqueeze(-1)
22
+ mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
23
+ return mean_embeddings.cpu().numpy()
24
+
25
+ def run_tinystyler_batch(source_texts, target_example_texts_batch, reranking, temperature, top_p):
26
+ inputs = tokenizer(source_texts, return_tensors="pt")
27
 
28
  # Generate the output with specified temperature and top_p
29
  output = model.generate(
 
34
  max_length=1024
35
  )
36
 
37
+ generated_texts = tokenizer.decode_batch(output, skip_special_tokens=True)
38
+ return generated_texts
39
+
40
+ def run_tinystyler(source_text, target_example_texts, reranking, temperature, top_p):
41
+ return run_tinystyler_batch([source_text], [target_example_texts], reranking, temperature, top_p)[0]
42
 
43
  # Preset examples with cached generations
44
  preset_examples = {