File size: 1,638 Bytes
8b513d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from itertools import groupby
import numpy as np
class EntityClassifier:
def __init__(self):
pass
def _get_grouped_by_length(self, entities):
sorted_by_len = sorted(entities, key=lambda entity: len(entity.get_span()), reverse=True)
entities_by_length = {}
for length, group in groupby(sorted_by_len, lambda entity: len(entity.get_span())):
entities = list(group)
entities_by_length[length] = entities
return entities_by_length
def _filter_max_length(self, entities):
entities_by_length = self._get_grouped_by_length(entities)
max_length = max(list(entities_by_length.keys()))
return entities_by_length[max_length]
def _select_max_prior(self, entities):
priors = [entity.get_prior() for entity in entities]
return entities[np.argmax(priors)]
def _get_casing_difference(self, word1, original):
difference = 0
for w1, w2 in zip(word1, original):
if w1 != w2:
difference += 1
return difference
def _filter_most_similar(self, entities):
similarities = np.array(
[self._get_casing_difference(entity.get_span().text, entity.get_original_alias()) for entity in entities])
min_indices = np.where(similarities == similarities.min())[0].tolist()
return [entities[i] for i in min_indices]
def __call__(self, entities):
filtered_by_length = self._filter_max_length(entities)
filtered_by_casing = self._filter_most_similar(filtered_by_length)
return self._select_max_prior(filtered_by_casing)
|