OMG-LLaVA / xtuner /utils /stop_criteria.py
zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
raw
history blame contribute delete
576 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
from transformers import StoppingCriteria
class StopWordStoppingCriteria(StoppingCriteria):
"""StopWord stopping criteria."""
def __init__(self, tokenizer, stop_word):
self.tokenizer = tokenizer
self.stop_word = stop_word
self.length = len(self.stop_word)
def __call__(self, input_ids, *args, **kwargs) -> bool:
cur_text = self.tokenizer.decode(input_ids[0])
cur_text = cur_text.replace('\r', '').replace('\n', '')
return cur_text[-self.length:] == self.stop_word