|
'''
|
|
This file has been 100% copied from this PR to the Transformers library:
|
|
https://github.com/huggingface/transformers/pull/27557
|
|
|
|
Author: Saibo-creator
|
|
Author GitHub: https://github.com/Saibo-creator
|
|
|
|
All credits go to the author.
|
|
'''
|
|
|
|
import math
|
|
|
|
import torch
|
|
from transformers.generation.logits_process import LogitsProcessor
|
|
from transformers.utils import add_start_docstrings
|
|
|
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
|
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
|
search or log softmax for each vocabulary token when using beam search
|
|
|
|
Return:
|
|
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
|
|
|
|
"""
|
|
|
|
|
|
class GrammarConstrainedLogitsProcessor(LogitsProcessor):
|
|
def __init__(self, grammar_constraint):
|
|
self.last_size = None
|
|
self.grammar_constraint = grammar_constraint
|
|
self.batch_stacks = None
|
|
|
|
def filter_logits(self, logits, device):
|
|
|
|
|
|
|
|
acceptance = self.grammar_constraint.batch_filter_vocab(self.batch_stacks, device)
|
|
|
|
|
|
logits[~acceptance] = -math.inf
|
|
|
|
|
|
def process_logits(self, input_ids, scores, parse_start_index=None):
|
|
"""
|
|
:param input_ids:
|
|
:param scores:
|
|
:param parse_start_index: default None, which means generate from scratch. Set to 0 to parse all input_ids
|
|
:return:
|
|
"""
|
|
|
|
if self.batch_stacks is None:
|
|
self.batch_stacks = [self.grammar_constraint.init_stacks() for _ in range(len(input_ids))]
|
|
|
|
|
|
|
|
if self.last_size is None:
|
|
prefix_to_parse = [
|
|
single_input_ids[parse_start_index:] if parse_start_index is not None else []
|
|
for single_input_ids in input_ids
|
|
]
|
|
|
|
self.batch_stacks = [
|
|
self.grammar_constraint.accept_token_ids(prefix, stack)
|
|
for prefix, stack in zip(prefix_to_parse, self.batch_stacks)
|
|
]
|
|
|
|
|
|
elif len(input_ids[0]) == self.last_size + 1:
|
|
|
|
self.batch_stacks = [
|
|
self.grammar_constraint.accept_token_id(single_input_ids[-1], stack)
|
|
for single_input_ids, stack in zip(input_ids, self.batch_stacks)
|
|
]
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise RuntimeError(
|
|
"Input ID's length is inconsistent with the current state of "
|
|
"the GrammarConstrainedLogitsProcessor. If you want to process "
|
|
"another input sequence, please instantiate a new "
|
|
"GrammarConstrainedLogitsProcessor."
|
|
)
|
|
|
|
self.filter_logits(scores, scores.device)
|
|
|
|
self.last_size = len(input_ids[0])
|
|
return scores
|
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
return self.process_logits(input_ids, scores)
|
|
|