CarlosMalaga's picture
Upload 201 files
2f044c1 verified
import contextlib
import logging
from typing import Any, Dict, Iterator, List
import numpy as np
import torch
import transformers as tr
from lightning_fabric.utilities import move_data_to_device
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm
from relik.common.log import get_logger
from relik.common.torch_utils import get_autocast_context
from relik.common.utils import get_callable_from_string
from relik.inference.data.objects import AnnotationType
from relik.reader.data.relik_reader_sample import RelikReaderSample
from relik.reader.pytorch_modules.base import RelikReaderBase
from relik.retriever.pytorch_modules import PRECISION_MAP
logger = get_logger(__name__, level=logging.INFO)
class RelikReaderForTripletExtraction(RelikReaderBase):
"""
A class for the RelikReader model for triplet extraction.
Args:
transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
The transformer model to use. If `None`, the default model is used.
additional_special_symbols (:obj:`int`, `optional`, defaults to 0):
The number of additional special symbols to add to the tokenizer.
num_layers (:obj:`int`, `optional`):
The number of layers to use. If `None`, all layers are used.
activation (:obj:`str`, `optional`, defaults to "gelu"):
The activation function to use.
linears_hidden_size (:obj:`int`, `optional`, defaults to 512):
The hidden size of the linears.
use_last_k_layers (:obj:`int`, `optional`, defaults to 1):
The number of last layers to use.
training (:obj:`bool`, `optional`, defaults to False):
Whether the model is in training mode.
device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`):
The device to use. If `None`, the default device is used.
tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`):
The tokenizer to use. If `None`, the default tokenizer is used.
dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`):
The dataset to use. If `None`, the default dataset is used.
dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`):
The keyword arguments to pass to the dataset class.
default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
The default reader class to use. If `None`, the default reader class is used.
**kwargs:
Keyword arguments.
"""
default_reader_class: str = (
"relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderREModel"
)
default_data_class: str = "relik.reader.data.relik_reader_re_data.RelikREDataset"
def __init__(
self,
transformer_model: str | tr.PreTrainedModel | None = None,
additional_special_symbols: int = 0,
additional_special_symbols_types: int = 0,
entity_type_loss: bool | None = None,
add_entity_embedding: bool | None = None,
num_layers: int | None = None,
activation: str = "gelu",
linears_hidden_size: int | None = 512,
use_last_k_layers: int = 1,
training: bool = False,
device: str | torch.device | None = None,
tokenizer: str | tr.PreTrainedTokenizer | None = None,
dataset: IterableDataset | str | None = None,
dataset_kwargs: Dict[str, Any] | None = None,
default_reader_class: tr.PreTrainedModel | str | None = None,
**kwargs,
):
super().__init__(
transformer_model=transformer_model,
additional_special_symbols=additional_special_symbols,
additional_special_symbols_types=additional_special_symbols_types,
entity_type_loss=entity_type_loss,
add_entity_embedding=add_entity_embedding,
num_layers=num_layers,
activation=activation,
linears_hidden_size=linears_hidden_size,
use_last_k_layers=use_last_k_layers,
training=training,
device=device,
tokenizer=tokenizer,
dataset=dataset,
default_reader_class=default_reader_class,
**kwargs,
)
# and instantiate the dataset class
self.dataset = dataset
if self.dataset is None and training is False:
self.default_data_class = get_callable_from_string(self.default_data_class)
default_data_kwargs = dict(
dataset_path=None,
materialize_samples=False,
transformer_model=self.tokenizer,
special_symbols=self.default_data_class.get_special_symbols_re(
self.relik_reader_model.config.additional_special_symbols,
use_nme=kwargs.get("use_nme_re", False),
),
special_symbols_types=self.default_data_class.get_special_symbols(
self.relik_reader_model.config.additional_special_symbols_types - 1
)
if self.relik_reader_model.config.additional_special_symbols_types > 0
else [],
for_inference=True,
use_nme=kwargs.get("use_nme", False),
)
# merge the default data kwargs with the ones passed to the model
default_data_kwargs.update(dataset_kwargs or {})
self.dataset = self.default_data_class(**default_data_kwargs)
@torch.no_grad()
@torch.inference_mode()
def _read(
self,
samples: List[RelikReaderSample] | None = None,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
prediction_mask: torch.Tensor | None = None,
special_symbols_mask: torch.Tensor | None = None,
max_length: int = 2048,
max_batch_size: int = 128,
token_batch_size: int = 2048,
precision: str = 32,
annotation_type: AnnotationType = AnnotationType.CHAR,
progress_bar: bool = False,
*args: object,
**kwargs: object,
) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
"""
A wrapper around the forward method that returns the predicted labels for each sample.
Args:
samples (:obj:`List[RelikReaderSample]`, `optional`):
The samples to read. If provided, `text` and `candidates` are ignored.
input_ids (:obj:`torch.Tensor`, `optional`):
The input ids of the text. If `samples` is provided, this is ignored.
attention_mask (:obj:`torch.Tensor`, `optional`):
The attention mask of the text. If `samples` is provided, this is ignored.
token_type_ids (:obj:`torch.Tensor`, `optional`):
The token type ids of the text. If `samples` is provided, this is ignored.
prediction_mask (:obj:`torch.Tensor`, `optional`):
The prediction mask of the text. If `samples` is provided, this is ignored.
special_symbols_mask (:obj:`torch.Tensor`, `optional`):
The special symbols mask of the text. If `samples` is provided, this is ignored.
max_length (:obj:`int`, `optional`, defaults to 1000):
The maximum length of the text.
max_batch_size (:obj:`int`, `optional`, defaults to 128):
The maximum batch size.
token_batch_size (:obj:`int`, `optional`):
The token batch size.
progress_bar (:obj:`bool`, `optional`, defaults to False):
Whether to show a progress bar.
precision (:obj:`str`, `optional`, defaults to 32):
The precision to use for the model.
annotation_type (`AnnotationType`, `optional`, defaults to `AnnotationType.CHAR`):
The type of annotation to return. If `char`, the spans will be in terms of
character offsets. If `word`, the spans will be in terms of word offsets.
*args:
Positional arguments.
**kwargs:
Keyword arguments.
Returns:
:obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`:
The predicted labels for each sample.
"""
precision = precision or self.precision
if samples is not None:
def _read_iterator():
def samples_it():
for i, sample in enumerate(samples):
assert sample._mixin_prediction_position is None
sample._mixin_prediction_position = i
if sample.spans is not None and len(sample.spans) > 0:
entities = []
offset_span = sample.char2token_start[str(sample.offset)]
for span_start, span_end in sample.spans:
if str(span_start) not in sample.char2token_start:
# span_start is in the middle of a word
# retrieve the first token of the word
while str(span_start) not in sample.char2token_start:
span_start -= 1
# skip
if span_start < 0:
break
if str(span_end) not in sample.char2token_end:
# span_end is in the middle of a word
# retrieve the last token of the word
while str(span_end) not in sample.char2token_end:
span_end += 1
# skip
if span_end >= int(list(sample.char2token_end.keys())[-1]):
break
if span_start < 0 or span_end > int(list(sample.char2token_end.keys())[-1]):
continue
entities.append([sample.char2token_start[str(span_start)]-offset_span, sample.char2token_end[str(span_end)]+1-offset_span, ""])
sample.entities = entities
yield sample
next_prediction_position = 0
position2predicted_sample = {}
# instantiate dataset
if self.dataset is None:
raise ValueError(
"You need to pass a dataset to the model in order to predict"
)
self.dataset.samples = samples_it()
self.dataset.model_max_length = max_length
self.dataset.tokens_per_batch = token_batch_size
self.dataset.max_batch_size = max_batch_size
# instantiate dataloader
iterator = DataLoader(
self.dataset, batch_size=None, num_workers=0, shuffle=False
)
if progress_bar:
iterator = tqdm(iterator, desc="Predicting with RelikReader")
with get_autocast_context(self.device, precision):
for batch in iterator:
batch = move_data_to_device(batch, self.device)
batch.update(kwargs)
batch_out = self._batch_predict(**batch)
for sample in batch_out:
if (
sample._mixin_prediction_position
>= next_prediction_position
):
position2predicted_sample[
sample._mixin_prediction_position
] = sample
# yield
while next_prediction_position in position2predicted_sample:
yield position2predicted_sample[next_prediction_position]
del position2predicted_sample[next_prediction_position]
next_prediction_position += 1
outputs = list(_read_iterator())
for sample in outputs:
self.dataset.merge_patches_predictions(sample)
if annotation_type == AnnotationType.CHAR:
self.dataset.convert_to_char_annotations(sample)
elif annotation_type == AnnotationType.WORD:
self.dataset.convert_to_word_annotations(sample)
else:
raise ValueError(
f"Annotation type {annotation_type} not recognized. "
f"Please choose one of {list(AnnotationType)}."
)
else:
outputs = list(
self._batch_predict(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
*args,
**kwargs,
)
)
return outputs
def _batch_predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor | None = None,
prediction_mask: torch.Tensor | None = None,
special_symbols_mask: torch.Tensor | None = None,
special_symbols_mask_entities: torch.Tensor | None = None,
sample: List[RelikReaderSample] | None = None,
*args,
**kwargs,
) -> Iterator[RelikReaderSample]:
"""
A wrapper around the forward method that returns the predicted labels for each sample.
It also adds the predicted labels to the samples.
Args:
input_ids (:obj:`torch.Tensor`):
The input ids of the text.
attention_mask (:obj:`torch.Tensor`):
The attention mask of the text.
token_type_ids (:obj:`torch.Tensor`, `optional`):
The token type ids of the text.
prediction_mask (:obj:`torch.Tensor`, `optional`):
The prediction mask of the text.
special_symbols_mask (:obj:`torch.Tensor`, `optional`):
The special symbols mask of the text.
sample (:obj:`List[RelikReaderSample]`, `optional`):
The samples to read. If provided, `text` and `candidates` are ignored.
top_k (:obj:`int`, `optional`, defaults to 5):
The amount of top-k most probable entities to predict.
*args:
Positional arguments.
**kwargs:
Keyword arguments.
Returns:
The predicted labels for each sample.
"""
forward_output = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
prediction_mask=prediction_mask,
special_symbols_mask=special_symbols_mask,
special_symbols_mask_entities=special_symbols_mask_entities,
is_prediction=True,
*args,
**kwargs,
)
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy()
ed_predictions = forward_output["re_entities_predictions"].cpu().numpy()
ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy()
re_predictions = forward_output["re_predictions"].cpu().numpy()
re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy()
for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip(
sample,
ned_start_predictions,
ned_end_predictions,
re_predictions,
re_probabilities,
ed_predictions,
ned_type_predictions,
):
ne_end = ne_end.cpu().numpy()
entities = []
if self.relik_reader_model.config.entity_type_loss:
starts = np.argwhere(ne_st)
i = 0
for start, end in zip(starts, ne_end):
ends = np.argwhere(end)
for e in ends:
entities.append([start[0], e[0], ne_et[i]])
i += 1
# if i == len(ne_et):
# break
# if i == len(ne_et):
# break
else:
starts = np.argwhere(ne_st)
for start, end in zip(starts, ne_end):
ends = np.argwhere(end)
for e in ends:
entities.append([start[0], e[0]])
edp = edp[: len(entities)]
re_pred = re_pred[: len(entities), : len(entities)]
re_prob = re_prob[: len(entities), : len(entities)]
possible_re = np.argwhere(re_pred)
predicted_triplets = []
predicted_triplets_prob = []
for i, j, r in possible_re:
if self.relik_reader_model.relation_disambiguation_loss:
if not (
i != j
and edp[i, r] == 1
and edp[j, r] == 1
and edp[i, 0] == 0
and edp[j, 0] == 0
):
continue
predicted_triplets.append([i, j, r])
predicted_triplets_prob.append(re_prob[i, j, r])
ts._d["predicted_relations"] = predicted_triplets
ts._d["predicted_entities"] = entities
ts._d["predicted_relations_probabilities"] = predicted_triplets_prob
# try-out for a new format
ts._d["predicted_triples"] = predicted_triplets
yield ts