relik-entity-linking / relik /reader /relik_reader_core.py
riccorl's picture
first commit
626eca0
import collections
from typing import Any, Dict, Iterator, List, Optional
import torch
from transformers import AutoModel
from transformers.activations import ClippedGELUActivation, GELUActivation
from transformers.modeling_utils import PoolerEndLogits
from relik.reader.data.relik_reader_sample import RelikReaderSample
activation2functions = {
"relu": torch.nn.ReLU(),
"gelu": GELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10),
}
class RelikReaderCoreModel(torch.nn.Module):
def __init__(
self,
transformer_model: str,
additional_special_symbols: int,
num_layers: Optional[int] = None,
activation: str = "gelu",
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
training: bool = False,
) -> None:
super().__init__()
# Transformer model declaration
self.transformer_model_name = transformer_model
self.transformer_model = (
AutoModel.from_pretrained(transformer_model)
if num_layers is None
else AutoModel.from_pretrained(
transformer_model, num_hidden_layers=num_layers
)
)
self.transformer_model.resize_token_embeddings(
self.transformer_model.config.vocab_size + additional_special_symbols
)
self.activation = activation
self.linears_hidden_size = linears_hidden_size
self.use_last_k_layers = use_last_k_layers
# named entity detection layers
self.ned_start_classifier = self._get_projection_layer(
self.activation, last_hidden=2, layer_norm=False
)
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
# END entity disambiguation layer
self.ed_start_projector = self._get_projection_layer(self.activation)
self.ed_end_projector = self._get_projection_layer(self.activation)
self.training = training
# criterion
self.criterion = torch.nn.CrossEntropyLoss()
def _get_projection_layer(
self,
activation: str,
last_hidden: Optional[int] = None,
input_hidden=None,
layer_norm: bool = True,
) -> torch.nn.Sequential:
head_components = [
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.transformer_model.config.hidden_size * self.use_last_k_layers
if input_hidden is None
else input_hidden,
self.linears_hidden_size,
),
activation2functions[activation],
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.linears_hidden_size,
self.linears_hidden_size if last_hidden is None else last_hidden,
),
]
if layer_norm:
head_components.append(
torch.nn.LayerNorm(
self.linears_hidden_size if last_hidden is None else last_hidden,
self.transformer_model.config.layer_norm_eps,
)
)
return torch.nn.Sequential(*head_components)
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1)
if next(self.parameters()).dtype == torch.float16:
logits = logits * (1 - mask) - 65500 * mask
else:
logits = logits * (1 - mask) - 1e30 * mask
return logits
def _get_model_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
):
model_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": self.use_last_k_layers > 1,
}
if token_type_ids is not None:
model_input["token_type_ids"] = token_type_ids
model_output = self.transformer_model(**model_input)
if self.use_last_k_layers > 1:
model_features = torch.cat(
model_output[1][-self.use_last_k_layers :], dim=-1
)
else:
model_features = model_output[0]
return model_features
def compute_ned_end_logits(
self,
start_predictions,
start_labels,
model_features,
prediction_mask,
batch_size,
) -> Optional[torch.Tensor]:
# todo: maybe when constraining on the spans,
# we should not use a prediction_mask for the end tokens.
# at least we should not during training imo
start_positions = start_labels if self.training else start_predictions
start_positions_indices = (
torch.arange(start_positions.size(1), device=start_positions.device)
.unsqueeze(0)
.expand(batch_size, -1)[start_positions > 0]
).to(start_positions.device)
if len(start_positions_indices) > 0:
expanded_features = torch.cat(
[
model_features[i].unsqueeze(0).expand(x, -1, -1)
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
if x > 0
],
dim=0,
).to(start_positions_indices.device)
expanded_prediction_mask = torch.cat(
[
prediction_mask[i].unsqueeze(0).expand(x, -1)
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
if x > 0
],
dim=0,
).to(expanded_features.device)
end_logits = self.ned_end_classifier(
hidden_states=expanded_features,
start_positions=start_positions_indices,
p_mask=expanded_prediction_mask,
)
return end_logits
return None
def compute_classification_logits(
self,
model_features,
special_symbols_mask,
prediction_mask,
batch_size,
start_positions=None,
end_positions=None,
) -> torch.Tensor:
if start_positions is None or end_positions is None:
start_positions = torch.zeros_like(prediction_mask)
end_positions = torch.zeros_like(prediction_mask)
model_start_features = self.ed_start_projector(model_features)
model_end_features = self.ed_end_projector(model_features)
model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
model_ed_features = torch.cat(
[model_start_features, model_end_features], dim=-1
)
# computing ed features
classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
special_symbols_representation = model_ed_features[special_symbols_mask].view(
batch_size, classes_representations, -1
)
logits = torch.bmm(
model_ed_features,
torch.permute(special_symbols_representation, (0, 2, 1)),
)
logits = self._mask_logits(logits, prediction_mask)
return logits
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
start_labels: Optional[torch.Tensor] = None,
end_labels: Optional[torch.Tensor] = None,
use_predefined_spans: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
batch_size, seq_len = input_ids.shape
model_features = self._get_model_features(
input_ids, attention_mask, token_type_ids
)
# named entity detection if required
if use_predefined_spans: # no need to compute spans
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
None,
None,
torch.clone(start_labels)
if start_labels is not None
else torch.zeros_like(input_ids),
)
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
None,
None,
torch.clone(end_labels)
if end_labels is not None
else torch.zeros_like(input_ids),
)
ned_start_predictions[ned_start_predictions > 0] = 1
ned_end_predictions[ned_end_predictions > 0] = 1
else: # compute spans
# start boundary prediction
ned_start_logits = self.ned_start_classifier(model_features)
ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
# end boundary prediction
ned_start_labels = (
torch.zeros_like(start_labels) if start_labels is not None else None
)
if ned_start_labels is not None:
ned_start_labels[start_labels == -100] = -100
ned_start_labels[start_labels > 0] = 1
ned_end_logits = self.compute_ned_end_logits(
ned_start_predictions,
ned_start_labels,
model_features,
prediction_mask,
batch_size,
)
if ned_end_logits is not None:
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
else:
ned_end_logits, ned_end_probabilities = None, None
ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
# flattening end predictions
# (flattening can happen only if the
# end boundaries were not predicted using the gold labels)
if not self.training:
flattened_end_predictions = torch.clone(ned_start_predictions)
flattened_end_predictions[flattened_end_predictions > 0] = 0
batch_start_predictions = list()
for elem_idx in range(batch_size):
batch_start_predictions.append(
torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
)
# check that the total number of start predictions
# is equal to the end predictions
total_start_predictions = sum(map(len, batch_start_predictions))
total_end_predictions = len(ned_end_predictions)
assert (
total_start_predictions == 0
or total_start_predictions == total_end_predictions
), (
f"Total number of start predictions = {total_start_predictions}. "
f"Total number of end predictions = {total_end_predictions}"
)
curr_end_pred_num = 0
for elem_idx, bsp in enumerate(batch_start_predictions):
for sp in bsp:
ep = ned_end_predictions[curr_end_pred_num].item()
if ep < sp:
ep = sp
# if we already set this span throw it (no overlap)
if flattened_end_predictions[elem_idx, ep] == 1:
ned_start_predictions[elem_idx, sp] = 0
else:
flattened_end_predictions[elem_idx, ep] = 1
curr_end_pred_num += 1
ned_end_predictions = flattened_end_predictions
start_position, end_position = (
(start_labels, end_labels)
if self.training
else (ned_start_predictions, ned_end_predictions)
)
# Entity disambiguation
ed_logits = self.compute_classification_logits(
model_features,
special_symbols_mask,
prediction_mask,
batch_size,
start_position,
end_position,
)
ed_probabilities = torch.softmax(ed_logits, dim=-1)
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
# output build
output_dict = dict(
batch_size=batch_size,
ned_start_logits=ned_start_logits,
ned_start_probabilities=ned_start_probabilities,
ned_start_predictions=ned_start_predictions,
ned_end_logits=ned_end_logits,
ned_end_probabilities=ned_end_probabilities,
ned_end_predictions=ned_end_predictions,
ed_logits=ed_logits,
ed_probabilities=ed_probabilities,
ed_predictions=ed_predictions,
)
# compute loss if labels
if start_labels is not None and end_labels is not None and self.training:
# named entity detection loss
# start
if ned_start_logits is not None:
ned_start_loss = self.criterion(
ned_start_logits.view(-1, ned_start_logits.shape[-1]),
ned_start_labels.view(-1),
)
else:
ned_start_loss = 0
# end
if ned_end_logits is not None:
ned_end_labels = torch.zeros_like(end_labels)
ned_end_labels[end_labels == -100] = -100
ned_end_labels[end_labels > 0] = 1
ned_end_loss = self.criterion(
ned_end_logits,
(
torch.arange(
ned_end_labels.size(1), device=ned_end_labels.device
)
.unsqueeze(0)
.expand(batch_size, -1)[ned_end_labels > 0]
).to(ned_end_labels.device),
)
else:
ned_end_loss = 0
# entity disambiguation loss
start_labels[ned_start_labels != 1] = -100
ed_labels = torch.clone(start_labels)
ed_labels[end_labels > 0] = end_labels[end_labels > 0]
ed_loss = self.criterion(
ed_logits.view(-1, ed_logits.shape[-1]),
ed_labels.view(-1),
)
output_dict["ned_start_loss"] = ned_start_loss
output_dict["ned_end_loss"] = ned_end_loss
output_dict["ed_loss"] = ed_loss
output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
return output_dict
def batch_predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
sample: Optional[List[RelikReaderSample]] = None,
top_k: int = 5, # the amount of top-k most probable entities to predict
*args,
**kwargs,
) -> Iterator[RelikReaderSample]:
forward_output = self.forward(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
)
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
ed_predictions = forward_output["ed_predictions"].cpu().numpy()
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
batch_predictable_candidates = kwargs["predictable_candidates"]
patch_offset = kwargs["patch_offset"]
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
sample,
ned_start_predictions,
ned_end_predictions,
ed_predictions,
ed_probabilities,
batch_predictable_candidates,
patch_offset,
):
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
final_class2predicted_spans = collections.defaultdict(list)
spans2predicted_probabilities = dict()
for start_token_index, end_token_index in zip(
ne_start_indices, ne_end_indices
):
# predicted candidate
token_class = edp[start_token_index + 1] - 1
predicted_candidate_title = pred_cands[token_class]
final_class2predicted_spans[predicted_candidate_title].append(
[start_token_index, end_token_index]
)
# candidates probabilities
classes_probabilities = edpr[start_token_index + 1]
classes_probabilities_best_indices = classes_probabilities.argsort()[
::-1
]
titles_2_probs = []
top_k = (
min(
top_k,
len(classes_probabilities_best_indices),
)
if top_k != -1
else len(classes_probabilities_best_indices)
)
for i in range(top_k):
titles_2_probs.append(
(
pred_cands[classes_probabilities_best_indices[i] - 1],
classes_probabilities[
classes_probabilities_best_indices[i]
].item(),
)
)
spans2predicted_probabilities[
(start_token_index, end_token_index)
] = titles_2_probs
if "patches" not in ts._d:
ts._d["patches"] = dict()
ts._d["patches"][po] = dict()
sample_patch = ts._d["patches"][po]
sample_patch["predicted_window_labels"] = final_class2predicted_spans
sample_patch["span_title_probabilities"] = spans2predicted_probabilities
# additional info
sample_patch["predictable_candidates"] = pred_cands
yield ts