import numpy as np class Config: def __init__(self): super(Config, self).__init__() def read_conll_ner(path): with open(path) as f: lines = f.readlines() unique_entries = [] sentences = [] curr_sentence = [] for line in lines: if not line.strip(): if curr_sentence: sentences.append(curr_sentence) curr_sentence = [] continue entry = line.split() curr_sentence.append(entry) if not len(unique_entries): unique_entries = [[] for _ in entry[1:]] for e, list in zip(entry[1:], unique_entries): if e not in list: list.append(e) return [sentences] + unique_entries def read_pickled_conll(path): with open(path, "rb") as f: data = pickle.load(f) return data def split_conll_docs(conll_sents, skip_docstart=True): docs = [] curr_doc = [] for sent in conll_sents: if sent[0][0] == '-DOCSTART-': if curr_doc: docs.append(curr_doc) curr_doc = [] if skip_docstart: continue curr_doc.append(sent) docs.append(curr_doc) return docs def create_context_data(docs, pos_col_id=1, tag_col_id=3, context_length=1, **kwargs): ctx_type = kwargs.get("ctx_type", "other") sep_token = kwargs.get("sep_token", "[SEP]") if ctx_type == "cand_titles": # create context for candidate titles scenario for doc in docs: doc["ctx_sent"] = doc["query"] + [sep_token] + f"{sep_token}".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("") return docs if ctx_type == "cand_links": for doc in docs: doc_titles_list = f"{sep_token}".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("") linked_titles_list = f"{sep_token}".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("") doc["ctx_sent"] = doc["query"] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list return docs if ctx_type == "raw_text": # create context for candidate raw text for doc in docs: doc["ctx_sent"] = [doc["query"] + [sep_token] + [cand["processed_text"]] for cand in doc["BM25_cands"]] return docs if ctx_type == 'matched_spans': matched_spans = kwargs.get('matched_spans') return [ [[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans None, # pos tags [s[tag_col_id] for s in d] if tag_col_id > 0 else None, # ner tags [len(d)] # sentence length ] for d, ms in zip(docs, matched_spans)] if ctx_type == 'bm25_matched_spans': matched_spans = kwargs.get('matched_spans') pickled_data = kwargs.get('pickled_data') docs = [[[t[0] for t in d] + [t for ms in ms for t in [sep_token] + ms[1]], # sentence tokens + spans None, # pos tags [s[tag_col_id] for s in d], # ner tags [len(d)] # sentence length ] for d, ms in zip(docs, matched_spans)] for ms, doc in zip(docs,pickled_data): doc_titles_list = f"{sep_token}".join([cand["doc_title"] for cand in doc["BM25_cands"]]).split("") linked_titles_list = f"{sep_token}".join([linked for cand in doc["BM25_cands"] for linked in cand["linked_titles"]]).split("") ms[0] = ms[0] + [sep_token] + doc_titles_list + [sep_token] + linked_titles_list return docs if ctx_type == "infobox": infobox_keys_path = kwargs.get("infobox_keys_path") infobox_keys = read_pickled_conll(infobox_keys_path) if 'pred_spans' in docs[0]: docs = get_pred_ent_bounds(docs) for doc in docs: if 'pred_spans' in doc: ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']] ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['pred_ent_bounds']] else: ents = [' '.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']] ents_wo_space = [''.join(doc['query'][bd[0]:bd[1] + 1]) for bd in doc['ent_bounds']] ents = list(set(ents + ents_wo_space)) infobox = [infobox_keys[en] for en in ents if en in infobox_keys and infobox_keys[en]] for ibs in infobox: ibs[0] = '[INFO] ' + ibs[0] ibs[-1] = ibs[-1] + ' [/INFO]' infobox = [i for j in infobox for i in j] doc["ctx_sent"] = doc["query"] + [sep_token] + infobox return docs # create context type for other scenarios res = [] for doc in docs: ctx_len = context_length if context_length > 0 else len(doc) # for the last sentences loop around to the beginning for context padded_doc = doc + doc[:ctx_len] for i in range(len(doc)): res.append(( [s[0] for sent in padded_doc[i:i+ctx_len] for s in sent], [s[pos_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent] if pos_col_id > 0 else None, [s[tag_col_id] for sent in padded_doc[i:i+ctx_len] for s in sent], [len(sent) for sent in padded_doc[i:i+ctx_len]], {} # dictionary for extra context )) return res def calc_correct(sentence): gold_chunks = [] parallel_chunks = [] pred_chunks = [] curr_gold_chunk = [] curr_parallel_chunk = [] curr_pred_chunk = [] prev_tag = None for line in sentence: _, _, _, gt, pt = line curr_tag = None if '-' in pt: curr_tag = pt.split('-')[1] if gt.startswith('B'): if curr_gold_chunk: gold_chunks.append(curr_gold_chunk) parallel_chunks.append(curr_parallel_chunk) curr_gold_chunk = [gt] curr_parallel_chunk = [pt] elif gt.startswith('I') or (pt.startswith('I') and curr_tag == prev_tag and curr_gold_chunk): curr_gold_chunk.append(gt) curr_parallel_chunk.append(pt) elif gt.startswith('O') and pt.startswith('O'): if curr_gold_chunk: gold_chunks.append(curr_gold_chunk) parallel_chunks.append(curr_parallel_chunk) curr_gold_chunk = [] curr_parallel_chunk = [] if pt.startswith('O'): if curr_pred_chunk: pred_chunks.append(curr_pred_chunk) curr_pred_chunk = [] elif pt.startswith('B'): if curr_pred_chunk: pred_chunks.append(curr_pred_chunk) curr_pred_chunk = [pt] prev_tag = curr_tag else: if prev_tag is not None and curr_tag != prev_tag: prev_tag = curr_tag if curr_pred_chunk: pred_chunks.append(curr_pred_chunk) curr_pred_chunk = [] curr_pred_chunk.append(pt) if curr_gold_chunk: gold_chunks.append(curr_gold_chunk) parallel_chunks.append(curr_parallel_chunk) if curr_pred_chunk: pred_chunks.append(curr_pred_chunk) correct = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks) if not len([1 for g, p in zip(gc, pc) if g != p])]) correct_tagless = sum([1 for gc, pc in zip(gold_chunks, parallel_chunks) if not len([1 for g, p in zip(gc, pc) if g[0] != p[0]])]) # return correct, gold_chunks, parallel_chunks, pred_chunks, ob1_correct, correct_tagless return {'correct': correct, 'correct_tagless': correct_tagless, 'gold_count': len(gold_chunks), 'pred_count': len(pred_chunks)} def tag_sentences(sentences): nlp = stanza.Pipeline(lang='en', processors='tokenize,pos', logging_level='WARNING') tagged_sents = [] for sentence in sentences: n = nlp(sentence) tagged_sent = [] for s in n.sentences: for w in s.words: tagged_sent.append([w.text, w.upos]) tagged_sents.append(tagged_sent) return tagged_sents def extract_spans(sentence, tagless=False): spans_positions = [] span_bounds = [] all_bounds = [] span_tags = [] curr_tag = None curr_span = [] curr_span_start = -1 # span ids, span types for i, token in enumerate(sentence): if token.startswith('B'): if curr_span: spans_positions.append([curr_span, len(all_bounds)]) span_bounds.append([curr_span_start, i-1]) all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)]) if not tagless: span_tags.append(token.split('-')[1]) curr_span = [] curr_tag = None curr_span.append(token) curr_tag = None if tagless else token.split('-')[1] curr_span_start = i elif token.startswith('I'): if not tagless: tag = token.split('-')[1] if tag != curr_tag and curr_tag is not None: spans_positions.append([curr_span, len(all_bounds)]) span_bounds.append([curr_span_start, i - 1]) span_tags.append(token.split('-')[1]) all_bounds.append([[curr_span_start, i - 1], 'E', len(all_bounds)]) curr_span = [] curr_tag = tag curr_span_start = i elif curr_tag is None: curr_span = [] curr_tag = tag curr_span_start = i elif not curr_span: curr_span_start = i curr_span.append(token) elif token.startswith('O') or token.startswith('-'): if curr_span: spans_positions.append([curr_span, len(all_bounds)]) span_bounds.append([curr_span_start, i-1]) all_bounds.append([[curr_span_start, i-1], 'E', len(all_bounds)]) curr_span = [] curr_tag = None all_bounds.append([[i], 'W', len(all_bounds)]) # check if sentence ended with a span if curr_span: spans_positions.append([curr_span, len(all_bounds)]) span_bounds.append([curr_span_start, len(sentence) - 1]) all_bounds.append([[curr_span_start, len(sentence) - 1], 'E', len(all_bounds)]) tagged_bounds = [[loc[0][0].split('-')[1] if '-' in loc[0][0] else loc[0][0], bound] for loc, bound in zip(spans_positions, span_bounds)] return spans_positions, span_bounds, all_bounds, tagged_bounds def ner_corpus_stats(corpus_path): onto_train_cols = read_conll_ner(corpus_path) tags = list(set([t.split('-')[1] for t in onto_train_cols[3] if '-' in t])) onto_train_spans = [extract_spans([t[3] for t in sent])[3] for sent in onto_train_cols[0]] span_lens = [span[1][1] - span[1][0] + 1 for sent in onto_train_spans for span in sent] len_stats = [span_lens.count(i + 1) / len(span_lens) for i in range(max(span_lens))] flat_spans = [span for sent in onto_train_spans for span in sent] tag_lens_dict = {k: [] for k in tags} tag_counts_dict = {k: 0 for k in tags} for span in flat_spans: span_length = span[1][1] - span[1][0] + 1 span_tag = span[0][0].split('-')[1] tag_lens_dict[span_tag].append(span_length) tag_counts_dict[span_tag] += 1 x = list(tag_counts_dict.items()) x.sort(key=lambda l: l[1]) tag_counts = [list(l) for l in x] for l in tag_counts: l[1] = l[1] / len(span_lens) tag_len_stats = {k: [v.count(i + 1) / len(v) for i in range(max(v))] for k, v in tag_lens_dict.items()} span_texts = [sent[span[1][0]:span[1][1] + 1] for sent, spans in zip(onto_train_cols[0], onto_train_spans) for span in spans] span_pos = [[span[0][-1].split('-')[1], '_'.join(t[1] for t in span)] for span in span_texts] unique_pos = list(set([span[1] for span in span_pos])) pos_dict = {k: 0 for k in unique_pos} for span in span_pos: pos_dict[span[1]] += 1 unique_pos.sort(key=lambda l: pos_dict[l], reverse=True) pos_stats = [[p, pos_dict[p] / len(span_pos)] for p in unique_pos] tag_pos_dict = {kt: {kp: 0 for kp in unique_pos} for kt in tags} for span in span_pos: tag_pos_dict[span[0]][span[1]] += 1 tag_pos_stats = {kt: [[p, tag_pos_dict[kt][p] / tag_counts_dict[kt]] for p in unique_pos] for kt in tags} for kt in tags: tag_pos_stats[kt].sort(key=lambda l: l[1], reverse=True) return len_stats, tag_len_stats, tag_counts, pos_stats, tag_pos_stats def filter_by_max_ents(sentences, max_ent_length): """ Filters a given list of sentences and only returns the sentences that have named entities shorter than or equal to the given max_ent_length. :param sentences: sentences in conll format as extracted by read_conll_ner :param max_ent_length: The maximum number of tokens in an entity :return: a lits of sentences """ filtered_sents = [] for sent in sentences: sent_span_lens = [s[1] - s[0] + 1 for s in extract_spans([t[3] for t in sent])[1]] if not sent_span_lens or max(sent_span_lens) <= max_ent_length: filtered_sents.append(sent) return filtered_sents def get_pred_ent_bounds(docs): for doc in docs: eb = [] count = 0 for p_eb in doc['pred_spans']: if p_eb == 'B': eb.append([count,count]) elif p_eb == 'I' and len(eb) > 0: eb[-1][1] = count count += 1 doc['pred_ent_bounds'] = eb return docs def enumerate_spans(batch): enumerated_spans_batch = [] for idx in range(0, len(batch)): sentence_length = batch[idx] enumerated_spans = [] for x in range(len(sentence_length)): for y in range(x, len(sentence_length)): enumerated_spans.append([x,y]) enumerated_spans_batch.append(enumerated_spans) return enumerated_spans_batch def compact_span_enumeration(batch): sentence_lengths = [len(b) for b in batch] enumerated_spans = [[[x, y] for y in range(0, sentence_length) for x in range(sentence_length)] for sentence_length in sentence_lengths] return enumerated_spans def preprocess_data(data): clean_data = [] for sample in data: clean_tokens = [araby.strip_tashkeel(token) for token in sample[0]] clean_tokens = [araby.strip_tatweel(token) for token in clean_tokens] clean_sample = [clean_tokens] clean_sample.extend(sample[1:]) clean_data.append(clean_sample) return clean_data def generate_targets(enumerated_spans, sentences): #### could be refactored into a helper function #### extracted_spans= [extract_spans(sentence,True)[3] for sentence in sentences] target_locations = [] for span in extracted_spans: sentence_locations = [] for location in span: sentence_locations.append(location[1]) target_locations.append(sentence_locations) #### could be refactored into a helper function #### targets= [] for span, location_list in zip(enumerated_spans, target_locations): span_arr = np.zeros_like(span).tolist() target_indices = [span.index(span_location) for span_location in location_list] for idx in target_indices: span_arr[idx] =1 span_arr = [0 if x!=1 else x for x in span_arr] targets.append(list(span_arr)) return targets def label_tags(tags): output_tags = [] for tag in tags: if (tag == "O"): output_tags.append(0) else: output_tags.append(1) return output_tags