Update helpers/required_classes.py
Browse files
helpers/required_classes.py
CHANGED
@@ -24,14 +24,15 @@ class BertEmbedder:
|
|
24 |
self.embedder.to(self.device)
|
25 |
|
26 |
def __call__(self, text: str):
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
35 |
|
36 |
def batch_predict(self, texts: List[str]):
|
37 |
encoded_input = self.tokenizer(texts,
|
|
|
24 |
self.embedder.to(self.device)
|
25 |
|
26 |
def __call__(self, text: str):
|
27 |
+
with torch.no_grad():
|
28 |
+
encoded_input = self.tokenizer(text,
|
29 |
+
return_tensors='pt',
|
30 |
+
max_length=self.max_length,
|
31 |
+
padding=True,
|
32 |
+
truncation=True).to(self.device)
|
33 |
+
model_output = self.embedder(**encoded_input)
|
34 |
+
text_embed = model_output.pooler_output[0].cpu()
|
35 |
+
return text_embed
|
36 |
|
37 |
def batch_predict(self, texts: List[str]):
|
38 |
encoded_input = self.tokenizer(texts,
|