nehalelkaref's picture
Update utils.py
f78caf1
raw
history blame
16.5 kB
import numpy as np
class Config:
def __init__(self):
super(Config, self).__init__()
def read_conll_ner(path):
with open(path) as f:
lines = f.readlines()
unique_entries = []
sentences = []
curr_sentence = []
for line in lines:
if not line.strip():
if curr_sentence:
sentences.append(curr_sentence)
curr_sentence = []
continue
entry = line.split()
curr_sentence.append(entry)
if not len(unique_entries):
unique_entries = [[] for _ in entry[1:]]
for e, list in zip(entry[1:], unique_entries):
if e not in list:
list.append(e)
return [sentences] + unique_entries
def read_pickled_conll(path):
with open(path, "rb") as f:
data = pickle.load(f)
return data
def split_conll_docs(conll_sents, skip_docstart=True):
docs = []
curr_doc = []
for sent in conll_sents:
if sent[0][0] == '-DOCSTART-':
if curr_doc:
docs.append(curr_doc)
curr_doc = []
if skip_docstart:
continue
curr_doc.append(sent)
docs.append(curr_doc)
return docs
def create_context_data(docs, pos_col_id=1, tag_col_id=3, context_length=1, **kwargs):
ctx_type = kwargs.get("ctx_type", "other")
sep_token = kwargs.get("sep_token", "[SEP]")
if ctx_type == "cand_titles":
# create context for candidate titles scenario
for doc in docs:
doc["ctx_sent"] = doc["query"] + [sep_token] + f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
return docs
if ctx_type == "cand_links":
for doc in docs:
doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>")
doc["ctx_sent"] = doc["query"] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list
return docs
if ctx_type == "raw_text":
# create context for candidate raw text
for doc in docs:
doc["ctx_sent"] = [doc["query"] + [sep_token] + [cand["processed_text"]] for cand in doc["BM25_cands"]]
return docs
if ctx_type == 'matched_spans':
matched_spans = kwargs.get('matched_spans')
return [
[[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans
None, # pos tags
[s[tag_col_id] for s in d] if tag_col_id > 0 else None, # ner tags
[len(d)] # sentence length
]
for d, ms in zip(docs, matched_spans)]
if ctx_type == 'bm25_matched_spans':
matched_spans = kwargs.get('matched_spans')
pickled_data = kwargs.get('pickled_data')
docs = [[[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans
None, # pos tags
[s[tag_col_id] for s in d], # ner tags
[len(d)] # sentence length
]
for d, ms in zip(docs, matched_spans)]
for ms, doc in zip(docs,pickled_data):
doc_titles_list = f"<split>{sep_token}<split>".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("<split>")
linked_titles_list = f"<split>{sep_token}<split>".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("<split>")
ms[0] = ms[0] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list
return docs
if ctx_type == "infobox":
infobox_keys_path = kwargs.get("infobox_keys_path")
infobox_keys = read_pickled_conll(infobox_keys_path)
if 'pred_spans' in docs[0]:
docs = get_pred_ent_bounds(docs)
for doc in docs:
if 'pred_spans' in doc:
ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']]
ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']]
else:
ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']]
ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']]
ents = list(set(ents + ents_wo_space))
infobox = [infobox_keys[en] for en in ents if en in infobox_keys and infobox_keys[en]]
for ibs in infobox:
ibs[0] = '[INFO] ' + ibs[0]
ibs[-1] = ibs[-1] + ' [/INFO]'
infobox = [i for j in infobox for i in j]
doc["ctx_sent"] = doc["query"] + [sep_token] + infobox
return docs
# create context type for other scenarios
res = []
for doc in docs:
ctx_len = context_length if context_length > 0 else len(doc)
# for the last sentences loop around to the beginning for context
padded_doc = doc + doc[:ctx_len]
for i in range(len(doc)):
res.append((
[s[0] for sent in padded_doc[i:i+ctx_len] for s in sent],
[s[pos_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent] if pos_col_id > 0 else None,
[s[tag_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent],
[len(sent) for sent in padded_doc[i:i+ctx_len]],
{} # dictionary for extra context
))
return res
def calc_correct(sentence):
gold_chunks = []
parallel_chunks = []
pred_chunks = []
curr_gold_chunk = []
curr_parallel_chunk = []
curr_pred_chunk = []
prev_tag = None
for line in sentence:
_, _, _, gt, pt = line
curr_tag = None
if '-' in pt:
curr_tag = pt.split('-')[1]
if gt.startswith('B'):
if curr_gold_chunk:
gold_chunks.append(curr_gold_chunk)
parallel_chunks.append(curr_parallel_chunk)
curr_gold_chunk = [gt]
curr_parallel_chunk = [pt]
elif gt.startswith('I') or (pt.startswith('I') and curr_tag == prev_tag
and curr_gold_chunk):
curr_gold_chunk.append(gt)
curr_parallel_chunk.append(pt)
elif gt.startswith('O') and pt.startswith('O'):
if curr_gold_chunk:
gold_chunks.append(curr_gold_chunk)
parallel_chunks.append(curr_parallel_chunk)
curr_gold_chunk = []
curr_parallel_chunk = []
if pt.startswith('O'):
if curr_pred_chunk:
pred_chunks.append(curr_pred_chunk)
curr_pred_chunk = []
elif pt.startswith('B'):
if curr_pred_chunk:
pred_chunks.append(curr_pred_chunk)
curr_pred_chunk = [pt]
prev_tag = curr_tag
else:
if prev_tag is not None and curr_tag != prev_tag:
prev_tag = curr_tag
if curr_pred_chunk:
pred_chunks.append(curr_pred_chunk)
curr_pred_chunk = []
curr_pred_chunk.append(pt)
if curr_gold_chunk:
gold_chunks.append(curr_gold_chunk)
parallel_chunks.append(curr_parallel_chunk)
if curr_pred_chunk:
pred_chunks.append(curr_pred_chunk)
correct = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks)
if not len([1 for g, p in zip(gc, pc) if g != p])])
correct_tagless = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks)
if not len([1 for g, p in zip(gc, pc) if g[0] != p[0]])])
# return correct, gold_chunks, parallel_chunks, pred_chunks, ob1_correct, correct_tagless
return {'correct': correct,
'correct_tagless': correct_tagless,
'gold_count': len(gold_chunks),
'pred_count': len(pred_chunks)}
def tag_sentences(sentences):
nlp = stanza.Pipeline(lang='en', processors='tokenize,pos', logging_level='WARNING')
tagged_sents = []
for sentence in sentences:
n = nlp(sentence)
tagged_sent = []
for s in n.sentences:
for w in s.words:
tagged_sent.append([w.text, w.upos])
tagged_sents.append(tagged_sent)
return tagged_sents
def extract_spans(sentence, tagless=False):
spans_positions = []
span_bounds = []
all_bounds = []
span_tags = []
curr_tag = None
curr_span = []
curr_span_start = -1
# span ids, span types
for i, token in enumerate(sentence):
if token.startswith('B'):
if curr_span:
spans_positions.append([curr_span, len(all_bounds)])
span_bounds.append([curr_span_start, i-1])
all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)])
if not tagless:
span_tags.append(token.split('-')[1])
curr_span = []
curr_tag = None
curr_span.append(token)
curr_tag = None if tagless else token.split('-')[1]
curr_span_start = i
elif token.startswith('I'):
if not tagless:
tag = token.split('-')[1]
if tag != curr_tag and curr_tag is not None:
spans_positions.append([curr_span, len(all_bounds)])
span_bounds.append([curr_span_start, i - 1])
span_tags.append(token.split('-')[1])
all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)])
curr_span = []
curr_tag = tag
curr_span_start = i
elif curr_tag is None:
curr_span = []
curr_tag = tag
curr_span_start = i
elif not curr_span:
curr_span_start = i
curr_span.append(token)
elif token.startswith('O') or token.startswith('-'):
if curr_span:
spans_positions.append([curr_span, len(all_bounds)])
span_bounds.append([curr_span_start, i-1])
all_bounds.append([[curr_span_start, i-1], 'E', len(all_bounds)])
curr_span = []
curr_tag = None
all_bounds.append([[i], 'W', len(all_bounds)])
# check if sentence ended with a span
if curr_span:
spans_positions.append([curr_span, len(all_bounds)])
span_bounds.append([curr_span_start, len(sentence) - 1])
all_bounds.append([[curr_span_start, len(sentence) - 1], 'E', len(all_bounds)])
tagged_bounds = [[loc[0][0].split('-')[1] if '-' in loc[0][0] else loc[0][0], bound]
for loc, bound in zip(spans_positions, span_bounds)]
return spans_positions, span_bounds, all_bounds, tagged_bounds
def ner_corpus_stats(corpus_path):
onto_train_cols = read_conll_ner(corpus_path)
tags = list(set([t.split('-')[1] for t in onto_train_cols[3] if '-' in t]))
onto_train_spans = [extract_spans([t[3] for t in sent])[3] for sent in
onto_train_cols[0]]
span_lens = [span[1][1] - span[1][0] + 1 for sent in onto_train_spans for
span in sent]
len_stats = [span_lens.count(i + 1) / len(span_lens) for i in
range(max(span_lens))]
flat_spans = [span for sent in onto_train_spans for span in sent]
tag_lens_dict = {k: [] for k in tags}
tag_counts_dict = {k: 0 for k in tags}
for span in flat_spans:
span_length = span[1][1] - span[1][0] + 1
span_tag = span[0][0].split('-')[1]
tag_lens_dict[span_tag].append(span_length)
tag_counts_dict[span_tag] += 1
x = list(tag_counts_dict.items())
x.sort(key=lambda l: l[1])
tag_counts = [list(l) for l in x]
for l in tag_counts:
l[1] = l[1] / len(span_lens)
tag_len_stats = {k: [v.count(i + 1) / len(v) for i in range(max(v))]
for k, v in tag_lens_dict.items()}
span_texts = [sent[span[1][0]:span[1][1] + 1]
for sent, spans in zip(onto_train_cols[0], onto_train_spans)
for span in spans]
span_pos = [[span[0][-1].split('-')[1], '_'.join(t[1] for t in span)]
for span in span_texts]
unique_pos = list(set([span[1] for span in span_pos]))
pos_dict = {k: 0 for k in unique_pos}
for span in span_pos:
pos_dict[span[1]] += 1
unique_pos.sort(key=lambda l: pos_dict[l], reverse=True)
pos_stats = [[p, pos_dict[p] / len(span_pos)] for p in unique_pos]
tag_pos_dict = {kt: {kp: 0 for kp in unique_pos} for kt in tags}
for span in span_pos:
tag_pos_dict[span[0]][span[1]] += 1
tag_pos_stats = {kt: [[p, tag_pos_dict[kt][p] / tag_counts_dict[kt]]
for p in unique_pos] for kt in tags}
for kt in tags:
tag_pos_stats[kt].sort(key=lambda l: l[1], reverse=True)
return len_stats, tag_len_stats, tag_counts, pos_stats, tag_pos_stats
def filter_by_max_ents(sentences, max_ent_length):
"""
Filters a given list of sentences and only returns the sentences that have
named entities shorter than or equal to the given max_ent_length.
:param sentences: sentences in conll format as extracted by read_conll_ner
:param max_ent_length: The maximum number of tokens in an entity
:return: a lits of sentences
"""
filtered_sents = []
for sent in sentences:
sent_span_lens = [s[1] - s[0] + 1
for s in extract_spans([t[3] for t in sent])[1]]
if not sent_span_lens or max(sent_span_lens) <= max_ent_length:
filtered_sents.append(sent)
return filtered_sents
def get_pred_ent_bounds(docs):
for doc in docs:
eb = []
count = 0
for p_eb in doc['pred_spans']:
if p_eb == 'B':
eb.append([count,count])
elif p_eb == 'I' and len(eb) > 0:
eb[-1][1] = count
count += 1
doc['pred_ent_bounds'] = eb
return docs
def enumerate_spans(batch):
enumerated_spans_batch = []
for idx in range(0, len(batch)):
sentence_length = batch[idx]
enumerated_spans = []
for x in range(len(sentence_length)):
for y in range(x, len(sentence_length)):
enumerated_spans.append([x,y])
enumerated_spans_batch.append(enumerated_spans)
return enumerated_spans_batch
def compact_span_enumeration(batch):
sentence_lengths = [len(b) for b in batch]
enumerated_spans = [[[x, y]
for y in range(0, sentence_length)
for x in range(sentence_length)]
for sentence_length in sentence_lengths]
return enumerated_spans
def preprocess_data(data):
clean_data = []
for sample in data:
clean_tokens = [araby.strip_tashkeel(token) for token in sample[0]]
clean_tokens = [araby.strip_tatweel(token) for token in clean_tokens]
clean_sample = [clean_tokens]
clean_sample.extend(sample[1:])
clean_data.append(clean_sample)
return clean_data
def generate_targets(enumerated_spans, sentences):
#### could be refactored into a helper function ####
extracted_spans= [extract_spans(sentence,True)[3] for sentence in sentences]
target_locations = []
for span in extracted_spans:
sentence_locations = []
for location in span:
sentence_locations.append(location[1])
target_locations.append(sentence_locations)
#### could be refactored into a helper function ####
targets= []
for span, location_list in zip(enumerated_spans, target_locations):
span_arr = np.zeros_like(span).tolist()
target_indices = [span.index(span_location) for
span_location in location_list]
for idx in target_indices:
span_arr[idx] =1
span_arr = [0 if x!=1 else x for x in span_arr]
targets.append(list(span_arr))
return targets
def label_tags(tags):
output_tags = []
for tag in tags:
if (tag == "O"):
output_tags.append(0)
else:
output_tags.append(1)
return output_tags