Upload 53 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tasks/clue/afqmc.py +94 -0
- tasks/clue/cmnli.py +103 -0
- tasks/clue/csl.py +93 -0
- tasks/clue/data.py +69 -0
- tasks/clue/finetune.py +130 -0
- tasks/clue/iflytek.py +94 -0
- tasks/clue/ocnli.py +102 -0
- tasks/clue/tnews.py +95 -0
- tasks/clue/wsc.py +117 -0
- tasks/clue/zc.py +96 -0
- tasks/data_utils.py +118 -0
- tasks/ensemble_classifier.py +149 -0
- tasks/eval_utils.py +250 -0
- tasks/finetune_utils.py +330 -0
- tasks/glue/data.py +69 -0
- tasks/glue/finetune.py +93 -0
- tasks/glue/mnli.py +84 -0
- tasks/glue/qqp.py +101 -0
- tasks/label_dict.py +73 -0
- tasks/main.py +121 -0
- tasks/msdp/README.md +19 -0
- tasks/msdp/evaluate.py +58 -0
- tasks/msdp/main.py +79 -0
- tasks/msdp/metrics.py +77 -0
- tasks/msdp/preprocessing.py +595 -0
- tasks/msdp/prompt.py +322 -0
- tasks/orqa/README.md +36 -0
- tasks/orqa/evaluate_orqa.py +52 -0
- tasks/orqa/evaluate_utils.py +188 -0
- tasks/orqa/supervised/data.py +300 -0
- tasks/orqa/supervised/eval_utils.py +206 -0
- tasks/orqa/supervised/finetune.py +251 -0
- tasks/orqa/unsupervised/nq.py +228 -0
- tasks/orqa/unsupervised/qa_utils.py +177 -0
- tasks/orqa/unsupervised/tokenizers.py +243 -0
- tasks/race/data.py +135 -0
- tasks/race/finetune.py +67 -0
- tasks/vision/classification/classification.py +94 -0
- tasks/vision/classification/eval_utils.py +129 -0
- tasks/vision/finetune_utils.py +312 -0
- tasks/vision/main.py +66 -0
- tasks/vision/segmentation/cityscapes.py +207 -0
- tasks/vision/segmentation/data.py +154 -0
- tasks/vision/segmentation/finetune_segformer.py +251 -0
- tasks/vision/segmentation/finetune_setr.py +225 -0
- tasks/vision/segmentation/metrics.py +594 -0
- tasks/vision/segmentation/seg_heads.py +140 -0
- tasks/vision/segmentation/seg_models.py +92 -0
- tasks/vision/segmentation/transforms.py +433 -0
- tasks/vision/segmentation/utils.py +85 -0
tasks/clue/afqmc.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("AFQMC")
|
25 |
+
|
26 |
+
class AFQMCDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='0'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('AFQMC', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
reader = f.readlines()
|
44 |
+
lines = []
|
45 |
+
for line in reader:
|
46 |
+
lines.append(json.loads(line.strip()))
|
47 |
+
drop_cnt = 0
|
48 |
+
for index, row in enumerate(lines):
|
49 |
+
if "id" not in row:
|
50 |
+
row["id"] = index
|
51 |
+
if first:
|
52 |
+
first = False
|
53 |
+
if "label" not in row:
|
54 |
+
is_test = True
|
55 |
+
print_rank_0(
|
56 |
+
' reading {}, {} and {} columns and setting '
|
57 |
+
'labels to {}'.format(
|
58 |
+
row["id"], row["sentence1"].strip(),
|
59 |
+
row["sentence2"].strip(), self.test_label))
|
60 |
+
else:
|
61 |
+
is_test = False
|
62 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
63 |
+
'...'.format(
|
64 |
+
row["id"], row["sentence1"].strip(),
|
65 |
+
row["sentence2"].strip(), row["label"].strip()))
|
66 |
+
|
67 |
+
text_a = clean_text(row["sentence1"].strip())
|
68 |
+
text_b = clean_text(row["sentence2"].strip())
|
69 |
+
unique_id = int(row["id"])
|
70 |
+
|
71 |
+
if is_test:
|
72 |
+
label = self.test_label
|
73 |
+
else:
|
74 |
+
label = row["label"].strip()
|
75 |
+
|
76 |
+
assert len(text_a) > 0
|
77 |
+
assert len(text_b) > 0
|
78 |
+
assert label in LABELS, "found label {} {}".format(label, row)
|
79 |
+
assert unique_id >= 0
|
80 |
+
|
81 |
+
sample = {'text_a': text_a,
|
82 |
+
'text_b': text_b,
|
83 |
+
'label': LABELS[label],
|
84 |
+
'uid': unique_id}
|
85 |
+
total += 1
|
86 |
+
samples.append(sample)
|
87 |
+
|
88 |
+
if total % 5000 == 0:
|
89 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
90 |
+
|
91 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
92 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
93 |
+
|
94 |
+
return samples
|
tasks/clue/cmnli.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("CMNLI")
|
25 |
+
|
26 |
+
class CMNLIDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='contradiction'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('CMNLI', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
reader = f.readlines()
|
44 |
+
lines = []
|
45 |
+
for line in reader:
|
46 |
+
lines.append(json.loads(line.strip()))
|
47 |
+
drop_cnt = 0
|
48 |
+
for index, row in enumerate(lines):
|
49 |
+
row["id"] = index
|
50 |
+
# line = line.strip()
|
51 |
+
# try:
|
52 |
+
# row = eval(line)
|
53 |
+
# except:
|
54 |
+
# print(">>>>>>>> ", line)
|
55 |
+
# continue
|
56 |
+
if first:
|
57 |
+
first = False
|
58 |
+
if "label" not in row:
|
59 |
+
is_test = True
|
60 |
+
print_rank_0(
|
61 |
+
' reading {}, {} and {} columns and setting '
|
62 |
+
'labels to {}'.format(
|
63 |
+
row["id"], row["sentence1"].strip(),
|
64 |
+
row["sentence2"].strip(), self.test_label))
|
65 |
+
else:
|
66 |
+
is_test = False
|
67 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
68 |
+
'...'.format(
|
69 |
+
row["id"], row["sentence1"].strip(),
|
70 |
+
row["sentence2"].strip(), row["label"].strip()))
|
71 |
+
|
72 |
+
text_a = clean_text(row["sentence1"].strip())
|
73 |
+
text_b = clean_text(row["sentence2"].strip())
|
74 |
+
unique_id = int(row["id"])
|
75 |
+
|
76 |
+
if is_test:
|
77 |
+
label = self.test_label
|
78 |
+
else:
|
79 |
+
label = row["label"].strip()
|
80 |
+
|
81 |
+
if label == "-":
|
82 |
+
drop_cnt += 1
|
83 |
+
continue
|
84 |
+
|
85 |
+
assert len(text_a) > 0
|
86 |
+
assert len(text_b) > 0
|
87 |
+
assert label in LABELS, "found label {} {}".format(label, row)
|
88 |
+
assert unique_id >= 0
|
89 |
+
|
90 |
+
sample = {'text_a': text_a,
|
91 |
+
'text_b': text_b,
|
92 |
+
'label': LABELS[label],
|
93 |
+
'uid': unique_id}
|
94 |
+
total += 1
|
95 |
+
samples.append(sample)
|
96 |
+
|
97 |
+
if total % 5000 == 0:
|
98 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
99 |
+
|
100 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
101 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
102 |
+
|
103 |
+
return samples
|
tasks/clue/csl.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("CSL")
|
25 |
+
|
26 |
+
class CSLDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='0'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('CSL', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
reader = f.readlines()
|
44 |
+
lines = []
|
45 |
+
for line in reader:
|
46 |
+
lines.append(json.loads(line.strip()))
|
47 |
+
drop_cnt = 0
|
48 |
+
for index, row in enumerate(lines):
|
49 |
+
row["id"] = index
|
50 |
+
if first:
|
51 |
+
first = False
|
52 |
+
if "label" not in row:
|
53 |
+
is_test = True
|
54 |
+
print_rank_0(
|
55 |
+
' reading {}, {} and {} columns and setting '
|
56 |
+
'labels to {}'.format(
|
57 |
+
row["id"], " ".join(row["keyword"]).strip(),
|
58 |
+
row["abst"].strip(), self.test_label))
|
59 |
+
else:
|
60 |
+
is_test = False
|
61 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
62 |
+
'...'.format(
|
63 |
+
row["id"], " ".join(row["keyword"]).strip(),
|
64 |
+
row["abst"].strip(), row["label"].strip()))
|
65 |
+
|
66 |
+
text_a = clean_text(" ".join(row["keyword"]).strip())
|
67 |
+
text_b = clean_text(row["abst"].strip())
|
68 |
+
unique_id = int(row["id"])
|
69 |
+
|
70 |
+
if is_test:
|
71 |
+
label = self.test_label
|
72 |
+
else:
|
73 |
+
label = row["label"].strip()
|
74 |
+
|
75 |
+
assert len(text_a) > 0
|
76 |
+
assert len(text_b) > 0
|
77 |
+
assert label in LABELS, "found label {} {}".format(label, row)
|
78 |
+
assert unique_id >= 0
|
79 |
+
|
80 |
+
sample = {'text_a': text_a,
|
81 |
+
'text_b': text_b,
|
82 |
+
'label': LABELS[label],
|
83 |
+
'uid': unique_id}
|
84 |
+
total += 1
|
85 |
+
samples.append(sample)
|
86 |
+
|
87 |
+
if total % 5000 == 0:
|
88 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
89 |
+
|
90 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
91 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
92 |
+
|
93 |
+
return samples
|
tasks/clue/data.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""GLUE dataset."""
|
17 |
+
|
18 |
+
from abc import ABC
|
19 |
+
from abc import abstractmethod
|
20 |
+
|
21 |
+
from torch.utils.data import Dataset
|
22 |
+
|
23 |
+
from megatron import print_rank_0
|
24 |
+
from tasks.data_utils import build_sample
|
25 |
+
from tasks.data_utils import build_tokens_types_paddings_from_text
|
26 |
+
|
27 |
+
|
28 |
+
class GLUEAbstractDataset(ABC, Dataset):
|
29 |
+
"""GLUE base dataset class."""
|
30 |
+
|
31 |
+
def __init__(self, task_name, dataset_name, datapaths,
|
32 |
+
tokenizer, max_seq_length):
|
33 |
+
# Store inputs.
|
34 |
+
self.task_name = task_name
|
35 |
+
self.dataset_name = dataset_name
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.max_seq_length = max_seq_length
|
38 |
+
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
|
39 |
+
self.dataset_name))
|
40 |
+
# Process the files.
|
41 |
+
string = ' > paths:'
|
42 |
+
for path in datapaths:
|
43 |
+
string += ' ' + path
|
44 |
+
print_rank_0(string)
|
45 |
+
self.samples = []
|
46 |
+
for datapath in datapaths:
|
47 |
+
self.samples.extend(self.process_samples_from_single_path(datapath))
|
48 |
+
print_rank_0(' >> total number of samples: {}'.format(
|
49 |
+
len(self.samples)))
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.samples)
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
raw_sample = self.samples[idx]
|
56 |
+
ids, types, paddings = build_tokens_types_paddings_from_text(
|
57 |
+
raw_sample['text_a'], raw_sample['text_b'],
|
58 |
+
self.tokenizer, self.max_seq_length)
|
59 |
+
sample = build_sample(ids, types, paddings,
|
60 |
+
raw_sample['label'], raw_sample['uid'])
|
61 |
+
return sample
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def process_samples_from_single_path(self, datapath):
|
65 |
+
"""Abstract method that takes a single path / filename and
|
66 |
+
returns a list of dataset samples, each sample being a dict of
|
67 |
+
{'text_a': string, 'text_b': string, 'label': int, 'uid': int}
|
68 |
+
"""
|
69 |
+
pass
|
tasks/clue/finetune.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""GLUE finetuning/evaluation."""
|
17 |
+
|
18 |
+
from megatron import get_args
|
19 |
+
from megatron import print_rank_0
|
20 |
+
from megatron import get_tokenizer
|
21 |
+
from megatron import mpu
|
22 |
+
from megatron.model.classification import Classification
|
23 |
+
from tasks.eval_utils import accuracy_func_provider
|
24 |
+
from tasks.finetune_utils import finetune
|
25 |
+
|
26 |
+
|
27 |
+
def clue_classification(num_classes, Dataset,
|
28 |
+
name_from_datapath_func):
|
29 |
+
|
30 |
+
def train_valid_datasets_provider():
|
31 |
+
"""Build train and validation dataset."""
|
32 |
+
args = get_args()
|
33 |
+
tokenizer = get_tokenizer()
|
34 |
+
|
35 |
+
train_dataset = Dataset('training', args.train_data,
|
36 |
+
tokenizer, args.seq_length)
|
37 |
+
valid_dataset = Dataset('validation', args.valid_data,
|
38 |
+
tokenizer, args.seq_length)
|
39 |
+
|
40 |
+
return train_dataset, valid_dataset
|
41 |
+
|
42 |
+
def model_provider(pre_process=True, post_process=True):
|
43 |
+
"""Build the model."""
|
44 |
+
args = get_args()
|
45 |
+
|
46 |
+
print_rank_0('building classification model for {} ...'.format(
|
47 |
+
args.task))
|
48 |
+
model = Classification(num_classes=num_classes, num_tokentypes=2,
|
49 |
+
pre_process=pre_process, post_process=post_process)
|
50 |
+
|
51 |
+
return model
|
52 |
+
|
53 |
+
def metrics_func_provider():
|
54 |
+
"""Privde metrics callback function."""
|
55 |
+
def single_dataset_provider(datapath):
|
56 |
+
args = get_args()
|
57 |
+
tokenizer = get_tokenizer()
|
58 |
+
name = name_from_datapath_func(datapath)
|
59 |
+
return Dataset(name, [datapath], tokenizer, args.seq_length)
|
60 |
+
return accuracy_func_provider(single_dataset_provider)
|
61 |
+
|
62 |
+
"""Finetune/evaluate."""
|
63 |
+
finetune(train_valid_datasets_provider, model_provider,
|
64 |
+
end_of_epoch_callback_provider=metrics_func_provider)
|
65 |
+
|
66 |
+
|
67 |
+
def main():
|
68 |
+
args = get_args()
|
69 |
+
|
70 |
+
if args.task == 'AFQMC':
|
71 |
+
num_classes = 2
|
72 |
+
from tasks.clue.afqmc import AFQMCDataset as Dataset
|
73 |
+
|
74 |
+
def name_from_datapath(datapath):
|
75 |
+
return "afqmc"
|
76 |
+
|
77 |
+
elif args.task == 'CSL':
|
78 |
+
num_classes = 2
|
79 |
+
from tasks.clue.csl import CSLDataset as Dataset
|
80 |
+
|
81 |
+
def name_from_datapath(datapath):
|
82 |
+
return "csl"
|
83 |
+
|
84 |
+
elif args.task == 'IFLYTEK':
|
85 |
+
num_classes = 119
|
86 |
+
from tasks.clue.iflytek import IFLYTEKDataset as Dataset
|
87 |
+
|
88 |
+
def name_from_datapath(datapath):
|
89 |
+
return "iflytek"
|
90 |
+
|
91 |
+
elif args.task == 'OCNLI':
|
92 |
+
num_classes = 3
|
93 |
+
from tasks.clue.ocnli import OCNLIDataset as Dataset
|
94 |
+
|
95 |
+
def name_from_datapath(datapath):
|
96 |
+
return "ocnli"
|
97 |
+
|
98 |
+
elif args.task == 'TNEWS':
|
99 |
+
num_classes = 15
|
100 |
+
from tasks.clue.tnews import TNEWSDataset as Dataset
|
101 |
+
|
102 |
+
def name_from_datapath(datapath):
|
103 |
+
return "tnews"
|
104 |
+
|
105 |
+
elif args.task == 'WSC':
|
106 |
+
num_classes = 2
|
107 |
+
from tasks.clue.wsc import WSCDataset as Dataset
|
108 |
+
|
109 |
+
def name_from_datapath(datapath):
|
110 |
+
return "wsc"
|
111 |
+
|
112 |
+
elif args.task == 'CMNLI':
|
113 |
+
num_classes = 3
|
114 |
+
from tasks.clue.cmnli import CMNLIDataset as Dataset
|
115 |
+
|
116 |
+
def name_from_datapath(datapath):
|
117 |
+
return "cmnli"
|
118 |
+
|
119 |
+
elif args.task == 'ZC':
|
120 |
+
num_classes = 2
|
121 |
+
from tasks.clue.zc import ZCDataset as Dataset
|
122 |
+
|
123 |
+
def name_from_datapath(datapath):
|
124 |
+
return "zc"
|
125 |
+
|
126 |
+
else:
|
127 |
+
raise NotImplementedError('GLUE task {} is not implemented.'.format(
|
128 |
+
args.task))
|
129 |
+
|
130 |
+
clue_classification(num_classes, Dataset, name_from_datapath)
|
tasks/clue/iflytek.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("IFLYTEK")
|
25 |
+
|
26 |
+
class IFLYTEKDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='0'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('IFLYTEK', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
reader = f.readlines()
|
44 |
+
lines = []
|
45 |
+
for line in reader:
|
46 |
+
lines.append(json.loads(line.strip()))
|
47 |
+
drop_cnt = 0
|
48 |
+
for index, row in enumerate(lines):
|
49 |
+
if "id" not in row:
|
50 |
+
row["id"] = index
|
51 |
+
if first:
|
52 |
+
first = False
|
53 |
+
if "label" not in row:
|
54 |
+
is_test = True
|
55 |
+
print_rank_0(
|
56 |
+
' reading {}, {} and {} columns and setting '
|
57 |
+
'labels to {}'.format(
|
58 |
+
row["id"], row["sentence"].strip(),
|
59 |
+
None, self.test_label))
|
60 |
+
else:
|
61 |
+
is_test = False
|
62 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
63 |
+
'...'.format(
|
64 |
+
row["id"], row["sentence"].strip(),
|
65 |
+
None, row["label"].strip()))
|
66 |
+
|
67 |
+
text_a = clean_text(row["sentence"].strip())
|
68 |
+
text_b = None
|
69 |
+
unique_id = int(row["id"])
|
70 |
+
|
71 |
+
if is_test:
|
72 |
+
label = self.test_label
|
73 |
+
else:
|
74 |
+
label = row["label"].strip()
|
75 |
+
|
76 |
+
assert len(text_a) > 0
|
77 |
+
# assert len(text_b) > 0
|
78 |
+
assert label in LABELS, "found label {} {}".format(label, row)
|
79 |
+
assert unique_id >= 0
|
80 |
+
|
81 |
+
sample = {'text_a': text_a,
|
82 |
+
'text_b': text_b,
|
83 |
+
'label': LABELS[label],
|
84 |
+
'uid': unique_id}
|
85 |
+
total += 1
|
86 |
+
samples.append(sample)
|
87 |
+
|
88 |
+
if total % 5000 == 0:
|
89 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
90 |
+
|
91 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
92 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
93 |
+
|
94 |
+
return samples
|
tasks/clue/ocnli.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("OCNLI")
|
25 |
+
|
26 |
+
class OCNLIDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='contradiction'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('OCNLI', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
reader = f.readlines()
|
44 |
+
lines = []
|
45 |
+
for line in reader:
|
46 |
+
lines.append(json.loads(line.strip()))
|
47 |
+
drop_cnt = 0
|
48 |
+
for index, row in enumerate(lines):
|
49 |
+
# line = line.strip()
|
50 |
+
# try:
|
51 |
+
# row = eval(line)
|
52 |
+
# except:
|
53 |
+
# print(">>>>>>>> ", line)
|
54 |
+
# continue
|
55 |
+
if first:
|
56 |
+
first = False
|
57 |
+
if "label" not in row:
|
58 |
+
is_test = True
|
59 |
+
print_rank_0(
|
60 |
+
' reading {}, {} and {} columns and setting '
|
61 |
+
'labels to {}'.format(
|
62 |
+
row["id"], row["sentence1"].strip(),
|
63 |
+
row["sentence2"].strip(), self.test_label))
|
64 |
+
else:
|
65 |
+
is_test = False
|
66 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
67 |
+
'...'.format(
|
68 |
+
row["id"], row["sentence1"].strip(),
|
69 |
+
row["sentence2"].strip(), row["label"].strip()))
|
70 |
+
|
71 |
+
text_a = clean_text(row["sentence1"].strip())
|
72 |
+
text_b = clean_text(row["sentence2"].strip())
|
73 |
+
unique_id = int(row["id"])
|
74 |
+
|
75 |
+
if is_test:
|
76 |
+
label = self.test_label
|
77 |
+
else:
|
78 |
+
label = row["label"].strip()
|
79 |
+
|
80 |
+
if label == "-":
|
81 |
+
drop_cnt += 1
|
82 |
+
continue
|
83 |
+
|
84 |
+
assert len(text_a) > 0
|
85 |
+
assert len(text_b) > 0
|
86 |
+
assert label in LABELS, "found label {} {}".format(label, row)
|
87 |
+
assert unique_id >= 0
|
88 |
+
|
89 |
+
sample = {'text_a': text_a,
|
90 |
+
'text_b': text_b,
|
91 |
+
'label': LABELS[label],
|
92 |
+
'uid': unique_id}
|
93 |
+
total += 1
|
94 |
+
samples.append(sample)
|
95 |
+
|
96 |
+
if total % 5000 == 0:
|
97 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
98 |
+
|
99 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
100 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
101 |
+
|
102 |
+
return samples
|
tasks/clue/tnews.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("TNEWS")
|
25 |
+
|
26 |
+
class TNEWSDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='100'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('TNEWS', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
reader = f.readlines()
|
44 |
+
lines = []
|
45 |
+
for line in reader:
|
46 |
+
lines.append(json.loads(line.strip()))
|
47 |
+
drop_cnt = 0
|
48 |
+
for index, row in enumerate(lines):
|
49 |
+
if "id" not in row:
|
50 |
+
row["id"] = index
|
51 |
+
if first:
|
52 |
+
first = False
|
53 |
+
if "label" not in row:
|
54 |
+
is_test = True
|
55 |
+
print_rank_0(
|
56 |
+
' reading {}, {} and {} columns and setting '
|
57 |
+
'labels to {}'.format(
|
58 |
+
row["id"], row["sentence"].strip(),
|
59 |
+
None, self.test_label))
|
60 |
+
else:
|
61 |
+
is_test = False
|
62 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
63 |
+
'...'.format(
|
64 |
+
row["id"], row["sentence"].strip(),
|
65 |
+
None, row["label"].strip()))
|
66 |
+
|
67 |
+
text_a = clean_text(row["sentence"].strip())
|
68 |
+
text_b = clean_text(row["keywords"].strip())
|
69 |
+
# text_b = None
|
70 |
+
unique_id = int(row["id"])
|
71 |
+
|
72 |
+
if is_test:
|
73 |
+
label = self.test_label
|
74 |
+
else:
|
75 |
+
label = row["label"].strip()
|
76 |
+
|
77 |
+
assert len(text_a) > 0
|
78 |
+
# assert len(text_b) > 0
|
79 |
+
assert label in LABELS, "found label {} {}".format(label, row)
|
80 |
+
assert unique_id >= 0
|
81 |
+
|
82 |
+
sample = {'text_a': text_a,
|
83 |
+
'text_b': text_b,
|
84 |
+
'label': LABELS[label],
|
85 |
+
'uid': unique_id}
|
86 |
+
total += 1
|
87 |
+
samples.append(sample)
|
88 |
+
|
89 |
+
if total % 5000 == 0:
|
90 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
91 |
+
|
92 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
93 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
94 |
+
|
95 |
+
return samples
|
tasks/clue/wsc.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("WSC")
|
25 |
+
|
26 |
+
class WSCDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label="false"):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('WSC', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
reader = f.readlines()
|
44 |
+
lines = []
|
45 |
+
for line in reader:
|
46 |
+
lines.append(json.loads(line.strip()))
|
47 |
+
drop_cnt = 0
|
48 |
+
for index, row in enumerate(lines):
|
49 |
+
if "id" not in row:
|
50 |
+
row["id"] = index
|
51 |
+
text_a = row['text']
|
52 |
+
text_a_list = list(text_a)
|
53 |
+
target = row['target']
|
54 |
+
query = target['span1_text']
|
55 |
+
query_idx = target['span1_index']
|
56 |
+
pronoun = target['span2_text']
|
57 |
+
pronoun_idx = target['span2_index']
|
58 |
+
assert text_a[pronoun_idx: (pronoun_idx + len(pronoun))] == pronoun, "pronoun: {}".format(pronoun)
|
59 |
+
assert text_a[query_idx: (query_idx + len(query))] == query, "query: {}".format(query)
|
60 |
+
if pronoun_idx > query_idx:
|
61 |
+
text_a_list.insert(query_idx, "_")
|
62 |
+
text_a_list.insert(query_idx + len(query) + 1, "_")
|
63 |
+
text_a_list.insert(pronoun_idx + 2, "[")
|
64 |
+
text_a_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
|
65 |
+
else:
|
66 |
+
text_a_list.insert(pronoun_idx, "[")
|
67 |
+
text_a_list.insert(pronoun_idx + len(pronoun) + 1, "]")
|
68 |
+
text_a_list.insert(query_idx + 2, "_")
|
69 |
+
text_a_list.insert(query_idx + len(query) + 2 + 1, "_")
|
70 |
+
text_a = "".join(text_a_list)
|
71 |
+
# text_b = "在这句话中,{}指代的是{}".format(pronoun, query)
|
72 |
+
text_b = None
|
73 |
+
if first:
|
74 |
+
first = False
|
75 |
+
if "label" not in row:
|
76 |
+
is_test = True
|
77 |
+
print_rank_0(
|
78 |
+
' reading {}, {} and {} columns and setting '
|
79 |
+
'labels to {}'.format(
|
80 |
+
row["id"], text_a,
|
81 |
+
text_b, self.test_label))
|
82 |
+
else:
|
83 |
+
is_test = False
|
84 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
85 |
+
'...'.format(
|
86 |
+
row["id"], text_a,
|
87 |
+
text_b, row["label"].strip()))
|
88 |
+
|
89 |
+
text_a = text_a
|
90 |
+
text_b = text_b
|
91 |
+
# text_b = None
|
92 |
+
unique_id = int(row["id"])
|
93 |
+
|
94 |
+
if is_test:
|
95 |
+
label = self.test_label
|
96 |
+
else:
|
97 |
+
label = row["label"].strip()
|
98 |
+
|
99 |
+
assert len(text_a) > 0
|
100 |
+
# assert len(text_b) > 0
|
101 |
+
assert label in LABELS, "found label {} {} {}".format(label, row, type(label))
|
102 |
+
assert unique_id >= 0
|
103 |
+
|
104 |
+
sample = {'text_a': text_a,
|
105 |
+
'text_b': text_b,
|
106 |
+
'label': LABELS[label],
|
107 |
+
'uid': unique_id}
|
108 |
+
total += 1
|
109 |
+
samples.append(sample)
|
110 |
+
|
111 |
+
if total % 5000 == 0:
|
112 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
113 |
+
|
114 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
115 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
116 |
+
|
117 |
+
return samples
|
tasks/clue/zc.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
import json
|
22 |
+
from tasks.label_dict import get_label_dict
|
23 |
+
|
24 |
+
LABELS = get_label_dict("ZC")
|
25 |
+
|
26 |
+
class ZCDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='negative'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('ZC', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
print('>>>>', filename)
|
43 |
+
with open(filename, 'r') as f:
|
44 |
+
reader = f.readlines()
|
45 |
+
lines = []
|
46 |
+
for line in reader:
|
47 |
+
lines.append(json.loads(line.strip()))
|
48 |
+
drop_cnt = 0
|
49 |
+
for index, row in enumerate(lines):
|
50 |
+
# if "id" not in row:
|
51 |
+
row["id"] = index
|
52 |
+
if first:
|
53 |
+
first = False
|
54 |
+
# if "label" not in row:
|
55 |
+
if "test.json" in filename:
|
56 |
+
is_test = True
|
57 |
+
print_rank_0(
|
58 |
+
' reading {}, {} and {} columns and setting '
|
59 |
+
'labels to {}'.format(
|
60 |
+
row["id"], row["text"].strip(),
|
61 |
+
None, self.test_label))
|
62 |
+
else:
|
63 |
+
is_test = False
|
64 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
65 |
+
'...'.format(
|
66 |
+
row["id"], row["text"].strip(),
|
67 |
+
None, row["label"].strip()))
|
68 |
+
|
69 |
+
text_a = clean_text(row["text"].strip())
|
70 |
+
text_b = None
|
71 |
+
unique_id = int(row["id"])
|
72 |
+
|
73 |
+
if is_test:
|
74 |
+
label = self.test_label
|
75 |
+
else:
|
76 |
+
label = row["label"].strip()
|
77 |
+
|
78 |
+
assert len(text_a) > 0
|
79 |
+
# assert len(text_b) > 0
|
80 |
+
assert label in LABELS, "found label {} {}".format(label, row)
|
81 |
+
assert unique_id >= 0
|
82 |
+
|
83 |
+
sample = {'text_a': text_a,
|
84 |
+
'text_b': text_b,
|
85 |
+
'label': LABELS[label],
|
86 |
+
'uid': unique_id}
|
87 |
+
total += 1
|
88 |
+
samples.append(sample)
|
89 |
+
|
90 |
+
if total % 5000 == 0:
|
91 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
92 |
+
|
93 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
94 |
+
print_rank_0(' >> drop {} samples.'.format(drop_cnt))
|
95 |
+
|
96 |
+
return samples
|
tasks/data_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Tasks data utility."""
|
17 |
+
|
18 |
+
import re
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
|
22 |
+
def clean_text(text):
|
23 |
+
"""Remove new lines and multiple spaces and adjust end of sentence dot."""
|
24 |
+
|
25 |
+
text = text.replace("\n", " ")
|
26 |
+
text = re.sub(r'\s+', ' ', text)
|
27 |
+
for _ in range(3):
|
28 |
+
text = text.replace(' . ', '. ')
|
29 |
+
|
30 |
+
return text
|
31 |
+
|
32 |
+
|
33 |
+
def build_sample(ids, types, paddings, label, unique_id):
|
34 |
+
"""Convert to numpy and return a sample consumed by the batch producer."""
|
35 |
+
|
36 |
+
ids_np = np.array(ids, dtype=np.int64)
|
37 |
+
types_np = np.array(types, dtype=np.int64)
|
38 |
+
paddings_np = np.array(paddings, dtype=np.int64)
|
39 |
+
sample = ({'text': ids_np,
|
40 |
+
'types': types_np,
|
41 |
+
'padding_mask': paddings_np,
|
42 |
+
'label': int(label),
|
43 |
+
'uid': int(unique_id)})
|
44 |
+
|
45 |
+
return sample
|
46 |
+
|
47 |
+
|
48 |
+
def build_tokens_types_paddings_from_text(text_a, text_b,
|
49 |
+
tokenizer, max_seq_length):
|
50 |
+
"""Build token types and paddings, trim if needed, and pad if needed."""
|
51 |
+
|
52 |
+
text_a_ids = tokenizer.tokenize(text_a)
|
53 |
+
text_b_ids = None
|
54 |
+
if text_b is not None:
|
55 |
+
text_b_ids = tokenizer.tokenize(text_b)
|
56 |
+
|
57 |
+
return build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids,
|
58 |
+
max_seq_length, tokenizer.cls,
|
59 |
+
tokenizer.sep, tokenizer.pad)
|
60 |
+
|
61 |
+
|
62 |
+
def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
|
63 |
+
cls_id, sep_id, pad_id):
|
64 |
+
"""Build token types and paddings, trim if needed, and pad if needed."""
|
65 |
+
|
66 |
+
ids = []
|
67 |
+
types = []
|
68 |
+
paddings = []
|
69 |
+
|
70 |
+
# [CLS].
|
71 |
+
ids.append(cls_id)
|
72 |
+
types.append(0)
|
73 |
+
paddings.append(1)
|
74 |
+
|
75 |
+
# A.
|
76 |
+
len_text_a = len(text_a_ids)
|
77 |
+
ids.extend(text_a_ids)
|
78 |
+
types.extend([0] * len_text_a)
|
79 |
+
paddings.extend([1] * len_text_a)
|
80 |
+
|
81 |
+
# [SEP].
|
82 |
+
ids.append(sep_id)
|
83 |
+
types.append(0)
|
84 |
+
paddings.append(1)
|
85 |
+
|
86 |
+
# B.
|
87 |
+
if text_b_ids is not None:
|
88 |
+
len_text_b = len(text_b_ids)
|
89 |
+
ids.extend(text_b_ids)
|
90 |
+
types.extend([1] * len_text_b)
|
91 |
+
paddings.extend([1] * len_text_b)
|
92 |
+
|
93 |
+
# Cap the size.
|
94 |
+
trimmed = False
|
95 |
+
if len(ids) >= max_seq_length:
|
96 |
+
max_seq_length_m1 = max_seq_length - 1
|
97 |
+
ids = ids[0:max_seq_length_m1]
|
98 |
+
types = types[0:max_seq_length_m1]
|
99 |
+
paddings = paddings[0:max_seq_length_m1]
|
100 |
+
trimmed = True
|
101 |
+
|
102 |
+
# [SEP].
|
103 |
+
if (text_b_ids is not None) or trimmed:
|
104 |
+
ids.append(sep_id)
|
105 |
+
if text_b_ids is None:
|
106 |
+
types.append(0)
|
107 |
+
else:
|
108 |
+
types.append(1)
|
109 |
+
paddings.append(1)
|
110 |
+
|
111 |
+
# Padding.
|
112 |
+
padding_length = max_seq_length - len(ids)
|
113 |
+
if padding_length > 0:
|
114 |
+
ids.extend([pad_id] * padding_length)
|
115 |
+
types.extend([pad_id] * padding_length)
|
116 |
+
paddings.extend([0] * padding_length)
|
117 |
+
|
118 |
+
return ids, types, paddings
|
tasks/ensemble_classifier.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import collections
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def process_files(args):
|
10 |
+
all_predictions = collections.OrderedDict()
|
11 |
+
all_labels = collections.OrderedDict()
|
12 |
+
all_uid = collections.OrderedDict()
|
13 |
+
for path in args.paths:
|
14 |
+
path = os.path.join(path, args.prediction_name)
|
15 |
+
try:
|
16 |
+
data = torch.load(path)
|
17 |
+
for dataset in data:
|
18 |
+
name, d = dataset
|
19 |
+
predictions, labels, uid = d
|
20 |
+
if name not in all_predictions:
|
21 |
+
all_predictions[name] = np.array(predictions)
|
22 |
+
if args.labels is None:
|
23 |
+
args.labels = [i for i in range(all_predictions[name].shape[1])]
|
24 |
+
if args.eval:
|
25 |
+
all_labels[name] = np.array(labels)
|
26 |
+
all_uid[name] = np.array(uid)
|
27 |
+
else:
|
28 |
+
all_predictions[name] += np.array(predictions)
|
29 |
+
assert np.allclose(all_uid[name], np.array(uid))
|
30 |
+
except Exception as e:
|
31 |
+
print(e)
|
32 |
+
continue
|
33 |
+
return all_predictions, all_labels, all_uid
|
34 |
+
|
35 |
+
|
36 |
+
def get_threshold(all_predictions, all_labels, one_threshold=False):
|
37 |
+
if one_threshold:
|
38 |
+
all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
|
39 |
+
all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
|
40 |
+
out_thresh = []
|
41 |
+
for dataset in all_predictions:
|
42 |
+
preds = all_predictions[dataset]
|
43 |
+
labels = all_labels[dataset]
|
44 |
+
out_thresh.append(calc_threshold(preds, labels))
|
45 |
+
return out_thresh
|
46 |
+
|
47 |
+
|
48 |
+
def calc_threshold(p, l):
|
49 |
+
trials = [(i) * (1. / 100.) for i in range(100)]
|
50 |
+
best_acc = float('-inf')
|
51 |
+
best_thresh = 0
|
52 |
+
for t in trials:
|
53 |
+
acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
|
54 |
+
if acc > best_acc:
|
55 |
+
best_acc = acc
|
56 |
+
best_thresh = t
|
57 |
+
return best_thresh
|
58 |
+
|
59 |
+
|
60 |
+
def apply_threshold(preds, t):
|
61 |
+
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
|
62 |
+
prob = preds[:, -1]
|
63 |
+
thresholded = (prob >= t).astype(int)
|
64 |
+
preds = np.zeros_like(preds)
|
65 |
+
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
|
66 |
+
return preds
|
67 |
+
|
68 |
+
|
69 |
+
def threshold_predictions(all_predictions, threshold):
|
70 |
+
if len(threshold) != len(all_predictions):
|
71 |
+
threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
|
72 |
+
for i, dataset in enumerate(all_predictions):
|
73 |
+
thresh = threshold[i]
|
74 |
+
preds = all_predictions[dataset]
|
75 |
+
all_predictions[dataset] = apply_threshold(preds, thresh)
|
76 |
+
return all_predictions
|
77 |
+
|
78 |
+
|
79 |
+
def postprocess_predictions(all_predictions, all_labels, args):
|
80 |
+
for d in all_predictions:
|
81 |
+
all_predictions[d] = all_predictions[d] / len(args.paths)
|
82 |
+
|
83 |
+
if args.calc_threshold:
|
84 |
+
args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
|
85 |
+
print('threshold', args.threshold)
|
86 |
+
|
87 |
+
if args.threshold is not None:
|
88 |
+
all_predictions = threshold_predictions(all_predictions, args.threshold)
|
89 |
+
|
90 |
+
return all_predictions, all_labels
|
91 |
+
|
92 |
+
|
93 |
+
def write_predictions(all_predictions, all_labels, all_uid, args):
|
94 |
+
all_correct = 0
|
95 |
+
count = 0
|
96 |
+
for dataset in all_predictions:
|
97 |
+
preds = all_predictions[dataset]
|
98 |
+
preds = np.argmax(preds, -1)
|
99 |
+
if args.eval:
|
100 |
+
correct = (preds == all_labels[dataset]).sum()
|
101 |
+
num = len(all_labels[dataset])
|
102 |
+
accuracy = correct / num
|
103 |
+
count += num
|
104 |
+
all_correct += correct
|
105 |
+
accuracy = (preds == all_labels[dataset]).mean()
|
106 |
+
print(accuracy)
|
107 |
+
if not os.path.exists(os.path.join(args.outdir, dataset)):
|
108 |
+
os.makedirs(os.path.join(args.outdir, dataset))
|
109 |
+
outpath = os.path.join(
|
110 |
+
args.outdir, dataset, os.path.splitext(
|
111 |
+
args.prediction_name)[0] + '.tsv')
|
112 |
+
with open(outpath, 'w') as f:
|
113 |
+
f.write('id\tlabel\n')
|
114 |
+
f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
|
115 |
+
for uid, p in zip(all_uid[dataset], preds.tolist())))
|
116 |
+
if args.eval:
|
117 |
+
print(all_correct / count)
|
118 |
+
|
119 |
+
|
120 |
+
def ensemble_predictions(args):
|
121 |
+
all_predictions, all_labels, all_uid = process_files(args)
|
122 |
+
all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
|
123 |
+
write_predictions(all_predictions, all_labels, all_uid, args)
|
124 |
+
|
125 |
+
|
126 |
+
def main():
|
127 |
+
parser = argparse.ArgumentParser()
|
128 |
+
parser.add_argument('--paths', required=True, nargs='+',
|
129 |
+
help='paths to checkpoint directories used in ensemble')
|
130 |
+
parser.add_argument('--eval', action='store_true',
|
131 |
+
help='compute accuracy metrics against labels (dev set)')
|
132 |
+
parser.add_argument('--outdir',
|
133 |
+
help='directory to place ensembled predictions in')
|
134 |
+
parser.add_argument('--prediction-name', default='test_predictions.pt',
|
135 |
+
help='name of predictions in checkpoint directories')
|
136 |
+
parser.add_argument('--calc-threshold', action='store_true',
|
137 |
+
help='calculate threshold classification')
|
138 |
+
parser.add_argument('--one-threshold', action='store_true',
|
139 |
+
help='use on threshold for all subdatasets')
|
140 |
+
parser.add_argument('--threshold', nargs='+', default=None, type=float,
|
141 |
+
help='user supplied threshold for classification')
|
142 |
+
parser.add_argument('--labels', nargs='+', default=None,
|
143 |
+
help='whitespace separated list of label names')
|
144 |
+
args = parser.parse_args()
|
145 |
+
ensemble_predictions(args)
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
main()
|
tasks/eval_utils.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Evaluation utilities."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import time
|
20 |
+
from functools import partial
|
21 |
+
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from megatron import get_args
|
25 |
+
from megatron import print_rank_last, is_last_rank
|
26 |
+
from megatron import mpu
|
27 |
+
from megatron.schedules import get_forward_backward_func
|
28 |
+
from tasks.finetune_utils import build_data_loader
|
29 |
+
from tasks.finetune_utils import process_batch
|
30 |
+
import json
|
31 |
+
import numpy as np
|
32 |
+
from tasks.label_dict import get_label_dict
|
33 |
+
|
34 |
+
def accuracy_func_provider(single_dataset_provider):
|
35 |
+
"""Provide function that calculates accuracies."""
|
36 |
+
args = get_args()
|
37 |
+
|
38 |
+
# Build dataloaders.
|
39 |
+
datapaths = [args.valid_data[0], args.test_data[0]]
|
40 |
+
dataloaders = []
|
41 |
+
for datapath in datapaths:
|
42 |
+
dataset = single_dataset_provider(datapath)
|
43 |
+
dataloader = build_data_loader(
|
44 |
+
dataset, args.micro_batch_size, num_workers=args.num_workers,
|
45 |
+
drop_last=(mpu.get_data_parallel_world_size() > 1))
|
46 |
+
dataloaders.append((dataset.dataset_name, dataloader))
|
47 |
+
|
48 |
+
def _generate_prediction_json(predictions, step, save_acc):
|
49 |
+
|
50 |
+
probs_list = predictions[0]
|
51 |
+
# labels_list = predictions[1]
|
52 |
+
ids_list = predictions[2]
|
53 |
+
min_id = min(ids_list)
|
54 |
+
max_id = max(ids_list)
|
55 |
+
LABELS = get_label_dict(args.task, write2file=True)
|
56 |
+
output_submit_file = os.path.join(args.res_path[0], args.task+"_prediction_{}_{}.json".format(step, save_acc))
|
57 |
+
with open(output_submit_file, "w") as writer:
|
58 |
+
for i in range(min_id, max_id + 1):
|
59 |
+
label_index = ids_list.index(i)
|
60 |
+
pred_prob_list = probs_list[label_index]
|
61 |
+
label = pred_prob_list.index(max(pred_prob_list))
|
62 |
+
json_d = {}
|
63 |
+
if min_id == 1:
|
64 |
+
json_d['id'] = i - 1
|
65 |
+
else:
|
66 |
+
json_d['id'] = i
|
67 |
+
json_d["label"] = LABELS[str(label)]
|
68 |
+
writer.write(json.dumps(json_d) + '\n')
|
69 |
+
|
70 |
+
def _generate_prediction_prob(predictions, step, save_acc):
|
71 |
+
|
72 |
+
probs_list = predictions[0]
|
73 |
+
ids_list = predictions[2]
|
74 |
+
min_id = min(ids_list)
|
75 |
+
max_id = max(ids_list)
|
76 |
+
|
77 |
+
output_prob_file = os.path.join(args.res_path[0], args.task+"_prob_{}_{}".format(step, save_acc))
|
78 |
+
prob_arr = []
|
79 |
+
for i in range(min_id, max_id + 1):
|
80 |
+
label_index = ids_list.index(i)
|
81 |
+
prob_arr.append(probs_list[label_index])
|
82 |
+
prob_arr = np.array(prob_arr)
|
83 |
+
np.save(output_prob_file, prob_arr)
|
84 |
+
|
85 |
+
def metrics_func(model, step):
|
86 |
+
print_rank_last('calculating metrics ...')
|
87 |
+
correct = 0
|
88 |
+
total = 0
|
89 |
+
|
90 |
+
for index, (name, dataloader) in enumerate(dataloaders):
|
91 |
+
if index == 1:
|
92 |
+
output_predictions = True
|
93 |
+
assert mpu.get_data_parallel_world_size() == 1
|
94 |
+
named_predictions = []
|
95 |
+
names = 'predictions'
|
96 |
+
else:
|
97 |
+
output_predictions = False
|
98 |
+
|
99 |
+
output = calculate_correct_answers(name, model, dataloader,
|
100 |
+
step, output_predictions)
|
101 |
+
if not output_predictions:
|
102 |
+
correct_ans, total_count = output
|
103 |
+
else:
|
104 |
+
correct_ans, total_count, predictions = output
|
105 |
+
named_predictions.append((name, predictions))
|
106 |
+
names += '_' + name
|
107 |
+
if not output_predictions:
|
108 |
+
correct += correct_ans
|
109 |
+
total += total_count
|
110 |
+
save_acc = str(round(correct / total, 4) * 10000)[:4]
|
111 |
+
|
112 |
+
if output_predictions:
|
113 |
+
print_rank_last("generate prediction...")
|
114 |
+
# import pdb;pdb.set_trace()
|
115 |
+
_generate_prediction_json(predictions, step, save_acc)
|
116 |
+
_generate_prediction_prob(predictions, step, save_acc)
|
117 |
+
print_rank_last("generate done")
|
118 |
+
# import pdb;pdb.set_trace()
|
119 |
+
# import pdb;pdb.set_trace()
|
120 |
+
# if is_last_rank():
|
121 |
+
# percent = float(correct) * 100.0 / float(total)
|
122 |
+
# print(' >> |step: {}| overall: correct / total = {} / {} = '
|
123 |
+
# '{:.4f} %'.format(step, correct, total, percent))
|
124 |
+
# if output_predictions and is_last_rank():
|
125 |
+
# assert args.load is not None
|
126 |
+
# filename = os.path.join(args.load, names + '.pt')
|
127 |
+
# torch.save(named_predictions, filename)
|
128 |
+
|
129 |
+
return metrics_func
|
130 |
+
|
131 |
+
|
132 |
+
def calculate_correct_answers(name, model, dataloader,
|
133 |
+
step, output_predictions):
|
134 |
+
"""Calculate correct over total answers and return prediction if the
|
135 |
+
`output_predictions` is true."""
|
136 |
+
args = get_args()
|
137 |
+
forward_backward_func = get_forward_backward_func()
|
138 |
+
start_time = time.time()
|
139 |
+
for m in model:
|
140 |
+
m.eval()
|
141 |
+
saved_micro_batch_size = args.micro_batch_size
|
142 |
+
saved_global_batch_size = args.global_batch_size
|
143 |
+
|
144 |
+
ds = dataloader.dataset
|
145 |
+
if hasattr(ds, 'sample_multiplier'):
|
146 |
+
# If our dataset as a sample_multiplier attribute that means
|
147 |
+
# each "sample" from the dataset actually has multiple samples
|
148 |
+
# that will collapse into the batch dimension (for example in
|
149 |
+
# the RACE dataset that has several options), we need to
|
150 |
+
# account for that when setting the micro batch size.
|
151 |
+
sample_multiplier = ds.sample_multiplier
|
152 |
+
else:
|
153 |
+
sample_multiplier = 1
|
154 |
+
micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
|
155 |
+
num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
|
156 |
+
|
157 |
+
def loss_func(output_predictions, labels, output_tensor):
|
158 |
+
logits = output_tensor
|
159 |
+
|
160 |
+
loss_dict = {}
|
161 |
+
# Add output predictions.
|
162 |
+
if output_predictions:
|
163 |
+
# assert False
|
164 |
+
loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
|
165 |
+
logits.float()).data.cpu().numpy().tolist()
|
166 |
+
loss_dict['labels'] = labels.data.cpu().numpy().tolist()
|
167 |
+
loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
|
168 |
+
# Compute the correct answers.
|
169 |
+
predicted = torch.argmax(logits, dim=-1)
|
170 |
+
corrects = (predicted == labels)
|
171 |
+
# Add to the counters.
|
172 |
+
loss_dict['total'] = labels.size(0)
|
173 |
+
loss_dict['correct'] = corrects.sum().item()
|
174 |
+
|
175 |
+
return 0, loss_dict
|
176 |
+
|
177 |
+
# defined inside to capture output_predictions
|
178 |
+
def correct_answers_forward_step(batch, model):
|
179 |
+
try:
|
180 |
+
batch_ = next(batch)
|
181 |
+
except BaseException:
|
182 |
+
batch_ = batch
|
183 |
+
tokens, types, labels, attention_mask = process_batch(batch_)
|
184 |
+
|
185 |
+
# Forward model.
|
186 |
+
args = get_args()
|
187 |
+
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
|
188 |
+
|
189 |
+
return output_tensor, partial(loss_func, output_predictions, labels)
|
190 |
+
|
191 |
+
with torch.no_grad():
|
192 |
+
# For all the batches in the dataset.
|
193 |
+
total = 0
|
194 |
+
correct = 0
|
195 |
+
if output_predictions:
|
196 |
+
# This option is only possible when data parallel size is 1.
|
197 |
+
assert mpu.get_data_parallel_world_size() == 1
|
198 |
+
softmaxes = []
|
199 |
+
labels = []
|
200 |
+
ids = []
|
201 |
+
for _, batch in enumerate(dataloader):
|
202 |
+
# For evaluation only mode we use drop_last = False to get all the
|
203 |
+
# samples, which means we might not have a full batch, so we
|
204 |
+
# adjust batch_size here to actual batch size of data
|
205 |
+
actual_batch_size = len(batch['label'])
|
206 |
+
# ... applying sample_multiplier if necessary
|
207 |
+
args.micro_batch_size = actual_batch_size * sample_multiplier
|
208 |
+
args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
|
209 |
+
|
210 |
+
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
|
211 |
+
optimizer=None, timers=None, forward_only=True)
|
212 |
+
|
213 |
+
for loss_dict in loss_dicts:
|
214 |
+
if output_predictions:
|
215 |
+
softmaxes.extend(loss_dict['softmaxes'])
|
216 |
+
labels.extend(loss_dict['labels'])
|
217 |
+
ids.extend(loss_dict['ids'])
|
218 |
+
total += loss_dict['total']
|
219 |
+
correct += loss_dict['correct']
|
220 |
+
|
221 |
+
|
222 |
+
for m in model:
|
223 |
+
m.train()
|
224 |
+
args.micro_batch_size = saved_micro_batch_size
|
225 |
+
args.global_batch_size = saved_global_batch_size
|
226 |
+
|
227 |
+
# Reduce.
|
228 |
+
if mpu.is_pipeline_last_stage():
|
229 |
+
unreduced = torch.cuda.LongTensor([correct, total])
|
230 |
+
torch.distributed.all_reduce(unreduced,
|
231 |
+
group=mpu.get_data_parallel_group())
|
232 |
+
|
233 |
+
# Print on screen.
|
234 |
+
|
235 |
+
correct_ans = unreduced[0].item()
|
236 |
+
total_count = unreduced[1].item()
|
237 |
+
percent = float(correct_ans) * 100.0 / float(total_count)
|
238 |
+
elapsed_time = time.time() - start_time
|
239 |
+
if not output_predictions:
|
240 |
+
print_rank_last(' > |step: {} | metrics for {}: correct / total '
|
241 |
+
'= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
|
242 |
+
step, name, correct_ans, total_count,
|
243 |
+
percent, elapsed_time))
|
244 |
+
|
245 |
+
if output_predictions:
|
246 |
+
return correct_ans, total_count, (softmaxes, labels, ids)
|
247 |
+
return correct_ans, total_count
|
248 |
+
if output_predictions:
|
249 |
+
return 0, 0, ()
|
250 |
+
return 0, 0
|
tasks/finetune_utils.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Finetune utilities."""
|
17 |
+
|
18 |
+
from functools import partial
|
19 |
+
import sys
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from megatron import get_args, get_num_microbatches
|
23 |
+
from megatron import print_rank_0
|
24 |
+
from megatron import get_timers
|
25 |
+
from megatron import mpu
|
26 |
+
from megatron.checkpointing import load_checkpoint
|
27 |
+
from megatron.checkpointing import save_checkpoint
|
28 |
+
from megatron.model import ModelType
|
29 |
+
from megatron.training import evaluate_and_print_results
|
30 |
+
from megatron.training import setup_model_and_optimizer
|
31 |
+
from megatron.training import train_step
|
32 |
+
from megatron.training import training_log
|
33 |
+
from megatron.utils import average_losses_across_data_parallel_group
|
34 |
+
from megatron.utils import calc_params_l2_norm
|
35 |
+
from megatron.utils import check_adlr_autoresume_termination
|
36 |
+
|
37 |
+
|
38 |
+
def process_batch(batch):
|
39 |
+
"""Process batch and produce inputs for the model."""
|
40 |
+
args = get_args()
|
41 |
+
|
42 |
+
tokens = batch['text'].long().cuda().contiguous()
|
43 |
+
types = batch['types'].long().cuda().contiguous()
|
44 |
+
labels = batch['label'].long().cuda().contiguous()
|
45 |
+
attention_mask = batch['padding_mask'].float().cuda().contiguous()
|
46 |
+
if args.fp16:
|
47 |
+
attention_mask = attention_mask.half()
|
48 |
+
|
49 |
+
return tokens, types, labels, attention_mask
|
50 |
+
|
51 |
+
|
52 |
+
def cross_entropy_loss_func(labels, output_tensor):
|
53 |
+
logits = output_tensor
|
54 |
+
|
55 |
+
# Cross-entropy loss.
|
56 |
+
loss_func = torch.nn.CrossEntropyLoss()
|
57 |
+
loss = loss_func(logits.contiguous().float(), labels)
|
58 |
+
|
59 |
+
# Reduce loss for logging.
|
60 |
+
averaged_loss = average_losses_across_data_parallel_group([loss])
|
61 |
+
|
62 |
+
return loss, {'training loss': averaged_loss[0]}
|
63 |
+
|
64 |
+
|
65 |
+
def _cross_entropy_forward_step(batch, model):
|
66 |
+
"""Simple forward step with cross-entropy loss."""
|
67 |
+
timers = get_timers()
|
68 |
+
|
69 |
+
# Get the batch.
|
70 |
+
timers('batch-generator').start()
|
71 |
+
try:
|
72 |
+
batch_ = next(batch)
|
73 |
+
except BaseException:
|
74 |
+
batch_ = batch
|
75 |
+
tokens, types, labels, attention_mask = process_batch(batch_)
|
76 |
+
timers('batch-generator').stop()
|
77 |
+
|
78 |
+
# Forward model.
|
79 |
+
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
|
80 |
+
|
81 |
+
return output_tensor, partial(cross_entropy_loss_func, labels)
|
82 |
+
|
83 |
+
|
84 |
+
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
|
85 |
+
task_collate_fn=None):
|
86 |
+
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
|
87 |
+
|
88 |
+
# Sampler.
|
89 |
+
world_size = mpu.get_data_parallel_world_size()
|
90 |
+
rank = mpu.get_data_parallel_rank()
|
91 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
92 |
+
dataset, num_replicas=world_size, rank=rank)
|
93 |
+
|
94 |
+
# Data loader. Note that batch size is the per GPU batch size.
|
95 |
+
data_loader = torch.utils.data.DataLoader(dataset,
|
96 |
+
batch_size=micro_batch_size,
|
97 |
+
sampler=sampler,
|
98 |
+
shuffle=False,
|
99 |
+
num_workers=num_workers,
|
100 |
+
drop_last=drop_last,
|
101 |
+
pin_memory=True,
|
102 |
+
collate_fn=task_collate_fn)
|
103 |
+
|
104 |
+
return data_loader
|
105 |
+
|
106 |
+
|
107 |
+
def _build_infinite_size_dataloader(dataloader):
|
108 |
+
"""Build a looped dataloader with infinite size."""
|
109 |
+
|
110 |
+
iterator = dataloader.__iter__()
|
111 |
+
while True:
|
112 |
+
try:
|
113 |
+
yield iterator.__next__()
|
114 |
+
except StopIteration:
|
115 |
+
iterator = dataloader.__iter__()
|
116 |
+
|
117 |
+
|
118 |
+
def _build_train_valid_dataloaders(train_dataset, valid_dataset,
|
119 |
+
task_collate_fn=None):
|
120 |
+
"""Traing and validation dataloaders."""
|
121 |
+
args = get_args()
|
122 |
+
|
123 |
+
print_rank_0('building train and validation dataloaders ...')
|
124 |
+
# Training dataset.
|
125 |
+
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
|
126 |
+
args.num_workers, not args.keep_last,
|
127 |
+
task_collate_fn)
|
128 |
+
# Set the training iterations.
|
129 |
+
args.train_iters_per_epoch = len(train_dataloader)
|
130 |
+
args.train_iters = args.epochs * args.train_iters_per_epoch
|
131 |
+
# Validation dataset. For this dataset, we do not need to set up
|
132 |
+
# shuffling so we can just use a simple infinite loop.
|
133 |
+
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
|
134 |
+
args.num_workers, not args.keep_last,
|
135 |
+
task_collate_fn)
|
136 |
+
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
|
137 |
+
|
138 |
+
# Now that we've built the data loaders, set batch_size arguments
|
139 |
+
# to the actual batch size the model will see for this dataset.
|
140 |
+
# This is necessary so pipeline transfers know what size they are
|
141 |
+
# and the LR schedule, which is based on samples seen, gets set
|
142 |
+
# correctly.
|
143 |
+
args.orig_micro_batch_size = args.micro_batch_size
|
144 |
+
args.orig_global_batch_size = args.global_batch_size
|
145 |
+
if hasattr(train_dataset, 'sample_multiplier'):
|
146 |
+
# If our dataset as a sample_multiplier attribute that means
|
147 |
+
# each "sample" from the dataset actually has multiple samples
|
148 |
+
# that will collapse into the batch dimension (for example in
|
149 |
+
# the RACE dataset that has several options), we need to
|
150 |
+
# account for that when setting the micro batch size.
|
151 |
+
args.micro_batch_size *= train_dataset.sample_multiplier
|
152 |
+
args.global_batch_size *= train_dataset.sample_multiplier
|
153 |
+
|
154 |
+
return train_dataloader, valid_dataloader
|
155 |
+
|
156 |
+
|
157 |
+
def _train(model, optimizer, opt_param_scheduler, forward_step,
|
158 |
+
train_dataloader, valid_dataloader, end_of_epoch_callback):
|
159 |
+
"""Train the model."""
|
160 |
+
args = get_args()
|
161 |
+
timers = get_timers()
|
162 |
+
|
163 |
+
assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
|
164 |
+
|
165 |
+
# Turn on training mode which enables dropout.
|
166 |
+
for m in model:
|
167 |
+
m.train()
|
168 |
+
|
169 |
+
# Tracking loss.
|
170 |
+
losses_dict_sum = {}
|
171 |
+
|
172 |
+
# Starting epoch and iteration
|
173 |
+
start_epoch = args.iteration // args.train_iters_per_epoch
|
174 |
+
start_iteration = args.iteration % args.train_iters_per_epoch
|
175 |
+
iteration = args.iteration
|
176 |
+
|
177 |
+
# Memory reporting flag.
|
178 |
+
report_memory_flag = True
|
179 |
+
|
180 |
+
# For each remaining epoch
|
181 |
+
timers('interval-time').start()
|
182 |
+
for epoch in range(start_epoch, args.epochs):
|
183 |
+
print_rank_0('working on epoch {} ...'.format(epoch + 1))
|
184 |
+
|
185 |
+
# Set the data loader epoch to shuffle the index iterator.
|
186 |
+
train_dataloader.sampler.set_epoch(args.seed + epoch)
|
187 |
+
|
188 |
+
# For all the batches in the dataset.
|
189 |
+
for iteration_, batch in enumerate(train_dataloader):
|
190 |
+
|
191 |
+
# Ignore the iterations before starting value
|
192 |
+
if iteration_ < start_iteration:
|
193 |
+
continue
|
194 |
+
# Set to zero so the next epoch does not skip any batches.
|
195 |
+
start_iteration = 0
|
196 |
+
|
197 |
+
# Train for one step.
|
198 |
+
out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
|
199 |
+
|
200 |
+
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
|
201 |
+
iteration += 1
|
202 |
+
|
203 |
+
# Logging.
|
204 |
+
params_norm = None
|
205 |
+
if args.log_params_norm:
|
206 |
+
params_norm = calc_params_l2_norm(model)
|
207 |
+
report_memory_flag = training_log(losses_dict, losses_dict_sum,
|
208 |
+
optimizer.param_groups[0]['lr'],
|
209 |
+
iteration,
|
210 |
+
optimizer.get_loss_scale().item(),
|
211 |
+
report_memory_flag, skipped_iter,
|
212 |
+
grad_norm, params_norm, num_zeros_in_grad, None)
|
213 |
+
|
214 |
+
# Autoresume
|
215 |
+
if args.adlr_autoresume and \
|
216 |
+
(iteration % args.adlr_autoresume_interval == 0):
|
217 |
+
check_adlr_autoresume_termination(iteration, model,
|
218 |
+
optimizer, opt_param_scheduler)
|
219 |
+
|
220 |
+
# Checkpointing
|
221 |
+
saved_checkpoint = False
|
222 |
+
if args.save and args.save_interval and \
|
223 |
+
iteration % args.save_interval == 0:
|
224 |
+
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
|
225 |
+
saved_checkpoint = True
|
226 |
+
|
227 |
+
# Evaluation
|
228 |
+
if args.eval_interval and iteration % args.eval_interval == 0:
|
229 |
+
prefix = 'iteration {}'.format(iteration)
|
230 |
+
evaluate_and_print_results(prefix, forward_step,
|
231 |
+
valid_dataloader, model,
|
232 |
+
iteration, None, False)
|
233 |
+
if end_of_epoch_callback is not None:
|
234 |
+
end_of_epoch_callback(model, iteration)
|
235 |
+
print_rank_0('-' * 72 + '\n')
|
236 |
+
|
237 |
+
# Exiting based on iterations
|
238 |
+
if args.exit_interval and iteration % args.exit_interval == 0:
|
239 |
+
if not saved_checkpoint:
|
240 |
+
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
|
241 |
+
torch.distributed.barrier()
|
242 |
+
print_rank_0('exiting program at iteration {}'.format(iteration))
|
243 |
+
sys.exit()
|
244 |
+
|
245 |
+
# Checkpointing at the end of each epoch.
|
246 |
+
if args.save:
|
247 |
+
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
|
248 |
+
|
249 |
+
prefix = 'iteration {}'.format(iteration)
|
250 |
+
evaluate_and_print_results(prefix, forward_step,
|
251 |
+
valid_dataloader, model,
|
252 |
+
iteration, None, False)
|
253 |
+
if end_of_epoch_callback is not None:
|
254 |
+
end_of_epoch_callback(model, iteration)
|
255 |
+
print_rank_0('-' * 72 + '\n')
|
256 |
+
|
257 |
+
# Callback at the end of each epoch.
|
258 |
+
# if end_of_epoch_callback is not None:
|
259 |
+
# end_of_epoch_callback(model, epoch)
|
260 |
+
|
261 |
+
|
262 |
+
def finetune(train_valid_datasets_provider, model_provider,
|
263 |
+
model_type=ModelType.encoder_or_decoder,
|
264 |
+
forward_step=_cross_entropy_forward_step,
|
265 |
+
end_of_epoch_callback_provider=None,
|
266 |
+
task_collate_fn=None):
|
267 |
+
"""Main finetune function used across all tasks."""
|
268 |
+
args = get_args()
|
269 |
+
timers = get_timers()
|
270 |
+
|
271 |
+
assert args.rampup_batch_size is None, \
|
272 |
+
'batch size scaling is not supported for finetuning'
|
273 |
+
|
274 |
+
# Train and validation data loaders.
|
275 |
+
timers('train/valid/test dataset/dataloder').start()
|
276 |
+
if args.epochs > 0:
|
277 |
+
train_dataset, valid_dataset = train_valid_datasets_provider()
|
278 |
+
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
|
279 |
+
train_dataset, valid_dataset, task_collate_fn)
|
280 |
+
else:
|
281 |
+
args.train_iters = 0
|
282 |
+
timers('train/valid/test dataset/dataloder').stop()
|
283 |
+
|
284 |
+
# Build calback function.
|
285 |
+
timers('callback function').start()
|
286 |
+
end_of_epoch_callback = None
|
287 |
+
if end_of_epoch_callback_provider is not None:
|
288 |
+
end_of_epoch_callback = end_of_epoch_callback_provider()
|
289 |
+
timers('callback function').stop()
|
290 |
+
|
291 |
+
# Build model, optimizer and learning rate scheduler.
|
292 |
+
timers('model and optimizer').start()
|
293 |
+
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
|
294 |
+
timers('model and optimizer').stop()
|
295 |
+
|
296 |
+
# If pretrained checkpoint is provided and we have not trained for
|
297 |
+
# any iteration (i.e., iteration is zero), then load the pretrained
|
298 |
+
# checkpoint.
|
299 |
+
timers('pretrained checkpoint').start()
|
300 |
+
if args.iteration == 0 and args.pretrained_checkpoint is not None:
|
301 |
+
original_load = args.load
|
302 |
+
args.load = args.pretrained_checkpoint
|
303 |
+
original_rng = args.no_load_rng
|
304 |
+
args.no_load_rng = True
|
305 |
+
_ = load_checkpoint(model, None, None)
|
306 |
+
args.load = original_load
|
307 |
+
args.no_load_rng = original_rng
|
308 |
+
# This is critical when only model is loaded. We should make sure
|
309 |
+
# main parameters are also updated.
|
310 |
+
optimizer.reload_model_params()
|
311 |
+
timers('pretrained checkpoint').stop()
|
312 |
+
|
313 |
+
# Print setup timing.
|
314 |
+
print_rank_0('done with setups ...')
|
315 |
+
timers.log(['train/valid/test dataset/dataloder', 'callback function',
|
316 |
+
'model and optimizer', 'pretrained checkpoint'])
|
317 |
+
print_rank_0('training ...')
|
318 |
+
|
319 |
+
# Finetune the model.
|
320 |
+
if args.epochs > 0:
|
321 |
+
_train(model, optimizer, opt_param_scheduler, forward_step,
|
322 |
+
train_dataloader, valid_dataloader, end_of_epoch_callback)
|
323 |
+
# Or just evaluate.
|
324 |
+
else:
|
325 |
+
print_rank_0("Not Imp")
|
326 |
+
import pdb;pdb.set_trace()
|
327 |
+
# if end_of_epoch_callback is not None:
|
328 |
+
# print_rank_0('evaluation only mode, setting epoch to -1')
|
329 |
+
# end_of_epoch_callback(model, epoch=-1, output_predictions=True)
|
330 |
+
print_rank_0('done :-)')
|
tasks/glue/data.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""GLUE dataset."""
|
17 |
+
|
18 |
+
from abc import ABC
|
19 |
+
from abc import abstractmethod
|
20 |
+
|
21 |
+
from torch.utils.data import Dataset
|
22 |
+
|
23 |
+
from megatron import print_rank_0
|
24 |
+
from tasks.data_utils import build_sample
|
25 |
+
from tasks.data_utils import build_tokens_types_paddings_from_text
|
26 |
+
|
27 |
+
|
28 |
+
class GLUEAbstractDataset(ABC, Dataset):
|
29 |
+
"""GLUE base dataset class."""
|
30 |
+
|
31 |
+
def __init__(self, task_name, dataset_name, datapaths,
|
32 |
+
tokenizer, max_seq_length):
|
33 |
+
# Store inputs.
|
34 |
+
self.task_name = task_name
|
35 |
+
self.dataset_name = dataset_name
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.max_seq_length = max_seq_length
|
38 |
+
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
|
39 |
+
self.dataset_name))
|
40 |
+
# Process the files.
|
41 |
+
string = ' > paths:'
|
42 |
+
for path in datapaths:
|
43 |
+
string += ' ' + path
|
44 |
+
print_rank_0(string)
|
45 |
+
self.samples = []
|
46 |
+
for datapath in datapaths:
|
47 |
+
self.samples.extend(self.process_samples_from_single_path(datapath))
|
48 |
+
print_rank_0(' >> total number of samples: {}'.format(
|
49 |
+
len(self.samples)))
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.samples)
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
raw_sample = self.samples[idx]
|
56 |
+
ids, types, paddings = build_tokens_types_paddings_from_text(
|
57 |
+
raw_sample['text_a'], raw_sample['text_b'],
|
58 |
+
self.tokenizer, self.max_seq_length)
|
59 |
+
sample = build_sample(ids, types, paddings,
|
60 |
+
raw_sample['label'], raw_sample['uid'])
|
61 |
+
return sample
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def process_samples_from_single_path(self, datapath):
|
65 |
+
"""Abstract method that takes a single path / filename and
|
66 |
+
returns a list of dataset samples, each sample being a dict of
|
67 |
+
{'text_a': string, 'text_b': string, 'label': int, 'uid': int}
|
68 |
+
"""
|
69 |
+
pass
|
tasks/glue/finetune.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""GLUE finetuning/evaluation."""
|
17 |
+
|
18 |
+
from megatron import get_args
|
19 |
+
from megatron import print_rank_0
|
20 |
+
from megatron import get_tokenizer
|
21 |
+
from megatron import mpu
|
22 |
+
from megatron.model.classification import Classification
|
23 |
+
from tasks.eval_utils import accuracy_func_provider
|
24 |
+
from tasks.finetune_utils import finetune
|
25 |
+
|
26 |
+
|
27 |
+
def glue_classification(num_classes, Dataset,
|
28 |
+
name_from_datapath_func):
|
29 |
+
|
30 |
+
def train_valid_datasets_provider():
|
31 |
+
"""Build train and validation dataset."""
|
32 |
+
args = get_args()
|
33 |
+
tokenizer = get_tokenizer()
|
34 |
+
|
35 |
+
train_dataset = Dataset('training', args.train_data,
|
36 |
+
tokenizer, args.seq_length)
|
37 |
+
valid_dataset = Dataset('validation', args.valid_data,
|
38 |
+
tokenizer, args.seq_length)
|
39 |
+
|
40 |
+
return train_dataset, valid_dataset
|
41 |
+
|
42 |
+
def model_provider(pre_process=True, post_process=True):
|
43 |
+
"""Build the model."""
|
44 |
+
args = get_args()
|
45 |
+
|
46 |
+
print_rank_0('building classification model for {} ...'.format(
|
47 |
+
args.task))
|
48 |
+
model = Classification(num_classes=num_classes, num_tokentypes=2,
|
49 |
+
pre_process=pre_process, post_process=post_process)
|
50 |
+
|
51 |
+
return model
|
52 |
+
|
53 |
+
def metrics_func_provider():
|
54 |
+
"""Privde metrics callback function."""
|
55 |
+
def single_dataset_provider(datapath):
|
56 |
+
args = get_args()
|
57 |
+
tokenizer = get_tokenizer()
|
58 |
+
|
59 |
+
name = name_from_datapath_func(datapath)
|
60 |
+
return Dataset(name, [datapath], tokenizer, args.seq_length)
|
61 |
+
return accuracy_func_provider(single_dataset_provider)
|
62 |
+
|
63 |
+
"""Finetune/evaluate."""
|
64 |
+
finetune(train_valid_datasets_provider, model_provider,
|
65 |
+
end_of_epoch_callback_provider=metrics_func_provider)
|
66 |
+
|
67 |
+
|
68 |
+
def main():
|
69 |
+
args = get_args()
|
70 |
+
|
71 |
+
if args.task == 'MNLI':
|
72 |
+
|
73 |
+
num_classes = 3
|
74 |
+
from tasks.glue.mnli import MNLIDataset as Dataset
|
75 |
+
|
76 |
+
def name_from_datapath(datapath):
|
77 |
+
return datapath.split('MNLI')[-1].strip(
|
78 |
+
'.tsv').strip('/').replace('_', '-')
|
79 |
+
|
80 |
+
elif args.task == 'QQP':
|
81 |
+
|
82 |
+
num_classes = 2
|
83 |
+
from tasks.glue.qqp import QQPDataset as Dataset
|
84 |
+
|
85 |
+
def name_from_datapath(datapath):
|
86 |
+
return datapath.split('QQP')[-1].strip(
|
87 |
+
'.tsv').strip('/').replace('_', '-')
|
88 |
+
|
89 |
+
else:
|
90 |
+
raise NotImplementedError('GLUE task {} is not implemented.'.format(
|
91 |
+
args.task))
|
92 |
+
|
93 |
+
glue_classification(num_classes, Dataset, name_from_datapath)
|
tasks/glue/mnli.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""MNLI dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
|
22 |
+
|
23 |
+
LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
|
24 |
+
|
25 |
+
|
26 |
+
class MNLIDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label='contradiction'):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('MNLI', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
for line in f:
|
44 |
+
row = line.strip().split('\t')
|
45 |
+
if first:
|
46 |
+
first = False
|
47 |
+
if len(row) == 10:
|
48 |
+
is_test = True
|
49 |
+
print_rank_0(
|
50 |
+
' reading {}, {} and {} columns and setting '
|
51 |
+
'labels to {}'.format(
|
52 |
+
row[0].strip(), row[8].strip(),
|
53 |
+
row[9].strip(), self.test_label))
|
54 |
+
else:
|
55 |
+
print_rank_0(' reading {} , {}, {}, and {} columns '
|
56 |
+
'...'.format(
|
57 |
+
row[0].strip(), row[8].strip(),
|
58 |
+
row[9].strip(), row[-1].strip()))
|
59 |
+
continue
|
60 |
+
|
61 |
+
text_a = clean_text(row[8].strip())
|
62 |
+
text_b = clean_text(row[9].strip())
|
63 |
+
unique_id = int(row[0].strip())
|
64 |
+
label = row[-1].strip()
|
65 |
+
if is_test:
|
66 |
+
label = self.test_label
|
67 |
+
|
68 |
+
assert len(text_a) > 0
|
69 |
+
assert len(text_b) > 0
|
70 |
+
assert label in LABELS
|
71 |
+
assert unique_id >= 0
|
72 |
+
|
73 |
+
sample = {'text_a': text_a,
|
74 |
+
'text_b': text_b,
|
75 |
+
'label': LABELS[label],
|
76 |
+
'uid': unique_id}
|
77 |
+
total += 1
|
78 |
+
samples.append(sample)
|
79 |
+
|
80 |
+
if total % 50000 == 0:
|
81 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
82 |
+
|
83 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
84 |
+
return samples
|
tasks/glue/qqp.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""QQP dataset."""
|
17 |
+
|
18 |
+
from megatron import print_rank_0
|
19 |
+
from tasks.data_utils import clean_text
|
20 |
+
from .data import GLUEAbstractDataset
|
21 |
+
|
22 |
+
|
23 |
+
LABELS = [0, 1]
|
24 |
+
|
25 |
+
|
26 |
+
class QQPDataset(GLUEAbstractDataset):
|
27 |
+
|
28 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length,
|
29 |
+
test_label=0):
|
30 |
+
self.test_label = test_label
|
31 |
+
super().__init__('QQP', name, datapaths,
|
32 |
+
tokenizer, max_seq_length)
|
33 |
+
|
34 |
+
def process_samples_from_single_path(self, filename):
|
35 |
+
""""Implement abstract method."""
|
36 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
37 |
+
|
38 |
+
samples = []
|
39 |
+
total = 0
|
40 |
+
first = True
|
41 |
+
is_test = False
|
42 |
+
with open(filename, 'r') as f:
|
43 |
+
for line in f:
|
44 |
+
row = line.strip().split('\t')
|
45 |
+
if first:
|
46 |
+
first = False
|
47 |
+
if len(row) == 3:
|
48 |
+
is_test = True
|
49 |
+
print_rank_0(' reading {}, {}, and {} columns and '
|
50 |
+
'setting labels to {}'.format(
|
51 |
+
row[0].strip(), row[1].strip(),
|
52 |
+
row[2].strip(), self.test_label))
|
53 |
+
else:
|
54 |
+
assert len(row) == 6
|
55 |
+
print_rank_0(' reading {}, {}, {}, and {} columns'
|
56 |
+
' ...'.format(
|
57 |
+
row[0].strip(), row[3].strip(),
|
58 |
+
row[4].strip(), row[5].strip()))
|
59 |
+
continue
|
60 |
+
|
61 |
+
if is_test:
|
62 |
+
assert len(row) == 3, 'expected length 3: {}'.format(row)
|
63 |
+
uid = int(row[0].strip())
|
64 |
+
text_a = clean_text(row[1].strip())
|
65 |
+
text_b = clean_text(row[2].strip())
|
66 |
+
label = self.test_label
|
67 |
+
assert len(text_a) > 0
|
68 |
+
assert len(text_b) > 0
|
69 |
+
else:
|
70 |
+
if len(row) == 6:
|
71 |
+
uid = int(row[0].strip())
|
72 |
+
text_a = clean_text(row[3].strip())
|
73 |
+
text_b = clean_text(row[4].strip())
|
74 |
+
label = int(row[5].strip())
|
75 |
+
else:
|
76 |
+
print_rank_0('***WARNING*** index error, '
|
77 |
+
'skipping: {}'.format(row))
|
78 |
+
continue
|
79 |
+
if len(text_a) == 0:
|
80 |
+
print_rank_0('***WARNING*** zero length a, '
|
81 |
+
'skipping: {}'.format(row))
|
82 |
+
continue
|
83 |
+
if len(text_b) == 0:
|
84 |
+
print_rank_0('***WARNING*** zero length b, '
|
85 |
+
'skipping: {}'.format(row))
|
86 |
+
continue
|
87 |
+
assert label in LABELS
|
88 |
+
assert uid >= 0
|
89 |
+
|
90 |
+
sample = {'uid': uid,
|
91 |
+
'text_a': text_a,
|
92 |
+
'text_b': text_b,
|
93 |
+
'label': label}
|
94 |
+
total += 1
|
95 |
+
samples.append(sample)
|
96 |
+
|
97 |
+
if total % 50000 == 0:
|
98 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
99 |
+
|
100 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
101 |
+
return samples
|
tasks/label_dict.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
AFQMC_LABELS = {
|
3 |
+
'0': '0',
|
4 |
+
'1': '1',
|
5 |
+
}
|
6 |
+
|
7 |
+
CSL_LABELS = {
|
8 |
+
'0': '0',
|
9 |
+
'1': '1',
|
10 |
+
'2': '2',
|
11 |
+
}
|
12 |
+
|
13 |
+
IFLYTEK_LABELS = {}
|
14 |
+
for i in range(119):
|
15 |
+
IFLYTEK_LABELS[str(i)] = str(i)
|
16 |
+
|
17 |
+
OCNLI_LABELS = {
|
18 |
+
'contradiction': '0',
|
19 |
+
'entailment': '1',
|
20 |
+
'neutral': '2'
|
21 |
+
}
|
22 |
+
|
23 |
+
CMNLI_LABELS = {
|
24 |
+
'contradiction': '0',
|
25 |
+
'entailment': '1',
|
26 |
+
'neutral': '2'
|
27 |
+
}
|
28 |
+
|
29 |
+
TNEWS_LABELS = {}
|
30 |
+
tnews_list = []
|
31 |
+
for i in range(17):
|
32 |
+
if i == 5 or i == 11:
|
33 |
+
continue
|
34 |
+
tnews_list.append(i)
|
35 |
+
for i in range(len(tnews_list)):
|
36 |
+
TNEWS_LABELS[str(100 + tnews_list[i])] = str(i)
|
37 |
+
|
38 |
+
WSC_LABELS = {
|
39 |
+
'true': '0',
|
40 |
+
'false': '1',
|
41 |
+
}
|
42 |
+
|
43 |
+
ZC_LABELS = {
|
44 |
+
'negative': '0',
|
45 |
+
'positive': '1',
|
46 |
+
}
|
47 |
+
|
48 |
+
def get_label_dict(task_name, write2file=False):
|
49 |
+
|
50 |
+
if task_name == "AFQMC":
|
51 |
+
label_dict = AFQMC_LABELS
|
52 |
+
elif task_name == "CSL":
|
53 |
+
label_dict = CSL_LABELS
|
54 |
+
elif task_name == "IFLYTEK":
|
55 |
+
label_dict = IFLYTEK_LABELS
|
56 |
+
elif task_name == "OCNLI":
|
57 |
+
label_dict = OCNLI_LABELS
|
58 |
+
elif task_name == "TNEWS":
|
59 |
+
label_dict = TNEWS_LABELS
|
60 |
+
elif task_name == "WSC":
|
61 |
+
label_dict = WSC_LABELS
|
62 |
+
elif task_name == "CMNLI":
|
63 |
+
label_dict = CMNLI_LABELS
|
64 |
+
elif task_name == "ZC":
|
65 |
+
label_dict = ZC_LABELS
|
66 |
+
else:
|
67 |
+
print("Not Imp")
|
68 |
+
import pdb;pdb.set_trace()
|
69 |
+
|
70 |
+
if write2file:
|
71 |
+
label_dict = {v:k for k,v in label_dict.items()}
|
72 |
+
|
73 |
+
return label_dict
|
tasks/main.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Main tasks functionality."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
|
21 |
+
os.path.pardir)))
|
22 |
+
|
23 |
+
from megatron import get_args
|
24 |
+
from megatron.initialize import initialize_megatron
|
25 |
+
|
26 |
+
|
27 |
+
def get_tasks_args(parser):
|
28 |
+
"""Provide extra arguments required for tasks."""
|
29 |
+
group = parser.add_argument_group(title='tasks')
|
30 |
+
|
31 |
+
group.add_argument('--task', type=str, required=True,
|
32 |
+
help='Task name.')
|
33 |
+
group.add_argument('--epochs', type=int, default=None,
|
34 |
+
help='Number of finetunning epochs. Zero results in '
|
35 |
+
'evaluation only.')
|
36 |
+
group.add_argument('--pretrained-checkpoint', type=str, default=None,
|
37 |
+
help='Pretrained checkpoint used for finetunning.')
|
38 |
+
group.add_argument('--keep-last', action='store_true',
|
39 |
+
help='Keep the last batch (maybe incomplete) in'
|
40 |
+
'the data loader')
|
41 |
+
group.add_argument('--train-data', nargs='+', default=None,
|
42 |
+
help='Whitespace separated paths or corpora names '
|
43 |
+
'for training.')
|
44 |
+
group.add_argument('--valid-data', nargs='*', default=None,
|
45 |
+
help='path(s) to the validation data.')
|
46 |
+
group.add_argument('--test-data', nargs='*', default=None,
|
47 |
+
help='path(s) to the test data.')
|
48 |
+
group.add_argument('--res-path', nargs='*', default=None,
|
49 |
+
help='path(s) to the test result.')
|
50 |
+
group.add_argument('--overlapping-eval', type=int, default=32,
|
51 |
+
help='Sliding window for overlapping evaluation.')
|
52 |
+
group.add_argument('--strict-lambada', action='store_true',
|
53 |
+
help='Use more difficult formulation of lambada.')
|
54 |
+
# Retriever args
|
55 |
+
group.add_argument('--qa-data-dev', type=str, default=None,
|
56 |
+
help='Path to the QA dataset dev file.')
|
57 |
+
group.add_argument('--qa-data-test', type=str, default=None,
|
58 |
+
help='Path to the QA dataset test file.')
|
59 |
+
|
60 |
+
# Faiss arguments for retriever
|
61 |
+
group.add_argument('--faiss-use-gpu', action='store_true',
|
62 |
+
help='Whether create the FaissMIPSIndex on GPU')
|
63 |
+
group.add_argument('--faiss-match', type=str, default='string', \
|
64 |
+
choices=['regex', 'string'], help="Answer matching '\
|
65 |
+
'logic type")
|
66 |
+
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
|
67 |
+
help='Number of blocks to use as top-k during retrieval')
|
68 |
+
|
69 |
+
# finetune for retriever
|
70 |
+
group.add_argument('--eval-micro-batch-size', type=int, default=None,
|
71 |
+
help='Eval Batch size per model instance (local batch '
|
72 |
+
'size). Global batch size is local batch size '
|
73 |
+
'times data parallel size.')
|
74 |
+
group.add_argument('--train-with-neg', action='store_true',
|
75 |
+
help='Whether to use negative examples during model '
|
76 |
+
'training')
|
77 |
+
group.add_argument('--train-hard-neg', type=int, default=0,
|
78 |
+
help='Number of hard negative exmaples to use during '
|
79 |
+
'training')
|
80 |
+
|
81 |
+
|
82 |
+
# parameters for Av.rank validation method
|
83 |
+
# Following options/arguments have been taken directly from DPR codebase
|
84 |
+
group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
|
85 |
+
help='Av.rank validation: how many hard negatives to'
|
86 |
+
' take from each question pool')
|
87 |
+
group.add_argument('--val-av-rank-other-neg', type=int, default=30,
|
88 |
+
help='Av.rank validation: how many other negatives to'
|
89 |
+
' take from each question pool')
|
90 |
+
|
91 |
+
|
92 |
+
return parser
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
|
97 |
+
initialize_megatron(extra_args_provider=get_tasks_args)
|
98 |
+
|
99 |
+
args = get_args()
|
100 |
+
|
101 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
102 |
+
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
|
103 |
+
exit()
|
104 |
+
|
105 |
+
if args.task == 'RACE':
|
106 |
+
from race.finetune import main
|
107 |
+
elif args.task in ['MNLI', 'QQP']:
|
108 |
+
from glue.finetune import main
|
109 |
+
elif args.task in ['AFQMC', 'CSL', 'IFLYTEK','OCNLI', 'TNEWS', 'WSC', 'CMNLI', "ZC"]:
|
110 |
+
from clue.finetune import main
|
111 |
+
elif args.task in ['LAMBADA', 'WIKITEXT103']:
|
112 |
+
from zeroshot_gpt.evaluate import main
|
113 |
+
elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
|
114 |
+
from orqa.evaluate_orqa import main
|
115 |
+
elif args.task in ['RET-FINETUNE-NQ']:
|
116 |
+
from orqa.supervised.finetune import main
|
117 |
+
else:
|
118 |
+
raise NotImplementedError('Task {} is not implemented.'.format(
|
119 |
+
args.task))
|
120 |
+
|
121 |
+
main()
|
tasks/msdp/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
|
3 |
+
|
4 |
+
Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework.
|
5 |
+
|
6 |
+
## Multi-Stage Dialogue Prompting
|
7 |
+
|
8 |
+
### Data Preparation
|
9 |
+
1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/)
|
10 |
+
2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets.
|
11 |
+
|
12 |
+
### Stage-1: Prompting for Knowledge Generation
|
13 |
+
1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation.
|
14 |
+
2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation.
|
15 |
+
|
16 |
+
### Stage-2: Prompting for Response Generation
|
17 |
+
1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file).
|
18 |
+
2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation.
|
19 |
+
3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation.
|
tasks/msdp/evaluate.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Model evaluation"""
|
17 |
+
|
18 |
+
from megatron import get_args
|
19 |
+
from megatron import print_rank_0
|
20 |
+
from tasks.msdp.metrics import F1Metric
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
|
24 |
+
def evaluate_f1(guess_file, answer_file):
|
25 |
+
"""Evaluating F1 Score"""
|
26 |
+
|
27 |
+
guess_list = []
|
28 |
+
print_rank_0('reading %s' % guess_file)
|
29 |
+
with open(guess_file, "r") as f:
|
30 |
+
for i, line in enumerate(tqdm(f)):
|
31 |
+
line = line.strip()
|
32 |
+
if "<|endoftext|>" in line:
|
33 |
+
line = line.replace("<|endoftext|>", "")
|
34 |
+
guess_list.append(line)
|
35 |
+
|
36 |
+
answer_list = []
|
37 |
+
print_rank_0('reading %s' % answer_file)
|
38 |
+
with open(answer_file, "r") as f:
|
39 |
+
for i, line in enumerate(tqdm(f)):
|
40 |
+
line = line.strip()
|
41 |
+
if line == "no_passages_used":
|
42 |
+
line = ""
|
43 |
+
answer_list.append(line)
|
44 |
+
|
45 |
+
assert len(guess_list) == len(answer_list), \
|
46 |
+
"lengths of guess and answer are different!"
|
47 |
+
|
48 |
+
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
|
49 |
+
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
|
50 |
+
|
51 |
+
print_rank_0('done :-)')
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
args = get_args()
|
56 |
+
|
57 |
+
evaluate_f1(args.guess_file, args.answer_file)
|
58 |
+
|
tasks/msdp/main.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Run multi-stage dialogue prompting (MSDP)."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
sys.path.append(os.path.abspath(os.path.join(
|
21 |
+
os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir)))
|
22 |
+
from megatron import get_args
|
23 |
+
from megatron.initialize import initialize_megatron
|
24 |
+
|
25 |
+
|
26 |
+
def get_tasks_args(parser):
|
27 |
+
"""Provide extra arguments required for tasks."""
|
28 |
+
group = parser.add_argument_group(title='tasks')
|
29 |
+
|
30 |
+
# parameters for the knowledgeable dialogue generation
|
31 |
+
group.add_argument('--task', type=str, required=True,
|
32 |
+
help='Task name.')
|
33 |
+
group.add_argument("--sample-input-file", type=str, default=None,
|
34 |
+
help='Get input from file instead of interactive mode, '
|
35 |
+
'each line is an input.')
|
36 |
+
group.add_argument("--sample-output-file", type=str, default=None,
|
37 |
+
help='Output file got from --sample-input-file')
|
38 |
+
group.add_argument('--prompt-file', type=str, default=None,
|
39 |
+
help='prompting file')
|
40 |
+
group.add_argument('--prompt-type', type=str, default=None,
|
41 |
+
choices=['knowledge', 'response'],
|
42 |
+
help='prompt type (knowledge or response)')
|
43 |
+
group.add_argument('--num-prompt-examples', type=int, default=10,
|
44 |
+
help='number of prompt examples')
|
45 |
+
group.add_argument('--guess-file', type=str, default=None,
|
46 |
+
help='datapath for generated sentences')
|
47 |
+
group.add_argument('--answer-file', type=str, default=None,
|
48 |
+
help='datapath for golden sentences')
|
49 |
+
group.add_argument('--out-seq-length', type=int, default=100,
|
50 |
+
help='output sequence length')
|
51 |
+
group.add_argument('--api-prompt', default=False, action="store_true",
|
52 |
+
help='setup model api for prompting')
|
53 |
+
group.add_argument('--megatron-api-url', type=str, default=None,
|
54 |
+
help='url of the megatron api')
|
55 |
+
|
56 |
+
return parser
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
|
61 |
+
initialize_megatron(extra_args_provider=get_tasks_args)
|
62 |
+
|
63 |
+
args = get_args()
|
64 |
+
|
65 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
66 |
+
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
|
67 |
+
exit()
|
68 |
+
|
69 |
+
if args.task == 'MSDP-PROMPT':
|
70 |
+
from tasks.msdp.prompt import main
|
71 |
+
|
72 |
+
elif args.task == 'MSDP-EVAL-F1':
|
73 |
+
from tasks.msdp.evaluate import main
|
74 |
+
|
75 |
+
else:
|
76 |
+
raise NotImplementedError('Task {} is not implemented.'.format(
|
77 |
+
args.task))
|
78 |
+
|
79 |
+
main()
|
tasks/msdp/metrics.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# The following code is adapted from
|
3 |
+
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
|
4 |
+
# which is licensed under the MIT license. More details on the license can be
|
5 |
+
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
|
6 |
+
|
7 |
+
"""Provides standard metric evaluations for dialog."""
|
8 |
+
|
9 |
+
from collections import Counter
|
10 |
+
from typing import List
|
11 |
+
import numpy as np
|
12 |
+
import re
|
13 |
+
|
14 |
+
re_art = re.compile(r'\b(a|an|the)\b')
|
15 |
+
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
|
16 |
+
|
17 |
+
|
18 |
+
def normalize_answer(s):
|
19 |
+
"""
|
20 |
+
Lower text and remove punctuation, articles and extra whitespace.
|
21 |
+
"""
|
22 |
+
s = s.lower()
|
23 |
+
s = re_punc.sub(' ', s)
|
24 |
+
s = re_art.sub(' ', s)
|
25 |
+
s = ' '.join(s.split())
|
26 |
+
return s
|
27 |
+
|
28 |
+
|
29 |
+
class F1Metric:
|
30 |
+
"""
|
31 |
+
Helper class which computes token-level F1.
|
32 |
+
"""
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def _prec_recall_f1_score(pred_items, gold_items):
|
36 |
+
"""
|
37 |
+
Compute precision, recall and f1 given a set of gold and prediction items.
|
38 |
+
:param pred_items: iterable of predicted values
|
39 |
+
:param gold_items: iterable of gold values
|
40 |
+
:return: tuple (p, r, f1) for precision, recall, f1
|
41 |
+
"""
|
42 |
+
common = Counter(gold_items) & Counter(pred_items)
|
43 |
+
num_same = sum(common.values())
|
44 |
+
if num_same == 0:
|
45 |
+
return 0, 0, 0
|
46 |
+
precision = 1.0 * num_same / len(pred_items)
|
47 |
+
recall = 1.0 * num_same / len(gold_items)
|
48 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
49 |
+
return precision, recall, f1
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def compute_each_pair(guess: str, answer: str):
|
53 |
+
if answer == "":
|
54 |
+
return None, None, None
|
55 |
+
if guess == "":
|
56 |
+
return 0, 0, 0
|
57 |
+
g_tokens = normalize_answer(guess).split()
|
58 |
+
a_tokens = normalize_answer(answer).split()
|
59 |
+
|
60 |
+
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
|
61 |
+
return precision, recall, f1
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def compute_all_pairs(guesses: List[str], answers: List[str]):
|
65 |
+
# additional augment:
|
66 |
+
assert len(guesses) == len(answers)
|
67 |
+
|
68 |
+
precision_list, recall_list, f1_list = [], [], []
|
69 |
+
for guess, answer in zip(guesses, answers):
|
70 |
+
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
|
71 |
+
if precision is None or recall is None or f1 is None:
|
72 |
+
continue
|
73 |
+
precision_list.append(precision)
|
74 |
+
recall_list.append(recall)
|
75 |
+
f1_list.append(f1)
|
76 |
+
|
77 |
+
return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
|
tasks/msdp/preprocessing.py
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import argparse
|
20 |
+
from nltk import word_tokenize
|
21 |
+
from tqdm import tqdm
|
22 |
+
import numpy as np
|
23 |
+
import json
|
24 |
+
|
25 |
+
def get_args():
|
26 |
+
parser = argparse.ArgumentParser(description="Preprocessing")
|
27 |
+
|
28 |
+
parser.add_argument("--func", type=str, default=None,
|
29 |
+
help="choose to run which function")
|
30 |
+
parser.add_argument("--raw_file", type=str, default=None,
|
31 |
+
help="path of the input file")
|
32 |
+
parser.add_argument("--processed_file", type=str, default=None,
|
33 |
+
help="path of the output file")
|
34 |
+
parser.add_argument("--knwl_ref_file", type=str, default=None,
|
35 |
+
help="path of the knowledge reference file")
|
36 |
+
parser.add_argument("--resp_ref_file", type=str, default=None,
|
37 |
+
help="path of the knowledge reference file")
|
38 |
+
parser.add_argument("--knwl_gen_file", type=str, default=None,
|
39 |
+
help="path of the generated knowledge file")
|
40 |
+
parser.add_argument("--test_file", type=str, default=None,
|
41 |
+
help="path of the test file")
|
42 |
+
parser.add_argument("--train_file", type=str, default=None,
|
43 |
+
help="path of the train file")
|
44 |
+
parser.add_argument("--model_file", type=str, default=None,
|
45 |
+
help="path of the model file")
|
46 |
+
parser.add_argument("--data_type", type=str, default=None,
|
47 |
+
help="data types, choose one out of three types: \
|
48 |
+
wow_seen, wow_unseen, and woi")
|
49 |
+
parser.add_argument("--seed", type=int, default=1234,
|
50 |
+
help="random seed")
|
51 |
+
|
52 |
+
args = parser.parse_args()
|
53 |
+
return args
|
54 |
+
|
55 |
+
|
56 |
+
def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
|
57 |
+
"""
|
58 |
+
This is a function used for processing the wizard of wikipedia (wow) dataset
|
59 |
+
Expected processed format:
|
60 |
+
topic \t dialogue context \t golden knowledge \t golden response
|
61 |
+
"""
|
62 |
+
|
63 |
+
# loading the raw data
|
64 |
+
print("> Loading data from %s" % raw_file)
|
65 |
+
with open(raw_file, "r") as fr:
|
66 |
+
dialog_data = json.load(fr)
|
67 |
+
|
68 |
+
print("> Processing data ...")
|
69 |
+
fproc = open(processed_file, "w")
|
70 |
+
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
|
71 |
+
fresp = open(resp_ref_file, "w") if resp_ref_file else None
|
72 |
+
|
73 |
+
for i, sample in enumerate(tqdm(dialog_data)):
|
74 |
+
# get all the dialog data for a single dialog sample
|
75 |
+
dialog = sample["dialog"]
|
76 |
+
|
77 |
+
turn_list = [] # collect the dialog history
|
78 |
+
# processing for each single dialog sample
|
79 |
+
for j, turn in enumerate(dialog):
|
80 |
+
# text of each turn
|
81 |
+
text = turn["text"]
|
82 |
+
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
|
83 |
+
text = text + "."
|
84 |
+
|
85 |
+
if j == 0:
|
86 |
+
# first turn
|
87 |
+
turn_list.append(text)
|
88 |
+
continue
|
89 |
+
|
90 |
+
speaker = turn["speaker"].lower()
|
91 |
+
if "wizard" in speaker:
|
92 |
+
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
|
93 |
+
checked_passage = list(turn["checked_passage"].values()) # topic
|
94 |
+
|
95 |
+
assert len(checked_sentence) <= 1
|
96 |
+
|
97 |
+
# get the ground truth knowledge
|
98 |
+
if len(checked_sentence) > 0:
|
99 |
+
checked_sentence = checked_sentence[0]
|
100 |
+
else:
|
101 |
+
checked_sentence = "no_passages_used"
|
102 |
+
|
103 |
+
if len(checked_passage) == 1:
|
104 |
+
checked_passage = checked_passage[0]
|
105 |
+
else:
|
106 |
+
checked_passage = "no_passages_used"
|
107 |
+
|
108 |
+
# get the topic
|
109 |
+
if checked_passage != "no_passages_used":
|
110 |
+
topic = checked_passage
|
111 |
+
else:
|
112 |
+
topic = sample["chosen_topic"]
|
113 |
+
|
114 |
+
dialog_context = " [SEP] ".join(turn_list)
|
115 |
+
knowledge = checked_sentence
|
116 |
+
response = text
|
117 |
+
# add the response into the dialog history
|
118 |
+
turn_list.append(response)
|
119 |
+
|
120 |
+
# write to the output files
|
121 |
+
fproc.write(topic + "\t" + dialog_context + "\t" + \
|
122 |
+
knowledge + "\t" + response + "\n")
|
123 |
+
|
124 |
+
if fknwl:
|
125 |
+
fknwl.write(knowledge + "\n")
|
126 |
+
if fresp:
|
127 |
+
# tokenize for evaluation
|
128 |
+
response = " ".join(word_tokenize(response))
|
129 |
+
fresp.write(response + "\n")
|
130 |
+
|
131 |
+
else:
|
132 |
+
assert "apprentice" in speaker
|
133 |
+
turn_list.append(text)
|
134 |
+
|
135 |
+
fproc.close()
|
136 |
+
if fknwl:
|
137 |
+
fknwl.close()
|
138 |
+
if fresp:
|
139 |
+
fresp.close()
|
140 |
+
|
141 |
+
|
142 |
+
def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
|
143 |
+
"""
|
144 |
+
This is a function used for processing the wizard of internet (woi) dataset
|
145 |
+
Expected processed format:
|
146 |
+
topic \t dialogue context \t golden knowledge \t golden response
|
147 |
+
"""
|
148 |
+
|
149 |
+
print("> Processing %s" % raw_file)
|
150 |
+
fproc = open(processed_file, "w")
|
151 |
+
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
|
152 |
+
fresp = open(resp_ref_file, "w") if resp_ref_file else None
|
153 |
+
|
154 |
+
with open(raw_file, "r") as fr:
|
155 |
+
for i, line in tqdm(enumerate(fr)):
|
156 |
+
# read line by line, each line uses json format
|
157 |
+
line = line.strip()
|
158 |
+
item_dict = json.loads(line)
|
159 |
+
|
160 |
+
# item_dict is a dictionary
|
161 |
+
# its key is the data id, and its value contains all the data content
|
162 |
+
item_dict = item_dict.values()
|
163 |
+
item_dict = list(item_dict)[0] # len(item_dict) == 1
|
164 |
+
|
165 |
+
# get the whole dialog data for a single dialog sample
|
166 |
+
dialog_data = item_dict['dialog_history']
|
167 |
+
length = len(dialog_data)
|
168 |
+
|
169 |
+
turn_list = [] # collect the dialog history
|
170 |
+
search_text = ""
|
171 |
+
for i in range(length):
|
172 |
+
item = dialog_data[i]
|
173 |
+
action = item['action']
|
174 |
+
|
175 |
+
if action == "Wizard => SearchAgent":
|
176 |
+
search_text = item['text']
|
177 |
+
|
178 |
+
elif action == "Wizard => Apprentice":
|
179 |
+
if len(turn_list) == 0:
|
180 |
+
# first turn
|
181 |
+
turn = item['text']
|
182 |
+
turn_list.append(turn)
|
183 |
+
continue
|
184 |
+
|
185 |
+
# get the relevant content
|
186 |
+
contents = item["context"]["contents"]
|
187 |
+
selects = item["context"]["selected_contents"]
|
188 |
+
flag = selects[0][0]
|
189 |
+
selects = selects[1:]
|
190 |
+
assert len(selects) == len(contents)
|
191 |
+
|
192 |
+
# get the topic
|
193 |
+
if flag:
|
194 |
+
# no knowledge sentence is used for the response
|
195 |
+
topic = "no_topic"
|
196 |
+
knwl_sent = "no_passages_used"
|
197 |
+
else:
|
198 |
+
# we consider the search text as the topic
|
199 |
+
topic = search_text
|
200 |
+
# get the knowledge sentence
|
201 |
+
knwl_sent = ""
|
202 |
+
for content, select in zip(contents, selects):
|
203 |
+
content = content['content']
|
204 |
+
assert len(content) == len(select)
|
205 |
+
for c, s in zip(content, select):
|
206 |
+
if s:
|
207 |
+
knwl_sent = c
|
208 |
+
break
|
209 |
+
|
210 |
+
if knwl_sent == "":
|
211 |
+
# no knowledge is used for the response
|
212 |
+
topic = "no_topic"
|
213 |
+
knwl_sent = "no_passages_used"
|
214 |
+
|
215 |
+
# get dialogue context, knowledge, and response
|
216 |
+
dialog_context = " [SEP] ".join(turn_list)
|
217 |
+
response = item['text']
|
218 |
+
|
219 |
+
# processing
|
220 |
+
topic = topic.replace("\n", "").replace("\r", \
|
221 |
+
"").replace("\t", "")
|
222 |
+
dialog_context = dialog_context.replace("\n", "").replace("\r", \
|
223 |
+
"").replace("\t", "")
|
224 |
+
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
|
225 |
+
"").replace("\t", "")
|
226 |
+
response = response.replace("\n", "").replace("\r", \
|
227 |
+
"").replace("\t", "")
|
228 |
+
|
229 |
+
if topic != "no_topic":
|
230 |
+
# write to the ouput files
|
231 |
+
fproc.write(topic + "\t" + dialog_context + "\t" + \
|
232 |
+
knwl_sent + "\t" + response + "\n")
|
233 |
+
if fknwl:
|
234 |
+
fknwl.write(knwl_sent + "\n")
|
235 |
+
if fresp:
|
236 |
+
# tokenize for evaluation
|
237 |
+
response = " ".join(word_tokenize(response))
|
238 |
+
fresp.write(response + "\n")
|
239 |
+
|
240 |
+
turn_list.append(response)
|
241 |
+
|
242 |
+
elif action == "Apprentice => Wizard":
|
243 |
+
turn = item['text']
|
244 |
+
turn_list.append(turn)
|
245 |
+
|
246 |
+
else:
|
247 |
+
assert action == "SearchAgent => Wizard", \
|
248 |
+
"Please check whether you have used the correct data!"
|
249 |
+
|
250 |
+
fproc.close()
|
251 |
+
if fknwl:
|
252 |
+
fknwl.close()
|
253 |
+
if fresp:
|
254 |
+
fresp.close()
|
255 |
+
|
256 |
+
|
257 |
+
def get_database(test_datapath, train_datapath, data_type):
|
258 |
+
"""Get the database by topics"""
|
259 |
+
|
260 |
+
assert data_type in ["wow_seen", "wow_unseen", "woi"], \
|
261 |
+
"Please input a correct data type!!"
|
262 |
+
|
263 |
+
# get test data topic dictionary
|
264 |
+
print("> reading test data from %s" % test_datapath)
|
265 |
+
test_topics = {}
|
266 |
+
with open(test_datapath, "r") as f:
|
267 |
+
for i, line in enumerate(f):
|
268 |
+
line = line.strip()
|
269 |
+
splits = line.split("\t")
|
270 |
+
topic = splits[0]
|
271 |
+
test_topics[topic] = True
|
272 |
+
|
273 |
+
print("> reading data from %s" % train_datapath)
|
274 |
+
train_data_by_topic = {}
|
275 |
+
dialog_data_by_topic = {}
|
276 |
+
dialog_examples = []
|
277 |
+
with open(train_datapath, "r") as f:
|
278 |
+
for i, line in enumerate(f):
|
279 |
+
line = line.strip()
|
280 |
+
splits = line.split("\t")
|
281 |
+
topic = splits[0]
|
282 |
+
turns = splits[1].split(" [SEP] ")[-3:]
|
283 |
+
knowledge = splits[2]
|
284 |
+
response = splits[3]
|
285 |
+
# filtering data samples
|
286 |
+
if knowledge == "no_passages_used":
|
287 |
+
# when no knowledge is used
|
288 |
+
continue
|
289 |
+
if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge):
|
290 |
+
# when bracket exists in the knowledge
|
291 |
+
continue
|
292 |
+
if data_type != "wow_seen" and topic not in knowledge:
|
293 |
+
# when topic does not exist in the knowledge
|
294 |
+
continue
|
295 |
+
|
296 |
+
# get the instance
|
297 |
+
last_turn = turns[-1]
|
298 |
+
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
|
299 |
+
|
300 |
+
# construct dialog example
|
301 |
+
dialog_example = ""
|
302 |
+
if data_type != "wow_seen":
|
303 |
+
dialog_example += "( " + topic + " ) "
|
304 |
+
for i, turn in enumerate(turns):
|
305 |
+
if i != 0:
|
306 |
+
dialog_example += " "
|
307 |
+
dialog_example += turn
|
308 |
+
|
309 |
+
# check overlaps
|
310 |
+
if topic in test_topics:
|
311 |
+
if topic not in train_data_by_topic:
|
312 |
+
train_data_by_topic[topic] = [instance]
|
313 |
+
else:
|
314 |
+
train_data_by_topic[topic].append(instance)
|
315 |
+
|
316 |
+
if topic not in dialog_data_by_topic:
|
317 |
+
dialog_data_by_topic[topic] = [dialog_example]
|
318 |
+
else:
|
319 |
+
dialog_data_by_topic[topic].append(dialog_example)
|
320 |
+
|
321 |
+
else:
|
322 |
+
# filtering data samples
|
323 |
+
if len(knowledge.split()) > 20:
|
324 |
+
# knowledge is too long
|
325 |
+
continue
|
326 |
+
if knowledge.startswith("It") or knowledge.startswith("it") or \
|
327 |
+
knowledge.startswith("This") or knowledge.startswith("this"):
|
328 |
+
continue
|
329 |
+
|
330 |
+
# append all the data into dialogue examples list
|
331 |
+
dialog_examples.append((topic, dialog_example, instance))
|
332 |
+
|
333 |
+
return train_data_by_topic, dialog_data_by_topic, dialog_examples
|
334 |
+
|
335 |
+
|
336 |
+
emb_dict = {}
|
337 |
+
def select_prompts_based_on_similarity(
|
338 |
+
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
|
339 |
+
"""Select samples based on the similarity"""
|
340 |
+
|
341 |
+
with torch.no_grad():
|
342 |
+
# get the query embeddings
|
343 |
+
query_ids = tokenizer.encode(query)
|
344 |
+
query_ids = torch.LongTensor([query_ids]).cuda()
|
345 |
+
query_emb = encoder(input_ids=query_ids).pooler_output
|
346 |
+
query_emb = query_emb[0]
|
347 |
+
|
348 |
+
# calculate embeddings for the samples in the database
|
349 |
+
if topic in emb_dict:
|
350 |
+
example_embeddings = emb_dict[topic]
|
351 |
+
example_embeddings = example_embeddings.cuda()
|
352 |
+
else:
|
353 |
+
for idx, example in enumerate(dialog_list):
|
354 |
+
example_ids = tokenizer.encode(example)
|
355 |
+
example_ids = torch.LongTensor([example_ids]).cuda()
|
356 |
+
example_emb = encoder(input_ids=example_ids).pooler_output
|
357 |
+
if idx == 0:
|
358 |
+
example_embeddings = example_emb
|
359 |
+
else:
|
360 |
+
example_embeddings = torch.cat(
|
361 |
+
(example_embeddings, example_emb), dim=0)
|
362 |
+
emb_dict[topic] = example_embeddings.cpu()
|
363 |
+
|
364 |
+
# compare the similarity and select the topk samples
|
365 |
+
similarity_list = example_embeddings.matmul(query_emb)
|
366 |
+
_, indices = torch.topk(similarity_list, k=topk)
|
367 |
+
|
368 |
+
indices = indices.tolist()
|
369 |
+
indices = indices[::-1] # reverse the order
|
370 |
+
selected_prompts = []
|
371 |
+
for index in indices:
|
372 |
+
# index = index.item()
|
373 |
+
selected_prompts.append(prompt_list[index])
|
374 |
+
|
375 |
+
return selected_prompts
|
376 |
+
|
377 |
+
|
378 |
+
def prompt_selection_for_knowledge_generation(
|
379 |
+
test_datapath, train_datapath, model_path, output_prompt_path, data_type):
|
380 |
+
"""Selecting prompts for the knowledge generation"""
|
381 |
+
|
382 |
+
print("> Selecting prompts for the knowledge generation")
|
383 |
+
|
384 |
+
train_data_by_topic, dialog_data_by_topic, dialog_examples = \
|
385 |
+
get_database(test_datapath, train_datapath, data_type)
|
386 |
+
|
387 |
+
from transformers import DPRQuestionEncoderTokenizer
|
388 |
+
print("> loading tokenizer and encoder")
|
389 |
+
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
390 |
+
'facebook/dpr-question_encoder-single-nq-base')
|
391 |
+
encoder = torch.load(model_path).cuda()
|
392 |
+
|
393 |
+
print("> getting dialog embeddings")
|
394 |
+
with torch.no_grad():
|
395 |
+
for idx, example in tqdm(enumerate(dialog_examples)):
|
396 |
+
dialog = example[1]
|
397 |
+
dialog_ids = tokenizer.encode(dialog)
|
398 |
+
dialog_ids = torch.LongTensor([dialog_ids]).cuda()
|
399 |
+
dialog_emb = encoder(input_ids=dialog_ids).pooler_output
|
400 |
+
|
401 |
+
if idx == 0:
|
402 |
+
dialog_embeddings = dialog_emb
|
403 |
+
else:
|
404 |
+
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
|
405 |
+
|
406 |
+
print("> reading test data from %s" % test_datapath)
|
407 |
+
prompt_list_for_each_sample = []
|
408 |
+
with open(test_datapath, "r") as f:
|
409 |
+
for i, line in tqdm(enumerate(f)):
|
410 |
+
line = line.strip()
|
411 |
+
|
412 |
+
splits = line.split("\t")
|
413 |
+
topic = splits[0]
|
414 |
+
turns = splits[1].split(" [SEP] ")[-3:]
|
415 |
+
|
416 |
+
# get the query sentence
|
417 |
+
query_sent = ""
|
418 |
+
if data_type != "seen":
|
419 |
+
query_sent += "( " + topic + " ) "
|
420 |
+
for i, turn in enumerate(turns):
|
421 |
+
if i != 0:
|
422 |
+
query_sent += " "
|
423 |
+
query_sent += turn
|
424 |
+
|
425 |
+
if topic not in train_data_by_topic:
|
426 |
+
# get the query embedding
|
427 |
+
query_ids = tokenizer.encode(query_sent)
|
428 |
+
query_ids = torch.LongTensor([query_ids]).cuda()
|
429 |
+
query_emb = encoder(input_ids=query_ids).pooler_output
|
430 |
+
query_emb = query_emb[0]
|
431 |
+
|
432 |
+
# calculate the similarity
|
433 |
+
similarity_list = dialog_embeddings.matmul(query_emb)
|
434 |
+
_, indices = torch.sort(similarity_list)
|
435 |
+
indices = indices.tolist()
|
436 |
+
selected_topics = {}
|
437 |
+
selected_prompts = []
|
438 |
+
num_prompt = 0
|
439 |
+
for index in indices:
|
440 |
+
example = dialog_examples[index]
|
441 |
+
topic_temp = example[0]
|
442 |
+
if topic_temp not in selected_topics:
|
443 |
+
selected_topics[topic_temp] = True
|
444 |
+
selected_prompts.append(example[2])
|
445 |
+
num_prompt += 1
|
446 |
+
if num_prompt == 10:
|
447 |
+
break
|
448 |
+
|
449 |
+
# get the selected samples
|
450 |
+
example_list = selected_prompts[::-1]
|
451 |
+
key = topic + " " + turns[-1]
|
452 |
+
prompt_list_for_each_sample.append({key: example_list})
|
453 |
+
|
454 |
+
else:
|
455 |
+
num_data_sample = min(len(train_data_by_topic[topic]), 10)
|
456 |
+
total_example_list = train_data_by_topic[topic]
|
457 |
+
|
458 |
+
dialog_list = dialog_data_by_topic[topic]
|
459 |
+
assert len(dialog_list) == len(train_data_by_topic[topic])
|
460 |
+
|
461 |
+
# calculate the similarity
|
462 |
+
example_list = select_prompts_based_on_similarity(
|
463 |
+
query_sent, dialog_list, total_example_list,
|
464 |
+
topic, tokenizer, encoder, topk=num_data_sample)
|
465 |
+
|
466 |
+
key = topic + " " + turns[-1]
|
467 |
+
prompt_list_for_each_sample.append({key: example_list})
|
468 |
+
|
469 |
+
print("writing to %s" % output_prompt_path)
|
470 |
+
with open(output_prompt_path, "w") as f:
|
471 |
+
for instance in tqdm(prompt_list_for_each_sample):
|
472 |
+
json.dump(instance, f)
|
473 |
+
f.write("\n")
|
474 |
+
|
475 |
+
|
476 |
+
def prompt_selection_for_response_generation(input_path, output_path, seed):
|
477 |
+
"""Selecting prompts for the response generation"""
|
478 |
+
|
479 |
+
print("> Selecting prompts for the response generation")
|
480 |
+
print("> set random seed")
|
481 |
+
np.random.seed(seed)
|
482 |
+
|
483 |
+
prompt_example_list = []
|
484 |
+
print("> reading data from %s" % input_path)
|
485 |
+
with open(input_path, "r") as f:
|
486 |
+
for i, line in tqdm(enumerate(f)):
|
487 |
+
line = line.strip()
|
488 |
+
splits = line.split("\t")
|
489 |
+
|
490 |
+
# get the topic, context, knowledge and response
|
491 |
+
topic = splits[0]
|
492 |
+
dialog_context = splits[1]
|
493 |
+
knowledge = splits[2]
|
494 |
+
response = splits[3]
|
495 |
+
turns = dialog_context.split(" [SEP] ")[-3:]
|
496 |
+
if knowledge == "no_passages_used":
|
497 |
+
continue
|
498 |
+
|
499 |
+
# calculate the overlap ratio
|
500 |
+
from nltk import word_tokenize
|
501 |
+
knowledge_sent_token_list = word_tokenize(knowledge)
|
502 |
+
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
|
503 |
+
knowledge_len = len(knowledge_sent_token_list)
|
504 |
+
response_token_list = word_tokenize(response)
|
505 |
+
response_len = len(response_token_list)
|
506 |
+
num_overlap_token = 0
|
507 |
+
accumulator = 0
|
508 |
+
for token in response_token_list:
|
509 |
+
if token in knowledge_sent_token_dict:
|
510 |
+
accumulator += 1
|
511 |
+
else:
|
512 |
+
if accumulator >= 10:
|
513 |
+
num_overlap_token += accumulator
|
514 |
+
accumulator = 0
|
515 |
+
if accumulator >= 10:
|
516 |
+
num_overlap_token += accumulator
|
517 |
+
|
518 |
+
# filtering the data based on the ratio
|
519 |
+
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
|
520 |
+
continue
|
521 |
+
if num_overlap_token < knowledge_len * 0.8:
|
522 |
+
continue
|
523 |
+
|
524 |
+
last_turn = " ".join(word_tokenize(turns[-1]))
|
525 |
+
knowledge = " ".join(word_tokenize(knowledge))
|
526 |
+
response = " ".join(word_tokenize(response))
|
527 |
+
prompt_example = ""
|
528 |
+
# add dialog context
|
529 |
+
prompt_example += "Topic: " + topic + ". "
|
530 |
+
prompt_example += "User says: " + last_turn + " "
|
531 |
+
prompt_example += "We know that: " + knowledge + " "
|
532 |
+
prompt_example += "System replies: " + response
|
533 |
+
|
534 |
+
prompt_example_list.append(prompt_example)
|
535 |
+
|
536 |
+
# shuffle the prompt examples
|
537 |
+
np.random.shuffle(prompt_example_list)
|
538 |
+
|
539 |
+
print("> writing to %s" % output_path)
|
540 |
+
with open(output_path, "w") as f:
|
541 |
+
# f.write("Generate the System's response based on the knowledge sentence:\n")
|
542 |
+
for i in tqdm(range(20)):
|
543 |
+
example = prompt_example_list[i]
|
544 |
+
f.write(example + "\n")
|
545 |
+
|
546 |
+
|
547 |
+
def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
|
548 |
+
"""Preparing inputs for the response generation"""
|
549 |
+
|
550 |
+
print("> Reading knowledge file from %s" % knwl_gen_file)
|
551 |
+
# get the knowledge list
|
552 |
+
with open(knwl_gen_file, "r") as f:
|
553 |
+
knowledge_list = f.readlines()
|
554 |
+
|
555 |
+
print("> Processing ...")
|
556 |
+
with open(test_file, "r") as fr:
|
557 |
+
with open(processed_file, "w") as fw:
|
558 |
+
for line_num, line in enumerate(tqdm(fr)):
|
559 |
+
line = line.strip()
|
560 |
+
splits = line.split("\t")
|
561 |
+
# prepare topic, context, knowledge and response
|
562 |
+
topic = splits[0]
|
563 |
+
dialog_context = splits[1]
|
564 |
+
response = splits[3]
|
565 |
+
knowledge = knowledge_list[line_num]
|
566 |
+
knowledge = knowledge.strip()
|
567 |
+
if "<|endoftext|>" in knowledge:
|
568 |
+
knowledge = knowledge.replace("<|endoftext|>", "")
|
569 |
+
|
570 |
+
# write to the output file
|
571 |
+
fw.write(topic + "\t" + dialog_context + "\t" \
|
572 |
+
+ knowledge + "\t" + response + "\n")
|
573 |
+
|
574 |
+
|
575 |
+
if __name__ == "__main__":
|
576 |
+
|
577 |
+
args = get_args()
|
578 |
+
if args.func == "process_wow_dataset":
|
579 |
+
process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
|
580 |
+
|
581 |
+
elif args.func == "process_woi_dataset":
|
582 |
+
process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
|
583 |
+
|
584 |
+
elif args.func == "get_knwl_gen_prompts":
|
585 |
+
prompt_selection_for_knowledge_generation(
|
586 |
+
args.test_file, args.train_file, args.model_file,
|
587 |
+
args.processed_file, args.data_type)
|
588 |
+
|
589 |
+
elif args.func == "get_resp_gen_prompts":
|
590 |
+
prompt_selection_for_response_generation(
|
591 |
+
args.train_file, args.processed_file, args.seed)
|
592 |
+
|
593 |
+
elif args.func == "prepare_input":
|
594 |
+
prepare_input_for_response_generation(
|
595 |
+
args.test_file, args.knwl_gen_file, args.processed_file)
|
tasks/msdp/prompt.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Prompting the pretrained language model to generate knowledge/response"""
|
17 |
+
|
18 |
+
import json
|
19 |
+
import torch
|
20 |
+
import requests
|
21 |
+
from nltk import word_tokenize
|
22 |
+
from megatron import mpu
|
23 |
+
from megatron import get_args
|
24 |
+
from megatron import print_rank_0
|
25 |
+
from megatron import get_tokenizer
|
26 |
+
from megatron.model import GPTModel
|
27 |
+
from megatron.training import get_model
|
28 |
+
from megatron.checkpointing import load_checkpoint
|
29 |
+
from megatron.initialize import initialize_megatron
|
30 |
+
from megatron.text_generation import generate_and_post_process
|
31 |
+
|
32 |
+
|
33 |
+
def call_model_api(inputs, tokens_to_generate):
|
34 |
+
"""Calling the model api to get the output generations"""
|
35 |
+
|
36 |
+
args = get_args()
|
37 |
+
|
38 |
+
# The following is an example of using the Megatron API
|
39 |
+
# You can also implement your own API function to place this part
|
40 |
+
headers = {'Content-Type': 'application/json; charset=UTF-8'}
|
41 |
+
data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
|
42 |
+
data_json = json.dumps(data)
|
43 |
+
outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
|
44 |
+
|
45 |
+
input_len = len(inputs)
|
46 |
+
outputs = outputs[input_len:]
|
47 |
+
outputs = outputs.split("\n")[0].strip()
|
48 |
+
|
49 |
+
return outputs
|
50 |
+
|
51 |
+
|
52 |
+
def read_prompts(prompt_path, prompt_type, n_example):
|
53 |
+
"""Read prompt data"""
|
54 |
+
|
55 |
+
if prompt_type == "knowledge":
|
56 |
+
# prompts for the knowledge generation
|
57 |
+
prompt_examples_dict = {}
|
58 |
+
# read prompt_path
|
59 |
+
with open(prompt_path, "r") as f:
|
60 |
+
for i, line in enumerate(f):
|
61 |
+
line = line.strip()
|
62 |
+
line_dict = json.loads(line)
|
63 |
+
key = list(line_dict.keys())[0]
|
64 |
+
|
65 |
+
if key not in prompt_examples_dict:
|
66 |
+
prompt_examples = line_dict[key]
|
67 |
+
prompt = ""
|
68 |
+
for instance in prompt_examples:
|
69 |
+
instance = instance.strip()
|
70 |
+
prompt += instance + " \n"
|
71 |
+
prompt_examples_dict[key] = prompt
|
72 |
+
|
73 |
+
return prompt_examples_dict
|
74 |
+
|
75 |
+
else:
|
76 |
+
# prompts for the response generation
|
77 |
+
# read prompt_path
|
78 |
+
prompt = ""
|
79 |
+
with open(prompt_path, "r") as f:
|
80 |
+
prompt_examples = f.readlines()
|
81 |
+
prompt_examples = prompt_examples[:n_example]
|
82 |
+
for instance in prompt_examples:
|
83 |
+
instance = instance.strip()
|
84 |
+
prompt += instance + " \n"
|
85 |
+
|
86 |
+
return prompt
|
87 |
+
|
88 |
+
|
89 |
+
def generate_samples_by_calling_api():
|
90 |
+
""" Generate outputs by calling"""
|
91 |
+
args = get_args()
|
92 |
+
assert args.prompt_type in ["knowledge", "response"], \
|
93 |
+
"Please input a correct prompt type!"
|
94 |
+
|
95 |
+
if args.prompt_type == "knowledge":
|
96 |
+
# read knowledge generation prompts
|
97 |
+
knwl_gen_prompt_dict = read_prompts(
|
98 |
+
args.prompt_file, args.prompt_type, args.num_prompt_examples)
|
99 |
+
|
100 |
+
else:
|
101 |
+
resp_gen_prompt = read_prompts(
|
102 |
+
args.prompt_file, args.prompt_type, args.num_prompt_examples)
|
103 |
+
|
104 |
+
# read the test data
|
105 |
+
fname = open(args.sample_input_file, "r")
|
106 |
+
test_sample_list = fname.readlines()
|
107 |
+
# create output file
|
108 |
+
fname_out = open(args.sample_output_file, "w")
|
109 |
+
|
110 |
+
# call the api to get the output generations
|
111 |
+
for test_sample in test_sample_list:
|
112 |
+
test_sample = test_sample.strip()
|
113 |
+
splits = test_sample.split("\t")
|
114 |
+
topic = splits[0]
|
115 |
+
|
116 |
+
# prepare the inputs for the api
|
117 |
+
if args.prompt_type == "knowledge":
|
118 |
+
## inputs = prompt + current test
|
119 |
+
# get the prompt
|
120 |
+
turns = splits[1].split(" [SEP] ")
|
121 |
+
last_turn = turns[-1]
|
122 |
+
key = topic + " " + last_turn
|
123 |
+
inputs = knwl_gen_prompt_dict[key]
|
124 |
+
|
125 |
+
# add current test
|
126 |
+
inputs += "( " + last_turn + " ) " + topic + " =>"
|
127 |
+
|
128 |
+
else:
|
129 |
+
# inputs = prompt + current test
|
130 |
+
# get the prompt
|
131 |
+
inputs = resp_gen_prompt
|
132 |
+
|
133 |
+
# add current test
|
134 |
+
turns = splits[1].split(" [SEP] ")
|
135 |
+
knowledge = splits[2]
|
136 |
+
last_turn = turns[-1]
|
137 |
+
last_turn = " ".join(word_tokenize(last_turn))
|
138 |
+
knowledge = " ".join(word_tokenize(knowledge))
|
139 |
+
knowledge = knowledge.strip()
|
140 |
+
last_turn = last_turn.strip()
|
141 |
+
inputs += "Topic: " + topic + ". "
|
142 |
+
inputs += "User says: " + last_turn + " "
|
143 |
+
inputs += "We know that: " + knowledge + " "
|
144 |
+
inputs += "System replies:"
|
145 |
+
|
146 |
+
# get the output generations from the api,
|
147 |
+
# and write to the output file
|
148 |
+
generations = call_model_api(inputs, args.out_seq_length)
|
149 |
+
fname_out.write(generations)
|
150 |
+
fname_out.write("\n")
|
151 |
+
|
152 |
+
fname.close()
|
153 |
+
fname_out.close()
|
154 |
+
|
155 |
+
|
156 |
+
def model_provider(pre_process=True, post_process=True):
|
157 |
+
"""Build the model."""
|
158 |
+
|
159 |
+
print_rank_0('building GPT model ...')
|
160 |
+
model = GPTModel(
|
161 |
+
num_tokentypes=0,
|
162 |
+
parallel_output=True,
|
163 |
+
pre_process=pre_process,
|
164 |
+
post_process=post_process
|
165 |
+
)
|
166 |
+
return model
|
167 |
+
|
168 |
+
|
169 |
+
def generate_samples_by_prompting_input_from_file(model):
|
170 |
+
"""Prompt a pretrained language model to generate knowledge/response"""
|
171 |
+
|
172 |
+
# get tokenizer
|
173 |
+
args = get_args()
|
174 |
+
tokenizer = get_tokenizer()
|
175 |
+
|
176 |
+
# Read the sample file and open the output file.
|
177 |
+
assert args.sample_input_file is not None, \
|
178 |
+
'sample input file is not provided.'
|
179 |
+
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
|
180 |
+
fname = open(args.sample_input_file, "r")
|
181 |
+
all_raw_text = fname.readlines()
|
182 |
+
input_count = len(all_raw_text)
|
183 |
+
if args.sample_output_file is None:
|
184 |
+
sample_output_file = args.sample_input_file + ".out"
|
185 |
+
print('`sample-output-file` not specified, setting '
|
186 |
+
'it to {}'.format(sample_output_file))
|
187 |
+
else:
|
188 |
+
sample_output_file = args.sample_output_file
|
189 |
+
|
190 |
+
fname_out = open(sample_output_file, "w")
|
191 |
+
|
192 |
+
# only two prompt types (i.e., knowledge and response) are allowed
|
193 |
+
assert args.prompt_type in ["knowledge", "response"], \
|
194 |
+
"Please input a correct prompt type!"
|
195 |
+
|
196 |
+
# Read the prompt file
|
197 |
+
if args.prompt_type == "knowledge":
|
198 |
+
# read the prompts for the knowledge generation
|
199 |
+
prompt_examples_dict = {}
|
200 |
+
with open(args.prompt_file, "r") as f:
|
201 |
+
for i, line in enumerate(f):
|
202 |
+
line = line.strip()
|
203 |
+
line_dict = json.loads(line)
|
204 |
+
key = list(line_dict.keys())[0]
|
205 |
+
|
206 |
+
# get the prompt examples based on the key
|
207 |
+
if key not in prompt_examples_dict:
|
208 |
+
prompt_examples = line_dict[key]
|
209 |
+
prompt = ""
|
210 |
+
for instance in prompt_examples:
|
211 |
+
instance = instance.strip()
|
212 |
+
prompt += instance + " \n"
|
213 |
+
prompt_examples_dict[key] = prompt
|
214 |
+
|
215 |
+
else:
|
216 |
+
# read the prompts for the response generation
|
217 |
+
# prompts are fixed for all test samples
|
218 |
+
with open(args.prompt_file, "r") as f:
|
219 |
+
prompt_examples = f.readlines()
|
220 |
+
prompt_examples = prompt_examples[:args.num_prompt_examples]
|
221 |
+
|
222 |
+
prompt = ""
|
223 |
+
for instance in prompt_examples:
|
224 |
+
instance = instance.strip()
|
225 |
+
prompt += instance + " \n"
|
226 |
+
|
227 |
+
input_pos = 0
|
228 |
+
model.eval()
|
229 |
+
# perform prompting
|
230 |
+
with torch.no_grad():
|
231 |
+
while True:
|
232 |
+
raw_text_len = 0
|
233 |
+
if mpu.is_pipeline_first_stage() \
|
234 |
+
and mpu.get_tensor_model_parallel_rank() == 0:
|
235 |
+
input_str = all_raw_text[input_pos]
|
236 |
+
input_str = input_str.strip()
|
237 |
+
splits = input_str.split("\t")
|
238 |
+
topic = splits[0]
|
239 |
+
|
240 |
+
if args.prompt_type == "knowledge":
|
241 |
+
# first add the prompt into the raw_text
|
242 |
+
turns = splits[1].split(" [SEP] ")
|
243 |
+
last_turn = turns[-1]
|
244 |
+
key = topic + " " + last_turn
|
245 |
+
raw_text = prompt_examples_dict[key]
|
246 |
+
|
247 |
+
# construct inputs for knowledge generation
|
248 |
+
# then add the constructed inputs into the raw_text
|
249 |
+
raw_text += "( " + last_turn + " ) " + topic + " =>"
|
250 |
+
|
251 |
+
else:
|
252 |
+
# first add the prompt into the raw_text
|
253 |
+
raw_text = prompt
|
254 |
+
|
255 |
+
# construct inputs for response generation
|
256 |
+
# then add the constructed inputs into the raw_text
|
257 |
+
turns = splits[1].split(" [SEP] ")
|
258 |
+
knowledge = splits[2]
|
259 |
+
last_turn = turns[-1]
|
260 |
+
last_turn = " ".join(word_tokenize(last_turn))
|
261 |
+
knowledge = " ".join(word_tokenize(knowledge))
|
262 |
+
knowledge = knowledge.strip()
|
263 |
+
last_turn = last_turn.strip()
|
264 |
+
raw_text += "Topic: " + topic + ". "
|
265 |
+
raw_text += "User says: " + last_turn + " "
|
266 |
+
raw_text += "We know that: " + knowledge + " "
|
267 |
+
raw_text += "System replies:"
|
268 |
+
|
269 |
+
input_pos += 1
|
270 |
+
raw_text_len = len(raw_text)
|
271 |
+
|
272 |
+
else:
|
273 |
+
raw_text = "EMPTY TEXT"
|
274 |
+
|
275 |
+
if input_pos % 100 == 0:
|
276 |
+
print_rank_0("input_pos: %d" % input_pos)
|
277 |
+
|
278 |
+
outputs = generate_and_post_process(
|
279 |
+
model=model,
|
280 |
+
prompts=[raw_text],
|
281 |
+
tokens_to_generate=args.out_seq_length,
|
282 |
+
top_k_sampling=1)
|
283 |
+
prompts_plus_generations = outputs[0]
|
284 |
+
prompts_plus_generations = prompts_plus_generations[0]
|
285 |
+
|
286 |
+
# write the generated output to the output file
|
287 |
+
if mpu.get_tensor_model_parallel_rank() == 0:
|
288 |
+
if mpu.is_pipeline_first_stage():
|
289 |
+
|
290 |
+
generations = prompts_plus_generations[raw_text_len:]
|
291 |
+
generations = generations.split("\n")[0]
|
292 |
+
generations = generations.strip()
|
293 |
+
fname_out.write(generations)
|
294 |
+
fname_out.write("\n")
|
295 |
+
|
296 |
+
raw_text = None
|
297 |
+
if input_pos == input_count:
|
298 |
+
return
|
299 |
+
|
300 |
+
|
301 |
+
def main():
|
302 |
+
|
303 |
+
args = get_args()
|
304 |
+
if args.api_prompt:
|
305 |
+
# obtain the generations by calling the api
|
306 |
+
generate_samples_by_calling_api()
|
307 |
+
return
|
308 |
+
|
309 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
310 |
+
print("Interleaved pipeline schedule is not yet supported for text generation.")
|
311 |
+
exit()
|
312 |
+
|
313 |
+
# Set up model and load checkpoint.
|
314 |
+
model = get_model(model_provider, wrap_with_ddp=False)
|
315 |
+
if args.load is not None:
|
316 |
+
_ = load_checkpoint(model, None, None)
|
317 |
+
|
318 |
+
assert len(model) == 1, "Above condition should have caught this"
|
319 |
+
model = model[0]
|
320 |
+
|
321 |
+
# perform the prompting
|
322 |
+
generate_samples_by_prompting_input_from_file(model)
|
tasks/orqa/README.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
|
2 |
+
|
3 |
+
Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
|
4 |
+
|
5 |
+
## Retriever Training
|
6 |
+
|
7 |
+
#### Unsupervised pretraining
|
8 |
+
1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
|
9 |
+
|
10 |
+
<pre>
|
11 |
+
python tools/preprocess_data.py \
|
12 |
+
--input /path/to/corpus.json \
|
13 |
+
--json-keys text title \
|
14 |
+
--split-sentences \
|
15 |
+
--tokenizer-type BertWordPieceLowerCase \
|
16 |
+
--vocab-file /path/to/vocab.txt \
|
17 |
+
--output-prefix corpus_indexed \
|
18 |
+
--workers 10
|
19 |
+
</pre>
|
20 |
+
|
21 |
+
2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training.
|
22 |
+
|
23 |
+
3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
|
24 |
+
|
25 |
+
#### Supervised finetuning
|
26 |
+
|
27 |
+
1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
|
28 |
+
|
29 |
+
2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
|
30 |
+
|
31 |
+
More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
|
32 |
+
|
33 |
+
## Reader Training
|
34 |
+
|
35 |
+
The reader component will be available soon.
|
36 |
+
|
tasks/orqa/evaluate_orqa.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Main tasks functionality."""
|
17 |
+
|
18 |
+
from megatron import get_args, print_rank_0
|
19 |
+
from megatron.indexer import IndexBuilder
|
20 |
+
from tasks.orqa.evaluate_utils import ORQAEvaluator
|
21 |
+
|
22 |
+
def main():
|
23 |
+
"""
|
24 |
+
Main program
|
25 |
+
"""
|
26 |
+
|
27 |
+
args = get_args()
|
28 |
+
|
29 |
+
"""
|
30 |
+
Create a BlockData data structure by running an IndexBuilder over an
|
31 |
+
ICT Dataset and then evaluate on NQ task
|
32 |
+
"""
|
33 |
+
|
34 |
+
print_rank_0("Starting index builder!")
|
35 |
+
|
36 |
+
index_builder = IndexBuilder()
|
37 |
+
index_builder.build_and_save_index()
|
38 |
+
print_rank_0("Build and save indices: done!")
|
39 |
+
|
40 |
+
|
41 |
+
print_rank_0("Starting evaluations!")
|
42 |
+
|
43 |
+
# Set up the model and evaluator
|
44 |
+
evaluator = ORQAEvaluator()
|
45 |
+
|
46 |
+
# Run evaluation
|
47 |
+
if args.qa_data_dev is not None:
|
48 |
+
evaluator.evaluate(args.qa_data_dev, "DEV")
|
49 |
+
|
50 |
+
if args.qa_data_test is not None:
|
51 |
+
evaluator.evaluate(args.qa_data_test, "TEST")
|
52 |
+
|
tasks/orqa/evaluate_utils.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from megatron import get_args, print_rank_0
|
19 |
+
from megatron.checkpointing import load_biencoder_checkpoint
|
20 |
+
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
|
21 |
+
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
|
22 |
+
from megatron.model.biencoder_model import get_model_provider
|
23 |
+
from megatron.training import get_model
|
24 |
+
from tasks.orqa.unsupervised.nq import get_nq_dataset
|
25 |
+
from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
|
26 |
+
from tasks.orqa.unsupervised.nq import process_nq_batch
|
27 |
+
from tasks.orqa.unsupervised.qa_utils import calculate_matches
|
28 |
+
|
29 |
+
|
30 |
+
class ORQAEvaluator(object):
|
31 |
+
def __init__(self):
|
32 |
+
args = get_args()
|
33 |
+
self.embedding_size = args.hidden_size
|
34 |
+
self.faiss_use_gpu = args.faiss_use_gpu
|
35 |
+
self.evidence_embedder_obj = None
|
36 |
+
self.evidence_dataset = None
|
37 |
+
self.mips_index = None
|
38 |
+
self.eval_dataset = None
|
39 |
+
|
40 |
+
# Get Evidence (Wikipedia) dataset
|
41 |
+
self.get_evidence_dataset()
|
42 |
+
|
43 |
+
# Load query encoder checkpoint
|
44 |
+
only_query_model = True
|
45 |
+
if args.biencoder_shared_query_context_model:
|
46 |
+
only_query_model = False
|
47 |
+
|
48 |
+
model = get_model(get_model_provider(only_query_model=only_query_model,
|
49 |
+
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
|
50 |
+
|
51 |
+
self.model = load_biencoder_checkpoint(model,
|
52 |
+
only_query_model=only_query_model)
|
53 |
+
|
54 |
+
assert len(self.model) == 1
|
55 |
+
self.model[0].eval()
|
56 |
+
|
57 |
+
# Load faiss indexer
|
58 |
+
self.faiss_wrapper()
|
59 |
+
|
60 |
+
def get_evidence_embedding(self):
|
61 |
+
# This will load the embedding from the embedding path
|
62 |
+
self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
|
63 |
+
|
64 |
+
def get_evidence_dataset(self):
|
65 |
+
self.evidence_dataset = get_open_retrieval_wiki_dataset()
|
66 |
+
|
67 |
+
def faiss_wrapper(self):
|
68 |
+
# Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
|
69 |
+
# is distributed over all the GPUs in a node and FAISS is not
|
70 |
+
# thread-safe
|
71 |
+
args = get_args()
|
72 |
+
if args.local_rank == 0:
|
73 |
+
# Get evidence embeddings computed using context encoder
|
74 |
+
self.get_evidence_embedding()
|
75 |
+
|
76 |
+
assert self.evidence_embedder_obj is not None
|
77 |
+
self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
|
78 |
+
embed_data=self.evidence_embedder_obj,
|
79 |
+
use_gpu=self.faiss_use_gpu)
|
80 |
+
|
81 |
+
# Wait for the FAISS index to be initialized in all the nodes
|
82 |
+
torch.distributed.barrier()
|
83 |
+
|
84 |
+
def generate_query_vectors(self, qa_data, split):
|
85 |
+
|
86 |
+
self.eval_dataset = get_nq_dataset(qa_data, split)
|
87 |
+
dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
|
88 |
+
|
89 |
+
query_vectors = []
|
90 |
+
reference_list = []
|
91 |
+
|
92 |
+
for batch in dataloader:
|
93 |
+
# batch also has query_tokens and query_pad_data
|
94 |
+
query_tokens, query_mask, query_types, \
|
95 |
+
query_len, reference = process_nq_batch(batch)
|
96 |
+
|
97 |
+
assert len(self.model) == 1
|
98 |
+
unwrapped_model = self.model[0]
|
99 |
+
while not hasattr(unwrapped_model, 'embed_text'):
|
100 |
+
unwrapped_model = unwrapped_model.module
|
101 |
+
|
102 |
+
with torch.no_grad():
|
103 |
+
query_logits = unwrapped_model.embed_text(
|
104 |
+
unwrapped_model.query_model, query_tokens,
|
105 |
+
query_mask, query_types)
|
106 |
+
|
107 |
+
reference_list.extend(reference)
|
108 |
+
query_vectors.extend(query_logits.split(1, dim=0))
|
109 |
+
if len(query_vectors) % 100 == 0:
|
110 |
+
print_rank_0('Encoded queries {}'.format(len(query_vectors)))
|
111 |
+
|
112 |
+
query_tensor = torch.cat(query_vectors, dim=0)
|
113 |
+
print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
|
114 |
+
|
115 |
+
assert query_tensor.size(0) == len(self.eval_dataset)
|
116 |
+
return query_tensor, reference_list
|
117 |
+
|
118 |
+
def evaluate(self, qa_data, split):
|
119 |
+
args = get_args()
|
120 |
+
query_tensor, reference_list = self.generate_query_vectors(qa_data, \
|
121 |
+
split)
|
122 |
+
local_rank = args.local_rank
|
123 |
+
rank = torch.distributed.get_rank()
|
124 |
+
device_count = torch.cuda.device_count()
|
125 |
+
num_nodes = torch.distributed.get_world_size() // device_count
|
126 |
+
node_id = rank // device_count
|
127 |
+
|
128 |
+
for node in range(num_nodes):
|
129 |
+
start_rank = node * device_count
|
130 |
+
end_rank = (node + 1) * device_count
|
131 |
+
ranks_list = list(range(start_rank, end_rank))
|
132 |
+
node_group = torch.distributed.new_group(ranks=ranks_list)
|
133 |
+
|
134 |
+
if node_id == node:
|
135 |
+
device_start_rank = start_rank
|
136 |
+
group = node_group
|
137 |
+
|
138 |
+
input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
|
139 |
+
tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
|
140 |
+
torch.distributed.all_gather(tensor_list, query_tensor, group=group)
|
141 |
+
|
142 |
+
if local_rank == 0 and self.mips_index is not None:
|
143 |
+
all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
|
144 |
+
|
145 |
+
distance, topkindex = self.mips_index.search_mips_index(
|
146 |
+
all_query_tensor, top_k=args.faiss_topk_retrievals,
|
147 |
+
reconstruct=False)
|
148 |
+
distance = torch.from_numpy(distance).cuda()
|
149 |
+
topkindex = torch.LongTensor(topkindex).cuda()
|
150 |
+
|
151 |
+
if local_rank != 0:
|
152 |
+
distance = torch.empty(device_count * len(query_tensor), \
|
153 |
+
args.faiss_topk_retrievals, dtype=torch.float32).cuda()
|
154 |
+
topkindex = torch.empty(device_count * len(query_tensor), \
|
155 |
+
args.faiss_topk_retrievals, dtype=torch.int64).cuda()
|
156 |
+
|
157 |
+
torch.distributed.broadcast(distance, src=device_start_rank, \
|
158 |
+
group=group)
|
159 |
+
torch.distributed.broadcast(topkindex, src=device_start_rank, \
|
160 |
+
group=group)
|
161 |
+
|
162 |
+
distance = torch.split(distance, len(query_tensor), dim=0)\
|
163 |
+
[local_rank]
|
164 |
+
topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
|
165 |
+
[local_rank]
|
166 |
+
|
167 |
+
top_ids_and_scores = []
|
168 |
+
for darray, topkarray in zip(distance, topkindex):
|
169 |
+
top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
|
170 |
+
|
171 |
+
passages = self.evidence_dataset.id2text
|
172 |
+
match_stats = calculate_matches(passages,
|
173 |
+
reference_list,
|
174 |
+
top_ids_and_scores,
|
175 |
+
workers_num=args.num_workers,
|
176 |
+
match_type=args.faiss_match)
|
177 |
+
top_k_hits = match_stats.top_k_hits
|
178 |
+
|
179 |
+
print_rank_0("{} SET RESULTS".format(split))
|
180 |
+
print_rank_0("topk-{} documents hits {}".format(
|
181 |
+
args.faiss_topk_retrievals, top_k_hits))
|
182 |
+
top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
|
183 |
+
print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
|
184 |
+
|
185 |
+
for i in args.retriever_report_topk_accuracies:
|
186 |
+
print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
|
187 |
+
|
188 |
+
return
|
tasks/orqa/supervised/data.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""ORQA dataset."""
|
17 |
+
|
18 |
+
import json
|
19 |
+
import random
|
20 |
+
from abc import ABC
|
21 |
+
from abc import abstractmethod
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
|
26 |
+
from megatron import print_rank_0, get_args
|
27 |
+
from megatron.data.biencoder_dataset_utils import make_attention_mask
|
28 |
+
|
29 |
+
def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
|
30 |
+
ctx_id_list, ctx_types_list = [], []
|
31 |
+
for context in ctx_list:
|
32 |
+
title_ids = tokenizer.tokenize(context['title'])
|
33 |
+
ctx_ids = tokenizer.tokenize(context['text'])
|
34 |
+
ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
|
35 |
+
|
36 |
+
ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
|
37 |
+
max_seq_length, tokenizer.cls,
|
38 |
+
tokenizer.sep, tokenizer.pad)
|
39 |
+
ctx_id_list.append(ctx_ids)
|
40 |
+
ctx_types_list.append(ctx_types)
|
41 |
+
|
42 |
+
return ctx_id_list, ctx_types_list
|
43 |
+
|
44 |
+
|
45 |
+
def build_tokens_types_paddings_from_text(query, context,
|
46 |
+
tokenizer, max_seq_length):
|
47 |
+
"""Build token types and paddings, trim if needed, and pad if needed."""
|
48 |
+
|
49 |
+
query_ids = tokenizer.tokenize(query)
|
50 |
+
query_ids, query_types, query_pad_mask = \
|
51 |
+
build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
|
52 |
+
tokenizer.cls, tokenizer.sep, tokenizer.pad)
|
53 |
+
|
54 |
+
# Appending the title of the context at front
|
55 |
+
extended_ctx_ids = None
|
56 |
+
if context is not None:
|
57 |
+
title_ids = tokenizer.tokenize(context['title'])
|
58 |
+
ctx_ids = tokenizer.tokenize(context['text'])
|
59 |
+
extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
|
60 |
+
|
61 |
+
ctx_ids, ctx_types, ctx_pad_mask = \
|
62 |
+
build_tokens_types_paddings_from_ids(extended_ctx_ids,
|
63 |
+
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
|
64 |
+
|
65 |
+
return query_ids, query_types, query_pad_mask, \
|
66 |
+
ctx_ids, ctx_types, ctx_pad_mask
|
67 |
+
|
68 |
+
|
69 |
+
# Similar code tasks/data_utils with some changes
|
70 |
+
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
|
71 |
+
cls_id, sep_id, pad_id):
|
72 |
+
"""Build token types and paddings, trim if needed, and pad if needed."""
|
73 |
+
enc_ids = []
|
74 |
+
tokentypes_enc = []
|
75 |
+
|
76 |
+
# [CLS].
|
77 |
+
enc_ids.append(cls_id)
|
78 |
+
tokentypes_enc.append(0)
|
79 |
+
|
80 |
+
# A.
|
81 |
+
len_src = len(text_ids)
|
82 |
+
enc_ids.extend(text_ids)
|
83 |
+
tokentypes_enc.extend([0] * len_src)
|
84 |
+
|
85 |
+
# Cap the size.
|
86 |
+
if len(enc_ids) > max_seq_length - 1:
|
87 |
+
enc_ids = enc_ids[0: max_seq_length - 1]
|
88 |
+
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
|
89 |
+
|
90 |
+
# [SEP].
|
91 |
+
enc_ids.append(sep_id)
|
92 |
+
tokentypes_enc.append(0)
|
93 |
+
|
94 |
+
num_tokens_enc = len(enc_ids)
|
95 |
+
# Padding.
|
96 |
+
padding_length = max_seq_length - len(enc_ids)
|
97 |
+
if padding_length > 0:
|
98 |
+
enc_ids.extend([pad_id] * padding_length)
|
99 |
+
tokentypes_enc.extend([pad_id] * padding_length)
|
100 |
+
|
101 |
+
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
|
102 |
+
pad_mask = np.array(pad_mask, dtype=np.int64)
|
103 |
+
|
104 |
+
return enc_ids, tokentypes_enc, pad_mask
|
105 |
+
|
106 |
+
|
107 |
+
def build_sample(query_ids, query_types, query_pad_mask,
|
108 |
+
ctx_ids, ctx_types, ctx_pad_mask, answers,
|
109 |
+
neg_ctx_id_list=None, neg_ctx_types_list=None,
|
110 |
+
include_neg=False):
|
111 |
+
"""Convert to numpy and return a sample consumed by the batch producer."""
|
112 |
+
|
113 |
+
query_ids = np.array(query_ids, dtype=np.int64)
|
114 |
+
query_types = np.array(query_types, dtype=np.int64)
|
115 |
+
query_mask = make_attention_mask(query_ids, query_ids)
|
116 |
+
|
117 |
+
ctx_ids = np.array(ctx_ids, dtype=np.int64)
|
118 |
+
ctx_types = np.array(ctx_types, dtype=np.int64)
|
119 |
+
ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
|
120 |
+
|
121 |
+
sample = ({
|
122 |
+
'query': query_ids,
|
123 |
+
'query_mask': query_mask,
|
124 |
+
'query_types': query_types,
|
125 |
+
'query_pad_mask': query_pad_mask,
|
126 |
+
'context': ctx_ids,
|
127 |
+
'context_mask': ctx_mask,
|
128 |
+
'context_types': ctx_types,
|
129 |
+
'context_pad_mask': ctx_pad_mask,
|
130 |
+
'reference': answers
|
131 |
+
})
|
132 |
+
|
133 |
+
if include_neg:
|
134 |
+
neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
|
135 |
+
neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
|
136 |
+
neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
|
137 |
+
for ids in neg_ctx_ids], dtype=np.int64)
|
138 |
+
|
139 |
+
sample['neg_context'] = neg_ctx_ids
|
140 |
+
sample['neg_context_types'] = neg_ctx_id_types
|
141 |
+
sample['neg_context_mask'] = neg_ctx_mask
|
142 |
+
|
143 |
+
return sample
|
144 |
+
|
145 |
+
|
146 |
+
class OpenRetrievalAbstractDataset(ABC, Dataset):
|
147 |
+
"""Open Retrieval base dataset class."""
|
148 |
+
|
149 |
+
def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
|
150 |
+
max_seq_length, evaluate=False):
|
151 |
+
# Store inputs.
|
152 |
+
args = get_args()
|
153 |
+
self.evaluate = evaluate
|
154 |
+
self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
|
155 |
+
self.val_av_rank_other_neg = args.val_av_rank_other_neg
|
156 |
+
self.train_with_neg = args.train_with_neg
|
157 |
+
self.train_hard_neg = args.train_hard_neg
|
158 |
+
|
159 |
+
self.task_name = task_name
|
160 |
+
self.dataset_name = dataset_name
|
161 |
+
self.tokenizer = tokenizer
|
162 |
+
self.max_seq_length = max_seq_length
|
163 |
+
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
|
164 |
+
self.dataset_name))
|
165 |
+
# Process the files.
|
166 |
+
string = ' > paths:'
|
167 |
+
for path in datapaths:
|
168 |
+
string += ' ' + path
|
169 |
+
print_rank_0(string)
|
170 |
+
self.samples = []
|
171 |
+
for datapath in datapaths:
|
172 |
+
self.samples.extend(self.process_samples_from_single_path(datapath))
|
173 |
+
|
174 |
+
args = get_args()
|
175 |
+
if args.sample_rate < 1: # subsample
|
176 |
+
k = int(len(self.samples) * args.sample_rate)
|
177 |
+
self.samples = random.sample(self.samples, k)
|
178 |
+
|
179 |
+
print_rank_0(' >> total number of samples: {}'.format(
|
180 |
+
len(self.samples)))
|
181 |
+
|
182 |
+
def __len__(self):
|
183 |
+
return len(self.samples)
|
184 |
+
|
185 |
+
def __getitem__(self, idx):
|
186 |
+
raw_sample = self.samples[idx]
|
187 |
+
|
188 |
+
query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
|
189 |
+
ctx_pad_mask = build_tokens_types_paddings_from_text( \
|
190 |
+
raw_sample['question'], raw_sample['pos_context'], \
|
191 |
+
self.tokenizer, self.max_seq_length)
|
192 |
+
|
193 |
+
if self.evaluate:
|
194 |
+
neg_ctx_list = \
|
195 |
+
raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
|
196 |
+
raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
|
197 |
+
neg_ctx_id_list, neg_ctx_types_list = \
|
198 |
+
build_token_types_from_context_list(neg_ctx_list, \
|
199 |
+
self.tokenizer, self.max_seq_length)
|
200 |
+
|
201 |
+
elif self.train_with_neg:
|
202 |
+
hard_negative_ctx = raw_sample['hard_negative_context']
|
203 |
+
negative_ctx = raw_sample['negative_context']
|
204 |
+
if True: # TODO: fix this or remove this condition
|
205 |
+
random.shuffle(hard_negative_ctx)
|
206 |
+
random.shuffle(negative_ctx)
|
207 |
+
|
208 |
+
neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
|
209 |
+
# In the Google NQ dataset by DPR paper, there are around more than
|
210 |
+
# 50 missing hard negatives in training data.
|
211 |
+
# In those cases, substitute hard negatives by simple negatives.
|
212 |
+
if len(neg_ctx_list) < self.train_hard_neg:
|
213 |
+
neg_ctx_list += negative_ctx[:self.train_hard_neg - \
|
214 |
+
len(neg_ctx_list)]
|
215 |
+
|
216 |
+
neg_ctx_id_list, neg_ctx_types_list = \
|
217 |
+
build_token_types_from_context_list(neg_ctx_list,
|
218 |
+
self.tokenizer, self.max_seq_length)
|
219 |
+
else:
|
220 |
+
neg_ctx_id_list = None
|
221 |
+
neg_ctx_types_list = None
|
222 |
+
|
223 |
+
sample = build_sample(query_ids, query_types, query_pad_mask,
|
224 |
+
ctx_ids, ctx_types, ctx_pad_mask,
|
225 |
+
raw_sample['answers'],
|
226 |
+
neg_ctx_id_list, neg_ctx_types_list,
|
227 |
+
include_neg=self.evaluate or self.train_with_neg)
|
228 |
+
|
229 |
+
return sample
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
@abstractmethod
|
233 |
+
def process_samples_from_single_path(filename):
|
234 |
+
"""Abstract method that takes a filename and
|
235 |
+
returns a list of dataset samples, each sample being a dict of
|
236 |
+
{'text': string, 'text': string}
|
237 |
+
"""
|
238 |
+
pass
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
def normalize_question(question):
|
243 |
+
if question[-1] == '?':
|
244 |
+
question = question[:-1]
|
245 |
+
return question
|
246 |
+
|
247 |
+
# The following class reads the datasets for training retriever as
|
248 |
+
# prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
|
249 |
+
|
250 |
+
class NQSupervisedDataset(OpenRetrievalAbstractDataset):
|
251 |
+
|
252 |
+
def __init__(self, name, datapaths, tokenizer, max_seq_length, \
|
253 |
+
evaluate=False):
|
254 |
+
super().__init__('natural_questions_ret',
|
255 |
+
name,
|
256 |
+
datapaths,
|
257 |
+
tokenizer,
|
258 |
+
max_seq_length,
|
259 |
+
evaluate=evaluate)
|
260 |
+
|
261 |
+
@staticmethod
|
262 |
+
def process_samples_from_single_path(filename):
|
263 |
+
""""Implement abstract method."""
|
264 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
265 |
+
samples = []
|
266 |
+
total = 0
|
267 |
+
|
268 |
+
with open(filename, 'r', encoding="utf-8") as f:
|
269 |
+
data = json.load(f)
|
270 |
+
for row in data:
|
271 |
+
question = normalize_question(row['question'])
|
272 |
+
pos_context = row['positive_ctxs'][0]
|
273 |
+
|
274 |
+
# Hard Negative Contexts
|
275 |
+
if len(row['hard_negative_ctxs']) > 0:
|
276 |
+
hard_neg_context = row['hard_negative_ctxs']
|
277 |
+
else:
|
278 |
+
hard_neg_context = []
|
279 |
+
|
280 |
+
# Negative Contexts
|
281 |
+
if len(row['negative_ctxs']) > 0:
|
282 |
+
neg_context = row['negative_ctxs']
|
283 |
+
else:
|
284 |
+
neg_context = []
|
285 |
+
|
286 |
+
answers = row['answers']
|
287 |
+
sample = {'question': question,
|
288 |
+
'pos_context': pos_context,
|
289 |
+
'hard_negative_context': hard_neg_context,
|
290 |
+
'negative_context': neg_context,
|
291 |
+
'answers': answers}
|
292 |
+
total += 1
|
293 |
+
samples.append(sample)
|
294 |
+
|
295 |
+
if total % 5000 == 0:
|
296 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
297 |
+
|
298 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
299 |
+
return samples
|
300 |
+
|
tasks/orqa/supervised/eval_utils.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Evaluation utilities."""
|
17 |
+
from collections import OrderedDict
|
18 |
+
import math
|
19 |
+
import numpy as np
|
20 |
+
import time
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from megatron import get_args, print_rank_0
|
26 |
+
from megatron import mpu
|
27 |
+
from megatron.utils import average_losses_across_data_parallel_group
|
28 |
+
from tasks.finetune_utils import build_data_loader
|
29 |
+
|
30 |
+
def task_collate_fn(batch_data):
|
31 |
+
# generate batch
|
32 |
+
batch_size = len(batch_data)
|
33 |
+
tensorized = OrderedDict()
|
34 |
+
for d in batch_data:
|
35 |
+
for k, v in d.items():
|
36 |
+
tensorized.setdefault(k, []).append(v)
|
37 |
+
|
38 |
+
tensorized['query'] = torch.LongTensor(tensorized['query'])
|
39 |
+
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
|
40 |
+
tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
|
41 |
+
tensorized['query_pad_mask'] = \
|
42 |
+
torch.LongTensor(tensorized['query_pad_mask'])
|
43 |
+
|
44 |
+
tensorized['context'] = torch.LongTensor(tensorized['context'])
|
45 |
+
tensorized['context_mask'] = \
|
46 |
+
torch.LongTensor(tensorized['context_mask'])
|
47 |
+
tensorized['context_types'] = \
|
48 |
+
torch.LongTensor(tensorized['context_types'])
|
49 |
+
tensorized['context_pad_mask'] = \
|
50 |
+
torch.LongTensor(tensorized['context_pad_mask'])
|
51 |
+
|
52 |
+
if 'neg_context' in tensorized:
|
53 |
+
tensorized['neg_context'] = \
|
54 |
+
torch.LongTensor(np.concatenate(tensorized['neg_context']))
|
55 |
+
tensorized['neg_context_mask'] = \
|
56 |
+
torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
|
57 |
+
tensorized['neg_context_types'] = \
|
58 |
+
torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
|
59 |
+
|
60 |
+
return tensorized
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def process_batch(batch):
|
65 |
+
"""Process batch and produce inputs for the model."""
|
66 |
+
query_tokens = batch['query'].long().cuda()
|
67 |
+
query_mask = (batch['query_mask'] < 0.5).cuda()
|
68 |
+
query_types = batch['query_types'].long().cuda()
|
69 |
+
query_pad_mask = batch['query_pad_mask'].long().cuda()
|
70 |
+
|
71 |
+
context_tokens = batch['context'].long().cuda()
|
72 |
+
context_mask = (batch['context_mask'] < 0.5).cuda()
|
73 |
+
context_types = batch['context_types'].long().cuda()
|
74 |
+
context_pad_mask = batch['context_pad_mask'].long().cuda()
|
75 |
+
|
76 |
+
if 'neg_context' in batch:
|
77 |
+
neg_context_tokens = batch['neg_context'].long().cuda()
|
78 |
+
neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
|
79 |
+
neg_context_types = batch['neg_context_types'].long().cuda()
|
80 |
+
else:
|
81 |
+
neg_context_tokens = None
|
82 |
+
neg_context_mask = None
|
83 |
+
neg_context_types = None
|
84 |
+
|
85 |
+
reference = batch['reference']
|
86 |
+
|
87 |
+
return query_tokens, query_mask, query_types, query_pad_mask, \
|
88 |
+
context_tokens, context_mask, context_types, context_pad_mask, \
|
89 |
+
neg_context_tokens, neg_context_mask, neg_context_types, reference
|
90 |
+
|
91 |
+
def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
|
92 |
+
"""Provide function that calculates accuracies."""
|
93 |
+
args = get_args()
|
94 |
+
|
95 |
+
print_rank_0("accuracy_func_provider is CALLED")
|
96 |
+
|
97 |
+
# Build dataloaders
|
98 |
+
datapath = args.valid_data
|
99 |
+
dataset = single_dataset_provider(datapath)
|
100 |
+
|
101 |
+
drop_last = False
|
102 |
+
if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
|
103 |
+
drop_last = True
|
104 |
+
|
105 |
+
print_rank_0(datapath)
|
106 |
+
print_rank_0(rank0sampler)
|
107 |
+
|
108 |
+
dataloader = build_data_loader(dataset,
|
109 |
+
args.eval_micro_batch_size,
|
110 |
+
num_workers=args.num_workers,
|
111 |
+
drop_last=drop_last,
|
112 |
+
task_collate_fn=task_collate_fn)
|
113 |
+
dataloaders = (dataset.dataset_name, dataloader)
|
114 |
+
|
115 |
+
def metrics_func(model, epoch, output_predictions=False):
|
116 |
+
print_rank_0('calculating metrics by accuracy func in ORQA...')
|
117 |
+
|
118 |
+
if output_predictions:
|
119 |
+
assert rank0sampler
|
120 |
+
names = 'predictions'
|
121 |
+
name, dataloader = dataloaders
|
122 |
+
if args.task == "RET-FINETUNE-NQ":
|
123 |
+
start_time = time.time()
|
124 |
+
output = retrieval_loss(model, dataloader)
|
125 |
+
stats_dict, total = output
|
126 |
+
format_string = ""
|
127 |
+
for k, v in stats_dict.items():
|
128 |
+
format_string += "|{} = {:.2f}".format(k, v / total)
|
129 |
+
print_rank_0("epoch:{}{}".format(epoch, format_string))
|
130 |
+
print_rank_0("taken time to calcuate metrics {:.3f}".format(\
|
131 |
+
time.time() - start_time))
|
132 |
+
else:
|
133 |
+
raise AssertionError("{} Task not supported".format(args.task))
|
134 |
+
|
135 |
+
return metrics_func
|
136 |
+
|
137 |
+
|
138 |
+
def retrieval_loss(model, dataloader):
|
139 |
+
args = get_args()
|
140 |
+
total = 0
|
141 |
+
topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
|
142 |
+
args.retriever_report_topk_accuracies}
|
143 |
+
stats_dict = dict(rank=0, **topk_stats_dict)
|
144 |
+
|
145 |
+
assert len(model) == 1
|
146 |
+
unwrapped_model = model[0]
|
147 |
+
unwrapped_model.eval()
|
148 |
+
|
149 |
+
with torch.no_grad():
|
150 |
+
# For all the batches in the dataset.
|
151 |
+
for batch in dataloader:
|
152 |
+
# Run the model forward.
|
153 |
+
query_tokens, query_mask, query_types, _, \
|
154 |
+
context_tokens, context_mask, context_types, _, \
|
155 |
+
neg_context_tokens, neg_context_mask, neg_context_types, \
|
156 |
+
reference = process_batch(batch)
|
157 |
+
|
158 |
+
query_logits, context_logits = unwrapped_model(query_tokens,
|
159 |
+
query_mask, query_types,
|
160 |
+
torch.cat([context_tokens, neg_context_tokens]),
|
161 |
+
torch.cat([context_mask, neg_context_mask]),
|
162 |
+
torch.cat([context_types, neg_context_types]))
|
163 |
+
|
164 |
+
retrieval_scores = torch.matmul(query_logits,
|
165 |
+
torch.transpose(context_logits, 0, 1))
|
166 |
+
|
167 |
+
if args.retriever_score_scaling:
|
168 |
+
retrieval_scores = retrieval_scores / \
|
169 |
+
math.sqrt(args.hidden_size)
|
170 |
+
|
171 |
+
local_batch_size = query_logits.shape[0]
|
172 |
+
labels = torch.arange(local_batch_size).long().cuda()
|
173 |
+
|
174 |
+
softmax_scores = F.softmax(retrieval_scores, dim=1)
|
175 |
+
sorted_vals, sorted_indices = torch.topk(softmax_scores,
|
176 |
+
k=softmax_scores.shape[1],
|
177 |
+
sorted=True)
|
178 |
+
|
179 |
+
def topk_accuracy(k):
|
180 |
+
return torch.cuda.FloatTensor(
|
181 |
+
[sum([int(labels[i] in sorted_indices[i, :k]) for i in \
|
182 |
+
range(local_batch_size)])])
|
183 |
+
|
184 |
+
def get_rank():
|
185 |
+
return torch.cuda.FloatTensor(
|
186 |
+
[sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
|
187 |
+
for i in range(local_batch_size)])])
|
188 |
+
|
189 |
+
topk_accs = [topk_accuracy(k) for k in \
|
190 |
+
args.retriever_report_topk_accuracies]
|
191 |
+
rank = get_rank()
|
192 |
+
losses = average_losses_across_data_parallel_group([rank, \
|
193 |
+
*topk_accs])
|
194 |
+
|
195 |
+
# create stats_dict with retrieval loss and all specified
|
196 |
+
# top-k accuracies
|
197 |
+
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
|
198 |
+
zip(args.retriever_report_topk_accuracies, losses[1:])}
|
199 |
+
temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
|
200 |
+
for k in stats_dict.keys():
|
201 |
+
stats_dict[k] += temp_stats_dict[k]
|
202 |
+
total += local_batch_size
|
203 |
+
|
204 |
+
unwrapped_model.train()
|
205 |
+
|
206 |
+
return stats_dict, total
|
tasks/orqa/supervised/finetune.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""ORQA finetuning/evaluation."""
|
17 |
+
|
18 |
+
from functools import partial
|
19 |
+
import sys
|
20 |
+
|
21 |
+
import math
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
from megatron import get_args, get_timers, get_tokenizer
|
26 |
+
from megatron import mpu, print_rank_0
|
27 |
+
from megatron.indexer import IndexBuilder
|
28 |
+
from megatron.model.biencoder_model import biencoder_model_provider
|
29 |
+
from megatron.utils import average_losses_across_data_parallel_group
|
30 |
+
from pretrain_ict import get_group_world_size_rank
|
31 |
+
from tasks.finetune_utils import finetune
|
32 |
+
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
|
33 |
+
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
|
34 |
+
from tasks.orqa.evaluate_utils import ORQAEvaluator
|
35 |
+
|
36 |
+
# input_ is a 2D tensor
|
37 |
+
def check_and_append_tensor_for_gather(group, rank, world_size, input_):
|
38 |
+
|
39 |
+
# gather the size of the first dimension of the tensor from all ranks
|
40 |
+
current_length = input_.size()[0]
|
41 |
+
first_dim = torch.tensor([[current_length]],
|
42 |
+
device=torch.cuda.current_device())
|
43 |
+
input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
|
44 |
+
input_list[rank].copy_(first_dim)
|
45 |
+
torch.distributed.all_gather(input_list, first_dim, group=group)
|
46 |
+
all_input_list = torch.cat(input_list, dim=0).contiguous()
|
47 |
+
max_length = torch.max(all_input_list)
|
48 |
+
|
49 |
+
# if the size are different than the max, extend the tensor
|
50 |
+
# accordingly
|
51 |
+
if max_length > current_length:
|
52 |
+
padding=tuple([0] * (input_.dim() * 2 - 1)) + \
|
53 |
+
tuple([max_length - current_length])
|
54 |
+
input_ = F.pad(input=input_, pad=padding)
|
55 |
+
|
56 |
+
return input_
|
57 |
+
|
58 |
+
def orqa(Dataset):
|
59 |
+
|
60 |
+
def cross_entropy_forward_step(batch, model):
|
61 |
+
"""Simple forward step with cross-entropy loss."""
|
62 |
+
timers = get_timers()
|
63 |
+
tokenizer = get_tokenizer()
|
64 |
+
|
65 |
+
# Get the batch.
|
66 |
+
timers('batch generator').start()
|
67 |
+
try:
|
68 |
+
batch_ = next(batch)
|
69 |
+
except BaseException:
|
70 |
+
batch_ = batch
|
71 |
+
|
72 |
+
group, rank, world_size = get_group_world_size_rank()
|
73 |
+
|
74 |
+
query_tokens, query_mask, query_types, query_pad_mask, \
|
75 |
+
context_tokens, context_mask, context_types, context_pad_mask, \
|
76 |
+
neg_context_tokens, neg_context_mask, neg_context_types, \
|
77 |
+
reference = process_batch(batch_)
|
78 |
+
|
79 |
+
timers('batch generator').stop()
|
80 |
+
local_batch_size = query_tokens.shape[0]
|
81 |
+
|
82 |
+
# Text representation of query and context
|
83 |
+
query_list, context_list = [], []
|
84 |
+
for i in range(local_batch_size):
|
85 |
+
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
|
86 |
+
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
|
87 |
+
|
88 |
+
if neg_context_tokens is not None:
|
89 |
+
neg_context_tokens = check_and_append_tensor_for_gather(group,
|
90 |
+
rank, world_size, neg_context_tokens)
|
91 |
+
neg_context_mask = check_and_append_tensor_for_gather(group,
|
92 |
+
rank, world_size, neg_context_mask)
|
93 |
+
neg_context_types = check_and_append_tensor_for_gather(group,
|
94 |
+
rank, world_size, neg_context_types)
|
95 |
+
|
96 |
+
if neg_context_tokens is not None:
|
97 |
+
context_tokens = torch.cat([context_tokens, neg_context_tokens])
|
98 |
+
context_mask = torch.cat([context_mask, neg_context_mask])
|
99 |
+
context_types = torch.cat([context_types, neg_context_types])
|
100 |
+
|
101 |
+
# Forward model.
|
102 |
+
output_tensor = model(query_tokens, query_mask,
|
103 |
+
query_types, context_tokens,
|
104 |
+
context_mask, context_types)
|
105 |
+
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
|
106 |
+
|
107 |
+
|
108 |
+
def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
|
109 |
+
args = get_args()
|
110 |
+
|
111 |
+
local_batch_size = query_tokens.shape[0]
|
112 |
+
group, rank, world_size = get_group_world_size_rank()
|
113 |
+
# recall we assert that model_parallel_size == 1
|
114 |
+
global_batch_size = world_size * local_batch_size
|
115 |
+
|
116 |
+
query_logits, context_logits = output_tensor
|
117 |
+
|
118 |
+
if world_size > 1:
|
119 |
+
input_ = torch.empty_like(context_logits).copy_(\
|
120 |
+
context_logits).detach_()
|
121 |
+
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
122 |
+
tensor_list[rank].copy_(input_)
|
123 |
+
torch.distributed.all_gather(tensor_list, input_, group=group)
|
124 |
+
|
125 |
+
# Check if all-gather happens in order
|
126 |
+
assert tensor_list[rank].sum().item() == \
|
127 |
+
context_logits.sum().item()
|
128 |
+
|
129 |
+
# Preserves the gradient
|
130 |
+
tensor_list[rank] = context_logits
|
131 |
+
all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
|
132 |
+
|
133 |
+
# Query tensors
|
134 |
+
input_ = torch.empty_like(query_logits).copy_(\
|
135 |
+
query_logits).detach_()
|
136 |
+
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
137 |
+
tensor_list[rank].copy_(input_)
|
138 |
+
torch.distributed.all_gather(tensor_list, input_, group=group)
|
139 |
+
|
140 |
+
# Check if all-gather happens in order
|
141 |
+
assert tensor_list[rank].sum().item() == query_logits.sum().item()
|
142 |
+
|
143 |
+
# Preserves the gradient
|
144 |
+
tensor_list[rank] = query_logits
|
145 |
+
all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
|
146 |
+
else:
|
147 |
+
all_query_logits = query_logits
|
148 |
+
all_context_logits = context_logits
|
149 |
+
|
150 |
+
retrieval_scores = torch.matmul(all_query_logits,
|
151 |
+
torch.transpose(all_context_logits, 0, 1))
|
152 |
+
# Scaling the retrieval scores
|
153 |
+
if args.retriever_score_scaling:
|
154 |
+
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
|
155 |
+
|
156 |
+
if args.train_with_neg:
|
157 |
+
# if the world size is 3, local batch size is 4, and
|
158 |
+
# local context size is 8, what we want is
|
159 |
+
# labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
|
160 |
+
labels = []
|
161 |
+
local_context_size = context_tokens.shape[0]
|
162 |
+
for i in range(world_size):
|
163 |
+
j = i * local_context_size
|
164 |
+
labels.extend(list(range(j, j + local_batch_size)))
|
165 |
+
labels = torch.LongTensor(labels).cuda()
|
166 |
+
assert len(labels) == global_batch_size
|
167 |
+
else:
|
168 |
+
labels = torch.arange(global_batch_size).long().cuda()
|
169 |
+
|
170 |
+
# Cross-entropy loss.
|
171 |
+
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
|
172 |
+
|
173 |
+
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
|
174 |
+
|
175 |
+
max_score, max_idxs = torch.max(softmax_scores, 1)
|
176 |
+
correct_predictions_count = (max_idxs == labels).sum().float()
|
177 |
+
|
178 |
+
# Reduce loss for logging.
|
179 |
+
reduced_loss = average_losses_across_data_parallel_group([loss, \
|
180 |
+
correct_predictions_count])
|
181 |
+
|
182 |
+
# Loss scaling for correct losses in Supervised Retrieval
|
183 |
+
loss = loss * mpu.get_data_parallel_world_size()
|
184 |
+
|
185 |
+
return loss, {'lm loss': reduced_loss[0],
|
186 |
+
'correct_prediction_count': reduced_loss[1]}
|
187 |
+
|
188 |
+
|
189 |
+
def train_valid_datasets_provider():
|
190 |
+
"""Build train and validation dataset."""
|
191 |
+
args = get_args()
|
192 |
+
tokenizer = get_tokenizer()
|
193 |
+
|
194 |
+
train_dataset = Dataset('training',
|
195 |
+
args.train_data,
|
196 |
+
tokenizer,
|
197 |
+
args.retriever_seq_length,
|
198 |
+
evaluate=False)
|
199 |
+
valid_dataset = Dataset('validation',
|
200 |
+
args.valid_data,
|
201 |
+
tokenizer,
|
202 |
+
args.retriever_seq_length,
|
203 |
+
evaluate=True)
|
204 |
+
return train_dataset, valid_dataset
|
205 |
+
|
206 |
+
def model_provider(pre_process=True, post_process=True):
|
207 |
+
"""Build the model."""
|
208 |
+
args = get_args()
|
209 |
+
print_rank_0('building retriever model for {} ...'.format(args.task))
|
210 |
+
|
211 |
+
model = biencoder_model_provider(only_context_model=False,
|
212 |
+
only_query_model=False,
|
213 |
+
biencoder_shared_query_context_model=\
|
214 |
+
args.biencoder_shared_query_context_model,
|
215 |
+
pre_process=pre_process, post_process=post_process)
|
216 |
+
|
217 |
+
return model
|
218 |
+
|
219 |
+
def single_dataset_provider(datapath):
|
220 |
+
args = get_args()
|
221 |
+
tokenizer = get_tokenizer()
|
222 |
+
|
223 |
+
name = datapath[0].split('/')[-1].split('.')[0]
|
224 |
+
return Dataset(name,
|
225 |
+
datapath,
|
226 |
+
tokenizer,
|
227 |
+
args.retriever_seq_length,
|
228 |
+
evaluate=True)
|
229 |
+
|
230 |
+
def metrics_func_provider():
|
231 |
+
"""Provide metrics callback function."""
|
232 |
+
return accuracy_func_provider(single_dataset_provider)
|
233 |
+
|
234 |
+
"""Finetune/evaluate."""
|
235 |
+
finetune(train_valid_datasets_provider,
|
236 |
+
model_provider,
|
237 |
+
forward_step=cross_entropy_forward_step,
|
238 |
+
end_of_epoch_callback_provider=metrics_func_provider,
|
239 |
+
task_collate_fn=task_collate_fn)
|
240 |
+
|
241 |
+
def main():
|
242 |
+
args = get_args()
|
243 |
+
|
244 |
+
if args.task == 'RET-FINETUNE-NQ':
|
245 |
+
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
|
246 |
+
else:
|
247 |
+
raise NotImplementedError('ORQA task {} is not implemented.'.format(
|
248 |
+
args.task))
|
249 |
+
|
250 |
+
orqa(Dataset)
|
251 |
+
|
tasks/orqa/unsupervised/nq.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
Data Loader for Google NQ dataset
|
18 |
+
"""
|
19 |
+
|
20 |
+
from abc import ABC
|
21 |
+
import csv
|
22 |
+
from collections import OrderedDict
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from torch.utils.data import DataLoader
|
27 |
+
from torch.utils.data import Dataset, BatchSampler
|
28 |
+
|
29 |
+
from megatron import print_rank_0, get_args, get_tokenizer, mpu
|
30 |
+
from megatron.data.biencoder_dataset_utils import make_attention_mask
|
31 |
+
|
32 |
+
def get_nq_dataset(qa_data, split):
|
33 |
+
args = get_args()
|
34 |
+
tokenizer = get_tokenizer()
|
35 |
+
|
36 |
+
dataset = NQDataset('Google NQ {} Split'.format(split),
|
37 |
+
'Google Natural Questions',
|
38 |
+
qa_data,
|
39 |
+
tokenizer,
|
40 |
+
args.retriever_seq_length)
|
41 |
+
return dataset
|
42 |
+
|
43 |
+
|
44 |
+
def process_nq_batch(batch):
|
45 |
+
query_tokens = batch['token_ids'].long().cuda()
|
46 |
+
query_mask = (batch['token_mask'] < 0.5).cuda()
|
47 |
+
query_types = batch['token_types'].long().cuda()
|
48 |
+
query_len = batch['seq_len'].long().cuda()
|
49 |
+
reference = batch['reference']
|
50 |
+
|
51 |
+
return query_tokens, query_mask, query_types, query_len, reference
|
52 |
+
|
53 |
+
|
54 |
+
class CustomDataLoader(DataLoader):
|
55 |
+
def __init__(self, dataset, eval=False, **kwargs):
|
56 |
+
if kwargs.get('collate_fn', None) is None:
|
57 |
+
kwargs['collate_fn'] = self._collate_fn
|
58 |
+
self.eval = eval
|
59 |
+
super().__init__(dataset, **kwargs)
|
60 |
+
|
61 |
+
def _collate_fn(self, batch_data):
|
62 |
+
# generate batch
|
63 |
+
batch_size = len(batch_data)
|
64 |
+
tensorized = OrderedDict()
|
65 |
+
for d in batch_data:
|
66 |
+
for k, v in d.items():
|
67 |
+
tensorized.setdefault(k, []).append(v)
|
68 |
+
assert len(tensorized) == 5
|
69 |
+
|
70 |
+
tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
|
71 |
+
tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
|
72 |
+
tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
|
73 |
+
tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
|
74 |
+
return tensorized
|
75 |
+
|
76 |
+
|
77 |
+
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
|
78 |
+
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
|
79 |
+
NOTE: This dataloader is not distributed !!!
|
80 |
+
"""
|
81 |
+
|
82 |
+
args = get_args()
|
83 |
+
if micro_batch_size is None:
|
84 |
+
micro_batch_size = args.micro_batch_size
|
85 |
+
num_workers = args.num_workers
|
86 |
+
|
87 |
+
sampler = torch.utils.data.SequentialSampler(dataset)
|
88 |
+
# importantly, drop_last must be False to get all the data.
|
89 |
+
batch_sampler = BatchSampler(sampler,
|
90 |
+
batch_size=micro_batch_size,
|
91 |
+
drop_last=False)
|
92 |
+
|
93 |
+
# Data loader. Note that batch size is the per GPU batch size.
|
94 |
+
data_loader = CustomDataLoader(dataset,
|
95 |
+
batch_sampler=batch_sampler,
|
96 |
+
num_workers=num_workers,
|
97 |
+
pin_memory=True)
|
98 |
+
return data_loader
|
99 |
+
|
100 |
+
|
101 |
+
def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
|
102 |
+
"""Build token types and paddings, trim if needed, and pad if needed."""
|
103 |
+
|
104 |
+
src_text_ids = tokenizer.tokenize(src_text)
|
105 |
+
|
106 |
+
return build_tokens_types_paddings_from_ids(src_text_ids,
|
107 |
+
max_seq_length,
|
108 |
+
tokenizer.cls,
|
109 |
+
tokenizer.sep,
|
110 |
+
tokenizer.pad)
|
111 |
+
|
112 |
+
|
113 |
+
def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
|
114 |
+
sep_id, pad_id):
|
115 |
+
"""
|
116 |
+
Build token types and paddings, trim if needed, and pad if needed.
|
117 |
+
|
118 |
+
TODO: Design modular interface to reuse this function. This is getting
|
119 |
+
repeated multiple times in different tasks
|
120 |
+
"""
|
121 |
+
|
122 |
+
enc_ids = []
|
123 |
+
tokentypes_enc = []
|
124 |
+
|
125 |
+
# [CLS].
|
126 |
+
enc_ids.append(cls_id)
|
127 |
+
tokentypes_enc.append(0)
|
128 |
+
|
129 |
+
# A.
|
130 |
+
len_src = len(src_ids)
|
131 |
+
enc_ids.extend(src_ids)
|
132 |
+
tokentypes_enc.extend([0] * len_src)
|
133 |
+
|
134 |
+
# Cap the size.
|
135 |
+
if len(enc_ids) > max_seq_length - 1:
|
136 |
+
enc_ids = enc_ids[0: max_seq_length - 1]
|
137 |
+
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
|
138 |
+
|
139 |
+
# [SEP].
|
140 |
+
enc_ids.append(sep_id)
|
141 |
+
tokentypes_enc.append(0)
|
142 |
+
|
143 |
+
num_tokens_enc = len(enc_ids)
|
144 |
+
# Padding.
|
145 |
+
padding_length = max_seq_length - len(enc_ids)
|
146 |
+
if padding_length > 0:
|
147 |
+
enc_ids.extend([pad_id] * padding_length)
|
148 |
+
tokentypes_enc.extend([pad_id] * padding_length)
|
149 |
+
|
150 |
+
return enc_ids, tokentypes_enc, num_tokens_enc
|
151 |
+
|
152 |
+
|
153 |
+
def build_sample(token_ids, token_types, num_tokens, reference):
|
154 |
+
"""
|
155 |
+
Convert to numpy and return a sample consumed by the
|
156 |
+
batch producer.
|
157 |
+
"""
|
158 |
+
|
159 |
+
token_ids = np.array(token_ids, dtype=np.int64)
|
160 |
+
token_types = np.array(token_types, dtype=np.int64)
|
161 |
+
token_mask = make_attention_mask(token_ids, token_ids)
|
162 |
+
|
163 |
+
sample = ({
|
164 |
+
'token_ids': token_ids,
|
165 |
+
'token_mask': token_mask,
|
166 |
+
'token_types': token_types,
|
167 |
+
'seq_len': num_tokens,
|
168 |
+
'reference': reference
|
169 |
+
})
|
170 |
+
return sample
|
171 |
+
|
172 |
+
|
173 |
+
class NQDataset(ABC, Dataset):
|
174 |
+
"""
|
175 |
+
Open Retrieval Question Answering evaluation using Google NQ dataset.
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(self, task_name, dataset_name, datapath,
|
179 |
+
tokenizer, max_seq_length):
|
180 |
+
# Store inputs.
|
181 |
+
self.task_name = task_name
|
182 |
+
self.dataset_name = dataset_name
|
183 |
+
self.tokenizer = tokenizer
|
184 |
+
self.max_seq_length = max_seq_length
|
185 |
+
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
|
186 |
+
self.dataset_name))
|
187 |
+
print_rank_0(datapath)
|
188 |
+
self.samples = self.process_samples_from_single_path(datapath)
|
189 |
+
print_rank_0(' >> total number of samples: {}'.format(\
|
190 |
+
len(self.samples)))
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return len(self.samples)
|
194 |
+
|
195 |
+
def __getitem__(self, idx):
|
196 |
+
raw_sample = self.samples[idx]
|
197 |
+
|
198 |
+
ques_tokens, tokentypes_enc, num_tokens_ques = \
|
199 |
+
build_tokens_types_paddings_from_text(raw_sample['question'],
|
200 |
+
self.tokenizer, self.max_seq_length)
|
201 |
+
|
202 |
+
sample = build_sample(ques_tokens,
|
203 |
+
tokentypes_enc,
|
204 |
+
num_tokens_ques,
|
205 |
+
raw_sample['answers'])
|
206 |
+
return sample
|
207 |
+
|
208 |
+
@staticmethod
|
209 |
+
def process_samples_from_single_path(filename):
|
210 |
+
print_rank_0(' > Processing {} ...'.format(filename))
|
211 |
+
samples = []
|
212 |
+
total = 0
|
213 |
+
|
214 |
+
with open(filename, 'r') as ifile:
|
215 |
+
reader = csv.reader(ifile, delimiter='\t')
|
216 |
+
for row in reader:
|
217 |
+
question = row[0]
|
218 |
+
answers = eval(row[1])
|
219 |
+
|
220 |
+
sample = {'question': question, 'answers': answers}
|
221 |
+
total += 1
|
222 |
+
samples.append(sample)
|
223 |
+
|
224 |
+
if total % 1000 == 0:
|
225 |
+
print_rank_0(' > processed {} so far ...'.format(total))
|
226 |
+
|
227 |
+
print_rank_0(' >> processed {} samples.'.format(len(samples)))
|
228 |
+
return samples
|
tasks/orqa/unsupervised/qa_utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
|
6 |
+
# The following code has been taken from
|
7 |
+
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
|
8 |
+
# licensed as of now. More details on the license can be found
|
9 |
+
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
|
10 |
+
|
11 |
+
"""
|
12 |
+
Set of utilities for Q&A results validation tasks - Retriver passage
|
13 |
+
validation and Reader predicted answer validation
|
14 |
+
"""
|
15 |
+
|
16 |
+
import collections
|
17 |
+
import logging
|
18 |
+
import string
|
19 |
+
import unicodedata
|
20 |
+
from functools import partial
|
21 |
+
from multiprocessing import Pool as ProcessPool
|
22 |
+
from typing import Tuple, List, Dict
|
23 |
+
|
24 |
+
import regex as re
|
25 |
+
from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
|
30 |
+
'questions_doc_hits'])
|
31 |
+
|
32 |
+
def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
|
33 |
+
answers: List[List[str]], closest_docs: List[Tuple[List[object],
|
34 |
+
List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
|
35 |
+
"""
|
36 |
+
Evaluates answers presence in the set of documents. This function is
|
37 |
+
supposed to be used with a large collection of documents and results.
|
38 |
+
It internally forks multiple sub-processes for evaluation and then
|
39 |
+
merges results
|
40 |
+
:param all_docs: dictionary of the entire documents database.
|
41 |
+
doc_id -> (doc_text, title)
|
42 |
+
:param answers: list of answers's list. One list per question
|
43 |
+
:param closest_docs: document ids of the top results along with their
|
44 |
+
scores
|
45 |
+
:param workers_num: amount of parallel threads to process data
|
46 |
+
:param match_type: type of answer matching. Refer to has_answer code for
|
47 |
+
available options
|
48 |
+
:return: matching information tuple.
|
49 |
+
top_k_hits - a list where the index is the amount of top documents retrieved
|
50 |
+
and the value is the total amount of valid matches across an entire
|
51 |
+
dataset.
|
52 |
+
questions_doc_hits - more detailed info with answer matches for every
|
53 |
+
question and every retrieved document
|
54 |
+
"""
|
55 |
+
global dpr_all_documents
|
56 |
+
dpr_all_documents = all_docs
|
57 |
+
|
58 |
+
tok_opts = {}
|
59 |
+
tokenizer = SimpleTokenizer(**tok_opts)
|
60 |
+
|
61 |
+
processes = ProcessPool(
|
62 |
+
processes=workers_num,
|
63 |
+
)
|
64 |
+
|
65 |
+
logger.info('Matching answers in top docs...')
|
66 |
+
|
67 |
+
get_score_partial = partial(check_answer, match_type=match_type,
|
68 |
+
tokenizer=tokenizer)
|
69 |
+
|
70 |
+
questions_answers_docs = zip(answers, closest_docs)
|
71 |
+
|
72 |
+
scores = processes.map(get_score_partial, questions_answers_docs)
|
73 |
+
|
74 |
+
logger.info('Per question validation results len=%d', len(scores))
|
75 |
+
|
76 |
+
n_docs = len(closest_docs[0][0])
|
77 |
+
top_k_hits = [0] * n_docs
|
78 |
+
for question_hits in scores:
|
79 |
+
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
|
80 |
+
if best_hit is not None:
|
81 |
+
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
|
82 |
+
|
83 |
+
return QAMatchStats(top_k_hits, scores)
|
84 |
+
|
85 |
+
|
86 |
+
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
|
87 |
+
"""
|
88 |
+
Search through all the top docs to see if they have any of the answers.
|
89 |
+
"""
|
90 |
+
answers, (doc_ids, doc_scores) = questions_answers_docs
|
91 |
+
|
92 |
+
global dpr_all_documents
|
93 |
+
hits = []
|
94 |
+
|
95 |
+
for i, doc_id in enumerate(doc_ids):
|
96 |
+
doc = dpr_all_documents[doc_id]
|
97 |
+
text = doc[0]
|
98 |
+
|
99 |
+
answer_found = False
|
100 |
+
if text is None: # cannot find the document for some reason
|
101 |
+
logger.warning("no doc in db")
|
102 |
+
hits.append(False)
|
103 |
+
continue
|
104 |
+
|
105 |
+
if has_answer(answers, text, tokenizer, match_type):
|
106 |
+
answer_found = True
|
107 |
+
hits.append(answer_found)
|
108 |
+
return hits
|
109 |
+
|
110 |
+
|
111 |
+
def has_answer(answers, text, tokenizer, match_type) -> bool:
|
112 |
+
"""
|
113 |
+
Check if a document contains an answer string.
|
114 |
+
If `match_type` is string, token matching is done between the text
|
115 |
+
and answer.
|
116 |
+
If `match_type` is regex, we search the whole text with the regex.
|
117 |
+
"""
|
118 |
+
text = _normalize(text)
|
119 |
+
|
120 |
+
if match_type == 'string':
|
121 |
+
# Answer is a list of possible strings
|
122 |
+
text = tokenizer.tokenize(text).words(uncased=True)
|
123 |
+
|
124 |
+
for single_answer in answers:
|
125 |
+
single_answer = _normalize(single_answer)
|
126 |
+
single_answer = tokenizer.tokenize(single_answer)
|
127 |
+
single_answer = single_answer.words(uncased=True)
|
128 |
+
|
129 |
+
for i in range(0, len(text) - len(single_answer) + 1):
|
130 |
+
if single_answer == text[i: i + len(single_answer)]:
|
131 |
+
return True
|
132 |
+
|
133 |
+
elif match_type == 'regex':
|
134 |
+
# Answer is a regex
|
135 |
+
for single_answer in answers:
|
136 |
+
single_answer = _normalize(single_answer)
|
137 |
+
if regex_match(text, single_answer):
|
138 |
+
return True
|
139 |
+
return False
|
140 |
+
|
141 |
+
|
142 |
+
def regex_match(text, pattern):
|
143 |
+
"""Test if a regex pattern is contained within a text."""
|
144 |
+
try:
|
145 |
+
pattern = re.compile(
|
146 |
+
pattern,
|
147 |
+
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
|
148 |
+
)
|
149 |
+
except BaseException:
|
150 |
+
return False
|
151 |
+
return pattern.search(text) is not None
|
152 |
+
|
153 |
+
|
154 |
+
# function for the reader model answer validation
|
155 |
+
def exact_match_score(prediction, ground_truth):
|
156 |
+
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
|
157 |
+
|
158 |
+
|
159 |
+
def _normalize_answer(s):
|
160 |
+
def remove_articles(text):
|
161 |
+
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
162 |
+
|
163 |
+
def white_space_fix(text):
|
164 |
+
return ' '.join(text.split())
|
165 |
+
|
166 |
+
def remove_punc(text):
|
167 |
+
exclude = set(string.punctuation)
|
168 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
169 |
+
|
170 |
+
def lower(text):
|
171 |
+
return text.lower()
|
172 |
+
|
173 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
174 |
+
|
175 |
+
|
176 |
+
def _normalize(text):
|
177 |
+
return unicodedata.normalize('NFD', text)
|
tasks/orqa/unsupervised/tokenizers.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
|
6 |
+
# The following code has been taken from
|
7 |
+
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
|
8 |
+
# licensed as of now. More details on the license can be found
|
9 |
+
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
|
10 |
+
|
11 |
+
"""
|
12 |
+
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
|
13 |
+
"""
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import logging
|
17 |
+
|
18 |
+
import regex
|
19 |
+
import spacy
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class Tokens(object):
|
25 |
+
"""A class to represent a list of tokenized text."""
|
26 |
+
TEXT = 0
|
27 |
+
TEXT_WS = 1
|
28 |
+
SPAN = 2
|
29 |
+
POS = 3
|
30 |
+
LEMMA = 4
|
31 |
+
NER = 5
|
32 |
+
|
33 |
+
def __init__(self, data, annotators, opts=None):
|
34 |
+
self.data = data
|
35 |
+
self.annotators = annotators
|
36 |
+
self.opts = opts or {}
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
"""The number of tokens."""
|
40 |
+
return len(self.data)
|
41 |
+
|
42 |
+
def slice(self, i=None, j=None):
|
43 |
+
"""Return a view of the list of tokens from [i, j)."""
|
44 |
+
new_tokens = copy.copy(self)
|
45 |
+
new_tokens.data = self.data[i: j]
|
46 |
+
return new_tokens
|
47 |
+
|
48 |
+
def untokenize(self):
|
49 |
+
"""Returns the original text (with whitespace reinserted)."""
|
50 |
+
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
|
51 |
+
|
52 |
+
def words(self, uncased=False):
|
53 |
+
"""Returns a list of the text of each token
|
54 |
+
|
55 |
+
Args:
|
56 |
+
uncased: lower cases text
|
57 |
+
"""
|
58 |
+
if uncased:
|
59 |
+
return [t[self.TEXT].lower() for t in self.data]
|
60 |
+
else:
|
61 |
+
return [t[self.TEXT] for t in self.data]
|
62 |
+
|
63 |
+
def offsets(self):
|
64 |
+
"""Returns a list of [start, end) character offsets of each token."""
|
65 |
+
return [t[self.SPAN] for t in self.data]
|
66 |
+
|
67 |
+
def pos(self):
|
68 |
+
"""Returns a list of part-of-speech tags of each token.
|
69 |
+
Returns None if this annotation was not included.
|
70 |
+
"""
|
71 |
+
if 'pos' not in self.annotators:
|
72 |
+
return None
|
73 |
+
return [t[self.POS] for t in self.data]
|
74 |
+
|
75 |
+
def lemmas(self):
|
76 |
+
"""Returns a list of the lemmatized text of each token.
|
77 |
+
Returns None if this annotation was not included.
|
78 |
+
"""
|
79 |
+
if 'lemma' not in self.annotators:
|
80 |
+
return None
|
81 |
+
return [t[self.LEMMA] for t in self.data]
|
82 |
+
|
83 |
+
def entities(self):
|
84 |
+
"""Returns a list of named-entity-recognition tags of each token.
|
85 |
+
Returns None if this annotation was not included.
|
86 |
+
"""
|
87 |
+
if 'ner' not in self.annotators:
|
88 |
+
return None
|
89 |
+
return [t[self.NER] for t in self.data]
|
90 |
+
|
91 |
+
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
|
92 |
+
"""Returns a list of all ngrams from length 1 to n.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
n: upper limit of ngram length
|
96 |
+
uncased: lower cases text
|
97 |
+
filter_fn: user function that takes in an ngram list and returns
|
98 |
+
True or False to keep or not keep the ngram
|
99 |
+
as_string: return the ngram as a string vs list
|
100 |
+
"""
|
101 |
+
|
102 |
+
def _skip(gram):
|
103 |
+
if not filter_fn:
|
104 |
+
return False
|
105 |
+
return filter_fn(gram)
|
106 |
+
|
107 |
+
words = self.words(uncased)
|
108 |
+
ngrams = [(s, e + 1)
|
109 |
+
for s in range(len(words))
|
110 |
+
for e in range(s, min(s + n, len(words)))
|
111 |
+
if not _skip(words[s:e + 1])]
|
112 |
+
|
113 |
+
# Concatenate into strings
|
114 |
+
if as_strings:
|
115 |
+
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
|
116 |
+
|
117 |
+
return ngrams
|
118 |
+
|
119 |
+
def entity_groups(self):
|
120 |
+
"""Group consecutive entity tokens with the same NER tag."""
|
121 |
+
entities = self.entities()
|
122 |
+
if not entities:
|
123 |
+
return None
|
124 |
+
non_ent = self.opts.get('non_ent', 'O')
|
125 |
+
groups = []
|
126 |
+
idx = 0
|
127 |
+
while idx < len(entities):
|
128 |
+
ner_tag = entities[idx]
|
129 |
+
# Check for entity tag
|
130 |
+
if ner_tag != non_ent:
|
131 |
+
# Chomp the sequence
|
132 |
+
start = idx
|
133 |
+
while (idx < len(entities) and entities[idx] == ner_tag):
|
134 |
+
idx += 1
|
135 |
+
groups.append((self.slice(start, idx).untokenize(), ner_tag))
|
136 |
+
else:
|
137 |
+
idx += 1
|
138 |
+
return groups
|
139 |
+
|
140 |
+
|
141 |
+
class Tokenizer(object):
|
142 |
+
"""Base tokenizer class.
|
143 |
+
Tokenizers implement tokenize, which should return a Tokens class.
|
144 |
+
"""
|
145 |
+
|
146 |
+
def tokenize(self, text):
|
147 |
+
raise NotImplementedError
|
148 |
+
|
149 |
+
def shutdown(self):
|
150 |
+
pass
|
151 |
+
|
152 |
+
def __del__(self):
|
153 |
+
self.shutdown()
|
154 |
+
|
155 |
+
|
156 |
+
class SimpleTokenizer(Tokenizer):
|
157 |
+
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
|
158 |
+
NON_WS = r'[^\p{Z}\p{C}]'
|
159 |
+
|
160 |
+
def __init__(self, **kwargs):
|
161 |
+
"""
|
162 |
+
Args:
|
163 |
+
annotators: None or empty set (only tokenizes).
|
164 |
+
"""
|
165 |
+
self._regexp = regex.compile(
|
166 |
+
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
|
167 |
+
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
|
168 |
+
)
|
169 |
+
if len(kwargs.get('annotators', {})) > 0:
|
170 |
+
logger.warning('%s only tokenizes! Skipping annotators: %s' %
|
171 |
+
(type(self).__name__, kwargs.get('annotators')))
|
172 |
+
self.annotators = set()
|
173 |
+
|
174 |
+
def tokenize(self, text):
|
175 |
+
data = []
|
176 |
+
matches = [m for m in self._regexp.finditer(text)]
|
177 |
+
for i in range(len(matches)):
|
178 |
+
# Get text
|
179 |
+
token = matches[i].group()
|
180 |
+
|
181 |
+
# Get whitespace
|
182 |
+
span = matches[i].span()
|
183 |
+
start_ws = span[0]
|
184 |
+
if i + 1 < len(matches):
|
185 |
+
end_ws = matches[i + 1].span()[0]
|
186 |
+
else:
|
187 |
+
end_ws = span[1]
|
188 |
+
|
189 |
+
# Format data
|
190 |
+
data.append((
|
191 |
+
token,
|
192 |
+
text[start_ws: end_ws],
|
193 |
+
span,
|
194 |
+
))
|
195 |
+
return Tokens(data, self.annotators)
|
196 |
+
|
197 |
+
|
198 |
+
class SpacyTokenizer(Tokenizer):
|
199 |
+
|
200 |
+
def __init__(self, **kwargs):
|
201 |
+
"""
|
202 |
+
Args:
|
203 |
+
annotators: set that can include pos, lemma, and ner.
|
204 |
+
model: spaCy model to use (either path, or keyword like 'en').
|
205 |
+
"""
|
206 |
+
model = kwargs.get('model', 'en')
|
207 |
+
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
|
208 |
+
nlp_kwargs = {'parser': False}
|
209 |
+
if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
|
210 |
+
nlp_kwargs['tagger'] = False
|
211 |
+
if 'ner' not in self.annotators:
|
212 |
+
nlp_kwargs['entity'] = False
|
213 |
+
self.nlp = spacy.load(model, **nlp_kwargs)
|
214 |
+
|
215 |
+
def tokenize(self, text):
|
216 |
+
# We don't treat new lines as tokens.
|
217 |
+
clean_text = text.replace('\n', ' ')
|
218 |
+
tokens = self.nlp.tokenizer(clean_text)
|
219 |
+
if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
|
220 |
+
self.nlp.tagger(tokens)
|
221 |
+
if 'ner' in self.annotators:
|
222 |
+
self.nlp.entity(tokens)
|
223 |
+
|
224 |
+
data = []
|
225 |
+
for i in range(len(tokens)):
|
226 |
+
# Get whitespace
|
227 |
+
start_ws = tokens[i].idx
|
228 |
+
if i + 1 < len(tokens):
|
229 |
+
end_ws = tokens[i + 1].idx
|
230 |
+
else:
|
231 |
+
end_ws = tokens[i].idx + len(tokens[i].text)
|
232 |
+
|
233 |
+
data.append((
|
234 |
+
tokens[i].text,
|
235 |
+
text[start_ws: end_ws],
|
236 |
+
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
|
237 |
+
tokens[i].tag_,
|
238 |
+
tokens[i].lemma_,
|
239 |
+
tokens[i].ent_type_,
|
240 |
+
))
|
241 |
+
|
242 |
+
# Set special option for non-entity tag: '' vs 'O' in spaCy
|
243 |
+
return Tokens(data, self.annotators, opts={'non_ent': ''})
|
tasks/race/data.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
from megatron import print_rank_0
|
10 |
+
from tasks.data_utils import build_sample
|
11 |
+
from tasks.data_utils import build_tokens_types_paddings_from_ids
|
12 |
+
from tasks.data_utils import clean_text
|
13 |
+
|
14 |
+
|
15 |
+
NUM_CHOICES = 4
|
16 |
+
MAX_QA_LENGTH = 128
|
17 |
+
|
18 |
+
|
19 |
+
class RaceDataset(Dataset):
|
20 |
+
|
21 |
+
def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
|
22 |
+
max_qa_length=MAX_QA_LENGTH):
|
23 |
+
|
24 |
+
self.dataset_name = dataset_name
|
25 |
+
print_rank_0(' > building RACE dataset for {}:'.format(
|
26 |
+
self.dataset_name))
|
27 |
+
|
28 |
+
string = ' > paths:'
|
29 |
+
for path in datapaths:
|
30 |
+
string += ' ' + path
|
31 |
+
print_rank_0(string)
|
32 |
+
|
33 |
+
self.samples = []
|
34 |
+
for datapath in datapaths:
|
35 |
+
self.samples.extend(process_single_datapath(datapath, tokenizer,
|
36 |
+
max_qa_length,
|
37 |
+
max_seq_length))
|
38 |
+
|
39 |
+
print_rank_0(' >> total number of samples: {}'.format(
|
40 |
+
len(self.samples)))
|
41 |
+
|
42 |
+
# This indicates that each "sample" has multiple samples that
|
43 |
+
# will collapse into batch dimension
|
44 |
+
self.sample_multiplier = NUM_CHOICES
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.samples)
|
48 |
+
|
49 |
+
def __getitem__(self, idx):
|
50 |
+
return self.samples[idx]
|
51 |
+
|
52 |
+
|
53 |
+
def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
|
54 |
+
"""Read in RACE files, combine, clean-up, tokenize, and convert to
|
55 |
+
samples."""
|
56 |
+
|
57 |
+
print_rank_0(' > working on {}'.format(datapath))
|
58 |
+
start_time = time.time()
|
59 |
+
|
60 |
+
# Get list of files.
|
61 |
+
filenames = glob.glob(os.path.join(datapath, '*.txt'))
|
62 |
+
|
63 |
+
samples = []
|
64 |
+
num_docs = 0
|
65 |
+
num_questions = 0
|
66 |
+
num_samples = 0
|
67 |
+
# Load all the files
|
68 |
+
for filename in filenames:
|
69 |
+
with open(filename, 'r') as f:
|
70 |
+
for line in f:
|
71 |
+
data = json.loads(line)
|
72 |
+
num_docs += 1
|
73 |
+
|
74 |
+
context = data["article"]
|
75 |
+
questions = data["questions"]
|
76 |
+
choices = data["options"]
|
77 |
+
answers = data["answers"]
|
78 |
+
# Check the length.
|
79 |
+
assert len(questions) == len(answers)
|
80 |
+
assert len(questions) == len(choices)
|
81 |
+
|
82 |
+
# Context: clean up and convert to ids.
|
83 |
+
context = clean_text(context)
|
84 |
+
context_ids = tokenizer.tokenize(context)
|
85 |
+
|
86 |
+
# Loop over questions.
|
87 |
+
for qi, question in enumerate(questions):
|
88 |
+
num_questions += 1
|
89 |
+
# Label.
|
90 |
+
label = ord(answers[qi]) - ord("A")
|
91 |
+
assert label >= 0
|
92 |
+
assert label < NUM_CHOICES
|
93 |
+
assert len(choices[qi]) == NUM_CHOICES
|
94 |
+
|
95 |
+
# For each question, build num-choices samples.
|
96 |
+
ids_list = []
|
97 |
+
types_list = []
|
98 |
+
paddings_list = []
|
99 |
+
for ci in range(NUM_CHOICES):
|
100 |
+
choice = choices[qi][ci]
|
101 |
+
# Merge with choice.
|
102 |
+
if "_" in question:
|
103 |
+
qa = question.replace("_", choice)
|
104 |
+
else:
|
105 |
+
qa = " ".join([question, choice])
|
106 |
+
# Clean QA.
|
107 |
+
qa = clean_text(qa)
|
108 |
+
# Tokenize.
|
109 |
+
qa_ids = tokenizer.tokenize(qa)
|
110 |
+
# Trim if needed.
|
111 |
+
if len(qa_ids) > max_qa_length:
|
112 |
+
qa_ids = qa_ids[0:max_qa_length]
|
113 |
+
|
114 |
+
# Build the sample.
|
115 |
+
ids, types, paddings \
|
116 |
+
= build_tokens_types_paddings_from_ids(
|
117 |
+
qa_ids, context_ids, max_seq_length,
|
118 |
+
tokenizer.cls, tokenizer.sep, tokenizer.pad)
|
119 |
+
|
120 |
+
ids_list.append(ids)
|
121 |
+
types_list.append(types)
|
122 |
+
paddings_list.append(paddings)
|
123 |
+
|
124 |
+
# Convert to numpy and add to samples
|
125 |
+
samples.append(build_sample(ids_list, types_list,
|
126 |
+
paddings_list, label,
|
127 |
+
num_samples))
|
128 |
+
num_samples += 1
|
129 |
+
|
130 |
+
elapsed_time = time.time() - start_time
|
131 |
+
print_rank_0(' > processed {} document, {} questions, and {} samples'
|
132 |
+
' in {:.2f} seconds'.format(num_docs, num_questions,
|
133 |
+
num_samples, elapsed_time))
|
134 |
+
|
135 |
+
return samples
|
tasks/race/finetune.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Race."""
|
17 |
+
|
18 |
+
from megatron import get_args
|
19 |
+
from megatron import print_rank_0
|
20 |
+
from megatron import get_tokenizer
|
21 |
+
from megatron import mpu
|
22 |
+
from megatron.model.multiple_choice import MultipleChoice
|
23 |
+
from tasks.eval_utils import accuracy_func_provider
|
24 |
+
from tasks.finetune_utils import finetune
|
25 |
+
from tasks.race.data import RaceDataset
|
26 |
+
|
27 |
+
|
28 |
+
def train_valid_datasets_provider():
|
29 |
+
"""Provide train and validation datasets."""
|
30 |
+
args = get_args()
|
31 |
+
tokenizer = get_tokenizer()
|
32 |
+
|
33 |
+
train_dataset = RaceDataset('training', args.train_data,
|
34 |
+
tokenizer, args.seq_length)
|
35 |
+
valid_dataset = RaceDataset('validation', args.valid_data,
|
36 |
+
tokenizer, args.seq_length)
|
37 |
+
|
38 |
+
return train_dataset, valid_dataset
|
39 |
+
|
40 |
+
|
41 |
+
def model_provider(pre_process=True, post_process=True):
|
42 |
+
"""Build the model."""
|
43 |
+
|
44 |
+
print_rank_0('building multichoice model for RACE ...')
|
45 |
+
model = MultipleChoice(num_tokentypes=2,
|
46 |
+
pre_process=pre_process,
|
47 |
+
post_process=post_process)
|
48 |
+
|
49 |
+
return model
|
50 |
+
|
51 |
+
|
52 |
+
def metrics_func_provider():
|
53 |
+
"""Privde metrics callback function."""
|
54 |
+
args = get_args()
|
55 |
+
tokenizer = get_tokenizer()
|
56 |
+
|
57 |
+
def single_dataset_provider(datapath):
|
58 |
+
name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
|
59 |
+
return RaceDataset(name, [datapath], tokenizer, args.seq_length)
|
60 |
+
|
61 |
+
return accuracy_func_provider(single_dataset_provider)
|
62 |
+
|
63 |
+
|
64 |
+
def main():
|
65 |
+
|
66 |
+
finetune(train_valid_datasets_provider, model_provider,
|
67 |
+
end_of_epoch_callback_provider=metrics_func_provider)
|
tasks/vision/classification/classification.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Vision-classification finetuning/evaluation."""
|
17 |
+
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from functools import partial
|
20 |
+
from megatron import get_args, get_timers
|
21 |
+
from megatron import print_rank_0
|
22 |
+
from megatron.model.vision.classification import VitClassificationModel
|
23 |
+
from megatron.data.vit_dataset import build_train_valid_datasets
|
24 |
+
from tasks.vision.classification.eval_utils import accuracy_func_provider
|
25 |
+
from tasks.vision.finetune_utils import finetune
|
26 |
+
from megatron.utils import average_losses_across_data_parallel_group
|
27 |
+
|
28 |
+
|
29 |
+
def classification():
|
30 |
+
def train_valid_datasets_provider():
|
31 |
+
"""Build train and validation dataset."""
|
32 |
+
args = get_args()
|
33 |
+
|
34 |
+
train_ds, valid_ds = build_train_valid_datasets(
|
35 |
+
data_path=args.data_path,
|
36 |
+
image_size=(args.img_h, args.img_w),
|
37 |
+
)
|
38 |
+
return train_ds, valid_ds
|
39 |
+
|
40 |
+
def model_provider(pre_process=True, post_process=True):
|
41 |
+
"""Build the model."""
|
42 |
+
args = get_args()
|
43 |
+
|
44 |
+
print_rank_0("building classification model for ImageNet ...")
|
45 |
+
|
46 |
+
return VitClassificationModel(num_classes=args.num_classes, finetune=True,
|
47 |
+
pre_process=pre_process, post_process=post_process)
|
48 |
+
|
49 |
+
def process_batch(batch):
|
50 |
+
"""Process batch and produce inputs for the model."""
|
51 |
+
images = batch[0].cuda().contiguous()
|
52 |
+
labels = batch[1].cuda().contiguous()
|
53 |
+
return images, labels
|
54 |
+
|
55 |
+
def cross_entropy_loss_func(labels, output_tensor):
|
56 |
+
logits = output_tensor
|
57 |
+
|
58 |
+
# Cross-entropy loss.
|
59 |
+
loss = F.cross_entropy(logits.contiguous().float(), labels)
|
60 |
+
|
61 |
+
# Reduce loss for logging.
|
62 |
+
averaged_loss = average_losses_across_data_parallel_group([loss])
|
63 |
+
|
64 |
+
return loss, {'lm loss': averaged_loss[0]}
|
65 |
+
|
66 |
+
def _cross_entropy_forward_step(batch, model):
|
67 |
+
"""Simple forward step with cross-entropy loss."""
|
68 |
+
timers = get_timers()
|
69 |
+
|
70 |
+
# Get the batch.
|
71 |
+
timers("batch generator").start()
|
72 |
+
try:
|
73 |
+
batch_ = next(batch)
|
74 |
+
except BaseException:
|
75 |
+
batch_ = batch
|
76 |
+
images, labels = process_batch(batch_)
|
77 |
+
timers("batch generator").stop()
|
78 |
+
|
79 |
+
# Forward model.
|
80 |
+
output_tensor = model(images)
|
81 |
+
|
82 |
+
return output_tensor, partial(cross_entropy_loss_func, labels)
|
83 |
+
|
84 |
+
"""Finetune/evaluate."""
|
85 |
+
finetune(
|
86 |
+
train_valid_datasets_provider,
|
87 |
+
model_provider,
|
88 |
+
forward_step=_cross_entropy_forward_step,
|
89 |
+
end_of_epoch_callback_provider=accuracy_func_provider,
|
90 |
+
)
|
91 |
+
|
92 |
+
def main():
|
93 |
+
classification()
|
94 |
+
|
tasks/vision/classification/eval_utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Evaluation utilities."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from megatron import get_args
|
24 |
+
from megatron import print_rank_0, print_rank_last
|
25 |
+
from megatron import mpu
|
26 |
+
from megatron.schedules import get_forward_backward_func
|
27 |
+
from tasks.vision.finetune_utils import build_data_loader
|
28 |
+
from tasks.vision.finetune_utils import process_batch
|
29 |
+
from torchvision import datasets, transforms
|
30 |
+
|
31 |
+
|
32 |
+
def accuracy_func_provider():
|
33 |
+
"""Provide function that calculates accuracies."""
|
34 |
+
args = get_args()
|
35 |
+
data_path = args.data_path
|
36 |
+
crop_size = (args.img_h, args.img_w)
|
37 |
+
|
38 |
+
# Build dataloaders.
|
39 |
+
val_data_path = data_path[1]
|
40 |
+
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
41 |
+
transform_val = transforms.Compose(
|
42 |
+
[
|
43 |
+
transforms.Resize(crop_size),
|
44 |
+
transforms.CenterCrop(crop_size),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
normalize,
|
47 |
+
]
|
48 |
+
)
|
49 |
+
dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
|
50 |
+
|
51 |
+
dataloader = build_data_loader(
|
52 |
+
dataset,
|
53 |
+
args.micro_batch_size,
|
54 |
+
num_workers=args.num_workers,
|
55 |
+
drop_last=(mpu.get_data_parallel_world_size() > 1),
|
56 |
+
shuffle=False
|
57 |
+
)
|
58 |
+
|
59 |
+
def metrics_func(model, epoch):
|
60 |
+
print_rank_0("calculating metrics ...")
|
61 |
+
correct, total = calculate_correct_answers(model, dataloader, epoch)
|
62 |
+
percent = float(correct) * 100.0 / float(total)
|
63 |
+
print_rank_last(
|
64 |
+
" >> |epoch: {}| overall: correct / total = {} / {} = "
|
65 |
+
"{:.4f} %".format(epoch, correct, total, percent)
|
66 |
+
)
|
67 |
+
|
68 |
+
return metrics_func
|
69 |
+
|
70 |
+
|
71 |
+
def calculate_correct_answers(model, dataloader, epoch):
|
72 |
+
"""Calculate correct over total answers"""
|
73 |
+
|
74 |
+
forward_backward_func = get_forward_backward_func()
|
75 |
+
for m in model:
|
76 |
+
m.eval()
|
77 |
+
|
78 |
+
def loss_func(labels, output_tensor):
|
79 |
+
logits = output_tensor
|
80 |
+
|
81 |
+
loss_dict = {}
|
82 |
+
# Compute the correct answers.
|
83 |
+
predicted = torch.argmax(logits, dim=-1)
|
84 |
+
corrects = (predicted == labels).float()
|
85 |
+
# Add to the counters.
|
86 |
+
loss_dict['total'] = labels.size(0)
|
87 |
+
loss_dict['correct'] = corrects.sum().item()
|
88 |
+
|
89 |
+
return 0, loss_dict
|
90 |
+
|
91 |
+
#defined inside to capture output_predictions
|
92 |
+
def correct_answers_forward_step(batch, model):
|
93 |
+
try:
|
94 |
+
batch_ = next(batch)
|
95 |
+
except BaseException:
|
96 |
+
batch_ = batch
|
97 |
+
images, labels = process_batch(batch_)
|
98 |
+
|
99 |
+
# Forward model.
|
100 |
+
output_tensor = model(images)
|
101 |
+
|
102 |
+
return output_tensor, partial(loss_func, labels)
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
# For all the batches in the dataset.
|
106 |
+
total = 0
|
107 |
+
correct = 0
|
108 |
+
for _, batch in enumerate(dataloader):
|
109 |
+
|
110 |
+
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
|
111 |
+
optimizer=None, timers=None, forward_only=True)
|
112 |
+
|
113 |
+
for loss_dict in loss_dicts:
|
114 |
+
total += loss_dict['total']
|
115 |
+
correct += loss_dict['correct']
|
116 |
+
|
117 |
+
for m in model:
|
118 |
+
m.train()
|
119 |
+
|
120 |
+
# Reduce.
|
121 |
+
if mpu.is_pipeline_last_stage():
|
122 |
+
unreduced = torch.cuda.LongTensor([correct, total])
|
123 |
+
torch.distributed.all_reduce(unreduced,
|
124 |
+
group=mpu.get_data_parallel_group())
|
125 |
+
|
126 |
+
# Print on screen.
|
127 |
+
correct_ans = unreduced[0].item()
|
128 |
+
total_count = unreduced[1].item()
|
129 |
+
return correct_ans, total_count
|
tasks/vision/finetune_utils.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Finetune utilities."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from megatron import get_args
|
21 |
+
from megatron import print_rank_0
|
22 |
+
from megatron import get_timers
|
23 |
+
from megatron import mpu, utils
|
24 |
+
from megatron.checkpointing import load_checkpoint
|
25 |
+
from megatron.checkpointing import save_checkpoint
|
26 |
+
from megatron.training import evaluate_and_print_results
|
27 |
+
from megatron.training import setup_model_and_optimizer
|
28 |
+
from megatron.training import train_step
|
29 |
+
from megatron.training import training_log
|
30 |
+
from megatron.utils import check_adlr_autoresume_termination
|
31 |
+
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
|
32 |
+
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
33 |
+
from megatron.model import DistributedDataParallel as LocalDDP
|
34 |
+
from megatron.model import Float16Module, ModelType
|
35 |
+
|
36 |
+
|
37 |
+
def process_batch(batch):
|
38 |
+
"""Process batch and produce inputs for the model."""
|
39 |
+
images = batch[0].cuda().contiguous()
|
40 |
+
labels = batch[1].cuda().contiguous()
|
41 |
+
return images, labels
|
42 |
+
|
43 |
+
|
44 |
+
def build_data_loader(dataset, micro_batch_size,
|
45 |
+
num_workers, drop_last, shuffle):
|
46 |
+
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
|
47 |
+
|
48 |
+
# Sampler.
|
49 |
+
world_size = mpu.get_data_parallel_world_size()
|
50 |
+
rank = mpu.get_data_parallel_rank()
|
51 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
52 |
+
dataset, num_replicas=world_size, rank=rank,
|
53 |
+
drop_last=drop_last, shuffle=shuffle
|
54 |
+
)
|
55 |
+
|
56 |
+
# Data loader. Note that batch size is the per GPU batch size.
|
57 |
+
data_loader = torch.utils.data.DataLoader(
|
58 |
+
dataset,
|
59 |
+
batch_size=micro_batch_size,
|
60 |
+
sampler=sampler,
|
61 |
+
shuffle=False,
|
62 |
+
num_workers=num_workers,
|
63 |
+
drop_last=drop_last,
|
64 |
+
pin_memory=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
return data_loader
|
68 |
+
|
69 |
+
|
70 |
+
def _build_infinite_size_dataloader(dataloader):
|
71 |
+
"""Build a looped dataloader with infinite size."""
|
72 |
+
|
73 |
+
iterator = dataloader.__iter__()
|
74 |
+
while True:
|
75 |
+
try:
|
76 |
+
yield iterator.__next__()
|
77 |
+
except StopIteration:
|
78 |
+
iterator = dataloader.__iter__()
|
79 |
+
|
80 |
+
|
81 |
+
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
|
82 |
+
"""Traing and validation dataloaders."""
|
83 |
+
args = get_args()
|
84 |
+
|
85 |
+
print_rank_0('building train and validation dataloaders ...')
|
86 |
+
# Training dataset.
|
87 |
+
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
|
88 |
+
args.num_workers, False, True)
|
89 |
+
# Set the training iterations.
|
90 |
+
args.train_iters_per_epoch = len(train_dataloader)
|
91 |
+
args.train_iters = args.epochs * args.train_iters_per_epoch
|
92 |
+
# Validation dataset. For this dataset, we do not need to set up
|
93 |
+
# shuffling so we can just use a simple infinite loop.
|
94 |
+
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
|
95 |
+
args.num_workers, True, False)
|
96 |
+
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
|
97 |
+
|
98 |
+
# Now that we've built the data loaders, set batch_size arguments
|
99 |
+
# to the actual batch size the model will see for this dataset.
|
100 |
+
# This is necessary so pipeline transfers know what size they are
|
101 |
+
# and the LR schedule, which is based on samples seen, gets set
|
102 |
+
# correctly.
|
103 |
+
args.orig_micro_batch_size = args.micro_batch_size
|
104 |
+
args.orig_global_batch_size = args.global_batch_size
|
105 |
+
|
106 |
+
return train_dataloader, valid_dataloader
|
107 |
+
|
108 |
+
|
109 |
+
def _train(
|
110 |
+
model,
|
111 |
+
optimizer,
|
112 |
+
opt_param_scheduler,
|
113 |
+
forward_step,
|
114 |
+
train_dataloader,
|
115 |
+
valid_dataloader,
|
116 |
+
end_of_epoch_callback,
|
117 |
+
process_non_loss_data_func=None
|
118 |
+
):
|
119 |
+
"""Train the model."""
|
120 |
+
args = get_args()
|
121 |
+
timers = get_timers()
|
122 |
+
|
123 |
+
# Turn on training mode which enables dropout.
|
124 |
+
for m in model:
|
125 |
+
m.train()
|
126 |
+
|
127 |
+
# Tracking loss.
|
128 |
+
losses_dict_sum = {}
|
129 |
+
|
130 |
+
# Starting epoch and iteration
|
131 |
+
start_epoch = args.iteration // args.train_iters_per_epoch
|
132 |
+
start_iteration = args.iteration % args.train_iters_per_epoch
|
133 |
+
iteration = args.iteration
|
134 |
+
|
135 |
+
# Memory reporting flag.
|
136 |
+
report_memory_flag = True
|
137 |
+
|
138 |
+
# For each remaining epoch
|
139 |
+
timers("interval-time").start()
|
140 |
+
for epoch in range(start_epoch, args.epochs):
|
141 |
+
print_rank_0("working on epoch {} ...".format(epoch + 1))
|
142 |
+
|
143 |
+
# Set the data loader epoch to shuffle the index iterator.
|
144 |
+
train_dataloader.sampler.set_epoch(args.seed + epoch)
|
145 |
+
train_dataloader.dataset.set_epoch(epoch)
|
146 |
+
|
147 |
+
# For all the batches in the dataset.
|
148 |
+
for iteration_, batch in enumerate(train_dataloader):
|
149 |
+
|
150 |
+
# Ignore the iterations before starting value
|
151 |
+
if iteration_ < start_iteration:
|
152 |
+
continue
|
153 |
+
# Set to zero so the next epoch does not skip any batches.
|
154 |
+
start_iteration = 0
|
155 |
+
|
156 |
+
# Train for one step.
|
157 |
+
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
|
158 |
+
forward_step, batch, model, optimizer, opt_param_scheduler
|
159 |
+
)
|
160 |
+
iteration += 1
|
161 |
+
|
162 |
+
# Logging.
|
163 |
+
params_norm = None
|
164 |
+
|
165 |
+
report_memory_flag = training_log(
|
166 |
+
losses_dict,
|
167 |
+
losses_dict_sum,
|
168 |
+
optimizer.param_groups[0]["lr"],
|
169 |
+
iteration,
|
170 |
+
optimizer.get_loss_scale().item(),
|
171 |
+
report_memory_flag,
|
172 |
+
skipped_iter,
|
173 |
+
grad_norm,
|
174 |
+
params_norm,
|
175 |
+
num_zeros_in_grad
|
176 |
+
)
|
177 |
+
|
178 |
+
# Autoresume
|
179 |
+
if args.adlr_autoresume and \
|
180 |
+
iteration % args.adlr_autoresume_interval == 0:
|
181 |
+
check_adlr_autoresume_termination(iteration, model, optimizer,
|
182 |
+
opt_param_scheduler)
|
183 |
+
|
184 |
+
# Checkpointing
|
185 |
+
if args.save and args.save_interval and \
|
186 |
+
iteration % args.save_interval == 0:
|
187 |
+
save_checkpoint(iteration, model, optimizer,
|
188 |
+
opt_param_scheduler)
|
189 |
+
|
190 |
+
# Evaluation
|
191 |
+
if args.eval_interval and iteration % args.eval_interval == 0:
|
192 |
+
prefix = "iteration {}".format(iteration)
|
193 |
+
evaluate_and_print_results(
|
194 |
+
prefix,
|
195 |
+
forward_step,
|
196 |
+
valid_dataloader,
|
197 |
+
model,
|
198 |
+
iteration,
|
199 |
+
process_non_loss_data_func,
|
200 |
+
False,
|
201 |
+
)
|
202 |
+
|
203 |
+
# Callback at the end of each epoch.
|
204 |
+
if end_of_epoch_callback is not None:
|
205 |
+
end_of_epoch_callback(model, epoch)
|
206 |
+
|
207 |
+
|
208 |
+
def finetune(
|
209 |
+
train_valid_datasets_provider,
|
210 |
+
model_provider,
|
211 |
+
forward_step,
|
212 |
+
model_type=ModelType.encoder_or_decoder,
|
213 |
+
process_non_loss_data_func=None,
|
214 |
+
end_of_epoch_callback_provider=None,
|
215 |
+
):
|
216 |
+
"""Main finetune function used across all tasks."""
|
217 |
+
args = get_args()
|
218 |
+
timers = get_timers()
|
219 |
+
|
220 |
+
# Train and validation data loaders.
|
221 |
+
timers("train/valid/test dataset/dataloder").start()
|
222 |
+
if args.epochs > 0:
|
223 |
+
train_dataset, valid_dataset = train_valid_datasets_provider()
|
224 |
+
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
|
225 |
+
train_dataset, valid_dataset
|
226 |
+
)
|
227 |
+
timers("train/valid/test dataset/dataloder").stop()
|
228 |
+
|
229 |
+
# Build calback function.
|
230 |
+
timers("callback function").start()
|
231 |
+
end_of_epoch_callback = None
|
232 |
+
if end_of_epoch_callback_provider is not None:
|
233 |
+
end_of_epoch_callback = end_of_epoch_callback_provider()
|
234 |
+
timers("callback function").stop()
|
235 |
+
|
236 |
+
# Build model, optimizer and learning rate scheduler.
|
237 |
+
timers("model and optimizer").start()
|
238 |
+
model, optimizer, opt_param_scheduler = \
|
239 |
+
setup_model_and_optimizer(
|
240 |
+
model_provider,
|
241 |
+
model_type,
|
242 |
+
scale_lr_cond=lambda name, param: ".head." in name,
|
243 |
+
lr_mult=args.head_lr_mult)
|
244 |
+
timers("model and optimizer").stop()
|
245 |
+
|
246 |
+
# If pretrained checkpoint is provided and we have not trained for
|
247 |
+
# any iteration (i.e., iteration is zero), then load the pretrained
|
248 |
+
# checkpoint.
|
249 |
+
timers("pretrained checkpoint").start()
|
250 |
+
if args.iteration == 0 and args.pretrained_checkpoint is not None:
|
251 |
+
if args.pretrained_checkpoint_type == 'default':
|
252 |
+
original_load = args.load
|
253 |
+
args.load = args.pretrained_checkpoint
|
254 |
+
_ = load_checkpoint(model, None, None, strict=False)
|
255 |
+
args.load = original_load
|
256 |
+
elif args.pretrained_checkpoint_type == 'external':
|
257 |
+
unwrap_model = utils.unwrap_model(model)
|
258 |
+
state_dict = torch.load(args.pretrained_checkpoint,
|
259 |
+
map_location="cpu")
|
260 |
+
unwrap_model[0].module.backbone.load_state_dict(state_dict,
|
261 |
+
strict=False)
|
262 |
+
elif args.pretrained_checkpoint_type == 'constrastive':
|
263 |
+
unwrap_model = utils.unwrap_model(model)
|
264 |
+
state_dict = torch.load(args.pretrained_checkpoint,
|
265 |
+
map_location="cpu")
|
266 |
+
state_dict = state_dict["model"]
|
267 |
+
state_dict = {k.replace("teacher.backbone.", ""): v
|
268 |
+
for k, v in state_dict.items()
|
269 |
+
if k.startswith("teacher.backbone.")}
|
270 |
+
unwrap_model[0].module.backbone.load_state_dict(state_dict,
|
271 |
+
strict=False)
|
272 |
+
else:
|
273 |
+
raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type))
|
274 |
+
|
275 |
+
# This is critical when only model is loaded. We should make sure
|
276 |
+
# master parameters are also updated.
|
277 |
+
optimizer.reload_model_params()
|
278 |
+
|
279 |
+
timers("pretrained checkpoint").stop()
|
280 |
+
|
281 |
+
# Print setup timing.
|
282 |
+
print_rank_0("done with setups ...")
|
283 |
+
timers.log(
|
284 |
+
[
|
285 |
+
"train/valid/test dataset/dataloder",
|
286 |
+
"callback function",
|
287 |
+
"model and optimizer",
|
288 |
+
"pretrained checkpoint",
|
289 |
+
]
|
290 |
+
)
|
291 |
+
print_rank_0("training ...")
|
292 |
+
|
293 |
+
# Finetune the model.
|
294 |
+
if args.epochs > 0:
|
295 |
+
_train(
|
296 |
+
model,
|
297 |
+
optimizer,
|
298 |
+
opt_param_scheduler,
|
299 |
+
forward_step,
|
300 |
+
train_dataloader,
|
301 |
+
valid_dataloader,
|
302 |
+
end_of_epoch_callback,
|
303 |
+
process_non_loss_data_func,
|
304 |
+
)
|
305 |
+
# Or just evaluate.
|
306 |
+
else:
|
307 |
+
if end_of_epoch_callback is not None:
|
308 |
+
print_rank_0("evaluation only mode, setting epoch to -1")
|
309 |
+
end_of_epoch_callback(model, epoch=-1)
|
310 |
+
|
311 |
+
print_rank_0("done :-)")
|
312 |
+
|
tasks/vision/main.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Main tasks functionality."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
|
21 |
+
sys.path.append(
|
22 |
+
os.path.abspath(
|
23 |
+
os.path.join(
|
24 |
+
os.path.join(os.path.dirname(__file__), os.path.pardir),
|
25 |
+
os.path.pardir,
|
26 |
+
)
|
27 |
+
)
|
28 |
+
)
|
29 |
+
from megatron import get_args
|
30 |
+
from megatron.initialize import initialize_megatron
|
31 |
+
|
32 |
+
def get_tasks_args(parser):
|
33 |
+
"""Provide extra arguments required for tasks."""
|
34 |
+
group = parser.add_argument_group(title="tasks")
|
35 |
+
|
36 |
+
group.add_argument('--task', type=str, default='segment',
|
37 |
+
choices=['classify', 'segment_setr', 'segment_segformer'],
|
38 |
+
help='task name.')
|
39 |
+
group.add_argument("--epochs", type=int, default=None,
|
40 |
+
help="Number of finetunning epochs. Zero results in "
|
41 |
+
"evaluation only.")
|
42 |
+
group.add_argument('--pretrained-checkpoint-type', type=str, default='default',
|
43 |
+
choices=['default', 'external', 'constrastive'],
|
44 |
+
help='Type of pretrained checkpoint')
|
45 |
+
group.add_argument("--pretrained-checkpoint", type=str, default=None,
|
46 |
+
help="Pretrained checkpoint used for finetunning.")
|
47 |
+
group.add_argument('--seg-stride', type=int, default=None,
|
48 |
+
help='sliding window stride during evaluation')
|
49 |
+
return parser
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
|
54 |
+
initialize_megatron(extra_args_provider=get_tasks_args)
|
55 |
+
args = get_args()
|
56 |
+
|
57 |
+
if args.task == 'classify':
|
58 |
+
from tasks.vision.classification.classification import main
|
59 |
+
main()
|
60 |
+
elif args.task == 'segment_setr':
|
61 |
+
from tasks.vision.segmentation.finetune_setr import main
|
62 |
+
main()
|
63 |
+
elif args.task == 'segment_segformer':
|
64 |
+
from tasks.vision.segmentation.finetune_segformer import main
|
65 |
+
main()
|
66 |
+
|
tasks/vision/segmentation/cityscapes.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BSD 3-Clause License
|
2 |
+
#
|
3 |
+
# Copyright (c) Soumith Chintala 2016,
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# Redistribution and use in source and binary forms, with or without
|
7 |
+
# modification, are permitted provided that the following conditions are met:
|
8 |
+
#
|
9 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
10 |
+
# list of conditions and the following disclaimer.
|
11 |
+
#
|
12 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
# this list of conditions and the following disclaimer in the documentation
|
14 |
+
# and/or other materials provided with the distribution.
|
15 |
+
#
|
16 |
+
# * Neither the name of the copyright holder nor the names of its
|
17 |
+
# contributors may be used to endorse or promote products derived from
|
18 |
+
# this software without specific prior written permission.
|
19 |
+
|
20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
|
31 |
+
# code taken from
|
32 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
|
33 |
+
# modified it to change max label index from 255 to 19 (num_classes)
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import json
|
37 |
+
import os
|
38 |
+
from collections import namedtuple
|
39 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
40 |
+
import numpy as np
|
41 |
+
from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str
|
42 |
+
from torchvision.datasets import VisionDataset
|
43 |
+
from PIL import Image
|
44 |
+
from megatron import print_rank_0
|
45 |
+
|
46 |
+
|
47 |
+
class Cityscapes(VisionDataset):
|
48 |
+
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
|
49 |
+
Args:
|
50 |
+
root (string): Root directory of dataset where directory ``leftImg8bit``
|
51 |
+
and ``gtFine`` or ``gtCoarse`` are located.
|
52 |
+
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
|
53 |
+
otherwise ``train``, ``train_extra`` or ``val``
|
54 |
+
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
|
55 |
+
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
|
56 |
+
or ``color``. Can also be a list to output a tuple with all specified target types.
|
57 |
+
transform (callable, optional): A function/transform that takes in a PIL image
|
58 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
59 |
+
target_transform (callable, optional): A function/transform that takes in the
|
60 |
+
target and transforms it.
|
61 |
+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
62 |
+
and returns a transformed version.
|
63 |
+
Examples:
|
64 |
+
Get semantic segmentation target
|
65 |
+
.. code-block:: python
|
66 |
+
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
|
67 |
+
target_type='semantic')
|
68 |
+
img, smnt = dataset[0]
|
69 |
+
Get multiple targets
|
70 |
+
.. code-block:: python
|
71 |
+
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
|
72 |
+
target_type=['instance', 'color', 'polygon'])
|
73 |
+
img, (inst, col, poly) = dataset[0]
|
74 |
+
Validate on the "coarse" set
|
75 |
+
.. code-block:: python
|
76 |
+
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
|
77 |
+
target_type='semantic')
|
78 |
+
img, smnt = dataset[0]
|
79 |
+
"""
|
80 |
+
num_classes = 19
|
81 |
+
ignore_index = 19
|
82 |
+
color_table = torch.tensor(
|
83 |
+
[[128, 64, 128],
|
84 |
+
[244, 35, 232],
|
85 |
+
[70, 70, 70],
|
86 |
+
[102, 102, 156],
|
87 |
+
[190, 153, 153],
|
88 |
+
[153, 153, 153],
|
89 |
+
[250, 170, 30],
|
90 |
+
[220, 220, 0],
|
91 |
+
[107, 142, 35],
|
92 |
+
[152, 251, 152],
|
93 |
+
[70, 130, 180],
|
94 |
+
[220, 20, 60],
|
95 |
+
[255, 0, 0],
|
96 |
+
[0, 0, 142],
|
97 |
+
[0, 0, 70],
|
98 |
+
[0, 60, 100],
|
99 |
+
[0, 80, 100],
|
100 |
+
[0, 0, 230],
|
101 |
+
[119, 11, 32],
|
102 |
+
[0, 0, 0]], dtype=torch.float, device='cuda')
|
103 |
+
|
104 |
+
|
105 |
+
# Based on https://github.com/mcordts/cityscapesScripts
|
106 |
+
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id',
|
107 |
+
'category', 'category_id', 'has_instances', 'ignore_in_eval', 'color'])
|
108 |
+
|
109 |
+
classes = [
|
110 |
+
CityscapesClass('unlabeled', 0, 19, 'void', 0, False, True, (0, 0, 0)),
|
111 |
+
CityscapesClass('ego vehicle', 1, 19, 'void', 0, False, True, (0, 0, 0)),
|
112 |
+
CityscapesClass('rectification border', 2, 19, 'void', 0, False, True, (0, 0, 0)),
|
113 |
+
CityscapesClass('out of roi', 3, 19, 'void', 0, False, True, (0, 0, 0)),
|
114 |
+
CityscapesClass('static', 4, 19, 'void', 0, False, True, (0, 0, 0)),
|
115 |
+
CityscapesClass('dynamic', 5, 19, 'void', 0, False, True, (111, 74, 0)),
|
116 |
+
CityscapesClass('ground', 6, 19, 'void', 0, False, True, (81, 0, 81)),
|
117 |
+
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
|
118 |
+
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
|
119 |
+
CityscapesClass('parking', 9, 19, 'flat', 1, False, True, (250, 170, 160)),
|
120 |
+
CityscapesClass('rail track', 10, 19, 'flat', 1, False, True, (230, 150, 140)),
|
121 |
+
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
|
122 |
+
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
|
123 |
+
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
|
124 |
+
CityscapesClass('guard rail', 14, 19, 'construction', 2, False, True, (180, 165, 180)),
|
125 |
+
CityscapesClass('bridge', 15, 19, 'construction', 2, False, True, (150, 100, 100)),
|
126 |
+
CityscapesClass('tunnel', 16, 19, 'construction', 2, False, True, (150, 120, 90)),
|
127 |
+
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
|
128 |
+
CityscapesClass('polegroup', 18, 19, 'object', 3, False, True, (153, 153, 153)),
|
129 |
+
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
|
130 |
+
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
|
131 |
+
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
|
132 |
+
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
|
133 |
+
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
|
134 |
+
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
|
135 |
+
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
|
136 |
+
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
|
137 |
+
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
|
138 |
+
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
|
139 |
+
CityscapesClass('caravan', 29, 19, 'vehicle', 7, True, True, (0, 0, 90)),
|
140 |
+
CityscapesClass('trailer', 30, 19, 'vehicle', 7, True, True, (0, 0, 110)),
|
141 |
+
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
|
142 |
+
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
|
143 |
+
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
|
144 |
+
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
|
145 |
+
]
|
146 |
+
|
147 |
+
# label2trainid
|
148 |
+
label2trainid = { label.id : label.train_id for label in classes}
|
149 |
+
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
root: str,
|
153 |
+
split: str = "train",
|
154 |
+
mode: str = "fine",
|
155 |
+
resolution: int = 1024,
|
156 |
+
transform: Optional[Callable] = None,
|
157 |
+
target_transform: Optional[Callable] = None,
|
158 |
+
transforms: Optional[Callable] = None,
|
159 |
+
) -> None:
|
160 |
+
super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
|
161 |
+
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
|
162 |
+
self.images_dir = os.path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit', split)
|
163 |
+
self.targets_dir = os.path.join(self.root, 'gtFine_trainvaltest/gtFine', split)
|
164 |
+
self.split = split
|
165 |
+
self.resolution = resolution
|
166 |
+
self.images = []
|
167 |
+
self.targets = []
|
168 |
+
|
169 |
+
for city in sorted(os.listdir(self.images_dir)):
|
170 |
+
img_dir = os.path.join(self.images_dir, city)
|
171 |
+
target_dir = os.path.join(self.targets_dir, city)
|
172 |
+
for file_name in os.listdir(img_dir):
|
173 |
+
target_name = '{}_{}_labelIds.png'.format(file_name.split('_leftImg8bit')[0], self.mode)
|
174 |
+
self.images.append(os.path.join(img_dir, file_name))
|
175 |
+
self.targets.append(os.path.join(target_dir, target_name))
|
176 |
+
|
177 |
+
|
178 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
179 |
+
"""
|
180 |
+
Args:
|
181 |
+
index (int): Index
|
182 |
+
Returns:
|
183 |
+
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
|
184 |
+
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
|
185 |
+
"""
|
186 |
+
image = Image.open(self.images[index]).convert('RGB')
|
187 |
+
|
188 |
+
target = Image.open(self.targets[index])
|
189 |
+
target = np.array(target)
|
190 |
+
|
191 |
+
target_copy = target.copy()
|
192 |
+
for k, v in Cityscapes.label2trainid.items():
|
193 |
+
binary_target = (target == k)
|
194 |
+
target_copy[binary_target] = v
|
195 |
+
target = target_copy
|
196 |
+
|
197 |
+
target = Image.fromarray(target.astype(np.uint8))
|
198 |
+
|
199 |
+
if self.transforms is not None:
|
200 |
+
image, target = self.transforms(image, target)
|
201 |
+
|
202 |
+
return image, target
|
203 |
+
|
204 |
+
def __len__(self) -> int:
|
205 |
+
# len(self.images)
|
206 |
+
return len(self.images)
|
207 |
+
|
tasks/vision/segmentation/data.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import mmcv
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import torchvision.transforms as T
|
8 |
+
from torchvision import datasets
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from megatron.data.autoaugment import ImageNetPolicy
|
11 |
+
from tasks.vision.segmentation.cityscapes import Cityscapes
|
12 |
+
import tasks.vision.segmentation.transforms as ET
|
13 |
+
from megatron.data.autoaugment import ImageNetPolicy
|
14 |
+
from megatron import get_args
|
15 |
+
from PIL import Image, ImageOps
|
16 |
+
|
17 |
+
|
18 |
+
class VitSegmentationJointTransform():
|
19 |
+
def __init__(self, train=True, resolution=None):
|
20 |
+
self.train = train
|
21 |
+
if self.train:
|
22 |
+
self.transform0 = ET.RandomSizeAndCrop(resolution)
|
23 |
+
self.transform1 = ET.RandomHorizontallyFlip()
|
24 |
+
|
25 |
+
def __call__(self, img, mask):
|
26 |
+
if self.train:
|
27 |
+
img, mask = self.transform0(img, mask)
|
28 |
+
img, mask = self.transform1(img, mask)
|
29 |
+
return img, mask
|
30 |
+
|
31 |
+
|
32 |
+
class VitSegmentationImageTransform():
|
33 |
+
def __init__(self, train=True, resolution=None):
|
34 |
+
args = get_args()
|
35 |
+
self.train = train
|
36 |
+
assert args.fp16 or args.bf16
|
37 |
+
self.data_type = torch.half if args.fp16 else torch.bfloat16
|
38 |
+
self.mean_std = args.mean_std
|
39 |
+
if self.train:
|
40 |
+
assert resolution is not None
|
41 |
+
self.transform = T.Compose([
|
42 |
+
ET.PhotoMetricDistortion(),
|
43 |
+
T.ToTensor(),
|
44 |
+
T.Normalize(*self.mean_std),
|
45 |
+
T.ConvertImageDtype(self.data_type)
|
46 |
+
])
|
47 |
+
else:
|
48 |
+
self.transform = T.Compose([
|
49 |
+
T.ToTensor(),
|
50 |
+
T.Normalize(*self.mean_std),
|
51 |
+
T.ConvertImageDtype(self.data_type)
|
52 |
+
])
|
53 |
+
|
54 |
+
def __call__(self, input):
|
55 |
+
output = self.transform(input)
|
56 |
+
return output
|
57 |
+
|
58 |
+
|
59 |
+
class VitSegmentationTargetTransform():
|
60 |
+
def __init__(self, train=True, resolution=None):
|
61 |
+
self.train = train
|
62 |
+
|
63 |
+
def __call__(self, input):
|
64 |
+
output = torch.from_numpy(np.array(input, dtype=np.int32)).long()
|
65 |
+
return output
|
66 |
+
|
67 |
+
|
68 |
+
class RandomSeedSegmentationDataset(Dataset):
|
69 |
+
def __init__(self,
|
70 |
+
dataset,
|
71 |
+
joint_transform,
|
72 |
+
image_transform,
|
73 |
+
target_transform):
|
74 |
+
|
75 |
+
args = get_args()
|
76 |
+
self.base_seed = args.seed
|
77 |
+
self.curr_seed = self.base_seed
|
78 |
+
self.dataset = dataset
|
79 |
+
self.joint_transform = joint_transform
|
80 |
+
self.image_transform = image_transform
|
81 |
+
self.target_transform = target_transform
|
82 |
+
|
83 |
+
def __len__(self):
|
84 |
+
return len(self.dataset)
|
85 |
+
|
86 |
+
def set_epoch(self, epoch):
|
87 |
+
self.curr_seed = self.base_seed + 100 * epoch
|
88 |
+
|
89 |
+
def __getitem__(self, idx):
|
90 |
+
seed = idx + self.curr_seed
|
91 |
+
img, mask = self.dataset[idx]
|
92 |
+
|
93 |
+
torch.manual_seed(seed)
|
94 |
+
random.seed(seed)
|
95 |
+
np.random.seed(seed)
|
96 |
+
img, mask = self.joint_transform(img, mask)
|
97 |
+
img = self.image_transform(img)
|
98 |
+
mask = self.target_transform(mask)
|
99 |
+
|
100 |
+
return img, mask
|
101 |
+
|
102 |
+
|
103 |
+
def build_cityscapes_train_valid_datasets(data_path, image_size):
|
104 |
+
args = get_args()
|
105 |
+
args.num_classes = Cityscapes.num_classes
|
106 |
+
args.ignore_index = Cityscapes.ignore_index
|
107 |
+
args.color_table = Cityscapes.color_table
|
108 |
+
args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
109 |
+
|
110 |
+
train_joint_transform = \
|
111 |
+
VitSegmentationJointTransform(train=True, resolution=image_size)
|
112 |
+
val_joint_transform = \
|
113 |
+
VitSegmentationJointTransform(train=False, resolution=image_size)
|
114 |
+
train_image_transform = \
|
115 |
+
VitSegmentationImageTransform(train=True, resolution=image_size)
|
116 |
+
val_image_transform = \
|
117 |
+
VitSegmentationImageTransform(train=False, resolution=image_size)
|
118 |
+
train_target_transform = \
|
119 |
+
VitSegmentationTargetTransform(train=True, resolution=image_size)
|
120 |
+
val_target_transform = \
|
121 |
+
VitSegmentationTargetTransform(train=False, resolution=image_size)
|
122 |
+
|
123 |
+
# training dataset
|
124 |
+
train_data = Cityscapes(
|
125 |
+
root=data_path[0],
|
126 |
+
split='train',
|
127 |
+
mode='fine',
|
128 |
+
resolution=image_size
|
129 |
+
)
|
130 |
+
train_data = RandomSeedSegmentationDataset(
|
131 |
+
train_data,
|
132 |
+
joint_transform=train_joint_transform,
|
133 |
+
image_transform=train_image_transform,
|
134 |
+
target_transform=train_target_transform)
|
135 |
+
|
136 |
+
# validation dataset
|
137 |
+
val_data = Cityscapes(
|
138 |
+
root=data_path[0],
|
139 |
+
split='val',
|
140 |
+
mode='fine',
|
141 |
+
resolution=image_size
|
142 |
+
)
|
143 |
+
|
144 |
+
val_data = RandomSeedSegmentationDataset(
|
145 |
+
val_data,
|
146 |
+
joint_transform=val_joint_transform,
|
147 |
+
image_transform=val_image_transform,
|
148 |
+
target_transform=val_target_transform)
|
149 |
+
|
150 |
+
return train_data, val_data
|
151 |
+
|
152 |
+
|
153 |
+
def build_train_valid_datasets(data_path, image_size):
|
154 |
+
return build_cityscapes_train_valid_datasets(data_path, image_size)
|
tasks/vision/segmentation/finetune_segformer.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Vision-classification finetuning/evaluation."""
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from functools import partial
|
22 |
+
from megatron import get_args, get_timers
|
23 |
+
from megatron import mpu, print_rank_0, print_rank_last
|
24 |
+
from tasks.vision.finetune_utils import finetune
|
25 |
+
from tasks.vision.finetune_utils import build_data_loader
|
26 |
+
from megatron.utils import average_losses_across_data_parallel_group
|
27 |
+
from megatron.schedules import get_forward_backward_func
|
28 |
+
from tasks.vision.segmentation.data import build_train_valid_datasets
|
29 |
+
from tasks.vision.segmentation.seg_models import SegformerSegmentationModel
|
30 |
+
from megatron.model.vision.utils import resize
|
31 |
+
|
32 |
+
|
33 |
+
def calculate_iou(hist_data):
|
34 |
+
acc = np.diag(hist_data).sum() / hist_data.sum()
|
35 |
+
acc_cls = np.diag(hist_data) / hist_data.sum(axis=1)
|
36 |
+
acc_cls = np.nanmean(acc_cls)
|
37 |
+
divisor = hist_data.sum(axis=1) + hist_data.sum(axis=0) - \
|
38 |
+
np.diag(hist_data)
|
39 |
+
iu = np.diag(hist_data) / divisor
|
40 |
+
return iu, acc, acc_cls
|
41 |
+
|
42 |
+
|
43 |
+
def fast_hist(pred, gtruth, num_classes):
|
44 |
+
# mask indicates pixels we care about
|
45 |
+
mask = (gtruth >= 0) & (gtruth < num_classes)
|
46 |
+
|
47 |
+
# stretch ground truth labels by num_classes
|
48 |
+
# class 0 -> 0
|
49 |
+
# class 1 -> 19
|
50 |
+
# class 18 -> 342
|
51 |
+
#
|
52 |
+
# TP at 0 + 0, 1 + 1, 2 + 2 ...
|
53 |
+
#
|
54 |
+
# TP exist where value == num_classes*class_id + class_id
|
55 |
+
# FP = row[class].sum() - TP
|
56 |
+
# FN = col[class].sum() - TP
|
57 |
+
hist = np.bincount(num_classes * gtruth[mask].astype(int) + pred[mask],
|
58 |
+
minlength=num_classes ** 2)
|
59 |
+
hist = hist.reshape(num_classes, num_classes)
|
60 |
+
return hist
|
61 |
+
|
62 |
+
|
63 |
+
def segmentation():
|
64 |
+
|
65 |
+
def train_valid_datasets_provider():
|
66 |
+
"""Build train and validation dataset."""
|
67 |
+
args = get_args()
|
68 |
+
|
69 |
+
train_ds, valid_ds = build_train_valid_datasets(
|
70 |
+
data_path=args.data_path,
|
71 |
+
image_size=(args.img_h, args.img_w)
|
72 |
+
|
73 |
+
)
|
74 |
+
return train_ds, valid_ds
|
75 |
+
|
76 |
+
def model_provider(pre_process=True, post_process=True):
|
77 |
+
"""Build the model."""
|
78 |
+
args = get_args()
|
79 |
+
|
80 |
+
model = SegformerSegmentationModel(num_classes=args.num_classes,
|
81 |
+
pre_process=pre_process,
|
82 |
+
post_process=post_process)
|
83 |
+
print_rank_0("model = {}".format(model))
|
84 |
+
return model
|
85 |
+
|
86 |
+
def process_batch(batch):
|
87 |
+
"""Process batch and produce inputs for the model."""
|
88 |
+
images = batch[0].cuda().contiguous()
|
89 |
+
masks = batch[1].cuda().contiguous()
|
90 |
+
return images, masks
|
91 |
+
|
92 |
+
def calculate_weight(masks, num_classes):
|
93 |
+
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
|
94 |
+
hist_norm = bins.float()/bins.sum()
|
95 |
+
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
|
96 |
+
return hist
|
97 |
+
|
98 |
+
def cross_entropy_loss_func(images, masks, output_tensor,
|
99 |
+
non_loss_data=False):
|
100 |
+
args = get_args()
|
101 |
+
ignore_index = args.ignore_index
|
102 |
+
color_table = args.color_table
|
103 |
+
logits = output_tensor.contiguous().float()
|
104 |
+
logits = resize(logits, size=masks.shape[1:],
|
105 |
+
mode='bilinear', align_corners=False)
|
106 |
+
|
107 |
+
# Cross-entropy loss.
|
108 |
+
# weight = calculate_weight(masks, num_classes)
|
109 |
+
loss = F.cross_entropy(logits, masks, ignore_index=ignore_index)
|
110 |
+
|
111 |
+
if not non_loss_data:
|
112 |
+
# Reduce loss for logging.
|
113 |
+
averaged_loss = average_losses_across_data_parallel_group([loss])
|
114 |
+
return loss, {'lm loss': averaged_loss[0]}
|
115 |
+
else:
|
116 |
+
seg_mask = logits.argmax(dim=1)
|
117 |
+
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
|
118 |
+
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
|
119 |
+
return torch.cat((images, output_mask, gt_mask), dim=2), loss
|
120 |
+
|
121 |
+
def _cross_entropy_forward_step(batch, model):
|
122 |
+
"""Simple forward step with cross-entropy loss."""
|
123 |
+
timers = get_timers()
|
124 |
+
|
125 |
+
# Get the batch.
|
126 |
+
timers("batch generator").start()
|
127 |
+
import types
|
128 |
+
if isinstance(batch, types.GeneratorType):
|
129 |
+
batch_ = next(batch)
|
130 |
+
else:
|
131 |
+
batch_ = batch
|
132 |
+
images, masks = process_batch(batch_)
|
133 |
+
timers("batch generator").stop()
|
134 |
+
|
135 |
+
# Forward model.
|
136 |
+
output_tensor = model(images)
|
137 |
+
|
138 |
+
return output_tensor, partial(cross_entropy_loss_func, images, masks)
|
139 |
+
|
140 |
+
def calculate_correct_answers(model, dataloader, epoch):
|
141 |
+
"""Calculate correct over total answers"""
|
142 |
+
|
143 |
+
forward_backward_func = get_forward_backward_func()
|
144 |
+
for m in model:
|
145 |
+
m.eval()
|
146 |
+
|
147 |
+
def loss_func(labels, output_tensor):
|
148 |
+
args = get_args()
|
149 |
+
logits = output_tensor
|
150 |
+
logits = resize(logits, size=labels.shape[1:],
|
151 |
+
mode='bilinear', align_corners=False)
|
152 |
+
|
153 |
+
loss_dict = {}
|
154 |
+
# Compute the correct answers.
|
155 |
+
probs = logits.contiguous().float().softmax(dim=1)
|
156 |
+
max_probs, preds = torch.max(probs, 1)
|
157 |
+
|
158 |
+
preds = preds.cpu().numpy()
|
159 |
+
performs = fast_hist(preds.flatten(),
|
160 |
+
labels.cpu().numpy().flatten(),
|
161 |
+
args.ignore_index)
|
162 |
+
loss_dict['performs'] = performs
|
163 |
+
return 0, loss_dict
|
164 |
+
|
165 |
+
# defined inside to capture output_predictions
|
166 |
+
def correct_answers_forward_step(batch, model):
|
167 |
+
try:
|
168 |
+
batch_ = next(batch)
|
169 |
+
except BaseException:
|
170 |
+
batch_ = batch
|
171 |
+
images, labels = process_batch(batch_)
|
172 |
+
|
173 |
+
# Forward model.
|
174 |
+
output_tensor = model(images)
|
175 |
+
|
176 |
+
return output_tensor, partial(loss_func, labels)
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
# For all the batches in the dataset.
|
180 |
+
performs = None
|
181 |
+
for _, batch in enumerate(dataloader):
|
182 |
+
loss_dicts = forward_backward_func(correct_answers_forward_step,
|
183 |
+
batch, model,
|
184 |
+
optimizer=None,
|
185 |
+
timers=None,
|
186 |
+
forward_only=True)
|
187 |
+
for loss_dict in loss_dicts:
|
188 |
+
if performs is None:
|
189 |
+
performs = loss_dict['performs']
|
190 |
+
else:
|
191 |
+
performs += loss_dict['performs']
|
192 |
+
|
193 |
+
for m in model:
|
194 |
+
m.train()
|
195 |
+
# Reduce.
|
196 |
+
if mpu.is_pipeline_last_stage():
|
197 |
+
performs_tensor = torch.cuda.FloatTensor(performs)
|
198 |
+
torch.distributed.all_reduce(performs_tensor,
|
199 |
+
group=mpu.get_data_parallel_group())
|
200 |
+
hist = performs_tensor.cpu().numpy()
|
201 |
+
iu, acc, acc_cls = calculate_iou(hist)
|
202 |
+
miou = np.nanmean(iu)
|
203 |
+
|
204 |
+
return iu, miou
|
205 |
+
|
206 |
+
def accuracy_func_provider():
|
207 |
+
"""Provide function that calculates accuracies."""
|
208 |
+
args = get_args()
|
209 |
+
|
210 |
+
train_ds, valid_ds = build_train_valid_datasets(
|
211 |
+
data_path=args.data_path,
|
212 |
+
image_size=(args.img_h, args.img_w)
|
213 |
+
)
|
214 |
+
dataloader = build_data_loader(
|
215 |
+
valid_ds,
|
216 |
+
args.micro_batch_size,
|
217 |
+
num_workers=args.num_workers,
|
218 |
+
drop_last=(mpu.get_data_parallel_world_size() > 1),
|
219 |
+
shuffle=False
|
220 |
+
)
|
221 |
+
|
222 |
+
def metrics_func(model, epoch):
|
223 |
+
print_rank_0("calculating metrics ...")
|
224 |
+
iou, miou = calculate_correct_answers(model, dataloader, epoch)
|
225 |
+
print_rank_last(
|
226 |
+
" >> |epoch: {}| overall: iou = {},"
|
227 |
+
"miou = {:.4f} %".format(epoch, iou, miou*100.0)
|
228 |
+
)
|
229 |
+
return metrics_func
|
230 |
+
|
231 |
+
def dump_output_data(data, iteration, writer):
|
232 |
+
for (output_tb, loss) in data:
|
233 |
+
# output_tb[output_tb < 0] = 0
|
234 |
+
# output_tb[output_tb > 1] = 1
|
235 |
+
writer.add_images("image-outputseg-realseg", output_tb,
|
236 |
+
global_step=None, walltime=None,
|
237 |
+
dataformats='NCHW')
|
238 |
+
|
239 |
+
"""Finetune/evaluate."""
|
240 |
+
finetune(
|
241 |
+
train_valid_datasets_provider,
|
242 |
+
model_provider,
|
243 |
+
forward_step=_cross_entropy_forward_step,
|
244 |
+
process_non_loss_data_func=dump_output_data,
|
245 |
+
end_of_epoch_callback_provider=accuracy_func_provider,
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
def main():
|
250 |
+
segmentation()
|
251 |
+
|
tasks/vision/segmentation/finetune_setr.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Vision-classification finetuning/evaluation."""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from functools import partial
|
21 |
+
from megatron import get_args, get_timers
|
22 |
+
from megatron import mpu, print_rank_0, print_rank_last
|
23 |
+
from tasks.vision.finetune_utils import finetune
|
24 |
+
from tasks.vision.finetune_utils import build_data_loader
|
25 |
+
from megatron.utils import average_losses_across_data_parallel_group
|
26 |
+
from megatron.schedules import get_forward_backward_func
|
27 |
+
from tasks.vision.segmentation.metrics import CFMatrix
|
28 |
+
from tasks.vision.segmentation.data import build_train_valid_datasets
|
29 |
+
from tasks.vision.segmentation.seg_models import SetrSegmentationModel
|
30 |
+
from tasks.vision.segmentation.utils import slidingcrops, slidingjoins
|
31 |
+
|
32 |
+
def segmentation():
|
33 |
+
def train_valid_datasets_provider():
|
34 |
+
"""Build train and validation dataset."""
|
35 |
+
args = get_args()
|
36 |
+
|
37 |
+
train_ds, valid_ds = build_train_valid_datasets(
|
38 |
+
data_path=args.data_path,
|
39 |
+
image_size=(args.img_h, args.img_w)
|
40 |
+
|
41 |
+
)
|
42 |
+
return train_ds, valid_ds
|
43 |
+
|
44 |
+
def model_provider(pre_process=True, post_process=True):
|
45 |
+
"""Build the model."""
|
46 |
+
args = get_args()
|
47 |
+
|
48 |
+
return SetrSegmentationModel(num_classes=args.num_classes,
|
49 |
+
pre_process=pre_process,
|
50 |
+
post_process=post_process)
|
51 |
+
|
52 |
+
def process_batch(batch):
|
53 |
+
"""Process batch and produce inputs for the model."""
|
54 |
+
images = batch[0].cuda().contiguous()
|
55 |
+
masks = batch[1].cuda().contiguous()
|
56 |
+
return images, masks
|
57 |
+
|
58 |
+
def calculate_weight(masks, num_classes):
|
59 |
+
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
|
60 |
+
hist_norm = bins.float()/bins.sum()
|
61 |
+
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
|
62 |
+
return hist
|
63 |
+
|
64 |
+
def cross_entropy_loss_func(images, masks, output_tensor, non_loss_data=False):
|
65 |
+
args = get_args()
|
66 |
+
ignore_index = args.ignore_index
|
67 |
+
color_table = args.color_table
|
68 |
+
weight = calculate_weight(masks, args.num_classes)
|
69 |
+
logits = output_tensor.contiguous().float()
|
70 |
+
loss = F.cross_entropy(logits, masks, weight=weight, ignore_index=ignore_index)
|
71 |
+
|
72 |
+
if not non_loss_data:
|
73 |
+
# Reduce loss for logging.
|
74 |
+
averaged_loss = average_losses_across_data_parallel_group([loss])
|
75 |
+
|
76 |
+
return loss, {'lm loss': averaged_loss[0]}
|
77 |
+
else:
|
78 |
+
seg_mask = logits.argmax(dim=1)
|
79 |
+
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
|
80 |
+
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
|
81 |
+
return torch.cat((images, output_mask, gt_mask), dim=2), loss
|
82 |
+
|
83 |
+
def _cross_entropy_forward_step(batch, model):
|
84 |
+
"""Simple forward step with cross-entropy loss."""
|
85 |
+
args = get_args()
|
86 |
+
timers = get_timers()
|
87 |
+
|
88 |
+
# Get the batch.
|
89 |
+
timers("batch generator").start()
|
90 |
+
import types
|
91 |
+
if isinstance(batch, types.GeneratorType):
|
92 |
+
batch_ = next(batch)
|
93 |
+
else:
|
94 |
+
batch_ = batch
|
95 |
+
images, masks = process_batch(batch_)
|
96 |
+
timers("batch generator").stop()
|
97 |
+
|
98 |
+
# Forward model.
|
99 |
+
if not model.training:
|
100 |
+
images, masks, _, _ = slidingcrops(images, masks)
|
101 |
+
#print_rank_0("images size = {}".format(images.size()))
|
102 |
+
|
103 |
+
if not model.training:
|
104 |
+
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
|
105 |
+
else:
|
106 |
+
output_tensor = model(images)
|
107 |
+
|
108 |
+
return output_tensor, partial(cross_entropy_loss_func, images, masks)
|
109 |
+
|
110 |
+
def calculate_correct_answers(model, dataloader, epoch):
|
111 |
+
"""Calculate correct over total answers"""
|
112 |
+
|
113 |
+
forward_backward_func = get_forward_backward_func()
|
114 |
+
for m in model:
|
115 |
+
m.eval()
|
116 |
+
|
117 |
+
def loss_func(labels, slices_info, img_size, output_tensor):
|
118 |
+
args = get_args()
|
119 |
+
logits = output_tensor
|
120 |
+
|
121 |
+
loss_dict = {}
|
122 |
+
# Compute the correct answers.
|
123 |
+
probs = logits.contiguous().float().softmax(dim=1)
|
124 |
+
max_probs, preds = torch.max(probs, 1)
|
125 |
+
preds = preds.int()
|
126 |
+
preds, labels = slidingjoins(preds, max_probs, labels, slices_info, img_size)
|
127 |
+
_, performs = CFMatrix()(preds, labels, args.ignore_index)
|
128 |
+
|
129 |
+
loss_dict['performs'] = performs
|
130 |
+
return 0, loss_dict
|
131 |
+
|
132 |
+
# defined inside to capture output_predictions
|
133 |
+
def correct_answers_forward_step(batch, model):
|
134 |
+
args = get_args()
|
135 |
+
try:
|
136 |
+
batch_ = next(batch)
|
137 |
+
except BaseException:
|
138 |
+
batch_ = batch
|
139 |
+
images, labels = process_batch(batch_)
|
140 |
+
|
141 |
+
assert not model.training
|
142 |
+
images, labels, slices_info, img_size = slidingcrops(images, labels)
|
143 |
+
# Forward model.
|
144 |
+
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
|
145 |
+
|
146 |
+
return output_tensor, partial(loss_func, labels, slices_info, img_size)
|
147 |
+
|
148 |
+
with torch.no_grad():
|
149 |
+
# For all the batches in the dataset.
|
150 |
+
performs = None
|
151 |
+
for _, batch in enumerate(dataloader):
|
152 |
+
loss_dicts = forward_backward_func(correct_answers_forward_step,
|
153 |
+
batch, model,
|
154 |
+
optimizer=None,
|
155 |
+
timers=None,
|
156 |
+
forward_only=True)
|
157 |
+
for loss_dict in loss_dicts:
|
158 |
+
if performs is None:
|
159 |
+
performs = loss_dict['performs']
|
160 |
+
else:
|
161 |
+
performs += loss_dict['performs']
|
162 |
+
|
163 |
+
for m in model:
|
164 |
+
m.train()
|
165 |
+
# Reduce.
|
166 |
+
if mpu.is_pipeline_last_stage():
|
167 |
+
torch.distributed.all_reduce(performs,
|
168 |
+
group=mpu.get_data_parallel_group())
|
169 |
+
# Print on screen.
|
170 |
+
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
|
171 |
+
true_positive = performs[:, 0]
|
172 |
+
false_positive = performs[:, 1]
|
173 |
+
false_negative = performs[:, 3]
|
174 |
+
|
175 |
+
iou = true_positive / (true_positive + false_positive + false_negative)
|
176 |
+
miou = iou[~torch.isnan(iou)].mean()
|
177 |
+
|
178 |
+
return iou.tolist(), miou.item()
|
179 |
+
|
180 |
+
def accuracy_func_provider():
|
181 |
+
"""Provide function that calculates accuracies."""
|
182 |
+
args = get_args()
|
183 |
+
|
184 |
+
train_ds, valid_ds = build_train_valid_datasets(
|
185 |
+
data_path=args.data_path,
|
186 |
+
image_size=(args.img_h, args.img_w)
|
187 |
+
)
|
188 |
+
dataloader = build_data_loader(
|
189 |
+
valid_ds,
|
190 |
+
args.micro_batch_size,
|
191 |
+
num_workers=args.num_workers,
|
192 |
+
drop_last=(mpu.get_data_parallel_world_size() > 1),
|
193 |
+
shuffle=False
|
194 |
+
)
|
195 |
+
|
196 |
+
def metrics_func(model, epoch):
|
197 |
+
print_rank_0("calculating metrics ...")
|
198 |
+
iou, miou = calculate_correct_answers(model, dataloader, epoch)
|
199 |
+
print_rank_last(
|
200 |
+
" >> |epoch: {}| overall: iou = {},"
|
201 |
+
"miou = {:.4f} %".format(epoch, iou, miou*100.0)
|
202 |
+
)
|
203 |
+
return metrics_func
|
204 |
+
|
205 |
+
def dump_output_data(data, iteration, writer):
|
206 |
+
for (output_tb, loss) in data:
|
207 |
+
# output_tb[output_tb < 0] = 0
|
208 |
+
# output_tb[output_tb > 1] = 1
|
209 |
+
writer.add_images("image-outputseg-realseg", output_tb,
|
210 |
+
global_step=None, walltime=None,
|
211 |
+
dataformats='NCHW')
|
212 |
+
|
213 |
+
"""Finetune/evaluate."""
|
214 |
+
finetune(
|
215 |
+
train_valid_datasets_provider,
|
216 |
+
model_provider,
|
217 |
+
forward_step=_cross_entropy_forward_step,
|
218 |
+
process_non_loss_data_func=dump_output_data,
|
219 |
+
end_of_epoch_callback_provider=accuracy_func_provider,
|
220 |
+
)
|
221 |
+
|
222 |
+
|
223 |
+
def main():
|
224 |
+
segmentation()
|
225 |
+
|
tasks/vision/segmentation/metrics.py
ADDED
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: UTF-8 -*-
|
3 |
+
#copyright (c) go-hiroaki & Chokurei
|
4 |
+
#email: [email protected]
|
5 | |
6 |
+
#
|
7 |
+
#
|
8 |
+
# This source code is licensed under the MIT license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
eps = 1e-6
|
16 |
+
|
17 |
+
def _binarize(y_data, threshold):
|
18 |
+
"""
|
19 |
+
args:
|
20 |
+
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
|
21 |
+
threshold : [float] [0.0, 1.0]
|
22 |
+
return 4-d binarized y_data
|
23 |
+
"""
|
24 |
+
y_data[y_data < threshold] = 0.0
|
25 |
+
y_data[y_data >= threshold] = 1.0
|
26 |
+
return y_data
|
27 |
+
|
28 |
+
def _argmax(y_data, dim):
|
29 |
+
"""
|
30 |
+
args:
|
31 |
+
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
|
32 |
+
dim : int
|
33 |
+
return 3-d [int] y_data
|
34 |
+
"""
|
35 |
+
return torch.argmax(y_data, dim).int()
|
36 |
+
|
37 |
+
|
38 |
+
def _get_tp(y_pred, y_true):
|
39 |
+
"""
|
40 |
+
args:
|
41 |
+
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
|
42 |
+
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
|
43 |
+
return [float] true_positive
|
44 |
+
"""
|
45 |
+
return torch.sum(y_true * y_pred).float()
|
46 |
+
|
47 |
+
|
48 |
+
def _get_fp(y_pred, y_true):
|
49 |
+
"""
|
50 |
+
args:
|
51 |
+
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
|
52 |
+
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
|
53 |
+
return [float] false_positive
|
54 |
+
"""
|
55 |
+
return torch.sum((1 - y_true) * y_pred).float()
|
56 |
+
|
57 |
+
|
58 |
+
def _get_tn(y_pred, y_true):
|
59 |
+
"""
|
60 |
+
args:
|
61 |
+
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
|
62 |
+
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
|
63 |
+
return [float] true_negative
|
64 |
+
"""
|
65 |
+
return torch.sum((1 - y_true) * (1 - y_pred)).float()
|
66 |
+
|
67 |
+
|
68 |
+
def _get_fn(y_pred, y_true):
|
69 |
+
"""
|
70 |
+
args:
|
71 |
+
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
|
72 |
+
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
|
73 |
+
return [float] false_negative
|
74 |
+
"""
|
75 |
+
return torch.sum(y_true * (1 - y_pred)).float()
|
76 |
+
|
77 |
+
|
78 |
+
def _get_weights(y_true, nb_ch):
|
79 |
+
"""
|
80 |
+
args:
|
81 |
+
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
|
82 |
+
nb_ch : int
|
83 |
+
return [float] weights
|
84 |
+
"""
|
85 |
+
batch_size, img_rows, img_cols = y_true.shape
|
86 |
+
pixels = batch_size * img_rows * img_cols
|
87 |
+
weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
|
88 |
+
return weights
|
89 |
+
|
90 |
+
|
91 |
+
class CFMatrix(object):
|
92 |
+
def __init__(self, des=None):
|
93 |
+
self.des = des
|
94 |
+
|
95 |
+
def __repr__(self):
|
96 |
+
return "ConfusionMatrix"
|
97 |
+
|
98 |
+
def __call__(self, y_pred, y_true, ignore_index, threshold=0.5):
|
99 |
+
|
100 |
+
"""
|
101 |
+
args:
|
102 |
+
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
|
103 |
+
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
104 |
+
threshold : [0.0, 1.0]
|
105 |
+
return confusion matrix
|
106 |
+
"""
|
107 |
+
batch_size, img_rows, img_cols = y_pred.shape
|
108 |
+
chs = ignore_index
|
109 |
+
device = y_true.device
|
110 |
+
if chs == 1:
|
111 |
+
y_pred = _binarize(y_pred, threshold)
|
112 |
+
y_true = _binarize(y_true, threshold)
|
113 |
+
nb_tp = _get_tp(y_pred, y_true)
|
114 |
+
nb_fp = _get_fp(y_pred, y_true)
|
115 |
+
nb_tn = _get_tn(y_pred, y_true)
|
116 |
+
nb_fn = _get_fn(y_pred, y_true)
|
117 |
+
mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
|
118 |
+
performs = None
|
119 |
+
else:
|
120 |
+
performs = torch.zeros(chs, 4).to(device)
|
121 |
+
weights = _get_weights(y_true, chs)
|
122 |
+
for ch in range(chs):
|
123 |
+
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
|
124 |
+
y_false_ch = torch.zeros(batch_size, img_rows, img_cols)
|
125 |
+
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
|
126 |
+
y_true_ch[y_true == ch] = 1
|
127 |
+
y_false_ch[torch.logical_and((y_true != ch), (y_true != ignore_index))] = 1
|
128 |
+
y_pred_ch[y_pred == ch] = 1
|
129 |
+
nb_tp = _get_tp(y_pred_ch, y_true_ch)
|
130 |
+
nb_fp = torch.sum(y_false_ch * y_pred_ch).float()
|
131 |
+
nb_tn = torch.sum(y_false_ch * (1 - y_pred_ch)).float()
|
132 |
+
nb_fn = _get_fn(y_pred_ch, y_true_ch)
|
133 |
+
performs[int(ch), :] = torch.FloatTensor([nb_tp, nb_fp, nb_tn, nb_fn])
|
134 |
+
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
|
135 |
+
return mperforms, performs
|
136 |
+
|
137 |
+
|
138 |
+
class OAAcc(object):
|
139 |
+
def __init__(self, des="Overall Accuracy"):
|
140 |
+
self.des = des
|
141 |
+
|
142 |
+
def __repr__(self):
|
143 |
+
return "OAcc"
|
144 |
+
|
145 |
+
def __call__(self, y_pred, y_true, threshold=0.5):
|
146 |
+
"""
|
147 |
+
args:
|
148 |
+
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
149 |
+
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
150 |
+
threshold : [0.0, 1.0]
|
151 |
+
return (tp+tn)/total
|
152 |
+
"""
|
153 |
+
batch_size, chs, img_rows, img_cols = y_true.shape
|
154 |
+
device = y_true.device
|
155 |
+
if chs == 1:
|
156 |
+
y_pred = _binarize(y_pred, threshold)
|
157 |
+
y_true = _binarize(y_true, threshold)
|
158 |
+
else:
|
159 |
+
y_pred = _argmax(y_pred, 1)
|
160 |
+
y_true = _argmax(y_true, 1)
|
161 |
+
|
162 |
+
nb_tp_tn = torch.sum(y_true == y_pred).float()
|
163 |
+
mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
|
164 |
+
performs = None
|
165 |
+
return mperforms, performs
|
166 |
+
|
167 |
+
|
168 |
+
class Precision(object):
|
169 |
+
def __init__(self, des="Precision"):
|
170 |
+
self.des = des
|
171 |
+
|
172 |
+
def __repr__(self):
|
173 |
+
return "Prec"
|
174 |
+
|
175 |
+
def __call__(self, y_pred, y_true, threshold=0.5):
|
176 |
+
"""
|
177 |
+
args:
|
178 |
+
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
179 |
+
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
180 |
+
threshold : [0.0, 1.0]
|
181 |
+
return tp/(tp+fp)
|
182 |
+
"""
|
183 |
+
batch_size, chs, img_rows, img_cols = y_true.shape
|
184 |
+
device = y_true.device
|
185 |
+
if chs == 1:
|
186 |
+
y_pred = _binarize(y_pred, threshold)
|
187 |
+
y_true = _binarize(y_true, threshold)
|
188 |
+
nb_tp = _get_tp(y_pred, y_true)
|
189 |
+
nb_fp = _get_fp(y_pred, y_true)
|
190 |
+
mperforms = nb_tp / (nb_tp + nb_fp + esp)
|
191 |
+
performs = None
|
192 |
+
else:
|
193 |
+
y_pred = _argmax(y_pred, 1)
|
194 |
+
y_true = _argmax(y_true, 1)
|
195 |
+
performs = torch.zeros(chs, 1).to(device)
|
196 |
+
weights = _get_weights(y_true, chs)
|
197 |
+
for ch in range(chs):
|
198 |
+
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
|
199 |
+
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
|
200 |
+
y_true_ch[y_true == ch] = 1
|
201 |
+
y_pred_ch[y_pred == ch] = 1
|
202 |
+
nb_tp = _get_tp(y_pred_ch, y_true_ch)
|
203 |
+
nb_fp = _get_fp(y_pred_ch, y_true_ch)
|
204 |
+
performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
|
205 |
+
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
|
206 |
+
return mperforms, performs
|
207 |
+
|
208 |
+
|
209 |
+
class Recall(object):
|
210 |
+
def __init__(self, des="Recall"):
|
211 |
+
self.des = des
|
212 |
+
|
213 |
+
def __repr__(self):
|
214 |
+
return "Reca"
|
215 |
+
|
216 |
+
def __call__(self, y_pred, y_true, threshold=0.5):
|
217 |
+
"""
|
218 |
+
args:
|
219 |
+
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
220 |
+
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
221 |
+
threshold : [0.0, 1.0]
|
222 |
+
return tp/(tp+fn)
|
223 |
+
"""
|
224 |
+
batch_size, chs, img_rows, img_cols = y_true.shape
|
225 |
+
device = y_true.device
|
226 |
+
if chs == 1:
|
227 |
+
y_pred = _binarize(y_pred, threshold)
|
228 |
+
y_true = _binarize(y_true, threshold)
|
229 |
+
nb_tp = _get_tp(y_pred, y_true)
|
230 |
+
nb_fn = _get_fn(y_pred, y_true)
|
231 |
+
mperforms = nb_tp / (nb_tp + nb_fn + esp)
|
232 |
+
performs = None
|
233 |
+
else:
|
234 |
+
y_pred = _argmax(y_pred, 1)
|
235 |
+
y_true = _argmax(y_true, 1)
|
236 |
+
performs = torch.zeros(chs, 1).to(device)
|
237 |
+
weights = _get_weights(y_true, chs)
|
238 |
+
for ch in range(chs):
|
239 |
+
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
|
240 |
+
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
|
241 |
+
y_true_ch[y_true == ch] = 1
|
242 |
+
y_pred_ch[y_pred == ch] = 1
|
243 |
+
nb_tp = _get_tp(y_pred_ch, y_true_ch)
|
244 |
+
nb_fn = _get_fn(y_pred_ch, y_true_ch)
|
245 |
+
performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
|
246 |
+
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
|
247 |
+
return mperforms, performs
|
248 |
+
|
249 |
+
|
250 |
+
class F1Score(object):
|
251 |
+
def __init__(self, des="F1Score"):
|
252 |
+
self.des = des
|
253 |
+
|
254 |
+
def __repr__(self):
|
255 |
+
return "F1Sc"
|
256 |
+
|
257 |
+
def __call__(self, y_pred, y_true, threshold=0.5):
|
258 |
+
|
259 |
+
"""
|
260 |
+
args:
|
261 |
+
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
262 |
+
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
263 |
+
threshold : [0.0, 1.0]
|
264 |
+
return 2*precision*recall/(precision+recall)
|
265 |
+
"""
|
266 |
+
batch_size, chs, img_rows, img_cols = y_true.shape
|
267 |
+
device = y_true.device
|
268 |
+
if chs == 1:
|
269 |
+
y_pred = _binarize(y_pred, threshold)
|
270 |
+
y_true = _binarize(y_true, threshold)
|
271 |
+
nb_tp = _get_tp(y_pred, y_true)
|
272 |
+
nb_fp = _get_fp(y_pred, y_true)
|
273 |
+
nb_fn = _get_fn(y_pred, y_true)
|
274 |
+
_precision = nb_tp / (nb_tp + nb_fp + esp)
|
275 |
+
_recall = nb_tp / (nb_tp + nb_fn + esp)
|
276 |
+
mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
|
277 |
+
performs = None
|
278 |
+
else:
|
279 |
+
y_pred = _argmax(y_pred, 1)
|
280 |
+
y_true = _argmax(y_true, 1)
|
281 |
+
performs = torch.zeros(chs, 1).to(device)
|
282 |
+
weights = _get_weights(y_true, chs)
|
283 |
+
for ch in range(chs):
|
284 |
+
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
|
285 |
+
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
|
286 |
+
y_true_ch[y_true == ch] = 1
|
287 |
+
y_pred_ch[y_pred == ch] = 1
|
288 |
+
nb_tp = _get_tp(y_pred_ch, y_true_ch)
|
289 |
+
nb_fp = _get_fp(y_pred_ch, y_true_ch)
|
290 |
+
nb_fn = _get_fn(y_pred_ch, y_true_ch)
|
291 |
+
_precision = nb_tp / (nb_tp + nb_fp + esp)
|
292 |
+
_recall = nb_tp / (nb_tp + nb_fn + esp)
|
293 |
+
performs[int(ch)] = 2 * _precision * \
|
294 |
+
_recall / (_precision + _recall + esp)
|
295 |
+
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
|
296 |
+
return mperforms, performs
|
297 |
+
|
298 |
+
|
299 |
+
class Kappa(object):
|
300 |
+
def __init__(self, des="Kappa"):
|
301 |
+
self.des = des
|
302 |
+
|
303 |
+
def __repr__(self):
|
304 |
+
return "Kapp"
|
305 |
+
|
306 |
+
def __call__(self, y_pred, y_true, threshold=0.5):
|
307 |
+
|
308 |
+
"""
|
309 |
+
args:
|
310 |
+
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
311 |
+
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
312 |
+
threshold : [0.0, 1.0]
|
313 |
+
return (Po-Pe)/(1-Pe)
|
314 |
+
"""
|
315 |
+
batch_size, chs, img_rows, img_cols = y_true.shape
|
316 |
+
device = y_true.device
|
317 |
+
if chs == 1:
|
318 |
+
y_pred = _binarize(y_pred, threshold)
|
319 |
+
y_true = _binarize(y_true, threshold)
|
320 |
+
nb_tp = _get_tp(y_pred, y_true)
|
321 |
+
nb_fp = _get_fp(y_pred, y_true)
|
322 |
+
nb_tn = _get_tn(y_pred, y_true)
|
323 |
+
nb_fn = _get_fn(y_pred, y_true)
|
324 |
+
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
|
325 |
+
Po = (nb_tp + nb_tn) / nb_total
|
326 |
+
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
|
327 |
+
(nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
|
328 |
+
mperforms = (Po - Pe) / (1 - Pe + esp)
|
329 |
+
performs = None
|
330 |
+
else:
|
331 |
+
y_pred = _argmax(y_pred, 1)
|
332 |
+
y_true = _argmax(y_true, 1)
|
333 |
+
performs = torch.zeros(chs, 1).to(device)
|
334 |
+
weights = _get_weights(y_true, chs)
|
335 |
+
for ch in range(chs):
|
336 |
+
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
|
337 |
+
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
|
338 |
+
y_true_ch[y_true == ch] = 1
|
339 |
+
y_pred_ch[y_pred == ch] = 1
|
340 |
+
nb_tp = _get_tp(y_pred_ch, y_true_ch)
|
341 |
+
nb_fp = _get_fp(y_pred_ch, y_true_ch)
|
342 |
+
nb_tn = _get_tn(y_pred_ch, y_true_ch)
|
343 |
+
nb_fn = _get_fn(y_pred_ch, y_true_ch)
|
344 |
+
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
|
345 |
+
Po = (nb_tp + nb_tn) / nb_total
|
346 |
+
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
|
347 |
+
+ (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
|
348 |
+
performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
|
349 |
+
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
|
350 |
+
return mperforms, performs
|
351 |
+
|
352 |
+
|
353 |
+
class Jaccard(object):
|
354 |
+
def __init__(self, des="Jaccard"):
|
355 |
+
self.des = des
|
356 |
+
|
357 |
+
def __repr__(self):
|
358 |
+
return "Jacc"
|
359 |
+
|
360 |
+
def __call__(self, y_pred, y_true, threshold=0.5):
|
361 |
+
"""
|
362 |
+
args:
|
363 |
+
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
364 |
+
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
|
365 |
+
threshold : [0.0, 1.0]
|
366 |
+
return intersection / (sum-intersection)
|
367 |
+
"""
|
368 |
+
batch_size, chs, img_rows, img_cols = y_true.shape
|
369 |
+
device = y_true.device
|
370 |
+
if chs == 1:
|
371 |
+
y_pred = _binarize(y_pred, threshold)
|
372 |
+
y_true = _binarize(y_true, threshold)
|
373 |
+
_intersec = torch.sum(y_true * y_pred).float()
|
374 |
+
_sum = torch.sum(y_true + y_pred).float()
|
375 |
+
mperforms = _intersec / (_sum - _intersec + esp)
|
376 |
+
performs = None
|
377 |
+
else:
|
378 |
+
y_pred = _argmax(y_pred, 1)
|
379 |
+
y_true = _argmax(y_true, 1)
|
380 |
+
performs = torch.zeros(chs, 1).to(device)
|
381 |
+
weights = _get_weights(y_true, chs)
|
382 |
+
for ch in range(chs):
|
383 |
+
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
|
384 |
+
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
|
385 |
+
y_true_ch[y_true == ch] = 1
|
386 |
+
y_pred_ch[y_pred == ch] = 1
|
387 |
+
_intersec = torch.sum(y_true_ch * y_pred_ch).float()
|
388 |
+
_sum = torch.sum(y_true_ch + y_pred_ch).float()
|
389 |
+
performs[int(ch)] = _intersec / (_sum - _intersec + esp)
|
390 |
+
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
|
391 |
+
return mperforms, performs
|
392 |
+
|
393 |
+
|
394 |
+
class MSE(object):
|
395 |
+
def __init__(self, des="Mean Square Error"):
|
396 |
+
self.des = des
|
397 |
+
|
398 |
+
def __repr__(self):
|
399 |
+
return "MSE"
|
400 |
+
|
401 |
+
def __call__(self, y_pred, y_true, dim=1, threshold=None):
|
402 |
+
"""
|
403 |
+
args:
|
404 |
+
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
405 |
+
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
406 |
+
threshold : [0.0, 1.0]
|
407 |
+
return mean_squared_error, smaller the better
|
408 |
+
"""
|
409 |
+
if threshold:
|
410 |
+
y_pred = _binarize(y_pred, threshold)
|
411 |
+
return torch.mean((y_pred - y_true) ** 2)
|
412 |
+
|
413 |
+
|
414 |
+
class PSNR(object):
|
415 |
+
def __init__(self, des="Peak Signal to Noise Ratio"):
|
416 |
+
self.des = des
|
417 |
+
|
418 |
+
def __repr__(self):
|
419 |
+
return "PSNR"
|
420 |
+
|
421 |
+
def __call__(self, y_pred, y_true, dim=1, threshold=None):
|
422 |
+
"""
|
423 |
+
args:
|
424 |
+
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
425 |
+
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
426 |
+
threshold : [0.0, 1.0]
|
427 |
+
return PSNR, larger the better
|
428 |
+
"""
|
429 |
+
if threshold:
|
430 |
+
y_pred = _binarize(y_pred, threshold)
|
431 |
+
mse = torch.mean((y_pred - y_true) ** 2)
|
432 |
+
return 10 * torch.log10(1 / mse)
|
433 |
+
|
434 |
+
|
435 |
+
class SSIM(object):
|
436 |
+
'''
|
437 |
+
modified from https://github.com/jorge-pessoa/pytorch-msssim
|
438 |
+
'''
|
439 |
+
def __init__(self, des="structural similarity index"):
|
440 |
+
self.des = des
|
441 |
+
|
442 |
+
def __repr__(self):
|
443 |
+
return "SSIM"
|
444 |
+
|
445 |
+
def gaussian(self, w_size, sigma):
|
446 |
+
gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
|
447 |
+
return gauss/gauss.sum()
|
448 |
+
|
449 |
+
def create_window(self, w_size, channel=1):
|
450 |
+
_1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
|
451 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
452 |
+
window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
|
453 |
+
return window
|
454 |
+
|
455 |
+
def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
|
456 |
+
"""
|
457 |
+
args:
|
458 |
+
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
459 |
+
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
460 |
+
w_size : int, default 11
|
461 |
+
size_average : boolean, default True
|
462 |
+
full : boolean, default False
|
463 |
+
return ssim, larger the better
|
464 |
+
"""
|
465 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
466 |
+
if torch.max(y_pred) > 128:
|
467 |
+
max_val = 255
|
468 |
+
else:
|
469 |
+
max_val = 1
|
470 |
+
|
471 |
+
if torch.min(y_pred) < -0.5:
|
472 |
+
min_val = -1
|
473 |
+
else:
|
474 |
+
min_val = 0
|
475 |
+
L = max_val - min_val
|
476 |
+
|
477 |
+
padd = 0
|
478 |
+
(_, channel, height, width) = y_pred.size()
|
479 |
+
window = self.create_window(w_size, channel=channel).to(y_pred.device)
|
480 |
+
|
481 |
+
mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
|
482 |
+
mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)
|
483 |
+
|
484 |
+
mu1_sq = mu1.pow(2)
|
485 |
+
mu2_sq = mu2.pow(2)
|
486 |
+
mu1_mu2 = mu1 * mu2
|
487 |
+
|
488 |
+
sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
|
489 |
+
sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
|
490 |
+
sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2
|
491 |
+
|
492 |
+
C1 = (0.01 * L) ** 2
|
493 |
+
C2 = (0.03 * L) ** 2
|
494 |
+
|
495 |
+
v1 = 2.0 * sigma12 + C2
|
496 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
497 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
498 |
+
|
499 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
500 |
+
|
501 |
+
if size_average:
|
502 |
+
ret = ssim_map.mean()
|
503 |
+
else:
|
504 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
505 |
+
|
506 |
+
if full:
|
507 |
+
return ret, cs
|
508 |
+
return ret
|
509 |
+
|
510 |
+
|
511 |
+
class AE(object):
|
512 |
+
"""
|
513 |
+
Modified from matlab : colorangle.m, MATLAB V2019b
|
514 |
+
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
|
515 |
+
angle = 180 / pi * angle;
|
516 |
+
"""
|
517 |
+
def __init__(self, des='average Angular Error'):
|
518 |
+
self.des = des
|
519 |
+
|
520 |
+
def __repr__(self):
|
521 |
+
return "AE"
|
522 |
+
|
523 |
+
def __call__(self, y_pred, y_true):
|
524 |
+
"""
|
525 |
+
args:
|
526 |
+
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
527 |
+
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
|
528 |
+
return average AE, smaller the better
|
529 |
+
"""
|
530 |
+
dotP = torch.sum(y_pred * y_true, dim=1)
|
531 |
+
Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
|
532 |
+
Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
|
533 |
+
ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
|
534 |
+
return ae.mean(1).mean(1)
|
535 |
+
|
536 |
+
|
537 |
+
if __name__ == "__main__":
|
538 |
+
for ch in [3, 1]:
|
539 |
+
batch_size, img_row, img_col = 1, 224, 224
|
540 |
+
y_true = torch.rand(batch_size, ch, img_row, img_col)
|
541 |
+
noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
|
542 |
+
y_pred = y_true + noise
|
543 |
+
for cuda in [False, True]:
|
544 |
+
if cuda:
|
545 |
+
y_pred = y_pred.cuda()
|
546 |
+
y_true = y_true.cuda()
|
547 |
+
|
548 |
+
print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
|
549 |
+
########### similarity metrics
|
550 |
+
metric = MSE()
|
551 |
+
acc = metric(y_pred, y_true).item()
|
552 |
+
print("{} ==> {}".format(repr(metric), acc))
|
553 |
+
|
554 |
+
metric = PSNR()
|
555 |
+
acc = metric(y_pred, y_true).item()
|
556 |
+
print("{} ==> {}".format(repr(metric), acc))
|
557 |
+
|
558 |
+
metric = SSIM()
|
559 |
+
acc = metric(y_pred, y_true).item()
|
560 |
+
print("{} ==> {}".format(repr(metric), acc))
|
561 |
+
|
562 |
+
metric = LPIPS(cuda)
|
563 |
+
acc = metric(y_pred, y_true).item()
|
564 |
+
print("{} ==> {}".format(repr(metric), acc))
|
565 |
+
|
566 |
+
metric = AE()
|
567 |
+
acc = metric(y_pred, y_true).item()
|
568 |
+
print("{} ==> {}".format(repr(metric), acc))
|
569 |
+
|
570 |
+
########### accuracy metrics
|
571 |
+
metric = OAAcc()
|
572 |
+
maccu, accu = metric(y_pred, y_true)
|
573 |
+
print('mAccu:', maccu, 'Accu', accu)
|
574 |
+
|
575 |
+
metric = Precision()
|
576 |
+
mprec, prec = metric(y_pred, y_true)
|
577 |
+
print('mPrec:', mprec, 'Prec', prec)
|
578 |
+
|
579 |
+
metric = Recall()
|
580 |
+
mreca, reca = metric(y_pred, y_true)
|
581 |
+
print('mReca:', mreca, 'Reca', reca)
|
582 |
+
|
583 |
+
metric = F1Score()
|
584 |
+
mf1sc, f1sc = metric(y_pred, y_true)
|
585 |
+
print('mF1sc:', mf1sc, 'F1sc', f1sc)
|
586 |
+
|
587 |
+
metric = Kappa()
|
588 |
+
mkapp, kapp = metric(y_pred, y_true)
|
589 |
+
print('mKapp:', mkapp, 'Kapp', kapp)
|
590 |
+
|
591 |
+
metric = Jaccard()
|
592 |
+
mjacc, jacc = metric(y_pred, y_true)
|
593 |
+
print('mJacc:', mjacc, 'Jacc', jacc)
|
594 |
+
|
tasks/vision/segmentation/seg_heads.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import math
|
16 |
+
import einops
|
17 |
+
import torch
|
18 |
+
import apex
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from megatron import get_args
|
21 |
+
from megatron.model import LayerNorm
|
22 |
+
from megatron.model.module import MegatronModule
|
23 |
+
from megatron.model.vision.utils import resize
|
24 |
+
|
25 |
+
|
26 |
+
class SetrSegmentationHead(MegatronModule):
|
27 |
+
def __init__(self, hidden_size, num_classes):
|
28 |
+
super(SetrSegmentationHead, self).__init__()
|
29 |
+
args = get_args()
|
30 |
+
self.hidden_size = hidden_size
|
31 |
+
self.num_classes = num_classes
|
32 |
+
self.img_h = args.img_h
|
33 |
+
self.img_w = args.img_w
|
34 |
+
self.patch_dim = args.patch_dim
|
35 |
+
|
36 |
+
self.layernorm = LayerNorm(hidden_size, eps=args.layernorm_epsilon)
|
37 |
+
self.conv_0 = torch.nn.Conv2d(hidden_size, hidden_size,
|
38 |
+
1, 1, bias=False)
|
39 |
+
self.norm_0 = apex.parallel.SyncBatchNorm(hidden_size)
|
40 |
+
self.conv_1 = torch.nn.Conv2d(hidden_size, num_classes, 1, 1)
|
41 |
+
|
42 |
+
def to_2D(self, x):
|
43 |
+
n, hw, c = x.shape
|
44 |
+
h = self.img_h // self.patch_dim
|
45 |
+
w = self.img_w // self.patch_dim
|
46 |
+
assert(hw == h * w)
|
47 |
+
x = x.transpose(1, 2).reshape(n, c, h, w)
|
48 |
+
return x
|
49 |
+
|
50 |
+
def forward(self, hidden_states):
|
51 |
+
# [b c h w]
|
52 |
+
hidden_states = self.layernorm(hidden_states)
|
53 |
+
hidden_states = self.to_2D(hidden_states)
|
54 |
+
|
55 |
+
hidden_states = self.conv_0(hidden_states)
|
56 |
+
hidden_states = self.norm_0(hidden_states)
|
57 |
+
hidden_states = torch.tanh(hidden_states)
|
58 |
+
hidden_states = self.conv_1(hidden_states)
|
59 |
+
|
60 |
+
# [b c h w]
|
61 |
+
result = F.interpolate(hidden_states,
|
62 |
+
size=(self.img_h, self.img_w),
|
63 |
+
mode='bilinear')
|
64 |
+
|
65 |
+
return result
|
66 |
+
|
67 |
+
|
68 |
+
class MLP(torch.nn.Module):
|
69 |
+
"""
|
70 |
+
Linear Embedding
|
71 |
+
"""
|
72 |
+
def __init__(self, input_dim=2048, embed_dim=768):
|
73 |
+
super().__init__()
|
74 |
+
self.proj = torch.nn.Linear(input_dim, embed_dim)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x = x.flatten(2).transpose(1, 2)
|
78 |
+
x = self.proj(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class SegformerSegmentationHead(MegatronModule):
|
83 |
+
def __init__(self, feature_strides, in_channels,
|
84 |
+
embedding_dim, dropout_ratio):
|
85 |
+
super(SegformerSegmentationHead, self).__init__()
|
86 |
+
assert len(feature_strides) == len(in_channels)
|
87 |
+
assert min(feature_strides) == feature_strides[0]
|
88 |
+
args = get_args()
|
89 |
+
self.feature_strides = feature_strides
|
90 |
+
self.in_channels = in_channels
|
91 |
+
self.embedding_dim = embedding_dim
|
92 |
+
self.num_classes = args.num_classes
|
93 |
+
self.dropout_ratio = dropout_ratio
|
94 |
+
|
95 |
+
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = \
|
96 |
+
self.in_channels
|
97 |
+
|
98 |
+
self.linear_c4 = MLP(input_dim=c4_in_channels,
|
99 |
+
embed_dim=self.embedding_dim)
|
100 |
+
self.linear_c3 = MLP(input_dim=c3_in_channels,
|
101 |
+
embed_dim=self.embedding_dim)
|
102 |
+
self.linear_c2 = MLP(input_dim=c2_in_channels,
|
103 |
+
embed_dim=self.embedding_dim)
|
104 |
+
self.linear_c1 = MLP(input_dim=c1_in_channels,
|
105 |
+
embed_dim=self.embedding_dim)
|
106 |
+
|
107 |
+
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4,
|
108 |
+
self.embedding_dim, 1, 1)
|
109 |
+
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
|
110 |
+
|
111 |
+
self.dropout = torch.nn.Dropout2d(self.dropout_ratio)
|
112 |
+
self.linear_pred = torch.nn.Conv2d(self.embedding_dim,
|
113 |
+
self.num_classes,
|
114 |
+
kernel_size=1)
|
115 |
+
|
116 |
+
def forward(self, inputs):
|
117 |
+
c1, c2, c3, c4 = inputs
|
118 |
+
|
119 |
+
############## MLP decoder on C1-C4 ###########
|
120 |
+
n, _, h, w = c4.shape
|
121 |
+
|
122 |
+
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
|
123 |
+
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
|
124 |
+
|
125 |
+
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
|
126 |
+
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
|
127 |
+
|
128 |
+
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
|
129 |
+
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
|
130 |
+
|
131 |
+
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
|
132 |
+
|
133 |
+
_c = self.conv_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
|
134 |
+
x = self.norm(_c)
|
135 |
+
x = F.relu(x, inplace=True)
|
136 |
+
x = self.dropout(x)
|
137 |
+
x = self.linear_pred(x)
|
138 |
+
|
139 |
+
return x
|
140 |
+
|
tasks/vision/segmentation/seg_models.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import math
|
16 |
+
import einops
|
17 |
+
import torch
|
18 |
+
import apex
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from megatron import get_args
|
21 |
+
from megatron.model.module import MegatronModule
|
22 |
+
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
|
23 |
+
from megatron.model.vision.mit_backbone import mit_b3, mit_b5
|
24 |
+
from tasks.vision.segmentation.seg_heads import SetrSegmentationHead, SegformerSegmentationHead
|
25 |
+
|
26 |
+
|
27 |
+
class SetrSegmentationModel(MegatronModule):
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
num_classes,
|
31 |
+
pre_process=True,
|
32 |
+
post_process=True):
|
33 |
+
super(SetrSegmentationModel, self).__init__()
|
34 |
+
args = get_args()
|
35 |
+
assert post_process & pre_process
|
36 |
+
self.hidden_size = args.hidden_size
|
37 |
+
self.num_classes = num_classes
|
38 |
+
self.backbone = VitBackbone(
|
39 |
+
pre_process=pre_process,
|
40 |
+
post_process=post_process,
|
41 |
+
class_token=False,
|
42 |
+
post_layer_norm=False,
|
43 |
+
drop_path_rate=0.1
|
44 |
+
)
|
45 |
+
|
46 |
+
self.head = SetrSegmentationHead(
|
47 |
+
self.hidden_size,
|
48 |
+
self.num_classes
|
49 |
+
)
|
50 |
+
|
51 |
+
def set_input_tensor(self, input_tensor):
|
52 |
+
"""See megatron.model.transformer.set_input_tensor()"""
|
53 |
+
pass
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
# [b hw c]
|
57 |
+
hidden_states = self.backbone(input)
|
58 |
+
result_final = self.head(hidden_states)
|
59 |
+
return result_final
|
60 |
+
|
61 |
+
|
62 |
+
class SegformerSegmentationModel(MegatronModule):
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
num_classes,
|
66 |
+
pre_process=True,
|
67 |
+
post_process=True):
|
68 |
+
super(SegformerSegmentationModel, self).__init__()
|
69 |
+
args = get_args()
|
70 |
+
self.hidden_size = args.hidden_size
|
71 |
+
self.num_classes = num_classes
|
72 |
+
self.pre_process = pre_process
|
73 |
+
self.post_process = post_process
|
74 |
+
|
75 |
+
self.backbone = mit_b5()
|
76 |
+
self.head = SegformerSegmentationHead(
|
77 |
+
feature_strides=[4, 8, 16, 32],
|
78 |
+
in_channels=[64, 128, 320, 512],
|
79 |
+
embedding_dim=768,
|
80 |
+
dropout_ratio=0.1
|
81 |
+
)
|
82 |
+
|
83 |
+
def set_input_tensor(self, input_tensor):
|
84 |
+
"""See megatron.model.transformer.set_input_tensor()"""
|
85 |
+
pass
|
86 |
+
|
87 |
+
def forward(self, input):
|
88 |
+
# [b hw c]
|
89 |
+
hidden_states = self.backbone(input)
|
90 |
+
hidden_states = self.head(hidden_states)
|
91 |
+
return hidden_states
|
92 |
+
|
tasks/vision/segmentation/transforms.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 The MMSegmenation Authors.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import os
|
8 |
+
import math
|
9 |
+
import mmcv
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import torchvision.transforms as T
|
13 |
+
from torchvision import datasets
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from megatron import print_rank_0
|
16 |
+
from megatron import get_args
|
17 |
+
from PIL import Image, ImageOps, ImageEnhance
|
18 |
+
import torchvision.transforms as torch_tr
|
19 |
+
|
20 |
+
def _is_pil_image(img):
|
21 |
+
return isinstance(img, Image.Image)
|
22 |
+
|
23 |
+
class PhotoMetricDistortion(object):
|
24 |
+
"""Apply photometric distortion to image sequentially, every transformation
|
25 |
+
is applied with a probability of 0.5. The position of random contrast is in
|
26 |
+
second or second to last.
|
27 |
+
1. random brightness
|
28 |
+
2. random contrast (mode 0)
|
29 |
+
3. convert color from BGR to HSV
|
30 |
+
4. random saturation
|
31 |
+
5. random hue
|
32 |
+
6. convert color from HSV to BGR
|
33 |
+
7. random contrast (mode 1)
|
34 |
+
8. randomly swap channels
|
35 |
+
Args:
|
36 |
+
brightness_delta (int): delta of brightness.
|
37 |
+
contrast_range (tuple): range of contrast.
|
38 |
+
saturation_range (tuple): range of saturation.
|
39 |
+
hue_delta (int): delta of hue.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
brightness_delta=32,
|
44 |
+
contrast_range=(0.5, 1.5),
|
45 |
+
saturation_range=(0.5, 1.5),
|
46 |
+
hue_delta=18):
|
47 |
+
self.brightness_delta = brightness_delta
|
48 |
+
self.contrast_lower, self.contrast_upper = contrast_range
|
49 |
+
self.saturation_lower, self.saturation_upper = saturation_range
|
50 |
+
self.hue_delta = hue_delta
|
51 |
+
|
52 |
+
def convert(self, img, alpha=1, beta=0):
|
53 |
+
"""Multiple with alpha and add beat with clip."""
|
54 |
+
img = img.astype(np.float32) * alpha + beta
|
55 |
+
img = np.clip(img, 0, 255)
|
56 |
+
return img.astype(np.uint8)
|
57 |
+
|
58 |
+
def brightness(self, img):
|
59 |
+
"""Brightness distortion."""
|
60 |
+
if random.randint(0, 1):
|
61 |
+
return self.convert(
|
62 |
+
img,
|
63 |
+
beta=random.uniform(-self.brightness_delta,
|
64 |
+
self.brightness_delta))
|
65 |
+
return img
|
66 |
+
|
67 |
+
def contrast(self, img):
|
68 |
+
"""Contrast distortion."""
|
69 |
+
if random.randint(0, 1):
|
70 |
+
return self.convert(
|
71 |
+
img,
|
72 |
+
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
|
73 |
+
return img
|
74 |
+
|
75 |
+
def saturation(self, img):
|
76 |
+
"""Saturation distortion."""
|
77 |
+
if random.randint(0, 1):
|
78 |
+
img = mmcv.bgr2hsv(img)
|
79 |
+
img[:, :, 1] = self.convert(
|
80 |
+
img[:, :, 1],
|
81 |
+
alpha=random.uniform(self.saturation_lower,
|
82 |
+
self.saturation_upper))
|
83 |
+
img = mmcv.hsv2bgr(img)
|
84 |
+
return img
|
85 |
+
|
86 |
+
def hue(self, img):
|
87 |
+
"""Hue distortion."""
|
88 |
+
if random.randint(0, 1):
|
89 |
+
img = mmcv.bgr2hsv(img)
|
90 |
+
img[:, :,
|
91 |
+
0] = (img[:, :, 0].astype(int) +
|
92 |
+
random.randint(-self.hue_delta, self.hue_delta)) % 180
|
93 |
+
img = mmcv.hsv2bgr(img)
|
94 |
+
return img
|
95 |
+
|
96 |
+
def __call__(self, img):
|
97 |
+
"""Call function to perform photometric distortion on images.
|
98 |
+
Args:
|
99 |
+
results (dict): Result dict from loading pipeline.
|
100 |
+
Returns:
|
101 |
+
dict: Result dict with images distorted.
|
102 |
+
"""
|
103 |
+
img = np.array(img)
|
104 |
+
|
105 |
+
# random brightness
|
106 |
+
img = self.brightness(img)
|
107 |
+
|
108 |
+
# mode == 0 --> do random contrast first
|
109 |
+
# mode == 1 --> do random contrast last
|
110 |
+
mode = random.randint(0, 1)
|
111 |
+
if mode == 1:
|
112 |
+
img = self.contrast(img)
|
113 |
+
|
114 |
+
# random saturation
|
115 |
+
img = self.saturation(img)
|
116 |
+
|
117 |
+
# random hue
|
118 |
+
img = self.hue(img)
|
119 |
+
|
120 |
+
# random contrast
|
121 |
+
if mode == 0:
|
122 |
+
img = self.contrast(img)
|
123 |
+
|
124 |
+
img = Image.fromarray(img.astype(np.uint8)).convert('RGB')
|
125 |
+
return img
|
126 |
+
|
127 |
+
|
128 |
+
class RandomCrop(object):
|
129 |
+
"""
|
130 |
+
Take a random crop from the image.
|
131 |
+
|
132 |
+
First the image or crop size may need to be adjusted if the incoming image
|
133 |
+
is too small...
|
134 |
+
|
135 |
+
If the image is smaller than the crop, then:
|
136 |
+
the image is padded up to the size of the crop
|
137 |
+
unless 'nopad', in which case the crop size is shrunk to fit the image
|
138 |
+
|
139 |
+
A random crop is taken such that the crop fits within the image.
|
140 |
+
|
141 |
+
|
142 |
+
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
|
143 |
+
translation randomness of at least that value around the image.
|
144 |
+
|
145 |
+
if image < crop_size:
|
146 |
+
# slide crop within image, random offset
|
147 |
+
else:
|
148 |
+
# slide image within crop
|
149 |
+
"""
|
150 |
+
def __init__(self, crop_size):
|
151 |
+
args = get_args()
|
152 |
+
self.size = crop_size
|
153 |
+
self.cat_max_ratio = 0.75
|
154 |
+
self.ignore_index = args.ignore_index
|
155 |
+
self.pad_color = (0, 0, 0)
|
156 |
+
|
157 |
+
def get_crop_bbox(self, img):
|
158 |
+
"""Randomly get a crop bounding box."""
|
159 |
+
img_w, img_h = img.size
|
160 |
+
target_h, target_w = self.size #[H W]
|
161 |
+
margin_h = max(img_h - target_h, 0)
|
162 |
+
margin_w = max(img_w - target_w, 0)
|
163 |
+
offset_h = random.randint(0, margin_h)
|
164 |
+
offset_w = random.randint(0, margin_w)
|
165 |
+
crop_y1, crop_y2 = offset_h, offset_h + target_h
|
166 |
+
crop_x1, crop_x2 = offset_w, offset_w + target_w
|
167 |
+
|
168 |
+
return crop_y1, crop_y2, crop_x1, crop_x2
|
169 |
+
|
170 |
+
def crop(self, img, crop_bbox):
|
171 |
+
"""Crop from ``img``"""
|
172 |
+
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
173 |
+
img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
174 |
+
return img
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def crop_in_image(target_w, target_h, w, h, img, mask):
|
178 |
+
if w == target_w:
|
179 |
+
x1 = 0
|
180 |
+
else:
|
181 |
+
x1 = random.randint(0, w - target_w)
|
182 |
+
if h == target_h:
|
183 |
+
y1 = 0
|
184 |
+
else:
|
185 |
+
y1 = random.randint(0, h - target_h)
|
186 |
+
|
187 |
+
return [img.crop((x1, y1, x1 + target_w, y1 + target_h)),
|
188 |
+
mask.crop((x1, y1, x1 + target_w, y1 + target_h))]
|
189 |
+
|
190 |
+
|
191 |
+
def __call__(self, img, mask):
|
192 |
+
w, h = img.size
|
193 |
+
target_h, target_w = self.size # ASSUME H, W
|
194 |
+
|
195 |
+
if w == target_w and h == target_h:
|
196 |
+
return img, mask
|
197 |
+
|
198 |
+
# Pad image if image < crop
|
199 |
+
if target_h > h:
|
200 |
+
pad_h = (target_h - h) // 2 + 1
|
201 |
+
else:
|
202 |
+
pad_h = 0
|
203 |
+
if target_w > w:
|
204 |
+
pad_w = (target_w - w) // 2 + 1
|
205 |
+
else:
|
206 |
+
pad_w = 0
|
207 |
+
border = (pad_w, pad_h, pad_w, pad_h)
|
208 |
+
if pad_h or pad_w:
|
209 |
+
img = ImageOps.expand(img, border=border, fill=(0, 0, 0))
|
210 |
+
mask = ImageOps.expand(mask, border=border, fill=self.ignore_index)
|
211 |
+
w, h = img.size
|
212 |
+
|
213 |
+
crop_bbox = self.get_crop_bbox(img)
|
214 |
+
if self.cat_max_ratio < 1.:
|
215 |
+
# Repeat 10 times
|
216 |
+
for _ in range(10):
|
217 |
+
seg_temp = self.crop(mask, crop_bbox)
|
218 |
+
labels, cnt = np.unique(seg_temp, return_counts=True)
|
219 |
+
cnt = cnt[labels != self.ignore_index]
|
220 |
+
if len(cnt) > 1 and np.max(cnt) / np.sum(
|
221 |
+
cnt) < self.cat_max_ratio:
|
222 |
+
break
|
223 |
+
crop_bbox = self.get_crop_bbox(img)
|
224 |
+
|
225 |
+
# crop the image
|
226 |
+
img = self.crop(img, crop_bbox)
|
227 |
+
|
228 |
+
# crop semantic seg
|
229 |
+
mask = self.crop(mask, crop_bbox)
|
230 |
+
assert(img.size[0] == self.size[1] and img.size[1] == self.size[0])
|
231 |
+
|
232 |
+
return img, mask
|
233 |
+
|
234 |
+
|
235 |
+
class RandomSizeAndCrop(object):
|
236 |
+
def __init__(self,
|
237 |
+
crop_size,
|
238 |
+
scale_min=0.5,
|
239 |
+
scale_max=2.0):
|
240 |
+
self.crop = RandomCrop(crop_size)
|
241 |
+
self.scale_min = scale_min
|
242 |
+
self.scale_max = scale_max
|
243 |
+
|
244 |
+
def __call__(self, img, mask):
|
245 |
+
|
246 |
+
scale_amt = random.uniform(self.scale_min, self.scale_max)
|
247 |
+
w, h = [int(i * scale_amt) for i in img.size]
|
248 |
+
|
249 |
+
resized_img = img.resize((w, h), Image.BICUBIC)
|
250 |
+
resized_mask = mask.resize((w, h), Image.NEAREST)
|
251 |
+
img, mask = self.crop(resized_img, resized_mask)
|
252 |
+
return img, mask
|
253 |
+
|
254 |
+
class RandomHorizontallyFlip(object):
|
255 |
+
def __call__(self, img, mask):
|
256 |
+
if random.random() < 0.5:
|
257 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(
|
258 |
+
Image.FLIP_LEFT_RIGHT)
|
259 |
+
return img, mask
|
260 |
+
|
261 |
+
|
262 |
+
def adjust_brightness(img, brightness_factor):
|
263 |
+
"""Adjust brightness of an Image.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
img (PIL Image): PIL Image to be adjusted.
|
267 |
+
brightness_factor (float): How much to adjust the brightness. Can be
|
268 |
+
any non negative number. 0 gives a black image, 1 gives the
|
269 |
+
original image while 2 increases the brightness by a factor of 2.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
PIL Image: Brightness adjusted image.
|
273 |
+
"""
|
274 |
+
if not _is_pil_image(img):
|
275 |
+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
276 |
+
|
277 |
+
enhancer = ImageEnhance.Brightness(img)
|
278 |
+
img = enhancer.enhance(brightness_factor)
|
279 |
+
return img
|
280 |
+
|
281 |
+
|
282 |
+
def adjust_contrast(img, contrast_factor):
|
283 |
+
"""Adjust contrast of an Image.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
img (PIL Image): PIL Image to be adjusted.
|
287 |
+
contrast_factor (float): How much to adjust the contrast. Can be any
|
288 |
+
non negative number. 0 gives a solid gray image, 1 gives the
|
289 |
+
original image while 2 increases the contrast by a factor of 2.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
PIL Image: Contrast adjusted image.
|
293 |
+
"""
|
294 |
+
if not _is_pil_image(img):
|
295 |
+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
296 |
+
|
297 |
+
enhancer = ImageEnhance.Contrast(img)
|
298 |
+
img = enhancer.enhance(contrast_factor)
|
299 |
+
return img
|
300 |
+
|
301 |
+
|
302 |
+
def adjust_saturation(img, saturation_factor):
|
303 |
+
"""Adjust color saturation of an image.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
img (PIL Image): PIL Image to be adjusted.
|
307 |
+
saturation_factor (float): How much to adjust the saturation. 0 will
|
308 |
+
give a black and white image, 1 will give the original image while
|
309 |
+
2 will enhance the saturation by a factor of 2.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
PIL Image: Saturation adjusted image.
|
313 |
+
"""
|
314 |
+
if not _is_pil_image(img):
|
315 |
+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
316 |
+
|
317 |
+
enhancer = ImageEnhance.Color(img)
|
318 |
+
img = enhancer.enhance(saturation_factor)
|
319 |
+
return img
|
320 |
+
|
321 |
+
|
322 |
+
def adjust_hue(img, hue_factor):
|
323 |
+
"""Adjust hue of an image.
|
324 |
+
|
325 |
+
The image hue is adjusted by converting the image to HSV and
|
326 |
+
cyclically shifting the intensities in the hue channel (H).
|
327 |
+
The image is then converted back to original image mode.
|
328 |
+
|
329 |
+
`hue_factor` is the amount of shift in H channel and must be in the
|
330 |
+
interval `[-0.5, 0.5]`.
|
331 |
+
|
332 |
+
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
img (PIL Image): PIL Image to be adjusted.
|
336 |
+
hue_factor (float): How much to shift the hue channel. Should be in
|
337 |
+
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
338 |
+
HSV space in positive and negative direction respectively.
|
339 |
+
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
340 |
+
with complementary colors while 0 gives the original image.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
PIL Image: Hue adjusted image.
|
344 |
+
"""
|
345 |
+
if not(-0.5 <= hue_factor <= 0.5):
|
346 |
+
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
|
347 |
+
|
348 |
+
if not _is_pil_image(img):
|
349 |
+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
350 |
+
|
351 |
+
input_mode = img.mode
|
352 |
+
if input_mode in {'L', '1', 'I', 'F'}:
|
353 |
+
return img
|
354 |
+
|
355 |
+
h, s, v = img.convert('HSV').split()
|
356 |
+
|
357 |
+
np_h = np.array(h, dtype=np.uint8)
|
358 |
+
# uint8 addition take cares of rotation across boundaries
|
359 |
+
with np.errstate(over='ignore'):
|
360 |
+
np_h += np.uint8(hue_factor * 255)
|
361 |
+
h = Image.fromarray(np_h, 'L')
|
362 |
+
|
363 |
+
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
364 |
+
return img
|
365 |
+
|
366 |
+
|
367 |
+
class ColorJitter(object):
|
368 |
+
"""Randomly change the brightness, contrast and saturation of an image.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
372 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
373 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
374 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
375 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
376 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
377 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
378 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
379 |
+
"""
|
380 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
381 |
+
self.brightness = brightness
|
382 |
+
self.contrast = contrast
|
383 |
+
self.saturation = saturation
|
384 |
+
self.hue = hue
|
385 |
+
|
386 |
+
@staticmethod
|
387 |
+
def get_params(brightness, contrast, saturation, hue):
|
388 |
+
"""Get a randomized transform to be applied on image.
|
389 |
+
|
390 |
+
Arguments are same as that of __init__.
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
Transform which randomly adjusts brightness, contrast and
|
394 |
+
saturation in a random order.
|
395 |
+
"""
|
396 |
+
transforms = []
|
397 |
+
if brightness > 0:
|
398 |
+
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
|
399 |
+
transforms.append(
|
400 |
+
torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
|
401 |
+
|
402 |
+
if contrast > 0:
|
403 |
+
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
|
404 |
+
transforms.append(
|
405 |
+
torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
|
406 |
+
|
407 |
+
if saturation > 0:
|
408 |
+
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
|
409 |
+
transforms.append(
|
410 |
+
torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
|
411 |
+
|
412 |
+
if hue > 0:
|
413 |
+
hue_factor = np.random.uniform(-hue, hue)
|
414 |
+
transforms.append(
|
415 |
+
torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
|
416 |
+
|
417 |
+
np.random.shuffle(transforms)
|
418 |
+
transform = torch_tr.Compose(transforms)
|
419 |
+
|
420 |
+
return transform
|
421 |
+
|
422 |
+
def __call__(self, img):
|
423 |
+
"""
|
424 |
+
Args:
|
425 |
+
img (PIL Image): Input image.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
PIL Image: Color jittered image.
|
429 |
+
"""
|
430 |
+
transform = self.get_params(self.brightness, self.contrast,
|
431 |
+
self.saturation, self.hue)
|
432 |
+
return transform(img)
|
433 |
+
|
tasks/vision/segmentation/utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from megatron import get_args
|
5 |
+
|
6 |
+
def slidingcrops(img, mask):
|
7 |
+
# img: [b c h w]
|
8 |
+
# mask: [b h w]
|
9 |
+
args = get_args()
|
10 |
+
assert args.img_h == args.img_w
|
11 |
+
crop_size = args.img_h
|
12 |
+
stride = args.seg_stride
|
13 |
+
ignore_index = args.ignore_index
|
14 |
+
n, c, h, w = img.shape
|
15 |
+
assert h >= crop_size
|
16 |
+
assert w >= crop_size
|
17 |
+
long_size = max(h, w)
|
18 |
+
|
19 |
+
img_slices, mask_slices, slices_info = [], [], []
|
20 |
+
if long_size > crop_size:
|
21 |
+
assert stride <= crop_size
|
22 |
+
h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1
|
23 |
+
w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1
|
24 |
+
for yy in range(h_step_num):
|
25 |
+
for xx in range(w_step_num):
|
26 |
+
sy, sx = yy * stride, xx * stride
|
27 |
+
ey, ex = sy + crop_size, sx + crop_size
|
28 |
+
img_sub = img[:, :, sy: ey, sx: ex]
|
29 |
+
mask_sub = mask[:, sy: ey, sx: ex]
|
30 |
+
|
31 |
+
# padding
|
32 |
+
sub_h, sub_w = img_sub.shape[2:]
|
33 |
+
pad_h = max(crop_size - sub_h, 0)
|
34 |
+
pad_w = max(crop_size - sub_w, 0)
|
35 |
+
img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index)
|
36 |
+
mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h))
|
37 |
+
|
38 |
+
img_slices.append(img_sub)
|
39 |
+
mask_slices.append(mask_sub)
|
40 |
+
slices_info.append([sy, ey, sx, ex, sub_h, sub_w])
|
41 |
+
|
42 |
+
return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w)
|
43 |
+
else:
|
44 |
+
return img, mask, [[0, h, 0, w, h, w]], (h, w)
|
45 |
+
|
46 |
+
|
47 |
+
def slidingjoins(preds, probs, labels, slices_info, img_size):
|
48 |
+
args = get_args()
|
49 |
+
num_slices = len(slices_info)
|
50 |
+
|
51 |
+
if num_slices == 1:
|
52 |
+
return preds, labels
|
53 |
+
|
54 |
+
h, w = img_size
|
55 |
+
split_size = args.micro_batch_size
|
56 |
+
|
57 |
+
preds_split = torch.split(preds, split_size)
|
58 |
+
probs_split = torch.split(probs, split_size)
|
59 |
+
labels_split = torch.split(labels, split_size)
|
60 |
+
|
61 |
+
assert(len(preds_split) == num_slices)
|
62 |
+
|
63 |
+
total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda')
|
64 |
+
total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
|
65 |
+
total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
|
66 |
+
|
67 |
+
for i in range(num_slices):
|
68 |
+
sy, ey, sx, ex, sub_h, sub_w = slices_info[i]
|
69 |
+
assert sy + sub_h <= h
|
70 |
+
assert sx + sub_w <= w
|
71 |
+
curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w]
|
72 |
+
curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w]
|
73 |
+
|
74 |
+
local_max_probs = probs_split[i][:, :sub_h, : sub_w]
|
75 |
+
local_preds = preds_split[i][:, :sub_h, :sub_w]
|
76 |
+
|
77 |
+
result_max_probs = torch.maximum(curr_max_probs, local_max_probs)
|
78 |
+
result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds)
|
79 |
+
|
80 |
+
total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs
|
81 |
+
total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds
|
82 |
+
total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w]
|
83 |
+
|
84 |
+
return total_preds, total_labels
|
85 |
+
|