Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from transformers import StoppingCriteria | |
class StopWordStoppingCriteria(StoppingCriteria): | |
"""StopWord stopping criteria.""" | |
def __init__(self, tokenizer, stop_word): | |
self.tokenizer = tokenizer | |
self.stop_word = stop_word | |
self.length = len(self.stop_word) | |
def __call__(self, input_ids, *args, **kwargs) -> bool: | |
cur_text = self.tokenizer.decode(input_ids[0]) | |
cur_text = cur_text.replace('\r', '').replace('\n', '') | |
return cur_text[-self.length:] == self.stop_word | |