ssa-perin / data /dataset.py
larkkin's picture
Add supporting code from perin
7daaa6b
#!/usr/bin/env python3
# coding=utf-8
import pickle
import torch
from data.parser.from_mrp.node_centric_parser import NodeCentricParser
from data.parser.from_mrp.labeled_edge_parser import LabeledEdgeParser
from data.parser.from_mrp.sequential_parser import SequentialParser
from data.parser.from_mrp.evaluation_parser import EvaluationParser
from data.parser.from_mrp.request_parser import RequestParser
from data.field.edge_field import EdgeField
from data.field.edge_label_field import EdgeLabelField
from data.field.field import Field
from data.field.mini_torchtext.field import Field as TorchTextField
from data.field.label_field import LabelField
from data.field.anchored_label_field import AnchoredLabelField
from data.field.nested_field import NestedField
from data.field.basic_field import BasicField
from data.field.bert_field import BertField
from data.field.anchor_field import AnchorField
from data.batch import Batch
def char_tokenize(word):
return [c for i, c in enumerate(word)] # if i < 10 or len(word) - i <= 10]
class Collate:
def __call__(self, batch):
batch.sort(key=lambda example: example["every_input"][0].size(0), reverse=True)
return Batch.build(batch)
class Dataset:
def __init__(self, args, verbose=True):
self.verbose = verbose
self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>"
self.bert_input_field = BertField()
self.scatter_field = BasicField()
self.every_word_input_field = Field(lower=True, init_token=self.sos, eos_token=self.eos, batch_first=True, include_lengths=True)
char_form_nesting = TorchTextField(tokenize=char_tokenize, init_token=self.sos, eos_token=self.eos, batch_first=True)
self.char_form_field = NestedField(char_form_nesting, include_lengths=True)
self.label_field = LabelField(preprocessing=lambda nodes: [n["label"] for n in nodes])
self.anchored_label_field = AnchoredLabelField()
self.id_field = Field(batch_first=True, tokenize=lambda x: [x])
self.edge_presence_field = EdgeField()
self.edge_label_field = EdgeLabelField()
self.anchor_field = AnchorField()
self.source_anchor_field = AnchorField()
self.target_anchor_field = AnchorField()
self.token_interval_field = BasicField()
self.load_dataset(args)
def log(self, text):
if not self.verbose:
return
print(text, flush=True)
def load_state_dict(self, args, d):
for key, value in d["vocabs"].items():
getattr(self, key).vocab = pickle.loads(value)
def state_dict(self):
return {
"vocabs": {key: pickle.dumps(value.vocab) for key, value in self.__dict__.items() if hasattr(value, "vocab")}
}
def load_sentences(self, sentences, args):
dataset = RequestParser(
sentences, args,
fields={
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
"bert input": ("input", self.bert_input_field),
"to scatter": ("input_scatter", self.scatter_field),
"token anchors": ("token_intervals", self.token_interval_field),
"id": ("id", self.id_field),
},
)
self.every_word_input_field.build_vocab(dataset, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
self.id_field.build_vocab(dataset, min_freq=1, specials=[])
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=Collate())
def load_dataset(self, args):
parser = {
"sequential": SequentialParser,
"node-centric": NodeCentricParser,
"labeled-edge": LabeledEdgeParser
}[args.graph_mode]
train = parser(
args, "training",
fields={
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
"bert input": ("input", self.bert_input_field),
"to scatter": ("input_scatter", self.scatter_field),
"nodes": ("labels", self.label_field),
"anchored labels": ("anchored_labels", self.anchored_label_field),
"edge presence": ("edge_presence", self.edge_presence_field),
"edge labels": ("edge_labels", self.edge_label_field),
"anchor edges": ("anchor", self.anchor_field),
"source anchor edges": ("source_anchor", self.source_anchor_field),
"target anchor edges": ("target_anchor", self.target_anchor_field),
"token anchors": ("token_intervals", self.token_interval_field),
"id": ("id", self.id_field),
},
filter_pred=lambda example: len(example.input) <= 256,
)
val = parser(
args, "validation",
fields={
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
"bert input": ("input", self.bert_input_field),
"to scatter": ("input_scatter", self.scatter_field),
"nodes": ("labels", self.label_field),
"anchored labels": ("anchored_labels", self.anchored_label_field),
"edge presence": ("edge_presence", self.edge_presence_field),
"edge labels": ("edge_labels", self.edge_label_field),
"anchor edges": ("anchor", self.anchor_field),
"source anchor edges": ("source_anchor", self.source_anchor_field),
"target anchor edges": ("target_anchor", self.target_anchor_field),
"token anchors": ("token_intervals", self.token_interval_field),
"id": ("id", self.id_field),
},
)
test = EvaluationParser(
args,
fields={
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
"bert input": ("input", self.bert_input_field),
"to scatter": ("input_scatter", self.scatter_field),
"token anchors": ("token_intervals", self.token_interval_field),
"id": ("id", self.id_field),
},
)
del train.data, val.data, test.data # TODO: why?
for f in list(train.fields.values()) + list(val.fields.values()) + list(test.fields.values()): # TODO: why?
if hasattr(f, "preprocessing"):
del f.preprocessing
self.train_size = len(train)
self.val_size = len(val)
self.test_size = len(test)
self.log(f"\n{self.train_size} sentences in the train split")
self.log(f"{self.val_size} sentences in the validation split")
self.log(f"{self.test_size} sentences in the test split")
self.node_count = train.node_counter
self.token_count = train.input_count
self.edge_count = train.edge_counter
self.no_edge_count = train.no_edge_counter
self.anchor_freq = train.anchor_freq
self.source_anchor_freq = train.source_anchor_freq if hasattr(train, "source_anchor_freq") else 0.5
self.target_anchor_freq = train.target_anchor_freq if hasattr(train, "target_anchor_freq") else 0.5
self.log(f"{self.node_count} nodes in the train split")
self.every_word_input_field.build_vocab(val, test, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
self.char_form_field.build_vocab(train, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
self.char_form_field.nesting_field.vocab = self.char_form_field.vocab
self.id_field.build_vocab(train, val, test, min_freq=1, specials=[])
self.label_field.build_vocab(train)
self.anchored_label_field.vocab = self.label_field.vocab
self.edge_label_field.build_vocab(train)
print(list(self.edge_label_field.vocab.freqs.keys()), flush=True)
self.char_form_vocab_size = len(self.char_form_field.vocab)
self.create_label_freqs(args)
self.create_edge_freqs(args)
self.log(f"Edge frequency: {self.edge_presence_freq*100:.2f} %")
self.log(f"{len(self.label_field.vocab)} words in the label vocabulary")
self.log(f"{len(self.anchored_label_field.vocab)} words in the anchored label vocabulary")
self.log(f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary")
self.log(f"{len(self.char_form_field.vocab)} characters in the vocabulary")
self.log(self.label_field.vocab.freqs)
self.log(self.anchored_label_field.vocab.freqs)
self.train = torch.utils.data.DataLoader(
train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
collate_fn=Collate(),
pin_memory=True,
drop_last=True
)
self.train_size = len(self.train.dataset)
self.val = torch.utils.data.DataLoader(
val,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
collate_fn=Collate(),
pin_memory=True,
)
self.val_size = len(self.val.dataset)
self.test = torch.utils.data.DataLoader(
test,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
collate_fn=Collate(),
pin_memory=True,
)
self.test_size = len(self.test.dataset)
if self.verbose:
batch = next(iter(self.train))
print(f"\nBatch content: {Batch.to_str(batch)}\n")
print(flush=True)
def create_label_freqs(self, args):
n_rules = len(self.label_field.vocab)
blank_count = (args.query_length * self.token_count - self.node_count)
label_counts = [blank_count] + [
self.label_field.vocab.freqs[self.label_field.vocab.itos[i]]
for i in range(n_rules)
]
label_counts = torch.FloatTensor(label_counts)
self.label_freqs = label_counts / (self.node_count + blank_count)
self.log(f"Label frequency: {self.label_freqs}")
def create_edge_freqs(self, args):
edge_counter = [
self.edge_label_field.vocab.freqs[self.edge_label_field.vocab.itos[i]] for i in range(len(self.edge_label_field.vocab))
]
edge_counter = torch.FloatTensor(edge_counter)
self.edge_label_freqs = edge_counter / self.edge_count
self.edge_presence_freq = self.edge_count / (self.edge_count + self.no_edge_count)