Spaces:
Runtime error
Runtime error
File size: 1,067 Bytes
efe0924 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import traceback
from queue import Queue
from threading import Thread
import collections.abc
import torch
from transformers import StoppingCriteria
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=[]):
super().__init__()
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to("cuda") for stop in stops]
self.num_stops = [0] * len(stops)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stopi, stop in enumerate(self.stops):
if torch.all((stop == input_ids[0][-len(stop):])).item():
self.num_stops[stopi] += 1
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False
|