jiangjiechen's picture
fix device issue: cpu
964af23
raw
history blame
10.7 kB
# -*- coding:utf-8 -*-
"""
@Last modified date : 2020/12/23
"""
import re
import nltk
import torch.cuda
from nltk.stem import WordNetLemmatizer
from allennlp.predictors.predictor import Predictor
nltk.download('wordnet')
nltk.download('stopwords')
def deal_bracket(text, restore, leading_ent=None):
if leading_ent:
leading_ent = ' '.join(leading_ent.split('_'))
text = f'Things about {leading_ent}: ' + text
if restore:
text = text.replace('-LRB-', '(').replace('-RRB-', ')')
text = text.replace('LRB', '(').replace('RRB', ')')
return text
def refine_entity(entity):
entity = re.sub(r'-LRB- .+ -RRB-$', '', entity)
entity = re.sub(r'LRB .+ RRB$', '', entity)
entity = re.sub(r'_', ' ', entity)
entity = re.sub(r'\s+', ' ', entity)
return entity.strip()
def find_sub_seq(seq_a, seq_b, shift=0, uncased=False, lemmatizer=None):
if uncased:
seq_a = [token.lower() for token in seq_a]
seq_b = [token.lower() for token in seq_b]
if lemmatizer is not None:
seq_a = [lemmatizer.lemmatize(token) for token in seq_a]
seq_b = [lemmatizer.lemmatize(token) for token in seq_b]
for i in range(shift, len(seq_a)):
if seq_a[i:i+len(seq_b)] == seq_b:
return i, i + len(seq_b)
return -1, -1
def is_sub_seq(seq_start, seq_end, all_seqs):
for start, end, is_candidate in all_seqs:
if start <= seq_start < seq_end <= end:
return start, end, is_candidate
return None
# extract named entity with B-I-L-U-O schema
def extract_named_entity(tags):
all_NEs = []
ne_type, ne_start = '', -1
for i, t in enumerate(tags):
if t == 'O':
ne_type, ne_start = '', -1
continue
t1, t2 = t.split('-')
if t1 == 'B':
ne_type, ne_start = t2, i
elif t1 == 'I' and t2 != ne_type:
ne_type, ne_start = '', -1
elif t1 == 'L' and t2 != ne_type:
ne_type, ne_start = '', -1
elif t1 == 'L' and t2 == ne_type:
all_NEs.append((ne_start, i + 1, False))
ne_type, ne_start = '', -1
elif t1 == 'U':
all_NEs.append((i, i + 1, False))
ne_type, ne_start = '', -1
return all_NEs
def refine_results(tokens, spans, stopwords):
all_spans = []
for span_start, span_end, is_candidate in spans:
# remove stopwords
if not is_candidate:
while span_start < span_end and tokens[span_start].lower() in stopwords:
span_start += 1
if span_start >= span_end:
continue
# add prefix
if span_start > 0 and tokens[span_start - 1] in ['a', 'an', 'A', 'An', 'the', 'The']:
span_start -= 1
# convert token-level index into char-level index
span = ' '.join(tokens[span_start:span_end])
span_start = len(' '.join(tokens[:span_start])) + 1 * min(1, span_start) # 1 for blank
span_end = span_start + len(span)
all_spans.append((span, span_start, span_end))
all_spans = sorted(all_spans, key=lambda x: (x[1], x[1] - x[2]))
# remove overlap
refined_spans = []
for span, span_start, span_end in all_spans:
flag = True
for _, start, end in refined_spans:
if start <= span_start < span_end <= end:
flag = False
break
if flag:
refined_spans.append((span, span_start, span_end))
return refined_spans
class SentenceParser:
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu',
ner_path="https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz",
cp_path="https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"):
self.device = self.parse_device(device)
self.ner = Predictor.from_path(ner_path, cuda_device=self.device)
print('* ner loaded')
self.cp = Predictor.from_path(cp_path, cuda_device=self.device)
print('* constituency parser loaded')
self.lemmatizer = WordNetLemmatizer()
# some heuristic rules can be added here
self.stopwords = set(nltk.corpus.stopwords.words('english'))
self.stopwords.update({'-', '\'s', 'try', 'tries', 'tried', 'trying',
'become', 'becomes', 'became', 'becoming',
'make', 'makes', 'made', 'making', 'call', 'called', 'calling',
'put', 'ever', 'something', 'someone', 'sometime'})
self.special_tokens = ['only', 'most', 'before', 'after', 'behind']
for token in self.special_tokens:
if token in self.stopwords: self.stopwords.remove(token)
if 'won' in self.stopwords: self.stopwords.remove('won')
if 'own' in self.stopwords: self.stopwords.remove('own')
def parse_device(self, device):
if 'cpu' in device:
return -1
else:
dev = re.findall('\d+', device)
return 0 if len(dev) == 0 else int(dev[0])
def identify_NPs(self, text, candidate_NPs=None):
text = re.sub(r'\s+', ' ', text).strip()
if len(text) == 0: return {'text': '', 'NPs': [], 'verbs': [], 'adjs': []}
cp_outputs = self.cp.predict(text)
ner_outputs = self.ner.predict(text)
tokens = cp_outputs['tokens']
pos_tags = cp_outputs['pos_tags']
ner_tags = ner_outputs['tags']
tree = cp_outputs['hierplane_tree']['root']
# extract candidate noun phrases passed by user with token index
all_NPs = []
candidate_NPs = [refine_entity(np).split() for np in candidate_NPs] if candidate_NPs else []
for np in sorted(candidate_NPs, key=len, reverse=True):
np_start, np_end = find_sub_seq(tokens, np, 0, uncased=True, lemmatizer=self.lemmatizer)
while np_start != -1 and np_end != -1:
if not is_sub_seq(np_start, np_end, all_NPs):
all_NPs.append((np_start, np_end, True))
np_start, np_end = find_sub_seq(tokens, np, np_end, uncased=True, lemmatizer=self.lemmatizer)
# extract noun phrases from tree
def _get_bottom_NPs(children):
if 'children' not in children:
return None
if {'NP', 'OP', 'XP', 'QP'} & set(children['attributes']):
is_bottom = True
for child in children['children']:
if 'children' in child:
is_bottom = False
if is_bottom:
bottom_NPs.append(children['word'].split())
else:
for child in children['children']:
_get_bottom_NPs(child)
else:
for child in children['children']:
_get_bottom_NPs(child)
bottom_NPs = []
_get_bottom_NPs(tree)
# find token indices of noun phrases
np_index = -1
for np in bottom_NPs:
np_start, np_end = find_sub_seq(tokens, np, np_index + 1)
if not is_sub_seq(np_start, np_end, all_NPs):
all_NPs.append((np_start, np_end, False))
np_index = np_end
# extract named entities with token index
all_NEs = extract_named_entity(ner_tags)
# extract verbs with token index
all_verbs = []
for i, pos in enumerate(pos_tags):
if pos[0] == 'V':
if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs):
all_verbs.append((i, i + 1, False))
# extract modifiers with token index
all_modifiers = []
for i, (token, pos) in enumerate(zip(tokens, pos_tags)):
if pos in ['JJ', 'RB']: # adj. and adv.
if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs):
all_modifiers.append((i, i + 1, False))
elif token in self.special_tokens:
if not is_sub_seq(i, i + 1, all_NPs) and not is_sub_seq(i, i + 1, all_NEs):
all_modifiers.append((i, i + 1, False))
# split noun phrases with named entities
all_spans = []
for np_start, np_end, np_is_candidate in all_NPs:
if np_is_candidate: # candidate noun phrases will be preserved
all_spans.append((np_start, np_end, np_is_candidate))
else:
match = is_sub_seq(np_start, np_end, all_NEs)
if match: # if a noun phrase is a sub span of a named entity, the named entity will be preserved
all_spans.append(match)
else: # else if a named entity is a sub span of a noun phrase, the noun phrase will be split
index = np_start
for ne_start, ne_end, ne_is_candidate in all_NEs:
if np_start <= ne_start < ne_end <= np_end:
all_modifiers.append((index, ne_start, False))
all_spans.append((ne_start, ne_end, ne_is_candidate))
index = ne_end
all_spans.append((index, np_end, False))
# named entities without overlapping
for ne_start, ne_end, is_candidate in all_NEs:
if not is_sub_seq(ne_start, ne_end, all_spans):
all_spans.append((ne_start, ne_end, is_candidate))
all_spans = refine_results(tokens, all_spans, self.stopwords)
all_verbs = refine_results(tokens, all_verbs, self.stopwords)
all_modifiers = refine_results(tokens, all_modifiers, self.stopwords)
return {'text': tree['word'], 'NPs': all_spans, 'verbs': all_verbs, 'adjs': all_modifiers}
if __name__ == '__main__':
import json
print('Initializing sentence parser.')
client = SentenceParser(device='cpu')
print('Parsing sentence.')
sentence = "The Africa Cup of Nations is held in odd - numbered years due to conflict with the World Cup . "
entities = ['Africa Cup of Nations', 'Africa_Cup_of_Nations', 'Africa Cup', 'Africa_Cup']
results = client.identify_NPs(sentence, entities)
print(json.dumps(results, ensure_ascii=False, indent=4))
# import random
# from tqdm import tqdm
# from utils import read_json_lines, save_json
#
# print('Parsing file.')
# results = []
# data = list(read_json_lines('data/train.jsonl'))
# random.shuffle(data)
# for entry in tqdm(data[:100]):
# results.append(client.identify_NPs(entry['claim']))
# save_json(results, 'data/results.json')