ner-analyzer / src /negation.py
Kaelan
initial commit
f5e3fa7
raw
history blame
2.65 kB
import spacy
from negspacy.negation import Negex
from spacy.matcher import PhraseMatcher
from spacy.tokens import Span
def negation(model: spacy, entities: list):
"""
Take in the current model pipeline and add in Negation model.
Add in entities to the negation model
Parameters:
model: spacy model
entities: list of entities
Returns:
model: spacy model with Negation added to the pipeline
"""
if 'parser' in model.pipe_names:
model.remove_pipe('parser')
#nlp.add_pipe(nlp.create_pipe('sentencizer'))
if 'sentencizer' not in model.pipe_names:
model.add_pipe('sentencizer')
#negex = Negex(nlp)
if 'negex' not in model.pipe_names:
model.add_pipe('negex',config=entities)
return model
def infer_negation(neg_model: spacy, model: spacy, text: str ,pred_doc: spacy):
"""
To match results from the negation model with the results from the model.
Replace the entity type of the spans or tokens in the predictions doc
that should be negated with entity type "NEG".
Parameters:
neg_model: spacy negation model
model: spacy model
text: text sample
pred_doc: prediction of the text sample from model
Returns:
pred_doc: spacy doc with all entities that should be negated replaced with the "NEG" entity type
"""
doc = neg_model(text)
results = {'ent':[],'start':[], 'end':[]}
for e in doc.ents:
rs = str(e._.negex)
if rs == "True":
results['ent'].append(e.text)
results['start'].append(e.start)
results['end'].append(e.end)
print('Negation: ', results)
patterns = [model.make_doc(text) for text in results['ent']]
matcher = PhraseMatcher(model.vocab)
matcher.add('NEG', None, *patterns)
# match all the tokens or span of text detected to be negated with the prediction doc.
matches = matcher(pred_doc)
seen_tokens = set()
new_entities = []
entities = pred_doc.ents
# to get exact matches: not only the span or word matches but also location
for match in results['start']:
count = 0
for match_id, start, end in matches:
if match == start:
new_entities.append(Span(pred_doc, start, end, label=match_id))
entities = [
e for e in entities if not (e.start < end and e.end > start)
]
seen_tokens.update(range(start, end))
matches.pop(count)
count += 1
pred_doc.ents = tuple(entities) + tuple(new_entities)
return pred_doc