import torch import torch.nn.functional as F from transformers.generation import TopKLogitsWarper, TopPLogitsWarper class CustomRepetitionPenaltyLogitsProcessorRepeat: def __init__(self, penalty: float, max_input_ids: int, past_window: int): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError( f"`penalty` has to be a strictly positive float, but is {penalty}" ) self.penalty = penalty self.max_input_ids = max_input_ids self.past_window = past_window def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: if input_ids.size(1) > self.past_window: input_ids = input_ids.narrow(1, -self.past_window, self.past_window) freq = F.one_hot(input_ids, scores.size(1)).sum(1) if freq.size(0) > self.max_input_ids: freq.narrow( 0, self.max_input_ids, freq.size(0) - self.max_input_ids ).zero_() alpha = torch.pow(self.penalty, freq) scores = scores.contiguous() inp = scores.multiply(alpha) oth = scores.divide(alpha) con = scores < 0 out = torch.where(con, inp, oth) del inp, oth, scores, con, alpha return out def gen_logits( num_code: int, top_P=0.7, top_K=20, repetition_penalty=1.0, ): logits_warpers = [] if top_P is not None: logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) if top_K is not None: logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) logits_processors = [] if repetition_penalty is not None and repetition_penalty != 1: logits_processors.append( CustomRepetitionPenaltyLogitsProcessorRepeat( repetition_penalty, num_code, 16 ) ) return logits_warpers, logits_processors