AjayP13 commited on
Commit
0174717
·
verified ·
1 Parent(s): 6d06d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -9,9 +9,12 @@ from sentence_transformers import SentenceTransformer
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  model_name = "google/flan-t5-large"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu')
13
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
14
  model.to(device)
 
 
 
 
15
 
16
  def get_target_style_embeddings(target_texts_batch):
17
  all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
@@ -19,13 +22,29 @@ def get_target_style_embeddings(target_texts_batch):
19
  lengths = [len(target_texts) for target_texts in target_texts_batch]
20
  split_embeddings = torch.split(embeddings, lengths)
21
  padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)
22
- mask = (torch.arange(padded_embeddings.size(1))[None, :] < torch.tensor(lengths)[:, None]).to(torch.float32).unsqueeze(-1)
23
  mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
24
- return mean_embeddings.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
27
- inputs = tokenizer(source_texts, return_tensors="pt")
28
  target_style_embeddings = get_target_style_embeddings(target_texts_batch)
 
 
 
29
 
30
  # Generate the output with specified temperature and top_p
31
  output = model.generate(
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  model_name = "google/flan-t5-large"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
  model.to(device)
14
+ embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu').half()
15
+ luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True).half()
16
+ luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
17
+
18
 
19
  def get_target_style_embeddings(target_texts_batch):
20
  all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
 
22
  lengths = [len(target_texts) for target_texts in target_texts_batch]
23
  split_embeddings = torch.split(embeddings, lengths)
24
  padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)
25
+ mask = (torch.arange(padded_embeddings.size(1))[None, :] < torch.tensor(lengths)[:, None]).to(embeddings.dtype).unsqueeze(-1)
26
  mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
27
+ return mean_embeddings.float().cpu().numpy()
28
+
29
+ def get_luar_embeddings(texts_batch):
30
+ episodes = texts_batch
31
+ tokenized_episodes = [luar_tokenizer(episode, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device) for episode in episodes]
32
+ episode_lengths = [t["attention_mask"].shape[0] for t in tokenized_episodes]
33
+ max_episode_length = max(episode_lengths)
34
+ sequence_lengths = [t["attention_mask"].shape[1] for t in tokenized_episodes]
35
+ max_sequence_length = max(sequence_lengths)
36
+ padded_input_ids = [torch.nn.functional.pad(t["input_ids"], (0, 0, 0, max_episode_length - t["input_ids"].shape[0])) for t in tokenized_episodes]
37
+ padded_attention_mask = [torch.nn.functional.pad(t["attention_mask"], (0, 0, 0, max_episode_length - t["attention_mask"].shape[0])) for t in tokenized_episodes]
38
+ input_ids = torch.stack(padded_input_ids)
39
+ attention_mask = torch.stack(padded_attention_mask)
40
+ return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
41
 
42
  def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
43
+ inputs = tokenizer(source_texts, return_tensors="pt").to(device)
44
  target_style_embeddings = get_target_style_embeddings(target_texts_batch)
45
+ source_style_luar_embeddings = get_luar_embeddings([[st] for st in source_texts])
46
+ target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
47
+
48
 
49
  # Generate the output with specified temperature and top_p
50
  output = model.generate(