|
import copy |
|
import math |
|
import torch.nn.functional as F |
|
|
|
import torch |
|
import logging |
|
from transformers.generation.logits_process import ( |
|
LogitsProcessor, |
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING, |
|
) |
|
from transformers.utils import add_start_docstrings |
|
from transformers_gad.grammar_utils import IncrementalGrammarConstraint |
|
from transformers_gad.oracle.oracle_trie import Trie |
|
|
|
class GrammarConstrainedLogitsProcessor(LogitsProcessor): |
|
def __init__(self, grammar_constraint, parse_start_index=None, save_log=False): |
|
|
|
self.grammar_constraint = grammar_constraint |
|
self.batch_parsing_states = None |
|
self.parse_start_index = parse_start_index |
|
|
|
|
|
self.generate_start_index = None |
|
self.generated_tokens = None |
|
|
|
|
|
self.save_log = save_log |
|
self.history = [] |
|
|
|
def reset(self): |
|
self.reset_parser() |
|
self.reset_history() |
|
|
|
def reset_parser(self): |
|
self.batch_parsing_states = None |
|
if self.grammar_constraint.is_incremental: |
|
self.grammar_constraint.reset() |
|
|
|
self.generate_start_index = None |
|
self.generated_tokens = None |
|
|
|
def reset_history(self): |
|
self.history = [] |
|
|
|
def mask_scores(self, scores, device): |
|
""" |
|
resolve each stack to a tensor of True/False for each token |
|
indicating acceptance |
|
""" |
|
masked_scores = scores.clone() |
|
acceptance = self.grammar_constraint.batch_filter_vocab( |
|
self.batch_parsing_states, device |
|
) |
|
|
|
if self.save_log: |
|
self.store_detailed_history(acceptance, scores) |
|
|
|
|
|
masked_scores[~acceptance] = -math.inf |
|
|
|
return masked_scores |
|
|
|
def process_scores(self, input_ids, scores): |
|
|
|
if self.batch_parsing_states is None: |
|
self.batch_parsing_states = [ |
|
copy.deepcopy( |
|
self.grammar_constraint.string_recognizer.get_initial_accept_state() |
|
) |
|
for _ in range(len(input_ids)) |
|
] |
|
|
|
|
|
if self.generate_start_index is None: |
|
|
|
self.generate_start_index = self.parse_start_index \ |
|
if self.parse_start_index else input_ids.size(1) |
|
self.generated_tokens = input_ids[:, self.generate_start_index:] |
|
|
|
|
|
self.batch_parsing_states = self.grammar_constraint.advance_token_ids( |
|
input_ids, self.batch_parsing_states, self.parse_start_index |
|
) |
|
|
|
masked_scores = self.mask_scores(scores, scores.device) |
|
return masked_scores |
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
|
def __call__( |
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor |
|
) -> torch.FloatTensor: |
|
return self.process_scores(input_ids, scores) |
|
|
|
def reset_parser(self): |
|
self.batch_parsing_states = None |
|
if isinstance(self.grammar_constraint, IncrementalGrammarConstraint): |
|
self.grammar_constraint.reset() |
|
|
|
def get_accepted_tokens(self, acceptance): |
|
""" |
|
Get the indices of accepted tokens and their corresponding string values for each item in the batch. |
|
|
|
Parameters: |
|
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch. |
|
""" |
|
batch_size, _ = acceptance.shape |
|
acceptance_np = acceptance.cpu().numpy() |
|
accepted_x, accepted_y = acceptance_np.nonzero() |
|
|
|
|
|
accepted_token_indices = {i: [] for i in range(batch_size)} |
|
for x, y in zip(accepted_x, accepted_y): |
|
accepted_token_indices[x].append(y) |
|
|
|
|
|
accepted_tokens = { |
|
i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids] |
|
for i, token_ids in accepted_token_indices.items() |
|
} |
|
|
|
return accepted_tokens |
|
|
|
def store_detailed_history(self, acceptance, scores): |
|
""" |
|
Processes and stores information for accepted tokens including their IDs, tokens, |
|
raw scores, and logits. |
|
|
|
Parameters: |
|
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch. |
|
- scores (torch.Tensor): The raw scores from the model output. |
|
- adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality. |
|
""" |
|
likelihoods = F.softmax(scores, dim=-1) |
|
|
|
|
|
batch_accepted_info = [] |
|
|
|
for batch_index in range(acceptance.size(0)): |
|
accepted_info = [] |
|
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1) |
|
|
|
for idx in accepted_indices: |
|
token_id = idx.item() |
|
raw_score = scores[batch_index, idx].item() |
|
likelihood = likelihoods[batch_index, idx].item() |
|
token = self.grammar_constraint.tokenizer.decode([token_id]) |
|
|
|
|
|
accepted_info.append({ |
|
"token_id": token_id, |
|
"token": str(token), |
|
"raw_score": raw_score, |
|
"raw_likelihood": likelihood |
|
}) |
|
|
|
batch_accepted_info.append(accepted_info) |
|
|
|
|
|
self.history.append(batch_accepted_info) |
|
|
|
class GrammarAlignedOracleLogitsProcessor(LogitsProcessor): |
|
def __init__(self, grammar_constraint, oracle_trie=Trie(), parse_start_index=None, save_log=False): |
|
|
|
self.grammar_constraint = grammar_constraint |
|
self.batch_parsing_states = None |
|
self.parse_start_index = parse_start_index |
|
|
|
|
|
self.oracle_trie = oracle_trie |
|
|
|
|
|
self.generate_start_index = None |
|
self.generated_tokens = None |
|
|
|
|
|
self.save_log = save_log |
|
self.history = [] |
|
|
|
def adjust_scores(self, scores, device): |
|
""" |
|
resolve each stack to a tensor of True/False for each token |
|
indicating acceptance |
|
""" |
|
acceptance = self.grammar_constraint.batch_filter_vocab( |
|
self.batch_parsing_states, device |
|
) |
|
|
|
current_parent = self.oracle_trie.search_last_parent(self.generated_tokens) |
|
current_parent.insert_accepted_tokens(scores, acceptance) |
|
adjusted_scores = self.apply_oracle_adjustments(acceptance, scores, current_parent) |
|
|
|
if self.save_log: |
|
self.store_detailed_history(acceptance, scores, adjusted_scores) |
|
|
|
|
|
adjusted_scores[~acceptance] = -math.inf |
|
|
|
return adjusted_scores |
|
|
|
def apply_oracle_adjustments(self, acceptance, scores, current_parent): |
|
""" |
|
Multiply expected future grammarticality |
|
Use the normalized (and unmasked) probabiltiy |
|
|
|
Parameters: |
|
- acceptance (torch.Tensor): A characteristic vector of valid tokens |
|
used to updated only valid tokens |
|
- scores (torch.Tensor): Unnormalized logits from language model |
|
- current_parent (TrieNode): The trie node for the current prefix |
|
""" |
|
adjusted_scores = scores.clone() |
|
likelihoods = F.softmax(adjusted_scores, dim=-1) |
|
log_likelihoods = torch.log(likelihoods) |
|
|
|
for batch_index in range(acceptance.size(0)): |
|
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1) |
|
|
|
for idx in accepted_indices: |
|
token_id = idx.item() |
|
log_likelihood = log_likelihoods[batch_index, idx].item() |
|
|
|
|
|
success_rate = current_parent.get_success_rate(token_id) |
|
|
|
if not isinstance(success_rate, torch.Tensor): |
|
success_rate = torch.tensor(success_rate, dtype=torch.float) |
|
log_theta = torch.log(success_rate) |
|
|
|
|
|
adjusted_score = log_likelihood + log_theta |
|
adjusted_scores[batch_index, idx] = adjusted_score |
|
|
|
return adjusted_scores |
|
|
|
def process_scores(self, input_ids, scores): |
|
|
|
if self.batch_parsing_states is None: |
|
self.batch_parsing_states = [ |
|
copy.deepcopy( |
|
self.grammar_constraint.string_recognizer.get_initial_accept_state() |
|
) |
|
for _ in range(len(input_ids)) |
|
] |
|
|
|
|
|
if self.generate_start_index is None: |
|
|
|
self.generate_start_index = self.parse_start_index \ |
|
if self.parse_start_index else input_ids.size(1) |
|
self.generated_tokens = input_ids[:, self.generate_start_index:] |
|
|
|
|
|
self.batch_parsing_states = self.grammar_constraint.advance_token_ids( |
|
input_ids, self.batch_parsing_states, self.parse_start_index |
|
) |
|
|
|
adjusted_scores = self.adjust_scores(scores, scores.device) |
|
|
|
return adjusted_scores |
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
|
def __call__( |
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor |
|
) -> torch.FloatTensor: |
|
return self.process_scores(input_ids, scores) |
|
|
|
def reset(self): |
|
self.reset_parser() |
|
self.reset_history() |
|
|
|
def reset_parser(self): |
|
self.batch_parsing_states = None |
|
if self.grammar_constraint.is_incremental: |
|
self.grammar_constraint.reset() |
|
|
|
self.generate_start_index = None |
|
self.generated_tokens = None |
|
|
|
def reset_history(self): |
|
self.history = [] |
|
|
|
def reset_trie(self): |
|
self.oracle_trie = Trie() |
|
|
|
def get_accepted_tokens(self, acceptance): |
|
""" |
|
Get the indices of accepted tokens and their corresponding string values for each item in the batch. |
|
|
|
Parameters: |
|
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch. |
|
""" |
|
batch_size, _ = acceptance.shape |
|
acceptance_np = acceptance.cpu().numpy() |
|
accepted_x, accepted_y = acceptance_np.nonzero() |
|
|
|
|
|
accepted_token_indices = {i: [] for i in range(batch_size)} |
|
for x, y in zip(accepted_x, accepted_y): |
|
accepted_token_indices[x].append(y) |
|
|
|
|
|
accepted_tokens = { |
|
i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids] |
|
for i, token_ids in accepted_token_indices.items() |
|
} |
|
|
|
return accepted_tokens |
|
|
|
def store_detailed_history(self, acceptance, scores, adjusted_scores): |
|
""" |
|
Processes and stores information for accepted tokens including their IDs, tokens, |
|
raw scores, and logits. |
|
|
|
Parameters: |
|
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch. |
|
- scores (torch.Tensor): The raw scores from the model output. |
|
- adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality. |
|
""" |
|
likelihoods = F.softmax(scores, dim=-1) |
|
adjusted_likelihoods = F.softmax(adjusted_scores, dim=-1) |
|
|
|
|
|
batch_accepted_info = [] |
|
|
|
for batch_index in range(acceptance.size(0)): |
|
accepted_info = [] |
|
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1) |
|
|
|
for idx in accepted_indices: |
|
token_id = idx.item() |
|
raw_score = scores[batch_index, idx].item() |
|
likelihood = likelihoods[batch_index, idx].item() |
|
adjusted_likelihood = adjusted_likelihoods[batch_index, idx].item() |
|
token = self.grammar_constraint.tokenizer.decode([token_id]) |
|
|
|
|
|
accepted_info.append({ |
|
"token_id": token_id, |
|
"token": str(token), |
|
"raw_score": raw_score, |
|
"raw_likelihood": likelihood, |
|
"adjusted_likelihood": adjusted_likelihood |
|
}) |
|
|
|
batch_accepted_info.append(accepted_info) |
|
|
|
|
|
self.history.append(batch_accepted_info) |
|
|