kjcjohnson's picture
Add GAD libraries
901bbd9
import copy
import math
import torch.nn.functional as F
import torch
import logging
from transformers.generation.logits_process import (
LogitsProcessor,
LOGITS_PROCESSOR_INPUTS_DOCSTRING,
)
from transformers.utils import add_start_docstrings
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.oracle.oracle_trie import Trie
class GrammarConstrainedLogitsProcessor(LogitsProcessor):
def __init__(self, grammar_constraint, parse_start_index=None, save_log=False):
# Parser variables
self.grammar_constraint = grammar_constraint
self.batch_parsing_states = None
self.parse_start_index = parse_start_index
# To start with a longer prefix in enumerative search
self.generate_start_index = None
self.generated_tokens = None
# Generation Log
self.save_log = save_log
self.history = []
def reset(self):
self.reset_parser()
self.reset_history()
def reset_parser(self):
self.batch_parsing_states = None
if self.grammar_constraint.is_incremental:
self.grammar_constraint.reset()
self.generate_start_index = None
self.generated_tokens = None
def reset_history(self):
self.history = []
def mask_scores(self, scores, device):
"""
resolve each stack to a tensor of True/False for each token
indicating acceptance
"""
masked_scores = scores.clone()
acceptance = self.grammar_constraint.batch_filter_vocab(
self.batch_parsing_states, device
)
if self.save_log:
self.store_detailed_history(acceptance, scores)
# Scores to -inf where False
masked_scores[~acceptance] = -math.inf
return masked_scores
def process_scores(self, input_ids, scores):
# we dynamically create stacks at the first call, so that we know the batch size and beam size
if self.batch_parsing_states is None:
self.batch_parsing_states = [
copy.deepcopy(
self.grammar_constraint.string_recognizer.get_initial_accept_state()
)
for _ in range(len(input_ids))
]
# assume the generation starts from the same index
if self.generate_start_index is None:
# the default is the end of input sequence of tokens
self.generate_start_index = self.parse_start_index \
if self.parse_start_index else input_ids.size(1)
self.generated_tokens = input_ids[:, self.generate_start_index:]
# Advance parser states
self.batch_parsing_states = self.grammar_constraint.advance_token_ids(
input_ids, self.batch_parsing_states, self.parse_start_index
)
masked_scores = self.mask_scores(scores, scores.device)
return masked_scores
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
return self.process_scores(input_ids, scores)
def reset_parser(self):
self.batch_parsing_states = None
if isinstance(self.grammar_constraint, IncrementalGrammarConstraint):
self.grammar_constraint.reset()
def get_accepted_tokens(self, acceptance):
"""
Get the indices of accepted tokens and their corresponding string values for each item in the batch.
Parameters:
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
"""
batch_size, _ = acceptance.shape
acceptance_np = acceptance.cpu().numpy()
accepted_x, accepted_y = acceptance_np.nonzero()
# Initialize the dictionary with empty lists for indices
accepted_token_indices = {i: [] for i in range(batch_size)}
for x, y in zip(accepted_x, accepted_y):
accepted_token_indices[x].append(y)
# Convert token IDs to tokens
accepted_tokens = {
i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids]
for i, token_ids in accepted_token_indices.items()
}
return accepted_tokens
def store_detailed_history(self, acceptance, scores):
"""
Processes and stores information for accepted tokens including their IDs, tokens,
raw scores, and logits.
Parameters:
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
- scores (torch.Tensor): The raw scores from the model output.
- adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality.
"""
likelihoods = F.softmax(scores, dim=-1)
# Initializing the list to store detailed information for each step
batch_accepted_info = []
for batch_index in range(acceptance.size(0)): # Iterate over batch items
accepted_info = []
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
for idx in accepted_indices:
token_id = idx.item()
raw_score = scores[batch_index, idx].item()
likelihood = likelihoods[batch_index, idx].item()
token = self.grammar_constraint.tokenizer.decode([token_id])
# Store detailed information as a dictionary
accepted_info.append({
"token_id": token_id,
"token": str(token),
"raw_score": raw_score,
"raw_likelihood": likelihood
})
batch_accepted_info.append(accepted_info)
# Store this detailed information in the history
self.history.append(batch_accepted_info)
class GrammarAlignedOracleLogitsProcessor(LogitsProcessor):
def __init__(self, grammar_constraint, oracle_trie=Trie(), parse_start_index=None, save_log=False):
# Parser variables
self.grammar_constraint = grammar_constraint
self.batch_parsing_states = None
self.parse_start_index = parse_start_index
# ASAp oracle trie
self.oracle_trie = oracle_trie
# To start with a longer prefix in enumerative search
self.generate_start_index = None
self.generated_tokens = None
# Generation Log
self.save_log = save_log
self.history = []
def adjust_scores(self, scores, device):
"""
resolve each stack to a tensor of True/False for each token
indicating acceptance
"""
acceptance = self.grammar_constraint.batch_filter_vocab(
self.batch_parsing_states, device
)
current_parent = self.oracle_trie.search_last_parent(self.generated_tokens)
current_parent.insert_accepted_tokens(scores, acceptance)
adjusted_scores = self.apply_oracle_adjustments(acceptance, scores, current_parent)
if self.save_log:
self.store_detailed_history(acceptance, scores, adjusted_scores)
# Scores to -inf where False
adjusted_scores[~acceptance] = -math.inf
return adjusted_scores
def apply_oracle_adjustments(self, acceptance, scores, current_parent):
"""
Multiply expected future grammarticality
Use the normalized (and unmasked) probabiltiy
Parameters:
- acceptance (torch.Tensor): A characteristic vector of valid tokens
used to updated only valid tokens
- scores (torch.Tensor): Unnormalized logits from language model
- current_parent (TrieNode): The trie node for the current prefix
"""
adjusted_scores = scores.clone()
likelihoods = F.softmax(adjusted_scores, dim=-1)
log_likelihoods = torch.log(likelihoods)
for batch_index in range(acceptance.size(0)):
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
for idx in accepted_indices:
token_id = idx.item()
log_likelihood = log_likelihoods[batch_index, idx].item()
# Get theta (log of expected future grammaticality) for this specific token
success_rate = current_parent.get_success_rate(token_id)
if not isinstance(success_rate, torch.Tensor):
success_rate = torch.tensor(success_rate, dtype=torch.float)
log_theta = torch.log(success_rate)
# Calculate adjusted score
adjusted_score = log_likelihood + log_theta
adjusted_scores[batch_index, idx] = adjusted_score
return adjusted_scores
def process_scores(self, input_ids, scores):
# we dynamically create stacks at the first call, so that we know the batch size and beam size
if self.batch_parsing_states is None:
self.batch_parsing_states = [
copy.deepcopy(
self.grammar_constraint.string_recognizer.get_initial_accept_state()
)
for _ in range(len(input_ids))
]
# assume the generation starts from the same index
if self.generate_start_index is None:
# the default is the end of input sequence of tokens
self.generate_start_index = self.parse_start_index \
if self.parse_start_index else input_ids.size(1)
self.generated_tokens = input_ids[:, self.generate_start_index:]
# Advance parser states
self.batch_parsing_states = self.grammar_constraint.advance_token_ids(
input_ids, self.batch_parsing_states, self.parse_start_index
)
adjusted_scores = self.adjust_scores(scores, scores.device)
return adjusted_scores
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
return self.process_scores(input_ids, scores)
def reset(self):
self.reset_parser()
self.reset_history()
def reset_parser(self):
self.batch_parsing_states = None
if self.grammar_constraint.is_incremental:
self.grammar_constraint.reset()
self.generate_start_index = None
self.generated_tokens = None
def reset_history(self):
self.history = []
def reset_trie(self):
self.oracle_trie = Trie()
def get_accepted_tokens(self, acceptance):
"""
Get the indices of accepted tokens and their corresponding string values for each item in the batch.
Parameters:
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
"""
batch_size, _ = acceptance.shape
acceptance_np = acceptance.cpu().numpy()
accepted_x, accepted_y = acceptance_np.nonzero()
# Initialize the dictionary with empty lists for indices
accepted_token_indices = {i: [] for i in range(batch_size)}
for x, y in zip(accepted_x, accepted_y):
accepted_token_indices[x].append(y)
# Convert token IDs to tokens
accepted_tokens = {
i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids]
for i, token_ids in accepted_token_indices.items()
}
return accepted_tokens
def store_detailed_history(self, acceptance, scores, adjusted_scores):
"""
Processes and stores information for accepted tokens including their IDs, tokens,
raw scores, and logits.
Parameters:
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
- scores (torch.Tensor): The raw scores from the model output.
- adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality.
"""
likelihoods = F.softmax(scores, dim=-1)
adjusted_likelihoods = F.softmax(adjusted_scores, dim=-1)
# Initializing the list to store detailed information for each step
batch_accepted_info = []
for batch_index in range(acceptance.size(0)): # Iterate over batch items
accepted_info = []
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
for idx in accepted_indices:
token_id = idx.item()
raw_score = scores[batch_index, idx].item()
likelihood = likelihoods[batch_index, idx].item()
adjusted_likelihood = adjusted_likelihoods[batch_index, idx].item()
token = self.grammar_constraint.tokenizer.decode([token_id])
# Store detailed information as a dictionary
accepted_info.append({
"token_id": token_id,
"token": str(token),
"raw_score": raw_score,
"raw_likelihood": likelihood,
"adjusted_likelihood": adjusted_likelihood
})
batch_accepted_info.append(accepted_info)
# Store this detailed information in the history
self.history.append(batch_accepted_info)