Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
class AbstractParser: | |
def __init__(self, dataset): | |
self.dataset = dataset | |
def create_nodes(self, prediction): | |
return [ | |
{"id": i, "label": self.label_to_str(l, prediction["anchors"][i], prediction)} | |
for i, l in enumerate(prediction["labels"]) | |
] | |
def label_to_str(self, label, anchors, prediction): | |
return self.dataset.label_field.vocab.itos[label - 1] | |
def create_edges(self, prediction, nodes): | |
N = len(nodes) | |
node_sets = [{"id": n, "set": set([n])} for n in range(N)] | |
_, indices = prediction["edge presence"][:N, :N].reshape(-1).sort(descending=True) | |
sources, targets = indices // N, indices % N | |
edges = [] | |
for i in range((N - 1) * N // 2): | |
source, target = sources[i].item(), targets[i].item() | |
p = prediction["edge presence"][source, target] | |
if p < 0.5 and len(edges) >= N - 1: | |
break | |
if node_sets[source]["set"] is node_sets[target]["set"] and p < 0.5: | |
continue | |
self.create_edge(source, target, prediction, edges, nodes) | |
if node_sets[source]["set"] is not node_sets[target]["set"]: | |
from_set = node_sets[source]["set"] | |
for n in node_sets[target]["set"]: | |
from_set.add(n) | |
node_sets[n]["set"] = from_set | |
return edges | |
def create_edge(self, source, target, prediction, edges, nodes): | |
label = self.get_edge_label(prediction, source, target) | |
edge = {"source": source, "target": target, "label": label} | |
edges.append(edge) | |
def create_anchors(self, prediction, nodes, join_contiguous=True, at_least_one=False, single_anchor=False, mode="anchors"): | |
for i, node in enumerate(nodes): | |
threshold = 0.5 if not at_least_one else min(0.5, prediction[mode][i].max().item()) | |
node[mode] = (prediction[mode][i] >= threshold).nonzero(as_tuple=False).squeeze(-1) | |
node[mode] = prediction["token intervals"][node[mode], :] | |
if single_anchor and len(node[mode]) > 1: | |
start = min(a[0].item() for a in node[mode]) | |
end = max(a[1].item() for a in node[mode]) | |
node[mode] = [{"from": start, "to": end}] | |
continue | |
node[mode] = [{"from": f.item(), "to": t.item()} for f, t in node[mode]] | |
node[mode] = sorted(node[mode], key=lambda a: a["from"]) | |
if join_contiguous and len(node[mode]) > 1: | |
cleaned_anchors = [] | |
end, start = node[mode][0]["from"], node[mode][0]["from"] | |
for anchor in node[mode]: | |
if end < anchor["from"]: | |
cleaned_anchors.append({"from": start, "to": end}) | |
start = anchor["from"] | |
end = anchor["to"] | |
cleaned_anchors.append({"from": start, "to": end}) | |
node[mode] = cleaned_anchors | |
return nodes | |
def get_edge_label(self, prediction, source, target): | |
return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].item()] | |