from itertools import groupby import numpy as np class EntityClassifier: def __init__(self): pass def _get_grouped_by_length(self, entities): sorted_by_len = sorted(entities, key=lambda entity: len(entity.get_span()), reverse=True) entities_by_length = {} for length, group in groupby(sorted_by_len, lambda entity: len(entity.get_span())): entities = list(group) entities_by_length[length] = entities return entities_by_length def _filter_max_length(self, entities): entities_by_length = self._get_grouped_by_length(entities) max_length = max(list(entities_by_length.keys())) return entities_by_length[max_length] def _select_max_prior(self, entities): priors = [entity.get_prior() for entity in entities] return entities[np.argmax(priors)] def _get_casing_difference(self, word1, original): difference = 0 for w1, w2 in zip(word1, original): if w1 != w2: difference += 1 return difference def _filter_most_similar(self, entities): similarities = np.array( [self._get_casing_difference(entity.get_span().text, entity.get_original_alias()) for entity in entities]) min_indices = np.where(similarities == similarities.min())[0].tolist() return [entities[i] for i in min_indices] def __call__(self, entities): filtered_by_length = self._filter_max_length(entities) filtered_by_casing = self._filter_most_similar(filtered_by_length) return self._select_max_prior(filtered_by_casing)