Shawn001 commited on
Commit
c2c125c
·
1 Parent(s): 23bd7af

Upload 53 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tasks/clue/afqmc.py +94 -0
  2. tasks/clue/cmnli.py +103 -0
  3. tasks/clue/csl.py +93 -0
  4. tasks/clue/data.py +69 -0
  5. tasks/clue/finetune.py +130 -0
  6. tasks/clue/iflytek.py +94 -0
  7. tasks/clue/ocnli.py +102 -0
  8. tasks/clue/tnews.py +95 -0
  9. tasks/clue/wsc.py +117 -0
  10. tasks/clue/zc.py +96 -0
  11. tasks/data_utils.py +118 -0
  12. tasks/ensemble_classifier.py +149 -0
  13. tasks/eval_utils.py +250 -0
  14. tasks/finetune_utils.py +330 -0
  15. tasks/glue/data.py +69 -0
  16. tasks/glue/finetune.py +93 -0
  17. tasks/glue/mnli.py +84 -0
  18. tasks/glue/qqp.py +101 -0
  19. tasks/label_dict.py +73 -0
  20. tasks/main.py +121 -0
  21. tasks/msdp/README.md +19 -0
  22. tasks/msdp/evaluate.py +58 -0
  23. tasks/msdp/main.py +79 -0
  24. tasks/msdp/metrics.py +77 -0
  25. tasks/msdp/preprocessing.py +595 -0
  26. tasks/msdp/prompt.py +322 -0
  27. tasks/orqa/README.md +36 -0
  28. tasks/orqa/evaluate_orqa.py +52 -0
  29. tasks/orqa/evaluate_utils.py +188 -0
  30. tasks/orqa/supervised/data.py +300 -0
  31. tasks/orqa/supervised/eval_utils.py +206 -0
  32. tasks/orqa/supervised/finetune.py +251 -0
  33. tasks/orqa/unsupervised/nq.py +228 -0
  34. tasks/orqa/unsupervised/qa_utils.py +177 -0
  35. tasks/orqa/unsupervised/tokenizers.py +243 -0
  36. tasks/race/data.py +135 -0
  37. tasks/race/finetune.py +67 -0
  38. tasks/vision/classification/classification.py +94 -0
  39. tasks/vision/classification/eval_utils.py +129 -0
  40. tasks/vision/finetune_utils.py +312 -0
  41. tasks/vision/main.py +66 -0
  42. tasks/vision/segmentation/cityscapes.py +207 -0
  43. tasks/vision/segmentation/data.py +154 -0
  44. tasks/vision/segmentation/finetune_segformer.py +251 -0
  45. tasks/vision/segmentation/finetune_setr.py +225 -0
  46. tasks/vision/segmentation/metrics.py +594 -0
  47. tasks/vision/segmentation/seg_heads.py +140 -0
  48. tasks/vision/segmentation/seg_models.py +92 -0
  49. tasks/vision/segmentation/transforms.py +433 -0
  50. tasks/vision/segmentation/utils.py +85 -0
