Spaces:
Runtime error
Runtime error
File size: 576 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
# Copyright (c) OpenMMLab. All rights reserved.
from transformers import StoppingCriteria
class StopWordStoppingCriteria(StoppingCriteria):
"""StopWord stopping criteria."""
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(self.stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace('\r', '').replace('\n', '')
return cur_text[-self.length:] == self.stop_word
|