Itamarl commited on
Commit
1b9dc29
·
1 Parent(s): 3ac4d0b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -10
handler.py CHANGED
@@ -22,16 +22,16 @@ class EndpointHandler():
22
  print("tokenizer created ", datetime.now())
23
 
24
 
25
- stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
26
-
27
- class StopOnTokens(StoppingCriteria):
28
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
29
- for stop_id in stop_token_ids:
30
- if input_ids[0][-1] == stop_id:
31
- return True
32
- return False
33
-
34
- stopping_criteria = StoppingCriteriaList([StopOnTokens()])
35
 
36
  self.generate_text = transformers.pipeline(
37
  model=self.model,
 
22
  print("tokenizer created ", datetime.now())
23
 
24
 
25
+ stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
26
+
27
+ class StopOnTokens(StoppingCriteria):
28
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
29
+ for stop_id in stop_token_ids:
30
+ if input_ids[0][-1] == stop_id:
31
+ return True
32
+ return False
33
+
34
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
35
 
36
  self.generate_text = transformers.pipeline(
37
  model=self.model,