exbert / server /transformer_details.py
bhoov's picture
First commit
63858e7
"""
Utilities for interfacing with the attentions from the front end.
"""
import torch
from typing import List, Union
from abc import ABC, abstractmethod
from transformer_formatter import TransformerOutputFormatter
from utils.token_processing import reshape
from spacyface import (
BertAligner,
GPT2Aligner,
RobertaAligner,
DistilBertAligner,
auto_aligner
)
from transformers import (
BertForMaskedLM,
GPT2LMHeadModel,
RobertaForMaskedLM,
DistilBertForMaskedLM,
)
from utils.f import delegates, pick, memoize
def get_cls(class_name):
cls_type = {
'bert-base-uncased': BertDetails,
'bert-base-cased': BertDetails,
'bert-large-uncased': BertDetails,
'bert-large-cased': BertDetails,
'gpt2': GPT2Details,
'gpt2-medium': GPT2Details,
'gpt2-large': GPT2Details,
'roberta-base': RobertaDetails,
'roberta-large': RobertaDetails,
'roberta-large-mnli': RobertaDetails,
'roberta-base-openai-detector': RobertaDetails,
'roberta-large-openai-detector': RobertaDetails,
'distilbert-base-uncased': DistilBertDetails,
'distilbert-base-uncased-distilled-squad': DistilBertDetails,
'distilgpt2': GPT2Details,
'distilroberta-base': RobertaDetails,
}
return cls_type[class_name]
@memoize
def from_pretrained(model_name):
"""Convert model name into appropriate transformer details"""
try: out = get_cls(model_name).from_pretrained(model_name)
except KeyError: raise KeyError(f"The model name of '{model_name}' either does not exist or is currently not supported")
return out
class TransformerBaseDetails(ABC):
""" All API calls will interact with this class to get the hidden states and attentions for any input sentence."""
def __init__(self, model, aligner):
self.model = model
self.aligner = aligner
self.model.eval()
self.forward_inputs = ['input_ids', 'attention_mask']
@classmethod
def from_pretrained(cls, model_name: str):
raise NotImplementedError(
"""Inherit from this class and specify the Model and Aligner to use"""
)
def att_from_sentence(self, s: str, mask_attentions=False) -> TransformerOutputFormatter:
"""Get formatted attention from a single sentence input"""
tokens = self.aligner.tokenize(s)
return self.att_from_tokens(tokens, s, add_special_tokens=True, mask_attentions=mask_attentions)
def att_from_tokens(
self, tokens: List[str], orig_sentence, add_special_tokens=False, mask_attentions=False
) -> TransformerOutputFormatter:
"""Get formatted attention from a list of tokens, using the original sentence for getting Spacy Metadata"""
ids = self.aligner.convert_tokens_to_ids(tokens)
# For GPT2, add the beginning of sentence token to the input. Note that this will work on all models but XLM
bost = self.aligner.bos_token_id
clst = self.aligner.cls_token_id
if (bost is not None) and (bost != clst) and add_special_tokens:
ids.insert(0, bost)
inputs = self.aligner.prepare_for_model(ids, add_special_tokens=add_special_tokens, return_tensors="pt")
parsed_input = self.format_model_input(inputs, mask_attentions=mask_attentions)
output = self.model(parsed_input['input_ids'], attention_mask=parsed_input['attention_mask'])
return self.format_model_output(inputs, orig_sentence, output)
def format_model_output(self, inputs, sentence:str, output, topk=5):
"""Convert model output to the desired format.
Formatter additionally needs access to the tokens and the original sentence
"""
hidden_state, attentions, contexts, logits = self.select_outputs(output)
words, probs = self.logits2words(logits, topk)
tokens = self.view_ids(inputs["input_ids"])
toks = self.aligner.meta_from_tokens(sentence, tokens, perform_check=False)
formatted_output = TransformerOutputFormatter(
sentence,
toks,
inputs["special_tokens_mask"],
attentions,
hidden_state,
contexts,
words,
probs.tolist()
)
return formatted_output
def select_outputs(self, output):
"""Extract the desired hidden states as passed by a particular model through the output
In all cases, we care for:
- hidden state embeddings (tuple of n_layers + 1)
- attentions (tuple of n_layers)
- contexts (tuple of n_layers)
- Top predicted words
- Probabilities of top predicted words
"""
logits, hidden_state, attentions, contexts = output
return hidden_state, attentions, contexts, logits
def format_model_input(self, inputs, mask_attentions=False):
"""Parse the input for the model according to what is expected in the forward pass.
If not otherwise defined, outputs a dict containing the keys:
{'input_ids', 'attention_mask'}
"""
return pick(self.forward_inputs, self.parse_inputs(inputs, mask_attentions=mask_attentions))
def logits2words(self, logits, topk=5):
probs, idxs = torch.topk(torch.softmax(logits.squeeze(0), 1), topk)
words = [self.aligner.convert_ids_to_tokens(i) for i in idxs]
return words, probs
def view_ids(self, ids: Union[List[int], torch.Tensor]) -> List[str]:
"""View what the tokenizer thinks certain ids are"""
if type(ids) == torch.Tensor:
# Remove batch dimension
ids = ids.squeeze(0).tolist()
out = self.aligner.convert_ids_to_tokens(ids)
return out
def parse_inputs(self, inputs, mask_attentions=False):
"""Parse the output from `tokenizer.prepare_for_model` to the desired attention mask from special tokens
Args:
- inputs: The output of `tokenizer.prepare_for_model`.
A dict with keys: {'special_token_mask', 'token_type_ids', 'input_ids'}
- mask_attentions: Flag indicating whether to mask the attentions or not
Returns:
Dict with keys: {'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'}
Usage:
```
s = "test sentence"
# from raw sentence to tokens
tokens = tokenizer.tokenize(s)
# From tokens to ids
ids = tokenizer.convert_tokens_to_ids(tokens)
# From ids to input
inputs = tokenizer.prepare_for_model(ids, return_tensors='pt')
# Parse the input. Optionally mask the special tokens from the analysis.
parsed_input = parse_inputs(inputs)
# Run the model, pick from this output whatever inputs you want
from utils.f import pick
out = model(**pick(['input_ids'], parse_inputs(inputs)))
```
"""
out = inputs.copy()
# DEFINE SPECIAL TOKENS MASK
if "special_tokens_mask" not in inputs.keys():
special_tokens = set([self.aligner.unk_token_id, self.aligner.cls_token_id, self.aligner.sep_token_id, self.aligner.bos_token_id, self.aligner.eos_token_id, self.aligner.pad_token_id])
in_ids = inputs['input_ids'][0]
special_tok_mask = [1 if int(i) in special_tokens else 0 for i in in_ids]
inputs['special_tokens_mask'] = special_tok_mask
if mask_attentions:
out["attention_mask"] = torch.tensor(
[int(not i) for i in inputs.get("special_tokens_mask")]
).unsqueeze(0)
else:
out["attention_mask"] = torch.tensor(
[1 for i in inputs.get("special_tokens_mask")]
).unsqueeze(0)
return out
class BertDetails(TransformerBaseDetails):
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
BertForMaskedLM.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
BertAligner.from_pretrained(model_name),
)
class GPT2Details(TransformerBaseDetails):
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
GPT2LMHeadModel.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
GPT2Aligner.from_pretrained(model_name),
)
def select_outputs(self, output):
logits, _ , hidden_states, att, contexts = output
return hidden_states, att, contexts, logits
class RobertaDetails(TransformerBaseDetails):
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
RobertaForMaskedLM.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
RobertaAligner.from_pretrained(model_name),
)
class DistilBertDetails(TransformerBaseDetails):
def __init__(self, model, aligner):
super().__init__(model, aligner)
self.forward_inputs = ['input_ids', 'attention_mask']
@classmethod
def from_pretrained(cls, model_name: str):
return cls(
DistilBertForMaskedLM.from_pretrained(
model_name,
output_attentions=True,
output_hidden_states=True,
output_additional_info=True,
),
DistilBertAligner.from_pretrained(model_name),
)
def select_outputs(self, output):
"""Extract the desired hidden states as passed by a particular model through the output
In all cases, we care for:
- hidden state embeddings (tuple of n_layers + 1)
- attentions (tuple of n_layers)
- contexts (tuple of n_layers)
"""
logits, hidden_states, attentions, contexts = output
contexts = tuple([c.permute(0, 2, 1, 3).contiguous() for c in contexts])
return hidden_states, attentions, contexts, logits