Spaces:
Build error
Build error
import numpy as np | |
class Config: | |
def __init__(self): | |
super(Config, self).__init__() | |
def read_conll_ner(path): | |
with open(path) as f: | |
lines = f.readlines() | |
unique_entries = [] | |
sentences = [] | |
curr_sentence = [] | |
for line in lines: | |
if not line.strip(): | |
if curr_sentence: | |
sentences.append(curr_sentence) | |
curr_sentence = [] | |
continue | |
entry = line.split() | |
curr_sentence.append(entry) | |
if not len(unique_entries): | |
unique_entries = [[] for _ in entry[1:]] | |
for e, list in zip(entry[1:], unique_entries): | |
if e not in list: | |
list.append(e) | |
return [sentences] + unique_entries | |
def read_pickled_conll(path): | |
with open(path, "rb") as f: | |
data = pickle.load(f) | |
return data | |
def split_conll_docs(conll_sents, skip_docstart=True): | |
docs = [] | |
curr_doc = [] | |
for sent in conll_sents: | |
if sent[0][0] == '-DOCSTART-': | |
if curr_doc: | |
docs.append(curr_doc) | |
curr_doc = [] | |
if skip_docstart: | |
continue | |
curr_doc.append(sent) | |
docs.append(curr_doc) | |
return docs | |
def create_context_data(docs, pos_col_id=1, tag_col_id=3, context_length=1, **kwargs): | |
ctx_type = kwargs.get("ctx_type", "other") | |
sep_token = kwargs.get("sep_token", "[SEP]") | |
if ctx_type == "cand_titles": | |
# create context for candidate titles scenario | |
for doc in docs: | |
doc["ctx_sent"] = doc["query"] + [sep_token] + f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>") | |
return docs | |
if ctx_type == "cand_links": | |
for doc in docs: | |
doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>") | |
linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>") | |
doc["ctx_sent"] = doc["query"] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list | |
return docs | |
if ctx_type == "raw_text": | |
# create context for candidate raw text | |
for doc in docs: | |
doc["ctx_sent"] = [doc["query"] + [sep_token] + [cand["processed_text"]] for cand in doc["BM25_cands"]] | |
return docs | |
if ctx_type == 'matched_spans': | |
matched_spans = kwargs.get('matched_spans') | |
return [ | |
[[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans | |
None, # pos tags | |
[s[tag_col_id] for s in d] if tag_col_id > 0 else None, # ner tags | |
[len(d)] # sentence length | |
] | |
for d, ms in zip(docs, matched_spans)] | |
if ctx_type == 'bm25_matched_spans': | |
matched_spans = kwargs.get('matched_spans') | |
pickled_data = kwargs.get('pickled_data') | |
docs = [[[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans | |
None, # pos tags | |
[s[tag_col_id] for s in d], # ner tags | |
[len(d)] # sentence length | |
] | |
for d, ms in zip(docs, matched_spans)] | |
for ms, doc in zip(docs,pickled_data): | |
doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>") | |
linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>") | |
ms[0] = ms[0] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list | |
return docs | |
if ctx_type == "infobox": | |
infobox_keys_path = kwargs.get("infobox_keys_path") | |
infobox_keys = read_pickled_conll(infobox_keys_path) | |
if 'pred_spans' in docs[0]: | |
docs = get_pred_ent_bounds(docs) | |
for doc in docs: | |
if 'pred_spans' in doc: | |
ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']] | |
ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']] | |
else: | |
ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']] | |
ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']] | |
ents = list(set(ents + ents_wo_space)) | |
infobox = [infobox_keys[en] for en in ents if en in infobox_keys and infobox_keys[en]] | |
for ibs in infobox: | |
ibs[0] = '[INFO] ' + ibs[0] | |
ibs[-1] = ibs[-1] + ' [/INFO]' | |
infobox = [i for j in infobox for i in j] | |
doc["ctx_sent"] = doc["query"] + [sep_token] + infobox | |
return docs | |
# create context type for other scenarios | |
res = [] | |
for doc in docs: | |
ctx_len = context_length if context_length > 0 else len(doc) | |
# for the last sentences loop around to the beginning for context | |
padded_doc = doc + doc[:ctx_len] | |
for i in range(len(doc)): | |
res.append(( | |
[s[0] for sent in padded_doc[i:i+ctx_len] for s in sent], | |
[s[pos_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent] if pos_col_id > 0 else None, | |
[s[tag_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent], | |
[len(sent) for sent in padded_doc[i:i+ctx_len]], | |
{} # dictionary for extra context | |
)) | |
return res | |
def calc_correct(sentence): | |
gold_chunks = [] | |
parallel_chunks = [] | |
pred_chunks = [] | |
curr_gold_chunk = [] | |
curr_parallel_chunk = [] | |
curr_pred_chunk = [] | |
prev_tag = None | |
for line in sentence: | |
_, _, _, gt, pt = line | |
curr_tag = None | |
if '-' in pt: | |
curr_tag = pt.split('-')[1] | |
if gt.startswith('B'): | |
if curr_gold_chunk: | |
gold_chunks.append(curr_gold_chunk) | |
parallel_chunks.append(curr_parallel_chunk) | |
curr_gold_chunk = [gt] | |
curr_parallel_chunk = [pt] | |
elif gt.startswith('I') or (pt.startswith('I') and curr_tag == prev_tag | |
and curr_gold_chunk): | |
curr_gold_chunk.append(gt) | |
curr_parallel_chunk.append(pt) | |
elif gt.startswith('O') and pt.startswith('O'): | |
if curr_gold_chunk: | |
gold_chunks.append(curr_gold_chunk) | |
parallel_chunks.append(curr_parallel_chunk) | |
curr_gold_chunk = [] | |
curr_parallel_chunk = [] | |
if pt.startswith('O'): | |
if curr_pred_chunk: | |
pred_chunks.append(curr_pred_chunk) | |
curr_pred_chunk = [] | |
elif pt.startswith('B'): | |
if curr_pred_chunk: | |
pred_chunks.append(curr_pred_chunk) | |
curr_pred_chunk = [pt] | |
prev_tag = curr_tag | |
else: | |
if prev_tag is not None and curr_tag != prev_tag: | |
prev_tag = curr_tag | |
if curr_pred_chunk: | |
pred_chunks.append(curr_pred_chunk) | |
curr_pred_chunk = [] | |
curr_pred_chunk.append(pt) | |
if curr_gold_chunk: | |
gold_chunks.append(curr_gold_chunk) | |
parallel_chunks.append(curr_parallel_chunk) | |
if curr_pred_chunk: | |
pred_chunks.append(curr_pred_chunk) | |
correct = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks) | |
if not len([1 for g, p in zip(gc, pc) if g != p])]) | |
correct_tagless = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks) | |
if not len([1 for g, p in zip(gc, pc) if g[0] != p[0]])]) | |
# return correct, gold_chunks, parallel_chunks, pred_chunks, ob1_correct, correct_tagless | |
return {'correct': correct, | |
'correct_tagless': correct_tagless, | |
'gold_count': len(gold_chunks), | |
'pred_count': len(pred_chunks)} | |
def tag_sentences(sentences): | |
nlp = stanza.Pipeline(lang='en', processors='tokenize,pos', logging_level='WARNING') | |
tagged_sents = [] | |
for sentence in sentences: | |
n = nlp(sentence) | |
tagged_sent = [] | |
for s in n.sentences: | |
for w in s.words: | |
tagged_sent.append([w.text, w.upos]) | |
tagged_sents.append(tagged_sent) | |
return tagged_sents | |
def extract_spans(sentence, tagless=False): | |
spans_positions = [] | |
span_bounds = [] | |
all_bounds = [] | |
span_tags = [] | |
curr_tag = None | |
curr_span = [] | |
curr_span_start = -1 | |
# span ids, span types | |
for i, token in enumerate(sentence): | |
if token.startswith('B'): | |
if curr_span: | |
spans_positions.append([curr_span, len(all_bounds)]) | |
span_bounds.append([curr_span_start, i-1]) | |
all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)]) | |
if not tagless: | |
span_tags.append(token.split('-')[1]) | |
curr_span = [] | |
curr_tag = None | |
curr_span.append(token) | |
curr_tag = None if tagless else token.split('-')[1] | |
curr_span_start = i | |
elif token.startswith('I'): | |
if not tagless: | |
tag = token.split('-')[1] | |
if tag != curr_tag and curr_tag is not None: | |
spans_positions.append([curr_span, len(all_bounds)]) | |
span_bounds.append([curr_span_start, i - 1]) | |
span_tags.append(token.split('-')[1]) | |
all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)]) | |
curr_span = [] | |
curr_tag = tag | |
curr_span_start = i | |
elif curr_tag is None: | |
curr_span = [] | |
curr_tag = tag | |
curr_span_start = i | |
elif not curr_span: | |
curr_span_start = i | |
curr_span.append(token) | |
elif token.startswith('O') or token.startswith('-'): | |
if curr_span: | |
spans_positions.append([curr_span, len(all_bounds)]) | |
span_bounds.append([curr_span_start, i-1]) | |
all_bounds.append([[curr_span_start, i-1], 'E', len(all_bounds)]) | |
curr_span = [] | |
curr_tag = None | |
all_bounds.append([[i], 'W', len(all_bounds)]) | |
# check if sentence ended with a span | |
if curr_span: | |
spans_positions.append([curr_span, len(all_bounds)]) | |
span_bounds.append([curr_span_start, len(sentence) - 1]) | |
all_bounds.append([[curr_span_start, len(sentence) - 1], 'E', len(all_bounds)]) | |
tagged_bounds = [[loc[0][0].split('-')[1] if '-' in loc[0][0] else loc[0][0], bound] | |
for loc, bound in zip(spans_positions, span_bounds)] | |
return spans_positions, span_bounds, all_bounds, tagged_bounds | |
def ner_corpus_stats(corpus_path): | |
onto_train_cols = read_conll_ner(corpus_path) | |
tags = list(set([t.split('-')[1] for t in onto_train_cols[3] if '-' in t])) | |
onto_train_spans = [extract_spans([t[3] for t in sent])[3] for sent in | |
onto_train_cols[0]] | |
span_lens = [span[1][1] - span[1][0] + 1 for sent in onto_train_spans for | |
span in sent] | |
len_stats = [span_lens.count(i + 1) / len(span_lens) for i in | |
range(max(span_lens))] | |
flat_spans = [span for sent in onto_train_spans for span in sent] | |
tag_lens_dict = {k: [] for k in tags} | |
tag_counts_dict = {k: 0 for k in tags} | |
for span in flat_spans: | |
span_length = span[1][1] - span[1][0] + 1 | |
span_tag = span[0][0].split('-')[1] | |
tag_lens_dict[span_tag].append(span_length) | |
tag_counts_dict[span_tag] += 1 | |
x = list(tag_counts_dict.items()) | |
x.sort(key=lambda l: l[1]) | |
tag_counts = [list(l) for l in x] | |
for l in tag_counts: | |
l[1] = l[1] / len(span_lens) | |
tag_len_stats = {k: [v.count(i + 1) / len(v) for i in range(max(v))] | |
for k, v in tag_lens_dict.items()} | |
span_texts = [sent[span[1][0]:span[1][1] + 1] | |
for sent, spans in zip(onto_train_cols[0], onto_train_spans) | |
for span in spans] | |
span_pos = [[span[0][-1].split('-')[1], '_'.join(t[1] for t in span)] | |
for span in span_texts] | |
unique_pos = list(set([span[1] for span in span_pos])) | |
pos_dict = {k: 0 for k in unique_pos} | |
for span in span_pos: | |
pos_dict[span[1]] += 1 | |
unique_pos.sort(key=lambda l: pos_dict[l], reverse=True) | |
pos_stats = [[p, pos_dict[p] / len(span_pos)] for p in unique_pos] | |
tag_pos_dict = {kt: {kp: 0 for kp in unique_pos} for kt in tags} | |
for span in span_pos: | |
tag_pos_dict[span[0]][span[1]] += 1 | |
tag_pos_stats = {kt: [[p, tag_pos_dict[kt][p] / tag_counts_dict[kt]] | |
for p in unique_pos] for kt in tags} | |
for kt in tags: | |
tag_pos_stats[kt].sort(key=lambda l: l[1], reverse=True) | |
return len_stats, tag_len_stats, tag_counts, pos_stats, tag_pos_stats | |
def filter_by_max_ents(sentences, max_ent_length): | |
""" | |
Filters a given list of sentences and only returns the sentences that have | |
named entities shorter than or equal to the given max_ent_length. | |
:param sentences: sentences in conll format as extracted by read_conll_ner | |
:param max_ent_length: The maximum number of tokens in an entity | |
:return: a lits of sentences | |
""" | |
filtered_sents = [] | |
for sent in sentences: | |
sent_span_lens = [s[1] - s[0] + 1 | |
for s in extract_spans([t[3] for t in sent])[1]] | |
if not sent_span_lens or max(sent_span_lens) <= max_ent_length: | |
filtered_sents.append(sent) | |
return filtered_sents | |
def get_pred_ent_bounds(docs): | |
for doc in docs: | |
eb = [] | |
count = 0 | |
for p_eb in doc['pred_spans']: | |
if p_eb == 'B': | |
eb.append([count,count]) | |
elif p_eb == 'I' and len(eb) > 0: | |
eb[-1][1] = count | |
count += 1 | |
doc['pred_ent_bounds'] = eb | |
return docs | |
def enumerate_spans(batch): | |
enumerated_spans_batch = [] | |
for idx in range(0, len(batch)): | |
sentence_length = batch[idx] | |
enumerated_spans = [] | |
for x in range(len(sentence_length)): | |
for y in range(x, len(sentence_length)): | |
enumerated_spans.append([x,y]) | |
enumerated_spans_batch.append(enumerated_spans) | |
return enumerated_spans_batch | |
def compact_span_enumeration(batch): | |
sentence_lengths = [len(b) for b in batch] | |
enumerated_spans = [[[x, y] | |
for y in range(0, sentence_length) | |
for x in range(sentence_length)] | |
for sentence_length in sentence_lengths] | |
return enumerated_spans | |
def preprocess_data(data): | |
clean_data = [] | |
for sample in data: | |
clean_tokens = [araby.strip_tashkeel(token) for token in sample[0]] | |
clean_tokens = [araby.strip_tatweel(token) for token in clean_tokens] | |
clean_sample = [clean_tokens] | |
clean_sample.extend(sample[1:]) | |
clean_data.append(clean_sample) | |
return clean_data | |
def generate_targets(enumerated_spans, sentences): | |
#### could be refactored into a helper function #### | |
extracted_spans= [extract_spans(sentence,True)[3] for sentence in sentences] | |
target_locations = [] | |
for span in extracted_spans: | |
sentence_locations = [] | |
for location in span: | |
sentence_locations.append(location[1]) | |
target_locations.append(sentence_locations) | |
#### could be refactored into a helper function #### | |
targets= [] | |
for span, location_list in zip(enumerated_spans, target_locations): | |
span_arr = np.zeros_like(span).tolist() | |
target_indices = [span.index(span_location) for | |
span_location in location_list] | |
for idx in target_indices: | |
span_arr[idx] =1 | |
span_arr = [0 if x!=1 else x for x in span_arr] | |
targets.append(list(span_arr)) | |
return targets | |
def label_tags(tags): | |
output_tags = [] | |
for tag in tags: | |
if (tag == "O"): | |
output_tags.append(0) | |
else: | |
output_tags.append(1) | |
return output_tags |