File size: 2,793 Bytes
7daaa6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
# coding=utf-8

import json
from itertools import chain
from transformers import AutoTokenizer

from utility.subtokenize import subtokenize

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"


def load_dataset(path):
    data = {}
    with open(path, encoding="utf8") as f:
        for sentence in f.readlines():
            sentence = json.loads(sentence)
            data[sentence["id"]] = sentence

            if "nodes" not in sentence:
                sentence["nodes"] = []

            if "edges" not in sentence:
                sentence["edges"] = []

    for sample in list(data.values()):
        sample["sentence"] = sample["input"]
        sample["input"] = sample["sentence"].split(' ')
        sample["token anchors"], offset = [], 0
        for token in sample["input"]:
            sample["token anchors"].append({"from": offset, "to": offset + len(token)})
            offset += len(token) + 1
    return data


def node_generator(data):
    for d in data.values():
        for n in d["nodes"]:
            yield n, d


def anchor_ids_from_intervals(data):
    for node, sentence in node_generator(data):
        if "anchors" not in node:
            node["anchors"] = []
        node["anchors"] = sorted(node["anchors"], key=lambda a: (a["from"], a["to"]))
        node["token references"] = set()

        for anchor in node["anchors"]:
            for i, token_anchor in enumerate(sentence["token anchors"]):
                if token_anchor["to"] <= anchor["from"]:
                    continue
                if token_anchor["from"] >= anchor["to"]:
                    break

                node["token references"].add(i)

        node["anchor intervals"] = node["anchors"]
        node["anchors"] = sorted(list(node["token references"]))
        del node["token references"]

    for sentence in data.values():
        sentence["token anchors"] = [[a["from"], a["to"]] for a in sentence["token anchors"]]


def create_bert_tokens(data, encoder: str):
    tokenizer = AutoTokenizer.from_pretrained(encoder, use_fast=True)

    for sentence in data.values():
        sentence["bert input"], sentence["to scatter"] = subtokenize(sentence["input"], tokenizer)


def create_edges(sentence, label_f=None):
    N = len(sentence["nodes"])

    sentence["edge presence"] = [N, N, []]
    sentence["edge labels"] = [N, N, []]

    for e in sentence["edges"]:
        source, target = e["source"], e["target"]
        label = e["label"] if "label" in e else "none"

        if label_f is not None:
            label = label_f(label)

        sentence["edge presence"][-1].append((source, target, 1))
        sentence["edge labels"][-1].append((source, target, label))

    edge_counter = len(sentence["edge presence"][-1])
    return edge_counter