from itertools import chain from typing import List, Optional, Tuple import numpy as np from transformers import Pipeline class RefSegPipeline(Pipeline): labels = [ 'publisher', 'source', 'url', 'other', 'author', 'editor', 'lpage', 'volume', 'year', 'issue', 'title', 'fpage', 'edition' ] iob_labels = list(chain.from_iterable([['B-' + x, 'I-' + x] for x in labels])) + ['O'] id2seg = {k: v for k, v in enumerate(iob_labels)} id2ref = {k: v for k, v in enumerate(['B-ref', 'I-ref', ])} is_split_into_words = False def _sanitize_parameters(self, **kwargs): if "id2seg" in kwargs: self.id2seg = kwargs["id2seg"] if "id2ref" in kwargs: self.id2ref = kwargs["id2ref"] return {}, {}, {} def preprocess(self, sentence, offset_mapping=None, split_into_words=True): tokens = sentence if split_into_words: split_sentence = self.tokenizer.pre_tokenizer.pre_tokenize_str(sentence) tokens, offsets = zip(*split_sentence) model_inputs = self.tokenizer( tokens, return_offsets_mapping=True, padding='max_length', truncation=True, max_length=512, return_tensors="pt", return_special_tokens_mask=True, return_overflowing_tokens=True, is_split_into_words=split_into_words, stride=32 ) if offset_mapping: model_inputs["offset_mapping"] = offset_mapping model_inputs["sentence"] = sentence model_inputs["token_offsets"] = offsets return model_inputs def _forward(self, model_inputs): special_tokens_mask = model_inputs.pop("special_tokens_mask") offset_mapping = model_inputs.pop("offset_mapping", None) sentence = model_inputs.pop("sentence") token_offsets = model_inputs.pop("token_offsets") overflow_mapping = model_inputs.pop("overflow_to_sample_mapping") if self.framework == "tf": logits = self.model(model_inputs.data)[0] else: logits = self.model(**model_inputs)[0] return { "logits": logits, "special_tokens_mask": special_tokens_mask, "offset_mapping": offset_mapping, "overflow_mapping": overflow_mapping, "sentence": sentence, "token_offsets": token_offsets, **model_inputs, } def postprocess(self, model_outputs): # if ignore_labels is None: ignore_labels = ["O"] logits_seg = model_outputs["logits"][0].numpy() logits_ref = model_outputs["logits"][1].numpy() sentence = model_outputs["sentence"] token_offsets = model_outputs["token_offsets"] input_ids = model_outputs["input_ids"] special_tokens_mask = model_outputs["special_tokens_mask"] offset_mapping = model_outputs["offset_mapping"] if model_outputs["offset_mapping"] is not None else None maxes_seg = np.max(logits_seg, axis=-1, keepdims=True) shifted_exp_seg = np.exp(logits_seg - maxes_seg) scores_seg = shifted_exp_seg / shifted_exp_seg.sum(axis=-1, keepdims=True) maxes_ref = np.max(logits_ref, axis=-1, keepdims=True) shifted_exp_ref = np.exp(logits_ref - maxes_ref) scores_ref = shifted_exp_ref / shifted_exp_ref.sum(axis=-1, keepdims=True) pre_entities = self.gather_pre_entities( input_ids, scores_seg, scores_ref, offset_mapping, special_tokens_mask ) grouped_entities = self.aggregate(pre_entities, token_offsets, sentence) cleaned_groups = [] for group in grouped_entities: start, end = None, None entities = [] group_dict = {} for entity in group: if entity.get("entity_group", None) in ignore_labels: continue if start is None or end is None: start = entity["start"] end = entity["end"] else: start = min(start, entity["start"]) end = max(end, entity["end"]) entities.append(entity) if entities: group_dict["reference_raw"] = sentence[start:end] group_dict["entities"] = entities cleaned_groups.append(group_dict) # entities = [ # entity # for entity in group # if entity.get("entity_group", None) not in ignore_labels # ] # if entities: # cleaned_groups.append(entities) return { "number_of_references": len(cleaned_groups), "references": cleaned_groups, } def gather_pre_entities( self, input_ids: np.ndarray, scores_seg: np.ndarray, scores_ref: np.ndarray, offset_mappings: Optional[List[Tuple[int, int]]], special_tokens_masks: np.ndarray, ) -> List[dict]: """Fuse various numpy arrays into dicts with all the information needed for aggregation""" pre_entities = [] for idx_list, (input_id, offset_mapping, special_tokens_mask, s_seg, s_ref) in enumerate( zip(input_ids, offset_mappings, special_tokens_masks, scores_seg, scores_ref)): for idx, iid in enumerate(input_id): skip = False if idx_list != 0 and idx <= 32: skip = True if special_tokens_mask[idx]: continue word = self.tokenizer.convert_ids_to_tokens(int(input_id[idx])) if offset_mapping is not None: start_ind, end_ind = offset_mapping[idx] if not isinstance(start_ind, int): if self.framework == "pt": start_ind = start_ind.item() end_ind = end_ind.item() is_subword = not word.startswith('\u2581') if int(input_id[idx]) == self.tokenizer.unk_token_id: is_subword = False else: start_ind = None end_ind = None is_subword = False pre_entity = { "word": word, "scores_seg": s_seg[idx], "scores_ref": s_ref[idx], "start": start_ind, "end": end_ind, "index": idx, "is_subword": is_subword, "is_stride": skip, } pre_entities.append(pre_entity) return pre_entities def aggregate(self, pre_entities: List[dict], token_offsets: List[tuple], sentence: str) -> List[dict]: entities = self.aggregate_words(pre_entities, token_offsets) return self.group_entities(entities, sentence) def aggregate_word(self, entities: List[dict], token_offset: tuple) -> dict: word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities]) scores_seg = entities[0]["scores_seg"] idx_seg = scores_seg.argmax() score_seg = scores_seg[idx_seg] entity_seg = self.id2seg[idx_seg] scores_ref = np.stack([entity["scores_ref"] for entity in entities]) indices_ref = scores_ref.argmax(axis=1) idx_ref = 1 if all(indices_ref) else 0 entity_ref = self.id2ref[idx_ref] new_entity = { "entity_seg": entity_seg, "score_seg": score_seg, "entity_ref": entity_ref, "word": word, "start": entities[0]["start"] + token_offset[0], "end": entities[-1]["end"] + token_offset[0], } return new_entity def aggregate_words(self, entities: List[dict], token_offsets: List[tuple]) -> List[dict]: """ Override tokens from a given word that disagree to force agreement on word boundaries. Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| company| B-ENT I-ENT """ word_entities = [] word_group = None idx = 0 for entity in entities: if entity["is_stride"]: continue if word_group is None: word_group = [entity] elif entity["is_subword"]: word_group.append(entity) else: word_entities.append(self.aggregate_word(word_group, token_offsets[idx])) word_group = [entity] idx += 1 word_entities.append(self.aggregate_word(word_group, token_offsets[idx])) idx += 1 return word_entities def group_entities(self, entities: List[dict], sentence: str) -> List[dict]: """ Find and group together the adjacent tokens with the same entity predicted. Args: entities (`dict`): The entities predicted by the pipeline. """ entity_chunk = [] entity_chunk_disagg = [] for entity in entities: if not entity_chunk_disagg: entity_chunk_disagg.append(entity) continue bi_ref, tag_ref = self.get_tag(entity["entity_ref"]) last_bi_ref, last_tag_ref = self.get_tag(entity_chunk_disagg[-1]["entity_ref"]) if tag_ref == last_tag_ref and bi_ref != "B": entity_chunk_disagg.append(entity) else: entity_chunk.append(entity_chunk_disagg) entity_chunk_disagg = [entity] if entity_chunk_disagg: entity_chunk.append(entity_chunk_disagg) entity_chunks_all = [] for chunk in entity_chunk: entity_groups = [] entity_group_disagg = [] for entity in chunk: if not entity_group_disagg: entity_group_disagg.append(entity) continue bi_seg, tag_seg = self.get_tag(entity["entity_seg"]) last_bi_seg, last_tag_seg = self.get_tag(entity_group_disagg[-1]["entity_seg"]) if tag_seg == last_tag_seg and bi_seg != "B": entity_group_disagg.append(entity) else: entity_groups.append(self.group_sub_entities(entity_group_disagg, sentence)) entity_group_disagg = [entity] if entity_group_disagg: entity_groups.append(self.group_sub_entities(entity_group_disagg, sentence)) entity_chunks_all.append(entity_groups) return entity_chunks_all def group_sub_entities(self, entities: List[dict], sentence: str) -> dict: """ Group together the adjacent tokens with the same entity predicted. Args: entities (`dict`): The entities predicted by the pipeline. """ entity = entities[0]["entity_seg"].split("-")[-1] scores = np.nanmean([entity["score_seg"] for entity in entities]) start = min([entity["start"] for entity in entities]) end = max([entity["end"] for entity in entities]) word = sentence[start:end] entity_group = { "entity_group": entity, "score": np.mean(scores), "word": word, "start": entities[0]["start"], "end": entities[-1]["end"], } return entity_group def get_tag(self, entity_name: str) -> Tuple[str, str]: if entity_name.startswith("B-"): bi = "B" tag = entity_name[2:] elif entity_name.startswith("I-"): bi = "I" tag = entity_name[2:] else: bi = "I" tag = entity_name return bi, tag