File size: 485 Bytes
71e7434 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
from transformers import StoppingCriteria
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_indices: list):
self.stop_indices = stop_indices
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# do not support batch inference
for i in range(len(self.stop_indices)):
if self.stop_indices[-1-i] != input_ids[0][-1-i]:
return False
return True
|