tasks/clue/afqmc.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("AFQMC")
25
+
26
+ class AFQMCDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='0'):
30
+ self.test_label = test_label
31
+ super().__init__('AFQMC', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ reader = f.readlines()
44
+ lines = []
45
+ for line in reader:
46
+ lines.append(json.loads(line.strip()))
47
+ drop_cnt = 0
48
+ for index, row in enumerate(lines):
49
+ if "id" not in row:
50
+ row["id"] = index
51
+ if first:
52
+ first = False
53
+ if "label" not in row:
54
+ is_test = True
55
+ print_rank_0(
56
+ ' reading {}, {} and {} columns and setting '
57
+ 'labels to {}'.format(
58
+ row["id"], row["sentence1"].strip(),
59
+ row["sentence2"].strip(), self.test_label))
60
+ else:
61
+ is_test = False
62
+ print_rank_0(' reading {} , {}, {}, and {} columns '
63
+ '...'.format(
64
+ row["id"], row["sentence1"].strip(),
65
+ row["sentence2"].strip(), row["label"].strip()))
66
+
67
+ text_a = clean_text(row["sentence1"].strip())
68
+ text_b = clean_text(row["sentence2"].strip())
69
+ unique_id = int(row["id"])
70
+
71
+ if is_test:
72
+ label = self.test_label
73
+ else:
74
+ label = row["label"].strip()
75
+
76
+ assert len(text_a) > 0
77
+ assert len(text_b) > 0
78
+ assert label in LABELS, "found label {} {}".format(label, row)
79
+ assert unique_id >= 0
80
+
81
+ sample = {'text_a': text_a,
82
+ 'text_b': text_b,
83
+ 'label': LABELS[label],
84
+ 'uid': unique_id}
85
+ total += 1
86
+ samples.append(sample)
87
+
88
+ if total % 5000 == 0:
89
+ print_rank_0(' > processed {} so far ...'.format(total))
90
+
91
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
92
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
93
+
94
+ return samples
tasks/clue/cmnli.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("CMNLI")
25
+
26
+ class CMNLIDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='contradiction'):
30
+ self.test_label = test_label
31
+ super().__init__('CMNLI', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ reader = f.readlines()
44
+ lines = []
45
+ for line in reader:
46
+ lines.append(json.loads(line.strip()))
47
+ drop_cnt = 0
48
+ for index, row in enumerate(lines):
49
+ row["id"] = index
50
+ # line = line.strip()
51
+ # try:
52
+ # row = eval(line)
53
+ # except:
54
+ # print(">>>>>>>> ", line)
55
+ # continue
56
+ if first:
57
+ first = False
58
+ if "label" not in row:
59
+ is_test = True
60
+ print_rank_0(
61
+ ' reading {}, {} and {} columns and setting '
62
+ 'labels to {}'.format(
63
+ row["id"], row["sentence1"].strip(),
64
+ row["sentence2"].strip(), self.test_label))
65
+ else:
66
+ is_test = False
67
+ print_rank_0(' reading {} , {}, {}, and {} columns '
68
+ '...'.format(
69
+ row["id"], row["sentence1"].strip(),
70
+ row["sentence2"].strip(), row["label"].strip()))
71
+
72
+ text_a = clean_text(row["sentence1"].strip())
73
+ text_b = clean_text(row["sentence2"].strip())
74
+ unique_id = int(row["id"])
75
+
76
+ if is_test:
77
+ label = self.test_label
78
+ else:
79
+ label = row["label"].strip()
80
+
81
+ if label == "-":
82
+ drop_cnt += 1
83
+ continue
84
+
85
+ assert len(text_a) > 0
86
+ assert len(text_b) > 0
87
+ assert label in LABELS, "found label {} {}".format(label, row)
88
+ assert unique_id >= 0
89
+
90
+ sample = {'text_a': text_a,
91
+ 'text_b': text_b,
92
+ 'label': LABELS[label],
93
+ 'uid': unique_id}
94
+ total += 1
95
+ samples.append(sample)
96
+
97
+ if total % 5000 == 0:
98
+ print_rank_0(' > processed {} so far ...'.format(total))
99
+
100
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
101
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
102
+
103
+ return samples
tasks/clue/csl.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("CSL")
25
+
26
+ class CSLDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='0'):
30
+ self.test_label = test_label
31
+ super().__init__('CSL', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ reader = f.readlines()
44
+ lines = []
45
+ for line in reader:
46
+ lines.append(json.loads(line.strip()))
47
+ drop_cnt = 0
48
+ for index, row in enumerate(lines):
49
+ row["id"] = index
50
+ if first:
51
+ first = False
52
+ if "label" not in row:
53
+ is_test = True
54
+ print_rank_0(
55
+ ' reading {}, {} and {} columns and setting '
56
+ 'labels to {}'.format(
57
+ row["id"], " ".join(row["keyword"]).strip(),
58
+ row["abst"].strip(), self.test_label))
59
+ else:
60
+ is_test = False
61
+ print_rank_0(' reading {} , {}, {}, and {} columns '
62
+ '...'.format(
63
+ row["id"], " ".join(row["keyword"]).strip(),
64
+ row["abst"].strip(), row["label"].strip()))
65
+
66
+ text_a = clean_text(" ".join(row["keyword"]).strip())
67
+ text_b = clean_text(row["abst"].strip())
68
+ unique_id = int(row["id"])
69
+
70
+ if is_test:
71
+ label = self.test_label
72
+ else:
73
+ label = row["label"].strip()
74
+
75
+ assert len(text_a) > 0
76
+ assert len(text_b) > 0
77
+ assert label in LABELS, "found label {} {}".format(label, row)
78
+ assert unique_id >= 0
79
+
80
+ sample = {'text_a': text_a,
81
+ 'text_b': text_b,
82
+ 'label': LABELS[label],
83
+ 'uid': unique_id}
84
+ total += 1
85
+ samples.append(sample)
86
+
87
+ if total % 5000 == 0:
88
+ print_rank_0(' > processed {} so far ...'.format(total))
89
+
90
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
91
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
92
+
93
+ return samples
tasks/clue/data.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """GLUE dataset."""
17
+
18
+ from abc import ABC
19
+ from abc import abstractmethod
20
+
21
+ from torch.utils.data import Dataset
22
+
23
+ from megatron import print_rank_0
24
+ from tasks.data_utils import build_sample
25
+ from tasks.data_utils import build_tokens_types_paddings_from_text
26
+
27
+
28
+ class GLUEAbstractDataset(ABC, Dataset):
29
+ """GLUE base dataset class."""
30
+
31
+ def __init__(self, task_name, dataset_name, datapaths,
32
+ tokenizer, max_seq_length):
33
+ # Store inputs.
34
+ self.task_name = task_name
35
+ self.dataset_name = dataset_name
36
+ self.tokenizer = tokenizer
37
+ self.max_seq_length = max_seq_length
38
+ print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
39
+ self.dataset_name))
40
+ # Process the files.
41
+ string = ' > paths:'
42
+ for path in datapaths:
43
+ string += ' ' + path
44
+ print_rank_0(string)
45
+ self.samples = []
46
+ for datapath in datapaths:
47
+ self.samples.extend(self.process_samples_from_single_path(datapath))
48
+ print_rank_0(' >> total number of samples: {}'.format(
49
+ len(self.samples)))
50
+
51
+ def __len__(self):
52
+ return len(self.samples)
53
+
54
+ def __getitem__(self, idx):
55
+ raw_sample = self.samples[idx]
56
+ ids, types, paddings = build_tokens_types_paddings_from_text(
57
+ raw_sample['text_a'], raw_sample['text_b'],
58
+ self.tokenizer, self.max_seq_length)
59
+ sample = build_sample(ids, types, paddings,
60
+ raw_sample['label'], raw_sample['uid'])
61
+ return sample
62
+
63
+ @abstractmethod
64
+ def process_samples_from_single_path(self, datapath):
65
+ """Abstract method that takes a single path / filename and
66
+ returns a list of dataset samples, each sample being a dict of
67
+ {'text_a': string, 'text_b': string, 'label': int, 'uid': int}
68
+ """
69
+ pass
tasks/clue/finetune.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """GLUE finetuning/evaluation."""
17
+
18
+ from megatron import get_args
19
+ from megatron import print_rank_0
20
+ from megatron import get_tokenizer
21
+ from megatron import mpu
22
+ from megatron.model.classification import Classification
23
+ from tasks.eval_utils import accuracy_func_provider
24
+ from tasks.finetune_utils import finetune
25
+
26
+
27
+ def clue_classification(num_classes, Dataset,
28
+ name_from_datapath_func):
29
+
30
+ def train_valid_datasets_provider():
31
+ """Build train and validation dataset."""
32
+ args = get_args()
33
+ tokenizer = get_tokenizer()
34
+
35
+ train_dataset = Dataset('training', args.train_data,
36
+ tokenizer, args.seq_length)
37
+ valid_dataset = Dataset('validation', args.valid_data,
38
+ tokenizer, args.seq_length)
39
+
40
+ return train_dataset, valid_dataset
41
+
42
+ def model_provider(pre_process=True, post_process=True):
43
+ """Build the model."""
44
+ args = get_args()
45
+
46
+ print_rank_0('building classification model for {} ...'.format(
47
+ args.task))
48
+ model = Classification(num_classes=num_classes, num_tokentypes=2,
49
+ pre_process=pre_process, post_process=post_process)
50
+
51
+ return model
52
+
53
+ def metrics_func_provider():
54
+ """Privde metrics callback function."""
55
+ def single_dataset_provider(datapath):
56
+ args = get_args()
57
+ tokenizer = get_tokenizer()
58
+ name = name_from_datapath_func(datapath)
59
+ return Dataset(name, [datapath], tokenizer, args.seq_length)
60
+ return accuracy_func_provider(single_dataset_provider)
61
+
62
+ """Finetune/evaluate."""
63
+ finetune(train_valid_datasets_provider, model_provider,
64
+ end_of_epoch_callback_provider=metrics_func_provider)
65
+
66
+
67
+ def main():
68
+ args = get_args()
69
+
70
+ if args.task == 'AFQMC':
71
+ num_classes = 2
72
+ from tasks.clue.afqmc import AFQMCDataset as Dataset
73
+
74
+ def name_from_datapath(datapath):
75
+ return "afqmc"
76
+
77
+ elif args.task == 'CSL':
78
+ num_classes = 2
79
+ from tasks.clue.csl import CSLDataset as Dataset
80
+
81
+ def name_from_datapath(datapath):
82
+ return "csl"
83
+
84
+ elif args.task == 'IFLYTEK':
85
+ num_classes = 119
86
+ from tasks.clue.iflytek import IFLYTEKDataset as Dataset
87
+
88
+ def name_from_datapath(datapath):
89
+ return "iflytek"
90
+
91
+ elif args.task == 'OCNLI':
92
+ num_classes = 3
93
+ from tasks.clue.ocnli import OCNLIDataset as Dataset
94
+
95
+ def name_from_datapath(datapath):
96
+ return "ocnli"
97
+
98
+ elif args.task == 'TNEWS':
99
+ num_classes = 15
100
+ from tasks.clue.tnews import TNEWSDataset as Dataset
101
+
102
+ def name_from_datapath(datapath):
103
+ return "tnews"
104
+
105
+ elif args.task == 'WSC':
106
+ num_classes = 2
107
+ from tasks.clue.wsc import WSCDataset as Dataset
108
+
109
+ def name_from_datapath(datapath):
110
+ return "wsc"
111
+
112
+ elif args.task == 'CMNLI':
113
+ num_classes = 3
114
+ from tasks.clue.cmnli import CMNLIDataset as Dataset
115
+
116
+ def name_from_datapath(datapath):
117
+ return "cmnli"
118
+
119
+ elif args.task == 'ZC':
120
+ num_classes = 2
121
+ from tasks.clue.zc import ZCDataset as Dataset
122
+
123
+ def name_from_datapath(datapath):
124
+ return "zc"
125
+
126
+ else:
127
+ raise NotImplementedError('GLUE task {} is not implemented.'.format(
128
+ args.task))
129
+
130
+ clue_classification(num_classes, Dataset, name_from_datapath)
tasks/clue/iflytek.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("IFLYTEK")
25
+
26
+ class IFLYTEKDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='0'):
30
+ self.test_label = test_label
31
+ super().__init__('IFLYTEK', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ reader = f.readlines()
44
+ lines = []
45
+ for line in reader:
46
+ lines.append(json.loads(line.strip()))
47
+ drop_cnt = 0
48
+ for index, row in enumerate(lines):
49
+ if "id" not in row:
50
+ row["id"] = index
51
+ if first:
52
+ first = False
53
+ if "label" not in row:
54
+ is_test = True
55
+ print_rank_0(
56
+ ' reading {}, {} and {} columns and setting '
57
+ 'labels to {}'.format(
58
+ row["id"], row["sentence"].strip(),
59
+ None, self.test_label))
60
+ else:
61
+ is_test = False
62
+ print_rank_0(' reading {} , {}, {}, and {} columns '
63
+ '...'.format(
64
+ row["id"], row["sentence"].strip(),
65
+ None, row["label"].strip()))
66
+
67
+ text_a = clean_text(row["sentence"].strip())
68
+ text_b = None
69
+ unique_id = int(row["id"])
70
+
71
+ if is_test:
72
+ label = self.test_label
73
+ else:
74
+ label = row["label"].strip()
75
+
76
+ assert len(text_a) > 0
77
+ # assert len(text_b) > 0
78
+ assert label in LABELS, "found label {} {}".format(label, row)
79
+ assert unique_id >= 0
80
+
81
+ sample = {'text_a': text_a,
82
+ 'text_b': text_b,
83
+ 'label': LABELS[label],
84
+ 'uid': unique_id}
85
+ total += 1
86
+ samples.append(sample)
87
+
88
+ if total % 5000 == 0:
89
+ print_rank_0(' > processed {} so far ...'.format(total))
90
+
91
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
92
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
93
+
94
+ return samples
tasks/clue/ocnli.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("OCNLI")
25
+
26
+ class OCNLIDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='contradiction'):
30
+ self.test_label = test_label
31
+ super().__init__('OCNLI', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ reader = f.readlines()
44
+ lines = []
45
+ for line in reader:
46
+ lines.append(json.loads(line.strip()))
47
+ drop_cnt = 0
48
+ for index, row in enumerate(lines):
49
+ # line = line.strip()
50
+ # try:
51
+ # row = eval(line)
52
+ # except:
53
+ # print(">>>>>>>> ", line)
54
+ # continue
55
+ if first:
56
+ first = False
57
+ if "label" not in row:
58
+ is_test = True
59
+ print_rank_0(
60
+ ' reading {}, {} and {} columns and setting '
61
+ 'labels to {}'.format(
62
+ row["id"], row["sentence1"].strip(),
63
+ row["sentence2"].strip(), self.test_label))
64
+ else:
65
+ is_test = False
66
+ print_rank_0(' reading {} , {}, {}, and {} columns '
67
+ '...'.format(
68
+ row["id"], row["sentence1"].strip(),
69
+ row["sentence2"].strip(), row["label"].strip()))
70
+
71
+ text_a = clean_text(row["sentence1"].strip())
72
+ text_b = clean_text(row["sentence2"].strip())
73
+ unique_id = int(row["id"])
74
+
75
+ if is_test:
76
+ label = self.test_label
77
+ else:
78
+ label = row["label"].strip()
79
+
80
+ if label == "-":
81
+ drop_cnt += 1
82
+ continue
83
+
84
+ assert len(text_a) > 0
85
+ assert len(text_b) > 0
86
+ assert label in LABELS, "found label {} {}".format(label, row)
87
+ assert unique_id >= 0
88
+
89
+ sample = {'text_a': text_a,
90
+ 'text_b': text_b,
91
+ 'label': LABELS[label],
92
+ 'uid': unique_id}
93
+ total += 1
94
+ samples.append(sample)
95
+
96
+ if total % 5000 == 0:
97
+ print_rank_0(' > processed {} so far ...'.format(total))
98
+
99
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
100
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
101
+
102
+ return samples
tasks/clue/tnews.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("TNEWS")
25
+
26
+ class TNEWSDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='100'):
30
+ self.test_label = test_label
31
+ super().__init__('TNEWS', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ reader = f.readlines()
44
+ lines = []
45
+ for line in reader:
46
+ lines.append(json.loads(line.strip()))
47
+ drop_cnt = 0
48
+ for index, row in enumerate(lines):
49
+ if "id" not in row:
50
+ row["id"] = index
51
+ if first:
52
+ first = False
53
+ if "label" not in row:
54
+ is_test = True
55
+ print_rank_0(
56
+ ' reading {}, {} and {} columns and setting '
57
+ 'labels to {}'.format(
58
+ row["id"], row["sentence"].strip(),
59
+ None, self.test_label))
60
+ else:
61
+ is_test = False
62
+ print_rank_0(' reading {} , {}, {}, and {} columns '
63
+ '...'.format(
64
+ row["id"], row["sentence"].strip(),
65
+ None, row["label"].strip()))
66
+
67
+ text_a = clean_text(row["sentence"].strip())
68
+ text_b = clean_text(row["keywords"].strip())
69
+ # text_b = None
70
+ unique_id = int(row["id"])
71
+
72
+ if is_test:
73
+ label = self.test_label
74
+ else:
75
+ label = row["label"].strip()
76
+
77
+ assert len(text_a) > 0
78
+ # assert len(text_b) > 0
79
+ assert label in LABELS, "found label {} {}".format(label, row)
80
+ assert unique_id >= 0
81
+
82
+ sample = {'text_a': text_a,
83
+ 'text_b': text_b,
84
+ 'label': LABELS[label],
85
+ 'uid': unique_id}
86
+ total += 1
87
+ samples.append(sample)
88
+
89
+ if total % 5000 == 0:
90
+ print_rank_0(' > processed {} so far ...'.format(total))
91
+
92
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
93
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
94
+
95
+ return samples
tasks/clue/wsc.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("WSC")
25
+
26
+ class WSCDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label="false"):
30
+ self.test_label = test_label
31
+ super().__init__('WSC', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ reader = f.readlines()
44
+ lines = []
45
+ for line in reader:
46
+ lines.append(json.loads(line.strip()))
47
+ drop_cnt = 0
48
+ for index, row in enumerate(lines):
49
+ if "id" not in row:
50
+ row["id"] = index
51
+ text_a = row['text']
52
+ text_a_list = list(text_a)
53
+ target = row['target']
54
+ query = target['span1_text']
55
+ query_idx = target['span1_index']
56
+ pronoun = target['span2_text']
57
+ pronoun_idx = target['span2_index']
58
+ assert text_a[pronoun_idx: (pronoun_idx + len(pronoun))] == pronoun, "pronoun: {}".format(pronoun)
59
+ assert text_a[query_idx: (query_idx + len(query))] == query, "query: {}".format(query)
60
+ if pronoun_idx > query_idx:
61
+ text_a_list.insert(query_idx, "_")
62
+ text_a_list.insert(query_idx + len(query) + 1, "_")
63
+ text_a_list.insert(pronoun_idx + 2, "[")
64
+ text_a_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
65
+ else:
66
+ text_a_list.insert(pronoun_idx, "[")
67
+ text_a_list.insert(pronoun_idx + len(pronoun) + 1, "]")
68
+ text_a_list.insert(query_idx + 2, "_")
69
+ text_a_list.insert(query_idx + len(query) + 2 + 1, "_")
70
+ text_a = "".join(text_a_list)
71
+ # text_b = "在这句话中,{}指代的是{}".format(pronoun, query)
72
+ text_b = None
73
+ if first:
74
+ first = False
75
+ if "label" not in row:
76
+ is_test = True
77
+ print_rank_0(
78
+ ' reading {}, {} and {} columns and setting '
79
+ 'labels to {}'.format(
80
+ row["id"], text_a,
81
+ text_b, self.test_label))
82
+ else:
83
+ is_test = False
84
+ print_rank_0(' reading {} , {}, {}, and {} columns '
85
+ '...'.format(
86
+ row["id"], text_a,
87
+ text_b, row["label"].strip()))
88
+
89
+ text_a = text_a
90
+ text_b = text_b
91
+ # text_b = None
92
+ unique_id = int(row["id"])
93
+
94
+ if is_test:
95
+ label = self.test_label
96
+ else:
97
+ label = row["label"].strip()
98
+
99
+ assert len(text_a) > 0
100
+ # assert len(text_b) > 0
101
+ assert label in LABELS, "found label {} {} {}".format(label, row, type(label))
102
+ assert unique_id >= 0
103
+
104
+ sample = {'text_a': text_a,
105
+ 'text_b': text_b,
106
+ 'label': LABELS[label],
107
+ 'uid': unique_id}
108
+ total += 1
109
+ samples.append(sample)
110
+
111
+ if total % 5000 == 0:
112
+ print_rank_0(' > processed {} so far ...'.format(total))
113
+
114
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
115
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
116
+
117
+ return samples
tasks/clue/zc.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+ import json
22
+ from tasks.label_dict import get_label_dict
23
+
24
+ LABELS = get_label_dict("ZC")
25
+
26
+ class ZCDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='negative'):
30
+ self.test_label = test_label
31
+ super().__init__('ZC', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ print('>>>>', filename)
43
+ with open(filename, 'r') as f:
44
+ reader = f.readlines()
45
+ lines = []
46
+ for line in reader:
47
+ lines.append(json.loads(line.strip()))
48
+ drop_cnt = 0
49
+ for index, row in enumerate(lines):
50
+ # if "id" not in row:
51
+ row["id"] = index
52
+ if first:
53
+ first = False
54
+ # if "label" not in row:
55
+ if "test.json" in filename:
56
+ is_test = True
57
+ print_rank_0(
58
+ ' reading {}, {} and {} columns and setting '
59
+ 'labels to {}'.format(
60
+ row["id"], row["text"].strip(),
61
+ None, self.test_label))
62
+ else:
63
+ is_test = False
64
+ print_rank_0(' reading {} , {}, {}, and {} columns '
65
+ '...'.format(
66
+ row["id"], row["text"].strip(),
67
+ None, row["label"].strip()))
68
+
69
+ text_a = clean_text(row["text"].strip())
70
+ text_b = None
71
+ unique_id = int(row["id"])
72
+
73
+ if is_test:
74
+ label = self.test_label
75
+ else:
76
+ label = row["label"].strip()
77
+
78
+ assert len(text_a) > 0
79
+ # assert len(text_b) > 0
80
+ assert label in LABELS, "found label {} {}".format(label, row)
81
+ assert unique_id >= 0
82
+
83
+ sample = {'text_a': text_a,
84
+ 'text_b': text_b,
85
+ 'label': LABELS[label],
86
+ 'uid': unique_id}
87
+ total += 1
88
+ samples.append(sample)
89
+
90
+ if total % 5000 == 0:
91
+ print_rank_0(' > processed {} so far ...'.format(total))
92
+
93
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
94
+ print_rank_0(' >> drop {} samples.'.format(drop_cnt))
95
+
96
+ return samples
tasks/data_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Tasks data utility."""
17
+
18
+ import re
19
+ import numpy as np
20
+
21
+
22
+ def clean_text(text):
23
+ """Remove new lines and multiple spaces and adjust end of sentence dot."""
24
+
25
+ text = text.replace("\n", " ")
26
+ text = re.sub(r'\s+', ' ', text)
27
+ for _ in range(3):
28
+ text = text.replace(' . ', '. ')
29
+
30
+ return text
31
+
32
+
33
+ def build_sample(ids, types, paddings, label, unique_id):
34
+ """Convert to numpy and return a sample consumed by the batch producer."""
35
+
36
+ ids_np = np.array(ids, dtype=np.int64)
37
+ types_np = np.array(types, dtype=np.int64)
38
+ paddings_np = np.array(paddings, dtype=np.int64)
39
+ sample = ({'text': ids_np,
40
+ 'types': types_np,
41
+ 'padding_mask': paddings_np,
42
+ 'label': int(label),
43
+ 'uid': int(unique_id)})
44
+
45
+ return sample
46
+
47
+
48
+ def build_tokens_types_paddings_from_text(text_a, text_b,
49
+ tokenizer, max_seq_length):
50
+ """Build token types and paddings, trim if needed, and pad if needed."""
51
+
52
+ text_a_ids = tokenizer.tokenize(text_a)
53
+ text_b_ids = None
54
+ if text_b is not None:
55
+ text_b_ids = tokenizer.tokenize(text_b)
56
+
57
+ return build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids,
58
+ max_seq_length, tokenizer.cls,
59
+ tokenizer.sep, tokenizer.pad)
60
+
61
+
62
+ def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
63
+ cls_id, sep_id, pad_id):
64
+ """Build token types and paddings, trim if needed, and pad if needed."""
65
+
66
+ ids = []
67
+ types = []
68
+ paddings = []
69
+
70
+ # [CLS].
71
+ ids.append(cls_id)
72
+ types.append(0)
73
+ paddings.append(1)
74
+
75
+ # A.
76
+ len_text_a = len(text_a_ids)
77
+ ids.extend(text_a_ids)
78
+ types.extend([0] * len_text_a)
79
+ paddings.extend([1] * len_text_a)
80
+
81
+ # [SEP].
82
+ ids.append(sep_id)
83
+ types.append(0)
84
+ paddings.append(1)
85
+
86
+ # B.
87
+ if text_b_ids is not None:
88
+ len_text_b = len(text_b_ids)
89
+ ids.extend(text_b_ids)
90
+ types.extend([1] * len_text_b)
91
+ paddings.extend([1] * len_text_b)
92
+
93
+ # Cap the size.
94
+ trimmed = False
95
+ if len(ids) >= max_seq_length:
96
+ max_seq_length_m1 = max_seq_length - 1
97
+ ids = ids[0:max_seq_length_m1]
98
+ types = types[0:max_seq_length_m1]
99
+ paddings = paddings[0:max_seq_length_m1]
100
+ trimmed = True
101
+
102
+ # [SEP].
103
+ if (text_b_ids is not None) or trimmed:
104
+ ids.append(sep_id)
105
+ if text_b_ids is None:
106
+ types.append(0)
107
+ else:
108
+ types.append(1)
109
+ paddings.append(1)
110
+
111
+ # Padding.
112
+ padding_length = max_seq_length - len(ids)
113
+ if padding_length > 0:
114
+ ids.extend([pad_id] * padding_length)
115
+ types.extend([pad_id] * padding_length)
116
+ paddings.extend([0] * padding_length)
117
+
118
+ return ids, types, paddings
tasks/ensemble_classifier.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import collections
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def process_files(args):
10
+ all_predictions = collections.OrderedDict()
11
+ all_labels = collections.OrderedDict()
12
+ all_uid = collections.OrderedDict()
13
+ for path in args.paths:
14
+ path = os.path.join(path, args.prediction_name)
15
+ try:
16
+ data = torch.load(path)
17
+ for dataset in data:
18
+ name, d = dataset
19
+ predictions, labels, uid = d
20
+ if name not in all_predictions:
21
+ all_predictions[name] = np.array(predictions)
22
+ if args.labels is None:
23
+ args.labels = [i for i in range(all_predictions[name].shape[1])]
24
+ if args.eval:
25
+ all_labels[name] = np.array(labels)
26
+ all_uid[name] = np.array(uid)
27
+ else:
28
+ all_predictions[name] += np.array(predictions)
29
+ assert np.allclose(all_uid[name], np.array(uid))
30
+ except Exception as e:
31
+ print(e)
32
+ continue
33
+ return all_predictions, all_labels, all_uid
34
+
35
+
36
+ def get_threshold(all_predictions, all_labels, one_threshold=False):
37
+ if one_threshold:
38
+ all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
39
+ all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
40
+ out_thresh = []
41
+ for dataset in all_predictions:
42
+ preds = all_predictions[dataset]
43
+ labels = all_labels[dataset]
44
+ out_thresh.append(calc_threshold(preds, labels))
45
+ return out_thresh
46
+
47
+
48
+ def calc_threshold(p, l):
49
+ trials = [(i) * (1. / 100.) for i in range(100)]
50
+ best_acc = float('-inf')
51
+ best_thresh = 0
52
+ for t in trials:
53
+ acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
54
+ if acc > best_acc:
55
+ best_acc = acc
56
+ best_thresh = t
57
+ return best_thresh
58
+
59
+
60
+ def apply_threshold(preds, t):
61
+ assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
62
+ prob = preds[:, -1]
63
+ thresholded = (prob >= t).astype(int)
64
+ preds = np.zeros_like(preds)
65
+ preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
66
+ return preds
67
+
68
+
69
+ def threshold_predictions(all_predictions, threshold):
70
+ if len(threshold) != len(all_predictions):
71
+ threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
72
+ for i, dataset in enumerate(all_predictions):
73
+ thresh = threshold[i]
74
+ preds = all_predictions[dataset]
75
+ all_predictions[dataset] = apply_threshold(preds, thresh)
76
+ return all_predictions
77
+
78
+
79
+ def postprocess_predictions(all_predictions, all_labels, args):
80
+ for d in all_predictions:
81
+ all_predictions[d] = all_predictions[d] / len(args.paths)
82
+
83
+ if args.calc_threshold:
84
+ args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
85
+ print('threshold', args.threshold)
86
+
87
+ if args.threshold is not None:
88
+ all_predictions = threshold_predictions(all_predictions, args.threshold)
89
+
90
+ return all_predictions, all_labels
91
+
92
+
93
+ def write_predictions(all_predictions, all_labels, all_uid, args):
94
+ all_correct = 0
95
+ count = 0
96
+ for dataset in all_predictions:
97
+ preds = all_predictions[dataset]
98
+ preds = np.argmax(preds, -1)
99
+ if args.eval:
100
+ correct = (preds == all_labels[dataset]).sum()
101
+ num = len(all_labels[dataset])
102
+ accuracy = correct / num
103
+ count += num
104
+ all_correct += correct
105
+ accuracy = (preds == all_labels[dataset]).mean()
106
+ print(accuracy)
107
+ if not os.path.exists(os.path.join(args.outdir, dataset)):
108
+ os.makedirs(os.path.join(args.outdir, dataset))
109
+ outpath = os.path.join(
110
+ args.outdir, dataset, os.path.splitext(
111
+ args.prediction_name)[0] + '.tsv')
112
+ with open(outpath, 'w') as f:
113
+ f.write('id\tlabel\n')
114
+ f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
115
+ for uid, p in zip(all_uid[dataset], preds.tolist())))
116
+ if args.eval:
117
+ print(all_correct / count)
118
+
119
+
120
+ def ensemble_predictions(args):
121
+ all_predictions, all_labels, all_uid = process_files(args)
122
+ all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
123
+ write_predictions(all_predictions, all_labels, all_uid, args)
124
+
125
+
126
+ def main():
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument('--paths', required=True, nargs='+',
129
+ help='paths to checkpoint directories used in ensemble')
130
+ parser.add_argument('--eval', action='store_true',
131
+ help='compute accuracy metrics against labels (dev set)')
132
+ parser.add_argument('--outdir',
133
+ help='directory to place ensembled predictions in')
134
+ parser.add_argument('--prediction-name', default='test_predictions.pt',
135
+ help='name of predictions in checkpoint directories')
136
+ parser.add_argument('--calc-threshold', action='store_true',
137
+ help='calculate threshold classification')
138
+ parser.add_argument('--one-threshold', action='store_true',
139
+ help='use on threshold for all subdatasets')
140
+ parser.add_argument('--threshold', nargs='+', default=None, type=float,
141
+ help='user supplied threshold for classification')
142
+ parser.add_argument('--labels', nargs='+', default=None,
143
+ help='whitespace separated list of label names')
144
+ args = parser.parse_args()
145
+ ensemble_predictions(args)
146
+
147
+
148
+ if __name__ == '__main__':
149
+ main()
tasks/eval_utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Evaluation utilities."""
17
+
18
+ import os
19
+ import time
20
+ from functools import partial
21
+
22
+ import torch
23
+
24
+ from megatron import get_args
25
+ from megatron import print_rank_last, is_last_rank
26
+ from megatron import mpu
27
+ from megatron.schedules import get_forward_backward_func
28
+ from tasks.finetune_utils import build_data_loader
29
+ from tasks.finetune_utils import process_batch
30
+ import json
31
+ import numpy as np
32
+ from tasks.label_dict import get_label_dict
33
+
34
+ def accuracy_func_provider(single_dataset_provider):
35
+ """Provide function that calculates accuracies."""
36
+ args = get_args()
37
+
38
+ # Build dataloaders.
39
+ datapaths = [args.valid_data[0], args.test_data[0]]
40
+ dataloaders = []
41
+ for datapath in datapaths:
42
+ dataset = single_dataset_provider(datapath)
43
+ dataloader = build_data_loader(
44
+ dataset, args.micro_batch_size, num_workers=args.num_workers,
45
+ drop_last=(mpu.get_data_parallel_world_size() > 1))
46
+ dataloaders.append((dataset.dataset_name, dataloader))
47
+
48
+ def _generate_prediction_json(predictions, step, save_acc):
49
+
50
+ probs_list = predictions[0]
51
+ # labels_list = predictions[1]
52
+ ids_list = predictions[2]
53
+ min_id = min(ids_list)
54
+ max_id = max(ids_list)
55
+ LABELS = get_label_dict(args.task, write2file=True)
56
+ output_submit_file = os.path.join(args.res_path[0], args.task+"_prediction_{}_{}.json".format(step, save_acc))
57
+ with open(output_submit_file, "w") as writer:
58
+ for i in range(min_id, max_id + 1):
59
+ label_index = ids_list.index(i)
60
+ pred_prob_list = probs_list[label_index]
61
+ label = pred_prob_list.index(max(pred_prob_list))
62
+ json_d = {}
63
+ if min_id == 1:
64
+ json_d['id'] = i - 1
65
+ else:
66
+ json_d['id'] = i
67
+ json_d["label"] = LABELS[str(label)]
68
+ writer.write(json.dumps(json_d) + '\n')
69
+
70
+ def _generate_prediction_prob(predictions, step, save_acc):
71
+
72
+ probs_list = predictions[0]
73
+ ids_list = predictions[2]
74
+ min_id = min(ids_list)
75
+ max_id = max(ids_list)
76
+
77
+ output_prob_file = os.path.join(args.res_path[0], args.task+"_prob_{}_{}".format(step, save_acc))
78
+ prob_arr = []
79
+ for i in range(min_id, max_id + 1):
80
+ label_index = ids_list.index(i)
81
+ prob_arr.append(probs_list[label_index])
82
+ prob_arr = np.array(prob_arr)
83
+ np.save(output_prob_file, prob_arr)
84
+
85
+ def metrics_func(model, step):
86
+ print_rank_last('calculating metrics ...')
87
+ correct = 0
88
+ total = 0
89
+
90
+ for index, (name, dataloader) in enumerate(dataloaders):
91
+ if index == 1:
92
+ output_predictions = True
93
+ assert mpu.get_data_parallel_world_size() == 1
94
+ named_predictions = []
95
+ names = 'predictions'
96
+ else:
97
+ output_predictions = False
98
+
99
+ output = calculate_correct_answers(name, model, dataloader,
100
+ step, output_predictions)
101
+ if not output_predictions:
102
+ correct_ans, total_count = output
103
+ else:
104
+ correct_ans, total_count, predictions = output
105
+ named_predictions.append((name, predictions))
106
+ names += '_' + name
107
+ if not output_predictions:
108
+ correct += correct_ans
109
+ total += total_count
110
+ save_acc = str(round(correct / total, 4) * 10000)[:4]
111
+
112
+ if output_predictions:
113
+ print_rank_last("generate prediction...")
114
+ # import pdb;pdb.set_trace()
115
+ _generate_prediction_json(predictions, step, save_acc)
116
+ _generate_prediction_prob(predictions, step, save_acc)
117
+ print_rank_last("generate done")
118
+ # import pdb;pdb.set_trace()
119
+ # import pdb;pdb.set_trace()
120
+ # if is_last_rank():
121
+ # percent = float(correct) * 100.0 / float(total)
122
+ # print(' >> |step: {}| overall: correct / total = {} / {} = '
123
+ # '{:.4f} %'.format(step, correct, total, percent))
124
+ # if output_predictions and is_last_rank():
125
+ # assert args.load is not None
126
+ # filename = os.path.join(args.load, names + '.pt')
127
+ # torch.save(named_predictions, filename)
128
+
129
+ return metrics_func
130
+
131
+
132
+ def calculate_correct_answers(name, model, dataloader,
133
+ step, output_predictions):
134
+ """Calculate correct over total answers and return prediction if the
135
+ `output_predictions` is true."""
136
+ args = get_args()
137
+ forward_backward_func = get_forward_backward_func()
138
+ start_time = time.time()
139
+ for m in model:
140
+ m.eval()
141
+ saved_micro_batch_size = args.micro_batch_size
142
+ saved_global_batch_size = args.global_batch_size
143
+
144
+ ds = dataloader.dataset
145
+ if hasattr(ds, 'sample_multiplier'):
146
+ # If our dataset as a sample_multiplier attribute that means
147
+ # each "sample" from the dataset actually has multiple samples
148
+ # that will collapse into the batch dimension (for example in
149
+ # the RACE dataset that has several options), we need to
150
+ # account for that when setting the micro batch size.
151
+ sample_multiplier = ds.sample_multiplier
152
+ else:
153
+ sample_multiplier = 1
154
+ micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
155
+ num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
156
+
157
+ def loss_func(output_predictions, labels, output_tensor):
158
+ logits = output_tensor
159
+
160
+ loss_dict = {}
161
+ # Add output predictions.
162
+ if output_predictions:
163
+ # assert False
164
+ loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
165
+ logits.float()).data.cpu().numpy().tolist()
166
+ loss_dict['labels'] = labels.data.cpu().numpy().tolist()
167
+ loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
168
+ # Compute the correct answers.
169
+ predicted = torch.argmax(logits, dim=-1)
170
+ corrects = (predicted == labels)
171
+ # Add to the counters.
172
+ loss_dict['total'] = labels.size(0)
173
+ loss_dict['correct'] = corrects.sum().item()
174
+
175
+ return 0, loss_dict
176
+
177
+ # defined inside to capture output_predictions
178
+ def correct_answers_forward_step(batch, model):
179
+ try:
180
+ batch_ = next(batch)
181
+ except BaseException:
182
+ batch_ = batch
183
+ tokens, types, labels, attention_mask = process_batch(batch_)
184
+
185
+ # Forward model.
186
+ args = get_args()
187
+ output_tensor = model(tokens, attention_mask, tokentype_ids=types)
188
+
189
+ return output_tensor, partial(loss_func, output_predictions, labels)
190
+
191
+ with torch.no_grad():
192
+ # For all the batches in the dataset.
193
+ total = 0
194
+ correct = 0
195
+ if output_predictions:
196
+ # This option is only possible when data parallel size is 1.
197
+ assert mpu.get_data_parallel_world_size() == 1
198
+ softmaxes = []
199
+ labels = []
200
+ ids = []
201
+ for _, batch in enumerate(dataloader):
202
+ # For evaluation only mode we use drop_last = False to get all the
203
+ # samples, which means we might not have a full batch, so we
204
+ # adjust batch_size here to actual batch size of data
205
+ actual_batch_size = len(batch['label'])
206
+ # ... applying sample_multiplier if necessary
207
+ args.micro_batch_size = actual_batch_size * sample_multiplier
208
+ args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
209
+
210
+ loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
211
+ optimizer=None, timers=None, forward_only=True)
212
+
213
+ for loss_dict in loss_dicts:
214
+ if output_predictions:
215
+ softmaxes.extend(loss_dict['softmaxes'])
216
+ labels.extend(loss_dict['labels'])
217
+ ids.extend(loss_dict['ids'])
218
+ total += loss_dict['total']
219
+ correct += loss_dict['correct']
220
+
221
+
222
+ for m in model:
223
+ m.train()
224
+ args.micro_batch_size = saved_micro_batch_size
225
+ args.global_batch_size = saved_global_batch_size
226
+
227
+ # Reduce.
228
+ if mpu.is_pipeline_last_stage():
229
+ unreduced = torch.cuda.LongTensor([correct, total])
230
+ torch.distributed.all_reduce(unreduced,
231
+ group=mpu.get_data_parallel_group())
232
+
233
+ # Print on screen.
234
+
235
+ correct_ans = unreduced[0].item()
236
+ total_count = unreduced[1].item()
237
+ percent = float(correct_ans) * 100.0 / float(total_count)
238
+ elapsed_time = time.time() - start_time
239
+ if not output_predictions:
240
+ print_rank_last(' > |step: {} | metrics for {}: correct / total '
241
+ '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
242
+ step, name, correct_ans, total_count,
243
+ percent, elapsed_time))
244
+
245
+ if output_predictions:
246
+ return correct_ans, total_count, (softmaxes, labels, ids)
247
+ return correct_ans, total_count
248
+ if output_predictions:
249
+ return 0, 0, ()
250
+ return 0, 0
tasks/finetune_utils.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Finetune utilities."""
17
+
18
+ from functools import partial
19
+ import sys
20
+ import torch
21
+
22
+ from megatron import get_args, get_num_microbatches
23
+ from megatron import print_rank_0
24
+ from megatron import get_timers
25
+ from megatron import mpu
26
+ from megatron.checkpointing import load_checkpoint
27
+ from megatron.checkpointing import save_checkpoint
28
+ from megatron.model import ModelType
29
+ from megatron.training import evaluate_and_print_results
30
+ from megatron.training import setup_model_and_optimizer
31
+ from megatron.training import train_step
32
+ from megatron.training import training_log
33
+ from megatron.utils import average_losses_across_data_parallel_group
34
+ from megatron.utils import calc_params_l2_norm
35
+ from megatron.utils import check_adlr_autoresume_termination
36
+
37
+
38
+ def process_batch(batch):
39
+ """Process batch and produce inputs for the model."""
40
+ args = get_args()
41
+
42
+ tokens = batch['text'].long().cuda().contiguous()
43
+ types = batch['types'].long().cuda().contiguous()
44
+ labels = batch['label'].long().cuda().contiguous()
45
+ attention_mask = batch['padding_mask'].float().cuda().contiguous()
46
+ if args.fp16:
47
+ attention_mask = attention_mask.half()
48
+
49
+ return tokens, types, labels, attention_mask
50
+
51
+
52
+ def cross_entropy_loss_func(labels, output_tensor):
53
+ logits = output_tensor
54
+
55
+ # Cross-entropy loss.
56
+ loss_func = torch.nn.CrossEntropyLoss()
57
+ loss = loss_func(logits.contiguous().float(), labels)
58
+
59
+ # Reduce loss for logging.
60
+ averaged_loss = average_losses_across_data_parallel_group([loss])
61
+
62
+ return loss, {'training loss': averaged_loss[0]}
63
+
64
+
65
+ def _cross_entropy_forward_step(batch, model):
66
+ """Simple forward step with cross-entropy loss."""
67
+ timers = get_timers()
68
+
69
+ # Get the batch.
70
+ timers('batch-generator').start()
71
+ try:
72
+ batch_ = next(batch)
73
+ except BaseException:
74
+ batch_ = batch
75
+ tokens, types, labels, attention_mask = process_batch(batch_)
76
+ timers('batch-generator').stop()
77
+
78
+ # Forward model.
79
+ output_tensor = model(tokens, attention_mask, tokentype_ids=types)
80
+
81
+ return output_tensor, partial(cross_entropy_loss_func, labels)
82
+
83
+
84
+ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
85
+ task_collate_fn=None):
86
+ """Data loader. Note that batch-size is the local (per GPU) batch-size."""
87
+
88
+ # Sampler.
89
+ world_size = mpu.get_data_parallel_world_size()
90
+ rank = mpu.get_data_parallel_rank()
91
+ sampler = torch.utils.data.distributed.DistributedSampler(
92
+ dataset, num_replicas=world_size, rank=rank)
93
+
94
+ # Data loader. Note that batch size is the per GPU batch size.
95
+ data_loader = torch.utils.data.DataLoader(dataset,
96
+ batch_size=micro_batch_size,
97
+ sampler=sampler,
98
+ shuffle=False,
99
+ num_workers=num_workers,
100
+ drop_last=drop_last,
101
+ pin_memory=True,
102
+ collate_fn=task_collate_fn)
103
+
104
+ return data_loader
105
+
106
+
107
+ def _build_infinite_size_dataloader(dataloader):
108
+ """Build a looped dataloader with infinite size."""
109
+
110
+ iterator = dataloader.__iter__()
111
+ while True:
112
+ try:
113
+ yield iterator.__next__()
114
+ except StopIteration:
115
+ iterator = dataloader.__iter__()
116
+
117
+
118
+ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
119
+ task_collate_fn=None):
120
+ """Traing and validation dataloaders."""
121
+ args = get_args()
122
+
123
+ print_rank_0('building train and validation dataloaders ...')
124
+ # Training dataset.
125
+ train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
126
+ args.num_workers, not args.keep_last,
127
+ task_collate_fn)
128
+ # Set the training iterations.
129
+ args.train_iters_per_epoch = len(train_dataloader)
130
+ args.train_iters = args.epochs * args.train_iters_per_epoch
131
+ # Validation dataset. For this dataset, we do not need to set up
132
+ # shuffling so we can just use a simple infinite loop.
133
+ valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
134
+ args.num_workers, not args.keep_last,
135
+ task_collate_fn)
136
+ valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
137
+
138
+ # Now that we've built the data loaders, set batch_size arguments
139
+ # to the actual batch size the model will see for this dataset.
140
+ # This is necessary so pipeline transfers know what size they are
141
+ # and the LR schedule, which is based on samples seen, gets set
142
+ # correctly.
143
+ args.orig_micro_batch_size = args.micro_batch_size
144
+ args.orig_global_batch_size = args.global_batch_size
145
+ if hasattr(train_dataset, 'sample_multiplier'):
146
+ # If our dataset as a sample_multiplier attribute that means
147
+ # each "sample" from the dataset actually has multiple samples
148
+ # that will collapse into the batch dimension (for example in
149
+ # the RACE dataset that has several options), we need to
150
+ # account for that when setting the micro batch size.
151
+ args.micro_batch_size *= train_dataset.sample_multiplier
152
+ args.global_batch_size *= train_dataset.sample_multiplier
153
+
154
+ return train_dataloader, valid_dataloader
155
+
156
+
157
+ def _train(model, optimizer, opt_param_scheduler, forward_step,
158
+ train_dataloader, valid_dataloader, end_of_epoch_callback):
159
+ """Train the model."""
160
+ args = get_args()
161
+ timers = get_timers()
162
+
163
+ assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
164
+
165
+ # Turn on training mode which enables dropout.
166
+ for m in model:
167
+ m.train()
168
+
169
+ # Tracking loss.
170
+ losses_dict_sum = {}
171
+
172
+ # Starting epoch and iteration
173
+ start_epoch = args.iteration // args.train_iters_per_epoch
174
+ start_iteration = args.iteration % args.train_iters_per_epoch
175
+ iteration = args.iteration
176
+
177
+ # Memory reporting flag.
178
+ report_memory_flag = True
179
+
180
+ # For each remaining epoch
181
+ timers('interval-time').start()
182
+ for epoch in range(start_epoch, args.epochs):
183
+ print_rank_0('working on epoch {} ...'.format(epoch + 1))
184
+
185
+ # Set the data loader epoch to shuffle the index iterator.
186
+ train_dataloader.sampler.set_epoch(args.seed + epoch)
187
+
188
+ # For all the batches in the dataset.
189
+ for iteration_, batch in enumerate(train_dataloader):
190
+
191
+ # Ignore the iterations before starting value
192
+ if iteration_ < start_iteration:
193
+ continue
194
+ # Set to zero so the next epoch does not skip any batches.
195
+ start_iteration = 0
196
+
197
+ # Train for one step.
198
+ out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
199
+
200
+ losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
201
+ iteration += 1
202
+
203
+ # Logging.
204
+ params_norm = None
205
+ if args.log_params_norm:
206
+ params_norm = calc_params_l2_norm(model)
207
+ report_memory_flag = training_log(losses_dict, losses_dict_sum,
208
+ optimizer.param_groups[0]['lr'],
209
+ iteration,
210
+ optimizer.get_loss_scale().item(),
211
+ report_memory_flag, skipped_iter,
212
+ grad_norm, params_norm, num_zeros_in_grad, None)
213
+
214
+ # Autoresume
215
+ if args.adlr_autoresume and \
216
+ (iteration % args.adlr_autoresume_interval == 0):
217
+ check_adlr_autoresume_termination(iteration, model,
218
+ optimizer, opt_param_scheduler)
219
+
220
+ # Checkpointing
221
+ saved_checkpoint = False
222
+ if args.save and args.save_interval and \
223
+ iteration % args.save_interval == 0:
224
+ save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
225
+ saved_checkpoint = True
226
+
227
+ # Evaluation
228
+ if args.eval_interval and iteration % args.eval_interval == 0:
229
+ prefix = 'iteration {}'.format(iteration)
230
+ evaluate_and_print_results(prefix, forward_step,
231
+ valid_dataloader, model,
232
+ iteration, None, False)
233
+ if end_of_epoch_callback is not None:
234
+ end_of_epoch_callback(model, iteration)
235
+ print_rank_0('-' * 72 + '\n')
236
+
237
+ # Exiting based on iterations
238
+ if args.exit_interval and iteration % args.exit_interval == 0:
239
+ if not saved_checkpoint:
240
+ save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
241
+ torch.distributed.barrier()
242
+ print_rank_0('exiting program at iteration {}'.format(iteration))
243
+ sys.exit()
244
+
245
+ # Checkpointing at the end of each epoch.
246
+ if args.save:
247
+ save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
248
+
249
+ prefix = 'iteration {}'.format(iteration)
250
+ evaluate_and_print_results(prefix, forward_step,
251
+ valid_dataloader, model,
252
+ iteration, None, False)
253
+ if end_of_epoch_callback is not None:
254
+ end_of_epoch_callback(model, iteration)
255
+ print_rank_0('-' * 72 + '\n')
256
+
257
+ # Callback at the end of each epoch.
258
+ # if end_of_epoch_callback is not None:
259
+ # end_of_epoch_callback(model, epoch)
260
+
261
+
262
+ def finetune(train_valid_datasets_provider, model_provider,
263
+ model_type=ModelType.encoder_or_decoder,
264
+ forward_step=_cross_entropy_forward_step,
265
+ end_of_epoch_callback_provider=None,
266
+ task_collate_fn=None):
267
+ """Main finetune function used across all tasks."""
268
+ args = get_args()
269
+ timers = get_timers()
270
+
271
+ assert args.rampup_batch_size is None, \
272
+ 'batch size scaling is not supported for finetuning'
273
+
274
+ # Train and validation data loaders.
275
+ timers('train/valid/test dataset/dataloder').start()
276
+ if args.epochs > 0:
277
+ train_dataset, valid_dataset = train_valid_datasets_provider()
278
+ train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
279
+ train_dataset, valid_dataset, task_collate_fn)
280
+ else:
281
+ args.train_iters = 0
282
+ timers('train/valid/test dataset/dataloder').stop()
283
+
284
+ # Build calback function.
285
+ timers('callback function').start()
286
+ end_of_epoch_callback = None
287
+ if end_of_epoch_callback_provider is not None:
288
+ end_of_epoch_callback = end_of_epoch_callback_provider()
289
+ timers('callback function').stop()
290
+
291
+ # Build model, optimizer and learning rate scheduler.
292
+ timers('model and optimizer').start()
293
+ model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
294
+ timers('model and optimizer').stop()
295
+
296
+ # If pretrained checkpoint is provided and we have not trained for
297
+ # any iteration (i.e., iteration is zero), then load the pretrained
298
+ # checkpoint.
299
+ timers('pretrained checkpoint').start()
300
+ if args.iteration == 0 and args.pretrained_checkpoint is not None:
301
+ original_load = args.load
302
+ args.load = args.pretrained_checkpoint
303
+ original_rng = args.no_load_rng
304
+ args.no_load_rng = True
305
+ _ = load_checkpoint(model, None, None)
306
+ args.load = original_load
307
+ args.no_load_rng = original_rng
308
+ # This is critical when only model is loaded. We should make sure
309
+ # main parameters are also updated.
310
+ optimizer.reload_model_params()
311
+ timers('pretrained checkpoint').stop()
312
+
313
+ # Print setup timing.
314
+ print_rank_0('done with setups ...')
315
+ timers.log(['train/valid/test dataset/dataloder', 'callback function',
316
+ 'model and optimizer', 'pretrained checkpoint'])
317
+ print_rank_0('training ...')
318
+
319
+ # Finetune the model.
320
+ if args.epochs > 0:
321
+ _train(model, optimizer, opt_param_scheduler, forward_step,
322
+ train_dataloader, valid_dataloader, end_of_epoch_callback)
323
+ # Or just evaluate.
324
+ else:
325
+ print_rank_0("Not Imp")
326
+ import pdb;pdb.set_trace()
327
+ # if end_of_epoch_callback is not None:
328
+ # print_rank_0('evaluation only mode, setting epoch to -1')
329
+ # end_of_epoch_callback(model, epoch=-1, output_predictions=True)
330
+ print_rank_0('done :-)')
tasks/glue/data.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """GLUE dataset."""
17
+
18
+ from abc import ABC
19
+ from abc import abstractmethod
20
+
21
+ from torch.utils.data import Dataset
22
+
23
+ from megatron import print_rank_0
24
+ from tasks.data_utils import build_sample
25
+ from tasks.data_utils import build_tokens_types_paddings_from_text
26
+
27
+
28
+ class GLUEAbstractDataset(ABC, Dataset):
29
+ """GLUE base dataset class."""
30
+
31
+ def __init__(self, task_name, dataset_name, datapaths,
32
+ tokenizer, max_seq_length):
33
+ # Store inputs.
34
+ self.task_name = task_name
35
+ self.dataset_name = dataset_name
36
+ self.tokenizer = tokenizer
37
+ self.max_seq_length = max_seq_length
38
+ print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
39
+ self.dataset_name))
40
+ # Process the files.
41
+ string = ' > paths:'
42
+ for path in datapaths:
43
+ string += ' ' + path
44
+ print_rank_0(string)
45
+ self.samples = []
46
+ for datapath in datapaths:
47
+ self.samples.extend(self.process_samples_from_single_path(datapath))
48
+ print_rank_0(' >> total number of samples: {}'.format(
49
+ len(self.samples)))
50
+
51
+ def __len__(self):
52
+ return len(self.samples)
53
+
54
+ def __getitem__(self, idx):
55
+ raw_sample = self.samples[idx]
56
+ ids, types, paddings = build_tokens_types_paddings_from_text(
57
+ raw_sample['text_a'], raw_sample['text_b'],
58
+ self.tokenizer, self.max_seq_length)
59
+ sample = build_sample(ids, types, paddings,
60
+ raw_sample['label'], raw_sample['uid'])
61
+ return sample
62
+
63
+ @abstractmethod
64
+ def process_samples_from_single_path(self, datapath):
65
+ """Abstract method that takes a single path / filename and
66
+ returns a list of dataset samples, each sample being a dict of
67
+ {'text_a': string, 'text_b': string, 'label': int, 'uid': int}
68
+ """
69
+ pass
tasks/glue/finetune.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """GLUE finetuning/evaluation."""
17
+
18
+ from megatron import get_args
19
+ from megatron import print_rank_0
20
+ from megatron import get_tokenizer
21
+ from megatron import mpu
22
+ from megatron.model.classification import Classification
23
+ from tasks.eval_utils import accuracy_func_provider
24
+ from tasks.finetune_utils import finetune
25
+
26
+
27
+ def glue_classification(num_classes, Dataset,
28
+ name_from_datapath_func):
29
+
30
+ def train_valid_datasets_provider():
31
+ """Build train and validation dataset."""
32
+ args = get_args()
33
+ tokenizer = get_tokenizer()
34
+
35
+ train_dataset = Dataset('training', args.train_data,
36
+ tokenizer, args.seq_length)
37
+ valid_dataset = Dataset('validation', args.valid_data,
38
+ tokenizer, args.seq_length)
39
+
40
+ return train_dataset, valid_dataset
41
+
42
+ def model_provider(pre_process=True, post_process=True):
43
+ """Build the model."""
44
+ args = get_args()
45
+
46
+ print_rank_0('building classification model for {} ...'.format(
47
+ args.task))
48
+ model = Classification(num_classes=num_classes, num_tokentypes=2,
49
+ pre_process=pre_process, post_process=post_process)
50
+
51
+ return model
52
+
53
+ def metrics_func_provider():
54
+ """Privde metrics callback function."""
55
+ def single_dataset_provider(datapath):
56
+ args = get_args()
57
+ tokenizer = get_tokenizer()
58
+
59
+ name = name_from_datapath_func(datapath)
60
+ return Dataset(name, [datapath], tokenizer, args.seq_length)
61
+ return accuracy_func_provider(single_dataset_provider)
62
+
63
+ """Finetune/evaluate."""
64
+ finetune(train_valid_datasets_provider, model_provider,
65
+ end_of_epoch_callback_provider=metrics_func_provider)
66
+
67
+
68
+ def main():
69
+ args = get_args()
70
+
71
+ if args.task == 'MNLI':
72
+
73
+ num_classes = 3
74
+ from tasks.glue.mnli import MNLIDataset as Dataset
75
+
76
+ def name_from_datapath(datapath):
77
+ return datapath.split('MNLI')[-1].strip(
78
+ '.tsv').strip('/').replace('_', '-')
79
+
80
+ elif args.task == 'QQP':
81
+
82
+ num_classes = 2
83
+ from tasks.glue.qqp import QQPDataset as Dataset
84
+
85
+ def name_from_datapath(datapath):
86
+ return datapath.split('QQP')[-1].strip(
87
+ '.tsv').strip('/').replace('_', '-')
88
+
89
+ else:
90
+ raise NotImplementedError('GLUE task {} is not implemented.'.format(
91
+ args.task))
92
+
93
+ glue_classification(num_classes, Dataset, name_from_datapath)
tasks/glue/mnli.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """MNLI dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+
22
+
23
+ LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
24
+
25
+
26
+ class MNLIDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label='contradiction'):
30
+ self.test_label = test_label
31
+ super().__init__('MNLI', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ for line in f:
44
+ row = line.strip().split('\t')
45
+ if first:
46
+ first = False
47
+ if len(row) == 10:
48
+ is_test = True
49
+ print_rank_0(
50
+ ' reading {}, {} and {} columns and setting '
51
+ 'labels to {}'.format(
52
+ row[0].strip(), row[8].strip(),
53
+ row[9].strip(), self.test_label))
54
+ else:
55
+ print_rank_0(' reading {} , {}, {}, and {} columns '
56
+ '...'.format(
57
+ row[0].strip(), row[8].strip(),
58
+ row[9].strip(), row[-1].strip()))
59
+ continue
60
+
61
+ text_a = clean_text(row[8].strip())
62
+ text_b = clean_text(row[9].strip())
63
+ unique_id = int(row[0].strip())
64
+ label = row[-1].strip()
65
+ if is_test:
66
+ label = self.test_label
67
+
68
+ assert len(text_a) > 0
69
+ assert len(text_b) > 0
70
+ assert label in LABELS
71
+ assert unique_id >= 0
72
+
73
+ sample = {'text_a': text_a,
74
+ 'text_b': text_b,
75
+ 'label': LABELS[label],
76
+ 'uid': unique_id}
77
+ total += 1
78
+ samples.append(sample)
79
+
80
+ if total % 50000 == 0:
81
+ print_rank_0(' > processed {} so far ...'.format(total))
82
+
83
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
84
+ return samples
tasks/glue/qqp.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """QQP dataset."""
17
+
18
+ from megatron import print_rank_0
19
+ from tasks.data_utils import clean_text
20
+ from .data import GLUEAbstractDataset
21
+
22
+
23
+ LABELS = [0, 1]
24
+
25
+
26
+ class QQPDataset(GLUEAbstractDataset):
27
+
28
+ def __init__(self, name, datapaths, tokenizer, max_seq_length,
29
+ test_label=0):
30
+ self.test_label = test_label
31
+ super().__init__('QQP', name, datapaths,
32
+ tokenizer, max_seq_length)
33
+
34
+ def process_samples_from_single_path(self, filename):
35
+ """"Implement abstract method."""
36
+ print_rank_0(' > Processing {} ...'.format(filename))
37
+
38
+ samples = []
39
+ total = 0
40
+ first = True
41
+ is_test = False
42
+ with open(filename, 'r') as f:
43
+ for line in f:
44
+ row = line.strip().split('\t')
45
+ if first:
46
+ first = False
47
+ if len(row) == 3:
48
+ is_test = True
49
+ print_rank_0(' reading {}, {}, and {} columns and '
50
+ 'setting labels to {}'.format(
51
+ row[0].strip(), row[1].strip(),
52
+ row[2].strip(), self.test_label))
53
+ else:
54
+ assert len(row) == 6
55
+ print_rank_0(' reading {}, {}, {}, and {} columns'
56
+ ' ...'.format(
57
+ row[0].strip(), row[3].strip(),
58
+ row[4].strip(), row[5].strip()))
59
+ continue
60
+
61
+ if is_test:
62
+ assert len(row) == 3, 'expected length 3: {}'.format(row)
63
+ uid = int(row[0].strip())
64
+ text_a = clean_text(row[1].strip())
65
+ text_b = clean_text(row[2].strip())
66
+ label = self.test_label
67
+ assert len(text_a) > 0
68
+ assert len(text_b) > 0
69
+ else:
70
+ if len(row) == 6:
71
+ uid = int(row[0].strip())
72
+ text_a = clean_text(row[3].strip())
73
+ text_b = clean_text(row[4].strip())
74
+ label = int(row[5].strip())
75
+ else:
76
+ print_rank_0('***WARNING*** index error, '
77
+ 'skipping: {}'.format(row))
78
+ continue
79
+ if len(text_a) == 0:
80
+ print_rank_0('***WARNING*** zero length a, '
81
+ 'skipping: {}'.format(row))
82
+ continue
83
+ if len(text_b) == 0:
84
+ print_rank_0('***WARNING*** zero length b, '
85
+ 'skipping: {}'.format(row))
86
+ continue
87
+ assert label in LABELS
88
+ assert uid >= 0
89
+
90
+ sample = {'uid': uid,
91
+ 'text_a': text_a,
92
+ 'text_b': text_b,
93
+ 'label': label}
94
+ total += 1
95
+ samples.append(sample)
96
+
97
+ if total % 50000 == 0:
98
+ print_rank_0(' > processed {} so far ...'.format(total))
99
+
100
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
101
+ return samples
tasks/label_dict.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ AFQMC_LABELS = {
3
+ '0': '0',
4
+ '1': '1',
5
+ }
6
+
7
+ CSL_LABELS = {
8
+ '0': '0',
9
+ '1': '1',
10
+ '2': '2',
11
+ }
12
+
13
+ IFLYTEK_LABELS = {}
14
+ for i in range(119):
15
+ IFLYTEK_LABELS[str(i)] = str(i)
16
+
17
+ OCNLI_LABELS = {
18
+ 'contradiction': '0',
19
+ 'entailment': '1',
20
+ 'neutral': '2'
21
+ }
22
+
23
+ CMNLI_LABELS = {
24
+ 'contradiction': '0',
25
+ 'entailment': '1',
26
+ 'neutral': '2'
27
+ }
28
+
29
+ TNEWS_LABELS = {}
30
+ tnews_list = []
31
+ for i in range(17):
32
+ if i == 5 or i == 11:
33
+ continue
34
+ tnews_list.append(i)
35
+ for i in range(len(tnews_list)):
36
+ TNEWS_LABELS[str(100 + tnews_list[i])] = str(i)
37
+
38
+ WSC_LABELS = {
39
+ 'true': '0',
40
+ 'false': '1',
41
+ }
42
+
43
+ ZC_LABELS = {
44
+ 'negative': '0',
45
+ 'positive': '1',
46
+ }
47
+
48
+ def get_label_dict(task_name, write2file=False):
49
+
50
+ if task_name == "AFQMC":
51
+ label_dict = AFQMC_LABELS
52
+ elif task_name == "CSL":
53
+ label_dict = CSL_LABELS
54
+ elif task_name == "IFLYTEK":
55
+ label_dict = IFLYTEK_LABELS
56
+ elif task_name == "OCNLI":
57
+ label_dict = OCNLI_LABELS
58
+ elif task_name == "TNEWS":
59
+ label_dict = TNEWS_LABELS
60
+ elif task_name == "WSC":
61
+ label_dict = WSC_LABELS
62
+ elif task_name == "CMNLI":
63
+ label_dict = CMNLI_LABELS
64
+ elif task_name == "ZC":
65
+ label_dict = ZC_LABELS
66
+ else:
67
+ print("Not Imp")
68
+ import pdb;pdb.set_trace()
69
+
70
+ if write2file:
71
+ label_dict = {v:k for k,v in label_dict.items()}
72
+
73
+ return label_dict
tasks/main.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Main tasks functionality."""
17
+
18
+ import os
19
+ import sys
20
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
21
+ os.path.pardir)))
22
+
23
+ from megatron import get_args
24
+ from megatron.initialize import initialize_megatron
25
+
26
+
27
+ def get_tasks_args(parser):
28
+ """Provide extra arguments required for tasks."""
29
+ group = parser.add_argument_group(title='tasks')
30
+
31
+ group.add_argument('--task', type=str, required=True,
32
+ help='Task name.')
33
+ group.add_argument('--epochs', type=int, default=None,
34
+ help='Number of finetunning epochs. Zero results in '
35
+ 'evaluation only.')
36
+ group.add_argument('--pretrained-checkpoint', type=str, default=None,
37
+ help='Pretrained checkpoint used for finetunning.')
38
+ group.add_argument('--keep-last', action='store_true',
39
+ help='Keep the last batch (maybe incomplete) in'
40
+ 'the data loader')
41
+ group.add_argument('--train-data', nargs='+', default=None,
42
+ help='Whitespace separated paths or corpora names '
43
+ 'for training.')
44
+ group.add_argument('--valid-data', nargs='*', default=None,
45
+ help='path(s) to the validation data.')
46
+ group.add_argument('--test-data', nargs='*', default=None,
47
+ help='path(s) to the test data.')
48
+ group.add_argument('--res-path', nargs='*', default=None,
49
+ help='path(s) to the test result.')
50
+ group.add_argument('--overlapping-eval', type=int, default=32,
51
+ help='Sliding window for overlapping evaluation.')
52
+ group.add_argument('--strict-lambada', action='store_true',
53
+ help='Use more difficult formulation of lambada.')
54
+ # Retriever args
55
+ group.add_argument('--qa-data-dev', type=str, default=None,
56
+ help='Path to the QA dataset dev file.')
57
+ group.add_argument('--qa-data-test', type=str, default=None,
58
+ help='Path to the QA dataset test file.')
59
+
60
+ # Faiss arguments for retriever
61
+ group.add_argument('--faiss-use-gpu', action='store_true',
62
+ help='Whether create the FaissMIPSIndex on GPU')
63
+ group.add_argument('--faiss-match', type=str, default='string', \
64
+ choices=['regex', 'string'], help="Answer matching '\
65
+ 'logic type")
66
+ group.add_argument('--faiss-topk-retrievals', type=int, default=100,
67
+ help='Number of blocks to use as top-k during retrieval')
68
+
69
+ # finetune for retriever
70
+ group.add_argument('--eval-micro-batch-size', type=int, default=None,
71
+ help='Eval Batch size per model instance (local batch '
72
+ 'size). Global batch size is local batch size '
73
+ 'times data parallel size.')
74
+ group.add_argument('--train-with-neg', action='store_true',
75
+ help='Whether to use negative examples during model '
76
+ 'training')
77
+ group.add_argument('--train-hard-neg', type=int, default=0,
78
+ help='Number of hard negative exmaples to use during '
79
+ 'training')
80
+
81
+
82
+ # parameters for Av.rank validation method
83
+ # Following options/arguments have been taken directly from DPR codebase
84
+ group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
85
+ help='Av.rank validation: how many hard negatives to'
86
+ ' take from each question pool')
87
+ group.add_argument('--val-av-rank-other-neg', type=int, default=30,
88
+ help='Av.rank validation: how many other negatives to'
89
+ ' take from each question pool')
90
+
91
+
92
+ return parser
93
+
94
+
95
+ if __name__ == '__main__':
96
+
97
+ initialize_megatron(extra_args_provider=get_tasks_args)
98
+
99
+ args = get_args()
100
+
101
+ if args.num_layers_per_virtual_pipeline_stage is not None:
102
+ print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
103
+ exit()
104
+
105
+ if args.task == 'RACE':
106
+ from race.finetune import main
107
+ elif args.task in ['MNLI', 'QQP']:
108
+ from glue.finetune import main
109
+ elif args.task in ['AFQMC', 'CSL', 'IFLYTEK','OCNLI', 'TNEWS', 'WSC', 'CMNLI', "ZC"]:
110
+ from clue.finetune import main
111
+ elif args.task in ['LAMBADA', 'WIKITEXT103']:
112
+ from zeroshot_gpt.evaluate import main
113
+ elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
114
+ from orqa.evaluate_orqa import main
115
+ elif args.task in ['RET-FINETUNE-NQ']:
116
+ from orqa.supervised.finetune import main
117
+ else:
118
+ raise NotImplementedError('Task {} is not implemented.'.format(
119
+ args.task))
120
+
121
+ main()
tasks/msdp/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Multi-Stage Prompting for Knowledgeable Dialogue Generation
3
+
4
+ Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework.
5
+
6
+ ## Multi-Stage Dialogue Prompting
7
+
8
+ ### Data Preparation
9
+ 1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/)
10
+ 2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets.
11
+
12
+ ### Stage-1: Prompting for Knowledge Generation
13
+ 1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation.
14
+ 2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation.
15
+
16
+ ### Stage-2: Prompting for Response Generation
17
+ 1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file).
18
+ 2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation.
19
+ 3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation.
tasks/msdp/evaluate.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Model evaluation"""
17
+
18
+ from megatron import get_args
19
+ from megatron import print_rank_0
20
+ from tasks.msdp.metrics import F1Metric
21
+ from tqdm import tqdm
22
+
23
+
24
+ def evaluate_f1(guess_file, answer_file):
25
+ """Evaluating F1 Score"""
26
+
27
+ guess_list = []
28
+ print_rank_0('reading %s' % guess_file)
29
+ with open(guess_file, "r") as f:
30
+ for i, line in enumerate(tqdm(f)):
31
+ line = line.strip()
32
+ if "<|endoftext|>" in line:
33
+ line = line.replace("<|endoftext|>", "")
34
+ guess_list.append(line)
35
+
36
+ answer_list = []
37
+ print_rank_0('reading %s' % answer_file)
38
+ with open(answer_file, "r") as f:
39
+ for i, line in enumerate(tqdm(f)):
40
+ line = line.strip()
41
+ if line == "no_passages_used":
42
+ line = ""
43
+ answer_list.append(line)
44
+
45
+ assert len(guess_list) == len(answer_list), \
46
+ "lengths of guess and answer are different!"
47
+
48
+ precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
49
+ print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
50
+
51
+ print_rank_0('done :-)')
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+
57
+ evaluate_f1(args.guess_file, args.answer_file)
58
+
tasks/msdp/main.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Run multi-stage dialogue prompting (MSDP)."""
17
+
18
+ import os
19
+ import sys
20
+ sys.path.append(os.path.abspath(os.path.join(
21
+ os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir)))
22
+ from megatron import get_args
23
+ from megatron.initialize import initialize_megatron
24
+
25
+
26
+ def get_tasks_args(parser):
27
+ """Provide extra arguments required for tasks."""
28
+ group = parser.add_argument_group(title='tasks')
29
+
30
+ # parameters for the knowledgeable dialogue generation
31
+ group.add_argument('--task', type=str, required=True,
32
+ help='Task name.')
33
+ group.add_argument("--sample-input-file", type=str, default=None,
34
+ help='Get input from file instead of interactive mode, '
35
+ 'each line is an input.')
36
+ group.add_argument("--sample-output-file", type=str, default=None,
37
+ help='Output file got from --sample-input-file')
38
+ group.add_argument('--prompt-file', type=str, default=None,
39
+ help='prompting file')
40
+ group.add_argument('--prompt-type', type=str, default=None,
41
+ choices=['knowledge', 'response'],
42
+ help='prompt type (knowledge or response)')
43
+ group.add_argument('--num-prompt-examples', type=int, default=10,
44
+ help='number of prompt examples')
45
+ group.add_argument('--guess-file', type=str, default=None,
46
+ help='datapath for generated sentences')
47
+ group.add_argument('--answer-file', type=str, default=None,
48
+ help='datapath for golden sentences')
49
+ group.add_argument('--out-seq-length', type=int, default=100,
50
+ help='output sequence length')
51
+ group.add_argument('--api-prompt', default=False, action="store_true",
52
+ help='setup model api for prompting')
53
+ group.add_argument('--megatron-api-url', type=str, default=None,
54
+ help='url of the megatron api')
55
+
56
+ return parser
57
+
58
+
59
+ if __name__ == '__main__':
60
+
61
+ initialize_megatron(extra_args_provider=get_tasks_args)
62
+
63
+ args = get_args()
64
+
65
+ if args.num_layers_per_virtual_pipeline_stage is not None:
66
+ print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
67
+ exit()
68
+
69
+ if args.task == 'MSDP-PROMPT':
70
+ from tasks.msdp.prompt import main
71
+
72
+ elif args.task == 'MSDP-EVAL-F1':
73
+ from tasks.msdp.evaluate import main
74
+
75
+ else:
76
+ raise NotImplementedError('Task {} is not implemented.'.format(
77
+ args.task))
78
+
79
+ main()
tasks/msdp/metrics.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # The following code is adapted from
3
+ # https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
4
+ # which is licensed under the MIT license. More details on the license can be
5
+ # found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
6
+
7
+ """Provides standard metric evaluations for dialog."""
8
+
9
+ from collections import Counter
10
+ from typing import List
11
+ import numpy as np
12
+ import re
13
+
14
+ re_art = re.compile(r'\b(a|an|the)\b')
15
+ re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
16
+
17
+
18
+ def normalize_answer(s):
19
+ """
20
+ Lower text and remove punctuation, articles and extra whitespace.
21
+ """
22
+ s = s.lower()
23
+ s = re_punc.sub(' ', s)
24
+ s = re_art.sub(' ', s)
25
+ s = ' '.join(s.split())
26
+ return s
27
+
28
+
29
+ class F1Metric:
30
+ """
31
+ Helper class which computes token-level F1.
32
+ """
33
+
34
+ @staticmethod
35
+ def _prec_recall_f1_score(pred_items, gold_items):
36
+ """
37
+ Compute precision, recall and f1 given a set of gold and prediction items.
38
+ :param pred_items: iterable of predicted values
39
+ :param gold_items: iterable of gold values
40
+ :return: tuple (p, r, f1) for precision, recall, f1
41
+ """
42
+ common = Counter(gold_items) & Counter(pred_items)
43
+ num_same = sum(common.values())
44
+ if num_same == 0:
45
+ return 0, 0, 0
46
+ precision = 1.0 * num_same / len(pred_items)
47
+ recall = 1.0 * num_same / len(gold_items)
48
+ f1 = (2 * precision * recall) / (precision + recall)
49
+ return precision, recall, f1
50
+
51
+ @staticmethod
52
+ def compute_each_pair(guess: str, answer: str):
53
+ if answer == "":
54
+ return None, None, None
55
+ if guess == "":
56
+ return 0, 0, 0
57
+ g_tokens = normalize_answer(guess).split()
58
+ a_tokens = normalize_answer(answer).split()
59
+
60
+ precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
61
+ return precision, recall, f1
62
+
63
+ @staticmethod
64
+ def compute_all_pairs(guesses: List[str], answers: List[str]):
65
+ # additional augment:
66
+ assert len(guesses) == len(answers)
67
+
68
+ precision_list, recall_list, f1_list = [], [], []
69
+ for guess, answer in zip(guesses, answers):
70
+ precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
71
+ if precision is None or recall is None or f1 is None:
72
+ continue
73
+ precision_list.append(precision)
74
+ recall_list.append(recall)
75
+ f1_list.append(f1)
76
+
77
+ return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
tasks/msdp/preprocessing.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
17
+
18
+ import torch
19
+ import argparse
20
+ from nltk import word_tokenize
21
+ from tqdm import tqdm
22
+ import numpy as np
23
+ import json
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser(description="Preprocessing")
27
+
28
+ parser.add_argument("--func", type=str, default=None,
29
+ help="choose to run which function")
30
+ parser.add_argument("--raw_file", type=str, default=None,
31
+ help="path of the input file")
32
+ parser.add_argument("--processed_file", type=str, default=None,
33
+ help="path of the output file")
34
+ parser.add_argument("--knwl_ref_file", type=str, default=None,
35
+ help="path of the knowledge reference file")
36
+ parser.add_argument("--resp_ref_file", type=str, default=None,
37
+ help="path of the knowledge reference file")
38
+ parser.add_argument("--knwl_gen_file", type=str, default=None,
39
+ help="path of the generated knowledge file")
40
+ parser.add_argument("--test_file", type=str, default=None,
41
+ help="path of the test file")
42
+ parser.add_argument("--train_file", type=str, default=None,
43
+ help="path of the train file")
44
+ parser.add_argument("--model_file", type=str, default=None,
45
+ help="path of the model file")
46
+ parser.add_argument("--data_type", type=str, default=None,
47
+ help="data types, choose one out of three types: \
48
+ wow_seen, wow_unseen, and woi")
49
+ parser.add_argument("--seed", type=int, default=1234,
50
+ help="random seed")
51
+
52
+ args = parser.parse_args()
53
+ return args
54
+
55
+
56
+ def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
57
+ """
58
+ This is a function used for processing the wizard of wikipedia (wow) dataset
59
+ Expected processed format:
60
+ topic \t dialogue context \t golden knowledge \t golden response
61
+ """
62
+
63
+ # loading the raw data
64
+ print("> Loading data from %s" % raw_file)
65
+ with open(raw_file, "r") as fr:
66
+ dialog_data = json.load(fr)
67
+
68
+ print("> Processing data ...")
69
+ fproc = open(processed_file, "w")
70
+ fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
71
+ fresp = open(resp_ref_file, "w") if resp_ref_file else None
72
+
73
+ for i, sample in enumerate(tqdm(dialog_data)):
74
+ # get all the dialog data for a single dialog sample
75
+ dialog = sample["dialog"]
76
+
77
+ turn_list = [] # collect the dialog history
78
+ # processing for each single dialog sample
79
+ for j, turn in enumerate(dialog):
80
+ # text of each turn
81
+ text = turn["text"]
82
+ if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
83
+ text = text + "."
84
+
85
+ if j == 0:
86
+ # first turn
87
+ turn_list.append(text)
88
+ continue
89
+
90
+ speaker = turn["speaker"].lower()
91
+ if "wizard" in speaker:
92
+ checked_sentence = list(turn["checked_sentence"].values()) # knowledge
93
+ checked_passage = list(turn["checked_passage"].values()) # topic
94
+
95
+ assert len(checked_sentence) <= 1
96
+
97
+ # get the ground truth knowledge
98
+ if len(checked_sentence) > 0:
99
+ checked_sentence = checked_sentence[0]
100
+ else:
101
+ checked_sentence = "no_passages_used"
102
+
103
+ if len(checked_passage) == 1:
104
+ checked_passage = checked_passage[0]
105
+ else:
106
+ checked_passage = "no_passages_used"
107
+
108
+ # get the topic
109
+ if checked_passage != "no_passages_used":
110
+ topic = checked_passage
111
+ else:
112
+ topic = sample["chosen_topic"]
113
+
114
+ dialog_context = " [SEP] ".join(turn_list)
115
+ knowledge = checked_sentence
116
+ response = text
117
+ # add the response into the dialog history
118
+ turn_list.append(response)
119
+
120
+ # write to the output files
121
+ fproc.write(topic + "\t" + dialog_context + "\t" + \
122
+ knowledge + "\t" + response + "\n")
123
+
124
+ if fknwl:
125
+ fknwl.write(knowledge + "\n")
126
+ if fresp:
127
+ # tokenize for evaluation
128
+ response = " ".join(word_tokenize(response))
129
+ fresp.write(response + "\n")
130
+
131
+ else:
132
+ assert "apprentice" in speaker
133
+ turn_list.append(text)
134
+
135
+ fproc.close()
136
+ if fknwl:
137
+ fknwl.close()
138
+ if fresp:
139
+ fresp.close()
140
+
141
+
142
+ def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
143
+ """
144
+ This is a function used for processing the wizard of internet (woi) dataset
145
+ Expected processed format:
146
+ topic \t dialogue context \t golden knowledge \t golden response
147
+ """
148
+
149
+ print("> Processing %s" % raw_file)
150
+ fproc = open(processed_file, "w")
151
+ fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
152
+ fresp = open(resp_ref_file, "w") if resp_ref_file else None
153
+
154
+ with open(raw_file, "r") as fr:
155
+ for i, line in tqdm(enumerate(fr)):
156
+ # read line by line, each line uses json format
157
+ line = line.strip()
158
+ item_dict = json.loads(line)
159
+
160
+ # item_dict is a dictionary
161
+ # its key is the data id, and its value contains all the data content
162
+ item_dict = item_dict.values()
163
+ item_dict = list(item_dict)[0] # len(item_dict) == 1
164
+
165
+ # get the whole dialog data for a single dialog sample
166
+ dialog_data = item_dict['dialog_history']
167
+ length = len(dialog_data)
168
+
169
+ turn_list = [] # collect the dialog history
170
+ search_text = ""
171
+ for i in range(length):
172
+ item = dialog_data[i]
173
+ action = item['action']
174
+
175
+ if action == "Wizard => SearchAgent":
176
+ search_text = item['text']
177
+
178
+ elif action == "Wizard => Apprentice":
179
+ if len(turn_list) == 0:
180
+ # first turn
181
+ turn = item['text']
182
+ turn_list.append(turn)
183
+ continue
184
+
185
+ # get the relevant content
186
+ contents = item["context"]["contents"]
187
+ selects = item["context"]["selected_contents"]
188
+ flag = selects[0][0]
189
+ selects = selects[1:]
190
+ assert len(selects) == len(contents)
191
+
192
+ # get the topic
193
+ if flag:
194
+ # no knowledge sentence is used for the response
195
+ topic = "no_topic"
196
+ knwl_sent = "no_passages_used"
197
+ else:
198
+ # we consider the search text as the topic
199
+ topic = search_text
200
+ # get the knowledge sentence
201
+ knwl_sent = ""
202
+ for content, select in zip(contents, selects):
203
+ content = content['content']
204
+ assert len(content) == len(select)
205
+ for c, s in zip(content, select):
206
+ if s:
207
+ knwl_sent = c
208
+ break
209
+
210
+ if knwl_sent == "":
211
+ # no knowledge is used for the response
212
+ topic = "no_topic"
213
+ knwl_sent = "no_passages_used"
214
+
215
+ # get dialogue context, knowledge, and response
216
+ dialog_context = " [SEP] ".join(turn_list)
217
+ response = item['text']
218
+
219
+ # processing
220
+ topic = topic.replace("\n", "").replace("\r", \
221
+ "").replace("\t", "")
222
+ dialog_context = dialog_context.replace("\n", "").replace("\r", \
223
+ "").replace("\t", "")
224
+ knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
225
+ "").replace("\t", "")
226
+ response = response.replace("\n", "").replace("\r", \
227
+ "").replace("\t", "")
228
+
229
+ if topic != "no_topic":
230
+ # write to the ouput files
231
+ fproc.write(topic + "\t" + dialog_context + "\t" + \
232
+ knwl_sent + "\t" + response + "\n")
233
+ if fknwl:
234
+ fknwl.write(knwl_sent + "\n")
235
+ if fresp:
236
+ # tokenize for evaluation
237
+ response = " ".join(word_tokenize(response))
238
+ fresp.write(response + "\n")
239
+
240
+ turn_list.append(response)
241
+
242
+ elif action == "Apprentice => Wizard":
243
+ turn = item['text']
244
+ turn_list.append(turn)
245
+
246
+ else:
247
+ assert action == "SearchAgent => Wizard", \
248
+ "Please check whether you have used the correct data!"
249
+
250
+ fproc.close()
251
+ if fknwl:
252
+ fknwl.close()
253
+ if fresp:
254
+ fresp.close()
255
+
256
+
257
+ def get_database(test_datapath, train_datapath, data_type):
258
+ """Get the database by topics"""
259
+
260
+ assert data_type in ["wow_seen", "wow_unseen", "woi"], \
261
+ "Please input a correct data type!!"
262
+
263
+ # get test data topic dictionary
264
+ print("> reading test data from %s" % test_datapath)
265
+ test_topics = {}
266
+ with open(test_datapath, "r") as f:
267
+ for i, line in enumerate(f):
268
+ line = line.strip()
269
+ splits = line.split("\t")
270
+ topic = splits[0]
271
+ test_topics[topic] = True
272
+
273
+ print("> reading data from %s" % train_datapath)
274
+ train_data_by_topic = {}
275
+ dialog_data_by_topic = {}
276
+ dialog_examples = []
277
+ with open(train_datapath, "r") as f:
278
+ for i, line in enumerate(f):
279
+ line = line.strip()
280
+ splits = line.split("\t")
281
+ topic = splits[0]
282
+ turns = splits[1].split(" [SEP] ")[-3:]
283
+ knowledge = splits[2]
284
+ response = splits[3]
285
+ # filtering data samples
286
+ if knowledge == "no_passages_used":
287
+ # when no knowledge is used
288
+ continue
289
+ if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge):
290
+ # when bracket exists in the knowledge
291
+ continue
292
+ if data_type != "wow_seen" and topic not in knowledge:
293
+ # when topic does not exist in the knowledge
294
+ continue
295
+
296
+ # get the instance
297
+ last_turn = turns[-1]
298
+ instance = "( " + last_turn + " ) " + topic + " => " + knowledge
299
+
300
+ # construct dialog example
301
+ dialog_example = ""
302
+ if data_type != "wow_seen":
303
+ dialog_example += "( " + topic + " ) "
304
+ for i, turn in enumerate(turns):
305
+ if i != 0:
306
+ dialog_example += " "
307
+ dialog_example += turn
308
+
309
+ # check overlaps
310
+ if topic in test_topics:
311
+ if topic not in train_data_by_topic:
312
+ train_data_by_topic[topic] = [instance]
313
+ else:
314
+ train_data_by_topic[topic].append(instance)
315
+
316
+ if topic not in dialog_data_by_topic:
317
+ dialog_data_by_topic[topic] = [dialog_example]
318
+ else:
319
+ dialog_data_by_topic[topic].append(dialog_example)
320
+
321
+ else:
322
+ # filtering data samples
323
+ if len(knowledge.split()) > 20:
324
+ # knowledge is too long
325
+ continue
326
+ if knowledge.startswith("It") or knowledge.startswith("it") or \
327
+ knowledge.startswith("This") or knowledge.startswith("this"):
328
+ continue
329
+
330
+ # append all the data into dialogue examples list
331
+ dialog_examples.append((topic, dialog_example, instance))
332
+
333
+ return train_data_by_topic, dialog_data_by_topic, dialog_examples
334
+
335
+
336
+ emb_dict = {}
337
+ def select_prompts_based_on_similarity(
338
+ query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
339
+ """Select samples based on the similarity"""
340
+
341
+ with torch.no_grad():
342
+ # get the query embeddings
343
+ query_ids = tokenizer.encode(query)
344
+ query_ids = torch.LongTensor([query_ids]).cuda()
345
+ query_emb = encoder(input_ids=query_ids).pooler_output
346
+ query_emb = query_emb[0]
347
+
348
+ # calculate embeddings for the samples in the database
349
+ if topic in emb_dict:
350
+ example_embeddings = emb_dict[topic]
351
+ example_embeddings = example_embeddings.cuda()
352
+ else:
353
+ for idx, example in enumerate(dialog_list):
354
+ example_ids = tokenizer.encode(example)
355
+ example_ids = torch.LongTensor([example_ids]).cuda()
356
+ example_emb = encoder(input_ids=example_ids).pooler_output
357
+ if idx == 0:
358
+ example_embeddings = example_emb
359
+ else:
360
+ example_embeddings = torch.cat(
361
+ (example_embeddings, example_emb), dim=0)
362
+ emb_dict[topic] = example_embeddings.cpu()
363
+
364
+ # compare the similarity and select the topk samples
365
+ similarity_list = example_embeddings.matmul(query_emb)
366
+ _, indices = torch.topk(similarity_list, k=topk)
367
+
368
+ indices = indices.tolist()
369
+ indices = indices[::-1] # reverse the order
370
+ selected_prompts = []
371
+ for index in indices:
372
+ # index = index.item()
373
+ selected_prompts.append(prompt_list[index])
374
+
375
+ return selected_prompts
376
+
377
+
378
+ def prompt_selection_for_knowledge_generation(
379
+ test_datapath, train_datapath, model_path, output_prompt_path, data_type):
380
+ """Selecting prompts for the knowledge generation"""
381
+
382
+ print("> Selecting prompts for the knowledge generation")
383
+
384
+ train_data_by_topic, dialog_data_by_topic, dialog_examples = \
385
+ get_database(test_datapath, train_datapath, data_type)
386
+
387
+ from transformers import DPRQuestionEncoderTokenizer
388
+ print("> loading tokenizer and encoder")
389
+ tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
390
+ 'facebook/dpr-question_encoder-single-nq-base')
391
+ encoder = torch.load(model_path).cuda()
392
+
393
+ print("> getting dialog embeddings")
394
+ with torch.no_grad():
395
+ for idx, example in tqdm(enumerate(dialog_examples)):
396
+ dialog = example[1]
397
+ dialog_ids = tokenizer.encode(dialog)
398
+ dialog_ids = torch.LongTensor([dialog_ids]).cuda()
399
+ dialog_emb = encoder(input_ids=dialog_ids).pooler_output
400
+
401
+ if idx == 0:
402
+ dialog_embeddings = dialog_emb
403
+ else:
404
+ dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
405
+
406
+ print("> reading test data from %s" % test_datapath)
407
+ prompt_list_for_each_sample = []
408
+ with open(test_datapath, "r") as f:
409
+ for i, line in tqdm(enumerate(f)):
410
+ line = line.strip()
411
+
412
+ splits = line.split("\t")
413
+ topic = splits[0]
414
+ turns = splits[1].split(" [SEP] ")[-3:]
415
+
416
+ # get the query sentence
417
+ query_sent = ""
418
+ if data_type != "seen":
419
+ query_sent += "( " + topic + " ) "
420
+ for i, turn in enumerate(turns):
421
+ if i != 0:
422
+ query_sent += " "
423
+ query_sent += turn
424
+
425
+ if topic not in train_data_by_topic:
426
+ # get the query embedding
427
+ query_ids = tokenizer.encode(query_sent)
428
+ query_ids = torch.LongTensor([query_ids]).cuda()
429
+ query_emb = encoder(input_ids=query_ids).pooler_output
430
+ query_emb = query_emb[0]
431
+
432
+ # calculate the similarity
433
+ similarity_list = dialog_embeddings.matmul(query_emb)
434
+ _, indices = torch.sort(similarity_list)
435
+ indices = indices.tolist()
436
+ selected_topics = {}
437
+ selected_prompts = []
438
+ num_prompt = 0
439
+ for index in indices:
440
+ example = dialog_examples[index]
441
+ topic_temp = example[0]
442
+ if topic_temp not in selected_topics:
443
+ selected_topics[topic_temp] = True
444
+ selected_prompts.append(example[2])
445
+ num_prompt += 1
446
+ if num_prompt == 10:
447
+ break
448
+
449
+ # get the selected samples
450
+ example_list = selected_prompts[::-1]
451
+ key = topic + " " + turns[-1]
452
+ prompt_list_for_each_sample.append({key: example_list})
453
+
454
+ else:
455
+ num_data_sample = min(len(train_data_by_topic[topic]), 10)
456
+ total_example_list = train_data_by_topic[topic]
457
+
458
+ dialog_list = dialog_data_by_topic[topic]
459
+ assert len(dialog_list) == len(train_data_by_topic[topic])
460
+
461
+ # calculate the similarity
462
+ example_list = select_prompts_based_on_similarity(
463
+ query_sent, dialog_list, total_example_list,
464
+ topic, tokenizer, encoder, topk=num_data_sample)
465
+
466
+ key = topic + " " + turns[-1]
467
+ prompt_list_for_each_sample.append({key: example_list})
468
+
469
+ print("writing to %s" % output_prompt_path)
470
+ with open(output_prompt_path, "w") as f:
471
+ for instance in tqdm(prompt_list_for_each_sample):
472
+ json.dump(instance, f)
473
+ f.write("\n")
474
+
475
+
476
+ def prompt_selection_for_response_generation(input_path, output_path, seed):
477
+ """Selecting prompts for the response generation"""
478
+
479
+ print("> Selecting prompts for the response generation")
480
+ print("> set random seed")
481
+ np.random.seed(seed)
482
+
483
+ prompt_example_list = []
484
+ print("> reading data from %s" % input_path)
485
+ with open(input_path, "r") as f:
486
+ for i, line in tqdm(enumerate(f)):
487
+ line = line.strip()
488
+ splits = line.split("\t")
489
+
490
+ # get the topic, context, knowledge and response
491
+ topic = splits[0]
492
+ dialog_context = splits[1]
493
+ knowledge = splits[2]
494
+ response = splits[3]
495
+ turns = dialog_context.split(" [SEP] ")[-3:]
496
+ if knowledge == "no_passages_used":
497
+ continue
498
+
499
+ # calculate the overlap ratio
500
+ from nltk import word_tokenize
501
+ knowledge_sent_token_list = word_tokenize(knowledge)
502
+ knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
503
+ knowledge_len = len(knowledge_sent_token_list)
504
+ response_token_list = word_tokenize(response)
505
+ response_len = len(response_token_list)
506
+ num_overlap_token = 0
507
+ accumulator = 0
508
+ for token in response_token_list:
509
+ if token in knowledge_sent_token_dict:
510
+ accumulator += 1
511
+ else:
512
+ if accumulator >= 10:
513
+ num_overlap_token += accumulator
514
+ accumulator = 0
515
+ if accumulator >= 10:
516
+ num_overlap_token += accumulator
517
+
518
+ # filtering the data based on the ratio
519
+ if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
520
+ continue
521
+ if num_overlap_token < knowledge_len * 0.8:
522
+ continue
523
+
524
+ last_turn = " ".join(word_tokenize(turns[-1]))
525
+ knowledge = " ".join(word_tokenize(knowledge))
526
+ response = " ".join(word_tokenize(response))
527
+ prompt_example = ""
528
+ # add dialog context
529
+ prompt_example += "Topic: " + topic + ". "
530
+ prompt_example += "User says: " + last_turn + " "
531
+ prompt_example += "We know that: " + knowledge + " "
532
+ prompt_example += "System replies: " + response
533
+
534
+ prompt_example_list.append(prompt_example)
535
+
536
+ # shuffle the prompt examples
537
+ np.random.shuffle(prompt_example_list)
538
+
539
+ print("> writing to %s" % output_path)
540
+ with open(output_path, "w") as f:
541
+ # f.write("Generate the System's response based on the knowledge sentence:\n")
542
+ for i in tqdm(range(20)):
543
+ example = prompt_example_list[i]
544
+ f.write(example + "\n")
545
+
546
+
547
+ def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
548
+ """Preparing inputs for the response generation"""
549
+
550
+ print("> Reading knowledge file from %s" % knwl_gen_file)
551
+ # get the knowledge list
552
+ with open(knwl_gen_file, "r") as f:
553
+ knowledge_list = f.readlines()
554
+
555
+ print("> Processing ...")
556
+ with open(test_file, "r") as fr:
557
+ with open(processed_file, "w") as fw:
558
+ for line_num, line in enumerate(tqdm(fr)):
559
+ line = line.strip()
560
+ splits = line.split("\t")
561
+ # prepare topic, context, knowledge and response
562
+ topic = splits[0]
563
+ dialog_context = splits[1]
564
+ response = splits[3]
565
+ knowledge = knowledge_list[line_num]
566
+ knowledge = knowledge.strip()
567
+ if "<|endoftext|>" in knowledge:
568
+ knowledge = knowledge.replace("<|endoftext|>", "")
569
+
570
+ # write to the output file
571
+ fw.write(topic + "\t" + dialog_context + "\t" \
572
+ + knowledge + "\t" + response + "\n")
573
+
574
+
575
+ if __name__ == "__main__":
576
+
577
+ args = get_args()
578
+ if args.func == "process_wow_dataset":
579
+ process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
580
+
581
+ elif args.func == "process_woi_dataset":
582
+ process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
583
+
584
+ elif args.func == "get_knwl_gen_prompts":
585
+ prompt_selection_for_knowledge_generation(
586
+ args.test_file, args.train_file, args.model_file,
587
+ args.processed_file, args.data_type)
588
+
589
+ elif args.func == "get_resp_gen_prompts":
590
+ prompt_selection_for_response_generation(
591
+ args.train_file, args.processed_file, args.seed)
592
+
593
+ elif args.func == "prepare_input":
594
+ prepare_input_for_response_generation(
595
+ args.test_file, args.knwl_gen_file, args.processed_file)
tasks/msdp/prompt.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Prompting the pretrained language model to generate knowledge/response"""
17
+
18
+ import json
19
+ import torch
20
+ import requests
21
+ from nltk import word_tokenize
22
+ from megatron import mpu
23
+ from megatron import get_args
24
+ from megatron import print_rank_0
25
+ from megatron import get_tokenizer
26
+ from megatron.model import GPTModel
27
+ from megatron.training import get_model
28
+ from megatron.checkpointing import load_checkpoint
29
+ from megatron.initialize import initialize_megatron
30
+ from megatron.text_generation import generate_and_post_process
31
+
32
+
33
+ def call_model_api(inputs, tokens_to_generate):
34
+ """Calling the model api to get the output generations"""
35
+
36
+ args = get_args()
37
+
38
+ # The following is an example of using the Megatron API
39
+ # You can also implement your own API function to place this part
40
+ headers = {'Content-Type': 'application/json; charset=UTF-8'}
41
+ data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
42
+ data_json = json.dumps(data)
43
+ outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
44
+
45
+ input_len = len(inputs)
46
+ outputs = outputs[input_len:]
47
+ outputs = outputs.split("\n")[0].strip()
48
+
49
+ return outputs
50
+
51
+
52
+ def read_prompts(prompt_path, prompt_type, n_example):
53
+ """Read prompt data"""
54
+
55
+ if prompt_type == "knowledge":
56
+ # prompts for the knowledge generation
57
+ prompt_examples_dict = {}
58
+ # read prompt_path
59
+ with open(prompt_path, "r") as f:
60
+ for i, line in enumerate(f):
61
+ line = line.strip()
62
+ line_dict = json.loads(line)
63
+ key = list(line_dict.keys())[0]
64
+
65
+ if key not in prompt_examples_dict:
66
+ prompt_examples = line_dict[key]
67
+ prompt = ""
68
+ for instance in prompt_examples:
69
+ instance = instance.strip()
70
+ prompt += instance + " \n"
71
+ prompt_examples_dict[key] = prompt
72
+
73
+ return prompt_examples_dict
74
+
75
+ else:
76
+ # prompts for the response generation
77
+ # read prompt_path
78
+ prompt = ""
79
+ with open(prompt_path, "r") as f:
80
+ prompt_examples = f.readlines()
81
+ prompt_examples = prompt_examples[:n_example]
82
+ for instance in prompt_examples:
83
+ instance = instance.strip()
84
+ prompt += instance + " \n"
85
+
86
+ return prompt
87
+
88
+
89
+ def generate_samples_by_calling_api():
90
+ """ Generate outputs by calling"""
91
+ args = get_args()
92
+ assert args.prompt_type in ["knowledge", "response"], \
93
+ "Please input a correct prompt type!"
94
+
95
+ if args.prompt_type == "knowledge":
96
+ # read knowledge generation prompts
97
+ knwl_gen_prompt_dict = read_prompts(
98
+ args.prompt_file, args.prompt_type, args.num_prompt_examples)
99
+
100
+ else:
101
+ resp_gen_prompt = read_prompts(
102
+ args.prompt_file, args.prompt_type, args.num_prompt_examples)
103
+
104
+ # read the test data
105
+ fname = open(args.sample_input_file, "r")
106
+ test_sample_list = fname.readlines()
107
+ # create output file
108
+ fname_out = open(args.sample_output_file, "w")
109
+
110
+ # call the api to get the output generations
111
+ for test_sample in test_sample_list:
112
+ test_sample = test_sample.strip()
113
+ splits = test_sample.split("\t")
114
+ topic = splits[0]
115
+
116
+ # prepare the inputs for the api
117
+ if args.prompt_type == "knowledge":
118
+ ## inputs = prompt + current test
119
+ # get the prompt
120
+ turns = splits[1].split(" [SEP] ")
121
+ last_turn = turns[-1]
122
+ key = topic + " " + last_turn
123
+ inputs = knwl_gen_prompt_dict[key]
124
+
125
+ # add current test
126
+ inputs += "( " + last_turn + " ) " + topic + " =>"
127
+
128
+ else:
129
+ # inputs = prompt + current test
130
+ # get the prompt
131
+ inputs = resp_gen_prompt
132
+
133
+ # add current test
134
+ turns = splits[1].split(" [SEP] ")
135
+ knowledge = splits[2]
136
+ last_turn = turns[-1]
137
+ last_turn = " ".join(word_tokenize(last_turn))
138
+ knowledge = " ".join(word_tokenize(knowledge))
139
+ knowledge = knowledge.strip()
140
+ last_turn = last_turn.strip()
141
+ inputs += "Topic: " + topic + ". "
142
+ inputs += "User says: " + last_turn + " "
143
+ inputs += "We know that: " + knowledge + " "
144
+ inputs += "System replies:"
145
+
146
+ # get the output generations from the api,
147
+ # and write to the output file
148
+ generations = call_model_api(inputs, args.out_seq_length)
149
+ fname_out.write(generations)
150
+ fname_out.write("\n")
151
+
152
+ fname.close()
153
+ fname_out.close()
154
+
155
+
156
+ def model_provider(pre_process=True, post_process=True):
157
+ """Build the model."""
158
+
159
+ print_rank_0('building GPT model ...')
160
+ model = GPTModel(
161
+ num_tokentypes=0,
162
+ parallel_output=True,
163
+ pre_process=pre_process,
164
+ post_process=post_process
165
+ )
166
+ return model
167
+
168
+
169
+ def generate_samples_by_prompting_input_from_file(model):
170
+ """Prompt a pretrained language model to generate knowledge/response"""
171
+
172
+ # get tokenizer
173
+ args = get_args()
174
+ tokenizer = get_tokenizer()
175
+
176
+ # Read the sample file and open the output file.
177
+ assert args.sample_input_file is not None, \
178
+ 'sample input file is not provided.'
179
+ if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
180
+ fname = open(args.sample_input_file, "r")
181
+ all_raw_text = fname.readlines()
182
+ input_count = len(all_raw_text)
183
+ if args.sample_output_file is None:
184
+ sample_output_file = args.sample_input_file + ".out"
185
+ print('`sample-output-file` not specified, setting '
186
+ 'it to {}'.format(sample_output_file))
187
+ else:
188
+ sample_output_file = args.sample_output_file
189
+
190
+ fname_out = open(sample_output_file, "w")
191
+
192
+ # only two prompt types (i.e., knowledge and response) are allowed
193
+ assert args.prompt_type in ["knowledge", "response"], \
194
+ "Please input a correct prompt type!"
195
+
196
+ # Read the prompt file
197
+ if args.prompt_type == "knowledge":
198
+ # read the prompts for the knowledge generation
199
+ prompt_examples_dict = {}
200
+ with open(args.prompt_file, "r") as f:
201
+ for i, line in enumerate(f):
202
+ line = line.strip()
203
+ line_dict = json.loads(line)
204
+ key = list(line_dict.keys())[0]
205
+
206
+ # get the prompt examples based on the key
207
+ if key not in prompt_examples_dict:
208
+ prompt_examples = line_dict[key]
209
+ prompt = ""
210
+ for instance in prompt_examples:
211
+ instance = instance.strip()
212
+ prompt += instance + " \n"
213
+ prompt_examples_dict[key] = prompt
214
+
215
+ else:
216
+ # read the prompts for the response generation
217
+ # prompts are fixed for all test samples
218
+ with open(args.prompt_file, "r") as f:
219
+ prompt_examples = f.readlines()
220
+ prompt_examples = prompt_examples[:args.num_prompt_examples]
221
+
222
+ prompt = ""
223
+ for instance in prompt_examples:
224
+ instance = instance.strip()
225
+ prompt += instance + " \n"
226
+
227
+ input_pos = 0
228
+ model.eval()
229
+ # perform prompting
230
+ with torch.no_grad():
231
+ while True:
232
+ raw_text_len = 0
233
+ if mpu.is_pipeline_first_stage() \
234
+ and mpu.get_tensor_model_parallel_rank() == 0:
235
+ input_str = all_raw_text[input_pos]
236
+ input_str = input_str.strip()
237
+ splits = input_str.split("\t")
238
+ topic = splits[0]
239
+
240
+ if args.prompt_type == "knowledge":
241
+ # first add the prompt into the raw_text
242
+ turns = splits[1].split(" [SEP] ")
243
+ last_turn = turns[-1]
244
+ key = topic + " " + last_turn
245
+ raw_text = prompt_examples_dict[key]
246
+
247
+ # construct inputs for knowledge generation
248
+ # then add the constructed inputs into the raw_text
249
+ raw_text += "( " + last_turn + " ) " + topic + " =>"
250
+
251
+ else:
252
+ # first add the prompt into the raw_text
253
+ raw_text = prompt
254
+
255
+ # construct inputs for response generation
256
+ # then add the constructed inputs into the raw_text
257
+ turns = splits[1].split(" [SEP] ")
258
+ knowledge = splits[2]
259
+ last_turn = turns[-1]
260
+ last_turn = " ".join(word_tokenize(last_turn))
261
+ knowledge = " ".join(word_tokenize(knowledge))
262
+ knowledge = knowledge.strip()
263
+ last_turn = last_turn.strip()
264
+ raw_text += "Topic: " + topic + ". "
265
+ raw_text += "User says: " + last_turn + " "
266
+ raw_text += "We know that: " + knowledge + " "
267
+ raw_text += "System replies:"
268
+
269
+ input_pos += 1
270
+ raw_text_len = len(raw_text)
271
+
272
+ else:
273
+ raw_text = "EMPTY TEXT"
274
+
275
+ if input_pos % 100 == 0:
276
+ print_rank_0("input_pos: %d" % input_pos)
277
+
278
+ outputs = generate_and_post_process(
279
+ model=model,
280
+ prompts=[raw_text],
281
+ tokens_to_generate=args.out_seq_length,
282
+ top_k_sampling=1)
283
+ prompts_plus_generations = outputs[0]
284
+ prompts_plus_generations = prompts_plus_generations[0]
285
+
286
+ # write the generated output to the output file
287
+ if mpu.get_tensor_model_parallel_rank() == 0:
288
+ if mpu.is_pipeline_first_stage():
289
+
290
+ generations = prompts_plus_generations[raw_text_len:]
291
+ generations = generations.split("\n")[0]
292
+ generations = generations.strip()
293
+ fname_out.write(generations)
294
+ fname_out.write("\n")
295
+
296
+ raw_text = None
297
+ if input_pos == input_count:
298
+ return
299
+
300
+
301
+ def main():
302
+
303
+ args = get_args()
304
+ if args.api_prompt:
305
+ # obtain the generations by calling the api
306
+ generate_samples_by_calling_api()
307
+ return
308
+
309
+ if args.num_layers_per_virtual_pipeline_stage is not None:
310
+ print("Interleaved pipeline schedule is not yet supported for text generation.")
311
+ exit()
312
+
313
+ # Set up model and load checkpoint.
314
+ model = get_model(model_provider, wrap_with_ddp=False)
315
+ if args.load is not None:
316
+ _ = load_checkpoint(model, None, None)
317
+
318
+ assert len(model) == 1, "Above condition should have caught this"
319
+ model = model[0]
320
+
321
+ # perform the prompting
322
+ generate_samples_by_prompting_input_from_file(model)
tasks/orqa/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
2
+
3
+ Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
4
+
5
+ ## Retriever Training
6
+
7
+ #### Unsupervised pretraining
8
+ 1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
9
+
10
+ <pre>
11
+ python tools/preprocess_data.py \
12
+ --input /path/to/corpus.json \
13
+ --json-keys text title \
14
+ --split-sentences \
15
+ --tokenizer-type BertWordPieceLowerCase \
16
+ --vocab-file /path/to/vocab.txt \
17
+ --output-prefix corpus_indexed \
18
+ --workers 10
19
+ </pre>
20
+
21
+ 2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training.
22
+
23
+ 3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
24
+
25
+ #### Supervised finetuning
26
+
27
+ 1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
28
+
29
+ 2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
30
+
31
+ More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
32
+
33
+ ## Reader Training
34
+
35
+ The reader component will be available soon.
36
+
tasks/orqa/evaluate_orqa.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Main tasks functionality."""
17
+
18
+ from megatron import get_args, print_rank_0
19
+ from megatron.indexer import IndexBuilder
20
+ from tasks.orqa.evaluate_utils import ORQAEvaluator
21
+
22
+ def main():
23
+ """
24
+ Main program
25
+ """
26
+
27
+ args = get_args()
28
+
29
+ """
30
+ Create a BlockData data structure by running an IndexBuilder over an
31
+ ICT Dataset and then evaluate on NQ task
32
+ """
33
+
34
+ print_rank_0("Starting index builder!")
35
+
36
+ index_builder = IndexBuilder()
37
+ index_builder.build_and_save_index()
38
+ print_rank_0("Build and save indices: done!")
39
+
40
+
41
+ print_rank_0("Starting evaluations!")
42
+
43
+ # Set up the model and evaluator
44
+ evaluator = ORQAEvaluator()
45
+
46
+ # Run evaluation
47
+ if args.qa_data_dev is not None:
48
+ evaluator.evaluate(args.qa_data_dev, "DEV")
49
+
50
+ if args.qa_data_test is not None:
51
+ evaluator.evaluate(args.qa_data_test, "TEST")
52
+
tasks/orqa/evaluate_utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+ from megatron import get_args, print_rank_0
19
+ from megatron.checkpointing import load_biencoder_checkpoint
20
+ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
21
+ from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
22
+ from megatron.model.biencoder_model import get_model_provider
23
+ from megatron.training import get_model
24
+ from tasks.orqa.unsupervised.nq import get_nq_dataset
25
+ from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
26
+ from tasks.orqa.unsupervised.nq import process_nq_batch
27
+ from tasks.orqa.unsupervised.qa_utils import calculate_matches
28
+
29
+
30
+ class ORQAEvaluator(object):
31
+ def __init__(self):
32
+ args = get_args()
33
+ self.embedding_size = args.hidden_size
34
+ self.faiss_use_gpu = args.faiss_use_gpu
35
+ self.evidence_embedder_obj = None
36
+ self.evidence_dataset = None
37
+ self.mips_index = None
38
+ self.eval_dataset = None
39
+
40
+ # Get Evidence (Wikipedia) dataset
41
+ self.get_evidence_dataset()
42
+
43
+ # Load query encoder checkpoint
44
+ only_query_model = True
45
+ if args.biencoder_shared_query_context_model:
46
+ only_query_model = False
47
+
48
+ model = get_model(get_model_provider(only_query_model=only_query_model,
49
+ biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
50
+
51
+ self.model = load_biencoder_checkpoint(model,
52
+ only_query_model=only_query_model)
53
+
54
+ assert len(self.model) == 1
55
+ self.model[0].eval()
56
+
57
+ # Load faiss indexer
58
+ self.faiss_wrapper()
59
+
60
+ def get_evidence_embedding(self):
61
+ # This will load the embedding from the embedding path
62
+ self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
63
+
64
+ def get_evidence_dataset(self):
65
+ self.evidence_dataset = get_open_retrieval_wiki_dataset()
66
+
67
+ def faiss_wrapper(self):
68
+ # Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
69
+ # is distributed over all the GPUs in a node and FAISS is not
70
+ # thread-safe
71
+ args = get_args()
72
+ if args.local_rank == 0:
73
+ # Get evidence embeddings computed using context encoder
74
+ self.get_evidence_embedding()
75
+
76
+ assert self.evidence_embedder_obj is not None
77
+ self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
78
+ embed_data=self.evidence_embedder_obj,
79
+ use_gpu=self.faiss_use_gpu)
80
+
81
+ # Wait for the FAISS index to be initialized in all the nodes
82
+ torch.distributed.barrier()
83
+
84
+ def generate_query_vectors(self, qa_data, split):
85
+
86
+ self.eval_dataset = get_nq_dataset(qa_data, split)
87
+ dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
88
+
89
+ query_vectors = []
90
+ reference_list = []
91
+
92
+ for batch in dataloader:
93
+ # batch also has query_tokens and query_pad_data
94
+ query_tokens, query_mask, query_types, \
95
+ query_len, reference = process_nq_batch(batch)
96
+
97
+ assert len(self.model) == 1
98
+ unwrapped_model = self.model[0]
99
+ while not hasattr(unwrapped_model, 'embed_text'):
100
+ unwrapped_model = unwrapped_model.module
101
+
102
+ with torch.no_grad():
103
+ query_logits = unwrapped_model.embed_text(
104
+ unwrapped_model.query_model, query_tokens,
105
+ query_mask, query_types)
106
+
107
+ reference_list.extend(reference)
108
+ query_vectors.extend(query_logits.split(1, dim=0))
109
+ if len(query_vectors) % 100 == 0:
110
+ print_rank_0('Encoded queries {}'.format(len(query_vectors)))
111
+
112
+ query_tensor = torch.cat(query_vectors, dim=0)
113
+ print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
114
+
115
+ assert query_tensor.size(0) == len(self.eval_dataset)
116
+ return query_tensor, reference_list
117
+
118
+ def evaluate(self, qa_data, split):
119
+ args = get_args()
120
+ query_tensor, reference_list = self.generate_query_vectors(qa_data, \
121
+ split)
122
+ local_rank = args.local_rank
123
+ rank = torch.distributed.get_rank()
124
+ device_count = torch.cuda.device_count()
125
+ num_nodes = torch.distributed.get_world_size() // device_count
126
+ node_id = rank // device_count
127
+
128
+ for node in range(num_nodes):
129
+ start_rank = node * device_count
130
+ end_rank = (node + 1) * device_count
131
+ ranks_list = list(range(start_rank, end_rank))
132
+ node_group = torch.distributed.new_group(ranks=ranks_list)
133
+
134
+ if node_id == node:
135
+ device_start_rank = start_rank
136
+ group = node_group
137
+
138
+ input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
139
+ tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
140
+ torch.distributed.all_gather(tensor_list, query_tensor, group=group)
141
+
142
+ if local_rank == 0 and self.mips_index is not None:
143
+ all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
144
+
145
+ distance, topkindex = self.mips_index.search_mips_index(
146
+ all_query_tensor, top_k=args.faiss_topk_retrievals,
147
+ reconstruct=False)
148
+ distance = torch.from_numpy(distance).cuda()
149
+ topkindex = torch.LongTensor(topkindex).cuda()
150
+
151
+ if local_rank != 0:
152
+ distance = torch.empty(device_count * len(query_tensor), \
153
+ args.faiss_topk_retrievals, dtype=torch.float32).cuda()
154
+ topkindex = torch.empty(device_count * len(query_tensor), \
155
+ args.faiss_topk_retrievals, dtype=torch.int64).cuda()
156
+
157
+ torch.distributed.broadcast(distance, src=device_start_rank, \
158
+ group=group)
159
+ torch.distributed.broadcast(topkindex, src=device_start_rank, \
160
+ group=group)
161
+
162
+ distance = torch.split(distance, len(query_tensor), dim=0)\
163
+ [local_rank]
164
+ topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
165
+ [local_rank]
166
+
167
+ top_ids_and_scores = []
168
+ for darray, topkarray in zip(distance, topkindex):
169
+ top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
170
+
171
+ passages = self.evidence_dataset.id2text
172
+ match_stats = calculate_matches(passages,
173
+ reference_list,
174
+ top_ids_and_scores,
175
+ workers_num=args.num_workers,
176
+ match_type=args.faiss_match)
177
+ top_k_hits = match_stats.top_k_hits
178
+
179
+ print_rank_0("{} SET RESULTS".format(split))
180
+ print_rank_0("topk-{} documents hits {}".format(
181
+ args.faiss_topk_retrievals, top_k_hits))
182
+ top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
183
+ print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
184
+
185
+ for i in args.retriever_report_topk_accuracies:
186
+ print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
187
+
188
+ return
tasks/orqa/supervised/data.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ORQA dataset."""
17
+
18
+ import json
19
+ import random
20
+ from abc import ABC
21
+ from abc import abstractmethod
22
+
23
+ import numpy as np
24
+ from torch.utils.data import Dataset
25
+
26
+ from megatron import print_rank_0, get_args
27
+ from megatron.data.biencoder_dataset_utils import make_attention_mask
28
+
29
+ def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
30
+ ctx_id_list, ctx_types_list = [], []
31
+ for context in ctx_list:
32
+ title_ids = tokenizer.tokenize(context['title'])
33
+ ctx_ids = tokenizer.tokenize(context['text'])
34
+ ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
35
+
36
+ ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
37
+ max_seq_length, tokenizer.cls,
38
+ tokenizer.sep, tokenizer.pad)
39
+ ctx_id_list.append(ctx_ids)
40
+ ctx_types_list.append(ctx_types)
41
+
42
+ return ctx_id_list, ctx_types_list
43
+
44
+
45
+ def build_tokens_types_paddings_from_text(query, context,
46
+ tokenizer, max_seq_length):
47
+ """Build token types and paddings, trim if needed, and pad if needed."""
48
+
49
+ query_ids = tokenizer.tokenize(query)
50
+ query_ids, query_types, query_pad_mask = \
51
+ build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
52
+ tokenizer.cls, tokenizer.sep, tokenizer.pad)
53
+
54
+ # Appending the title of the context at front
55
+ extended_ctx_ids = None
56
+ if context is not None:
57
+ title_ids = tokenizer.tokenize(context['title'])
58
+ ctx_ids = tokenizer.tokenize(context['text'])
59
+ extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
60
+
61
+ ctx_ids, ctx_types, ctx_pad_mask = \
62
+ build_tokens_types_paddings_from_ids(extended_ctx_ids,
63
+ max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
64
+
65
+ return query_ids, query_types, query_pad_mask, \
66
+ ctx_ids, ctx_types, ctx_pad_mask
67
+
68
+
69
+ # Similar code tasks/data_utils with some changes
70
+ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
71
+ cls_id, sep_id, pad_id):
72
+ """Build token types and paddings, trim if needed, and pad if needed."""
73
+ enc_ids = []
74
+ tokentypes_enc = []
75
+
76
+ # [CLS].
77
+ enc_ids.append(cls_id)
78
+ tokentypes_enc.append(0)
79
+
80
+ # A.
81
+ len_src = len(text_ids)
82
+ enc_ids.extend(text_ids)
83
+ tokentypes_enc.extend([0] * len_src)
84
+
85
+ # Cap the size.
86
+ if len(enc_ids) > max_seq_length - 1:
87
+ enc_ids = enc_ids[0: max_seq_length - 1]
88
+ tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
89
+
90
+ # [SEP].
91
+ enc_ids.append(sep_id)
92
+ tokentypes_enc.append(0)
93
+
94
+ num_tokens_enc = len(enc_ids)
95
+ # Padding.
96
+ padding_length = max_seq_length - len(enc_ids)
97
+ if padding_length > 0:
98
+ enc_ids.extend([pad_id] * padding_length)
99
+ tokentypes_enc.extend([pad_id] * padding_length)
100
+
101
+ pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
102
+ pad_mask = np.array(pad_mask, dtype=np.int64)
103
+
104
+ return enc_ids, tokentypes_enc, pad_mask
105
+
106
+
107
+ def build_sample(query_ids, query_types, query_pad_mask,
108
+ ctx_ids, ctx_types, ctx_pad_mask, answers,
109
+ neg_ctx_id_list=None, neg_ctx_types_list=None,
110
+ include_neg=False):
111
+ """Convert to numpy and return a sample consumed by the batch producer."""
112
+
113
+ query_ids = np.array(query_ids, dtype=np.int64)
114
+ query_types = np.array(query_types, dtype=np.int64)
115
+ query_mask = make_attention_mask(query_ids, query_ids)
116
+
117
+ ctx_ids = np.array(ctx_ids, dtype=np.int64)
118
+ ctx_types = np.array(ctx_types, dtype=np.int64)
119
+ ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
120
+
121
+ sample = ({
122
+ 'query': query_ids,
123
+ 'query_mask': query_mask,
124
+ 'query_types': query_types,
125
+ 'query_pad_mask': query_pad_mask,
126
+ 'context': ctx_ids,
127
+ 'context_mask': ctx_mask,
128
+ 'context_types': ctx_types,
129
+ 'context_pad_mask': ctx_pad_mask,
130
+ 'reference': answers
131
+ })
132
+
133
+ if include_neg:
134
+ neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
135
+ neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
136
+ neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
137
+ for ids in neg_ctx_ids], dtype=np.int64)
138
+
139
+ sample['neg_context'] = neg_ctx_ids
140
+ sample['neg_context_types'] = neg_ctx_id_types
141
+ sample['neg_context_mask'] = neg_ctx_mask
142
+
143
+ return sample
144
+
145
+
146
+ class OpenRetrievalAbstractDataset(ABC, Dataset):
147
+ """Open Retrieval base dataset class."""
148
+
149
+ def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
150
+ max_seq_length, evaluate=False):
151
+ # Store inputs.
152
+ args = get_args()
153
+ self.evaluate = evaluate
154
+ self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
155
+ self.val_av_rank_other_neg = args.val_av_rank_other_neg
156
+ self.train_with_neg = args.train_with_neg
157
+ self.train_hard_neg = args.train_hard_neg
158
+
159
+ self.task_name = task_name
160
+ self.dataset_name = dataset_name
161
+ self.tokenizer = tokenizer
162
+ self.max_seq_length = max_seq_length
163
+ print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
164
+ self.dataset_name))
165
+ # Process the files.
166
+ string = ' > paths:'
167
+ for path in datapaths:
168
+ string += ' ' + path
169
+ print_rank_0(string)
170
+ self.samples = []
171
+ for datapath in datapaths:
172
+ self.samples.extend(self.process_samples_from_single_path(datapath))
173
+
174
+ args = get_args()
175
+ if args.sample_rate < 1: # subsample
176
+ k = int(len(self.samples) * args.sample_rate)
177
+ self.samples = random.sample(self.samples, k)
178
+
179
+ print_rank_0(' >> total number of samples: {}'.format(
180
+ len(self.samples)))
181
+
182
+ def __len__(self):
183
+ return len(self.samples)
184
+
185
+ def __getitem__(self, idx):
186
+ raw_sample = self.samples[idx]
187
+
188
+ query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
189
+ ctx_pad_mask = build_tokens_types_paddings_from_text( \
190
+ raw_sample['question'], raw_sample['pos_context'], \
191
+ self.tokenizer, self.max_seq_length)
192
+
193
+ if self.evaluate:
194
+ neg_ctx_list = \
195
+ raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
196
+ raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
197
+ neg_ctx_id_list, neg_ctx_types_list = \
198
+ build_token_types_from_context_list(neg_ctx_list, \
199
+ self.tokenizer, self.max_seq_length)
200
+
201
+ elif self.train_with_neg:
202
+ hard_negative_ctx = raw_sample['hard_negative_context']
203
+ negative_ctx = raw_sample['negative_context']
204
+ if True: # TODO: fix this or remove this condition
205
+ random.shuffle(hard_negative_ctx)
206
+ random.shuffle(negative_ctx)
207
+
208
+ neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
209
+ # In the Google NQ dataset by DPR paper, there are around more than
210
+ # 50 missing hard negatives in training data.
211
+ # In those cases, substitute hard negatives by simple negatives.
212
+ if len(neg_ctx_list) < self.train_hard_neg:
213
+ neg_ctx_list += negative_ctx[:self.train_hard_neg - \
214
+ len(neg_ctx_list)]
215
+
216
+ neg_ctx_id_list, neg_ctx_types_list = \
217
+ build_token_types_from_context_list(neg_ctx_list,
218
+ self.tokenizer, self.max_seq_length)
219
+ else:
220
+ neg_ctx_id_list = None
221
+ neg_ctx_types_list = None
222
+
223
+ sample = build_sample(query_ids, query_types, query_pad_mask,
224
+ ctx_ids, ctx_types, ctx_pad_mask,
225
+ raw_sample['answers'],
226
+ neg_ctx_id_list, neg_ctx_types_list,
227
+ include_neg=self.evaluate or self.train_with_neg)
228
+
229
+ return sample
230
+
231
+ @staticmethod
232
+ @abstractmethod
233
+ def process_samples_from_single_path(filename):
234
+ """Abstract method that takes a filename and
235
+ returns a list of dataset samples, each sample being a dict of
236
+ {'text': string, 'text': string}
237
+ """
238
+ pass
239
+
240
+
241
+
242
+ def normalize_question(question):
243
+ if question[-1] == '?':
244
+ question = question[:-1]
245
+ return question
246
+
247
+ # The following class reads the datasets for training retriever as
248
+ # prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
249
+
250
+ class NQSupervisedDataset(OpenRetrievalAbstractDataset):
251
+
252
+ def __init__(self, name, datapaths, tokenizer, max_seq_length, \
253
+ evaluate=False):
254
+ super().__init__('natural_questions_ret',
255
+ name,
256
+ datapaths,
257
+ tokenizer,
258
+ max_seq_length,
259
+ evaluate=evaluate)
260
+
261
+ @staticmethod
262
+ def process_samples_from_single_path(filename):
263
+ """"Implement abstract method."""
264
+ print_rank_0(' > Processing {} ...'.format(filename))
265
+ samples = []
266
+ total = 0
267
+
268
+ with open(filename, 'r', encoding="utf-8") as f:
269
+ data = json.load(f)
270
+ for row in data:
271
+ question = normalize_question(row['question'])
272
+ pos_context = row['positive_ctxs'][0]
273
+
274
+ # Hard Negative Contexts
275
+ if len(row['hard_negative_ctxs']) > 0:
276
+ hard_neg_context = row['hard_negative_ctxs']
277
+ else:
278
+ hard_neg_context = []
279
+
280
+ # Negative Contexts
281
+ if len(row['negative_ctxs']) > 0:
282
+ neg_context = row['negative_ctxs']
283
+ else:
284
+ neg_context = []
285
+
286
+ answers = row['answers']
287
+ sample = {'question': question,
288
+ 'pos_context': pos_context,
289
+ 'hard_negative_context': hard_neg_context,
290
+ 'negative_context': neg_context,
291
+ 'answers': answers}
292
+ total += 1
293
+ samples.append(sample)
294
+
295
+ if total % 5000 == 0:
296
+ print_rank_0(' > processed {} so far ...'.format(total))
297
+
298
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
299
+ return samples
300
+
tasks/orqa/supervised/eval_utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Evaluation utilities."""
17
+ from collections import OrderedDict
18
+ import math
19
+ import numpy as np
20
+ import time
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import DataLoader
24
+
25
+ from megatron import get_args, print_rank_0
26
+ from megatron import mpu
27
+ from megatron.utils import average_losses_across_data_parallel_group
28
+ from tasks.finetune_utils import build_data_loader
29
+
30
+ def task_collate_fn(batch_data):
31
+ # generate batch
32
+ batch_size = len(batch_data)
33
+ tensorized = OrderedDict()
34
+ for d in batch_data:
35
+ for k, v in d.items():
36
+ tensorized.setdefault(k, []).append(v)
37
+
38
+ tensorized['query'] = torch.LongTensor(tensorized['query'])
39
+ tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
40
+ tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
41
+ tensorized['query_pad_mask'] = \
42
+ torch.LongTensor(tensorized['query_pad_mask'])
43
+
44
+ tensorized['context'] = torch.LongTensor(tensorized['context'])
45
+ tensorized['context_mask'] = \
46
+ torch.LongTensor(tensorized['context_mask'])
47
+ tensorized['context_types'] = \
48
+ torch.LongTensor(tensorized['context_types'])
49
+ tensorized['context_pad_mask'] = \
50
+ torch.LongTensor(tensorized['context_pad_mask'])
51
+
52
+ if 'neg_context' in tensorized:
53
+ tensorized['neg_context'] = \
54
+ torch.LongTensor(np.concatenate(tensorized['neg_context']))
55
+ tensorized['neg_context_mask'] = \
56
+ torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
57
+ tensorized['neg_context_types'] = \
58
+ torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
59
+
60
+ return tensorized
61
+
62
+
63
+
64
+ def process_batch(batch):
65
+ """Process batch and produce inputs for the model."""
66
+ query_tokens = batch['query'].long().cuda()
67
+ query_mask = (batch['query_mask'] < 0.5).cuda()
68
+ query_types = batch['query_types'].long().cuda()
69
+ query_pad_mask = batch['query_pad_mask'].long().cuda()
70
+
71
+ context_tokens = batch['context'].long().cuda()
72
+ context_mask = (batch['context_mask'] < 0.5).cuda()
73
+ context_types = batch['context_types'].long().cuda()
74
+ context_pad_mask = batch['context_pad_mask'].long().cuda()
75
+
76
+ if 'neg_context' in batch:
77
+ neg_context_tokens = batch['neg_context'].long().cuda()
78
+ neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
79
+ neg_context_types = batch['neg_context_types'].long().cuda()
80
+ else:
81
+ neg_context_tokens = None
82
+ neg_context_mask = None
83
+ neg_context_types = None
84
+
85
+ reference = batch['reference']
86
+
87
+ return query_tokens, query_mask, query_types, query_pad_mask, \
88
+ context_tokens, context_mask, context_types, context_pad_mask, \
89
+ neg_context_tokens, neg_context_mask, neg_context_types, reference
90
+
91
+ def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
92
+ """Provide function that calculates accuracies."""
93
+ args = get_args()
94
+
95
+ print_rank_0("accuracy_func_provider is CALLED")
96
+
97
+ # Build dataloaders
98
+ datapath = args.valid_data
99
+ dataset = single_dataset_provider(datapath)
100
+
101
+ drop_last = False
102
+ if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
103
+ drop_last = True
104
+
105
+ print_rank_0(datapath)
106
+ print_rank_0(rank0sampler)
107
+
108
+ dataloader = build_data_loader(dataset,
109
+ args.eval_micro_batch_size,
110
+ num_workers=args.num_workers,
111
+ drop_last=drop_last,
112
+ task_collate_fn=task_collate_fn)
113
+ dataloaders = (dataset.dataset_name, dataloader)
114
+
115
+ def metrics_func(model, epoch, output_predictions=False):
116
+ print_rank_0('calculating metrics by accuracy func in ORQA...')
117
+
118
+ if output_predictions:
119
+ assert rank0sampler
120
+ names = 'predictions'
121
+ name, dataloader = dataloaders
122
+ if args.task == "RET-FINETUNE-NQ":
123
+ start_time = time.time()
124
+ output = retrieval_loss(model, dataloader)
125
+ stats_dict, total = output
126
+ format_string = ""
127
+ for k, v in stats_dict.items():
128
+ format_string += "|{} = {:.2f}".format(k, v / total)
129
+ print_rank_0("epoch:{}{}".format(epoch, format_string))
130
+ print_rank_0("taken time to calcuate metrics {:.3f}".format(\
131
+ time.time() - start_time))
132
+ else:
133
+ raise AssertionError("{} Task not supported".format(args.task))
134
+
135
+ return metrics_func
136
+
137
+
138
+ def retrieval_loss(model, dataloader):
139
+ args = get_args()
140
+ total = 0
141
+ topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
142
+ args.retriever_report_topk_accuracies}
143
+ stats_dict = dict(rank=0, **topk_stats_dict)
144
+
145
+ assert len(model) == 1
146
+ unwrapped_model = model[0]
147
+ unwrapped_model.eval()
148
+
149
+ with torch.no_grad():
150
+ # For all the batches in the dataset.
151
+ for batch in dataloader:
152
+ # Run the model forward.
153
+ query_tokens, query_mask, query_types, _, \
154
+ context_tokens, context_mask, context_types, _, \
155
+ neg_context_tokens, neg_context_mask, neg_context_types, \
156
+ reference = process_batch(batch)
157
+
158
+ query_logits, context_logits = unwrapped_model(query_tokens,
159
+ query_mask, query_types,
160
+ torch.cat([context_tokens, neg_context_tokens]),
161
+ torch.cat([context_mask, neg_context_mask]),
162
+ torch.cat([context_types, neg_context_types]))
163
+
164
+ retrieval_scores = torch.matmul(query_logits,
165
+ torch.transpose(context_logits, 0, 1))
166
+
167
+ if args.retriever_score_scaling:
168
+ retrieval_scores = retrieval_scores / \
169
+ math.sqrt(args.hidden_size)
170
+
171
+ local_batch_size = query_logits.shape[0]
172
+ labels = torch.arange(local_batch_size).long().cuda()
173
+
174
+ softmax_scores = F.softmax(retrieval_scores, dim=1)
175
+ sorted_vals, sorted_indices = torch.topk(softmax_scores,
176
+ k=softmax_scores.shape[1],
177
+ sorted=True)
178
+
179
+ def topk_accuracy(k):
180
+ return torch.cuda.FloatTensor(
181
+ [sum([int(labels[i] in sorted_indices[i, :k]) for i in \
182
+ range(local_batch_size)])])
183
+
184
+ def get_rank():
185
+ return torch.cuda.FloatTensor(
186
+ [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
187
+ for i in range(local_batch_size)])])
188
+
189
+ topk_accs = [topk_accuracy(k) for k in \
190
+ args.retriever_report_topk_accuracies]
191
+ rank = get_rank()
192
+ losses = average_losses_across_data_parallel_group([rank, \
193
+ *topk_accs])
194
+
195
+ # create stats_dict with retrieval loss and all specified
196
+ # top-k accuracies
197
+ topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
198
+ zip(args.retriever_report_topk_accuracies, losses[1:])}
199
+ temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
200
+ for k in stats_dict.keys():
201
+ stats_dict[k] += temp_stats_dict[k]
202
+ total += local_batch_size
203
+
204
+ unwrapped_model.train()
205
+
206
+ return stats_dict, total
tasks/orqa/supervised/finetune.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ORQA finetuning/evaluation."""
17
+
18
+ from functools import partial
19
+ import sys
20
+
21
+ import math
22
+ import torch
23
+ import torch.nn.functional as F
24
+
25
+ from megatron import get_args, get_timers, get_tokenizer
26
+ from megatron import mpu, print_rank_0
27
+ from megatron.indexer import IndexBuilder
28
+ from megatron.model.biencoder_model import biencoder_model_provider
29
+ from megatron.utils import average_losses_across_data_parallel_group
30
+ from pretrain_ict import get_group_world_size_rank
31
+ from tasks.finetune_utils import finetune
32
+ from tasks.orqa.supervised.eval_utils import accuracy_func_provider
33
+ from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
34
+ from tasks.orqa.evaluate_utils import ORQAEvaluator
35
+
36
+ # input_ is a 2D tensor
37
+ def check_and_append_tensor_for_gather(group, rank, world_size, input_):
38
+
39
+ # gather the size of the first dimension of the tensor from all ranks
40
+ current_length = input_.size()[0]
41
+ first_dim = torch.tensor([[current_length]],
42
+ device=torch.cuda.current_device())
43
+ input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
44
+ input_list[rank].copy_(first_dim)
45
+ torch.distributed.all_gather(input_list, first_dim, group=group)
46
+ all_input_list = torch.cat(input_list, dim=0).contiguous()
47
+ max_length = torch.max(all_input_list)
48
+
49
+ # if the size are different than the max, extend the tensor
50
+ # accordingly
51
+ if max_length > current_length:
52
+ padding=tuple([0] * (input_.dim() * 2 - 1)) + \
53
+ tuple([max_length - current_length])
54
+ input_ = F.pad(input=input_, pad=padding)
55
+
56
+ return input_
57
+
58
+ def orqa(Dataset):
59
+
60
+ def cross_entropy_forward_step(batch, model):
61
+ """Simple forward step with cross-entropy loss."""
62
+ timers = get_timers()
63
+ tokenizer = get_tokenizer()
64
+
65
+ # Get the batch.
66
+ timers('batch generator').start()
67
+ try:
68
+ batch_ = next(batch)
69
+ except BaseException:
70
+ batch_ = batch
71
+
72
+ group, rank, world_size = get_group_world_size_rank()
73
+
74
+ query_tokens, query_mask, query_types, query_pad_mask, \
75
+ context_tokens, context_mask, context_types, context_pad_mask, \
76
+ neg_context_tokens, neg_context_mask, neg_context_types, \
77
+ reference = process_batch(batch_)
78
+
79
+ timers('batch generator').stop()
80
+ local_batch_size = query_tokens.shape[0]
81
+
82
+ # Text representation of query and context
83
+ query_list, context_list = [], []
84
+ for i in range(local_batch_size):
85
+ query_list.append(tokenizer.decode(query_tokens[i].tolist()))
86
+ context_list.append(tokenizer.decode(context_tokens[i].tolist()))
87
+
88
+ if neg_context_tokens is not None:
89
+ neg_context_tokens = check_and_append_tensor_for_gather(group,
90
+ rank, world_size, neg_context_tokens)
91
+ neg_context_mask = check_and_append_tensor_for_gather(group,
92
+ rank, world_size, neg_context_mask)
93
+ neg_context_types = check_and_append_tensor_for_gather(group,
94
+ rank, world_size, neg_context_types)
95
+
96
+ if neg_context_tokens is not None:
97
+ context_tokens = torch.cat([context_tokens, neg_context_tokens])
98
+ context_mask = torch.cat([context_mask, neg_context_mask])
99
+ context_types = torch.cat([context_types, neg_context_types])
100
+
101
+ # Forward model.
102
+ output_tensor = model(query_tokens, query_mask,
103
+ query_types, context_tokens,
104
+ context_mask, context_types)
105
+ return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
106
+
107
+
108
+ def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
109
+ args = get_args()
110
+
111
+ local_batch_size = query_tokens.shape[0]
112
+ group, rank, world_size = get_group_world_size_rank()
113
+ # recall we assert that model_parallel_size == 1
114
+ global_batch_size = world_size * local_batch_size
115
+
116
+ query_logits, context_logits = output_tensor
117
+
118
+ if world_size > 1:
119
+ input_ = torch.empty_like(context_logits).copy_(\
120
+ context_logits).detach_()
121
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
122
+ tensor_list[rank].copy_(input_)
123
+ torch.distributed.all_gather(tensor_list, input_, group=group)
124
+
125
+ # Check if all-gather happens in order
126
+ assert tensor_list[rank].sum().item() == \
127
+ context_logits.sum().item()
128
+
129
+ # Preserves the gradient
130
+ tensor_list[rank] = context_logits
131
+ all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
132
+
133
+ # Query tensors
134
+ input_ = torch.empty_like(query_logits).copy_(\
135
+ query_logits).detach_()
136
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
137
+ tensor_list[rank].copy_(input_)
138
+ torch.distributed.all_gather(tensor_list, input_, group=group)
139
+
140
+ # Check if all-gather happens in order
141
+ assert tensor_list[rank].sum().item() == query_logits.sum().item()
142
+
143
+ # Preserves the gradient
144
+ tensor_list[rank] = query_logits
145
+ all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
146
+ else:
147
+ all_query_logits = query_logits
148
+ all_context_logits = context_logits
149
+
150
+ retrieval_scores = torch.matmul(all_query_logits,
151
+ torch.transpose(all_context_logits, 0, 1))
152
+ # Scaling the retrieval scores
153
+ if args.retriever_score_scaling:
154
+ retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
155
+
156
+ if args.train_with_neg:
157
+ # if the world size is 3, local batch size is 4, and
158
+ # local context size is 8, what we want is
159
+ # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
160
+ labels = []
161
+ local_context_size = context_tokens.shape[0]
162
+ for i in range(world_size):
163
+ j = i * local_context_size
164
+ labels.extend(list(range(j, j + local_batch_size)))
165
+ labels = torch.LongTensor(labels).cuda()
166
+ assert len(labels) == global_batch_size
167
+ else:
168
+ labels = torch.arange(global_batch_size).long().cuda()
169
+
170
+ # Cross-entropy loss.
171
+ softmax_scores = F.log_softmax(retrieval_scores, dim=1)
172
+
173
+ loss = F.nll_loss(softmax_scores, labels, reduction='mean')
174
+
175
+ max_score, max_idxs = torch.max(softmax_scores, 1)
176
+ correct_predictions_count = (max_idxs == labels).sum().float()
177
+
178
+ # Reduce loss for logging.
179
+ reduced_loss = average_losses_across_data_parallel_group([loss, \
180
+ correct_predictions_count])
181
+
182
+ # Loss scaling for correct losses in Supervised Retrieval
183
+ loss = loss * mpu.get_data_parallel_world_size()
184
+
185
+ return loss, {'lm loss': reduced_loss[0],
186
+ 'correct_prediction_count': reduced_loss[1]}
187
+
188
+
189
+ def train_valid_datasets_provider():
190
+ """Build train and validation dataset."""
191
+ args = get_args()
192
+ tokenizer = get_tokenizer()
193
+
194
+ train_dataset = Dataset('training',
195
+ args.train_data,
196
+ tokenizer,
197
+ args.retriever_seq_length,
198
+ evaluate=False)
199
+ valid_dataset = Dataset('validation',
200
+ args.valid_data,
201
+ tokenizer,
202
+ args.retriever_seq_length,
203
+ evaluate=True)
204
+ return train_dataset, valid_dataset
205
+
206
+ def model_provider(pre_process=True, post_process=True):
207
+ """Build the model."""
208
+ args = get_args()
209
+ print_rank_0('building retriever model for {} ...'.format(args.task))
210
+
211
+ model = biencoder_model_provider(only_context_model=False,
212
+ only_query_model=False,
213
+ biencoder_shared_query_context_model=\
214
+ args.biencoder_shared_query_context_model,
215
+ pre_process=pre_process, post_process=post_process)
216
+
217
+ return model
218
+
219
+ def single_dataset_provider(datapath):
220
+ args = get_args()
221
+ tokenizer = get_tokenizer()
222
+
223
+ name = datapath[0].split('/')[-1].split('.')[0]
224
+ return Dataset(name,
225
+ datapath,
226
+ tokenizer,
227
+ args.retriever_seq_length,
228
+ evaluate=True)
229
+
230
+ def metrics_func_provider():
231
+ """Provide metrics callback function."""
232
+ return accuracy_func_provider(single_dataset_provider)
233
+
234
+ """Finetune/evaluate."""
235
+ finetune(train_valid_datasets_provider,
236
+ model_provider,
237
+ forward_step=cross_entropy_forward_step,
238
+ end_of_epoch_callback_provider=metrics_func_provider,
239
+ task_collate_fn=task_collate_fn)
240
+
241
+ def main():
242
+ args = get_args()
243
+
244
+ if args.task == 'RET-FINETUNE-NQ':
245
+ from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
246
+ else:
247
+ raise NotImplementedError('ORQA task {} is not implemented.'.format(
248
+ args.task))
249
+
250
+ orqa(Dataset)
251
+
tasks/orqa/unsupervised/nq.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Data Loader for Google NQ dataset
18
+ """
19
+
20
+ from abc import ABC
21
+ import csv
22
+ from collections import OrderedDict
23
+ import numpy as np
24
+
25
+ import torch
26
+ from torch.utils.data import DataLoader
27
+ from torch.utils.data import Dataset, BatchSampler
28
+
29
+ from megatron import print_rank_0, get_args, get_tokenizer, mpu
30
+ from megatron.data.biencoder_dataset_utils import make_attention_mask
31
+
32
+ def get_nq_dataset(qa_data, split):
33
+ args = get_args()
34
+ tokenizer = get_tokenizer()
35
+
36
+ dataset = NQDataset('Google NQ {} Split'.format(split),
37
+ 'Google Natural Questions',
38
+ qa_data,
39
+ tokenizer,
40
+ args.retriever_seq_length)
41
+ return dataset
42
+
43
+
44
+ def process_nq_batch(batch):
45
+ query_tokens = batch['token_ids'].long().cuda()
46
+ query_mask = (batch['token_mask'] < 0.5).cuda()
47
+ query_types = batch['token_types'].long().cuda()
48
+ query_len = batch['seq_len'].long().cuda()
49
+ reference = batch['reference']
50
+
51
+ return query_tokens, query_mask, query_types, query_len, reference
52
+
53
+
54
+ class CustomDataLoader(DataLoader):
55
+ def __init__(self, dataset, eval=False, **kwargs):
56
+ if kwargs.get('collate_fn', None) is None:
57
+ kwargs['collate_fn'] = self._collate_fn
58
+ self.eval = eval
59
+ super().__init__(dataset, **kwargs)
60
+
61
+ def _collate_fn(self, batch_data):
62
+ # generate batch
63
+ batch_size = len(batch_data)
64
+ tensorized = OrderedDict()
65
+ for d in batch_data:
66
+ for k, v in d.items():
67
+ tensorized.setdefault(k, []).append(v)
68
+ assert len(tensorized) == 5
69
+
70
+ tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
71
+ tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
72
+ tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
73
+ tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
74
+ return tensorized
75
+
76
+
77
+ def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
78
+ """Data loader. Note that batch-size is the local (per GPU) batch-size.
79
+ NOTE: This dataloader is not distributed !!!
80
+ """
81
+
82
+ args = get_args()
83
+ if micro_batch_size is None:
84
+ micro_batch_size = args.micro_batch_size
85
+ num_workers = args.num_workers
86
+
87
+ sampler = torch.utils.data.SequentialSampler(dataset)
88
+ # importantly, drop_last must be False to get all the data.
89
+ batch_sampler = BatchSampler(sampler,
90
+ batch_size=micro_batch_size,
91
+ drop_last=False)
92
+
93
+ # Data loader. Note that batch size is the per GPU batch size.
94
+ data_loader = CustomDataLoader(dataset,
95
+ batch_sampler=batch_sampler,
96
+ num_workers=num_workers,
97
+ pin_memory=True)
98
+ return data_loader
99
+
100
+
101
+ def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
102
+ """Build token types and paddings, trim if needed, and pad if needed."""
103
+
104
+ src_text_ids = tokenizer.tokenize(src_text)
105
+
106
+ return build_tokens_types_paddings_from_ids(src_text_ids,
107
+ max_seq_length,
108
+ tokenizer.cls,
109
+ tokenizer.sep,
110
+ tokenizer.pad)
111
+
112
+
113
+ def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
114
+ sep_id, pad_id):
115
+ """
116
+ Build token types and paddings, trim if needed, and pad if needed.
117
+
118
+ TODO: Design modular interface to reuse this function. This is getting
119
+ repeated multiple times in different tasks
120
+ """
121
+
122
+ enc_ids = []
123
+ tokentypes_enc = []
124
+
125
+ # [CLS].
126
+ enc_ids.append(cls_id)
127
+ tokentypes_enc.append(0)
128
+
129
+ # A.
130
+ len_src = len(src_ids)
131
+ enc_ids.extend(src_ids)
132
+ tokentypes_enc.extend([0] * len_src)
133
+
134
+ # Cap the size.
135
+ if len(enc_ids) > max_seq_length - 1:
136
+ enc_ids = enc_ids[0: max_seq_length - 1]
137
+ tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
138
+
139
+ # [SEP].
140
+ enc_ids.append(sep_id)
141
+ tokentypes_enc.append(0)
142
+
143
+ num_tokens_enc = len(enc_ids)
144
+ # Padding.
145
+ padding_length = max_seq_length - len(enc_ids)
146
+ if padding_length > 0:
147
+ enc_ids.extend([pad_id] * padding_length)
148
+ tokentypes_enc.extend([pad_id] * padding_length)
149
+
150
+ return enc_ids, tokentypes_enc, num_tokens_enc
151
+
152
+
153
+ def build_sample(token_ids, token_types, num_tokens, reference):
154
+ """
155
+ Convert to numpy and return a sample consumed by the
156
+ batch producer.
157
+ """
158
+
159
+ token_ids = np.array(token_ids, dtype=np.int64)
160
+ token_types = np.array(token_types, dtype=np.int64)
161
+ token_mask = make_attention_mask(token_ids, token_ids)
162
+
163
+ sample = ({
164
+ 'token_ids': token_ids,
165
+ 'token_mask': token_mask,
166
+ 'token_types': token_types,
167
+ 'seq_len': num_tokens,
168
+ 'reference': reference
169
+ })
170
+ return sample
171
+
172
+
173
+ class NQDataset(ABC, Dataset):
174
+ """
175
+ Open Retrieval Question Answering evaluation using Google NQ dataset.
176
+ """
177
+
178
+ def __init__(self, task_name, dataset_name, datapath,
179
+ tokenizer, max_seq_length):
180
+ # Store inputs.
181
+ self.task_name = task_name
182
+ self.dataset_name = dataset_name
183
+ self.tokenizer = tokenizer
184
+ self.max_seq_length = max_seq_length
185
+ print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
186
+ self.dataset_name))
187
+ print_rank_0(datapath)
188
+ self.samples = self.process_samples_from_single_path(datapath)
189
+ print_rank_0(' >> total number of samples: {}'.format(\
190
+ len(self.samples)))
191
+
192
+ def __len__(self):
193
+ return len(self.samples)
194
+
195
+ def __getitem__(self, idx):
196
+ raw_sample = self.samples[idx]
197
+
198
+ ques_tokens, tokentypes_enc, num_tokens_ques = \
199
+ build_tokens_types_paddings_from_text(raw_sample['question'],
200
+ self.tokenizer, self.max_seq_length)
201
+
202
+ sample = build_sample(ques_tokens,
203
+ tokentypes_enc,
204
+ num_tokens_ques,
205
+ raw_sample['answers'])
206
+ return sample
207
+
208
+ @staticmethod
209
+ def process_samples_from_single_path(filename):
210
+ print_rank_0(' > Processing {} ...'.format(filename))
211
+ samples = []
212
+ total = 0
213
+
214
+ with open(filename, 'r') as ifile:
215
+ reader = csv.reader(ifile, delimiter='\t')
216
+ for row in reader:
217
+ question = row[0]
218
+ answers = eval(row[1])
219
+
220
+ sample = {'question': question, 'answers': answers}
221
+ total += 1
222
+ samples.append(sample)
223
+
224
+ if total % 1000 == 0:
225
+ print_rank_0(' > processed {} so far ...'.format(total))
226
+
227
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
228
+ return samples
tasks/orqa/unsupervised/qa_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+
6
+ # The following code has been taken from
7
+ # https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
8
+ # licensed as of now. More details on the license can be found
9
+ # at https://github.com/facebookresearch/DPR/blob/master/LICENSE
10
+
11
+ """
12
+ Set of utilities for Q&A results validation tasks - Retriver passage
13
+ validation and Reader predicted answer validation
14
+ """
15
+
16
+ import collections
17
+ import logging
18
+ import string
19
+ import unicodedata
20
+ from functools import partial
21
+ from multiprocessing import Pool as ProcessPool
22
+ from typing import Tuple, List, Dict
23
+
24
+ import regex as re
25
+ from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
30
+ 'questions_doc_hits'])
31
+
32
+ def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
33
+ answers: List[List[str]], closest_docs: List[Tuple[List[object],
34
+ List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
35
+ """
36
+ Evaluates answers presence in the set of documents. This function is
37
+ supposed to be used with a large collection of documents and results.
38
+ It internally forks multiple sub-processes for evaluation and then
39
+ merges results
40
+ :param all_docs: dictionary of the entire documents database.
41
+ doc_id -> (doc_text, title)
42
+ :param answers: list of answers's list. One list per question
43
+ :param closest_docs: document ids of the top results along with their
44
+ scores
45
+ :param workers_num: amount of parallel threads to process data
46
+ :param match_type: type of answer matching. Refer to has_answer code for
47
+ available options
48
+ :return: matching information tuple.
49
+ top_k_hits - a list where the index is the amount of top documents retrieved
50
+ and the value is the total amount of valid matches across an entire
51
+ dataset.
52
+ questions_doc_hits - more detailed info with answer matches for every
53
+ question and every retrieved document
54
+ """
55
+ global dpr_all_documents
56
+ dpr_all_documents = all_docs
57
+
58
+ tok_opts = {}
59
+ tokenizer = SimpleTokenizer(**tok_opts)
60
+
61
+ processes = ProcessPool(
62
+ processes=workers_num,
63
+ )
64
+
65
+ logger.info('Matching answers in top docs...')
66
+
67
+ get_score_partial = partial(check_answer, match_type=match_type,
68
+ tokenizer=tokenizer)
69
+
70
+ questions_answers_docs = zip(answers, closest_docs)
71
+
72
+ scores = processes.map(get_score_partial, questions_answers_docs)
73
+
74
+ logger.info('Per question validation results len=%d', len(scores))
75
+
76
+ n_docs = len(closest_docs[0][0])
77
+ top_k_hits = [0] * n_docs
78
+ for question_hits in scores:
79
+ best_hit = next((i for i, x in enumerate(question_hits) if x), None)
80
+ if best_hit is not None:
81
+ top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
82
+
83
+ return QAMatchStats(top_k_hits, scores)
84
+
85
+
86
+ def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
87
+ """
88
+ Search through all the top docs to see if they have any of the answers.
89
+ """
90
+ answers, (doc_ids, doc_scores) = questions_answers_docs
91
+
92
+ global dpr_all_documents
93
+ hits = []
94
+
95
+ for i, doc_id in enumerate(doc_ids):
96
+ doc = dpr_all_documents[doc_id]
97
+ text = doc[0]
98
+
99
+ answer_found = False
100
+ if text is None: # cannot find the document for some reason
101
+ logger.warning("no doc in db")
102
+ hits.append(False)
103
+ continue
104
+
105
+ if has_answer(answers, text, tokenizer, match_type):
106
+ answer_found = True
107
+ hits.append(answer_found)
108
+ return hits
109
+
110
+
111
+ def has_answer(answers, text, tokenizer, match_type) -> bool:
112
+ """
113
+ Check if a document contains an answer string.
114
+ If `match_type` is string, token matching is done between the text
115
+ and answer.
116
+ If `match_type` is regex, we search the whole text with the regex.
117
+ """
118
+ text = _normalize(text)
119
+
120
+ if match_type == 'string':
121
+ # Answer is a list of possible strings
122
+ text = tokenizer.tokenize(text).words(uncased=True)
123
+
124
+ for single_answer in answers:
125
+ single_answer = _normalize(single_answer)
126
+ single_answer = tokenizer.tokenize(single_answer)
127
+ single_answer = single_answer.words(uncased=True)
128
+
129
+ for i in range(0, len(text) - len(single_answer) + 1):
130
+ if single_answer == text[i: i + len(single_answer)]:
131
+ return True
132
+
133
+ elif match_type == 'regex':
134
+ # Answer is a regex
135
+ for single_answer in answers:
136
+ single_answer = _normalize(single_answer)
137
+ if regex_match(text, single_answer):
138
+ return True
139
+ return False
140
+
141
+
142
+ def regex_match(text, pattern):
143
+ """Test if a regex pattern is contained within a text."""
144
+ try:
145
+ pattern = re.compile(
146
+ pattern,
147
+ flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
148
+ )
149
+ except BaseException:
150
+ return False
151
+ return pattern.search(text) is not None
152
+
153
+
154
+ # function for the reader model answer validation
155
+ def exact_match_score(prediction, ground_truth):
156
+ return _normalize_answer(prediction) == _normalize_answer(ground_truth)
157
+
158
+
159
+ def _normalize_answer(s):
160
+ def remove_articles(text):
161
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
162
+
163
+ def white_space_fix(text):
164
+ return ' '.join(text.split())
165
+
166
+ def remove_punc(text):
167
+ exclude = set(string.punctuation)
168
+ return ''.join(ch for ch in text if ch not in exclude)
169
+
170
+ def lower(text):
171
+ return text.lower()
172
+
173
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
174
+
175
+
176
+ def _normalize(text):
177
+ return unicodedata.normalize('NFD', text)
tasks/orqa/unsupervised/tokenizers.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+
6
+ # The following code has been taken from
7
+ # https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
8
+ # licensed as of now. More details on the license can be found
9
+ # at https://github.com/facebookresearch/DPR/blob/master/LICENSE
10
+
11
+ """
12
+ Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
13
+ """
14
+
15
+ import copy
16
+ import logging
17
+
18
+ import regex
19
+ import spacy
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class Tokens(object):
25
+ """A class to represent a list of tokenized text."""
26
+ TEXT = 0
27
+ TEXT_WS = 1
28
+ SPAN = 2
29
+ POS = 3
30
+ LEMMA = 4
31
+ NER = 5
32
+
33
+ def __init__(self, data, annotators, opts=None):
34
+ self.data = data
35
+ self.annotators = annotators
36
+ self.opts = opts or {}
37
+
38
+ def __len__(self):
39
+ """The number of tokens."""
40
+ return len(self.data)
41
+
42
+ def slice(self, i=None, j=None):
43
+ """Return a view of the list of tokens from [i, j)."""
44
+ new_tokens = copy.copy(self)
45
+ new_tokens.data = self.data[i: j]
46
+ return new_tokens
47
+
48
+ def untokenize(self):
49
+ """Returns the original text (with whitespace reinserted)."""
50
+ return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
51
+
52
+ def words(self, uncased=False):
53
+ """Returns a list of the text of each token
54
+
55
+ Args:
56
+ uncased: lower cases text
57
+ """
58
+ if uncased:
59
+ return [t[self.TEXT].lower() for t in self.data]
60
+ else:
61
+ return [t[self.TEXT] for t in self.data]
62
+
63
+ def offsets(self):
64
+ """Returns a list of [start, end) character offsets of each token."""
65
+ return [t[self.SPAN] for t in self.data]
66
+
67
+ def pos(self):
68
+ """Returns a list of part-of-speech tags of each token.
69
+ Returns None if this annotation was not included.
70
+ """
71
+ if 'pos' not in self.annotators:
72
+ return None
73
+ return [t[self.POS] for t in self.data]
74
+
75
+ def lemmas(self):
76
+ """Returns a list of the lemmatized text of each token.
77
+ Returns None if this annotation was not included.
78
+ """
79
+ if 'lemma' not in self.annotators:
80
+ return None
81
+ return [t[self.LEMMA] for t in self.data]
82
+
83
+ def entities(self):
84
+ """Returns a list of named-entity-recognition tags of each token.
85
+ Returns None if this annotation was not included.
86
+ """
87
+ if 'ner' not in self.annotators:
88
+ return None
89
+ return [t[self.NER] for t in self.data]
90
+
91
+ def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
92
+ """Returns a list of all ngrams from length 1 to n.
93
+
94
+ Args:
95
+ n: upper limit of ngram length
96
+ uncased: lower cases text
97
+ filter_fn: user function that takes in an ngram list and returns
98
+ True or False to keep or not keep the ngram
99
+ as_string: return the ngram as a string vs list
100
+ """
101
+
102
+ def _skip(gram):
103
+ if not filter_fn:
104
+ return False
105
+ return filter_fn(gram)
106
+
107
+ words = self.words(uncased)
108
+ ngrams = [(s, e + 1)
109
+ for s in range(len(words))
110
+ for e in range(s, min(s + n, len(words)))
111
+ if not _skip(words[s:e + 1])]
112
+
113
+ # Concatenate into strings
114
+ if as_strings:
115
+ ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
116
+
117
+ return ngrams
118
+
119
+ def entity_groups(self):
120
+ """Group consecutive entity tokens with the same NER tag."""
121
+ entities = self.entities()
122
+ if not entities:
123
+ return None
124
+ non_ent = self.opts.get('non_ent', 'O')
125
+ groups = []
126
+ idx = 0
127
+ while idx < len(entities):
128
+ ner_tag = entities[idx]
129
+ # Check for entity tag
130
+ if ner_tag != non_ent:
131
+ # Chomp the sequence
132
+ start = idx
133
+ while (idx < len(entities) and entities[idx] == ner_tag):
134
+ idx += 1
135
+ groups.append((self.slice(start, idx).untokenize(), ner_tag))
136
+ else:
137
+ idx += 1
138
+ return groups
139
+
140
+
141
+ class Tokenizer(object):
142
+ """Base tokenizer class.
143
+ Tokenizers implement tokenize, which should return a Tokens class.
144
+ """
145
+
146
+ def tokenize(self, text):
147
+ raise NotImplementedError
148
+
149
+ def shutdown(self):
150
+ pass
151
+
152
+ def __del__(self):
153
+ self.shutdown()
154
+
155
+
156
+ class SimpleTokenizer(Tokenizer):
157
+ ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
158
+ NON_WS = r'[^\p{Z}\p{C}]'
159
+
160
+ def __init__(self, **kwargs):
161
+ """
162
+ Args:
163
+ annotators: None or empty set (only tokenizes).
164
+ """
165
+ self._regexp = regex.compile(
166
+ '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
167
+ flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
168
+ )
169
+ if len(kwargs.get('annotators', {})) > 0:
170
+ logger.warning('%s only tokenizes! Skipping annotators: %s' %
171
+ (type(self).__name__, kwargs.get('annotators')))
172
+ self.annotators = set()
173
+
174
+ def tokenize(self, text):
175
+ data = []
176
+ matches = [m for m in self._regexp.finditer(text)]
177
+ for i in range(len(matches)):
178
+ # Get text
179
+ token = matches[i].group()
180
+
181
+ # Get whitespace
182
+ span = matches[i].span()
183
+ start_ws = span[0]
184
+ if i + 1 < len(matches):
185
+ end_ws = matches[i + 1].span()[0]
186
+ else:
187
+ end_ws = span[1]
188
+
189
+ # Format data
190
+ data.append((
191
+ token,
192
+ text[start_ws: end_ws],
193
+ span,
194
+ ))
195
+ return Tokens(data, self.annotators)
196
+
197
+
198
+ class SpacyTokenizer(Tokenizer):
199
+
200
+ def __init__(self, **kwargs):
201
+ """
202
+ Args:
203
+ annotators: set that can include pos, lemma, and ner.
204
+ model: spaCy model to use (either path, or keyword like 'en').
205
+ """
206
+ model = kwargs.get('model', 'en')
207
+ self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
208
+ nlp_kwargs = {'parser': False}
209
+ if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
210
+ nlp_kwargs['tagger'] = False
211
+ if 'ner' not in self.annotators:
212
+ nlp_kwargs['entity'] = False
213
+ self.nlp = spacy.load(model, **nlp_kwargs)
214
+
215
+ def tokenize(self, text):
216
+ # We don't treat new lines as tokens.
217
+ clean_text = text.replace('\n', ' ')
218
+ tokens = self.nlp.tokenizer(clean_text)
219
+ if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
220
+ self.nlp.tagger(tokens)
221
+ if 'ner' in self.annotators:
222
+ self.nlp.entity(tokens)
223
+
224
+ data = []
225
+ for i in range(len(tokens)):
226
+ # Get whitespace
227
+ start_ws = tokens[i].idx
228
+ if i + 1 < len(tokens):
229
+ end_ws = tokens[i + 1].idx
230
+ else:
231
+ end_ws = tokens[i].idx + len(tokens[i].text)
232
+
233
+ data.append((
234
+ tokens[i].text,
235
+ text[start_ws: end_ws],
236
+ (tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
237
+ tokens[i].tag_,
238
+ tokens[i].lemma_,
239
+ tokens[i].ent_type_,
240
+ ))
241
+
242
+ # Set special option for non-entity tag: '' vs 'O' in spaCy
243
+ return Tokens(data, self.annotators, opts={'non_ent': ''})
tasks/race/data.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import glob
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ from torch.utils.data import Dataset
8
+
9
+ from megatron import print_rank_0
10
+ from tasks.data_utils import build_sample
11
+ from tasks.data_utils import build_tokens_types_paddings_from_ids
12
+ from tasks.data_utils import clean_text
13
+
14
+
15
+ NUM_CHOICES = 4
16
+ MAX_QA_LENGTH = 128
17
+
18
+
19
+ class RaceDataset(Dataset):
20
+
21
+ def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
22
+ max_qa_length=MAX_QA_LENGTH):
23
+
24
+ self.dataset_name = dataset_name
25
+ print_rank_0(' > building RACE dataset for {}:'.format(
26
+ self.dataset_name))
27
+
28
+ string = ' > paths:'
29
+ for path in datapaths:
30
+ string += ' ' + path
31
+ print_rank_0(string)
32
+
33
+ self.samples = []
34
+ for datapath in datapaths:
35
+ self.samples.extend(process_single_datapath(datapath, tokenizer,
36
+ max_qa_length,
37
+ max_seq_length))
38
+
39
+ print_rank_0(' >> total number of samples: {}'.format(
40
+ len(self.samples)))
41
+
42
+ # This indicates that each "sample" has multiple samples that
43
+ # will collapse into batch dimension
44
+ self.sample_multiplier = NUM_CHOICES
45
+
46
+ def __len__(self):
47
+ return len(self.samples)
48
+
49
+ def __getitem__(self, idx):
50
+ return self.samples[idx]
51
+
52
+
53
+ def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
54
+ """Read in RACE files, combine, clean-up, tokenize, and convert to
55
+ samples."""
56
+
57
+ print_rank_0(' > working on {}'.format(datapath))
58
+ start_time = time.time()
59
+
60
+ # Get list of files.
61
+ filenames = glob.glob(os.path.join(datapath, '*.txt'))
62
+
63
+ samples = []
64
+ num_docs = 0
65
+ num_questions = 0
66
+ num_samples = 0
67
+ # Load all the files
68
+ for filename in filenames:
69
+ with open(filename, 'r') as f:
70
+ for line in f:
71
+ data = json.loads(line)
72
+ num_docs += 1
73
+
74
+ context = data["article"]
75
+ questions = data["questions"]
76
+ choices = data["options"]
77
+ answers = data["answers"]
78
+ # Check the length.
79
+ assert len(questions) == len(answers)
80
+ assert len(questions) == len(choices)
81
+
82
+ # Context: clean up and convert to ids.
83
+ context = clean_text(context)
84
+ context_ids = tokenizer.tokenize(context)
85
+
86
+ # Loop over questions.
87
+ for qi, question in enumerate(questions):
88
+ num_questions += 1
89
+ # Label.
90
+ label = ord(answers[qi]) - ord("A")
91
+ assert label >= 0
92
+ assert label < NUM_CHOICES
93
+ assert len(choices[qi]) == NUM_CHOICES
94
+
95
+ # For each question, build num-choices samples.
96
+ ids_list = []
97
+ types_list = []
98
+ paddings_list = []
99
+ for ci in range(NUM_CHOICES):
100
+ choice = choices[qi][ci]
101
+ # Merge with choice.
102
+ if "_" in question:
103
+ qa = question.replace("_", choice)
104
+ else:
105
+ qa = " ".join([question, choice])
106
+ # Clean QA.
107
+ qa = clean_text(qa)
108
+ # Tokenize.
109
+ qa_ids = tokenizer.tokenize(qa)
110
+ # Trim if needed.
111
+ if len(qa_ids) > max_qa_length:
112
+ qa_ids = qa_ids[0:max_qa_length]
113
+
114
+ # Build the sample.
115
+ ids, types, paddings \
116
+ = build_tokens_types_paddings_from_ids(
117
+ qa_ids, context_ids, max_seq_length,
118
+ tokenizer.cls, tokenizer.sep, tokenizer.pad)
119
+
120
+ ids_list.append(ids)
121
+ types_list.append(types)
122
+ paddings_list.append(paddings)
123
+
124
+ # Convert to numpy and add to samples
125
+ samples.append(build_sample(ids_list, types_list,
126
+ paddings_list, label,
127
+ num_samples))
128
+ num_samples += 1
129
+
130
+ elapsed_time = time.time() - start_time
131
+ print_rank_0(' > processed {} document, {} questions, and {} samples'
132
+ ' in {:.2f} seconds'.format(num_docs, num_questions,
133
+ num_samples, elapsed_time))
134
+
135
+ return samples
tasks/race/finetune.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Race."""
17
+
18
+ from megatron import get_args
19
+ from megatron import print_rank_0
20
+ from megatron import get_tokenizer
21
+ from megatron import mpu
22
+ from megatron.model.multiple_choice import MultipleChoice
23
+ from tasks.eval_utils import accuracy_func_provider
24
+ from tasks.finetune_utils import finetune
25
+ from tasks.race.data import RaceDataset
26
+
27
+
28
+ def train_valid_datasets_provider():
29
+ """Provide train and validation datasets."""
30
+ args = get_args()
31
+ tokenizer = get_tokenizer()
32
+
33
+ train_dataset = RaceDataset('training', args.train_data,
34
+ tokenizer, args.seq_length)
35
+ valid_dataset = RaceDataset('validation', args.valid_data,
36
+ tokenizer, args.seq_length)
37
+
38
+ return train_dataset, valid_dataset
39
+
40
+
41
+ def model_provider(pre_process=True, post_process=True):
42
+ """Build the model."""
43
+
44
+ print_rank_0('building multichoice model for RACE ...')
45
+ model = MultipleChoice(num_tokentypes=2,
46
+ pre_process=pre_process,
47
+ post_process=post_process)
48
+
49
+ return model
50
+
51
+
52
+ def metrics_func_provider():
53
+ """Privde metrics callback function."""
54
+ args = get_args()
55
+ tokenizer = get_tokenizer()
56
+
57
+ def single_dataset_provider(datapath):
58
+ name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
59
+ return RaceDataset(name, [datapath], tokenizer, args.seq_length)
60
+
61
+ return accuracy_func_provider(single_dataset_provider)
62
+
63
+
64
+ def main():
65
+
66
+ finetune(train_valid_datasets_provider, model_provider,
67
+ end_of_epoch_callback_provider=metrics_func_provider)
tasks/vision/classification/classification.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Vision-classification finetuning/evaluation."""
17
+
18
+ import torch.nn.functional as F
19
+ from functools import partial
20
+ from megatron import get_args, get_timers
21
+ from megatron import print_rank_0
22
+ from megatron.model.vision.classification import VitClassificationModel
23
+ from megatron.data.vit_dataset import build_train_valid_datasets
24
+ from tasks.vision.classification.eval_utils import accuracy_func_provider
25
+ from tasks.vision.finetune_utils import finetune
26
+ from megatron.utils import average_losses_across_data_parallel_group
27
+
28
+
29
+ def classification():
30
+ def train_valid_datasets_provider():
31
+ """Build train and validation dataset."""
32
+ args = get_args()
33
+
34
+ train_ds, valid_ds = build_train_valid_datasets(
35
+ data_path=args.data_path,
36
+ image_size=(args.img_h, args.img_w),
37
+ )
38
+ return train_ds, valid_ds
39
+
40
+ def model_provider(pre_process=True, post_process=True):
41
+ """Build the model."""
42
+ args = get_args()
43
+
44
+ print_rank_0("building classification model for ImageNet ...")
45
+
46
+ return VitClassificationModel(num_classes=args.num_classes, finetune=True,
47
+ pre_process=pre_process, post_process=post_process)
48
+
49
+ def process_batch(batch):
50
+ """Process batch and produce inputs for the model."""
51
+ images = batch[0].cuda().contiguous()
52
+ labels = batch[1].cuda().contiguous()
53
+ return images, labels
54
+
55
+ def cross_entropy_loss_func(labels, output_tensor):
56
+ logits = output_tensor
57
+
58
+ # Cross-entropy loss.
59
+ loss = F.cross_entropy(logits.contiguous().float(), labels)
60
+
61
+ # Reduce loss for logging.
62
+ averaged_loss = average_losses_across_data_parallel_group([loss])
63
+
64
+ return loss, {'lm loss': averaged_loss[0]}
65
+
66
+ def _cross_entropy_forward_step(batch, model):
67
+ """Simple forward step with cross-entropy loss."""
68
+ timers = get_timers()
69
+
70
+ # Get the batch.
71
+ timers("batch generator").start()
72
+ try:
73
+ batch_ = next(batch)
74
+ except BaseException:
75
+ batch_ = batch
76
+ images, labels = process_batch(batch_)
77
+ timers("batch generator").stop()
78
+
79
+ # Forward model.
80
+ output_tensor = model(images)
81
+
82
+ return output_tensor, partial(cross_entropy_loss_func, labels)
83
+
84
+ """Finetune/evaluate."""
85
+ finetune(
86
+ train_valid_datasets_provider,
87
+ model_provider,
88
+ forward_step=_cross_entropy_forward_step,
89
+ end_of_epoch_callback_provider=accuracy_func_provider,
90
+ )
91
+
92
+ def main():
93
+ classification()
94
+
tasks/vision/classification/eval_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Evaluation utilities."""
17
+
18
+ import os
19
+ from functools import partial
20
+
21
+ import torch
22
+
23
+ from megatron import get_args
24
+ from megatron import print_rank_0, print_rank_last
25
+ from megatron import mpu
26
+ from megatron.schedules import get_forward_backward_func
27
+ from tasks.vision.finetune_utils import build_data_loader
28
+ from tasks.vision.finetune_utils import process_batch
29
+ from torchvision import datasets, transforms
30
+
31
+
32
+ def accuracy_func_provider():
33
+ """Provide function that calculates accuracies."""
34
+ args = get_args()
35
+ data_path = args.data_path
36
+ crop_size = (args.img_h, args.img_w)
37
+
38
+ # Build dataloaders.
39
+ val_data_path = data_path[1]
40
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
41
+ transform_val = transforms.Compose(
42
+ [
43
+ transforms.Resize(crop_size),
44
+ transforms.CenterCrop(crop_size),
45
+ transforms.ToTensor(),
46
+ normalize,
47
+ ]
48
+ )
49
+ dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
50
+
51
+ dataloader = build_data_loader(
52
+ dataset,
53
+ args.micro_batch_size,
54
+ num_workers=args.num_workers,
55
+ drop_last=(mpu.get_data_parallel_world_size() > 1),
56
+ shuffle=False
57
+ )
58
+
59
+ def metrics_func(model, epoch):
60
+ print_rank_0("calculating metrics ...")
61
+ correct, total = calculate_correct_answers(model, dataloader, epoch)
62
+ percent = float(correct) * 100.0 / float(total)
63
+ print_rank_last(
64
+ " >> |epoch: {}| overall: correct / total = {} / {} = "
65
+ "{:.4f} %".format(epoch, correct, total, percent)
66
+ )
67
+
68
+ return metrics_func
69
+
70
+
71
+ def calculate_correct_answers(model, dataloader, epoch):
72
+ """Calculate correct over total answers"""
73
+
74
+ forward_backward_func = get_forward_backward_func()
75
+ for m in model:
76
+ m.eval()
77
+
78
+ def loss_func(labels, output_tensor):
79
+ logits = output_tensor
80
+
81
+ loss_dict = {}
82
+ # Compute the correct answers.
83
+ predicted = torch.argmax(logits, dim=-1)
84
+ corrects = (predicted == labels).float()
85
+ # Add to the counters.
86
+ loss_dict['total'] = labels.size(0)
87
+ loss_dict['correct'] = corrects.sum().item()
88
+
89
+ return 0, loss_dict
90
+
91
+ #defined inside to capture output_predictions
92
+ def correct_answers_forward_step(batch, model):
93
+ try:
94
+ batch_ = next(batch)
95
+ except BaseException:
96
+ batch_ = batch
97
+ images, labels = process_batch(batch_)
98
+
99
+ # Forward model.
100
+ output_tensor = model(images)
101
+
102
+ return output_tensor, partial(loss_func, labels)
103
+
104
+ with torch.no_grad():
105
+ # For all the batches in the dataset.
106
+ total = 0
107
+ correct = 0
108
+ for _, batch in enumerate(dataloader):
109
+
110
+ loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
111
+ optimizer=None, timers=None, forward_only=True)
112
+
113
+ for loss_dict in loss_dicts:
114
+ total += loss_dict['total']
115
+ correct += loss_dict['correct']
116
+
117
+ for m in model:
118
+ m.train()
119
+
120
+ # Reduce.
121
+ if mpu.is_pipeline_last_stage():
122
+ unreduced = torch.cuda.LongTensor([correct, total])
123
+ torch.distributed.all_reduce(unreduced,
124
+ group=mpu.get_data_parallel_group())
125
+
126
+ # Print on screen.
127
+ correct_ans = unreduced[0].item()
128
+ total_count = unreduced[1].item()
129
+ return correct_ans, total_count
tasks/vision/finetune_utils.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Finetune utilities."""
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from megatron import get_args
21
+ from megatron import print_rank_0
22
+ from megatron import get_timers
23
+ from megatron import mpu, utils
24
+ from megatron.checkpointing import load_checkpoint
25
+ from megatron.checkpointing import save_checkpoint
26
+ from megatron.training import evaluate_and_print_results
27
+ from megatron.training import setup_model_and_optimizer
28
+ from megatron.training import train_step
29
+ from megatron.training import training_log
30
+ from megatron.utils import check_adlr_autoresume_termination
31
+ from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
32
+ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
33
+ from megatron.model import DistributedDataParallel as LocalDDP
34
+ from megatron.model import Float16Module, ModelType
35
+
36
+
37
+ def process_batch(batch):
38
+ """Process batch and produce inputs for the model."""
39
+ images = batch[0].cuda().contiguous()
40
+ labels = batch[1].cuda().contiguous()
41
+ return images, labels
42
+
43
+
44
+ def build_data_loader(dataset, micro_batch_size,
45
+ num_workers, drop_last, shuffle):
46
+ """Data loader. Note that batch-size is the local (per GPU) batch-size."""
47
+
48
+ # Sampler.
49
+ world_size = mpu.get_data_parallel_world_size()
50
+ rank = mpu.get_data_parallel_rank()
51
+ sampler = torch.utils.data.distributed.DistributedSampler(
52
+ dataset, num_replicas=world_size, rank=rank,
53
+ drop_last=drop_last, shuffle=shuffle
54
+ )
55
+
56
+ # Data loader. Note that batch size is the per GPU batch size.
57
+ data_loader = torch.utils.data.DataLoader(
58
+ dataset,
59
+ batch_size=micro_batch_size,
60
+ sampler=sampler,
61
+ shuffle=False,
62
+ num_workers=num_workers,
63
+ drop_last=drop_last,
64
+ pin_memory=True,
65
+ )
66
+
67
+ return data_loader
68
+
69
+
70
+ def _build_infinite_size_dataloader(dataloader):
71
+ """Build a looped dataloader with infinite size."""
72
+
73
+ iterator = dataloader.__iter__()
74
+ while True:
75
+ try:
76
+ yield iterator.__next__()
77
+ except StopIteration:
78
+ iterator = dataloader.__iter__()
79
+
80
+
81
+ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
82
+ """Traing and validation dataloaders."""
83
+ args = get_args()
84
+
85
+ print_rank_0('building train and validation dataloaders ...')
86
+ # Training dataset.
87
+ train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
88
+ args.num_workers, False, True)
89
+ # Set the training iterations.
90
+ args.train_iters_per_epoch = len(train_dataloader)
91
+ args.train_iters = args.epochs * args.train_iters_per_epoch
92
+ # Validation dataset. For this dataset, we do not need to set up
93
+ # shuffling so we can just use a simple infinite loop.
94
+ valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
95
+ args.num_workers, True, False)
96
+ valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
97
+
98
+ # Now that we've built the data loaders, set batch_size arguments
99
+ # to the actual batch size the model will see for this dataset.
100
+ # This is necessary so pipeline transfers know what size they are
101
+ # and the LR schedule, which is based on samples seen, gets set
102
+ # correctly.
103
+ args.orig_micro_batch_size = args.micro_batch_size
104
+ args.orig_global_batch_size = args.global_batch_size
105
+
106
+ return train_dataloader, valid_dataloader
107
+
108
+
109
+ def _train(
110
+ model,
111
+ optimizer,
112
+ opt_param_scheduler,
113
+ forward_step,
114
+ train_dataloader,
115
+ valid_dataloader,
116
+ end_of_epoch_callback,
117
+ process_non_loss_data_func=None
118
+ ):
119
+ """Train the model."""
120
+ args = get_args()
121
+ timers = get_timers()
122
+
123
+ # Turn on training mode which enables dropout.
124
+ for m in model:
125
+ m.train()
126
+
127
+ # Tracking loss.
128
+ losses_dict_sum = {}
129
+
130
+ # Starting epoch and iteration
131
+ start_epoch = args.iteration // args.train_iters_per_epoch
132
+ start_iteration = args.iteration % args.train_iters_per_epoch
133
+ iteration = args.iteration
134
+
135
+ # Memory reporting flag.
136
+ report_memory_flag = True
137
+
138
+ # For each remaining epoch
139
+ timers("interval-time").start()
140
+ for epoch in range(start_epoch, args.epochs):
141
+ print_rank_0("working on epoch {} ...".format(epoch + 1))
142
+
143
+ # Set the data loader epoch to shuffle the index iterator.
144
+ train_dataloader.sampler.set_epoch(args.seed + epoch)
145
+ train_dataloader.dataset.set_epoch(epoch)
146
+
147
+ # For all the batches in the dataset.
148
+ for iteration_, batch in enumerate(train_dataloader):
149
+
150
+ # Ignore the iterations before starting value
151
+ if iteration_ < start_iteration:
152
+ continue
153
+ # Set to zero so the next epoch does not skip any batches.
154
+ start_iteration = 0
155
+
156
+ # Train for one step.
157
+ losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
158
+ forward_step, batch, model, optimizer, opt_param_scheduler
159
+ )
160
+ iteration += 1
161
+
162
+ # Logging.
163
+ params_norm = None
164
+
165
+ report_memory_flag = training_log(
166
+ losses_dict,
167
+ losses_dict_sum,
168
+ optimizer.param_groups[0]["lr"],
169
+ iteration,
170
+ optimizer.get_loss_scale().item(),
171
+ report_memory_flag,
172
+ skipped_iter,
173
+ grad_norm,
174
+ params_norm,
175
+ num_zeros_in_grad
176
+ )
177
+
178
+ # Autoresume
179
+ if args.adlr_autoresume and \
180
+ iteration % args.adlr_autoresume_interval == 0:
181
+ check_adlr_autoresume_termination(iteration, model, optimizer,
182
+ opt_param_scheduler)
183
+
184
+ # Checkpointing
185
+ if args.save and args.save_interval and \
186
+ iteration % args.save_interval == 0:
187
+ save_checkpoint(iteration, model, optimizer,
188
+ opt_param_scheduler)
189
+
190
+ # Evaluation
191
+ if args.eval_interval and iteration % args.eval_interval == 0:
192
+ prefix = "iteration {}".format(iteration)
193
+ evaluate_and_print_results(
194
+ prefix,
195
+ forward_step,
196
+ valid_dataloader,
197
+ model,
198
+ iteration,
199
+ process_non_loss_data_func,
200
+ False,
201
+ )
202
+
203
+ # Callback at the end of each epoch.
204
+ if end_of_epoch_callback is not None:
205
+ end_of_epoch_callback(model, epoch)
206
+
207
+
208
+ def finetune(
209
+ train_valid_datasets_provider,
210
+ model_provider,
211
+ forward_step,
212
+ model_type=ModelType.encoder_or_decoder,
213
+ process_non_loss_data_func=None,
214
+ end_of_epoch_callback_provider=None,
215
+ ):
216
+ """Main finetune function used across all tasks."""
217
+ args = get_args()
218
+ timers = get_timers()
219
+
220
+ # Train and validation data loaders.
221
+ timers("train/valid/test dataset/dataloder").start()
222
+ if args.epochs > 0:
223
+ train_dataset, valid_dataset = train_valid_datasets_provider()
224
+ train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
225
+ train_dataset, valid_dataset
226
+ )
227
+ timers("train/valid/test dataset/dataloder").stop()
228
+
229
+ # Build calback function.
230
+ timers("callback function").start()
231
+ end_of_epoch_callback = None
232
+ if end_of_epoch_callback_provider is not None:
233
+ end_of_epoch_callback = end_of_epoch_callback_provider()
234
+ timers("callback function").stop()
235
+
236
+ # Build model, optimizer and learning rate scheduler.
237
+ timers("model and optimizer").start()
238
+ model, optimizer, opt_param_scheduler = \
239
+ setup_model_and_optimizer(
240
+ model_provider,
241
+ model_type,
242
+ scale_lr_cond=lambda name, param: ".head." in name,
243
+ lr_mult=args.head_lr_mult)
244
+ timers("model and optimizer").stop()
245
+
246
+ # If pretrained checkpoint is provided and we have not trained for
247
+ # any iteration (i.e., iteration is zero), then load the pretrained
248
+ # checkpoint.
249
+ timers("pretrained checkpoint").start()
250
+ if args.iteration == 0 and args.pretrained_checkpoint is not None:
251
+ if args.pretrained_checkpoint_type == 'default':
252
+ original_load = args.load
253
+ args.load = args.pretrained_checkpoint
254
+ _ = load_checkpoint(model, None, None, strict=False)
255
+ args.load = original_load
256
+ elif args.pretrained_checkpoint_type == 'external':
257
+ unwrap_model = utils.unwrap_model(model)
258
+ state_dict = torch.load(args.pretrained_checkpoint,
259
+ map_location="cpu")
260
+ unwrap_model[0].module.backbone.load_state_dict(state_dict,
261
+ strict=False)
262
+ elif args.pretrained_checkpoint_type == 'constrastive':
263
+ unwrap_model = utils.unwrap_model(model)
264
+ state_dict = torch.load(args.pretrained_checkpoint,
265
+ map_location="cpu")
266
+ state_dict = state_dict["model"]
267
+ state_dict = {k.replace("teacher.backbone.", ""): v
268
+ for k, v in state_dict.items()
269
+ if k.startswith("teacher.backbone.")}
270
+ unwrap_model[0].module.backbone.load_state_dict(state_dict,
271
+ strict=False)
272
+ else:
273
+ raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type))
274
+
275
+ # This is critical when only model is loaded. We should make sure
276
+ # master parameters are also updated.
277
+ optimizer.reload_model_params()
278
+
279
+ timers("pretrained checkpoint").stop()
280
+
281
+ # Print setup timing.
282
+ print_rank_0("done with setups ...")
283
+ timers.log(
284
+ [
285
+ "train/valid/test dataset/dataloder",
286
+ "callback function",
287
+ "model and optimizer",
288
+ "pretrained checkpoint",
289
+ ]
290
+ )
291
+ print_rank_0("training ...")
292
+
293
+ # Finetune the model.
294
+ if args.epochs > 0:
295
+ _train(
296
+ model,
297
+ optimizer,
298
+ opt_param_scheduler,
299
+ forward_step,
300
+ train_dataloader,
301
+ valid_dataloader,
302
+ end_of_epoch_callback,
303
+ process_non_loss_data_func,
304
+ )
305
+ # Or just evaluate.
306
+ else:
307
+ if end_of_epoch_callback is not None:
308
+ print_rank_0("evaluation only mode, setting epoch to -1")
309
+ end_of_epoch_callback(model, epoch=-1)
310
+
311
+ print_rank_0("done :-)")
312
+
tasks/vision/main.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Main tasks functionality."""
17
+
18
+ import os
19
+ import sys
20
+
21
+ sys.path.append(
22
+ os.path.abspath(
23
+ os.path.join(
24
+ os.path.join(os.path.dirname(__file__), os.path.pardir),
25
+ os.path.pardir,
26
+ )
27
+ )
28
+ )
29
+ from megatron import get_args
30
+ from megatron.initialize import initialize_megatron
31
+
32
+ def get_tasks_args(parser):
33
+ """Provide extra arguments required for tasks."""
34
+ group = parser.add_argument_group(title="tasks")
35
+
36
+ group.add_argument('--task', type=str, default='segment',
37
+ choices=['classify', 'segment_setr', 'segment_segformer'],
38
+ help='task name.')
39
+ group.add_argument("--epochs", type=int, default=None,
40
+ help="Number of finetunning epochs. Zero results in "
41
+ "evaluation only.")
42
+ group.add_argument('--pretrained-checkpoint-type', type=str, default='default',
43
+ choices=['default', 'external', 'constrastive'],
44
+ help='Type of pretrained checkpoint')
45
+ group.add_argument("--pretrained-checkpoint", type=str, default=None,
46
+ help="Pretrained checkpoint used for finetunning.")
47
+ group.add_argument('--seg-stride', type=int, default=None,
48
+ help='sliding window stride during evaluation')
49
+ return parser
50
+
51
+
52
+ if __name__ == "__main__":
53
+
54
+ initialize_megatron(extra_args_provider=get_tasks_args)
55
+ args = get_args()
56
+
57
+ if args.task == 'classify':
58
+ from tasks.vision.classification.classification import main
59
+ main()
60
+ elif args.task == 'segment_setr':
61
+ from tasks.vision.segmentation.finetune_setr import main
62
+ main()
63
+ elif args.task == 'segment_segformer':
64
+ from tasks.vision.segmentation.finetune_segformer import main
65
+ main()
66
+
tasks/vision/segmentation/cityscapes.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BSD 3-Clause License
2
+ #
3
+ # Copyright (c) Soumith Chintala 2016,
4
+ # All rights reserved.
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # * Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # * Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # * Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+
31
+ # code taken from
32
+ # https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
33
+ # modified it to change max label index from 255 to 19 (num_classes)
34
+
35
+ import torch
36
+ import json
37
+ import os
38
+ from collections import namedtuple
39
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
40
+ import numpy as np
41
+ from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str
42
+ from torchvision.datasets import VisionDataset
43
+ from PIL import Image
44
+ from megatron import print_rank_0
45
+
46
+
47
+ class Cityscapes(VisionDataset):
48
+ """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
49
+ Args:
50
+ root (string): Root directory of dataset where directory ``leftImg8bit``
51
+ and ``gtFine`` or ``gtCoarse`` are located.
52
+ split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
53
+ otherwise ``train``, ``train_extra`` or ``val``
54
+ mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
55
+ target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
56
+ or ``color``. Can also be a list to output a tuple with all specified target types.
57
+ transform (callable, optional): A function/transform that takes in a PIL image
58
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
59
+ target_transform (callable, optional): A function/transform that takes in the
60
+ target and transforms it.
61
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
62
+ and returns a transformed version.
63
+ Examples:
64
+ Get semantic segmentation target
65
+ .. code-block:: python
66
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
67
+ target_type='semantic')
68
+ img, smnt = dataset[0]
69
+ Get multiple targets
70
+ .. code-block:: python
71
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
72
+ target_type=['instance', 'color', 'polygon'])
73
+ img, (inst, col, poly) = dataset[0]
74
+ Validate on the "coarse" set
75
+ .. code-block:: python
76
+ dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
77
+ target_type='semantic')
78
+ img, smnt = dataset[0]
79
+ """
80
+ num_classes = 19
81
+ ignore_index = 19
82
+ color_table = torch.tensor(
83
+ [[128, 64, 128],
84
+ [244, 35, 232],
85
+ [70, 70, 70],
86
+ [102, 102, 156],
87
+ [190, 153, 153],
88
+ [153, 153, 153],
89
+ [250, 170, 30],
90
+ [220, 220, 0],
91
+ [107, 142, 35],
92
+ [152, 251, 152],
93
+ [70, 130, 180],
94
+ [220, 20, 60],
95
+ [255, 0, 0],
96
+ [0, 0, 142],
97
+ [0, 0, 70],
98
+ [0, 60, 100],
99
+ [0, 80, 100],
100
+ [0, 0, 230],
101
+ [119, 11, 32],
102
+ [0, 0, 0]], dtype=torch.float, device='cuda')
103
+
104
+
105
+ # Based on https://github.com/mcordts/cityscapesScripts
106
+ CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id',
107
+ 'category', 'category_id', 'has_instances', 'ignore_in_eval', 'color'])
108
+
109
+ classes = [
110
+ CityscapesClass('unlabeled', 0, 19, 'void', 0, False, True, (0, 0, 0)),
111
+ CityscapesClass('ego vehicle', 1, 19, 'void', 0, False, True, (0, 0, 0)),
112
+ CityscapesClass('rectification border', 2, 19, 'void', 0, False, True, (0, 0, 0)),
113
+ CityscapesClass('out of roi', 3, 19, 'void', 0, False, True, (0, 0, 0)),
114
+ CityscapesClass('static', 4, 19, 'void', 0, False, True, (0, 0, 0)),
115
+ CityscapesClass('dynamic', 5, 19, 'void', 0, False, True, (111, 74, 0)),
116
+ CityscapesClass('ground', 6, 19, 'void', 0, False, True, (81, 0, 81)),
117
+ CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
118
+ CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
119
+ CityscapesClass('parking', 9, 19, 'flat', 1, False, True, (250, 170, 160)),
120
+ CityscapesClass('rail track', 10, 19, 'flat', 1, False, True, (230, 150, 140)),
121
+ CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
122
+ CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
123
+ CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
124
+ CityscapesClass('guard rail', 14, 19, 'construction', 2, False, True, (180, 165, 180)),
125
+ CityscapesClass('bridge', 15, 19, 'construction', 2, False, True, (150, 100, 100)),
126
+ CityscapesClass('tunnel', 16, 19, 'construction', 2, False, True, (150, 120, 90)),
127
+ CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
128
+ CityscapesClass('polegroup', 18, 19, 'object', 3, False, True, (153, 153, 153)),
129
+ CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
130
+ CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
131
+ CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
132
+ CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
133
+ CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
134
+ CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
135
+ CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
136
+ CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
137
+ CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
138
+ CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
139
+ CityscapesClass('caravan', 29, 19, 'vehicle', 7, True, True, (0, 0, 90)),
140
+ CityscapesClass('trailer', 30, 19, 'vehicle', 7, True, True, (0, 0, 110)),
141
+ CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
142
+ CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
143
+ CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
144
+ CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
145
+ ]
146
+
147
+ # label2trainid
148
+ label2trainid = { label.id : label.train_id for label in classes}
149
+
150
+ def __init__(
151
+ self,
152
+ root: str,
153
+ split: str = "train",
154
+ mode: str = "fine",
155
+ resolution: int = 1024,
156
+ transform: Optional[Callable] = None,
157
+ target_transform: Optional[Callable] = None,
158
+ transforms: Optional[Callable] = None,
159
+ ) -> None:
160
+ super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
161
+ self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
162
+ self.images_dir = os.path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit', split)
163
+ self.targets_dir = os.path.join(self.root, 'gtFine_trainvaltest/gtFine', split)
164
+ self.split = split
165
+ self.resolution = resolution
166
+ self.images = []
167
+ self.targets = []
168
+
169
+ for city in sorted(os.listdir(self.images_dir)):
170
+ img_dir = os.path.join(self.images_dir, city)
171
+ target_dir = os.path.join(self.targets_dir, city)
172
+ for file_name in os.listdir(img_dir):
173
+ target_name = '{}_{}_labelIds.png'.format(file_name.split('_leftImg8bit')[0], self.mode)
174
+ self.images.append(os.path.join(img_dir, file_name))
175
+ self.targets.append(os.path.join(target_dir, target_name))
176
+
177
+
178
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
179
+ """
180
+ Args:
181
+ index (int): Index
182
+ Returns:
183
+ tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
184
+ than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
185
+ """
186
+ image = Image.open(self.images[index]).convert('RGB')
187
+
188
+ target = Image.open(self.targets[index])
189
+ target = np.array(target)
190
+
191
+ target_copy = target.copy()
192
+ for k, v in Cityscapes.label2trainid.items():
193
+ binary_target = (target == k)
194
+ target_copy[binary_target] = v
195
+ target = target_copy
196
+
197
+ target = Image.fromarray(target.astype(np.uint8))
198
+
199
+ if self.transforms is not None:
200
+ image, target = self.transforms(image, target)
201
+
202
+ return image, target
203
+
204
+ def __len__(self) -> int:
205
+ # len(self.images)
206
+ return len(self.images)
207
+
tasks/vision/segmentation/data.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ import math
4
+ import mmcv
5
+ import torch
6
+ import numpy as np
7
+ import torchvision.transforms as T
8
+ from torchvision import datasets
9
+ from torch.utils.data import Dataset
10
+ from megatron.data.autoaugment import ImageNetPolicy
11
+ from tasks.vision.segmentation.cityscapes import Cityscapes
12
+ import tasks.vision.segmentation.transforms as ET
13
+ from megatron.data.autoaugment import ImageNetPolicy
14
+ from megatron import get_args
15
+ from PIL import Image, ImageOps
16
+
17
+
18
+ class VitSegmentationJointTransform():
19
+ def __init__(self, train=True, resolution=None):
20
+ self.train = train
21
+ if self.train:
22
+ self.transform0 = ET.RandomSizeAndCrop(resolution)
23
+ self.transform1 = ET.RandomHorizontallyFlip()
24
+
25
+ def __call__(self, img, mask):
26
+ if self.train:
27
+ img, mask = self.transform0(img, mask)
28
+ img, mask = self.transform1(img, mask)
29
+ return img, mask
30
+
31
+
32
+ class VitSegmentationImageTransform():
33
+ def __init__(self, train=True, resolution=None):
34
+ args = get_args()
35
+ self.train = train
36
+ assert args.fp16 or args.bf16
37
+ self.data_type = torch.half if args.fp16 else torch.bfloat16
38
+ self.mean_std = args.mean_std
39
+ if self.train:
40
+ assert resolution is not None
41
+ self.transform = T.Compose([
42
+ ET.PhotoMetricDistortion(),
43
+ T.ToTensor(),
44
+ T.Normalize(*self.mean_std),
45
+ T.ConvertImageDtype(self.data_type)
46
+ ])
47
+ else:
48
+ self.transform = T.Compose([
49
+ T.ToTensor(),
50
+ T.Normalize(*self.mean_std),
51
+ T.ConvertImageDtype(self.data_type)
52
+ ])
53
+
54
+ def __call__(self, input):
55
+ output = self.transform(input)
56
+ return output
57
+
58
+
59
+ class VitSegmentationTargetTransform():
60
+ def __init__(self, train=True, resolution=None):
61
+ self.train = train
62
+
63
+ def __call__(self, input):
64
+ output = torch.from_numpy(np.array(input, dtype=np.int32)).long()
65
+ return output
66
+
67
+
68
+ class RandomSeedSegmentationDataset(Dataset):
69
+ def __init__(self,
70
+ dataset,
71
+ joint_transform,
72
+ image_transform,
73
+ target_transform):
74
+
75
+ args = get_args()
76
+ self.base_seed = args.seed
77
+ self.curr_seed = self.base_seed
78
+ self.dataset = dataset
79
+ self.joint_transform = joint_transform
80
+ self.image_transform = image_transform
81
+ self.target_transform = target_transform
82
+
83
+ def __len__(self):
84
+ return len(self.dataset)
85
+
86
+ def set_epoch(self, epoch):
87
+ self.curr_seed = self.base_seed + 100 * epoch
88
+
89
+ def __getitem__(self, idx):
90
+ seed = idx + self.curr_seed
91
+ img, mask = self.dataset[idx]
92
+
93
+ torch.manual_seed(seed)
94
+ random.seed(seed)
95
+ np.random.seed(seed)
96
+ img, mask = self.joint_transform(img, mask)
97
+ img = self.image_transform(img)
98
+ mask = self.target_transform(mask)
99
+
100
+ return img, mask
101
+
102
+
103
+ def build_cityscapes_train_valid_datasets(data_path, image_size):
104
+ args = get_args()
105
+ args.num_classes = Cityscapes.num_classes
106
+ args.ignore_index = Cityscapes.ignore_index
107
+ args.color_table = Cityscapes.color_table
108
+ args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
109
+
110
+ train_joint_transform = \
111
+ VitSegmentationJointTransform(train=True, resolution=image_size)
112
+ val_joint_transform = \
113
+ VitSegmentationJointTransform(train=False, resolution=image_size)
114
+ train_image_transform = \
115
+ VitSegmentationImageTransform(train=True, resolution=image_size)
116
+ val_image_transform = \
117
+ VitSegmentationImageTransform(train=False, resolution=image_size)
118
+ train_target_transform = \
119
+ VitSegmentationTargetTransform(train=True, resolution=image_size)
120
+ val_target_transform = \
121
+ VitSegmentationTargetTransform(train=False, resolution=image_size)
122
+
123
+ # training dataset
124
+ train_data = Cityscapes(
125
+ root=data_path[0],
126
+ split='train',
127
+ mode='fine',
128
+ resolution=image_size
129
+ )
130
+ train_data = RandomSeedSegmentationDataset(
131
+ train_data,
132
+ joint_transform=train_joint_transform,
133
+ image_transform=train_image_transform,
134
+ target_transform=train_target_transform)
135
+
136
+ # validation dataset
137
+ val_data = Cityscapes(
138
+ root=data_path[0],
139
+ split='val',
140
+ mode='fine',
141
+ resolution=image_size
142
+ )
143
+
144
+ val_data = RandomSeedSegmentationDataset(
145
+ val_data,
146
+ joint_transform=val_joint_transform,
147
+ image_transform=val_image_transform,
148
+ target_transform=val_target_transform)
149
+
150
+ return train_data, val_data
151
+
152
+
153
+ def build_train_valid_datasets(data_path, image_size):
154
+ return build_cityscapes_train_valid_datasets(data_path, image_size)
tasks/vision/segmentation/finetune_segformer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Vision-classification finetuning/evaluation."""
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from functools import partial
22
+ from megatron import get_args, get_timers
23
+ from megatron import mpu, print_rank_0, print_rank_last
24
+ from tasks.vision.finetune_utils import finetune
25
+ from tasks.vision.finetune_utils import build_data_loader
26
+ from megatron.utils import average_losses_across_data_parallel_group
27
+ from megatron.schedules import get_forward_backward_func
28
+ from tasks.vision.segmentation.data import build_train_valid_datasets
29
+ from tasks.vision.segmentation.seg_models import SegformerSegmentationModel
30
+ from megatron.model.vision.utils import resize
31
+
32
+
33
+ def calculate_iou(hist_data):
34
+ acc = np.diag(hist_data).sum() / hist_data.sum()
35
+ acc_cls = np.diag(hist_data) / hist_data.sum(axis=1)
36
+ acc_cls = np.nanmean(acc_cls)
37
+ divisor = hist_data.sum(axis=1) + hist_data.sum(axis=0) - \
38
+ np.diag(hist_data)
39
+ iu = np.diag(hist_data) / divisor
40
+ return iu, acc, acc_cls
41
+
42
+
43
+ def fast_hist(pred, gtruth, num_classes):
44
+ # mask indicates pixels we care about
45
+ mask = (gtruth >= 0) & (gtruth < num_classes)
46
+
47
+ # stretch ground truth labels by num_classes
48
+ # class 0 -> 0
49
+ # class 1 -> 19
50
+ # class 18 -> 342
51
+ #
52
+ # TP at 0 + 0, 1 + 1, 2 + 2 ...
53
+ #
54
+ # TP exist where value == num_classes*class_id + class_id
55
+ # FP = row[class].sum() - TP
56
+ # FN = col[class].sum() - TP
57
+ hist = np.bincount(num_classes * gtruth[mask].astype(int) + pred[mask],
58
+ minlength=num_classes ** 2)
59
+ hist = hist.reshape(num_classes, num_classes)
60
+ return hist
61
+
62
+
63
+ def segmentation():
64
+
65
+ def train_valid_datasets_provider():
66
+ """Build train and validation dataset."""
67
+ args = get_args()
68
+
69
+ train_ds, valid_ds = build_train_valid_datasets(
70
+ data_path=args.data_path,
71
+ image_size=(args.img_h, args.img_w)
72
+
73
+ )
74
+ return train_ds, valid_ds
75
+
76
+ def model_provider(pre_process=True, post_process=True):
77
+ """Build the model."""
78
+ args = get_args()
79
+
80
+ model = SegformerSegmentationModel(num_classes=args.num_classes,
81
+ pre_process=pre_process,
82
+ post_process=post_process)
83
+ print_rank_0("model = {}".format(model))
84
+ return model
85
+
86
+ def process_batch(batch):
87
+ """Process batch and produce inputs for the model."""
88
+ images = batch[0].cuda().contiguous()
89
+ masks = batch[1].cuda().contiguous()
90
+ return images, masks
91
+
92
+ def calculate_weight(masks, num_classes):
93
+ bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
94
+ hist_norm = bins.float()/bins.sum()
95
+ hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
96
+ return hist
97
+
98
+ def cross_entropy_loss_func(images, masks, output_tensor,
99
+ non_loss_data=False):
100
+ args = get_args()
101
+ ignore_index = args.ignore_index
102
+ color_table = args.color_table
103
+ logits = output_tensor.contiguous().float()
104
+ logits = resize(logits, size=masks.shape[1:],
105
+ mode='bilinear', align_corners=False)
106
+
107
+ # Cross-entropy loss.
108
+ # weight = calculate_weight(masks, num_classes)
109
+ loss = F.cross_entropy(logits, masks, ignore_index=ignore_index)
110
+
111
+ if not non_loss_data:
112
+ # Reduce loss for logging.
113
+ averaged_loss = average_losses_across_data_parallel_group([loss])
114
+ return loss, {'lm loss': averaged_loss[0]}
115
+ else:
116
+ seg_mask = logits.argmax(dim=1)
117
+ output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
118
+ gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
119
+ return torch.cat((images, output_mask, gt_mask), dim=2), loss
120
+
121
+ def _cross_entropy_forward_step(batch, model):
122
+ """Simple forward step with cross-entropy loss."""
123
+ timers = get_timers()
124
+
125
+ # Get the batch.
126
+ timers("batch generator").start()
127
+ import types
128
+ if isinstance(batch, types.GeneratorType):
129
+ batch_ = next(batch)
130
+ else:
131
+ batch_ = batch
132
+ images, masks = process_batch(batch_)
133
+ timers("batch generator").stop()
134
+
135
+ # Forward model.
136
+ output_tensor = model(images)
137
+
138
+ return output_tensor, partial(cross_entropy_loss_func, images, masks)
139
+
140
+ def calculate_correct_answers(model, dataloader, epoch):
141
+ """Calculate correct over total answers"""
142
+
143
+ forward_backward_func = get_forward_backward_func()
144
+ for m in model:
145
+ m.eval()
146
+
147
+ def loss_func(labels, output_tensor):
148
+ args = get_args()
149
+ logits = output_tensor
150
+ logits = resize(logits, size=labels.shape[1:],
151
+ mode='bilinear', align_corners=False)
152
+
153
+ loss_dict = {}
154
+ # Compute the correct answers.
155
+ probs = logits.contiguous().float().softmax(dim=1)
156
+ max_probs, preds = torch.max(probs, 1)
157
+
158
+ preds = preds.cpu().numpy()
159
+ performs = fast_hist(preds.flatten(),
160
+ labels.cpu().numpy().flatten(),
161
+ args.ignore_index)
162
+ loss_dict['performs'] = performs
163
+ return 0, loss_dict
164
+
165
+ # defined inside to capture output_predictions
166
+ def correct_answers_forward_step(batch, model):
167
+ try:
168
+ batch_ = next(batch)
169
+ except BaseException:
170
+ batch_ = batch
171
+ images, labels = process_batch(batch_)
172
+
173
+ # Forward model.
174
+ output_tensor = model(images)
175
+
176
+ return output_tensor, partial(loss_func, labels)
177
+
178
+ with torch.no_grad():
179
+ # For all the batches in the dataset.
180
+ performs = None
181
+ for _, batch in enumerate(dataloader):
182
+ loss_dicts = forward_backward_func(correct_answers_forward_step,
183
+ batch, model,
184
+ optimizer=None,
185
+ timers=None,
186
+ forward_only=True)
187
+ for loss_dict in loss_dicts:
188
+ if performs is None:
189
+ performs = loss_dict['performs']
190
+ else:
191
+ performs += loss_dict['performs']
192
+
193
+ for m in model:
194
+ m.train()
195
+ # Reduce.
196
+ if mpu.is_pipeline_last_stage():
197
+ performs_tensor = torch.cuda.FloatTensor(performs)
198
+ torch.distributed.all_reduce(performs_tensor,
199
+ group=mpu.get_data_parallel_group())
200
+ hist = performs_tensor.cpu().numpy()
201
+ iu, acc, acc_cls = calculate_iou(hist)
202
+ miou = np.nanmean(iu)
203
+
204
+ return iu, miou
205
+
206
+ def accuracy_func_provider():
207
+ """Provide function that calculates accuracies."""
208
+ args = get_args()
209
+
210
+ train_ds, valid_ds = build_train_valid_datasets(
211
+ data_path=args.data_path,
212
+ image_size=(args.img_h, args.img_w)
213
+ )
214
+ dataloader = build_data_loader(
215
+ valid_ds,
216
+ args.micro_batch_size,
217
+ num_workers=args.num_workers,
218
+ drop_last=(mpu.get_data_parallel_world_size() > 1),
219
+ shuffle=False
220
+ )
221
+
222
+ def metrics_func(model, epoch):
223
+ print_rank_0("calculating metrics ...")
224
+ iou, miou = calculate_correct_answers(model, dataloader, epoch)
225
+ print_rank_last(
226
+ " >> |epoch: {}| overall: iou = {},"
227
+ "miou = {:.4f} %".format(epoch, iou, miou*100.0)
228
+ )
229
+ return metrics_func
230
+
231
+ def dump_output_data(data, iteration, writer):
232
+ for (output_tb, loss) in data:
233
+ # output_tb[output_tb < 0] = 0
234
+ # output_tb[output_tb > 1] = 1
235
+ writer.add_images("image-outputseg-realseg", output_tb,
236
+ global_step=None, walltime=None,
237
+ dataformats='NCHW')
238
+
239
+ """Finetune/evaluate."""
240
+ finetune(
241
+ train_valid_datasets_provider,
242
+ model_provider,
243
+ forward_step=_cross_entropy_forward_step,
244
+ process_non_loss_data_func=dump_output_data,
245
+ end_of_epoch_callback_provider=accuracy_func_provider,
246
+ )
247
+
248
+
249
+ def main():
250
+ segmentation()
251
+
tasks/vision/segmentation/finetune_setr.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Vision-classification finetuning/evaluation."""
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from functools import partial
21
+ from megatron import get_args, get_timers
22
+ from megatron import mpu, print_rank_0, print_rank_last
23
+ from tasks.vision.finetune_utils import finetune
24
+ from tasks.vision.finetune_utils import build_data_loader
25
+ from megatron.utils import average_losses_across_data_parallel_group
26
+ from megatron.schedules import get_forward_backward_func
27
+ from tasks.vision.segmentation.metrics import CFMatrix
28
+ from tasks.vision.segmentation.data import build_train_valid_datasets
29
+ from tasks.vision.segmentation.seg_models import SetrSegmentationModel
30
+ from tasks.vision.segmentation.utils import slidingcrops, slidingjoins
31
+
32
+ def segmentation():
33
+ def train_valid_datasets_provider():
34
+ """Build train and validation dataset."""
35
+ args = get_args()
36
+
37
+ train_ds, valid_ds = build_train_valid_datasets(
38
+ data_path=args.data_path,
39
+ image_size=(args.img_h, args.img_w)
40
+
41
+ )
42
+ return train_ds, valid_ds
43
+
44
+ def model_provider(pre_process=True, post_process=True):
45
+ """Build the model."""
46
+ args = get_args()
47
+
48
+ return SetrSegmentationModel(num_classes=args.num_classes,
49
+ pre_process=pre_process,
50
+ post_process=post_process)
51
+
52
+ def process_batch(batch):
53
+ """Process batch and produce inputs for the model."""
54
+ images = batch[0].cuda().contiguous()
55
+ masks = batch[1].cuda().contiguous()
56
+ return images, masks
57
+
58
+ def calculate_weight(masks, num_classes):
59
+ bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
60
+ hist_norm = bins.float()/bins.sum()
61
+ hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
62
+ return hist
63
+
64
+ def cross_entropy_loss_func(images, masks, output_tensor, non_loss_data=False):
65
+ args = get_args()
66
+ ignore_index = args.ignore_index
67
+ color_table = args.color_table
68
+ weight = calculate_weight(masks, args.num_classes)
69
+ logits = output_tensor.contiguous().float()
70
+ loss = F.cross_entropy(logits, masks, weight=weight, ignore_index=ignore_index)
71
+
72
+ if not non_loss_data:
73
+ # Reduce loss for logging.
74
+ averaged_loss = average_losses_across_data_parallel_group([loss])
75
+
76
+ return loss, {'lm loss': averaged_loss[0]}
77
+ else:
78
+ seg_mask = logits.argmax(dim=1)
79
+ output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
80
+ gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
81
+ return torch.cat((images, output_mask, gt_mask), dim=2), loss
82
+
83
+ def _cross_entropy_forward_step(batch, model):
84
+ """Simple forward step with cross-entropy loss."""
85
+ args = get_args()
86
+ timers = get_timers()
87
+
88
+ # Get the batch.
89
+ timers("batch generator").start()
90
+ import types
91
+ if isinstance(batch, types.GeneratorType):
92
+ batch_ = next(batch)
93
+ else:
94
+ batch_ = batch
95
+ images, masks = process_batch(batch_)
96
+ timers("batch generator").stop()
97
+
98
+ # Forward model.
99
+ if not model.training:
100
+ images, masks, _, _ = slidingcrops(images, masks)
101
+ #print_rank_0("images size = {}".format(images.size()))
102
+
103
+ if not model.training:
104
+ output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
105
+ else:
106
+ output_tensor = model(images)
107
+
108
+ return output_tensor, partial(cross_entropy_loss_func, images, masks)
109
+
110
+ def calculate_correct_answers(model, dataloader, epoch):
111
+ """Calculate correct over total answers"""
112
+
113
+ forward_backward_func = get_forward_backward_func()
114
+ for m in model:
115
+ m.eval()
116
+
117
+ def loss_func(labels, slices_info, img_size, output_tensor):
118
+ args = get_args()
119
+ logits = output_tensor
120
+
121
+ loss_dict = {}
122
+ # Compute the correct answers.
123
+ probs = logits.contiguous().float().softmax(dim=1)
124
+ max_probs, preds = torch.max(probs, 1)
125
+ preds = preds.int()
126
+ preds, labels = slidingjoins(preds, max_probs, labels, slices_info, img_size)
127
+ _, performs = CFMatrix()(preds, labels, args.ignore_index)
128
+
129
+ loss_dict['performs'] = performs
130
+ return 0, loss_dict
131
+
132
+ # defined inside to capture output_predictions
133
+ def correct_answers_forward_step(batch, model):
134
+ args = get_args()
135
+ try:
136
+ batch_ = next(batch)
137
+ except BaseException:
138
+ batch_ = batch
139
+ images, labels = process_batch(batch_)
140
+
141
+ assert not model.training
142
+ images, labels, slices_info, img_size = slidingcrops(images, labels)
143
+ # Forward model.
144
+ output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
145
+
146
+ return output_tensor, partial(loss_func, labels, slices_info, img_size)
147
+
148
+ with torch.no_grad():
149
+ # For all the batches in the dataset.
150
+ performs = None
151
+ for _, batch in enumerate(dataloader):
152
+ loss_dicts = forward_backward_func(correct_answers_forward_step,
153
+ batch, model,
154
+ optimizer=None,
155
+ timers=None,
156
+ forward_only=True)
157
+ for loss_dict in loss_dicts:
158
+ if performs is None:
159
+ performs = loss_dict['performs']
160
+ else:
161
+ performs += loss_dict['performs']
162
+
163
+ for m in model:
164
+ m.train()
165
+ # Reduce.
166
+ if mpu.is_pipeline_last_stage():
167
+ torch.distributed.all_reduce(performs,
168
+ group=mpu.get_data_parallel_group())
169
+ # Print on screen.
170
+ # performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
171
+ true_positive = performs[:, 0]
172
+ false_positive = performs[:, 1]
173
+ false_negative = performs[:, 3]
174
+
175
+ iou = true_positive / (true_positive + false_positive + false_negative)
176
+ miou = iou[~torch.isnan(iou)].mean()
177
+
178
+ return iou.tolist(), miou.item()
179
+
180
+ def accuracy_func_provider():
181
+ """Provide function that calculates accuracies."""
182
+ args = get_args()
183
+
184
+ train_ds, valid_ds = build_train_valid_datasets(
185
+ data_path=args.data_path,
186
+ image_size=(args.img_h, args.img_w)
187
+ )
188
+ dataloader = build_data_loader(
189
+ valid_ds,
190
+ args.micro_batch_size,
191
+ num_workers=args.num_workers,
192
+ drop_last=(mpu.get_data_parallel_world_size() > 1),
193
+ shuffle=False
194
+ )
195
+
196
+ def metrics_func(model, epoch):
197
+ print_rank_0("calculating metrics ...")
198
+ iou, miou = calculate_correct_answers(model, dataloader, epoch)
199
+ print_rank_last(
200
+ " >> |epoch: {}| overall: iou = {},"
201
+ "miou = {:.4f} %".format(epoch, iou, miou*100.0)
202
+ )
203
+ return metrics_func
204
+
205
+ def dump_output_data(data, iteration, writer):
206
+ for (output_tb, loss) in data:
207
+ # output_tb[output_tb < 0] = 0
208
+ # output_tb[output_tb > 1] = 1
209
+ writer.add_images("image-outputseg-realseg", output_tb,
210
+ global_step=None, walltime=None,
211
+ dataformats='NCHW')
212
+
213
+ """Finetune/evaluate."""
214
+ finetune(
215
+ train_valid_datasets_provider,
216
+ model_provider,
217
+ forward_step=_cross_entropy_forward_step,
218
+ process_non_loss_data_func=dump_output_data,
219
+ end_of_epoch_callback_provider=accuracy_func_provider,
220
+ )
221
+
222
+
223
+ def main():
224
+ segmentation()
225
+
tasks/vision/segmentation/metrics.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: UTF-8 -*-
3
+ #copyright (c) go-hiroaki & Chokurei
4
+ #email: [email protected]
5
6
+ #
7
+ #
8
+ # This source code is licensed under the MIT license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ eps = 1e-6
16
+
17
+ def _binarize(y_data, threshold):
18
+ """
19
+ args:
20
+ y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
21
+ threshold : [float] [0.0, 1.0]
22
+ return 4-d binarized y_data
23
+ """
24
+ y_data[y_data < threshold] = 0.0
25
+ y_data[y_data >= threshold] = 1.0
26
+ return y_data
27
+
28
+ def _argmax(y_data, dim):
29
+ """
30
+ args:
31
+ y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
32
+ dim : int
33
+ return 3-d [int] y_data
34
+ """
35
+ return torch.argmax(y_data, dim).int()
36
+
37
+
38
+ def _get_tp(y_pred, y_true):
39
+ """
40
+ args:
41
+ y_true : [int] 3-d in [batch_size, img_rows, img_cols]
42
+ y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
43
+ return [float] true_positive
44
+ """
45
+ return torch.sum(y_true * y_pred).float()
46
+
47
+
48
+ def _get_fp(y_pred, y_true):
49
+ """
50
+ args:
51
+ y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
52
+ y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
53
+ return [float] false_positive
54
+ """
55
+ return torch.sum((1 - y_true) * y_pred).float()
56
+
57
+
58
+ def _get_tn(y_pred, y_true):
59
+ """
60
+ args:
61
+ y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
62
+ y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
63
+ return [float] true_negative
64
+ """
65
+ return torch.sum((1 - y_true) * (1 - y_pred)).float()
66
+
67
+
68
+ def _get_fn(y_pred, y_true):
69
+ """
70
+ args:
71
+ y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
72
+ y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
73
+ return [float] false_negative
74
+ """
75
+ return torch.sum(y_true * (1 - y_pred)).float()
76
+
77
+
78
+ def _get_weights(y_true, nb_ch):
79
+ """
80
+ args:
81
+ y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
82
+ nb_ch : int
83
+ return [float] weights
84
+ """
85
+ batch_size, img_rows, img_cols = y_true.shape
86
+ pixels = batch_size * img_rows * img_cols
87
+ weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
88
+ return weights
89
+
90
+
91
+ class CFMatrix(object):
92
+ def __init__(self, des=None):
93
+ self.des = des
94
+
95
+ def __repr__(self):
96
+ return "ConfusionMatrix"
97
+
98
+ def __call__(self, y_pred, y_true, ignore_index, threshold=0.5):
99
+
100
+ """
101
+ args:
102
+ y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
103
+ y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
104
+ threshold : [0.0, 1.0]
105
+ return confusion matrix
106
+ """
107
+ batch_size, img_rows, img_cols = y_pred.shape
108
+ chs = ignore_index
109
+ device = y_true.device
110
+ if chs == 1:
111
+ y_pred = _binarize(y_pred, threshold)
112
+ y_true = _binarize(y_true, threshold)
113
+ nb_tp = _get_tp(y_pred, y_true)
114
+ nb_fp = _get_fp(y_pred, y_true)
115
+ nb_tn = _get_tn(y_pred, y_true)
116
+ nb_fn = _get_fn(y_pred, y_true)
117
+ mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
118
+ performs = None
119
+ else:
120
+ performs = torch.zeros(chs, 4).to(device)
121
+ weights = _get_weights(y_true, chs)
122
+ for ch in range(chs):
123
+ y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
124
+ y_false_ch = torch.zeros(batch_size, img_rows, img_cols)
125
+ y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
126
+ y_true_ch[y_true == ch] = 1
127
+ y_false_ch[torch.logical_and((y_true != ch), (y_true != ignore_index))] = 1
128
+ y_pred_ch[y_pred == ch] = 1
129
+ nb_tp = _get_tp(y_pred_ch, y_true_ch)
130
+ nb_fp = torch.sum(y_false_ch * y_pred_ch).float()
131
+ nb_tn = torch.sum(y_false_ch * (1 - y_pred_ch)).float()
132
+ nb_fn = _get_fn(y_pred_ch, y_true_ch)
133
+ performs[int(ch), :] = torch.FloatTensor([nb_tp, nb_fp, nb_tn, nb_fn])
134
+ mperforms = sum([i*j for (i, j) in zip(performs, weights)])
135
+ return mperforms, performs
136
+
137
+
138
+ class OAAcc(object):
139
+ def __init__(self, des="Overall Accuracy"):
140
+ self.des = des
141
+
142
+ def __repr__(self):
143
+ return "OAcc"
144
+
145
+ def __call__(self, y_pred, y_true, threshold=0.5):
146
+ """
147
+ args:
148
+ y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
149
+ y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
150
+ threshold : [0.0, 1.0]
151
+ return (tp+tn)/total
152
+ """
153
+ batch_size, chs, img_rows, img_cols = y_true.shape
154
+ device = y_true.device
155
+ if chs == 1:
156
+ y_pred = _binarize(y_pred, threshold)
157
+ y_true = _binarize(y_true, threshold)
158
+ else:
159
+ y_pred = _argmax(y_pred, 1)
160
+ y_true = _argmax(y_true, 1)
161
+
162
+ nb_tp_tn = torch.sum(y_true == y_pred).float()
163
+ mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
164
+ performs = None
165
+ return mperforms, performs
166
+
167
+
168
+ class Precision(object):
169
+ def __init__(self, des="Precision"):
170
+ self.des = des
171
+
172
+ def __repr__(self):
173
+ return "Prec"
174
+
175
+ def __call__(self, y_pred, y_true, threshold=0.5):
176
+ """
177
+ args:
178
+ y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
179
+ y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
180
+ threshold : [0.0, 1.0]
181
+ return tp/(tp+fp)
182
+ """
183
+ batch_size, chs, img_rows, img_cols = y_true.shape
184
+ device = y_true.device
185
+ if chs == 1:
186
+ y_pred = _binarize(y_pred, threshold)
187
+ y_true = _binarize(y_true, threshold)
188
+ nb_tp = _get_tp(y_pred, y_true)
189
+ nb_fp = _get_fp(y_pred, y_true)
190
+ mperforms = nb_tp / (nb_tp + nb_fp + esp)
191
+ performs = None
192
+ else:
193
+ y_pred = _argmax(y_pred, 1)
194
+ y_true = _argmax(y_true, 1)
195
+ performs = torch.zeros(chs, 1).to(device)
196
+ weights = _get_weights(y_true, chs)
197
+ for ch in range(chs):
198
+ y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
199
+ y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
200
+ y_true_ch[y_true == ch] = 1
201
+ y_pred_ch[y_pred == ch] = 1
202
+ nb_tp = _get_tp(y_pred_ch, y_true_ch)
203
+ nb_fp = _get_fp(y_pred_ch, y_true_ch)
204
+ performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
205
+ mperforms = sum([i*j for (i, j) in zip(performs, weights)])
206
+ return mperforms, performs
207
+
208
+
209
+ class Recall(object):
210
+ def __init__(self, des="Recall"):
211
+ self.des = des
212
+
213
+ def __repr__(self):
214
+ return "Reca"
215
+
216
+ def __call__(self, y_pred, y_true, threshold=0.5):
217
+ """
218
+ args:
219
+ y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
220
+ y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
221
+ threshold : [0.0, 1.0]
222
+ return tp/(tp+fn)
223
+ """
224
+ batch_size, chs, img_rows, img_cols = y_true.shape
225
+ device = y_true.device
226
+ if chs == 1:
227
+ y_pred = _binarize(y_pred, threshold)
228
+ y_true = _binarize(y_true, threshold)
229
+ nb_tp = _get_tp(y_pred, y_true)
230
+ nb_fn = _get_fn(y_pred, y_true)
231
+ mperforms = nb_tp / (nb_tp + nb_fn + esp)
232
+ performs = None
233
+ else:
234
+ y_pred = _argmax(y_pred, 1)
235
+ y_true = _argmax(y_true, 1)
236
+ performs = torch.zeros(chs, 1).to(device)
237
+ weights = _get_weights(y_true, chs)
238
+ for ch in range(chs):
239
+ y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
240
+ y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
241
+ y_true_ch[y_true == ch] = 1
242
+ y_pred_ch[y_pred == ch] = 1
243
+ nb_tp = _get_tp(y_pred_ch, y_true_ch)
244
+ nb_fn = _get_fn(y_pred_ch, y_true_ch)
245
+ performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
246
+ mperforms = sum([i*j for (i, j) in zip(performs, weights)])
247
+ return mperforms, performs
248
+
249
+
250
+ class F1Score(object):
251
+ def __init__(self, des="F1Score"):
252
+ self.des = des
253
+
254
+ def __repr__(self):
255
+ return "F1Sc"
256
+
257
+ def __call__(self, y_pred, y_true, threshold=0.5):
258
+
259
+ """
260
+ args:
261
+ y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
262
+ y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
263
+ threshold : [0.0, 1.0]
264
+ return 2*precision*recall/(precision+recall)
265
+ """
266
+ batch_size, chs, img_rows, img_cols = y_true.shape
267
+ device = y_true.device
268
+ if chs == 1:
269
+ y_pred = _binarize(y_pred, threshold)
270
+ y_true = _binarize(y_true, threshold)
271
+ nb_tp = _get_tp(y_pred, y_true)
272
+ nb_fp = _get_fp(y_pred, y_true)
273
+ nb_fn = _get_fn(y_pred, y_true)
274
+ _precision = nb_tp / (nb_tp + nb_fp + esp)
275
+ _recall = nb_tp / (nb_tp + nb_fn + esp)
276
+ mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
277
+ performs = None
278
+ else:
279
+ y_pred = _argmax(y_pred, 1)
280
+ y_true = _argmax(y_true, 1)
281
+ performs = torch.zeros(chs, 1).to(device)
282
+ weights = _get_weights(y_true, chs)
283
+ for ch in range(chs):
284
+ y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
285
+ y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
286
+ y_true_ch[y_true == ch] = 1
287
+ y_pred_ch[y_pred == ch] = 1
288
+ nb_tp = _get_tp(y_pred_ch, y_true_ch)
289
+ nb_fp = _get_fp(y_pred_ch, y_true_ch)
290
+ nb_fn = _get_fn(y_pred_ch, y_true_ch)
291
+ _precision = nb_tp / (nb_tp + nb_fp + esp)
292
+ _recall = nb_tp / (nb_tp + nb_fn + esp)
293
+ performs[int(ch)] = 2 * _precision * \
294
+ _recall / (_precision + _recall + esp)
295
+ mperforms = sum([i*j for (i, j) in zip(performs, weights)])
296
+ return mperforms, performs
297
+
298
+
299
+ class Kappa(object):
300
+ def __init__(self, des="Kappa"):
301
+ self.des = des
302
+
303
+ def __repr__(self):
304
+ return "Kapp"
305
+
306
+ def __call__(self, y_pred, y_true, threshold=0.5):
307
+
308
+ """
309
+ args:
310
+ y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
311
+ y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
312
+ threshold : [0.0, 1.0]
313
+ return (Po-Pe)/(1-Pe)
314
+ """
315
+ batch_size, chs, img_rows, img_cols = y_true.shape
316
+ device = y_true.device
317
+ if chs == 1:
318
+ y_pred = _binarize(y_pred, threshold)
319
+ y_true = _binarize(y_true, threshold)
320
+ nb_tp = _get_tp(y_pred, y_true)
321
+ nb_fp = _get_fp(y_pred, y_true)
322
+ nb_tn = _get_tn(y_pred, y_true)
323
+ nb_fn = _get_fn(y_pred, y_true)
324
+ nb_total = nb_tp + nb_fp + nb_tn + nb_fn
325
+ Po = (nb_tp + nb_tn) / nb_total
326
+ Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
327
+ (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
328
+ mperforms = (Po - Pe) / (1 - Pe + esp)
329
+ performs = None
330
+ else:
331
+ y_pred = _argmax(y_pred, 1)
332
+ y_true = _argmax(y_true, 1)
333
+ performs = torch.zeros(chs, 1).to(device)
334
+ weights = _get_weights(y_true, chs)
335
+ for ch in range(chs):
336
+ y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
337
+ y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
338
+ y_true_ch[y_true == ch] = 1
339
+ y_pred_ch[y_pred == ch] = 1
340
+ nb_tp = _get_tp(y_pred_ch, y_true_ch)
341
+ nb_fp = _get_fp(y_pred_ch, y_true_ch)
342
+ nb_tn = _get_tn(y_pred_ch, y_true_ch)
343
+ nb_fn = _get_fn(y_pred_ch, y_true_ch)
344
+ nb_total = nb_tp + nb_fp + nb_tn + nb_fn
345
+ Po = (nb_tp + nb_tn) / nb_total
346
+ Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
347
+ + (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
348
+ performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
349
+ mperforms = sum([i*j for (i, j) in zip(performs, weights)])
350
+ return mperforms, performs
351
+
352
+
353
+ class Jaccard(object):
354
+ def __init__(self, des="Jaccard"):
355
+ self.des = des
356
+
357
+ def __repr__(self):
358
+ return "Jacc"
359
+
360
+ def __call__(self, y_pred, y_true, threshold=0.5):
361
+ """
362
+ args:
363
+ y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
364
+ y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
365
+ threshold : [0.0, 1.0]
366
+ return intersection / (sum-intersection)
367
+ """
368
+ batch_size, chs, img_rows, img_cols = y_true.shape
369
+ device = y_true.device
370
+ if chs == 1:
371
+ y_pred = _binarize(y_pred, threshold)
372
+ y_true = _binarize(y_true, threshold)
373
+ _intersec = torch.sum(y_true * y_pred).float()
374
+ _sum = torch.sum(y_true + y_pred).float()
375
+ mperforms = _intersec / (_sum - _intersec + esp)
376
+ performs = None
377
+ else:
378
+ y_pred = _argmax(y_pred, 1)
379
+ y_true = _argmax(y_true, 1)
380
+ performs = torch.zeros(chs, 1).to(device)
381
+ weights = _get_weights(y_true, chs)
382
+ for ch in range(chs):
383
+ y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
384
+ y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
385
+ y_true_ch[y_true == ch] = 1
386
+ y_pred_ch[y_pred == ch] = 1
387
+ _intersec = torch.sum(y_true_ch * y_pred_ch).float()
388
+ _sum = torch.sum(y_true_ch + y_pred_ch).float()
389
+ performs[int(ch)] = _intersec / (_sum - _intersec + esp)
390
+ mperforms = sum([i*j for (i, j) in zip(performs, weights)])
391
+ return mperforms, performs
392
+
393
+
394
+ class MSE(object):
395
+ def __init__(self, des="Mean Square Error"):
396
+ self.des = des
397
+
398
+ def __repr__(self):
399
+ return "MSE"
400
+
401
+ def __call__(self, y_pred, y_true, dim=1, threshold=None):
402
+ """
403
+ args:
404
+ y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
405
+ y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
406
+ threshold : [0.0, 1.0]
407
+ return mean_squared_error, smaller the better
408
+ """
409
+ if threshold:
410
+ y_pred = _binarize(y_pred, threshold)
411
+ return torch.mean((y_pred - y_true) ** 2)
412
+
413
+
414
+ class PSNR(object):
415
+ def __init__(self, des="Peak Signal to Noise Ratio"):
416
+ self.des = des
417
+
418
+ def __repr__(self):
419
+ return "PSNR"
420
+
421
+ def __call__(self, y_pred, y_true, dim=1, threshold=None):
422
+ """
423
+ args:
424
+ y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
425
+ y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
426
+ threshold : [0.0, 1.0]
427
+ return PSNR, larger the better
428
+ """
429
+ if threshold:
430
+ y_pred = _binarize(y_pred, threshold)
431
+ mse = torch.mean((y_pred - y_true) ** 2)
432
+ return 10 * torch.log10(1 / mse)
433
+
434
+
435
+ class SSIM(object):
436
+ '''
437
+ modified from https://github.com/jorge-pessoa/pytorch-msssim
438
+ '''
439
+ def __init__(self, des="structural similarity index"):
440
+ self.des = des
441
+
442
+ def __repr__(self):
443
+ return "SSIM"
444
+
445
+ def gaussian(self, w_size, sigma):
446
+ gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
447
+ return gauss/gauss.sum()
448
+
449
+ def create_window(self, w_size, channel=1):
450
+ _1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
451
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
452
+ window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
453
+ return window
454
+
455
+ def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
456
+ """
457
+ args:
458
+ y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
459
+ y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
460
+ w_size : int, default 11
461
+ size_average : boolean, default True
462
+ full : boolean, default False
463
+ return ssim, larger the better
464
+ """
465
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
466
+ if torch.max(y_pred) > 128:
467
+ max_val = 255
468
+ else:
469
+ max_val = 1
470
+
471
+ if torch.min(y_pred) < -0.5:
472
+ min_val = -1
473
+ else:
474
+ min_val = 0
475
+ L = max_val - min_val
476
+
477
+ padd = 0
478
+ (_, channel, height, width) = y_pred.size()
479
+ window = self.create_window(w_size, channel=channel).to(y_pred.device)
480
+
481
+ mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
482
+ mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)
483
+
484
+ mu1_sq = mu1.pow(2)
485
+ mu2_sq = mu2.pow(2)
486
+ mu1_mu2 = mu1 * mu2
487
+
488
+ sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
489
+ sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
490
+ sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2
491
+
492
+ C1 = (0.01 * L) ** 2
493
+ C2 = (0.03 * L) ** 2
494
+
495
+ v1 = 2.0 * sigma12 + C2
496
+ v2 = sigma1_sq + sigma2_sq + C2
497
+ cs = torch.mean(v1 / v2) # contrast sensitivity
498
+
499
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
500
+
501
+ if size_average:
502
+ ret = ssim_map.mean()
503
+ else:
504
+ ret = ssim_map.mean(1).mean(1).mean(1)
505
+
506
+ if full:
507
+ return ret, cs
508
+ return ret
509
+
510
+
511
+ class AE(object):
512
+ """
513
+ Modified from matlab : colorangle.m, MATLAB V2019b
514
+ angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
515
+ angle = 180 / pi * angle;
516
+ """
517
+ def __init__(self, des='average Angular Error'):
518
+ self.des = des
519
+
520
+ def __repr__(self):
521
+ return "AE"
522
+
523
+ def __call__(self, y_pred, y_true):
524
+ """
525
+ args:
526
+ y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
527
+ y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
528
+ return average AE, smaller the better
529
+ """
530
+ dotP = torch.sum(y_pred * y_true, dim=1)
531
+ Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
532
+ Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
533
+ ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
534
+ return ae.mean(1).mean(1)
535
+
536
+
537
+ if __name__ == "__main__":
538
+ for ch in [3, 1]:
539
+ batch_size, img_row, img_col = 1, 224, 224
540
+ y_true = torch.rand(batch_size, ch, img_row, img_col)
541
+ noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
542
+ y_pred = y_true + noise
543
+ for cuda in [False, True]:
544
+ if cuda:
545
+ y_pred = y_pred.cuda()
546
+ y_true = y_true.cuda()
547
+
548
+ print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
549
+ ########### similarity metrics
550
+ metric = MSE()
551
+ acc = metric(y_pred, y_true).item()
552
+ print("{} ==> {}".format(repr(metric), acc))
553
+
554
+ metric = PSNR()
555
+ acc = metric(y_pred, y_true).item()
556
+ print("{} ==> {}".format(repr(metric), acc))
557
+
558
+ metric = SSIM()
559
+ acc = metric(y_pred, y_true).item()
560
+ print("{} ==> {}".format(repr(metric), acc))
561
+
562
+ metric = LPIPS(cuda)
563
+ acc = metric(y_pred, y_true).item()
564
+ print("{} ==> {}".format(repr(metric), acc))
565
+
566
+ metric = AE()
567
+ acc = metric(y_pred, y_true).item()
568
+ print("{} ==> {}".format(repr(metric), acc))
569
+
570
+ ########### accuracy metrics
571
+ metric = OAAcc()
572
+ maccu, accu = metric(y_pred, y_true)
573
+ print('mAccu:', maccu, 'Accu', accu)
574
+
575
+ metric = Precision()
576
+ mprec, prec = metric(y_pred, y_true)
577
+ print('mPrec:', mprec, 'Prec', prec)
578
+
579
+ metric = Recall()
580
+ mreca, reca = metric(y_pred, y_true)
581
+ print('mReca:', mreca, 'Reca', reca)
582
+
583
+ metric = F1Score()
584
+ mf1sc, f1sc = metric(y_pred, y_true)
585
+ print('mF1sc:', mf1sc, 'F1sc', f1sc)
586
+
587
+ metric = Kappa()
588
+ mkapp, kapp = metric(y_pred, y_true)
589
+ print('mKapp:', mkapp, 'Kapp', kapp)
590
+
591
+ metric = Jaccard()
592
+ mjacc, jacc = metric(y_pred, y_true)
593
+ print('mJacc:', mjacc, 'Jacc', jacc)
594
+
tasks/vision/segmentation/seg_heads.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ import einops
17
+ import torch
18
+ import apex
19
+ import torch.nn.functional as F
20
+ from megatron import get_args
21
+ from megatron.model import LayerNorm
22
+ from megatron.model.module import MegatronModule
23
+ from megatron.model.vision.utils import resize
24
+
25
+
26
+ class SetrSegmentationHead(MegatronModule):
27
+ def __init__(self, hidden_size, num_classes):
28
+ super(SetrSegmentationHead, self).__init__()
29
+ args = get_args()
30
+ self.hidden_size = hidden_size
31
+ self.num_classes = num_classes
32
+ self.img_h = args.img_h
33
+ self.img_w = args.img_w
34
+ self.patch_dim = args.patch_dim
35
+
36
+ self.layernorm = LayerNorm(hidden_size, eps=args.layernorm_epsilon)
37
+ self.conv_0 = torch.nn.Conv2d(hidden_size, hidden_size,
38
+ 1, 1, bias=False)
39
+ self.norm_0 = apex.parallel.SyncBatchNorm(hidden_size)
40
+ self.conv_1 = torch.nn.Conv2d(hidden_size, num_classes, 1, 1)
41
+
42
+ def to_2D(self, x):
43
+ n, hw, c = x.shape
44
+ h = self.img_h // self.patch_dim
45
+ w = self.img_w // self.patch_dim
46
+ assert(hw == h * w)
47
+ x = x.transpose(1, 2).reshape(n, c, h, w)
48
+ return x
49
+
50
+ def forward(self, hidden_states):
51
+ # [b c h w]
52
+ hidden_states = self.layernorm(hidden_states)
53
+ hidden_states = self.to_2D(hidden_states)
54
+
55
+ hidden_states = self.conv_0(hidden_states)
56
+ hidden_states = self.norm_0(hidden_states)
57
+ hidden_states = torch.tanh(hidden_states)
58
+ hidden_states = self.conv_1(hidden_states)
59
+
60
+ # [b c h w]
61
+ result = F.interpolate(hidden_states,
62
+ size=(self.img_h, self.img_w),
63
+ mode='bilinear')
64
+
65
+ return result
66
+
67
+
68
+ class MLP(torch.nn.Module):
69
+ """
70
+ Linear Embedding
71
+ """
72
+ def __init__(self, input_dim=2048, embed_dim=768):
73
+ super().__init__()
74
+ self.proj = torch.nn.Linear(input_dim, embed_dim)
75
+
76
+ def forward(self, x):
77
+ x = x.flatten(2).transpose(1, 2)
78
+ x = self.proj(x)
79
+ return x
80
+
81
+
82
+ class SegformerSegmentationHead(MegatronModule):
83
+ def __init__(self, feature_strides, in_channels,
84
+ embedding_dim, dropout_ratio):
85
+ super(SegformerSegmentationHead, self).__init__()
86
+ assert len(feature_strides) == len(in_channels)
87
+ assert min(feature_strides) == feature_strides[0]
88
+ args = get_args()
89
+ self.feature_strides = feature_strides
90
+ self.in_channels = in_channels
91
+ self.embedding_dim = embedding_dim
92
+ self.num_classes = args.num_classes
93
+ self.dropout_ratio = dropout_ratio
94
+
95
+ c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = \
96
+ self.in_channels
97
+
98
+ self.linear_c4 = MLP(input_dim=c4_in_channels,
99
+ embed_dim=self.embedding_dim)
100
+ self.linear_c3 = MLP(input_dim=c3_in_channels,
101
+ embed_dim=self.embedding_dim)
102
+ self.linear_c2 = MLP(input_dim=c2_in_channels,
103
+ embed_dim=self.embedding_dim)
104
+ self.linear_c1 = MLP(input_dim=c1_in_channels,
105
+ embed_dim=self.embedding_dim)
106
+
107
+ self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4,
108
+ self.embedding_dim, 1, 1)
109
+ self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
110
+
111
+ self.dropout = torch.nn.Dropout2d(self.dropout_ratio)
112
+ self.linear_pred = torch.nn.Conv2d(self.embedding_dim,
113
+ self.num_classes,
114
+ kernel_size=1)
115
+
116
+ def forward(self, inputs):
117
+ c1, c2, c3, c4 = inputs
118
+
119
+ ############## MLP decoder on C1-C4 ###########
120
+ n, _, h, w = c4.shape
121
+
122
+ _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
123
+ _c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
124
+
125
+ _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
126
+ _c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
127
+
128
+ _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
129
+ _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
130
+
131
+ _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
132
+
133
+ _c = self.conv_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
134
+ x = self.norm(_c)
135
+ x = F.relu(x, inplace=True)
136
+ x = self.dropout(x)
137
+ x = self.linear_pred(x)
138
+
139
+ return x
140
+
tasks/vision/segmentation/seg_models.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ import einops
17
+ import torch
18
+ import apex
19
+ import torch.nn.functional as F
20
+ from megatron import get_args
21
+ from megatron.model.module import MegatronModule
22
+ from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
23
+ from megatron.model.vision.mit_backbone import mit_b3, mit_b5
24
+ from tasks.vision.segmentation.seg_heads import SetrSegmentationHead, SegformerSegmentationHead
25
+
26
+
27
+ class SetrSegmentationModel(MegatronModule):
28
+
29
+ def __init__(self,
30
+ num_classes,
31
+ pre_process=True,
32
+ post_process=True):
33
+ super(SetrSegmentationModel, self).__init__()
34
+ args = get_args()
35
+ assert post_process & pre_process
36
+ self.hidden_size = args.hidden_size
37
+ self.num_classes = num_classes
38
+ self.backbone = VitBackbone(
39
+ pre_process=pre_process,
40
+ post_process=post_process,
41
+ class_token=False,
42
+ post_layer_norm=False,
43
+ drop_path_rate=0.1
44
+ )
45
+
46
+ self.head = SetrSegmentationHead(
47
+ self.hidden_size,
48
+ self.num_classes
49
+ )
50
+
51
+ def set_input_tensor(self, input_tensor):
52
+ """See megatron.model.transformer.set_input_tensor()"""
53
+ pass
54
+
55
+ def forward(self, input):
56
+ # [b hw c]
57
+ hidden_states = self.backbone(input)
58
+ result_final = self.head(hidden_states)
59
+ return result_final
60
+
61
+
62
+ class SegformerSegmentationModel(MegatronModule):
63
+
64
+ def __init__(self,
65
+ num_classes,
66
+ pre_process=True,
67
+ post_process=True):
68
+ super(SegformerSegmentationModel, self).__init__()
69
+ args = get_args()
70
+ self.hidden_size = args.hidden_size
71
+ self.num_classes = num_classes
72
+ self.pre_process = pre_process
73
+ self.post_process = post_process
74
+
75
+ self.backbone = mit_b5()
76
+ self.head = SegformerSegmentationHead(
77
+ feature_strides=[4, 8, 16, 32],
78
+ in_channels=[64, 128, 320, 512],
79
+ embedding_dim=768,
80
+ dropout_ratio=0.1
81
+ )
82
+
83
+ def set_input_tensor(self, input_tensor):
84
+ """See megatron.model.transformer.set_input_tensor()"""
85
+ pass
86
+
87
+ def forward(self, input):
88
+ # [b hw c]
89
+ hidden_states = self.backbone(input)
90
+ hidden_states = self.head(hidden_states)
91
+ return hidden_states
92
+
tasks/vision/segmentation/transforms.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 The MMSegmenation Authors.
2
+ #
3
+ # This source code is licensed under the Apache license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import os
8
+ import math
9
+ import mmcv
10
+ import torch
11
+ import numpy as np
12
+ import torchvision.transforms as T
13
+ from torchvision import datasets
14
+ from torch.utils.data import Dataset
15
+ from megatron import print_rank_0
16
+ from megatron import get_args
17
+ from PIL import Image, ImageOps, ImageEnhance
18
+ import torchvision.transforms as torch_tr
19
+
20
+ def _is_pil_image(img):
21
+ return isinstance(img, Image.Image)
22
+
23
+ class PhotoMetricDistortion(object):
24
+ """Apply photometric distortion to image sequentially, every transformation
25
+ is applied with a probability of 0.5. The position of random contrast is in
26
+ second or second to last.
27
+ 1. random brightness
28
+ 2. random contrast (mode 0)
29
+ 3. convert color from BGR to HSV
30
+ 4. random saturation
31
+ 5. random hue
32
+ 6. convert color from HSV to BGR
33
+ 7. random contrast (mode 1)
34
+ 8. randomly swap channels
35
+ Args:
36
+ brightness_delta (int): delta of brightness.
37
+ contrast_range (tuple): range of contrast.
38
+ saturation_range (tuple): range of saturation.
39
+ hue_delta (int): delta of hue.
40
+ """
41
+
42
+ def __init__(self,
43
+ brightness_delta=32,
44
+ contrast_range=(0.5, 1.5),
45
+ saturation_range=(0.5, 1.5),
46
+ hue_delta=18):
47
+ self.brightness_delta = brightness_delta
48
+ self.contrast_lower, self.contrast_upper = contrast_range
49
+ self.saturation_lower, self.saturation_upper = saturation_range
50
+ self.hue_delta = hue_delta
51
+
52
+ def convert(self, img, alpha=1, beta=0):
53
+ """Multiple with alpha and add beat with clip."""
54
+ img = img.astype(np.float32) * alpha + beta
55
+ img = np.clip(img, 0, 255)
56
+ return img.astype(np.uint8)
57
+
58
+ def brightness(self, img):
59
+ """Brightness distortion."""
60
+ if random.randint(0, 1):
61
+ return self.convert(
62
+ img,
63
+ beta=random.uniform(-self.brightness_delta,
64
+ self.brightness_delta))
65
+ return img
66
+
67
+ def contrast(self, img):
68
+ """Contrast distortion."""
69
+ if random.randint(0, 1):
70
+ return self.convert(
71
+ img,
72
+ alpha=random.uniform(self.contrast_lower, self.contrast_upper))
73
+ return img
74
+
75
+ def saturation(self, img):
76
+ """Saturation distortion."""
77
+ if random.randint(0, 1):
78
+ img = mmcv.bgr2hsv(img)
79
+ img[:, :, 1] = self.convert(
80
+ img[:, :, 1],
81
+ alpha=random.uniform(self.saturation_lower,
82
+ self.saturation_upper))
83
+ img = mmcv.hsv2bgr(img)
84
+ return img
85
+
86
+ def hue(self, img):
87
+ """Hue distortion."""
88
+ if random.randint(0, 1):
89
+ img = mmcv.bgr2hsv(img)
90
+ img[:, :,
91
+ 0] = (img[:, :, 0].astype(int) +
92
+ random.randint(-self.hue_delta, self.hue_delta)) % 180
93
+ img = mmcv.hsv2bgr(img)
94
+ return img
95
+
96
+ def __call__(self, img):
97
+ """Call function to perform photometric distortion on images.
98
+ Args:
99
+ results (dict): Result dict from loading pipeline.
100
+ Returns:
101
+ dict: Result dict with images distorted.
102
+ """
103
+ img = np.array(img)
104
+
105
+ # random brightness
106
+ img = self.brightness(img)
107
+
108
+ # mode == 0 --> do random contrast first
109
+ # mode == 1 --> do random contrast last
110
+ mode = random.randint(0, 1)
111
+ if mode == 1:
112
+ img = self.contrast(img)
113
+
114
+ # random saturation
115
+ img = self.saturation(img)
116
+
117
+ # random hue
118
+ img = self.hue(img)
119
+
120
+ # random contrast
121
+ if mode == 0:
122
+ img = self.contrast(img)
123
+
124
+ img = Image.fromarray(img.astype(np.uint8)).convert('RGB')
125
+ return img
126
+
127
+
128
+ class RandomCrop(object):
129
+ """
130
+ Take a random crop from the image.
131
+
132
+ First the image or crop size may need to be adjusted if the incoming image
133
+ is too small...
134
+
135
+ If the image is smaller than the crop, then:
136
+ the image is padded up to the size of the crop
137
+ unless 'nopad', in which case the crop size is shrunk to fit the image
138
+
139
+ A random crop is taken such that the crop fits within the image.
140
+
141
+
142
+ if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
143
+ translation randomness of at least that value around the image.
144
+
145
+ if image < crop_size:
146
+ # slide crop within image, random offset
147
+ else:
148
+ # slide image within crop
149
+ """
150
+ def __init__(self, crop_size):
151
+ args = get_args()
152
+ self.size = crop_size
153
+ self.cat_max_ratio = 0.75
154
+ self.ignore_index = args.ignore_index
155
+ self.pad_color = (0, 0, 0)
156
+
157
+ def get_crop_bbox(self, img):
158
+ """Randomly get a crop bounding box."""
159
+ img_w, img_h = img.size
160
+ target_h, target_w = self.size #[H W]
161
+ margin_h = max(img_h - target_h, 0)
162
+ margin_w = max(img_w - target_w, 0)
163
+ offset_h = random.randint(0, margin_h)
164
+ offset_w = random.randint(0, margin_w)
165
+ crop_y1, crop_y2 = offset_h, offset_h + target_h
166
+ crop_x1, crop_x2 = offset_w, offset_w + target_w
167
+
168
+ return crop_y1, crop_y2, crop_x1, crop_x2
169
+
170
+ def crop(self, img, crop_bbox):
171
+ """Crop from ``img``"""
172
+ crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
173
+ img = img.crop((crop_x1, crop_y1, crop_x2, crop_y2))
174
+ return img
175
+
176
+ @staticmethod
177
+ def crop_in_image(target_w, target_h, w, h, img, mask):
178
+ if w == target_w:
179
+ x1 = 0
180
+ else:
181
+ x1 = random.randint(0, w - target_w)
182
+ if h == target_h:
183
+ y1 = 0
184
+ else:
185
+ y1 = random.randint(0, h - target_h)
186
+
187
+ return [img.crop((x1, y1, x1 + target_w, y1 + target_h)),
188
+ mask.crop((x1, y1, x1 + target_w, y1 + target_h))]
189
+
190
+
191
+ def __call__(self, img, mask):
192
+ w, h = img.size
193
+ target_h, target_w = self.size # ASSUME H, W
194
+
195
+ if w == target_w and h == target_h:
196
+ return img, mask
197
+
198
+ # Pad image if image < crop
199
+ if target_h > h:
200
+ pad_h = (target_h - h) // 2 + 1
201
+ else:
202
+ pad_h = 0
203
+ if target_w > w:
204
+ pad_w = (target_w - w) // 2 + 1
205
+ else:
206
+ pad_w = 0
207
+ border = (pad_w, pad_h, pad_w, pad_h)
208
+ if pad_h or pad_w:
209
+ img = ImageOps.expand(img, border=border, fill=(0, 0, 0))
210
+ mask = ImageOps.expand(mask, border=border, fill=self.ignore_index)
211
+ w, h = img.size
212
+
213
+ crop_bbox = self.get_crop_bbox(img)
214
+ if self.cat_max_ratio < 1.:
215
+ # Repeat 10 times
216
+ for _ in range(10):
217
+ seg_temp = self.crop(mask, crop_bbox)
218
+ labels, cnt = np.unique(seg_temp, return_counts=True)
219
+ cnt = cnt[labels != self.ignore_index]
220
+ if len(cnt) > 1 and np.max(cnt) / np.sum(
221
+ cnt) < self.cat_max_ratio:
222
+ break
223
+ crop_bbox = self.get_crop_bbox(img)
224
+
225
+ # crop the image
226
+ img = self.crop(img, crop_bbox)
227
+
228
+ # crop semantic seg
229
+ mask = self.crop(mask, crop_bbox)
230
+ assert(img.size[0] == self.size[1] and img.size[1] == self.size[0])
231
+
232
+ return img, mask
233
+
234
+
235
+ class RandomSizeAndCrop(object):
236
+ def __init__(self,
237
+ crop_size,
238
+ scale_min=0.5,
239
+ scale_max=2.0):
240
+ self.crop = RandomCrop(crop_size)
241
+ self.scale_min = scale_min
242
+ self.scale_max = scale_max
243
+
244
+ def __call__(self, img, mask):
245
+
246
+ scale_amt = random.uniform(self.scale_min, self.scale_max)
247
+ w, h = [int(i * scale_amt) for i in img.size]
248
+
249
+ resized_img = img.resize((w, h), Image.BICUBIC)
250
+ resized_mask = mask.resize((w, h), Image.NEAREST)
251
+ img, mask = self.crop(resized_img, resized_mask)
252
+ return img, mask
253
+
254
+ class RandomHorizontallyFlip(object):
255
+ def __call__(self, img, mask):
256
+ if random.random() < 0.5:
257
+ return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(
258
+ Image.FLIP_LEFT_RIGHT)
259
+ return img, mask
260
+
261
+
262
+ def adjust_brightness(img, brightness_factor):
263
+ """Adjust brightness of an Image.
264
+
265
+ Args:
266
+ img (PIL Image): PIL Image to be adjusted.
267
+ brightness_factor (float): How much to adjust the brightness. Can be
268
+ any non negative number. 0 gives a black image, 1 gives the
269
+ original image while 2 increases the brightness by a factor of 2.
270
+
271
+ Returns:
272
+ PIL Image: Brightness adjusted image.
273
+ """
274
+ if not _is_pil_image(img):
275
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
276
+
277
+ enhancer = ImageEnhance.Brightness(img)
278
+ img = enhancer.enhance(brightness_factor)
279
+ return img
280
+
281
+
282
+ def adjust_contrast(img, contrast_factor):
283
+ """Adjust contrast of an Image.
284
+
285
+ Args:
286
+ img (PIL Image): PIL Image to be adjusted.
287
+ contrast_factor (float): How much to adjust the contrast. Can be any
288
+ non negative number. 0 gives a solid gray image, 1 gives the
289
+ original image while 2 increases the contrast by a factor of 2.
290
+
291
+ Returns:
292
+ PIL Image: Contrast adjusted image.
293
+ """
294
+ if not _is_pil_image(img):
295
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
296
+
297
+ enhancer = ImageEnhance.Contrast(img)
298
+ img = enhancer.enhance(contrast_factor)
299
+ return img
300
+
301
+
302
+ def adjust_saturation(img, saturation_factor):
303
+ """Adjust color saturation of an image.
304
+
305
+ Args:
306
+ img (PIL Image): PIL Image to be adjusted.
307
+ saturation_factor (float): How much to adjust the saturation. 0 will
308
+ give a black and white image, 1 will give the original image while
309
+ 2 will enhance the saturation by a factor of 2.
310
+
311
+ Returns:
312
+ PIL Image: Saturation adjusted image.
313
+ """
314
+ if not _is_pil_image(img):
315
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
316
+
317
+ enhancer = ImageEnhance.Color(img)
318
+ img = enhancer.enhance(saturation_factor)
319
+ return img
320
+
321
+
322
+ def adjust_hue(img, hue_factor):
323
+ """Adjust hue of an image.
324
+
325
+ The image hue is adjusted by converting the image to HSV and
326
+ cyclically shifting the intensities in the hue channel (H).
327
+ The image is then converted back to original image mode.
328
+
329
+ `hue_factor` is the amount of shift in H channel and must be in the
330
+ interval `[-0.5, 0.5]`.
331
+
332
+ See https://en.wikipedia.org/wiki/Hue for more details on Hue.
333
+
334
+ Args:
335
+ img (PIL Image): PIL Image to be adjusted.
336
+ hue_factor (float): How much to shift the hue channel. Should be in
337
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
338
+ HSV space in positive and negative direction respectively.
339
+ 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
340
+ with complementary colors while 0 gives the original image.
341
+
342
+ Returns:
343
+ PIL Image: Hue adjusted image.
344
+ """
345
+ if not(-0.5 <= hue_factor <= 0.5):
346
+ raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
347
+
348
+ if not _is_pil_image(img):
349
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
350
+
351
+ input_mode = img.mode
352
+ if input_mode in {'L', '1', 'I', 'F'}:
353
+ return img
354
+
355
+ h, s, v = img.convert('HSV').split()
356
+
357
+ np_h = np.array(h, dtype=np.uint8)
358
+ # uint8 addition take cares of rotation across boundaries
359
+ with np.errstate(over='ignore'):
360
+ np_h += np.uint8(hue_factor * 255)
361
+ h = Image.fromarray(np_h, 'L')
362
+
363
+ img = Image.merge('HSV', (h, s, v)).convert(input_mode)
364
+ return img
365
+
366
+
367
+ class ColorJitter(object):
368
+ """Randomly change the brightness, contrast and saturation of an image.
369
+
370
+ Args:
371
+ brightness (float): How much to jitter brightness. brightness_factor
372
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
373
+ contrast (float): How much to jitter contrast. contrast_factor
374
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
375
+ saturation (float): How much to jitter saturation. saturation_factor
376
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
377
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
378
+ [-hue, hue]. Should be >=0 and <= 0.5.
379
+ """
380
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
381
+ self.brightness = brightness
382
+ self.contrast = contrast
383
+ self.saturation = saturation
384
+ self.hue = hue
385
+
386
+ @staticmethod
387
+ def get_params(brightness, contrast, saturation, hue):
388
+ """Get a randomized transform to be applied on image.
389
+
390
+ Arguments are same as that of __init__.
391
+
392
+ Returns:
393
+ Transform which randomly adjusts brightness, contrast and
394
+ saturation in a random order.
395
+ """
396
+ transforms = []
397
+ if brightness > 0:
398
+ brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
399
+ transforms.append(
400
+ torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
401
+
402
+ if contrast > 0:
403
+ contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
404
+ transforms.append(
405
+ torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
406
+
407
+ if saturation > 0:
408
+ saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
409
+ transforms.append(
410
+ torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
411
+
412
+ if hue > 0:
413
+ hue_factor = np.random.uniform(-hue, hue)
414
+ transforms.append(
415
+ torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
416
+
417
+ np.random.shuffle(transforms)
418
+ transform = torch_tr.Compose(transforms)
419
+
420
+ return transform
421
+
422
+ def __call__(self, img):
423
+ """
424
+ Args:
425
+ img (PIL Image): Input image.
426
+
427
+ Returns:
428
+ PIL Image: Color jittered image.
429
+ """
430
+ transform = self.get_params(self.brightness, self.contrast,
431
+ self.saturation, self.hue)
432
+ return transform(img)
433
+
tasks/vision/segmentation/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ from megatron import get_args
5
+
6
+ def slidingcrops(img, mask):
7
+ # img: [b c h w]
8
+ # mask: [b h w]
9
+ args = get_args()
10
+ assert args.img_h == args.img_w
11
+ crop_size = args.img_h
12
+ stride = args.seg_stride
13
+ ignore_index = args.ignore_index
14
+ n, c, h, w = img.shape
15
+ assert h >= crop_size
16
+ assert w >= crop_size
17
+ long_size = max(h, w)
18
+
19
+ img_slices, mask_slices, slices_info = [], [], []
20
+ if long_size > crop_size:
21
+ assert stride <= crop_size
22
+ h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1
23
+ w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1
24
+ for yy in range(h_step_num):
25
+ for xx in range(w_step_num):
26
+ sy, sx = yy * stride, xx * stride
27
+ ey, ex = sy + crop_size, sx + crop_size
28
+ img_sub = img[:, :, sy: ey, sx: ex]
29
+ mask_sub = mask[:, sy: ey, sx: ex]
30
+
31
+ # padding
32
+ sub_h, sub_w = img_sub.shape[2:]
33
+ pad_h = max(crop_size - sub_h, 0)
34
+ pad_w = max(crop_size - sub_w, 0)
35
+ img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index)
36
+ mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h))
37
+
38
+ img_slices.append(img_sub)
39
+ mask_slices.append(mask_sub)
40
+ slices_info.append([sy, ey, sx, ex, sub_h, sub_w])
41
+
42
+ return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w)
43
+ else:
44
+ return img, mask, [[0, h, 0, w, h, w]], (h, w)
45
+
46
+
47
+ def slidingjoins(preds, probs, labels, slices_info, img_size):
48
+ args = get_args()
49
+ num_slices = len(slices_info)
50
+
51
+ if num_slices == 1:
52
+ return preds, labels
53
+
54
+ h, w = img_size
55
+ split_size = args.micro_batch_size
56
+
57
+ preds_split = torch.split(preds, split_size)
58
+ probs_split = torch.split(probs, split_size)
59
+ labels_split = torch.split(labels, split_size)
60
+
61
+ assert(len(preds_split) == num_slices)
62
+
63
+ total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda')
64
+ total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
65
+ total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda')
66
+
67
+ for i in range(num_slices):
68
+ sy, ey, sx, ex, sub_h, sub_w = slices_info[i]
69
+ assert sy + sub_h <= h
70
+ assert sx + sub_w <= w
71
+ curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w]
72
+ curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w]
73
+
74
+ local_max_probs = probs_split[i][:, :sub_h, : sub_w]
75
+ local_preds = preds_split[i][:, :sub_h, :sub_w]
76
+
77
+ result_max_probs = torch.maximum(curr_max_probs, local_max_probs)
78
+ result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds)
79
+
80
+ total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs
81
+ total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds
82
+ total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w]
83
+
84
+ return total_preds, total_labels
85
+