AjayP13 commited on
Commit
a5b34f7
·
verified ·
1 Parent(s): 8655f82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -8,13 +8,13 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  model_name = "google/flan-t5-large"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu')
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
  model.to(device)
14
 
15
  def get_target_style_embeddings(target_texts_batch):
16
  all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
17
- embeddings = model.encode(all_target_texts, batch_size=len(all_target_texts), convert_to_tensor=True, show_progress_bar=False)
18
  lengths = [len(target_texts) for target_texts in target_texts_batch]
19
  split_embeddings = torch.split(embeddings, lengths)
20
  padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)
 
8
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  model_name = "google/flan-t5-large"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu')
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
  model.to(device)
14
 
15
  def get_target_style_embeddings(target_texts_batch):
16
  all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
17
+ embeddings = embedding_model.encode(all_target_texts, batch_size=len(all_target_texts), convert_to_tensor=True, show_progress_bar=False)
18
  lengths = [len(target_texts) for target_texts in target_texts_batch]
19
  split_embeddings = torch.split(embeddings, lengths)
20
  padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)