lyangas commited on
Commit
4585ef5
1 Parent(s): e980b69

Update helpers/required_classes.py

Browse files
Files changed (1) hide show
  1. helpers/required_classes.py +9 -8
helpers/required_classes.py CHANGED
@@ -24,14 +24,15 @@ class BertEmbedder:
24
  self.embedder.to(self.device)
25
 
26
  def __call__(self, text: str):
27
- encoded_input = self.tokenizer(text,
28
- return_tensors='pt',
29
- max_length=self.max_length,
30
- padding=True,
31
- truncation=True).to(self.device)
32
- model_output = self.embedder(**encoded_input)
33
- text_embed = model_output.pooler_output[0].cpu()
34
- return text_embed
 
35
 
36
  def batch_predict(self, texts: List[str]):
37
  encoded_input = self.tokenizer(texts,
 
24
  self.embedder.to(self.device)
25
 
26
  def __call__(self, text: str):
27
+ with torch.no_grad():
28
+ encoded_input = self.tokenizer(text,
29
+ return_tensors='pt',
30
+ max_length=self.max_length,
31
+ padding=True,
32
+ truncation=True).to(self.device)
33
+ model_output = self.embedder(**encoded_input)
34
+ text_embed = model_output.pooler_output[0].cpu()
35
+ return text_embed
36
 
37
  def batch_predict(self, texts: List[str]):
38
  encoded_input = self.tokenizer(texts,