Spaces:
Runtime error
Runtime error
Add application code and models, update README
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +39 -7
- app.py +80 -0
- config/__init__.py +0 -0
- config/params.py +89 -0
- data/__init__.py +0 -0
- data/batch.py +95 -0
- data/dataset.py +245 -0
- data/field/__init__.py +0 -0
- data/field/anchor_field.py +19 -0
- data/field/anchored_label_field.py +38 -0
- data/field/basic_field.py +11 -0
- data/field/bert_field.py +18 -0
- data/field/edge_field.py +63 -0
- data/field/edge_label_field.py +67 -0
- data/field/field.py +70 -0
- data/field/label_field.py +36 -0
- data/field/mini_torchtext/example.py +100 -0
- data/field/mini_torchtext/field.py +637 -0
- data/field/mini_torchtext/pipeline.py +86 -0
- data/field/mini_torchtext/utils.py +256 -0
- data/field/mini_torchtext/vocab.py +116 -0
- data/field/nested_field.py +50 -0
- data/parser/__init__.py +0 -0
- data/parser/from_mrp/__init__.py +0 -0
- data/parser/from_mrp/abstract_parser.py +50 -0
- data/parser/from_mrp/evaluation_parser.py +18 -0
- data/parser/from_mrp/labeled_edge_parser.py +70 -0
- data/parser/from_mrp/node_centric_parser.py +69 -0
- data/parser/from_mrp/request_parser.py +23 -0
- data/parser/from_mrp/sequential_parser.py +90 -0
- data/parser/json_parser.py +35 -0
- data/parser/to_mrp/__init__.py +0 -0
- data/parser/to_mrp/abstract_parser.py +80 -0
- data/parser/to_mrp/labeled_edge_parser.py +52 -0
- data/parser/to_mrp/node_centric_parser.py +35 -0
- data/parser/to_mrp/sequential_parser.py +35 -0
- model/__init__.py +0 -0
- model/head/__init__.py +0 -0
- model/head/abstract_head.py +274 -0
- model/head/labeled_edge_head.py +67 -0
- model/head/node_centric_head.py +25 -0
- model/head/sequential_head.py +24 -0
- model/model.py +82 -0
- model/module/__init__.py +0 -0
- model/module/anchor_classifier.py +32 -0
- model/module/biaffine.py +20 -0
- model/module/bilinear.py +43 -0
- model/module/char_embedding.py +42 -0
- model/module/edge_classifier.py +56 -0
- model/module/encoder.py +95 -0
README.md
CHANGED
@@ -1,13 +1,45 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: cc-by-4.0
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Sentiment Analysis
|
3 |
+
emoji: 🤔
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.1.7
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
+
This space provides a gradio demo and an easy-to-run wrapper of the pre-trained model for structured sentiment analysis in Norwegian language, pre-trained on the [NoReC dataset](https://huggingface.co/datasets/norec).
|
13 |
+
This space containt an implementation of method described in "Direct parsing to sentiment graphs" (Samuel _et al._, ACL 2022). The main repository that also contains the scripts for training the model, can be found on the project [github](https://github.com/jerbarnes/direct_parsing_to_sent_graph).
|
14 |
+
|
15 |
+
The proposed method suggests three different ways to encode the sentiment graph: "node-centric", "labeled-edge", and "opinion-tuple". The current model uses the "labeled-edge" graph encoding, and achieves the following results on the held-out set of the NoReC dataset:
|
16 |
+
|
17 |
+
| Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
|
18 |
+
|:----------------------------:|:----------:|:---------------------------:|
|
19 |
+
| 0.434 | 0.541 | 0.926 |
|
20 |
+
|
21 |
+
|
22 |
+
In "Word Substitution with Masked Language Models as Data Augmentation for Sentiment Analysis", we analyzed data augmentation strategies for improving performance of the model. Using masked-language modeling (MLM), we augmented the sentences with MLM-substituted words inside, outside, or inside+outside the actual sentiment tuples. The results below show that augmentation may be improve the model performance. This space, however, runs the original model trained without augmentation.
|
23 |
+
|
24 |
+
| | Augmentation rate | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
|
25 |
+
|----------------|-------------------|------------------------------|-----------|-----------------------------|
|
26 |
+
| Baseline | 0% | 43.39 | 54.13 | 92.59 |
|
27 |
+
| Outside | 59% | **45.08** | 56.18 | 92.95 |
|
28 |
+
| Inside | 9% | 43.38 | 55.62 | 92.49 |
|
29 |
+
| Inside+Outside | 27% | 44.12 | **56.44** | **93.19** |
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
The model can be easily used for predicting sentiment tuples as follows:
|
34 |
+
|
35 |
+
```python
|
36 |
+
>>> import model_wrapper
|
37 |
+
>>> model = model_wrapper.PredictionModel()
|
38 |
+
>>> model.predict(['vi liker svart kaffe'])
|
39 |
+
[{'sent_id': '0',
|
40 |
+
'text': 'vi liker svart kaffe',
|
41 |
+
'opinions': [{'Source': [['vi'], ['0:2']],
|
42 |
+
'Target': [['svart', 'kaffe'], ['9:14', '15:20']],
|
43 |
+
'Polar_expression': [['liker'], ['3:8']],
|
44 |
+
'Polarity': 'Positive'}]}]
|
45 |
+
```
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import model_wrapper
|
3 |
+
|
4 |
+
|
5 |
+
model = model_wrapper.PredictionModel()
|
6 |
+
|
7 |
+
|
8 |
+
def pretty_print_opinion(opinion_dict):
|
9 |
+
res = []
|
10 |
+
maxlen = max([len(key) for key in opinion_dict.keys()]) + 2
|
11 |
+
maxlen = 0
|
12 |
+
for key, value in opinion_dict.items():
|
13 |
+
if key == 'Polarity':
|
14 |
+
res.append(f'{(key + ":").ljust(maxlen)} {value}')
|
15 |
+
else:
|
16 |
+
res.append(f'{(key + ":").ljust(maxlen)} \'{" ".join(value[0])}\'')
|
17 |
+
return '\n'.join(res) + '\n'
|
18 |
+
|
19 |
+
|
20 |
+
def predict(text):
|
21 |
+
print(f'Input message "{text}"')
|
22 |
+
try:
|
23 |
+
predictions = model.predict([text])
|
24 |
+
prediction = predictions[0]
|
25 |
+
results = []
|
26 |
+
if not prediction['opinions']:
|
27 |
+
return 'No opinions detected'
|
28 |
+
for opinion in prediction['opinions']:
|
29 |
+
results.append(pretty_print_opinion(opinion))
|
30 |
+
print(f'Successfully predicted SA for input message "{text}": {results}')
|
31 |
+
return '\n'.join(results)
|
32 |
+
except Exception as e:
|
33 |
+
print(f'Error for input message "{text}": {e}')
|
34 |
+
raise e
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
markdown_text = '''
|
39 |
+
<br>
|
40 |
+
<br>
|
41 |
+
This space provides a gradio demo and an easy-to-run wrapper of the pre-trained model for structured sentiment analysis in Norwegian language, pre-trained on the [NoReC dataset](https://huggingface.co/datasets/norec).
|
42 |
+
This model is an implementation of the paper "Direct parsing to sentiment graphs" (Samuel _et al._, ACL 2022). The main repository that also contains the scripts for training the model, can be found on the project [github](https://github.com/jerbarnes/direct_parsing_to_sent_graph).
|
43 |
+
|
44 |
+
The current model uses the 'labeled-edge' graph encoding, and achieves the following results on the NoReC dataset:
|
45 |
+
|
46 |
+
| Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
|
47 |
+
|:----------------------------:|:----------:|:---------------------------:|
|
48 |
+
| 0.393 | 0.468 | 0.939 |
|
49 |
+
|
50 |
+
|
51 |
+
The model can be easily used for predicting sentiment tuples as follows:
|
52 |
+
|
53 |
+
```python
|
54 |
+
>>> import model_wrapper
|
55 |
+
>>> model = model_wrapper.PredictionModel()
|
56 |
+
>>> model.predict(['vi liker svart kaffe'])
|
57 |
+
[{'sent_id': '0',
|
58 |
+
'text': 'vi liker svart kaffe',
|
59 |
+
'opinions': [{'Source': [['vi'], ['0:2']],
|
60 |
+
'Target': [['svart', 'kaffe'], ['9:14', '15:20']],
|
61 |
+
'Polar_expression': [['liker'], ['3:8']],
|
62 |
+
'Polarity': 'Positive'}]}]
|
63 |
+
```
|
64 |
+
'''
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
with gr.Blocks() as demo:
|
69 |
+
with gr.Row() as row:
|
70 |
+
text_input = gr.Textbox(label="input")
|
71 |
+
text_output = gr.Textbox(label="output")
|
72 |
+
with gr.Row() as row:
|
73 |
+
text_button = gr.Button("submit")
|
74 |
+
|
75 |
+
text_button.click(fn=predict, inputs=text_input, outputs=text_output)
|
76 |
+
|
77 |
+
gr.Markdown(markdown_text)
|
78 |
+
|
79 |
+
|
80 |
+
demo.launch()
|
config/__init__.py
ADDED
File without changes
|
config/params.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
|
3 |
+
|
4 |
+
class Params:
|
5 |
+
def __init__(self):
|
6 |
+
self.graph_mode = "sequential" # possibilities: {sequential, node-centric, edge-labeled}
|
7 |
+
self.accumulation_steps = 1 # number of gradient accumulation steps for achieving a bigger batch_size
|
8 |
+
self.activation = "relu" # transformer (decoder) activation function, supported values: {'relu', 'gelu', 'sigmoid', 'mish'}
|
9 |
+
self.predict_intensity = False
|
10 |
+
self.batch_size = 32 # batch size (further divided into multiple GPUs)
|
11 |
+
self.beta_2 = 0.98 # beta 2 parameter for Adam(W) optimizer
|
12 |
+
self.blank_weight = 1.0 # weight of cross-entropy loss for predicting an empty label
|
13 |
+
self.char_embedding = True # use character embedding in addition to bert
|
14 |
+
self.char_embedding_size = 128 # dimension of the character embedding layer in the character embedding module
|
15 |
+
self.decoder_delay_steps = 0 # number of initial steps with frozen decoder
|
16 |
+
self.decoder_learning_rate = 6e-4 # initial decoder learning rate
|
17 |
+
self.decoder_weight_decay = 1.2e-6 # amount of weight decay
|
18 |
+
self.dropout_anchor = 0.5 # dropout at the last layer of anchor classifier
|
19 |
+
self.dropout_edge_label = 0.5 # dropout at the last layer of edge label classifier
|
20 |
+
self.dropout_edge_presence = 0.5 # dropout at the last layer of edge presence classifier
|
21 |
+
self.dropout_label = 0.5 # dropout at the last layer of label classifier
|
22 |
+
self.dropout_transformer = 0.5 # dropout for the transformer layers (decoder)
|
23 |
+
self.dropout_transformer_attention = 0.1 # dropout for the transformer's attention (decoder)
|
24 |
+
self.dropout_word = 0.1 # probability of dropping out a whole word from the encoder (in favour of char embedding)
|
25 |
+
self.encoder = "xlm-roberta-base" # pretrained encoder model
|
26 |
+
self.encoder_delay_steps = 2000 # number of initial steps with frozen XLM-R
|
27 |
+
self.encoder_freeze_embedding = True # freeze the first embedding layer in XLM-R
|
28 |
+
self.encoder_learning_rate = 6e-5 # initial encoder learning rate
|
29 |
+
self.encoder_weight_decay = 1e-2 # amount of weight decay
|
30 |
+
self.lr_decay_multiplier = 100
|
31 |
+
self.epochs = 100 # number of epochs for train
|
32 |
+
self.focal = True # use focal loss for the label prediction
|
33 |
+
self.freeze_bert = False # use focal loss for the label prediction
|
34 |
+
self.group_ops = False # group 'opN' edge labels into one
|
35 |
+
self.hidden_size_ff = 4 * 768 # hidden size of the transformer feed-forward submodule
|
36 |
+
self.hidden_size_anchor = 128 # hidden size anchor biaffine layer
|
37 |
+
self.hidden_size_edge_label = 256 # hidden size for edge label biaffine layer
|
38 |
+
self.hidden_size_edge_presence = 512 # hidden size for edge label biaffine layer
|
39 |
+
self.layerwise_lr_decay = 1.0 # layerwise decay of learning rate in the encoder
|
40 |
+
self.n_attention_heads = 8 # number of attention heads in the decoding transformer
|
41 |
+
self.n_layers = 3 # number of layers in the decoder
|
42 |
+
self.query_length = 4 # number of queries genereted for each word on the input
|
43 |
+
self.pre_norm = True # use pre-normalized version of the transformer (as in Transformers without Tears)
|
44 |
+
self.warmup_steps = 6000 # number of the warm-up steps for the inverse_sqrt scheduler
|
45 |
+
|
46 |
+
def init_data_paths(self):
|
47 |
+
directory_1 = {
|
48 |
+
"sequential": "node_centric_mrp",
|
49 |
+
"node-centric": "node_centric_mrp",
|
50 |
+
"labeled-edge": "labeled_edge_mrp"
|
51 |
+
}[self.graph_mode]
|
52 |
+
directory_2 = {
|
53 |
+
("darmstadt", "en"): "darmstadt_unis",
|
54 |
+
("mpqa", "en"): "mpqa",
|
55 |
+
("multibooked", "ca"): "multibooked_ca",
|
56 |
+
("multibooked", "eu"): "multibooked_eu",
|
57 |
+
("norec", "no"): "norec",
|
58 |
+
("opener", "en"): "opener_en",
|
59 |
+
("opener", "es"): "opener_es",
|
60 |
+
}[(self.framework, self.language)]
|
61 |
+
|
62 |
+
self.training_data = f"{self.data_directory}/{directory_1}/{directory_2}/train.mrp"
|
63 |
+
self.validation_data = f"{self.data_directory}/{directory_1}/{directory_2}/dev.mrp"
|
64 |
+
self.test_data = f"{self.data_directory}/{directory_1}/{directory_2}/test.mrp"
|
65 |
+
|
66 |
+
self.raw_training_data = f"{self.data_directory}/raw/{directory_2}/train.json"
|
67 |
+
self.raw_validation_data = f"{self.data_directory}/raw/{directory_2}/dev.json"
|
68 |
+
|
69 |
+
return self
|
70 |
+
|
71 |
+
def load_state_dict(self, d):
|
72 |
+
for k, v in d.items():
|
73 |
+
setattr(self, k, v)
|
74 |
+
return self
|
75 |
+
|
76 |
+
def state_dict(self):
|
77 |
+
members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")]
|
78 |
+
return {k: self.__dict__[k] for k in members}
|
79 |
+
|
80 |
+
def load(self, args):
|
81 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
82 |
+
params = yaml.safe_load(f)
|
83 |
+
self.load_state_dict(params)
|
84 |
+
self.init_data_paths()
|
85 |
+
|
86 |
+
def save(self, json_path):
|
87 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
88 |
+
d = self.state_dict()
|
89 |
+
yaml.dump(d, f)
|
data/__init__.py
ADDED
File without changes
|
data/batch.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Batch:
|
9 |
+
@staticmethod
|
10 |
+
def build(data):
|
11 |
+
fields = list(data[0].keys())
|
12 |
+
transposed = {}
|
13 |
+
for field in fields:
|
14 |
+
if isinstance(data[0][field], tuple):
|
15 |
+
transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field])))
|
16 |
+
else:
|
17 |
+
transposed[field] = Batch._stack(field, [example[field] for example in data])
|
18 |
+
|
19 |
+
return transposed
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def _stack(field: str, examples):
|
23 |
+
if field == "anchored_labels":
|
24 |
+
return examples
|
25 |
+
|
26 |
+
dim = examples[0].dim()
|
27 |
+
|
28 |
+
if dim == 0:
|
29 |
+
return torch.stack(examples)
|
30 |
+
|
31 |
+
lengths = [max(example.size(i) for example in examples) for i in range(dim)]
|
32 |
+
if any(length == 0 for length in lengths):
|
33 |
+
return torch.LongTensor(len(examples), *lengths)
|
34 |
+
|
35 |
+
examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples]
|
36 |
+
return torch.stack(examples)
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def _pad_size(example, total_size):
|
40 |
+
return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))]
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def index_select(batch, indices):
|
44 |
+
filtered_batch = {}
|
45 |
+
for key, examples in batch.items():
|
46 |
+
if isinstance(examples, list) or isinstance(examples, tuple):
|
47 |
+
filtered_batch[key] = [example.index_select(0, indices) for example in examples]
|
48 |
+
else:
|
49 |
+
filtered_batch[key] = examples.index_select(0, indices)
|
50 |
+
|
51 |
+
return filtered_batch
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def to_str(batch):
|
55 |
+
string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()])
|
56 |
+
return string
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def to(batch, device):
|
60 |
+
converted = {}
|
61 |
+
for field in batch.keys():
|
62 |
+
converted[field] = Batch._to(batch[field], device)
|
63 |
+
return converted
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def _short_str(tensor):
|
67 |
+
# unwrap variable to tensor
|
68 |
+
if not torch.is_tensor(tensor):
|
69 |
+
# (1) unpack variable
|
70 |
+
if hasattr(tensor, "data"):
|
71 |
+
tensor = getattr(tensor, "data")
|
72 |
+
# (2) handle include_lengths
|
73 |
+
elif isinstance(tensor, tuple) or isinstance(tensor, list):
|
74 |
+
return str(tuple(Batch._short_str(t) for t in tensor))
|
75 |
+
# (3) fallback to default str
|
76 |
+
else:
|
77 |
+
return str(tensor)
|
78 |
+
|
79 |
+
# copied from torch _tensor_str
|
80 |
+
size_str = "x".join(str(size) for size in tensor.size())
|
81 |
+
device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device())
|
82 |
+
strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str)
|
83 |
+
return strt
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def _to(tensor, device):
|
87 |
+
if not torch.is_tensor(tensor):
|
88 |
+
if isinstance(tensor, tuple):
|
89 |
+
return tuple(Batch._to(t, device) for t in tensor)
|
90 |
+
elif isinstance(tensor, list):
|
91 |
+
return [Batch._to(t, device) for t in tensor]
|
92 |
+
else:
|
93 |
+
raise Exception(f"unsupported type of {tensor} to be casted to cuda")
|
94 |
+
|
95 |
+
return tensor.to(device, non_blocking=True)
|
data/dataset.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from data.parser.from_mrp.node_centric_parser import NodeCentricParser
|
9 |
+
from data.parser.from_mrp.labeled_edge_parser import LabeledEdgeParser
|
10 |
+
from data.parser.from_mrp.sequential_parser import SequentialParser
|
11 |
+
from data.parser.from_mrp.evaluation_parser import EvaluationParser
|
12 |
+
from data.parser.from_mrp.request_parser import RequestParser
|
13 |
+
from data.field.edge_field import EdgeField
|
14 |
+
from data.field.edge_label_field import EdgeLabelField
|
15 |
+
from data.field.field import Field
|
16 |
+
from data.field.mini_torchtext.field import Field as TorchTextField
|
17 |
+
from data.field.label_field import LabelField
|
18 |
+
from data.field.anchored_label_field import AnchoredLabelField
|
19 |
+
from data.field.nested_field import NestedField
|
20 |
+
from data.field.basic_field import BasicField
|
21 |
+
from data.field.bert_field import BertField
|
22 |
+
from data.field.anchor_field import AnchorField
|
23 |
+
from data.batch import Batch
|
24 |
+
|
25 |
+
|
26 |
+
def char_tokenize(word):
|
27 |
+
return [c for i, c in enumerate(word)] # if i < 10 or len(word) - i <= 10]
|
28 |
+
|
29 |
+
|
30 |
+
class Collate:
|
31 |
+
def __call__(self, batch):
|
32 |
+
batch.sort(key=lambda example: example["every_input"][0].size(0), reverse=True)
|
33 |
+
return Batch.build(batch)
|
34 |
+
|
35 |
+
|
36 |
+
class Dataset:
|
37 |
+
def __init__(self, args, verbose=True):
|
38 |
+
self.verbose = verbose
|
39 |
+
self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>"
|
40 |
+
|
41 |
+
self.bert_input_field = BertField()
|
42 |
+
self.scatter_field = BasicField()
|
43 |
+
self.every_word_input_field = Field(lower=True, init_token=self.sos, eos_token=self.eos, batch_first=True, include_lengths=True)
|
44 |
+
|
45 |
+
char_form_nesting = TorchTextField(tokenize=char_tokenize, init_token=self.sos, eos_token=self.eos, batch_first=True)
|
46 |
+
self.char_form_field = NestedField(char_form_nesting, include_lengths=True)
|
47 |
+
|
48 |
+
self.label_field = LabelField(preprocessing=lambda nodes: [n["label"] for n in nodes])
|
49 |
+
self.anchored_label_field = AnchoredLabelField()
|
50 |
+
|
51 |
+
self.id_field = Field(batch_first=True, tokenize=lambda x: [x])
|
52 |
+
self.edge_presence_field = EdgeField()
|
53 |
+
self.edge_label_field = EdgeLabelField()
|
54 |
+
self.anchor_field = AnchorField()
|
55 |
+
self.source_anchor_field = AnchorField()
|
56 |
+
self.target_anchor_field = AnchorField()
|
57 |
+
self.token_interval_field = BasicField()
|
58 |
+
|
59 |
+
self.load_dataset(args)
|
60 |
+
|
61 |
+
def log(self, text):
|
62 |
+
if not self.verbose:
|
63 |
+
return
|
64 |
+
print(text, flush=True)
|
65 |
+
|
66 |
+
def load_state_dict(self, args, d):
|
67 |
+
for key, value in d["vocabs"].items():
|
68 |
+
getattr(self, key).vocab = pickle.loads(value)
|
69 |
+
|
70 |
+
def state_dict(self):
|
71 |
+
return {
|
72 |
+
"vocabs": {key: pickle.dumps(value.vocab) for key, value in self.__dict__.items() if hasattr(value, "vocab")}
|
73 |
+
}
|
74 |
+
|
75 |
+
def load_sentences(self, sentences, args):
|
76 |
+
dataset = RequestParser(
|
77 |
+
sentences, args,
|
78 |
+
fields={
|
79 |
+
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
|
80 |
+
"bert input": ("input", self.bert_input_field),
|
81 |
+
"to scatter": ("input_scatter", self.scatter_field),
|
82 |
+
"token anchors": ("token_intervals", self.token_interval_field),
|
83 |
+
"id": ("id", self.id_field),
|
84 |
+
},
|
85 |
+
)
|
86 |
+
|
87 |
+
self.every_word_input_field.build_vocab(dataset, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
|
88 |
+
self.id_field.build_vocab(dataset, min_freq=1, specials=[])
|
89 |
+
|
90 |
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=Collate())
|
91 |
+
|
92 |
+
def load_dataset(self, args):
|
93 |
+
parser = {
|
94 |
+
"sequential": SequentialParser,
|
95 |
+
"node-centric": NodeCentricParser,
|
96 |
+
"labeled-edge": LabeledEdgeParser
|
97 |
+
}[args.graph_mode]
|
98 |
+
|
99 |
+
train = parser(
|
100 |
+
args, "training",
|
101 |
+
fields={
|
102 |
+
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
|
103 |
+
"bert input": ("input", self.bert_input_field),
|
104 |
+
"to scatter": ("input_scatter", self.scatter_field),
|
105 |
+
"nodes": ("labels", self.label_field),
|
106 |
+
"anchored labels": ("anchored_labels", self.anchored_label_field),
|
107 |
+
"edge presence": ("edge_presence", self.edge_presence_field),
|
108 |
+
"edge labels": ("edge_labels", self.edge_label_field),
|
109 |
+
"anchor edges": ("anchor", self.anchor_field),
|
110 |
+
"source anchor edges": ("source_anchor", self.source_anchor_field),
|
111 |
+
"target anchor edges": ("target_anchor", self.target_anchor_field),
|
112 |
+
"token anchors": ("token_intervals", self.token_interval_field),
|
113 |
+
"id": ("id", self.id_field),
|
114 |
+
},
|
115 |
+
filter_pred=lambda example: len(example.input) <= 256,
|
116 |
+
)
|
117 |
+
|
118 |
+
val = parser(
|
119 |
+
args, "validation",
|
120 |
+
fields={
|
121 |
+
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
|
122 |
+
"bert input": ("input", self.bert_input_field),
|
123 |
+
"to scatter": ("input_scatter", self.scatter_field),
|
124 |
+
"nodes": ("labels", self.label_field),
|
125 |
+
"anchored labels": ("anchored_labels", self.anchored_label_field),
|
126 |
+
"edge presence": ("edge_presence", self.edge_presence_field),
|
127 |
+
"edge labels": ("edge_labels", self.edge_label_field),
|
128 |
+
"anchor edges": ("anchor", self.anchor_field),
|
129 |
+
"source anchor edges": ("source_anchor", self.source_anchor_field),
|
130 |
+
"target anchor edges": ("target_anchor", self.target_anchor_field),
|
131 |
+
"token anchors": ("token_intervals", self.token_interval_field),
|
132 |
+
"id": ("id", self.id_field),
|
133 |
+
},
|
134 |
+
)
|
135 |
+
|
136 |
+
test = EvaluationParser(
|
137 |
+
args,
|
138 |
+
fields={
|
139 |
+
"input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
|
140 |
+
"bert input": ("input", self.bert_input_field),
|
141 |
+
"to scatter": ("input_scatter", self.scatter_field),
|
142 |
+
"token anchors": ("token_intervals", self.token_interval_field),
|
143 |
+
"id": ("id", self.id_field),
|
144 |
+
},
|
145 |
+
)
|
146 |
+
|
147 |
+
del train.data, val.data, test.data # TODO: why?
|
148 |
+
for f in list(train.fields.values()) + list(val.fields.values()) + list(test.fields.values()): # TODO: why?
|
149 |
+
if hasattr(f, "preprocessing"):
|
150 |
+
del f.preprocessing
|
151 |
+
|
152 |
+
self.train_size = len(train)
|
153 |
+
self.val_size = len(val)
|
154 |
+
self.test_size = len(test)
|
155 |
+
|
156 |
+
self.log(f"\n{self.train_size} sentences in the train split")
|
157 |
+
self.log(f"{self.val_size} sentences in the validation split")
|
158 |
+
self.log(f"{self.test_size} sentences in the test split")
|
159 |
+
|
160 |
+
self.node_count = train.node_counter
|
161 |
+
self.token_count = train.input_count
|
162 |
+
self.edge_count = train.edge_counter
|
163 |
+
self.no_edge_count = train.no_edge_counter
|
164 |
+
self.anchor_freq = train.anchor_freq
|
165 |
+
|
166 |
+
self.source_anchor_freq = train.source_anchor_freq if hasattr(train, "source_anchor_freq") else 0.5
|
167 |
+
self.target_anchor_freq = train.target_anchor_freq if hasattr(train, "target_anchor_freq") else 0.5
|
168 |
+
self.log(f"{self.node_count} nodes in the train split")
|
169 |
+
|
170 |
+
self.every_word_input_field.build_vocab(val, test, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
|
171 |
+
self.char_form_field.build_vocab(train, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
|
172 |
+
self.char_form_field.nesting_field.vocab = self.char_form_field.vocab
|
173 |
+
self.id_field.build_vocab(train, val, test, min_freq=1, specials=[])
|
174 |
+
self.label_field.build_vocab(train)
|
175 |
+
self.anchored_label_field.vocab = self.label_field.vocab
|
176 |
+
self.edge_label_field.build_vocab(train)
|
177 |
+
print(list(self.edge_label_field.vocab.freqs.keys()), flush=True)
|
178 |
+
|
179 |
+
self.char_form_vocab_size = len(self.char_form_field.vocab)
|
180 |
+
self.create_label_freqs(args)
|
181 |
+
self.create_edge_freqs(args)
|
182 |
+
|
183 |
+
self.log(f"Edge frequency: {self.edge_presence_freq*100:.2f} %")
|
184 |
+
self.log(f"{len(self.label_field.vocab)} words in the label vocabulary")
|
185 |
+
self.log(f"{len(self.anchored_label_field.vocab)} words in the anchored label vocabulary")
|
186 |
+
self.log(f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary")
|
187 |
+
self.log(f"{len(self.char_form_field.vocab)} characters in the vocabulary")
|
188 |
+
|
189 |
+
self.log(self.label_field.vocab.freqs)
|
190 |
+
self.log(self.anchored_label_field.vocab.freqs)
|
191 |
+
|
192 |
+
self.train = torch.utils.data.DataLoader(
|
193 |
+
train,
|
194 |
+
batch_size=args.batch_size,
|
195 |
+
shuffle=True,
|
196 |
+
num_workers=args.workers,
|
197 |
+
collate_fn=Collate(),
|
198 |
+
pin_memory=True,
|
199 |
+
drop_last=True
|
200 |
+
)
|
201 |
+
self.train_size = len(self.train.dataset)
|
202 |
+
|
203 |
+
self.val = torch.utils.data.DataLoader(
|
204 |
+
val,
|
205 |
+
batch_size=args.batch_size,
|
206 |
+
shuffle=False,
|
207 |
+
num_workers=args.workers,
|
208 |
+
collate_fn=Collate(),
|
209 |
+
pin_memory=True,
|
210 |
+
)
|
211 |
+
self.val_size = len(self.val.dataset)
|
212 |
+
|
213 |
+
self.test = torch.utils.data.DataLoader(
|
214 |
+
test,
|
215 |
+
batch_size=args.batch_size,
|
216 |
+
shuffle=False,
|
217 |
+
num_workers=args.workers,
|
218 |
+
collate_fn=Collate(),
|
219 |
+
pin_memory=True,
|
220 |
+
)
|
221 |
+
self.test_size = len(self.test.dataset)
|
222 |
+
|
223 |
+
if self.verbose:
|
224 |
+
batch = next(iter(self.train))
|
225 |
+
print(f"\nBatch content: {Batch.to_str(batch)}\n")
|
226 |
+
print(flush=True)
|
227 |
+
|
228 |
+
def create_label_freqs(self, args):
|
229 |
+
n_rules = len(self.label_field.vocab)
|
230 |
+
blank_count = (args.query_length * self.token_count - self.node_count)
|
231 |
+
label_counts = [blank_count] + [
|
232 |
+
self.label_field.vocab.freqs[self.label_field.vocab.itos[i]]
|
233 |
+
for i in range(n_rules)
|
234 |
+
]
|
235 |
+
label_counts = torch.FloatTensor(label_counts)
|
236 |
+
self.label_freqs = label_counts / (self.node_count + blank_count)
|
237 |
+
self.log(f"Label frequency: {self.label_freqs}")
|
238 |
+
|
239 |
+
def create_edge_freqs(self, args):
|
240 |
+
edge_counter = [
|
241 |
+
self.edge_label_field.vocab.freqs[self.edge_label_field.vocab.itos[i]] for i in range(len(self.edge_label_field.vocab))
|
242 |
+
]
|
243 |
+
edge_counter = torch.FloatTensor(edge_counter)
|
244 |
+
self.edge_label_freqs = edge_counter / self.edge_count
|
245 |
+
self.edge_presence_freq = self.edge_count / (self.edge_count + self.no_edge_count)
|
data/field/__init__.py
ADDED
File without changes
|
data/field/anchor_field.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from data.field.mini_torchtext.field import RawField
|
6 |
+
|
7 |
+
|
8 |
+
class AnchorField(RawField):
|
9 |
+
def process(self, batch, device=None):
|
10 |
+
tensors, masks = self.pad(batch, device)
|
11 |
+
return tensors, masks
|
12 |
+
|
13 |
+
def pad(self, anchors, device):
|
14 |
+
tensor = torch.zeros(anchors[0], anchors[1], dtype=torch.long, device=device)
|
15 |
+
for anchor in anchors[-1]:
|
16 |
+
tensor[anchor[0], anchor[1]] = 1
|
17 |
+
mask = tensor.sum(-1) == 0
|
18 |
+
|
19 |
+
return tensor, mask
|
data/field/anchored_label_field.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from data.field.mini_torchtext.field import RawField
|
3 |
+
|
4 |
+
|
5 |
+
class AnchoredLabelField(RawField):
|
6 |
+
def __init__(self):
|
7 |
+
super(AnchoredLabelField, self).__init__()
|
8 |
+
self.vocab = None
|
9 |
+
|
10 |
+
def process(self, example, device=None):
|
11 |
+
example = self.numericalize(example)
|
12 |
+
tensor = self.pad(example, device)
|
13 |
+
return tensor
|
14 |
+
|
15 |
+
def pad(self, example, device):
|
16 |
+
n_labels = len(self.vocab)
|
17 |
+
n_nodes, n_tokens = len(example[1]), example[0]
|
18 |
+
|
19 |
+
tensor = torch.full([n_nodes, n_tokens, n_labels + 1], 0, dtype=torch.long, device=device)
|
20 |
+
for i_node, node in enumerate(example[1]):
|
21 |
+
for anchor, rule in node:
|
22 |
+
tensor[i_node, anchor, rule + 1] = 1
|
23 |
+
|
24 |
+
return tensor
|
25 |
+
|
26 |
+
def numericalize(self, arr):
|
27 |
+
def multi_map(array, function):
|
28 |
+
if isinstance(array, tuple):
|
29 |
+
return (array[0], function(array[1]))
|
30 |
+
elif isinstance(array, list):
|
31 |
+
return [multi_map(a, function) for a in array]
|
32 |
+
else:
|
33 |
+
return array
|
34 |
+
|
35 |
+
if self.vocab is not None:
|
36 |
+
arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0)
|
37 |
+
|
38 |
+
return arr
|
data/field/basic_field.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from data.field.mini_torchtext.field import RawField
|
6 |
+
|
7 |
+
|
8 |
+
class BasicField(RawField):
|
9 |
+
def process(self, example, device=None):
|
10 |
+
tensor = torch.tensor(example, dtype=torch.long, device=device)
|
11 |
+
return tensor
|
data/field/bert_field.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from data.field.mini_torchtext.field import RawField
|
6 |
+
|
7 |
+
|
8 |
+
class BertField(RawField):
|
9 |
+
def __init__(self):
|
10 |
+
super(BertField, self).__init__()
|
11 |
+
|
12 |
+
def process(self, example, device=None):
|
13 |
+
attention_mask = [1] * len(example)
|
14 |
+
|
15 |
+
example = torch.LongTensor(example, device=device)
|
16 |
+
attention_mask = torch.ones_like(example)
|
17 |
+
|
18 |
+
return example, attention_mask
|
data/field/edge_field.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from data.field.mini_torchtext.field import RawField
|
6 |
+
from data.field.mini_torchtext.vocab import Vocab
|
7 |
+
from collections import Counter
|
8 |
+
import types
|
9 |
+
|
10 |
+
|
11 |
+
class EdgeField(RawField):
|
12 |
+
def __init__(self):
|
13 |
+
super(EdgeField, self).__init__()
|
14 |
+
self.vocab = None
|
15 |
+
|
16 |
+
def process(self, edges, device=None):
|
17 |
+
edges = self.numericalize(edges)
|
18 |
+
tensor = self.pad(edges, device)
|
19 |
+
return tensor
|
20 |
+
|
21 |
+
def pad(self, edges, device):
|
22 |
+
tensor = torch.zeros(edges[0], edges[1], dtype=torch.long, device=device)
|
23 |
+
for edge in edges[-1]:
|
24 |
+
tensor[edge[0], edge[1]] = edge[2]
|
25 |
+
|
26 |
+
return tensor
|
27 |
+
|
28 |
+
def numericalize(self, arr):
|
29 |
+
def multi_map(array, function):
|
30 |
+
if isinstance(array, tuple):
|
31 |
+
return (array[0], array[1], function(array[2]))
|
32 |
+
elif isinstance(array, list):
|
33 |
+
return [multi_map(array[i], function) for i in range(len(array))]
|
34 |
+
else:
|
35 |
+
return array
|
36 |
+
|
37 |
+
if self.vocab is not None:
|
38 |
+
arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x is not None else 0)
|
39 |
+
return arr
|
40 |
+
|
41 |
+
def build_vocab(self, *args):
|
42 |
+
def generate(l):
|
43 |
+
if isinstance(l, tuple):
|
44 |
+
yield l[2]
|
45 |
+
elif isinstance(l, list) or isinstance(l, types.GeneratorType):
|
46 |
+
for i in l:
|
47 |
+
yield from generate(i)
|
48 |
+
else:
|
49 |
+
return
|
50 |
+
|
51 |
+
counter = Counter()
|
52 |
+
sources = []
|
53 |
+
for arg in args:
|
54 |
+
if isinstance(arg, torch.utils.data.Dataset):
|
55 |
+
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
|
56 |
+
else:
|
57 |
+
sources.append(arg)
|
58 |
+
|
59 |
+
for x in generate(sources):
|
60 |
+
if x is not None:
|
61 |
+
counter.update([x])
|
62 |
+
|
63 |
+
self.vocab = Vocab(counter, specials=[])
|
data/field/edge_label_field.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from data.field.mini_torchtext.field import RawField
|
6 |
+
from data.field.mini_torchtext.vocab import Vocab
|
7 |
+
from collections import Counter
|
8 |
+
import types
|
9 |
+
|
10 |
+
|
11 |
+
class EdgeLabelField(RawField):
|
12 |
+
def process(self, edges, device=None):
|
13 |
+
edges, masks = self.numericalize(edges)
|
14 |
+
edges, masks = self.pad(edges, masks, device)
|
15 |
+
|
16 |
+
return edges, masks
|
17 |
+
|
18 |
+
def pad(self, edges, masks, device):
|
19 |
+
n_labels = len(self.vocab)
|
20 |
+
|
21 |
+
tensor = torch.zeros(edges[0], edges[1], n_labels, dtype=torch.long, device=device)
|
22 |
+
mask_tensor = torch.zeros(edges[0], edges[1], dtype=torch.bool, device=device)
|
23 |
+
|
24 |
+
for edge in edges[-1]:
|
25 |
+
tensor[edge[0], edge[1], edge[2]] = 1
|
26 |
+
|
27 |
+
for mask in masks[-1]:
|
28 |
+
mask_tensor[mask[0], mask[1]] = mask[2]
|
29 |
+
|
30 |
+
return tensor, mask_tensor
|
31 |
+
|
32 |
+
def numericalize(self, arr):
|
33 |
+
def multi_map(array, function):
|
34 |
+
if isinstance(array, tuple):
|
35 |
+
return (array[0], array[1], function(array[2]))
|
36 |
+
elif isinstance(array, list):
|
37 |
+
return [multi_map(array[i], function) for i in range(len(array))]
|
38 |
+
else:
|
39 |
+
return array
|
40 |
+
|
41 |
+
mask = multi_map(arr, lambda x: x is None)
|
42 |
+
arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0)
|
43 |
+
return arr, mask
|
44 |
+
|
45 |
+
def build_vocab(self, *args):
|
46 |
+
def generate(l):
|
47 |
+
if isinstance(l, tuple):
|
48 |
+
yield l[2]
|
49 |
+
elif isinstance(l, list) or isinstance(l, types.GeneratorType):
|
50 |
+
for i in l:
|
51 |
+
yield from generate(i)
|
52 |
+
else:
|
53 |
+
return
|
54 |
+
|
55 |
+
counter = Counter()
|
56 |
+
sources = []
|
57 |
+
for arg in args:
|
58 |
+
if isinstance(arg, torch.utils.data.Dataset):
|
59 |
+
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
|
60 |
+
else:
|
61 |
+
sources.append(arg)
|
62 |
+
|
63 |
+
for x in generate(sources):
|
64 |
+
if x is not None:
|
65 |
+
counter.update([x])
|
66 |
+
|
67 |
+
self.vocab = Vocab(counter, specials=[])
|
data/field/field.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from data.field.mini_torchtext.field import Field as TorchTextField
|
3 |
+
from collections import Counter, OrderedDict
|
4 |
+
|
5 |
+
|
6 |
+
# small change of vocab building to correspond to our version of Dataset
|
7 |
+
class Field(TorchTextField):
|
8 |
+
def build_vocab(self, *args, **kwargs):
|
9 |
+
counter = Counter()
|
10 |
+
sources = []
|
11 |
+
for arg in args:
|
12 |
+
if isinstance(arg, torch.utils.data.Dataset):
|
13 |
+
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
|
14 |
+
else:
|
15 |
+
sources.append(arg)
|
16 |
+
for data in sources:
|
17 |
+
for x in data:
|
18 |
+
if not self.sequential:
|
19 |
+
x = [x]
|
20 |
+
counter.update(x)
|
21 |
+
|
22 |
+
specials = list(
|
23 |
+
OrderedDict.fromkeys(
|
24 |
+
tok
|
25 |
+
for tok in [self.unk_token, self.pad_token, self.init_token, self.eos_token] + kwargs.pop("specials", [])
|
26 |
+
if tok is not None
|
27 |
+
)
|
28 |
+
)
|
29 |
+
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
|
30 |
+
|
31 |
+
def process(self, example, device=None):
|
32 |
+
if self.include_lengths:
|
33 |
+
example = example, len(example)
|
34 |
+
tensor = self.numericalize(example, device=device)
|
35 |
+
return tensor
|
36 |
+
|
37 |
+
def numericalize(self, ex, device=None):
|
38 |
+
if self.include_lengths and not isinstance(ex, tuple):
|
39 |
+
raise ValueError("Field has include_lengths set to True, but input data is not a tuple of (data batch, batch lengths).")
|
40 |
+
|
41 |
+
if isinstance(ex, tuple):
|
42 |
+
ex, lengths = ex
|
43 |
+
lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
|
44 |
+
|
45 |
+
if self.use_vocab:
|
46 |
+
if self.sequential:
|
47 |
+
ex = [self.vocab.stoi[x] for x in ex]
|
48 |
+
else:
|
49 |
+
ex = self.vocab.stoi[ex]
|
50 |
+
|
51 |
+
if self.postprocessing is not None:
|
52 |
+
ex = self.postprocessing(ex, self.vocab)
|
53 |
+
else:
|
54 |
+
numericalization_func = self.dtypes[self.dtype]
|
55 |
+
|
56 |
+
if not self.sequential:
|
57 |
+
ex = numericalization_func(ex) if isinstance(ex, str) else ex
|
58 |
+
if self.postprocessing is not None:
|
59 |
+
ex = self.postprocessing(ex, None)
|
60 |
+
|
61 |
+
var = torch.tensor(ex, dtype=self.dtype, device=device)
|
62 |
+
|
63 |
+
if self.sequential and not self.batch_first:
|
64 |
+
var.t_()
|
65 |
+
if self.sequential:
|
66 |
+
var = var.contiguous()
|
67 |
+
|
68 |
+
if self.include_lengths:
|
69 |
+
return var, lengths
|
70 |
+
return var
|
data/field/label_field.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from data.field.mini_torchtext.field import RawField
|
3 |
+
from data.field.mini_torchtext.vocab import Vocab
|
4 |
+
from collections import Counter
|
5 |
+
|
6 |
+
|
7 |
+
class LabelField(RawField):
|
8 |
+
def __self__(self, preprocessing):
|
9 |
+
super(LabelField, self).__init__(preprocessing=preprocessing)
|
10 |
+
self.vocab = None
|
11 |
+
|
12 |
+
def build_vocab(self, *args, **kwargs):
|
13 |
+
sources = []
|
14 |
+
for arg in args:
|
15 |
+
if isinstance(arg, torch.utils.data.Dataset):
|
16 |
+
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
|
17 |
+
else:
|
18 |
+
sources.append(arg)
|
19 |
+
|
20 |
+
counter = Counter()
|
21 |
+
for data in sources:
|
22 |
+
for x in data:
|
23 |
+
counter.update(x)
|
24 |
+
|
25 |
+
self.vocab = Vocab(counter, specials=[])
|
26 |
+
|
27 |
+
def process(self, example, device=None):
|
28 |
+
tensor, lengths = self.numericalize(example, device=device)
|
29 |
+
return tensor, lengths
|
30 |
+
|
31 |
+
def numericalize(self, example, device=None):
|
32 |
+
example = [self.vocab.stoi[x] + 1 for x in example]
|
33 |
+
length = torch.LongTensor([len(example)], device=device).squeeze(0)
|
34 |
+
tensor = torch.LongTensor(example, device=device)
|
35 |
+
|
36 |
+
return tensor, length
|
data/field/mini_torchtext/example.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import six
|
2 |
+
import json
|
3 |
+
from functools import reduce
|
4 |
+
|
5 |
+
|
6 |
+
class Example(object):
|
7 |
+
"""Defines a single training or test example.
|
8 |
+
|
9 |
+
Stores each column of the example as an attribute.
|
10 |
+
"""
|
11 |
+
@classmethod
|
12 |
+
def fromJSON(cls, data, fields):
|
13 |
+
ex = cls()
|
14 |
+
obj = json.loads(data)
|
15 |
+
|
16 |
+
for key, vals in fields.items():
|
17 |
+
if vals is not None:
|
18 |
+
if not isinstance(vals, list):
|
19 |
+
vals = [vals]
|
20 |
+
|
21 |
+
for val in vals:
|
22 |
+
# for processing the key likes 'foo.bar'
|
23 |
+
name, field = val
|
24 |
+
ks = key.split('.')
|
25 |
+
|
26 |
+
def reducer(obj, key):
|
27 |
+
if isinstance(obj, list):
|
28 |
+
results = []
|
29 |
+
for data in obj:
|
30 |
+
if key not in data:
|
31 |
+
# key error
|
32 |
+
raise ValueError("Specified key {} was not found in "
|
33 |
+
"the input data".format(key))
|
34 |
+
else:
|
35 |
+
results.append(data[key])
|
36 |
+
return results
|
37 |
+
else:
|
38 |
+
# key error
|
39 |
+
if key not in obj:
|
40 |
+
raise ValueError("Specified key {} was not found in "
|
41 |
+
"the input data".format(key))
|
42 |
+
else:
|
43 |
+
return obj[key]
|
44 |
+
|
45 |
+
v = reduce(reducer, ks, obj)
|
46 |
+
setattr(ex, name, field.preprocess(v))
|
47 |
+
return ex
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def fromdict(cls, data, fields):
|
51 |
+
ex = cls()
|
52 |
+
for key, vals in fields.items():
|
53 |
+
if key not in data:
|
54 |
+
raise ValueError("Specified key {} was not found in "
|
55 |
+
"the input data".format(key))
|
56 |
+
if vals is not None:
|
57 |
+
if not isinstance(vals, list):
|
58 |
+
vals = [vals]
|
59 |
+
for val in vals:
|
60 |
+
name, field = val
|
61 |
+
setattr(ex, name, field.preprocess(data[key]))
|
62 |
+
return ex
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def fromCSV(cls, data, fields, field_to_index=None):
|
66 |
+
if field_to_index is None:
|
67 |
+
return cls.fromlist(data, fields)
|
68 |
+
else:
|
69 |
+
assert(isinstance(fields, dict))
|
70 |
+
data_dict = {f: data[idx] for f, idx in field_to_index.items()}
|
71 |
+
return cls.fromdict(data_dict, fields)
|
72 |
+
|
73 |
+
@classmethod
|
74 |
+
def fromlist(cls, data, fields):
|
75 |
+
ex = cls()
|
76 |
+
for (name, field), val in zip(fields, data):
|
77 |
+
if field is not None:
|
78 |
+
if isinstance(val, six.string_types):
|
79 |
+
val = val.rstrip('\n')
|
80 |
+
# Handle field tuples
|
81 |
+
if isinstance(name, tuple):
|
82 |
+
for n, f in zip(name, field):
|
83 |
+
setattr(ex, n, f.preprocess(val))
|
84 |
+
else:
|
85 |
+
setattr(ex, name, field.preprocess(val))
|
86 |
+
return ex
|
87 |
+
|
88 |
+
@classmethod
|
89 |
+
def fromtree(cls, data, fields, subtrees=False):
|
90 |
+
try:
|
91 |
+
from nltk.tree import Tree
|
92 |
+
except ImportError:
|
93 |
+
print("Please install NLTK. "
|
94 |
+
"See the docs at http://nltk.org for more information.")
|
95 |
+
raise
|
96 |
+
tree = Tree.fromstring(data)
|
97 |
+
if subtrees:
|
98 |
+
return [cls.fromlist(
|
99 |
+
[' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
|
100 |
+
return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)
|
data/field/mini_torchtext/field.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf8
|
2 |
+
from collections import Counter, OrderedDict
|
3 |
+
from itertools import chain
|
4 |
+
import six
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .pipeline import Pipeline
|
8 |
+
from .utils import get_tokenizer, dtype_to_attr, is_tokenizer_serializable
|
9 |
+
from .vocab import Vocab
|
10 |
+
|
11 |
+
|
12 |
+
class RawField(object):
|
13 |
+
""" Defines a general datatype.
|
14 |
+
|
15 |
+
Every dataset consists of one or more types of data. For instance, a text
|
16 |
+
classification dataset contains sentences and their classes, while a
|
17 |
+
machine translation dataset contains paired examples of text in two
|
18 |
+
languages. Each of these types of data is represented by a RawField object.
|
19 |
+
A RawField object does not assume any property of the data type and
|
20 |
+
it holds parameters relating to how a datatype should be processed.
|
21 |
+
|
22 |
+
Attributes:
|
23 |
+
preprocessing: The Pipeline that will be applied to examples
|
24 |
+
using this field before creating an example.
|
25 |
+
Default: None.
|
26 |
+
postprocessing: A Pipeline that will be applied to a list of examples
|
27 |
+
using this field before assigning to a batch.
|
28 |
+
Function signature: (batch(list)) -> object
|
29 |
+
Default: None.
|
30 |
+
is_target: Whether this field is a target variable.
|
31 |
+
Affects iteration over batches. Default: False
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, preprocessing=None, postprocessing=None, is_target=False):
|
35 |
+
self.preprocessing = preprocessing
|
36 |
+
self.postprocessing = postprocessing
|
37 |
+
self.is_target = is_target
|
38 |
+
|
39 |
+
def preprocess(self, x):
|
40 |
+
""" Preprocess an example if the `preprocessing` Pipeline is provided. """
|
41 |
+
if hasattr(self, "preprocessing") and self.preprocessing is not None:
|
42 |
+
return self.preprocessing(x)
|
43 |
+
else:
|
44 |
+
return x
|
45 |
+
|
46 |
+
def process(self, batch, *args, **kwargs):
|
47 |
+
""" Process a list of examples to create a batch.
|
48 |
+
|
49 |
+
Postprocess the batch with user-provided Pipeline.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
batch (list(object)): A list of object from a batch of examples.
|
53 |
+
Returns:
|
54 |
+
object: Processed object given the input and custom
|
55 |
+
postprocessing Pipeline.
|
56 |
+
"""
|
57 |
+
if self.postprocessing is not None:
|
58 |
+
batch = self.postprocessing(batch)
|
59 |
+
return batch
|
60 |
+
|
61 |
+
|
62 |
+
class Field(RawField):
|
63 |
+
"""Defines a datatype together with instructions for converting to Tensor.
|
64 |
+
|
65 |
+
Field class models common text processing datatypes that can be represented
|
66 |
+
by tensors. It holds a Vocab object that defines the set of possible values
|
67 |
+
for elements of the field and their corresponding numerical representations.
|
68 |
+
The Field object also holds other parameters relating to how a datatype
|
69 |
+
should be numericalized, such as a tokenization method and the kind of
|
70 |
+
Tensor that should be produced.
|
71 |
+
|
72 |
+
If a Field is shared between two columns in a dataset (e.g., question and
|
73 |
+
answer in a QA dataset), then they will have a shared vocabulary.
|
74 |
+
|
75 |
+
Attributes:
|
76 |
+
sequential: Whether the datatype represents sequential data. If False,
|
77 |
+
no tokenization is applied. Default: True.
|
78 |
+
use_vocab: Whether to use a Vocab object. If False, the data in this
|
79 |
+
field should already be numerical. Default: True.
|
80 |
+
init_token: A token that will be prepended to every example using this
|
81 |
+
field, or None for no initial token. Default: None.
|
82 |
+
eos_token: A token that will be appended to every example using this
|
83 |
+
field, or None for no end-of-sentence token. Default: None.
|
84 |
+
fix_length: A fixed length that all examples using this field will be
|
85 |
+
padded to, or None for flexible sequence lengths. Default: None.
|
86 |
+
dtype: The torch.dtype class that represents a batch of examples
|
87 |
+
of this kind of data. Default: torch.long.
|
88 |
+
preprocessing: The Pipeline that will be applied to examples
|
89 |
+
using this field after tokenizing but before numericalizing. Many
|
90 |
+
Datasets replace this attribute with a custom preprocessor.
|
91 |
+
Default: None.
|
92 |
+
postprocessing: A Pipeline that will be applied to examples using
|
93 |
+
this field after numericalizing but before the numbers are turned
|
94 |
+
into a Tensor. The pipeline function takes the batch as a list, and
|
95 |
+
the field's Vocab.
|
96 |
+
Default: None.
|
97 |
+
lower: Whether to lowercase the text in this field. Default: False.
|
98 |
+
tokenize: The function used to tokenize strings using this field into
|
99 |
+
sequential examples. If "spacy", the SpaCy tokenizer is
|
100 |
+
used. If a non-serializable function is passed as an argument,
|
101 |
+
the field will not be able to be serialized. Default: string.split.
|
102 |
+
tokenizer_language: The language of the tokenizer to be constructed.
|
103 |
+
Various languages currently supported only in SpaCy.
|
104 |
+
include_lengths: Whether to return a tuple of a padded minibatch and
|
105 |
+
a list containing the lengths of each examples, or just a padded
|
106 |
+
minibatch. Default: False.
|
107 |
+
batch_first: Whether to produce tensors with the batch dimension first.
|
108 |
+
Default: False.
|
109 |
+
pad_token: The string token used as padding. Default: "<pad>".
|
110 |
+
unk_token: The string token used to represent OOV words. Default: "<unk>".
|
111 |
+
pad_first: Do the padding of the sequence at the beginning. Default: False.
|
112 |
+
truncate_first: Do the truncating of the sequence at the beginning. Default: False
|
113 |
+
stop_words: Tokens to discard during the preprocessing step. Default: None
|
114 |
+
is_target: Whether this field is a target variable.
|
115 |
+
Affects iteration over batches. Default: False
|
116 |
+
"""
|
117 |
+
|
118 |
+
vocab_cls = Vocab
|
119 |
+
# Dictionary mapping PyTorch tensor dtypes to the appropriate Python
|
120 |
+
# numeric type.
|
121 |
+
dtypes = {
|
122 |
+
torch.float32: float,
|
123 |
+
torch.float: float,
|
124 |
+
torch.float64: float,
|
125 |
+
torch.double: float,
|
126 |
+
torch.float16: float,
|
127 |
+
torch.half: float,
|
128 |
+
|
129 |
+
torch.uint8: int,
|
130 |
+
torch.int8: int,
|
131 |
+
torch.int16: int,
|
132 |
+
torch.short: int,
|
133 |
+
torch.int32: int,
|
134 |
+
torch.int: int,
|
135 |
+
torch.int64: int,
|
136 |
+
torch.long: int,
|
137 |
+
}
|
138 |
+
|
139 |
+
ignore = ['dtype', 'tokenize']
|
140 |
+
|
141 |
+
def __init__(self, sequential=True, use_vocab=True, init_token=None,
|
142 |
+
eos_token=None, fix_length=None, dtype=torch.long,
|
143 |
+
preprocessing=None, postprocessing=None, lower=False,
|
144 |
+
tokenize=None, tokenizer_language='en', include_lengths=False,
|
145 |
+
batch_first=False, pad_token="<pad>", unk_token="<unk>",
|
146 |
+
pad_first=False, truncate_first=False, stop_words=None,
|
147 |
+
is_target=False):
|
148 |
+
self.sequential = sequential
|
149 |
+
self.use_vocab = use_vocab
|
150 |
+
self.init_token = init_token
|
151 |
+
self.eos_token = eos_token
|
152 |
+
self.unk_token = unk_token
|
153 |
+
self.fix_length = fix_length
|
154 |
+
self.dtype = dtype
|
155 |
+
self.preprocessing = preprocessing
|
156 |
+
self.postprocessing = postprocessing
|
157 |
+
self.lower = lower
|
158 |
+
# store params to construct tokenizer for serialization
|
159 |
+
# in case the tokenizer isn't picklable (e.g. spacy)
|
160 |
+
self.tokenizer_args = (tokenize, tokenizer_language)
|
161 |
+
self.tokenize = get_tokenizer(tokenize, tokenizer_language)
|
162 |
+
self.include_lengths = include_lengths
|
163 |
+
self.batch_first = batch_first
|
164 |
+
self.pad_token = pad_token if self.sequential else None
|
165 |
+
self.pad_first = pad_first
|
166 |
+
self.truncate_first = truncate_first
|
167 |
+
try:
|
168 |
+
self.stop_words = set(stop_words) if stop_words is not None else None
|
169 |
+
except TypeError:
|
170 |
+
raise ValueError("Stop words must be convertible to a set")
|
171 |
+
self.is_target = is_target
|
172 |
+
|
173 |
+
def __getstate__(self):
|
174 |
+
str_type = dtype_to_attr(self.dtype)
|
175 |
+
if is_tokenizer_serializable(*self.tokenizer_args):
|
176 |
+
tokenize = self.tokenize
|
177 |
+
else:
|
178 |
+
# signal to restore in `__setstate__`
|
179 |
+
tokenize = None
|
180 |
+
attrs = {k: v for k, v in self.__dict__.items() if k not in self.ignore}
|
181 |
+
attrs['dtype'] = str_type
|
182 |
+
attrs['tokenize'] = tokenize
|
183 |
+
|
184 |
+
return attrs
|
185 |
+
|
186 |
+
def __setstate__(self, state):
|
187 |
+
state['dtype'] = getattr(torch, state['dtype'])
|
188 |
+
if not state['tokenize']:
|
189 |
+
state['tokenize'] = get_tokenizer(*state['tokenizer_args'])
|
190 |
+
self.__dict__.update(state)
|
191 |
+
|
192 |
+
def __hash__(self):
|
193 |
+
# we don't expect this to be called often
|
194 |
+
return 42
|
195 |
+
|
196 |
+
def __eq__(self, other):
|
197 |
+
if not isinstance(other, RawField):
|
198 |
+
return False
|
199 |
+
|
200 |
+
return self.__dict__ == other.__dict__
|
201 |
+
|
202 |
+
def preprocess(self, x):
|
203 |
+
"""Load a single example using this field, tokenizing if necessary.
|
204 |
+
|
205 |
+
If the input is a Python 2 `str`, it will be converted to Unicode
|
206 |
+
first. If `sequential=True`, it will be tokenized. Then the input
|
207 |
+
will be optionally lowercased and passed to the user-provided
|
208 |
+
`preprocessing` Pipeline."""
|
209 |
+
if (six.PY2 and isinstance(x, six.string_types)
|
210 |
+
and not isinstance(x, six.text_type)):
|
211 |
+
x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
|
212 |
+
if self.sequential and isinstance(x, six.text_type):
|
213 |
+
x = self.tokenize(x.rstrip('\n'))
|
214 |
+
if self.lower:
|
215 |
+
x = Pipeline(six.text_type.lower)(x)
|
216 |
+
if self.sequential and self.use_vocab and self.stop_words is not None:
|
217 |
+
x = [w for w in x if w not in self.stop_words]
|
218 |
+
if hasattr(self, "preprocessing") and self.preprocessing is not None:
|
219 |
+
return self.preprocessing(x)
|
220 |
+
else:
|
221 |
+
return x
|
222 |
+
|
223 |
+
def process(self, batch, device=None):
|
224 |
+
""" Process a list of examples to create a torch.Tensor.
|
225 |
+
|
226 |
+
Pad, numericalize, and postprocess a batch and create a tensor.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
batch (list(object)): A list of object from a batch of examples.
|
230 |
+
Returns:
|
231 |
+
torch.autograd.Variable: Processed object given the input
|
232 |
+
and custom postprocessing Pipeline.
|
233 |
+
"""
|
234 |
+
padded = self.pad(batch)
|
235 |
+
tensor = self.numericalize(padded, device=device)
|
236 |
+
return tensor
|
237 |
+
|
238 |
+
def pad(self, minibatch):
|
239 |
+
"""Pad a batch of examples using this field.
|
240 |
+
|
241 |
+
Pads to self.fix_length if provided, otherwise pads to the length of
|
242 |
+
the longest example in the batch. Prepends self.init_token and appends
|
243 |
+
self.eos_token if those attributes are not None. Returns a tuple of the
|
244 |
+
padded list and a list containing lengths of each example if
|
245 |
+
`self.include_lengths` is `True` and `self.sequential` is `True`, else just
|
246 |
+
returns the padded list. If `self.sequential` is `False`, no padding is applied.
|
247 |
+
"""
|
248 |
+
minibatch = list(minibatch)
|
249 |
+
if not self.sequential:
|
250 |
+
return minibatch
|
251 |
+
if self.fix_length is None:
|
252 |
+
max_len = max(len(x) for x in minibatch)
|
253 |
+
else:
|
254 |
+
max_len = self.fix_length + (
|
255 |
+
self.init_token, self.eos_token).count(None) - 2
|
256 |
+
padded, lengths = [], []
|
257 |
+
for x in minibatch:
|
258 |
+
if self.pad_first:
|
259 |
+
padded.append(
|
260 |
+
[self.pad_token] * max(0, max_len - len(x))
|
261 |
+
+ ([] if self.init_token is None else [self.init_token])
|
262 |
+
+ list(x[-max_len:] if self.truncate_first else x[:max_len])
|
263 |
+
+ ([] if self.eos_token is None else [self.eos_token]))
|
264 |
+
else:
|
265 |
+
padded.append(
|
266 |
+
([] if self.init_token is None else [self.init_token])
|
267 |
+
+ list(x[-max_len:] if self.truncate_first else x[:max_len])
|
268 |
+
+ ([] if self.eos_token is None else [self.eos_token])
|
269 |
+
+ [self.pad_token] * max(0, max_len - len(x)))
|
270 |
+
lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
|
271 |
+
if self.include_lengths:
|
272 |
+
return (padded, lengths)
|
273 |
+
return padded
|
274 |
+
|
275 |
+
def build_vocab(self, *args, **kwargs):
|
276 |
+
"""Construct the Vocab object for this field from one or more datasets.
|
277 |
+
|
278 |
+
Arguments:
|
279 |
+
Positional arguments: Dataset objects or other iterable data
|
280 |
+
sources from which to construct the Vocab object that
|
281 |
+
represents the set of possible values for this field. If
|
282 |
+
a Dataset object is provided, all columns corresponding
|
283 |
+
to this field are used; individual columns can also be
|
284 |
+
provided directly.
|
285 |
+
Remaining keyword arguments: Passed to the constructor of Vocab.
|
286 |
+
"""
|
287 |
+
counter = Counter()
|
288 |
+
sources = []
|
289 |
+
for arg in args:
|
290 |
+
sources.append(arg)
|
291 |
+
for data in sources:
|
292 |
+
for x in data:
|
293 |
+
if not self.sequential:
|
294 |
+
x = [x]
|
295 |
+
try:
|
296 |
+
counter.update(x)
|
297 |
+
except TypeError:
|
298 |
+
counter.update(chain.from_iterable(x))
|
299 |
+
specials = list(OrderedDict.fromkeys(
|
300 |
+
tok for tok in [self.unk_token, self.pad_token, self.init_token,
|
301 |
+
self.eos_token] + kwargs.pop('specials', [])
|
302 |
+
if tok is not None))
|
303 |
+
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
|
304 |
+
|
305 |
+
def numericalize(self, arr, device=None):
|
306 |
+
"""Turn a batch of examples that use this field into a Variable.
|
307 |
+
|
308 |
+
If the field has include_lengths=True, a tensor of lengths will be
|
309 |
+
included in the return value.
|
310 |
+
|
311 |
+
Arguments:
|
312 |
+
arr (List[List[str]], or tuple of (List[List[str]], List[int])):
|
313 |
+
List of tokenized and padded examples, or tuple of List of
|
314 |
+
tokenized and padded examples and List of lengths of each
|
315 |
+
example if self.include_lengths is True.
|
316 |
+
device (str or torch.device): A string or instance of `torch.device`
|
317 |
+
specifying which device the Variables are going to be created on.
|
318 |
+
If left as default, the tensors will be created on cpu. Default: None.
|
319 |
+
"""
|
320 |
+
if self.include_lengths and not isinstance(arr, tuple):
|
321 |
+
raise ValueError("Field has include_lengths set to True, but "
|
322 |
+
"input data is not a tuple of "
|
323 |
+
"(data batch, batch lengths).")
|
324 |
+
if isinstance(arr, tuple):
|
325 |
+
arr, lengths = arr
|
326 |
+
lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
|
327 |
+
|
328 |
+
if self.use_vocab:
|
329 |
+
if self.sequential:
|
330 |
+
arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
|
331 |
+
else:
|
332 |
+
arr = [self.vocab.stoi[x] for x in arr]
|
333 |
+
|
334 |
+
if self.postprocessing is not None:
|
335 |
+
arr = self.postprocessing(arr, self.vocab)
|
336 |
+
else:
|
337 |
+
if self.dtype not in self.dtypes:
|
338 |
+
raise ValueError(
|
339 |
+
"Specified Field dtype {} can not be used with "
|
340 |
+
"use_vocab=False because we do not know how to numericalize it. "
|
341 |
+
"Please raise an issue at "
|
342 |
+
"https://github.com/pytorch/text/issues".format(self.dtype))
|
343 |
+
numericalization_func = self.dtypes[self.dtype]
|
344 |
+
# It doesn't make sense to explicitly coerce to a numeric type if
|
345 |
+
# the data is sequential, since it's unclear how to coerce padding tokens
|
346 |
+
# to a numeric type.
|
347 |
+
if not self.sequential:
|
348 |
+
arr = [numericalization_func(x) if isinstance(x, six.string_types)
|
349 |
+
else x for x in arr]
|
350 |
+
if self.postprocessing is not None:
|
351 |
+
arr = self.postprocessing(arr, None)
|
352 |
+
|
353 |
+
var = torch.tensor(arr, dtype=self.dtype, device=device)
|
354 |
+
|
355 |
+
if self.sequential and not self.batch_first:
|
356 |
+
var.t_()
|
357 |
+
if self.sequential:
|
358 |
+
var = var.contiguous()
|
359 |
+
|
360 |
+
if self.include_lengths:
|
361 |
+
return var, lengths
|
362 |
+
return var
|
363 |
+
|
364 |
+
|
365 |
+
class NestedField(Field):
|
366 |
+
"""A nested field.
|
367 |
+
|
368 |
+
A nested field holds another field (called *nesting field*), accepts an untokenized
|
369 |
+
string or a list string tokens and groups and treats them as one field as described
|
370 |
+
by the nesting field. Every token will be preprocessed, padded, etc. in the manner
|
371 |
+
specified by the nesting field. Note that this means a nested field always has
|
372 |
+
``sequential=True``. The two fields' vocabularies will be shared. Their
|
373 |
+
numericalization results will be stacked into a single tensor. And NestedField will
|
374 |
+
share the same include_lengths with nesting_field, so one shouldn't specify the
|
375 |
+
include_lengths in the nesting_field. This field is
|
376 |
+
primarily used to implement character embeddings. See ``tests/data/test_field.py``
|
377 |
+
for examples on how to use this field.
|
378 |
+
|
379 |
+
Arguments:
|
380 |
+
nesting_field (Field): A field contained in this nested field.
|
381 |
+
use_vocab (bool): Whether to use a Vocab object. If False, the data in this
|
382 |
+
field should already be numerical. Default: ``True``.
|
383 |
+
init_token (str): A token that will be prepended to every example using this
|
384 |
+
field, or None for no initial token. Default: ``None``.
|
385 |
+
eos_token (str): A token that will be appended to every example using this
|
386 |
+
field, or None for no end-of-sentence token. Default: ``None``.
|
387 |
+
fix_length (int): A fixed length that all examples using this field will be
|
388 |
+
padded to, or ``None`` for flexible sequence lengths. Default: ``None``.
|
389 |
+
dtype: The torch.dtype class that represents a batch of examples
|
390 |
+
of this kind of data. Default: ``torch.long``.
|
391 |
+
preprocessing (Pipeline): The Pipeline that will be applied to examples
|
392 |
+
using this field after tokenizing but before numericalizing. Many
|
393 |
+
Datasets replace this attribute with a custom preprocessor.
|
394 |
+
Default: ``None``.
|
395 |
+
postprocessing (Pipeline): A Pipeline that will be applied to examples using
|
396 |
+
this field after numericalizing but before the numbers are turned
|
397 |
+
into a Tensor. The pipeline function takes the batch as a list, and
|
398 |
+
the field's Vocab. Default: ``None``.
|
399 |
+
include_lengths: Whether to return a tuple of a padded minibatch and
|
400 |
+
a list containing the lengths of each examples, or just a padded
|
401 |
+
minibatch. Default: False.
|
402 |
+
tokenize: The function used to tokenize strings using this field into
|
403 |
+
sequential examples. If "spacy", the SpaCy tokenizer is
|
404 |
+
used. If a non-serializable function is passed as an argument,
|
405 |
+
the field will not be able to be serialized. Default: string.split.
|
406 |
+
tokenizer_language: The language of the tokenizer to be constructed.
|
407 |
+
Various languages currently supported only in SpaCy.
|
408 |
+
pad_token (str): The string token used as padding. If ``nesting_field`` is
|
409 |
+
sequential, this will be set to its ``pad_token``. Default: ``"<pad>"``.
|
410 |
+
pad_first (bool): Do the padding of the sequence at the beginning. Default:
|
411 |
+
``False``.
|
412 |
+
"""
|
413 |
+
|
414 |
+
def __init__(self, nesting_field, use_vocab=True, init_token=None, eos_token=None,
|
415 |
+
fix_length=None, dtype=torch.long, preprocessing=None,
|
416 |
+
postprocessing=None, tokenize=None, tokenizer_language='en',
|
417 |
+
include_lengths=False, pad_token='<pad>',
|
418 |
+
pad_first=False, truncate_first=False):
|
419 |
+
if isinstance(nesting_field, NestedField):
|
420 |
+
raise ValueError('nesting field must not be another NestedField')
|
421 |
+
if nesting_field.include_lengths:
|
422 |
+
raise ValueError('nesting field cannot have include_lengths=True')
|
423 |
+
|
424 |
+
if nesting_field.sequential:
|
425 |
+
pad_token = nesting_field.pad_token
|
426 |
+
super(NestedField, self).__init__(
|
427 |
+
use_vocab=use_vocab,
|
428 |
+
init_token=init_token,
|
429 |
+
eos_token=eos_token,
|
430 |
+
fix_length=fix_length,
|
431 |
+
dtype=dtype,
|
432 |
+
preprocessing=preprocessing,
|
433 |
+
postprocessing=postprocessing,
|
434 |
+
lower=nesting_field.lower,
|
435 |
+
tokenize=tokenize,
|
436 |
+
tokenizer_language=tokenizer_language,
|
437 |
+
batch_first=True,
|
438 |
+
pad_token=pad_token,
|
439 |
+
unk_token=nesting_field.unk_token,
|
440 |
+
pad_first=pad_first,
|
441 |
+
truncate_first=truncate_first,
|
442 |
+
include_lengths=include_lengths
|
443 |
+
)
|
444 |
+
self.nesting_field = nesting_field
|
445 |
+
# in case the user forget to do that
|
446 |
+
self.nesting_field.batch_first = True
|
447 |
+
|
448 |
+
def preprocess(self, xs):
|
449 |
+
"""Preprocess a single example.
|
450 |
+
|
451 |
+
Firstly, tokenization and the supplied preprocessing pipeline is applied. Since
|
452 |
+
this field is always sequential, the result is a list. Then, each element of
|
453 |
+
the list is preprocessed using ``self.nesting_field.preprocess`` and the resulting
|
454 |
+
list is returned.
|
455 |
+
|
456 |
+
Arguments:
|
457 |
+
xs (list or str): The input to preprocess.
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
list: The preprocessed list.
|
461 |
+
"""
|
462 |
+
return [self.nesting_field.preprocess(x)
|
463 |
+
for x in super(NestedField, self).preprocess(xs)]
|
464 |
+
|
465 |
+
def pad(self, minibatch):
|
466 |
+
"""Pad a batch of examples using this field.
|
467 |
+
|
468 |
+
If ``self.nesting_field.sequential`` is ``False``, each example in the batch must
|
469 |
+
be a list of string tokens, and pads them as if by a ``Field`` with
|
470 |
+
``sequential=True``. Otherwise, each example must be a list of list of tokens.
|
471 |
+
Using ``self.nesting_field``, pads the list of tokens to
|
472 |
+
``self.nesting_field.fix_length`` if provided, or otherwise to the length of the
|
473 |
+
longest list of tokens in the batch. Next, using this field, pads the result by
|
474 |
+
filling short examples with ``self.nesting_field.pad_token``.
|
475 |
+
|
476 |
+
Example:
|
477 |
+
>>> import pprint
|
478 |
+
>>> pp = pprint.PrettyPrinter(indent=4)
|
479 |
+
>>>
|
480 |
+
>>> nesting_field = Field(pad_token='<c>', init_token='<w>', eos_token='</w>')
|
481 |
+
>>> field = NestedField(nesting_field, init_token='<s>', eos_token='</s>')
|
482 |
+
>>> minibatch = [
|
483 |
+
... [list('john'), list('loves'), list('mary')],
|
484 |
+
... [list('mary'), list('cries')],
|
485 |
+
... ]
|
486 |
+
>>> padded = field.pad(minibatch)
|
487 |
+
>>> pp.pprint(padded)
|
488 |
+
[ [ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
|
489 |
+
['<w>', 'j', 'o', 'h', 'n', '</w>', '<c>'],
|
490 |
+
['<w>', 'l', 'o', 'v', 'e', 's', '</w>'],
|
491 |
+
['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
|
492 |
+
['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>']],
|
493 |
+
[ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
|
494 |
+
['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
|
495 |
+
['<w>', 'c', 'r', 'i', 'e', 's', '</w>'],
|
496 |
+
['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
|
497 |
+
['<c>', '<c>', '<c>', '<c>', '<c>', '<c>', '<c>']]]
|
498 |
+
|
499 |
+
Arguments:
|
500 |
+
minibatch (list): Each element is a list of string if
|
501 |
+
``self.nesting_field.sequential`` is ``False``, a list of list of string
|
502 |
+
otherwise.
|
503 |
+
|
504 |
+
Returns:
|
505 |
+
list: The padded minibatch. or (padded, sentence_lens, word_lengths)
|
506 |
+
"""
|
507 |
+
minibatch = list(minibatch)
|
508 |
+
if not self.nesting_field.sequential:
|
509 |
+
return super(NestedField, self).pad(minibatch)
|
510 |
+
|
511 |
+
# Save values of attributes to be monkeypatched
|
512 |
+
old_pad_token = self.pad_token
|
513 |
+
old_init_token = self.init_token
|
514 |
+
old_eos_token = self.eos_token
|
515 |
+
old_fix_len = self.nesting_field.fix_length
|
516 |
+
# Monkeypatch the attributes
|
517 |
+
if self.nesting_field.fix_length is None:
|
518 |
+
max_len = max(len(xs) for ex in minibatch for xs in ex)
|
519 |
+
fix_len = max_len + 2 - (self.nesting_field.init_token,
|
520 |
+
self.nesting_field.eos_token).count(None)
|
521 |
+
self.nesting_field.fix_length = fix_len
|
522 |
+
self.pad_token = [self.pad_token] * self.nesting_field.fix_length
|
523 |
+
if self.init_token is not None:
|
524 |
+
# self.init_token = self.nesting_field.pad([[self.init_token]])[0]
|
525 |
+
self.init_token = [self.init_token]
|
526 |
+
if self.eos_token is not None:
|
527 |
+
# self.eos_token = self.nesting_field.pad([[self.eos_token]])[0]
|
528 |
+
self.eos_token = [self.eos_token]
|
529 |
+
# Do padding
|
530 |
+
old_include_lengths = self.include_lengths
|
531 |
+
self.include_lengths = True
|
532 |
+
self.nesting_field.include_lengths = True
|
533 |
+
padded, sentence_lengths = super(NestedField, self).pad(minibatch)
|
534 |
+
padded_with_lengths = [self.nesting_field.pad(ex) for ex in padded]
|
535 |
+
word_lengths = []
|
536 |
+
final_padded = []
|
537 |
+
max_sen_len = len(padded[0])
|
538 |
+
for (pad, lens), sentence_len in zip(padded_with_lengths, sentence_lengths):
|
539 |
+
if sentence_len == max_sen_len:
|
540 |
+
lens = lens
|
541 |
+
pad = pad
|
542 |
+
elif self.pad_first:
|
543 |
+
lens[:(max_sen_len - sentence_len)] = (
|
544 |
+
[0] * (max_sen_len - sentence_len))
|
545 |
+
pad[:(max_sen_len - sentence_len)] = (
|
546 |
+
[self.pad_token] * (max_sen_len - sentence_len))
|
547 |
+
else:
|
548 |
+
lens[-(max_sen_len - sentence_len):] = (
|
549 |
+
[0] * (max_sen_len - sentence_len))
|
550 |
+
pad[-(max_sen_len - sentence_len):] = (
|
551 |
+
[self.pad_token] * (max_sen_len - sentence_len))
|
552 |
+
word_lengths.append(lens)
|
553 |
+
final_padded.append(pad)
|
554 |
+
padded = final_padded
|
555 |
+
|
556 |
+
# Restore monkeypatched attributes
|
557 |
+
self.nesting_field.fix_length = old_fix_len
|
558 |
+
self.pad_token = old_pad_token
|
559 |
+
self.init_token = old_init_token
|
560 |
+
self.eos_token = old_eos_token
|
561 |
+
self.include_lengths = old_include_lengths
|
562 |
+
if self.include_lengths:
|
563 |
+
return padded, sentence_lengths, word_lengths
|
564 |
+
return padded
|
565 |
+
|
566 |
+
def build_vocab(self, *args, **kwargs):
|
567 |
+
"""Construct the Vocab object for nesting field and combine it with this field's vocab.
|
568 |
+
|
569 |
+
Arguments:
|
570 |
+
Positional arguments: Dataset objects or other iterable data
|
571 |
+
sources from which to construct the Vocab object that
|
572 |
+
represents the set of possible values for the nesting field. If
|
573 |
+
a Dataset object is provided, all columns corresponding
|
574 |
+
to this field are used; individual columns can also be
|
575 |
+
provided directly.
|
576 |
+
Remaining keyword arguments: Passed to the constructor of Vocab.
|
577 |
+
"""
|
578 |
+
sources = []
|
579 |
+
for arg in args:
|
580 |
+
sources.append(arg)
|
581 |
+
|
582 |
+
flattened = []
|
583 |
+
for source in sources:
|
584 |
+
flattened.extend(source)
|
585 |
+
old_vectors = None
|
586 |
+
old_unk_init = None
|
587 |
+
old_vectors_cache = None
|
588 |
+
if "vectors" in kwargs.keys():
|
589 |
+
old_vectors = kwargs["vectors"]
|
590 |
+
kwargs["vectors"] = None
|
591 |
+
if "unk_init" in kwargs.keys():
|
592 |
+
old_unk_init = kwargs["unk_init"]
|
593 |
+
kwargs["unk_init"] = None
|
594 |
+
if "vectors_cache" in kwargs.keys():
|
595 |
+
old_vectors_cache = kwargs["vectors_cache"]
|
596 |
+
kwargs["vectors_cache"] = None
|
597 |
+
# just build vocab and does not load vector
|
598 |
+
self.nesting_field.build_vocab(*flattened, **kwargs)
|
599 |
+
super(NestedField, self).build_vocab()
|
600 |
+
self.vocab.extend(self.nesting_field.vocab)
|
601 |
+
self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
|
602 |
+
if old_vectors is not None:
|
603 |
+
self.vocab.load_vectors(old_vectors,
|
604 |
+
unk_init=old_unk_init, cache=old_vectors_cache)
|
605 |
+
|
606 |
+
self.nesting_field.vocab = self.vocab
|
607 |
+
|
608 |
+
def numericalize(self, arrs, device=None):
|
609 |
+
"""Convert a padded minibatch into a variable tensor.
|
610 |
+
|
611 |
+
Each item in the minibatch will be numericalized independently and the resulting
|
612 |
+
tensors will be stacked at the first dimension.
|
613 |
+
|
614 |
+
Arguments:
|
615 |
+
arr (List[List[str]]): List of tokenized and padded examples.
|
616 |
+
device (str or torch.device): A string or instance of `torch.device`
|
617 |
+
specifying which device the Variables are going to be created on.
|
618 |
+
If left as default, the tensors will be created on cpu. Default: None.
|
619 |
+
"""
|
620 |
+
numericalized = []
|
621 |
+
self.nesting_field.include_lengths = False
|
622 |
+
if self.include_lengths:
|
623 |
+
arrs, sentence_lengths, word_lengths = arrs
|
624 |
+
|
625 |
+
for arr in arrs:
|
626 |
+
numericalized_ex = self.nesting_field.numericalize(
|
627 |
+
arr, device=device)
|
628 |
+
numericalized.append(numericalized_ex)
|
629 |
+
padded_batch = torch.stack(numericalized)
|
630 |
+
|
631 |
+
self.nesting_field.include_lengths = True
|
632 |
+
if self.include_lengths:
|
633 |
+
sentence_lengths = \
|
634 |
+
torch.tensor(sentence_lengths, dtype=self.dtype, device=device)
|
635 |
+
word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
|
636 |
+
return (padded_batch, sentence_lengths, word_lengths)
|
637 |
+
return padded_batch
|
data/field/mini_torchtext/pipeline.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Pipeline(object):
|
2 |
+
"""Defines a pipeline for transforming sequence data.
|
3 |
+
|
4 |
+
The input is assumed to be utf-8 encoded `str` (Python 3) or
|
5 |
+
`unicode` (Python 2).
|
6 |
+
|
7 |
+
Attributes:
|
8 |
+
convert_token: The function to apply to input sequence data.
|
9 |
+
pipes: The Pipelines that will be applied to input sequence
|
10 |
+
data in order.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, convert_token=None):
|
14 |
+
"""Create a pipeline.
|
15 |
+
|
16 |
+
Arguments:
|
17 |
+
convert_token: The function to apply to input sequence data.
|
18 |
+
If None, the identity function is used. Default: None
|
19 |
+
"""
|
20 |
+
if convert_token is None:
|
21 |
+
self.convert_token = Pipeline.identity
|
22 |
+
elif callable(convert_token):
|
23 |
+
self.convert_token = convert_token
|
24 |
+
else:
|
25 |
+
raise ValueError("Pipeline input convert_token {} is not None "
|
26 |
+
"or callable".format(convert_token))
|
27 |
+
self.pipes = [self]
|
28 |
+
|
29 |
+
def __call__(self, x, *args):
|
30 |
+
"""Apply the the current Pipeline(s) to an input.
|
31 |
+
|
32 |
+
Arguments:
|
33 |
+
x: The input to process with the Pipeline(s).
|
34 |
+
Positional arguments: Forwarded to the `call` function
|
35 |
+
of the Pipeline(s).
|
36 |
+
"""
|
37 |
+
for pipe in self.pipes:
|
38 |
+
x = pipe.call(x, *args)
|
39 |
+
return x
|
40 |
+
|
41 |
+
def call(self, x, *args):
|
42 |
+
"""Apply _only_ the convert_token function of the current pipeline
|
43 |
+
to the input. If the input is a list, a list with the results of
|
44 |
+
applying the `convert_token` function to all input elements is
|
45 |
+
returned.
|
46 |
+
|
47 |
+
Arguments:
|
48 |
+
x: The input to apply the convert_token function to.
|
49 |
+
Positional arguments: Forwarded to the `convert_token` function
|
50 |
+
of the current Pipeline.
|
51 |
+
"""
|
52 |
+
if isinstance(x, list):
|
53 |
+
return [self.convert_token(tok, *args) for tok in x]
|
54 |
+
return self.convert_token(x, *args)
|
55 |
+
|
56 |
+
def add_before(self, pipeline):
|
57 |
+
"""Add a Pipeline to be applied before this processing pipeline.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
pipeline: The Pipeline or callable to apply before this
|
61 |
+
Pipeline.
|
62 |
+
"""
|
63 |
+
if not isinstance(pipeline, Pipeline):
|
64 |
+
pipeline = Pipeline(pipeline)
|
65 |
+
self.pipes = pipeline.pipes[:] + self.pipes[:]
|
66 |
+
return self
|
67 |
+
|
68 |
+
def add_after(self, pipeline):
|
69 |
+
"""Add a Pipeline to be applied after this processing pipeline.
|
70 |
+
|
71 |
+
Arguments:
|
72 |
+
pipeline: The Pipeline or callable to apply after this
|
73 |
+
Pipeline.
|
74 |
+
"""
|
75 |
+
if not isinstance(pipeline, Pipeline):
|
76 |
+
pipeline = Pipeline(pipeline)
|
77 |
+
self.pipes = self.pipes[:] + pipeline.pipes[:]
|
78 |
+
return self
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def identity(x):
|
82 |
+
"""Return a copy of the input.
|
83 |
+
|
84 |
+
This is here for serialization compatibility with pickle.
|
85 |
+
"""
|
86 |
+
return x
|
data/field/mini_torchtext/utils.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from copy import deepcopy
|
4 |
+
import re
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
|
9 |
+
def _split_tokenizer(x):
|
10 |
+
return x.split()
|
11 |
+
|
12 |
+
|
13 |
+
def _spacy_tokenize(x, spacy):
|
14 |
+
return [tok.text for tok in spacy.tokenizer(x)]
|
15 |
+
|
16 |
+
|
17 |
+
_patterns = [r'\'',
|
18 |
+
r'\"',
|
19 |
+
r'\.',
|
20 |
+
r'<br \/>',
|
21 |
+
r',',
|
22 |
+
r'\(',
|
23 |
+
r'\)',
|
24 |
+
r'\!',
|
25 |
+
r'\?',
|
26 |
+
r'\;',
|
27 |
+
r'\:',
|
28 |
+
r'\s+']
|
29 |
+
|
30 |
+
_replacements = [' \' ',
|
31 |
+
'',
|
32 |
+
' . ',
|
33 |
+
' ',
|
34 |
+
' , ',
|
35 |
+
' ( ',
|
36 |
+
' ) ',
|
37 |
+
' ! ',
|
38 |
+
' ? ',
|
39 |
+
' ',
|
40 |
+
' ',
|
41 |
+
' ']
|
42 |
+
|
43 |
+
_patterns_dict = list((re.compile(p), r) for p, r in zip(_patterns, _replacements))
|
44 |
+
|
45 |
+
|
46 |
+
def _basic_english_normalize(line):
|
47 |
+
r"""
|
48 |
+
Basic normalization for a line of text.
|
49 |
+
Normalization includes
|
50 |
+
- lowercasing
|
51 |
+
- complete some basic text normalization for English words as follows:
|
52 |
+
add spaces before and after '\''
|
53 |
+
remove '\"',
|
54 |
+
add spaces before and after '.'
|
55 |
+
replace '<br \/>'with single space
|
56 |
+
add spaces before and after ','
|
57 |
+
add spaces before and after '('
|
58 |
+
add spaces before and after ')'
|
59 |
+
add spaces before and after '!'
|
60 |
+
add spaces before and after '?'
|
61 |
+
replace ';' with single space
|
62 |
+
replace ':' with single space
|
63 |
+
replace multiple spaces with single space
|
64 |
+
|
65 |
+
Returns a list of tokens after splitting on whitespace.
|
66 |
+
"""
|
67 |
+
|
68 |
+
line = line.lower()
|
69 |
+
for pattern_re, replaced_str in _patterns_dict:
|
70 |
+
line = pattern_re.sub(replaced_str, line)
|
71 |
+
return line.split()
|
72 |
+
|
73 |
+
|
74 |
+
def get_tokenizer(tokenizer, language='en'):
|
75 |
+
r"""
|
76 |
+
Generate tokenizer function for a string sentence.
|
77 |
+
|
78 |
+
Arguments:
|
79 |
+
tokenizer: the name of tokenizer function. If None, it returns split()
|
80 |
+
function, which splits the string sentence by space.
|
81 |
+
If basic_english, it returns _basic_english_normalize() function,
|
82 |
+
which normalize the string first and split by space. If a callable
|
83 |
+
function, it will return the function. If a tokenizer library
|
84 |
+
(e.g. spacy, moses, toktok, revtok, subword), it returns the
|
85 |
+
corresponding library.
|
86 |
+
language: Default en
|
87 |
+
|
88 |
+
Examples:
|
89 |
+
>>> import torchtext
|
90 |
+
>>> from torchtext.data import get_tokenizer
|
91 |
+
>>> tokenizer = get_tokenizer("basic_english")
|
92 |
+
>>> tokens = tokenizer("You can now install TorchText using pip!")
|
93 |
+
>>> tokens
|
94 |
+
>>> ['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!']
|
95 |
+
|
96 |
+
"""
|
97 |
+
|
98 |
+
# default tokenizer is string.split(), added as a module function for serialization
|
99 |
+
if tokenizer is None:
|
100 |
+
return _split_tokenizer
|
101 |
+
|
102 |
+
if tokenizer == "basic_english":
|
103 |
+
if language != 'en':
|
104 |
+
raise ValueError("Basic normalization is only available for Enlish(en)")
|
105 |
+
return _basic_english_normalize
|
106 |
+
|
107 |
+
# simply return if a function is passed
|
108 |
+
if callable(tokenizer):
|
109 |
+
return tokenizer
|
110 |
+
|
111 |
+
if tokenizer == "spacy":
|
112 |
+
try:
|
113 |
+
import spacy
|
114 |
+
spacy = spacy.load(language)
|
115 |
+
return partial(_spacy_tokenize, spacy=spacy)
|
116 |
+
except ImportError:
|
117 |
+
print("Please install SpaCy. "
|
118 |
+
"See the docs at https://spacy.io for more information.")
|
119 |
+
raise
|
120 |
+
except AttributeError:
|
121 |
+
print("Please install SpaCy and the SpaCy {} tokenizer. "
|
122 |
+
"See the docs at https://spacy.io for more "
|
123 |
+
"information.".format(language))
|
124 |
+
raise
|
125 |
+
elif tokenizer == "moses":
|
126 |
+
try:
|
127 |
+
from sacremoses import MosesTokenizer
|
128 |
+
moses_tokenizer = MosesTokenizer()
|
129 |
+
return moses_tokenizer.tokenize
|
130 |
+
except ImportError:
|
131 |
+
print("Please install SacreMoses. "
|
132 |
+
"See the docs at https://github.com/alvations/sacremoses "
|
133 |
+
"for more information.")
|
134 |
+
raise
|
135 |
+
elif tokenizer == "toktok":
|
136 |
+
try:
|
137 |
+
from nltk.tokenize.toktok import ToktokTokenizer
|
138 |
+
toktok = ToktokTokenizer()
|
139 |
+
return toktok.tokenize
|
140 |
+
except ImportError:
|
141 |
+
print("Please install NLTK. "
|
142 |
+
"See the docs at https://nltk.org for more information.")
|
143 |
+
raise
|
144 |
+
elif tokenizer == 'revtok':
|
145 |
+
try:
|
146 |
+
import revtok
|
147 |
+
return revtok.tokenize
|
148 |
+
except ImportError:
|
149 |
+
print("Please install revtok.")
|
150 |
+
raise
|
151 |
+
elif tokenizer == 'subword':
|
152 |
+
try:
|
153 |
+
import revtok
|
154 |
+
return partial(revtok.tokenize, decap=True)
|
155 |
+
except ImportError:
|
156 |
+
print("Please install revtok.")
|
157 |
+
raise
|
158 |
+
raise ValueError("Requested tokenizer {}, valid choices are a "
|
159 |
+
"callable that takes a single string as input, "
|
160 |
+
"\"revtok\" for the revtok reversible tokenizer, "
|
161 |
+
"\"subword\" for the revtok caps-aware tokenizer, "
|
162 |
+
"\"spacy\" for the SpaCy English tokenizer, or "
|
163 |
+
"\"moses\" for the NLTK port of the Moses tokenization "
|
164 |
+
"script.".format(tokenizer))
|
165 |
+
|
166 |
+
|
167 |
+
def is_tokenizer_serializable(tokenizer, language):
|
168 |
+
"""Extend with other tokenizers which are found to not be serializable
|
169 |
+
"""
|
170 |
+
if tokenizer == 'spacy':
|
171 |
+
return False
|
172 |
+
return True
|
173 |
+
|
174 |
+
|
175 |
+
def interleave_keys(a, b):
|
176 |
+
"""Interleave bits from two sort keys to form a joint sort key.
|
177 |
+
|
178 |
+
Examples that are similar in both of the provided keys will have similar
|
179 |
+
values for the key defined by this function. Useful for tasks with two
|
180 |
+
text fields like machine translation or natural language inference.
|
181 |
+
"""
|
182 |
+
def interleave(args):
|
183 |
+
return ''.join([x for t in zip(*args) for x in t])
|
184 |
+
return int(''.join(interleave(format(x, '016b') for x in (a, b))), base=2)
|
185 |
+
|
186 |
+
|
187 |
+
def get_torch_version():
|
188 |
+
import torch
|
189 |
+
v = torch.__version__
|
190 |
+
version_substrings = v.split('.')
|
191 |
+
major, minor = version_substrings[0], version_substrings[1]
|
192 |
+
return int(major), int(minor)
|
193 |
+
|
194 |
+
|
195 |
+
def dtype_to_attr(dtype):
|
196 |
+
# convert torch.dtype to dtype string id
|
197 |
+
# e.g. torch.int32 -> "int32"
|
198 |
+
# used for serialization
|
199 |
+
_, dtype = str(dtype).split('.')
|
200 |
+
return dtype
|
201 |
+
|
202 |
+
|
203 |
+
# TODO: Write more tests!
|
204 |
+
def ngrams_iterator(token_list, ngrams):
|
205 |
+
"""Return an iterator that yields the given tokens and their ngrams.
|
206 |
+
|
207 |
+
Arguments:
|
208 |
+
token_list: A list of tokens
|
209 |
+
ngrams: the number of ngrams.
|
210 |
+
|
211 |
+
Examples:
|
212 |
+
>>> token_list = ['here', 'we', 'are']
|
213 |
+
>>> list(ngrams_iterator(token_list, 2))
|
214 |
+
>>> ['here', 'here we', 'we', 'we are', 'are']
|
215 |
+
"""
|
216 |
+
|
217 |
+
def _get_ngrams(n):
|
218 |
+
return zip(*[token_list[i:] for i in range(n)])
|
219 |
+
|
220 |
+
for x in token_list:
|
221 |
+
yield x
|
222 |
+
for n in range(2, ngrams + 1):
|
223 |
+
for x in _get_ngrams(n):
|
224 |
+
yield ' '.join(x)
|
225 |
+
|
226 |
+
|
227 |
+
class RandomShuffler(object):
|
228 |
+
"""Use random functions while keeping track of the random state to make it
|
229 |
+
reproducible and deterministic."""
|
230 |
+
|
231 |
+
def __init__(self, random_state=None):
|
232 |
+
self._random_state = random_state
|
233 |
+
if self._random_state is None:
|
234 |
+
self._random_state = random.getstate()
|
235 |
+
|
236 |
+
@contextmanager
|
237 |
+
def use_internal_state(self):
|
238 |
+
"""Use a specific RNG state."""
|
239 |
+
old_state = random.getstate()
|
240 |
+
random.setstate(self._random_state)
|
241 |
+
yield
|
242 |
+
self._random_state = random.getstate()
|
243 |
+
random.setstate(old_state)
|
244 |
+
|
245 |
+
@property
|
246 |
+
def random_state(self):
|
247 |
+
return deepcopy(self._random_state)
|
248 |
+
|
249 |
+
@random_state.setter
|
250 |
+
def random_state(self, s):
|
251 |
+
self._random_state = s
|
252 |
+
|
253 |
+
def __call__(self, data):
|
254 |
+
"""Shuffle and return a new list."""
|
255 |
+
with self.use_internal_state():
|
256 |
+
return random.sample(data, len(data))
|
data/field/mini_torchtext/vocab.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import unicode_literals
|
2 |
+
from collections import defaultdict
|
3 |
+
import logging
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
class Vocab(object):
|
9 |
+
"""Defines a vocabulary object that will be used to numericalize a field.
|
10 |
+
|
11 |
+
Attributes:
|
12 |
+
freqs: A collections.Counter object holding the frequencies of tokens
|
13 |
+
in the data used to build the Vocab.
|
14 |
+
stoi: A collections.defaultdict instance mapping token strings to
|
15 |
+
numerical identifiers.
|
16 |
+
itos: A list of token strings indexed by their numerical identifiers.
|
17 |
+
"""
|
18 |
+
|
19 |
+
# TODO (@mttk): Populate classs with default values of special symbols
|
20 |
+
UNK = '<unk>'
|
21 |
+
|
22 |
+
def __init__(self, counter, max_size=None, min_freq=1, specials=['<unk>', '<pad>'], specials_first=True):
|
23 |
+
"""Create a Vocab object from a collections.Counter.
|
24 |
+
|
25 |
+
Arguments:
|
26 |
+
counter: collections.Counter object holding the frequencies of
|
27 |
+
each value found in the data.
|
28 |
+
max_size: The maximum size of the vocabulary, or None for no
|
29 |
+
maximum. Default: None.
|
30 |
+
min_freq: The minimum frequency needed to include a token in the
|
31 |
+
vocabulary. Values less than 1 will be set to 1. Default: 1.
|
32 |
+
specials: The list of special tokens (e.g., padding or eos) that
|
33 |
+
will be prepended to the vocabulary. Default: ['<unk'>, '<pad>']
|
34 |
+
specials_first: Whether to add special tokens into the vocabulary at first.
|
35 |
+
If it is False, they are added into the vocabulary at last.
|
36 |
+
Default: True.
|
37 |
+
"""
|
38 |
+
self.freqs = counter
|
39 |
+
counter = counter.copy()
|
40 |
+
min_freq = max(min_freq, 1)
|
41 |
+
|
42 |
+
self.itos = list()
|
43 |
+
self.unk_index = None
|
44 |
+
if specials_first:
|
45 |
+
self.itos = list(specials)
|
46 |
+
# only extend max size if specials are prepended
|
47 |
+
max_size = None if max_size is None else max_size + len(specials)
|
48 |
+
|
49 |
+
# frequencies of special tokens are not counted when building vocabulary
|
50 |
+
# in frequency order
|
51 |
+
for tok in specials:
|
52 |
+
del counter[tok]
|
53 |
+
|
54 |
+
# sort by frequency, then alphabetically
|
55 |
+
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
|
56 |
+
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
|
57 |
+
|
58 |
+
for word, freq in words_and_frequencies:
|
59 |
+
if freq < min_freq or len(self.itos) == max_size:
|
60 |
+
break
|
61 |
+
self.itos.append(word)
|
62 |
+
|
63 |
+
if Vocab.UNK in specials: # hard-coded for now
|
64 |
+
unk_index = specials.index(Vocab.UNK) # position in list
|
65 |
+
# account for ordering of specials, set variable
|
66 |
+
self.unk_index = unk_index if specials_first else len(self.itos) + unk_index
|
67 |
+
self.stoi = defaultdict(self._default_unk_index)
|
68 |
+
else:
|
69 |
+
self.stoi = defaultdict()
|
70 |
+
|
71 |
+
if not specials_first:
|
72 |
+
self.itos.extend(list(specials))
|
73 |
+
|
74 |
+
# stoi is simply a reverse dict for itos
|
75 |
+
self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
|
76 |
+
|
77 |
+
def _default_unk_index(self):
|
78 |
+
return self.unk_index
|
79 |
+
|
80 |
+
def __getitem__(self, token):
|
81 |
+
return self.stoi.get(token, self.stoi.get(Vocab.UNK))
|
82 |
+
|
83 |
+
def __getstate__(self):
|
84 |
+
# avoid picking defaultdict
|
85 |
+
attrs = dict(self.__dict__)
|
86 |
+
# cast to regular dict
|
87 |
+
attrs['stoi'] = dict(self.stoi)
|
88 |
+
return attrs
|
89 |
+
|
90 |
+
def __setstate__(self, state):
|
91 |
+
if state.get("unk_index", None) is None:
|
92 |
+
stoi = defaultdict()
|
93 |
+
else:
|
94 |
+
stoi = defaultdict(self._default_unk_index)
|
95 |
+
stoi.update(state['stoi'])
|
96 |
+
state['stoi'] = stoi
|
97 |
+
self.__dict__.update(state)
|
98 |
+
|
99 |
+
def __eq__(self, other):
|
100 |
+
if self.freqs != other.freqs:
|
101 |
+
return False
|
102 |
+
if self.stoi != other.stoi:
|
103 |
+
return False
|
104 |
+
if self.itos != other.itos:
|
105 |
+
return False
|
106 |
+
return True
|
107 |
+
|
108 |
+
def __len__(self):
|
109 |
+
return len(self.itos)
|
110 |
+
|
111 |
+
def extend(self, v, sort=False):
|
112 |
+
words = sorted(v.itos) if sort else v.itos
|
113 |
+
for w in words:
|
114 |
+
if w not in self.stoi:
|
115 |
+
self.itos.append(w)
|
116 |
+
self.stoi[w] = len(self.itos) - 1
|
data/field/nested_field.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from data.field.mini_torchtext.field import NestedField as TorchTextNestedField
|
6 |
+
|
7 |
+
|
8 |
+
class NestedField(TorchTextNestedField):
|
9 |
+
def pad(self, example):
|
10 |
+
self.nesting_field.include_lengths = self.include_lengths
|
11 |
+
if not self.include_lengths:
|
12 |
+
return self.nesting_field.pad(example)
|
13 |
+
|
14 |
+
sentence_length = len(example)
|
15 |
+
example, word_lengths = self.nesting_field.pad(example)
|
16 |
+
return example, sentence_length, word_lengths
|
17 |
+
|
18 |
+
def numericalize(self, arr, device=None):
|
19 |
+
numericalized = []
|
20 |
+
self.nesting_field.include_lengths = False
|
21 |
+
if self.include_lengths:
|
22 |
+
arr, sentence_length, word_lengths = arr
|
23 |
+
|
24 |
+
numericalized = self.nesting_field.numericalize(arr, device=device)
|
25 |
+
|
26 |
+
self.nesting_field.include_lengths = True
|
27 |
+
if self.include_lengths:
|
28 |
+
sentence_length = torch.tensor(sentence_length, dtype=self.dtype, device=device)
|
29 |
+
word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
|
30 |
+
return (numericalized, sentence_length, word_lengths)
|
31 |
+
return numericalized
|
32 |
+
|
33 |
+
def build_vocab(self, *args, **kwargs):
|
34 |
+
sources = []
|
35 |
+
for arg in args:
|
36 |
+
if isinstance(arg, torch.utils.data.Dataset):
|
37 |
+
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
|
38 |
+
else:
|
39 |
+
sources.append(arg)
|
40 |
+
|
41 |
+
flattened = []
|
42 |
+
for source in sources:
|
43 |
+
flattened.extend(source)
|
44 |
+
|
45 |
+
# just build vocab and does not load vector
|
46 |
+
self.nesting_field.build_vocab(*flattened, **kwargs)
|
47 |
+
super(TorchTextNestedField, self).build_vocab()
|
48 |
+
self.vocab.extend(self.nesting_field.vocab)
|
49 |
+
self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
|
50 |
+
self.nesting_field.vocab = self.vocab
|
data/parser/__init__.py
ADDED
File without changes
|
data/parser/from_mrp/__init__.py
ADDED
File without changes
|
data/parser/from_mrp/abstract_parser.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from data.parser.json_parser import example_from_json
|
6 |
+
|
7 |
+
|
8 |
+
class AbstractParser(torch.utils.data.Dataset):
|
9 |
+
def __init__(self, fields, data, filter_pred=None):
|
10 |
+
super(AbstractParser, self).__init__()
|
11 |
+
|
12 |
+
self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())]
|
13 |
+
|
14 |
+
if isinstance(fields, dict):
|
15 |
+
fields, field_dict = [], fields
|
16 |
+
for field in field_dict.values():
|
17 |
+
if isinstance(field, list):
|
18 |
+
fields.extend(field)
|
19 |
+
else:
|
20 |
+
fields.append(field)
|
21 |
+
|
22 |
+
if filter_pred is not None:
|
23 |
+
make_list = isinstance(self.examples, list)
|
24 |
+
self.examples = filter(filter_pred, self.examples)
|
25 |
+
if make_list:
|
26 |
+
self.examples = list(self.examples)
|
27 |
+
|
28 |
+
self.fields = dict(fields)
|
29 |
+
|
30 |
+
# Unpack field tuples
|
31 |
+
for n, f in list(self.fields.items()):
|
32 |
+
if isinstance(n, tuple):
|
33 |
+
self.fields.update(zip(n, f))
|
34 |
+
del self.fields[n]
|
35 |
+
|
36 |
+
def __getitem__(self, i):
|
37 |
+
item = self.examples[i]
|
38 |
+
processed_item = {}
|
39 |
+
for (name, field) in self.fields.items():
|
40 |
+
if field is not None:
|
41 |
+
processed_item[name] = field.process(getattr(item, name), device=None)
|
42 |
+
return processed_item
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.examples)
|
46 |
+
|
47 |
+
def get_examples(self, attr):
|
48 |
+
if attr in self.fields:
|
49 |
+
for x in self.examples:
|
50 |
+
yield getattr(x, attr)
|
data/parser/from_mrp/evaluation_parser.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
from data.parser.from_mrp.abstract_parser import AbstractParser
|
5 |
+
import utility.parser_utils as utils
|
6 |
+
|
7 |
+
|
8 |
+
class EvaluationParser(AbstractParser):
|
9 |
+
def __init__(self, args, fields):
|
10 |
+
path = args.test_data
|
11 |
+
self.data = utils.load_dataset(path)
|
12 |
+
|
13 |
+
for sentence in self.data.values():
|
14 |
+
sentence["token anchors"] = [[a["from"], a["to"]] for a in sentence["token anchors"]]
|
15 |
+
|
16 |
+
utils.create_bert_tokens(self.data, args.encoder)
|
17 |
+
|
18 |
+
super(EvaluationParser, self).__init__(fields, self.data)
|
data/parser/from_mrp/labeled_edge_parser.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
from data.parser.from_mrp.abstract_parser import AbstractParser
|
5 |
+
import utility.parser_utils as utils
|
6 |
+
|
7 |
+
|
8 |
+
class LabeledEdgeParser(AbstractParser):
|
9 |
+
def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
|
10 |
+
assert part == "training" or part == "validation"
|
11 |
+
path = args.training_data if part == "training" else args.validation_data
|
12 |
+
|
13 |
+
self.data = utils.load_dataset(path)
|
14 |
+
utils.anchor_ids_from_intervals(self.data)
|
15 |
+
|
16 |
+
self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
|
17 |
+
anchor_count, n_node_token_pairs = 0, 0
|
18 |
+
|
19 |
+
for sentence_id, sentence in list(self.data.items()):
|
20 |
+
for edge in sentence["edges"]:
|
21 |
+
if "label" not in edge:
|
22 |
+
del self.data[sentence_id]
|
23 |
+
break
|
24 |
+
|
25 |
+
for node, sentence in utils.node_generator(self.data):
|
26 |
+
node["label"] = "Node"
|
27 |
+
|
28 |
+
self.node_counter += 1
|
29 |
+
|
30 |
+
utils.create_bert_tokens(self.data, args.encoder)
|
31 |
+
|
32 |
+
# create edge vectors
|
33 |
+
for sentence in self.data.values():
|
34 |
+
assert sentence["tops"] == [0], sentence
|
35 |
+
N = len(sentence["nodes"])
|
36 |
+
|
37 |
+
edge_count = utils.create_edges(sentence)
|
38 |
+
self.edge_counter += edge_count
|
39 |
+
self.no_edge_counter += N * (N - 1) - edge_count
|
40 |
+
|
41 |
+
sentence["nodes"] = sentence["nodes"][1:]
|
42 |
+
N = len(sentence["nodes"])
|
43 |
+
|
44 |
+
sentence["anchor edges"] = [N, len(sentence["input"]), []]
|
45 |
+
sentence["source anchor edges"] = [N, len(sentence["input"]), []] # dummy
|
46 |
+
sentence["target anchor edges"] = [N, len(sentence["input"]), []] # dummy
|
47 |
+
sentence["anchored labels"] = [len(sentence["input"]), []]
|
48 |
+
for i, node in enumerate(sentence["nodes"]):
|
49 |
+
anchored_labels = []
|
50 |
+
|
51 |
+
for anchor in node["anchors"]:
|
52 |
+
sentence["anchor edges"][-1].append((i, anchor))
|
53 |
+
anchored_labels.append((anchor, node["label"]))
|
54 |
+
|
55 |
+
sentence["anchored labels"][1].append(anchored_labels)
|
56 |
+
|
57 |
+
anchor_count += len(node["anchors"])
|
58 |
+
n_node_token_pairs += len(sentence["input"])
|
59 |
+
|
60 |
+
sentence["id"] = [sentence["id"]]
|
61 |
+
|
62 |
+
self.anchor_freq = anchor_count / n_node_token_pairs
|
63 |
+
self.source_anchor_freq = self.target_anchor_freq = 0.5 # dummy
|
64 |
+
self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
|
65 |
+
|
66 |
+
super(LabeledEdgeParser, self).__init__(fields, self.data, filter_pred)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def node_similarity_key(node):
|
70 |
+
return tuple([node["label"]] + node["anchors"])
|
data/parser/from_mrp/node_centric_parser.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
from data.parser.from_mrp.abstract_parser import AbstractParser
|
5 |
+
import utility.parser_utils as utils
|
6 |
+
|
7 |
+
|
8 |
+
class NodeCentricParser(AbstractParser):
|
9 |
+
def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
|
10 |
+
assert part == "training" or part == "validation"
|
11 |
+
path = args.training_data if part == "training" else args.validation_data
|
12 |
+
|
13 |
+
self.data = utils.load_dataset(path)
|
14 |
+
utils.anchor_ids_from_intervals(self.data)
|
15 |
+
|
16 |
+
self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
|
17 |
+
anchor_count, n_node_token_pairs = 0, 0
|
18 |
+
|
19 |
+
for sentence_id, sentence in list(self.data.items()):
|
20 |
+
for node in sentence["nodes"]:
|
21 |
+
if "label" not in node:
|
22 |
+
del self.data[sentence_id]
|
23 |
+
break
|
24 |
+
|
25 |
+
for node, _ in utils.node_generator(self.data):
|
26 |
+
self.node_counter += 1
|
27 |
+
|
28 |
+
# print(f"Number of unlabeled nodes: {unlabeled_count}", flush=True)
|
29 |
+
|
30 |
+
utils.create_bert_tokens(self.data, args.encoder)
|
31 |
+
|
32 |
+
# create edge vectors
|
33 |
+
for sentence in self.data.values():
|
34 |
+
N = len(sentence["nodes"])
|
35 |
+
|
36 |
+
edge_count = utils.create_edges(sentence)
|
37 |
+
self.edge_counter += edge_count
|
38 |
+
# self.no_edge_counter += len([n for n in sentence["nodes"] if n["label"] in ["Source", "Target"]]) * len([n for n in sentence["nodes"] if n["label"] not in ["Source", "Target"]]) - edge_count
|
39 |
+
self.no_edge_counter += N * (N - 1) - edge_count
|
40 |
+
|
41 |
+
sentence["anchor edges"] = [N, len(sentence["input"]), []]
|
42 |
+
sentence["source anchor edges"] = [N, len(sentence["input"]), []] # dummy
|
43 |
+
sentence["target anchor edges"] = [N, len(sentence["input"]), []] # dummy
|
44 |
+
sentence["anchored labels"] = [len(sentence["input"]), []]
|
45 |
+
for i, node in enumerate(sentence["nodes"]):
|
46 |
+
anchored_labels = []
|
47 |
+
#if len(node["anchors"]) == 0:
|
48 |
+
# print(f"Empty node in {sentence['id']}", flush=True)
|
49 |
+
|
50 |
+
for anchor in node["anchors"]:
|
51 |
+
sentence["anchor edges"][-1].append((i, anchor))
|
52 |
+
anchored_labels.append((anchor, node["label"]))
|
53 |
+
|
54 |
+
sentence["anchored labels"][1].append(anchored_labels)
|
55 |
+
|
56 |
+
anchor_count += len(node["anchors"])
|
57 |
+
n_node_token_pairs += len(sentence["input"])
|
58 |
+
|
59 |
+
sentence["id"] = [sentence["id"]]
|
60 |
+
|
61 |
+
self.anchor_freq = anchor_count / n_node_token_pairs
|
62 |
+
self.source_anchor_freq = self.target_anchor_freq = 0.5 # dummy
|
63 |
+
self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
|
64 |
+
|
65 |
+
super(NodeCentricParser, self).__init__(fields, self.data, filter_pred)
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def node_similarity_key(node):
|
69 |
+
return tuple([node["label"]] + node["anchors"])
|
data/parser/from_mrp/request_parser.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import utility.parser_utils as utils
|
5 |
+
from data.parser.from_mrp.abstract_parser import AbstractParser
|
6 |
+
|
7 |
+
|
8 |
+
class RequestParser(AbstractParser):
|
9 |
+
def __init__(self, sentences, args, fields):
|
10 |
+
self.data = {i: {"id": str(i), "sentence": sentence} for i, sentence in enumerate(sentences)}
|
11 |
+
|
12 |
+
sentences = [example["sentence"] for example in self.data.values()]
|
13 |
+
|
14 |
+
for example in self.data.values():
|
15 |
+
example["input"] = example["sentence"].strip().split(' ')
|
16 |
+
example["token anchors"], offset = [], 0
|
17 |
+
for token in example["input"]:
|
18 |
+
example["token anchors"].append([offset, offset + len(token)])
|
19 |
+
offset += len(token) + 1
|
20 |
+
|
21 |
+
utils.create_bert_tokens(self.data, args.encoder)
|
22 |
+
|
23 |
+
super(RequestParser, self).__init__(fields, self.data)
|
data/parser/from_mrp/sequential_parser.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
from data.parser.from_mrp.abstract_parser import AbstractParser
|
5 |
+
import utility.parser_utils as utils
|
6 |
+
|
7 |
+
|
8 |
+
class SequentialParser(AbstractParser):
|
9 |
+
def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
|
10 |
+
assert part == "training" or part == "validation"
|
11 |
+
path = args.training_data if part == "training" else args.validation_data
|
12 |
+
|
13 |
+
self.data = utils.load_dataset(path)
|
14 |
+
utils.anchor_ids_from_intervals(self.data)
|
15 |
+
|
16 |
+
self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
|
17 |
+
anchor_count, source_anchor_count, target_anchor_count, n_node_token_pairs = 0, 0, 0, 0
|
18 |
+
|
19 |
+
for sentence_id, sentence in list(self.data.items()):
|
20 |
+
for node in sentence["nodes"]:
|
21 |
+
if "label" not in node:
|
22 |
+
del self.data[sentence_id]
|
23 |
+
break
|
24 |
+
|
25 |
+
for node, _ in utils.node_generator(self.data):
|
26 |
+
node["target anchors"] = []
|
27 |
+
node["source anchors"] = []
|
28 |
+
|
29 |
+
for sentence in self.data.values():
|
30 |
+
for e in sentence["edges"]:
|
31 |
+
source, target = e["source"], e["target"]
|
32 |
+
|
33 |
+
if sentence["nodes"][target]["label"] == "Target":
|
34 |
+
sentence["nodes"][source]["target anchors"] += sentence["nodes"][target]["anchors"]
|
35 |
+
elif sentence["nodes"][target]["label"] == "Source":
|
36 |
+
sentence["nodes"][source]["source anchors"] += sentence["nodes"][target]["anchors"]
|
37 |
+
|
38 |
+
for i, node in list(enumerate(sentence["nodes"]))[::-1]:
|
39 |
+
if "label" not in node or node["label"] in ["Source", "Target"]:
|
40 |
+
del sentence["nodes"][i]
|
41 |
+
sentence["edges"] = []
|
42 |
+
|
43 |
+
for node, sentence in utils.node_generator(self.data):
|
44 |
+
self.node_counter += 1
|
45 |
+
|
46 |
+
utils.create_bert_tokens(self.data, args.encoder)
|
47 |
+
|
48 |
+
# create edge vectors
|
49 |
+
for sentence in self.data.values():
|
50 |
+
N = len(sentence["nodes"])
|
51 |
+
|
52 |
+
utils.create_edges(sentence)
|
53 |
+
self.no_edge_counter += N * (N - 1)
|
54 |
+
|
55 |
+
sentence["anchor edges"] = [N, len(sentence["input"]), []]
|
56 |
+
sentence["source anchor edges"] = [N, len(sentence["input"]), []]
|
57 |
+
sentence["target anchor edges"] = [N, len(sentence["input"]), []]
|
58 |
+
|
59 |
+
sentence["anchored labels"] = [len(sentence["input"]), []]
|
60 |
+
for i, node in enumerate(sentence["nodes"]):
|
61 |
+
anchored_labels = []
|
62 |
+
|
63 |
+
for anchor in node["anchors"]:
|
64 |
+
sentence["anchor edges"][-1].append((i, anchor))
|
65 |
+
anchored_labels.append((anchor, node["label"]))
|
66 |
+
|
67 |
+
for anchor in node["source anchors"]:
|
68 |
+
sentence["source anchor edges"][-1].append((i, anchor))
|
69 |
+
for anchor in node["target anchors"]:
|
70 |
+
sentence["target anchor edges"][-1].append((i, anchor))
|
71 |
+
|
72 |
+
sentence["anchored labels"][1].append(anchored_labels)
|
73 |
+
|
74 |
+
anchor_count += len(node["anchors"])
|
75 |
+
source_anchor_count += len(node["source anchors"])
|
76 |
+
target_anchor_count += len(node["target anchors"])
|
77 |
+
n_node_token_pairs += len(sentence["input"])
|
78 |
+
|
79 |
+
sentence["id"] = [sentence["id"]]
|
80 |
+
|
81 |
+
self.anchor_freq = anchor_count / n_node_token_pairs
|
82 |
+
self.source_anchor_freq = anchor_count / n_node_token_pairs
|
83 |
+
self.target_anchor_freq = anchor_count / n_node_token_pairs
|
84 |
+
self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
|
85 |
+
|
86 |
+
super(SequentialParser, self).__init__(fields, self.data, filter_pred)
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def node_similarity_key(node):
|
90 |
+
return tuple([node["label"]] + node["anchors"])
|
data/parser/json_parser.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce
|
2 |
+
from data.field.mini_torchtext.example import Example
|
3 |
+
|
4 |
+
|
5 |
+
def example_from_json(obj, fields):
|
6 |
+
ex = Example()
|
7 |
+
for key, vals in fields.items():
|
8 |
+
if vals is not None:
|
9 |
+
if not isinstance(vals, list):
|
10 |
+
vals = [vals]
|
11 |
+
for val in vals:
|
12 |
+
# for processing the key likes 'foo.bar'
|
13 |
+
name, field = val
|
14 |
+
ks = key.split(".")
|
15 |
+
|
16 |
+
def reducer(obj, key):
|
17 |
+
if isinstance(obj, list):
|
18 |
+
results = []
|
19 |
+
for data in obj:
|
20 |
+
if key not in data:
|
21 |
+
# key error
|
22 |
+
raise ValueError("Specified key {} was not found in " "the input data".format(key))
|
23 |
+
else:
|
24 |
+
results.append(data[key])
|
25 |
+
return results
|
26 |
+
else:
|
27 |
+
# key error
|
28 |
+
if key not in obj:
|
29 |
+
raise ValueError("Specified key {} was not found in " "the input data".format(key))
|
30 |
+
else:
|
31 |
+
return obj[key]
|
32 |
+
|
33 |
+
v = reduce(reducer, ks, obj)
|
34 |
+
setattr(ex, name, field.preprocess(v))
|
35 |
+
return ex
|
data/parser/to_mrp/__init__.py
ADDED
File without changes
|
data/parser/to_mrp/abstract_parser.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
class AbstractParser:
|
5 |
+
def __init__(self, dataset):
|
6 |
+
self.dataset = dataset
|
7 |
+
|
8 |
+
def create_nodes(self, prediction):
|
9 |
+
return [
|
10 |
+
{"id": i, "label": self.label_to_str(l, prediction["anchors"][i], prediction)}
|
11 |
+
for i, l in enumerate(prediction["labels"])
|
12 |
+
]
|
13 |
+
|
14 |
+
def label_to_str(self, label, anchors, prediction):
|
15 |
+
return self.dataset.label_field.vocab.itos[label - 1]
|
16 |
+
|
17 |
+
def create_edges(self, prediction, nodes):
|
18 |
+
N = len(nodes)
|
19 |
+
node_sets = [{"id": n, "set": set([n])} for n in range(N)]
|
20 |
+
_, indices = prediction["edge presence"][:N, :N].reshape(-1).sort(descending=True)
|
21 |
+
sources, targets = indices // N, indices % N
|
22 |
+
|
23 |
+
edges = []
|
24 |
+
for i in range((N - 1) * N // 2):
|
25 |
+
source, target = sources[i].item(), targets[i].item()
|
26 |
+
p = prediction["edge presence"][source, target]
|
27 |
+
|
28 |
+
if p < 0.5 and len(edges) >= N - 1:
|
29 |
+
break
|
30 |
+
|
31 |
+
if node_sets[source]["set"] is node_sets[target]["set"] and p < 0.5:
|
32 |
+
continue
|
33 |
+
|
34 |
+
self.create_edge(source, target, prediction, edges, nodes)
|
35 |
+
|
36 |
+
if node_sets[source]["set"] is not node_sets[target]["set"]:
|
37 |
+
from_set = node_sets[source]["set"]
|
38 |
+
for n in node_sets[target]["set"]:
|
39 |
+
from_set.add(n)
|
40 |
+
node_sets[n]["set"] = from_set
|
41 |
+
|
42 |
+
return edges
|
43 |
+
|
44 |
+
def create_edge(self, source, target, prediction, edges, nodes):
|
45 |
+
label = self.get_edge_label(prediction, source, target)
|
46 |
+
edge = {"source": source, "target": target, "label": label}
|
47 |
+
|
48 |
+
edges.append(edge)
|
49 |
+
|
50 |
+
def create_anchors(self, prediction, nodes, join_contiguous=True, at_least_one=False, single_anchor=False, mode="anchors"):
|
51 |
+
for i, node in enumerate(nodes):
|
52 |
+
threshold = 0.5 if not at_least_one else min(0.5, prediction[mode][i].max().item())
|
53 |
+
node[mode] = (prediction[mode][i] >= threshold).nonzero(as_tuple=False).squeeze(-1)
|
54 |
+
node[mode] = prediction["token intervals"][node[mode], :]
|
55 |
+
|
56 |
+
if single_anchor and len(node[mode]) > 1:
|
57 |
+
start = min(a[0].item() for a in node[mode])
|
58 |
+
end = max(a[1].item() for a in node[mode])
|
59 |
+
node[mode] = [{"from": start, "to": end}]
|
60 |
+
continue
|
61 |
+
|
62 |
+
node[mode] = [{"from": f.item(), "to": t.item()} for f, t in node[mode]]
|
63 |
+
node[mode] = sorted(node[mode], key=lambda a: a["from"])
|
64 |
+
|
65 |
+
if join_contiguous and len(node[mode]) > 1:
|
66 |
+
cleaned_anchors = []
|
67 |
+
end, start = node[mode][0]["from"], node[mode][0]["from"]
|
68 |
+
for anchor in node[mode]:
|
69 |
+
if end < anchor["from"]:
|
70 |
+
cleaned_anchors.append({"from": start, "to": end})
|
71 |
+
start = anchor["from"]
|
72 |
+
end = anchor["to"]
|
73 |
+
cleaned_anchors.append({"from": start, "to": end})
|
74 |
+
|
75 |
+
node[mode] = cleaned_anchors
|
76 |
+
|
77 |
+
return nodes
|
78 |
+
|
79 |
+
def get_edge_label(self, prediction, source, target):
|
80 |
+
return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].item()]
|
data/parser/to_mrp/labeled_edge_parser.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
from data.parser.to_mrp.abstract_parser import AbstractParser
|
5 |
+
|
6 |
+
|
7 |
+
class LabeledEdgeParser(AbstractParser):
|
8 |
+
def __init__(self, *args):
|
9 |
+
super().__init__(*args)
|
10 |
+
self.source_id = self.dataset.edge_label_field.vocab.stoi["Source"]
|
11 |
+
self.target_id = self.dataset.edge_label_field.vocab.stoi["Target"]
|
12 |
+
|
13 |
+
def parse(self, prediction):
|
14 |
+
output = {}
|
15 |
+
|
16 |
+
output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
|
17 |
+
output["nodes"] = self.create_nodes(prediction)
|
18 |
+
output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True)
|
19 |
+
output["nodes"] = [{"id": 0}] + output["nodes"]
|
20 |
+
output["edges"] = self.create_edges(prediction, output["nodes"])
|
21 |
+
|
22 |
+
return output
|
23 |
+
|
24 |
+
def create_nodes(self, prediction):
|
25 |
+
return [{"id": i + 1} for i, l in enumerate(prediction["labels"])]
|
26 |
+
|
27 |
+
def create_edges(self, prediction, nodes):
|
28 |
+
N = len(nodes)
|
29 |
+
edge_prediction = prediction["edge presence"][:N, :N]
|
30 |
+
|
31 |
+
edges = []
|
32 |
+
for target in range(1, N):
|
33 |
+
if edge_prediction[0, target] >= 0.5:
|
34 |
+
prediction["edge labels"][0, target, self.source_id] = float("-inf")
|
35 |
+
prediction["edge labels"][0, target, self.target_id] = float("-inf")
|
36 |
+
self.create_edge(0, target, prediction, edges, nodes)
|
37 |
+
|
38 |
+
for source in range(1, N):
|
39 |
+
for target in range(1, N):
|
40 |
+
if source == target:
|
41 |
+
continue
|
42 |
+
if edge_prediction[source, target] < 0.5:
|
43 |
+
continue
|
44 |
+
for i in range(prediction["edge labels"].size(2)):
|
45 |
+
if i not in [self.source_id, self.target_id]:
|
46 |
+
prediction["edge labels"][source, target, i] = float("-inf")
|
47 |
+
self.create_edge(source, target, prediction, edges, nodes)
|
48 |
+
|
49 |
+
return edges
|
50 |
+
|
51 |
+
def get_edge_label(self, prediction, source, target):
|
52 |
+
return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].argmax(-1).item()]
|
data/parser/to_mrp/node_centric_parser.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
from data.parser.to_mrp.abstract_parser import AbstractParser
|
5 |
+
|
6 |
+
|
7 |
+
class NodeCentricParser(AbstractParser):
|
8 |
+
def parse(self, prediction):
|
9 |
+
output = {}
|
10 |
+
|
11 |
+
output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
|
12 |
+
output["nodes"] = self.create_nodes(prediction)
|
13 |
+
output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True)
|
14 |
+
output["edges"] = self.create_edges(prediction, output["nodes"])
|
15 |
+
|
16 |
+
return output
|
17 |
+
|
18 |
+
def create_edge(self, source, target, prediction, edges, nodes):
|
19 |
+
edge = {"source": source, "target": target, "label": None}
|
20 |
+
edges.append(edge)
|
21 |
+
|
22 |
+
def create_edges(self, prediction, nodes):
|
23 |
+
N = len(nodes)
|
24 |
+
edge_prediction = prediction["edge presence"][:N, :N]
|
25 |
+
|
26 |
+
targets = [i for i, node in enumerate(nodes) if node["label"] in ["Source", "Target"]]
|
27 |
+
sources = [i for i, node in enumerate(nodes) if node["label"] not in ["Source", "Target"]]
|
28 |
+
|
29 |
+
edges = []
|
30 |
+
for target in targets:
|
31 |
+
for source in sources:
|
32 |
+
if edge_prediction[source, target] >= 0.5:
|
33 |
+
self.create_edge(source, target, prediction, edges, nodes)
|
34 |
+
|
35 |
+
return edges
|
data/parser/to_mrp/sequential_parser.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
from data.parser.to_mrp.abstract_parser import AbstractParser
|
5 |
+
|
6 |
+
|
7 |
+
class SequentialParser(AbstractParser):
|
8 |
+
def parse(self, prediction):
|
9 |
+
output = {}
|
10 |
+
|
11 |
+
output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
|
12 |
+
output["nodes"] = self.create_nodes(prediction)
|
13 |
+
output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True, mode="anchors")
|
14 |
+
output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="source anchors")
|
15 |
+
output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="target anchors")
|
16 |
+
output["edges"], output["nodes"] = self.create_targets_sources(output["nodes"])
|
17 |
+
|
18 |
+
return output
|
19 |
+
|
20 |
+
def create_targets_sources(self, nodes):
|
21 |
+
edges, new_nodes = [], []
|
22 |
+
for i, node in enumerate(nodes):
|
23 |
+
new_node_id = len(nodes) + len(new_nodes)
|
24 |
+
if len(node["source anchors"]) > 0:
|
25 |
+
new_nodes.append({"id": new_node_id, "label": "Source", "anchors": node["source anchors"]})
|
26 |
+
edges.append({"source": i, "target": new_node_id, "label": ""})
|
27 |
+
new_node_id += 1
|
28 |
+
del node["source anchors"]
|
29 |
+
|
30 |
+
if len(node["target anchors"]) > 0:
|
31 |
+
new_nodes.append({"id": new_node_id, "label": "Target", "anchors": node["target anchors"]})
|
32 |
+
edges.append({"source": i, "target": new_node_id, "label": ""})
|
33 |
+
del node["target anchors"]
|
34 |
+
|
35 |
+
return edges, nodes + new_nodes
|
model/__init__.py
ADDED
File without changes
|
model/head/__init__.py
ADDED
File without changes
|
model/head/abstract_head.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from model.module.edge_classifier import EdgeClassifier
|
10 |
+
from model.module.anchor_classifier import AnchorClassifier
|
11 |
+
from utility.cross_entropy import cross_entropy, binary_cross_entropy
|
12 |
+
from utility.hungarian_matching import get_matching, reorder, match_anchor, match_label
|
13 |
+
from utility.utils import create_padding_mask
|
14 |
+
|
15 |
+
|
16 |
+
class AbstractHead(nn.Module):
|
17 |
+
def __init__(self, dataset, args, config, initialize: bool):
|
18 |
+
super(AbstractHead, self).__init__()
|
19 |
+
|
20 |
+
self.edge_classifier = self.init_edge_classifier(dataset, args, config, initialize)
|
21 |
+
self.label_classifier = self.init_label_classifier(dataset, args, config, initialize)
|
22 |
+
self.anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="anchor")
|
23 |
+
self.source_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="source_anchor")
|
24 |
+
self.target_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="target_anchor")
|
25 |
+
|
26 |
+
self.query_length = args.query_length
|
27 |
+
self.focal = args.focal
|
28 |
+
self.dataset = dataset
|
29 |
+
|
30 |
+
def forward(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch):
|
31 |
+
output = {}
|
32 |
+
|
33 |
+
decoder_lens = self.query_length * batch["every_input"][1]
|
34 |
+
output["label"] = self.forward_label(decoder_output)
|
35 |
+
output["anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w)
|
36 |
+
output["source_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w)
|
37 |
+
output["target_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w)
|
38 |
+
|
39 |
+
cost_matrices = self.create_cost_matrices(output, batch, decoder_lens)
|
40 |
+
matching = get_matching(cost_matrices)
|
41 |
+
decoder_output = reorder(decoder_output, matching, batch["labels"][0].size(1))
|
42 |
+
output["edge presence"], output["edge label"] = self.forward_edge(decoder_output)
|
43 |
+
|
44 |
+
return self.loss(output, batch, matching, decoder_mask)
|
45 |
+
|
46 |
+
def predict(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch, **kwargs):
|
47 |
+
every_input, word_lens = batch["every_input"]
|
48 |
+
decoder_lens = self.query_length * word_lens
|
49 |
+
batch_size = every_input.size(0)
|
50 |
+
|
51 |
+
label_pred = self.forward_label(decoder_output)
|
52 |
+
anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w)
|
53 |
+
source_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w)
|
54 |
+
target_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w)
|
55 |
+
|
56 |
+
labels = [[] for _ in range(batch_size)]
|
57 |
+
anchors, source_anchors, target_anchors = [[] for _ in range(batch_size)], [[] for _ in range(batch_size)], [[] for _ in range(batch_size)]
|
58 |
+
|
59 |
+
for b in range(batch_size):
|
60 |
+
label_indices = self.inference_label(label_pred[b, :decoder_lens[b], :]).cpu()
|
61 |
+
for t in range(label_indices.size(0)):
|
62 |
+
label_index = label_indices[t].item()
|
63 |
+
if label_index == 0:
|
64 |
+
continue
|
65 |
+
|
66 |
+
decoder_output[b, len(labels[b]), :] = decoder_output[b, t, :]
|
67 |
+
|
68 |
+
labels[b].append(label_index)
|
69 |
+
if anchor_pred is None:
|
70 |
+
anchors[b].append(list(range(t // self.query_length, word_lens[b])))
|
71 |
+
else:
|
72 |
+
anchors[b].append(self.inference_anchor(anchor_pred[b, t, :word_lens[b]]).cpu())
|
73 |
+
|
74 |
+
if source_anchor_pred is None:
|
75 |
+
source_anchors[b].append(list(range(t // self.query_length, word_lens[b])))
|
76 |
+
else:
|
77 |
+
source_anchors[b].append(self.inference_anchor(source_anchor_pred[b, t, :word_lens[b]]).cpu())
|
78 |
+
|
79 |
+
if target_anchor_pred is None:
|
80 |
+
target_anchors[b].append(list(range(t // self.query_length, word_lens[b])))
|
81 |
+
else:
|
82 |
+
target_anchors[b].append(self.inference_anchor(target_anchor_pred[b, t, :word_lens[b]]).cpu())
|
83 |
+
|
84 |
+
decoder_output = decoder_output[:, : max(len(l) for l in labels), :]
|
85 |
+
edge_presence, edge_labels = self.forward_edge(decoder_output)
|
86 |
+
|
87 |
+
outputs = [
|
88 |
+
self.parser.parse(
|
89 |
+
{
|
90 |
+
"labels": labels[b],
|
91 |
+
"anchors": anchors[b],
|
92 |
+
"source anchors": source_anchors[b],
|
93 |
+
"target anchors": target_anchors[b],
|
94 |
+
"edge presence": self.inference_edge_presence(edge_presence, b),
|
95 |
+
"edge labels": self.inference_edge_label(edge_labels, b),
|
96 |
+
"id": batch["id"][b].cpu(),
|
97 |
+
"tokens": batch["every_input"][0][b, : word_lens[b]].cpu(),
|
98 |
+
"token intervals": batch["token_intervals"][b, :, :].cpu(),
|
99 |
+
},
|
100 |
+
**kwargs
|
101 |
+
)
|
102 |
+
for b in range(batch_size)
|
103 |
+
]
|
104 |
+
|
105 |
+
return outputs
|
106 |
+
|
107 |
+
def loss(self, output, batch, matching, decoder_mask):
|
108 |
+
batch_size = batch["every_input"][0].size(0)
|
109 |
+
device = batch["every_input"][0].device
|
110 |
+
T_label = batch["labels"][0].size(1)
|
111 |
+
T_input = batch["every_input"][0].size(1)
|
112 |
+
T_edge = batch["edge_presence"].size(1)
|
113 |
+
|
114 |
+
input_mask = create_padding_mask(batch_size, T_input, batch["every_input"][1], device) # shape: (B, T_input)
|
115 |
+
label_mask = create_padding_mask(batch_size, T_label, batch["labels"][1], device) # shape: (B, T_label)
|
116 |
+
edge_mask = torch.eye(T_label, T_label, device=device, dtype=torch.bool).unsqueeze(0) # shape: (1, T_label, T_label)
|
117 |
+
edge_mask = edge_mask | label_mask.unsqueeze(1) | label_mask.unsqueeze(2) # shape: (B, T_label, T_label)
|
118 |
+
if T_edge != T_label:
|
119 |
+
edge_mask = F.pad(edge_mask, (T_edge - T_label, 0, T_edge - T_label, 0), value=0)
|
120 |
+
edge_label_mask = (batch["edge_presence"] == 0) | edge_mask
|
121 |
+
|
122 |
+
if output["edge label"] is not None:
|
123 |
+
batch["edge_labels"] = (
|
124 |
+
batch["edge_labels"][0][:, :, :, :output["edge label"].size(-1)],
|
125 |
+
batch["edge_labels"][1],
|
126 |
+
)
|
127 |
+
|
128 |
+
losses = {}
|
129 |
+
losses.update(self.loss_label(output, batch, decoder_mask, matching))
|
130 |
+
losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="anchor"))
|
131 |
+
losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="source_anchor"))
|
132 |
+
losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="target_anchor"))
|
133 |
+
losses.update(self.loss_edge_presence(output, batch, edge_mask))
|
134 |
+
losses.update(self.loss_edge_label(output, batch, edge_label_mask.unsqueeze(-1)))
|
135 |
+
|
136 |
+
stats = {f"{key}": value.detach().cpu().item() for key, value in losses.items()}
|
137 |
+
total_loss = sum(losses.values()) / len(losses)
|
138 |
+
|
139 |
+
return total_loss, stats
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def create_cost_matrices(self, output, batch, decoder_lens):
|
143 |
+
batch_size = len(batch["labels"][1])
|
144 |
+
decoder_lens = decoder_lens.cpu()
|
145 |
+
|
146 |
+
matrices = []
|
147 |
+
for b in range(batch_size):
|
148 |
+
label_cost_matrix = self.label_cost_matrix(output, batch, decoder_lens, b)
|
149 |
+
anchor_cost_matrix = self.anchor_cost_matrix(output, batch, decoder_lens, b)
|
150 |
+
|
151 |
+
cost_matrix = label_cost_matrix * anchor_cost_matrix
|
152 |
+
matrices.append(cost_matrix.cpu())
|
153 |
+
|
154 |
+
return matrices
|
155 |
+
|
156 |
+
def init_edge_classifier(self, dataset, args, config, initialize: bool):
|
157 |
+
if not config["edge presence"] and not config["edge label"]:
|
158 |
+
return None
|
159 |
+
return EdgeClassifier(dataset, args, initialize, presence=config["edge presence"], label=config["edge label"])
|
160 |
+
|
161 |
+
def init_label_classifier(self, dataset, args, config, initialize: bool):
|
162 |
+
if not config["label"]:
|
163 |
+
return None
|
164 |
+
|
165 |
+
classifier = nn.Sequential(
|
166 |
+
nn.Dropout(args.dropout_label),
|
167 |
+
nn.Linear(args.hidden_size, len(dataset.label_field.vocab) + 1, bias=True)
|
168 |
+
)
|
169 |
+
if initialize:
|
170 |
+
classifier[1].bias.data = dataset.label_freqs.log()
|
171 |
+
|
172 |
+
return classifier
|
173 |
+
|
174 |
+
def init_anchor_classifier(self, dataset, args, config, initialize: bool, mode="anchor"):
|
175 |
+
if not config[mode]:
|
176 |
+
return None
|
177 |
+
|
178 |
+
return AnchorClassifier(dataset, args, initialize, mode=mode)
|
179 |
+
|
180 |
+
def forward_edge(self, decoder_output):
|
181 |
+
if self.edge_classifier is None:
|
182 |
+
return None, None
|
183 |
+
return self.edge_classifier(decoder_output)
|
184 |
+
|
185 |
+
def forward_label(self, decoder_output):
|
186 |
+
if self.label_classifier is None:
|
187 |
+
return None
|
188 |
+
return torch.log_softmax(self.label_classifier(decoder_output), dim=-1)
|
189 |
+
|
190 |
+
def forward_anchor(self, decoder_output, encoder_output, encoder_mask, mode="anchor"):
|
191 |
+
classifier = getattr(self, f"{mode}_classifier")
|
192 |
+
if classifier is None:
|
193 |
+
return None
|
194 |
+
return classifier(decoder_output, encoder_output, encoder_mask)
|
195 |
+
|
196 |
+
def inference_label(self, prediction):
|
197 |
+
prediction = prediction.exp()
|
198 |
+
return torch.where(
|
199 |
+
prediction[:, 0] > prediction[:, 1:].sum(-1),
|
200 |
+
torch.zeros(prediction.size(0), dtype=torch.long, device=prediction.device),
|
201 |
+
prediction[:, 1:].argmax(dim=-1) + 1
|
202 |
+
)
|
203 |
+
|
204 |
+
def inference_anchor(self, prediction):
|
205 |
+
return prediction.sigmoid()
|
206 |
+
|
207 |
+
def inference_edge_presence(self, prediction, example_index: int):
|
208 |
+
if prediction is None:
|
209 |
+
return None
|
210 |
+
|
211 |
+
N = prediction.size(1)
|
212 |
+
mask = torch.eye(N, N, device=prediction.device, dtype=torch.bool)
|
213 |
+
return prediction[example_index, :, :].sigmoid().masked_fill(mask, 0.0).cpu()
|
214 |
+
|
215 |
+
def inference_edge_label(self, prediction, example_index: int):
|
216 |
+
if prediction is None:
|
217 |
+
return None
|
218 |
+
return prediction[example_index, :, :, :].cpu()
|
219 |
+
|
220 |
+
def loss_edge_presence(self, prediction, target, mask):
|
221 |
+
if self.edge_classifier is None or prediction["edge presence"] is None:
|
222 |
+
return {}
|
223 |
+
return {"edge presence": binary_cross_entropy(prediction["edge presence"], target["edge_presence"].float(), mask)}
|
224 |
+
|
225 |
+
def loss_edge_label(self, prediction, target, mask):
|
226 |
+
if self.edge_classifier is None or prediction["edge label"] is None:
|
227 |
+
return {}
|
228 |
+
return {"edge label": binary_cross_entropy(prediction["edge label"], target["edge_labels"][0].float(), mask)}
|
229 |
+
|
230 |
+
def loss_label(self, prediction, target, mask, matching):
|
231 |
+
if self.label_classifier is None or prediction["label"] is None:
|
232 |
+
return {}
|
233 |
+
|
234 |
+
prediction = prediction["label"]
|
235 |
+
target = match_label(
|
236 |
+
target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length
|
237 |
+
)
|
238 |
+
return {"label": cross_entropy(prediction, target, mask, focal=self.focal)}
|
239 |
+
|
240 |
+
def loss_anchor(self, prediction, target, mask, matching, mode="anchor"):
|
241 |
+
if getattr(self, f"{mode}_classifier") is None or prediction[mode] is None:
|
242 |
+
return {}
|
243 |
+
|
244 |
+
prediction = prediction[mode]
|
245 |
+
target, anchor_mask = match_anchor(target[mode], matching, prediction.shape, prediction.device)
|
246 |
+
mask = anchor_mask.unsqueeze(-1) | mask.unsqueeze(-2)
|
247 |
+
return {mode: binary_cross_entropy(prediction, target.float(), mask)}
|
248 |
+
|
249 |
+
def label_cost_matrix(self, output, batch, decoder_lens, b: int):
|
250 |
+
if output["label"] is None:
|
251 |
+
return 1.0
|
252 |
+
|
253 |
+
target_labels = batch["anchored_labels"][b] # shape: (num_nodes, num_inputs, num_classes)
|
254 |
+
label_prob = output["label"][b, : decoder_lens[b], :].exp().unsqueeze(0) # shape: (1, num_queries, num_classes)
|
255 |
+
tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) # shape: (num_nodes, num_queries, num_classes)
|
256 |
+
cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() # shape: (num_queries, num_nodes)
|
257 |
+
|
258 |
+
return cost_matrix
|
259 |
+
|
260 |
+
def anchor_cost_matrix(self, output, batch, decoder_lens, b: int):
|
261 |
+
if output["anchor"] is None:
|
262 |
+
return 1.0
|
263 |
+
|
264 |
+
num_nodes = batch["labels"][1][b]
|
265 |
+
word_lens = batch["every_input"][1]
|
266 |
+
target_anchors, _ = batch["anchor"]
|
267 |
+
pred_anchors = output["anchor"].sigmoid()
|
268 |
+
|
269 |
+
tgt_align = target_anchors[b, : num_nodes, : word_lens[b]] # shape: (num_nodes, num_inputs)
|
270 |
+
align_prob = pred_anchors[b, : decoder_lens[b], : word_lens[b]] # shape: (num_queries, num_inputs)
|
271 |
+
align_prob = align_prob.unsqueeze(1).expand(-1, num_nodes, -1) # shape: (num_queries, num_nodes, num_inputs)
|
272 |
+
align_prob = torch.where(tgt_align.unsqueeze(0).bool(), align_prob, 1.0 - align_prob) # shape: (num_queries, num_nodes, num_inputs)
|
273 |
+
cost_matrix = align_prob.log().mean(-1).exp() # shape: (num_queries, num_nodes)
|
274 |
+
return cost_matrix
|
model/head/labeled_edge_head.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from model.head.abstract_head import AbstractHead
|
8 |
+
from data.parser.to_mrp.labeled_edge_parser import LabeledEdgeParser
|
9 |
+
from utility.cross_entropy import binary_cross_entropy
|
10 |
+
from utility.hungarian_matching import match_label
|
11 |
+
|
12 |
+
|
13 |
+
class LabeledEdgeHead(AbstractHead):
|
14 |
+
def __init__(self, dataset, args, initialize):
|
15 |
+
config = {
|
16 |
+
"label": True,
|
17 |
+
"edge presence": True,
|
18 |
+
"edge label": True,
|
19 |
+
"anchor": True,
|
20 |
+
"source_anchor": False,
|
21 |
+
"target_anchor": False
|
22 |
+
}
|
23 |
+
super(LabeledEdgeHead, self).__init__(dataset, args, config, initialize)
|
24 |
+
|
25 |
+
self.top_node = nn.Parameter(torch.randn(1, 1, args.hidden_size), requires_grad=True)
|
26 |
+
self.parser = LabeledEdgeParser(dataset)
|
27 |
+
|
28 |
+
def init_label_classifier(self, dataset, args, config, initialize: bool):
|
29 |
+
classifier = nn.Sequential(
|
30 |
+
nn.Dropout(args.dropout_label),
|
31 |
+
nn.Linear(args.hidden_size, 1, bias=True)
|
32 |
+
)
|
33 |
+
if initialize:
|
34 |
+
bias_init = torch.tensor([dataset.label_freqs[1]])
|
35 |
+
classifier[1].bias.data = (bias_init / (1.0 - bias_init)).log()
|
36 |
+
|
37 |
+
return classifier
|
38 |
+
|
39 |
+
def forward_label(self, decoder_output):
|
40 |
+
return self.label_classifier(decoder_output)
|
41 |
+
|
42 |
+
def forward_edge(self, decoder_output):
|
43 |
+
top_node = self.top_node.expand(decoder_output.size(0), -1, -1)
|
44 |
+
decoder_output = torch.cat([top_node, decoder_output], dim=1)
|
45 |
+
return self.edge_classifier(decoder_output)
|
46 |
+
|
47 |
+
def loss_label(self, prediction, target, mask, matching):
|
48 |
+
prediction = prediction["label"]
|
49 |
+
target = match_label(
|
50 |
+
target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length
|
51 |
+
)
|
52 |
+
return {"label": binary_cross_entropy(prediction.squeeze(-1), target.float(), mask, focal=self.focal)}
|
53 |
+
|
54 |
+
def inference_label(self, prediction):
|
55 |
+
return (prediction.squeeze(-1) > 0.0).long()
|
56 |
+
|
57 |
+
def label_cost_matrix(self, output, batch, decoder_lens, b: int):
|
58 |
+
if output["label"] is None:
|
59 |
+
return 1.0
|
60 |
+
|
61 |
+
target_labels = batch["anchored_labels"][b] # shape: (num_nodes, num_inputs, 2)
|
62 |
+
label_prob = output["label"][b, : decoder_lens[b], :].sigmoid().unsqueeze(0) # shape: (1, num_queries, 1)
|
63 |
+
label_prob = torch.cat([1.0 - label_prob, label_prob], dim=-1) # shape: (1, num_queries, 2)
|
64 |
+
tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) # shape: (num_nodes, num_queries, 2)
|
65 |
+
cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() # shape: (num_queries, num_nodes)
|
66 |
+
|
67 |
+
return cost_matrix
|
model/head/node_centric_head.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from model.head.abstract_head import AbstractHead
|
7 |
+
from data.parser.to_mrp.node_centric_parser import NodeCentricParser
|
8 |
+
from utility.cross_entropy import binary_cross_entropy
|
9 |
+
|
10 |
+
|
11 |
+
class NodeCentricHead(AbstractHead):
|
12 |
+
def __init__(self, dataset, args, initialize):
|
13 |
+
config = {
|
14 |
+
"label": True,
|
15 |
+
"edge presence": True,
|
16 |
+
"edge label": False,
|
17 |
+
"anchor": True,
|
18 |
+
"source_anchor": False,
|
19 |
+
"target_anchor": False
|
20 |
+
}
|
21 |
+
super(NodeCentricHead, self).__init__(dataset, args, config, initialize)
|
22 |
+
|
23 |
+
self.source_id = dataset.label_field.vocab.stoi["Source"] + 1
|
24 |
+
self.target_id = dataset.label_field.vocab.stoi["Target"] + 1
|
25 |
+
self.parser = NodeCentricParser(dataset)
|
model/head/sequential_head.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from model.head.abstract_head import AbstractHead
|
9 |
+
from data.parser.to_mrp.sequential_parser import SequentialParser
|
10 |
+
from utility.cross_entropy import cross_entropy
|
11 |
+
|
12 |
+
|
13 |
+
class SequentialHead(AbstractHead):
|
14 |
+
def __init__(self, dataset, args, initialize):
|
15 |
+
config = {
|
16 |
+
"label": True,
|
17 |
+
"edge presence": False,
|
18 |
+
"edge label": False,
|
19 |
+
"anchor": True,
|
20 |
+
"source_anchor": True,
|
21 |
+
"target_anchor": True
|
22 |
+
}
|
23 |
+
super(SequentialHead, self).__init__(dataset, args, config, initialize)
|
24 |
+
self.parser = SequentialParser(dataset)
|
model/model.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from model.module.encoder import Encoder
|
8 |
+
|
9 |
+
from model.module.transformer import Decoder
|
10 |
+
from model.head.node_centric_head import NodeCentricHead
|
11 |
+
from model.head.labeled_edge_head import LabeledEdgeHead
|
12 |
+
from model.head.sequential_head import SequentialHead
|
13 |
+
from utility.utils import create_padding_mask
|
14 |
+
|
15 |
+
|
16 |
+
class Model(nn.Module):
|
17 |
+
def __init__(self, dataset, args, initialize=True):
|
18 |
+
super(Model, self).__init__()
|
19 |
+
self.encoder = Encoder(args, dataset)
|
20 |
+
if args.n_layers > 0:
|
21 |
+
self.decoder = Decoder(args)
|
22 |
+
else:
|
23 |
+
self.decoder = lambda x, *args: x # identity function, which ignores all arguments except the first one
|
24 |
+
|
25 |
+
if args.graph_mode == "sequential":
|
26 |
+
self.head = SequentialHead(dataset, args, initialize)
|
27 |
+
elif args.graph_mode == "node-centric":
|
28 |
+
self.head = NodeCentricHead(dataset, args, initialize)
|
29 |
+
elif args.graph_mode == "labeled-edge":
|
30 |
+
self.head = LabeledEdgeHead(dataset, args, initialize)
|
31 |
+
|
32 |
+
self.query_length = args.query_length
|
33 |
+
self.dataset = dataset
|
34 |
+
self.args = args
|
35 |
+
|
36 |
+
def forward(self, batch, inference=False, **kwargs):
|
37 |
+
every_input, word_lens = batch["every_input"]
|
38 |
+
decoder_lens = self.query_length * word_lens
|
39 |
+
batch_size, input_len = every_input.size(0), every_input.size(1)
|
40 |
+
device = every_input.device
|
41 |
+
|
42 |
+
encoder_mask = create_padding_mask(batch_size, input_len, word_lens, device)
|
43 |
+
decoder_mask = create_padding_mask(batch_size, self.query_length * input_len, decoder_lens, device)
|
44 |
+
|
45 |
+
encoder_output, decoder_input = self.encoder(batch["input"], batch["char_form_input"], batch["input_scatter"], input_len)
|
46 |
+
|
47 |
+
decoder_output = self.decoder(decoder_input, encoder_output, decoder_mask, encoder_mask)
|
48 |
+
|
49 |
+
if inference:
|
50 |
+
return self.head.predict(encoder_output, decoder_output, encoder_mask, decoder_mask, batch)
|
51 |
+
else:
|
52 |
+
return self.head(encoder_output, decoder_output, encoder_mask, decoder_mask, batch)
|
53 |
+
|
54 |
+
def get_params_for_optimizer(self, args):
|
55 |
+
encoder_decay, encoder_no_decay = self.get_encoder_parameters(args.n_encoder_layers)
|
56 |
+
decoder_decay, decoder_no_decay = self.get_decoder_parameters()
|
57 |
+
|
58 |
+
parameters = [{"params": p, "weight_decay": args.encoder_weight_decay} for p in encoder_decay]
|
59 |
+
parameters += [{"params": p, "weight_decay": 0.0} for p in encoder_no_decay]
|
60 |
+
parameters += [
|
61 |
+
{"params": decoder_decay, "weight_decay": args.decoder_weight_decay},
|
62 |
+
{"params": decoder_no_decay, "weight_decay": 0.0},
|
63 |
+
]
|
64 |
+
return parameters
|
65 |
+
|
66 |
+
def get_decoder_parameters(self):
|
67 |
+
no_decay = ["bias", "LayerNorm.weight", "_norm.weight"]
|
68 |
+
decay_params = (p for name, p in self.named_parameters() if not any(nd in name for nd in no_decay) and not name.startswith("encoder.bert") and p.requires_grad)
|
69 |
+
no_decay_params = (p for name, p in self.named_parameters() if any(nd in name for nd in no_decay) and not name.startswith("encoder.bert") and p.requires_grad)
|
70 |
+
|
71 |
+
return decay_params, no_decay_params
|
72 |
+
|
73 |
+
def get_encoder_parameters(self, n_layers):
|
74 |
+
no_decay = ["bias", "LayerNorm.weight", "_norm.weight"]
|
75 |
+
decay_params = [
|
76 |
+
[p for name, p in self.named_parameters() if not any(nd in name for nd in no_decay) and name.startswith(f"encoder.bert.encoder.layer.{n_layers - 1 - i}.") and p.requires_grad] for i in range(n_layers)
|
77 |
+
]
|
78 |
+
no_decay_params = [
|
79 |
+
[p for name, p in self.named_parameters() if any(nd in name for nd in no_decay) and name.startswith(f"encoder.bert.encoder.layer.{n_layers - 1 - i}.") and p.requires_grad] for i in range(n_layers)
|
80 |
+
]
|
81 |
+
|
82 |
+
return decay_params, no_decay_params
|
model/module/__init__.py
ADDED
File without changes
|
model/module/anchor_classifier.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from model.module.biaffine import Biaffine
|
8 |
+
|
9 |
+
|
10 |
+
class AnchorClassifier(nn.Module):
|
11 |
+
def __init__(self, dataset, args, initialize: bool, bias=True, mode="anchor"):
|
12 |
+
super(AnchorClassifier, self).__init__()
|
13 |
+
|
14 |
+
self.token_f = nn.Linear(args.hidden_size, args.hidden_size_anchor)
|
15 |
+
self.label_f = nn.Linear(args.hidden_size, args.hidden_size_anchor)
|
16 |
+
self.dropout = nn.Dropout(args.dropout_anchor)
|
17 |
+
|
18 |
+
if bias and initialize:
|
19 |
+
bias_init = torch.tensor([getattr(dataset, f"{mode}_freq")])
|
20 |
+
bias_init = (bias_init / (1.0 - bias_init)).log()
|
21 |
+
else:
|
22 |
+
bias_init = None
|
23 |
+
|
24 |
+
self.output = Biaffine(args.hidden_size_anchor, 1, bias=bias, bias_init=bias_init)
|
25 |
+
|
26 |
+
def forward(self, label, tokens, encoder_mask):
|
27 |
+
tokens = self.dropout(F.elu(self.token_f(tokens))) # shape: (B, T_w, H)
|
28 |
+
label = self.dropout(F.elu(self.label_f(label))) # shape: (B, T_l, H)
|
29 |
+
anchor = self.output(label, tokens).squeeze(-1) # shape: (B, T_l, T_w)
|
30 |
+
|
31 |
+
anchor = anchor.masked_fill(encoder_mask.unsqueeze(1), float("-inf")) # shape: (B, T_l, T_w)
|
32 |
+
return anchor
|
model/module/biaffine.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from model.module.bilinear import Bilinear
|
6 |
+
|
7 |
+
|
8 |
+
class Biaffine(nn.Module):
|
9 |
+
def __init__(self, input_dim, output_dim, bias=True, bias_init=None):
|
10 |
+
super(Biaffine, self).__init__()
|
11 |
+
|
12 |
+
self.linear_1 = nn.Linear(input_dim, output_dim, bias=False)
|
13 |
+
self.linear_2 = nn.Linear(input_dim, output_dim, bias=False)
|
14 |
+
|
15 |
+
self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias)
|
16 |
+
if bias_init is not None:
|
17 |
+
self.bilinear.bias.data = bias_init
|
18 |
+
|
19 |
+
def forward(self, x, y):
|
20 |
+
return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1)
|
model/module/bilinear.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://github.com/NLPInBLCU/BiaffineDependencyParsing/blob/master/modules/biaffine.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class Bilinear(nn.Module):
|
8 |
+
"""
|
9 |
+
使用版本
|
10 |
+
A bilinear module that deals with broadcasting for efficient memory usage.
|
11 |
+
Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
|
12 |
+
Output: tensor of size (N x L1 x L2 x O)"""
|
13 |
+
|
14 |
+
def __init__(self, input1_size, input2_size, output_size, bias=True):
|
15 |
+
super(Bilinear, self).__init__()
|
16 |
+
|
17 |
+
self.input1_size = input1_size
|
18 |
+
self.input2_size = input2_size
|
19 |
+
self.output_size = output_size
|
20 |
+
|
21 |
+
self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))
|
22 |
+
self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else None
|
23 |
+
|
24 |
+
self.reset_parameters()
|
25 |
+
|
26 |
+
def reset_parameters(self):
|
27 |
+
nn.init.zeros_(self.weight)
|
28 |
+
|
29 |
+
def forward(self, input1, input2):
|
30 |
+
input1_size = list(input1.size())
|
31 |
+
input2_size = list(input2.size())
|
32 |
+
|
33 |
+
intermediate = torch.mm(input1.view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size),)
|
34 |
+
|
35 |
+
input2 = input2.transpose(1, 2)
|
36 |
+
output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
|
37 |
+
|
38 |
+
output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)
|
39 |
+
|
40 |
+
if self.bias is not None:
|
41 |
+
output = output + self.bias
|
42 |
+
|
43 |
+
return output
|
model/module/char_embedding.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
7 |
+
|
8 |
+
|
9 |
+
class CharEmbedding(nn.Module):
|
10 |
+
def __init__(self, vocab_size: int, embedding_size: int, output_size: int):
|
11 |
+
super(CharEmbedding, self).__init__()
|
12 |
+
|
13 |
+
self.embedding = nn.Embedding(vocab_size, embedding_size, sparse=False)
|
14 |
+
self.layer_norm = nn.LayerNorm(embedding_size)
|
15 |
+
self.gru = nn.GRU(embedding_size, embedding_size, num_layers=1, bidirectional=True)
|
16 |
+
self.out_linear = nn.Linear(2*embedding_size, output_size)
|
17 |
+
self.layer_norm_2 = nn.LayerNorm(output_size)
|
18 |
+
|
19 |
+
def forward(self, words, sentence_lens, word_lens):
|
20 |
+
# input shape: (B, W, C)
|
21 |
+
n_words = words.size(1)
|
22 |
+
sentence_lens = sentence_lens.cpu()
|
23 |
+
sentence_packed = pack_padded_sequence(words, sentence_lens, batch_first=True) # shape: (B*W, C)
|
24 |
+
lens_packed = pack_padded_sequence(word_lens, sentence_lens, batch_first=True) # shape: (B*W)
|
25 |
+
word_packed = pack_padded_sequence(sentence_packed.data, lens_packed.data.cpu(), batch_first=True, enforce_sorted=False) # shape: (B*W*C)
|
26 |
+
|
27 |
+
embedded = self.embedding(word_packed.data) # shape: (B*W*C, D)
|
28 |
+
embedded = self.layer_norm(embedded) # shape: (B*W*C, D)
|
29 |
+
|
30 |
+
embedded_packed = PackedSequence(embedded, word_packed[1], word_packed[2], word_packed[3])
|
31 |
+
_, embedded = self.gru(embedded_packed) # shape: (layers * 2, B*W, D)
|
32 |
+
|
33 |
+
embedded = embedded[-2:, :, :].transpose(0, 1).flatten(1, 2) # shape: (B*W, 2*D)
|
34 |
+
embedded = F.relu(embedded)
|
35 |
+
embedded = self.out_linear(embedded)
|
36 |
+
embedded = self.layer_norm_2(embedded)
|
37 |
+
|
38 |
+
embedded, _ = pad_packed_sequence(
|
39 |
+
PackedSequence(embedded, sentence_packed[1], sentence_packed[2], sentence_packed[3]), batch_first=True, total_length=n_words,
|
40 |
+
) # shape: (B, W, 2*D)
|
41 |
+
|
42 |
+
return embedded # shape: (B, W, 2*D)
|
model/module/edge_classifier.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from model.module.biaffine import Biaffine
|
8 |
+
|
9 |
+
|
10 |
+
class EdgeClassifier(nn.Module):
|
11 |
+
def __init__(self, dataset, args, initialize: bool, presence: bool, label: bool):
|
12 |
+
super(EdgeClassifier, self).__init__()
|
13 |
+
|
14 |
+
self.presence = presence
|
15 |
+
if self.presence:
|
16 |
+
if initialize:
|
17 |
+
presence_init = torch.tensor([dataset.edge_presence_freq])
|
18 |
+
presence_init = (presence_init / (1.0 - presence_init)).log()
|
19 |
+
else:
|
20 |
+
presence_init = None
|
21 |
+
|
22 |
+
self.edge_presence = EdgeBiaffine(
|
23 |
+
args.hidden_size, args.hidden_size_edge_presence, 1, args.dropout_edge_presence, bias_init=presence_init
|
24 |
+
)
|
25 |
+
|
26 |
+
self.label = label
|
27 |
+
if self.label:
|
28 |
+
label_init = (dataset.edge_label_freqs / (1.0 - dataset.edge_label_freqs)).log() if initialize else None
|
29 |
+
n_labels = len(dataset.edge_label_field.vocab)
|
30 |
+
self.edge_label = EdgeBiaffine(
|
31 |
+
args.hidden_size, args.hidden_size_edge_label, n_labels, args.dropout_edge_label, bias_init=label_init
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
presence, label = None, None
|
36 |
+
|
37 |
+
if self.presence:
|
38 |
+
presence = self.edge_presence(x).squeeze(-1) # shape: (B, T, T)
|
39 |
+
if self.label:
|
40 |
+
label = self.edge_label(x) # shape: (B, T, T, O_1)
|
41 |
+
|
42 |
+
return presence, label
|
43 |
+
|
44 |
+
|
45 |
+
class EdgeBiaffine(nn.Module):
|
46 |
+
def __init__(self, hidden_dim, bottleneck_dim, output_dim, dropout, bias_init=None):
|
47 |
+
super(EdgeBiaffine, self).__init__()
|
48 |
+
self.hidden = nn.Linear(hidden_dim, 2 * bottleneck_dim)
|
49 |
+
self.output = Biaffine(bottleneck_dim, output_dim, bias_init=bias_init)
|
50 |
+
self.dropout = nn.Dropout(dropout)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.dropout(F.elu(self.hidden(x))) # shape: (B, T, 2H)
|
54 |
+
predecessors, current = x.chunk(2, dim=-1) # shape: (B, T, H), (B, T, H)
|
55 |
+
edge = self.output(current, predecessors) # shape: (B, T, T, O)
|
56 |
+
return edge
|
model/module/encoder.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# coding=utf-8
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from transformers import AutoModel
|
11 |
+
from model.module.char_embedding import CharEmbedding
|
12 |
+
|
13 |
+
|
14 |
+
class WordDropout(nn.Dropout):
|
15 |
+
def forward(self, input_tensor):
|
16 |
+
if self.p == 0:
|
17 |
+
return input_tensor
|
18 |
+
|
19 |
+
ones = input_tensor.new_ones(input_tensor.shape[:-1])
|
20 |
+
dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)
|
21 |
+
|
22 |
+
return dropout_mask.unsqueeze(-1) * input_tensor
|
23 |
+
|
24 |
+
|
25 |
+
class Encoder(nn.Module):
|
26 |
+
def __init__(self, args, dataset):
|
27 |
+
super(Encoder, self).__init__()
|
28 |
+
|
29 |
+
self.dim = args.hidden_size
|
30 |
+
self.n_layers = args.n_encoder_layers
|
31 |
+
self.width_factor = args.query_length
|
32 |
+
|
33 |
+
self.bert = AutoModel.from_pretrained(args.encoder, add_pooling_layer=False)
|
34 |
+
# self.bert._set_gradient_checkpointing(self.bert.encoder, value=True)
|
35 |
+
if args.encoder_freeze_embedding:
|
36 |
+
self.bert.embeddings.requires_grad_(False)
|
37 |
+
self.bert.embeddings.LayerNorm.requires_grad_(True)
|
38 |
+
|
39 |
+
if args.freeze_bert:
|
40 |
+
self.bert.requires_grad_(False)
|
41 |
+
|
42 |
+
self.use_char_embedding = args.char_embedding
|
43 |
+
if self.use_char_embedding:
|
44 |
+
self.form_char_embedding = CharEmbedding(dataset.char_form_vocab_size, args.char_embedding_size, self.dim)
|
45 |
+
self.word_dropout = WordDropout(args.dropout_word)
|
46 |
+
|
47 |
+
self.post_layer_norm = nn.LayerNorm(self.dim)
|
48 |
+
self.subword_attention = nn.Linear(self.dim, 1)
|
49 |
+
|
50 |
+
if self.width_factor > 1:
|
51 |
+
self.query_generator = nn.Linear(self.dim, self.dim * self.width_factor)
|
52 |
+
else:
|
53 |
+
self.query_generator = nn.Identity()
|
54 |
+
|
55 |
+
self.encoded_layer_norm = nn.LayerNorm(self.dim)
|
56 |
+
self.scores = nn.Parameter(torch.zeros(self.n_layers, 1, 1, 1), requires_grad=True)
|
57 |
+
|
58 |
+
def forward(self, bert_input, form_chars, to_scatter, n_words):
|
59 |
+
tokens, mask = bert_input
|
60 |
+
batch_size = tokens.size(0)
|
61 |
+
|
62 |
+
encoded = self.bert(tokens, attention_mask=mask, output_hidden_states=True).hidden_states[1:]
|
63 |
+
encoded = torch.stack(encoded, dim=0) # shape: (12, B, T, H)
|
64 |
+
encoded = self.encoded_layer_norm(encoded)
|
65 |
+
|
66 |
+
if self.training:
|
67 |
+
time_len = encoded.size(2)
|
68 |
+
scores = self.scores.expand(-1, batch_size, time_len, -1)
|
69 |
+
dropout = torch.empty(self.n_layers, batch_size, 1, 1, dtype=torch.bool, device=self.scores.device)
|
70 |
+
dropout.bernoulli_(0.1)
|
71 |
+
scores = scores.masked_fill(dropout, float("-inf"))
|
72 |
+
else:
|
73 |
+
scores = self.scores
|
74 |
+
|
75 |
+
scores = F.softmax(scores, dim=0)
|
76 |
+
encoded = (scores * encoded).sum(0) # shape: (B, T, H)
|
77 |
+
encoded = encoded.masked_fill(mask.unsqueeze(-1) == 0, 0.0) # shape: (B, T, H)
|
78 |
+
|
79 |
+
subword_attention = self.subword_attention(encoded) / math.sqrt(self.dim) # shape: (B, T, 1)
|
80 |
+
subword_attention = subword_attention.expand_as(to_scatter) # shape: (B, T_subword, T_word)
|
81 |
+
subword_attention = subword_attention.masked_fill(to_scatter == 0, float("-inf")) # shape: (B, T_subword, T_word)
|
82 |
+
subword_attention = torch.softmax(subword_attention, dim=1) # shape: (B, T_subword, T_word)
|
83 |
+
subword_attention = subword_attention.masked_fill(to_scatter.sum(1, keepdim=True) == 0, value=0.0) # shape: (B, T_subword, T_word)
|
84 |
+
|
85 |
+
encoder_output = torch.einsum("bsd,bsw->bwd", encoded, subword_attention)
|
86 |
+
encoder_output = self.post_layer_norm(encoder_output)
|
87 |
+
|
88 |
+
if self.use_char_embedding:
|
89 |
+
form_char_embedding = self.form_char_embedding(form_chars[0], form_chars[1], form_chars[2])
|
90 |
+
encoder_output = self.word_dropout(encoder_output) + form_char_embedding
|
91 |
+
|
92 |
+
decoder_input = self.query_generator(encoder_output)
|
93 |
+
decoder_input = decoder_input.view(batch_size, -1, self.width_factor, self.dim).flatten(1, 2) # shape: (B, T*Q, D)
|
94 |
+
|
95 |
+
return encoder_output, decoder_input
|