Spaces:
Runtime error
Runtime error
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
from dataclasses import dataclass | |
from typing import List, Optional | |
from utils import get_preprocess_function, get_utterance_processing_functions, byt5_decode_batch, consistent | |
from utils import PROGRAM_SPECIAL_TOKEN, UTTERANCES_SPECIAL_TOKEN, GT_PROGRAM_SPECIAL_TOKEN | |
from greenery import parse | |
from greenery.parse import NoMatch | |
import numpy as np | |
import torch | |
class Agent: | |
def __init__(self, | |
model_path: str, | |
gen_config: dict, | |
inference_batch_size: int = 1, | |
): | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.gen_config = GenerationConfig(**gen_config) | |
self.inference_batch_size = inference_batch_size | |
class ListenerOutput: | |
programs: List[List[str]] | |
idx: Optional[List[List[int]]] = None | |
decoded: Optional[List[List[str]]] = None | |
decoded_scores: Optional[List[List[float]]] = None | |
pruned: Optional[List[List[str]]] = None | |
class Listener(Agent): | |
def __init__(self, | |
model_path, | |
gen_config, | |
inference_batch_size=4, | |
label_pos="suffix", | |
idx: bool=True, | |
program_special_token=PROGRAM_SPECIAL_TOKEN, | |
utterances_special_token=UTTERANCES_SPECIAL_TOKEN | |
): | |
super().__init__( | |
model_path, | |
gen_config, | |
inference_batch_size, | |
) | |
self.label_pos = label_pos | |
self.idx = idx | |
self.program_special_token = program_special_token | |
self.utterances_special_token = utterances_special_token | |
self.utterances_to_string, self.string_to_utterances = ( | |
get_utterance_processing_functions( | |
label_pos, idx, separator=utterances_special_token | |
) | |
) | |
self.device = self.model.device | |
def synthesize(self, context, return_scores=False, enforce_consistency=True): | |
# If context is a list of utterances, convert to string | |
if isinstance(context[0], list): | |
context_str = list(map(self.utterances_to_string, context)) | |
else: | |
context_str = context | |
context_tokens = self.tokenizer( | |
[f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c | |
for c in context_str], | |
return_tensors="pt", | |
padding=True | |
).to(self.device) | |
decoder_inputs = self.tokenizer( | |
[self.program_special_token for _ in context], return_tensors="pt", | |
add_special_tokens=False | |
).to(self.device) | |
outputs = self.model.generate(**context_tokens, | |
decoder_input_ids=decoder_inputs.input_ids, | |
generation_config=self.gen_config, | |
return_dict_in_generate=True, | |
output_scores=True | |
) | |
decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True) | |
consistent_programs = [] | |
idxs = [] | |
for decoded, ctx in zip(decoded_batch, context): | |
cp = [] | |
idx = [] | |
for i, p in enumerate(decoded): | |
if enforce_consistency: | |
if consistent(p, ctx): | |
cp.append(p) | |
idx.append(i) | |
else: | |
cp.append(p) | |
idx.append(i) | |
consistent_programs.append(cp) | |
idxs.append(idx) | |
logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) | |
gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1) | |
gen_probs.masked_fill_(gen_probs.isinf(), 0) | |
scores = gen_probs.sum(-1) | |
n_decoded = scores.shape[0] | |
n_seq = n_decoded // len(context) | |
scores = scores.reshape((len(context), n_seq)) | |
scores_list = scores.tolist() | |
if return_scores: | |
return ListenerOutput( | |
consistent_programs, | |
idxs, | |
decoded_batch, | |
scores_list | |
) | |
else: | |
return ListenerOutput(consistent_programs) | |