diff --git a/tasks/clue/afqmc.py b/tasks/clue/afqmc.py new file mode 100644 index 0000000000000000000000000000000000000000..7774bc744e687d1701c1df915638058b1b074b04 --- /dev/null +++ b/tasks/clue/afqmc.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("AFQMC") + +class AFQMCDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='0'): + self.test_label = test_label + super().__init__('AFQMC', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + if "id" not in row: + row["id"] = index + if first: + first = False + if "label" not in row: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], row["sentence1"].strip(), + row["sentence2"].strip(), self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], row["sentence1"].strip(), + row["sentence2"].strip(), row["label"].strip())) + + text_a = clean_text(row["sentence1"].strip()) + text_b = clean_text(row["sentence2"].strip()) + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + assert len(text_a) > 0 + assert len(text_b) > 0 + assert label in LABELS, "found label {} {}".format(label, row) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/clue/cmnli.py b/tasks/clue/cmnli.py new file mode 100644 index 0000000000000000000000000000000000000000..54473b32d0fe27cbff752fcc6930d9633c595582 --- /dev/null +++ b/tasks/clue/cmnli.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("CMNLI") + +class CMNLIDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='contradiction'): + self.test_label = test_label + super().__init__('CMNLI', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + row["id"] = index + # line = line.strip() + # try: + # row = eval(line) + # except: + # print(">>>>>>>> ", line) + # continue + if first: + first = False + if "label" not in row: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], row["sentence1"].strip(), + row["sentence2"].strip(), self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], row["sentence1"].strip(), + row["sentence2"].strip(), row["label"].strip())) + + text_a = clean_text(row["sentence1"].strip()) + text_b = clean_text(row["sentence2"].strip()) + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + if label == "-": + drop_cnt += 1 + continue + + assert len(text_a) > 0 + assert len(text_b) > 0 + assert label in LABELS, "found label {} {}".format(label, row) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/clue/csl.py b/tasks/clue/csl.py new file mode 100644 index 0000000000000000000000000000000000000000..489db578beeb513a53e85d43bc4df9436efd251f --- /dev/null +++ b/tasks/clue/csl.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("CSL") + +class CSLDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='0'): + self.test_label = test_label + super().__init__('CSL', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + row["id"] = index + if first: + first = False + if "label" not in row: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], " ".join(row["keyword"]).strip(), + row["abst"].strip(), self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], " ".join(row["keyword"]).strip(), + row["abst"].strip(), row["label"].strip())) + + text_a = clean_text(" ".join(row["keyword"]).strip()) + text_b = clean_text(row["abst"].strip()) + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + assert len(text_a) > 0 + assert len(text_b) > 0 + assert label in LABELS, "found label {} {}".format(label, row) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/clue/data.py b/tasks/clue/data.py new file mode 100644 index 0000000000000000000000000000000000000000..357ad130c3ac353bd06163822c5a9443b33d1510 --- /dev/null +++ b/tasks/clue/data.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GLUE dataset.""" + +from abc import ABC +from abc import abstractmethod + +from torch.utils.data import Dataset + +from megatron import print_rank_0 +from tasks.data_utils import build_sample +from tasks.data_utils import build_tokens_types_paddings_from_text + + +class GLUEAbstractDataset(ABC, Dataset): + """GLUE base dataset class.""" + + def __init__(self, task_name, dataset_name, datapaths, + tokenizer, max_seq_length): + # Store inputs. + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + # Process the files. + string = ' > paths:' + for path in datapaths: + string += ' ' + path + print_rank_0(string) + self.samples = [] + for datapath in datapaths: + self.samples.extend(self.process_samples_from_single_path(datapath)) + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + ids, types, paddings = build_tokens_types_paddings_from_text( + raw_sample['text_a'], raw_sample['text_b'], + self.tokenizer, self.max_seq_length) + sample = build_sample(ids, types, paddings, + raw_sample['label'], raw_sample['uid']) + return sample + + @abstractmethod + def process_samples_from_single_path(self, datapath): + """Abstract method that takes a single path / filename and + returns a list of dataset samples, each sample being a dict of + {'text_a': string, 'text_b': string, 'label': int, 'uid': int} + """ + pass diff --git a/tasks/clue/finetune.py b/tasks/clue/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6a3691dddff204c9fbc5e904b7797b2ccd8c7c --- /dev/null +++ b/tasks/clue/finetune.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GLUE finetuning/evaluation.""" + +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_tokenizer +from megatron import mpu +from megatron.model.classification import Classification +from tasks.eval_utils import accuracy_func_provider +from tasks.finetune_utils import finetune + + +def clue_classification(num_classes, Dataset, + name_from_datapath_func): + + def train_valid_datasets_provider(): + """Build train and validation dataset.""" + args = get_args() + tokenizer = get_tokenizer() + + train_dataset = Dataset('training', args.train_data, + tokenizer, args.seq_length) + valid_dataset = Dataset('validation', args.valid_data, + tokenizer, args.seq_length) + + return train_dataset, valid_dataset + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + + print_rank_0('building classification model for {} ...'.format( + args.task)) + model = Classification(num_classes=num_classes, num_tokentypes=2, + pre_process=pre_process, post_process=post_process) + + return model + + def metrics_func_provider(): + """Privde metrics callback function.""" + def single_dataset_provider(datapath): + args = get_args() + tokenizer = get_tokenizer() + name = name_from_datapath_func(datapath) + return Dataset(name, [datapath], tokenizer, args.seq_length) + return accuracy_func_provider(single_dataset_provider) + + """Finetune/evaluate.""" + finetune(train_valid_datasets_provider, model_provider, + end_of_epoch_callback_provider=metrics_func_provider) + + +def main(): + args = get_args() + + if args.task == 'AFQMC': + num_classes = 2 + from tasks.clue.afqmc import AFQMCDataset as Dataset + + def name_from_datapath(datapath): + return "afqmc" + + elif args.task == 'CSL': + num_classes = 2 + from tasks.clue.csl import CSLDataset as Dataset + + def name_from_datapath(datapath): + return "csl" + + elif args.task == 'IFLYTEK': + num_classes = 119 + from tasks.clue.iflytek import IFLYTEKDataset as Dataset + + def name_from_datapath(datapath): + return "iflytek" + + elif args.task == 'OCNLI': + num_classes = 3 + from tasks.clue.ocnli import OCNLIDataset as Dataset + + def name_from_datapath(datapath): + return "ocnli" + + elif args.task == 'TNEWS': + num_classes = 15 + from tasks.clue.tnews import TNEWSDataset as Dataset + + def name_from_datapath(datapath): + return "tnews" + + elif args.task == 'WSC': + num_classes = 2 + from tasks.clue.wsc import WSCDataset as Dataset + + def name_from_datapath(datapath): + return "wsc" + + elif args.task == 'CMNLI': + num_classes = 3 + from tasks.clue.cmnli import CMNLIDataset as Dataset + + def name_from_datapath(datapath): + return "cmnli" + + elif args.task == 'ZC': + num_classes = 2 + from tasks.clue.zc import ZCDataset as Dataset + + def name_from_datapath(datapath): + return "zc" + + else: + raise NotImplementedError('GLUE task {} is not implemented.'.format( + args.task)) + + clue_classification(num_classes, Dataset, name_from_datapath) diff --git a/tasks/clue/iflytek.py b/tasks/clue/iflytek.py new file mode 100644 index 0000000000000000000000000000000000000000..a10fbf42e8f939c60118e0779cd40c15edd00873 --- /dev/null +++ b/tasks/clue/iflytek.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("IFLYTEK") + +class IFLYTEKDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='0'): + self.test_label = test_label + super().__init__('IFLYTEK', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + if "id" not in row: + row["id"] = index + if first: + first = False + if "label" not in row: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], row["sentence"].strip(), + None, self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], row["sentence"].strip(), + None, row["label"].strip())) + + text_a = clean_text(row["sentence"].strip()) + text_b = None + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + assert len(text_a) > 0 + # assert len(text_b) > 0 + assert label in LABELS, "found label {} {}".format(label, row) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/clue/ocnli.py b/tasks/clue/ocnli.py new file mode 100644 index 0000000000000000000000000000000000000000..7a243582d8de42fcde8d897f0946113136230cc2 --- /dev/null +++ b/tasks/clue/ocnli.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("OCNLI") + +class OCNLIDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='contradiction'): + self.test_label = test_label + super().__init__('OCNLI', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + # line = line.strip() + # try: + # row = eval(line) + # except: + # print(">>>>>>>> ", line) + # continue + if first: + first = False + if "label" not in row: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], row["sentence1"].strip(), + row["sentence2"].strip(), self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], row["sentence1"].strip(), + row["sentence2"].strip(), row["label"].strip())) + + text_a = clean_text(row["sentence1"].strip()) + text_b = clean_text(row["sentence2"].strip()) + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + if label == "-": + drop_cnt += 1 + continue + + assert len(text_a) > 0 + assert len(text_b) > 0 + assert label in LABELS, "found label {} {}".format(label, row) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/clue/tnews.py b/tasks/clue/tnews.py new file mode 100644 index 0000000000000000000000000000000000000000..2821e26ae5a69f606ab270dc891acc8d3b13d5b2 --- /dev/null +++ b/tasks/clue/tnews.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("TNEWS") + +class TNEWSDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='100'): + self.test_label = test_label + super().__init__('TNEWS', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + if "id" not in row: + row["id"] = index + if first: + first = False + if "label" not in row: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], row["sentence"].strip(), + None, self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], row["sentence"].strip(), + None, row["label"].strip())) + + text_a = clean_text(row["sentence"].strip()) + text_b = clean_text(row["keywords"].strip()) + # text_b = None + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + assert len(text_a) > 0 + # assert len(text_b) > 0 + assert label in LABELS, "found label {} {}".format(label, row) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/clue/wsc.py b/tasks/clue/wsc.py new file mode 100644 index 0000000000000000000000000000000000000000..23ff6797599b9c43bd7840e51ca4bc09585de8dc --- /dev/null +++ b/tasks/clue/wsc.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("WSC") + +class WSCDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label="false"): + self.test_label = test_label + super().__init__('WSC', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + if "id" not in row: + row["id"] = index + text_a = row['text'] + text_a_list = list(text_a) + target = row['target'] + query = target['span1_text'] + query_idx = target['span1_index'] + pronoun = target['span2_text'] + pronoun_idx = target['span2_index'] + assert text_a[pronoun_idx: (pronoun_idx + len(pronoun))] == pronoun, "pronoun: {}".format(pronoun) + assert text_a[query_idx: (query_idx + len(query))] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_a_list.insert(query_idx, "_") + text_a_list.insert(query_idx + len(query) + 1, "_") + text_a_list.insert(pronoun_idx + 2, "[") + text_a_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_a_list.insert(pronoun_idx, "[") + text_a_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_a_list.insert(query_idx + 2, "_") + text_a_list.insert(query_idx + len(query) + 2 + 1, "_") + text_a = "".join(text_a_list) + # text_b = "在这句话中,{}指代的是{}".format(pronoun, query) + text_b = None + if first: + first = False + if "label" not in row: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], text_a, + text_b, self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], text_a, + text_b, row["label"].strip())) + + text_a = text_a + text_b = text_b + # text_b = None + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + assert len(text_a) > 0 + # assert len(text_b) > 0 + assert label in LABELS, "found label {} {} {}".format(label, row, type(label)) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/clue/zc.py b/tasks/clue/zc.py new file mode 100644 index 0000000000000000000000000000000000000000..36b409313c19eac4ace68a22b88bcfd394092ff4 --- /dev/null +++ b/tasks/clue/zc.py @@ -0,0 +1,96 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset +import json +from tasks.label_dict import get_label_dict + +LABELS = get_label_dict("ZC") + +class ZCDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='negative'): + self.test_label = test_label + super().__init__('ZC', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + print('>>>>', filename) + with open(filename, 'r') as f: + reader = f.readlines() + lines = [] + for line in reader: + lines.append(json.loads(line.strip())) + drop_cnt = 0 + for index, row in enumerate(lines): + # if "id" not in row: + row["id"] = index + if first: + first = False + # if "label" not in row: + if "test.json" in filename: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row["id"], row["text"].strip(), + None, self.test_label)) + else: + is_test = False + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row["id"], row["text"].strip(), + None, row["label"].strip())) + + text_a = clean_text(row["text"].strip()) + text_b = None + unique_id = int(row["id"]) + + if is_test: + label = self.test_label + else: + label = row["label"].strip() + + assert len(text_a) > 0 + # assert len(text_b) > 0 + assert label in LABELS, "found label {} {}".format(label, row) + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + print_rank_0(' >> drop {} samples.'.format(drop_cnt)) + + return samples diff --git a/tasks/data_utils.py b/tasks/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..866a5e69a233d9a9a68a837e156ebb240be6bfee --- /dev/null +++ b/tasks/data_utils.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tasks data utility.""" + +import re +import numpy as np + + +def clean_text(text): + """Remove new lines and multiple spaces and adjust end of sentence dot.""" + + text = text.replace("\n", " ") + text = re.sub(r'\s+', ' ', text) + for _ in range(3): + text = text.replace(' . ', '. ') + + return text + + +def build_sample(ids, types, paddings, label, unique_id): + """Convert to numpy and return a sample consumed by the batch producer.""" + + ids_np = np.array(ids, dtype=np.int64) + types_np = np.array(types, dtype=np.int64) + paddings_np = np.array(paddings, dtype=np.int64) + sample = ({'text': ids_np, + 'types': types_np, + 'padding_mask': paddings_np, + 'label': int(label), + 'uid': int(unique_id)}) + + return sample + + +def build_tokens_types_paddings_from_text(text_a, text_b, + tokenizer, max_seq_length): + """Build token types and paddings, trim if needed, and pad if needed.""" + + text_a_ids = tokenizer.tokenize(text_a) + text_b_ids = None + if text_b is not None: + text_b_ids = tokenizer.tokenize(text_b) + + return build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, + max_seq_length, tokenizer.cls, + tokenizer.sep, tokenizer.pad) + + +def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length, + cls_id, sep_id, pad_id): + """Build token types and paddings, trim if needed, and pad if needed.""" + + ids = [] + types = [] + paddings = [] + + # [CLS]. + ids.append(cls_id) + types.append(0) + paddings.append(1) + + # A. + len_text_a = len(text_a_ids) + ids.extend(text_a_ids) + types.extend([0] * len_text_a) + paddings.extend([1] * len_text_a) + + # [SEP]. + ids.append(sep_id) + types.append(0) + paddings.append(1) + + # B. + if text_b_ids is not None: + len_text_b = len(text_b_ids) + ids.extend(text_b_ids) + types.extend([1] * len_text_b) + paddings.extend([1] * len_text_b) + + # Cap the size. + trimmed = False + if len(ids) >= max_seq_length: + max_seq_length_m1 = max_seq_length - 1 + ids = ids[0:max_seq_length_m1] + types = types[0:max_seq_length_m1] + paddings = paddings[0:max_seq_length_m1] + trimmed = True + + # [SEP]. + if (text_b_ids is not None) or trimmed: + ids.append(sep_id) + if text_b_ids is None: + types.append(0) + else: + types.append(1) + paddings.append(1) + + # Padding. + padding_length = max_seq_length - len(ids) + if padding_length > 0: + ids.extend([pad_id] * padding_length) + types.extend([pad_id] * padding_length) + paddings.extend([0] * padding_length) + + return ids, types, paddings diff --git a/tasks/ensemble_classifier.py b/tasks/ensemble_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..c2333b70154b5761b47bcb7cdf50e11c3d500dda --- /dev/null +++ b/tasks/ensemble_classifier.py @@ -0,0 +1,149 @@ +import os +import argparse +import collections + +import numpy as np +import torch + + +def process_files(args): + all_predictions = collections.OrderedDict() + all_labels = collections.OrderedDict() + all_uid = collections.OrderedDict() + for path in args.paths: + path = os.path.join(path, args.prediction_name) + try: + data = torch.load(path) + for dataset in data: + name, d = dataset + predictions, labels, uid = d + if name not in all_predictions: + all_predictions[name] = np.array(predictions) + if args.labels is None: + args.labels = [i for i in range(all_predictions[name].shape[1])] + if args.eval: + all_labels[name] = np.array(labels) + all_uid[name] = np.array(uid) + else: + all_predictions[name] += np.array(predictions) + assert np.allclose(all_uid[name], np.array(uid)) + except Exception as e: + print(e) + continue + return all_predictions, all_labels, all_uid + + +def get_threshold(all_predictions, all_labels, one_threshold=False): + if one_threshold: + all_predictons = {'combined': np.concatenate(list(all_predictions.values()))} + all_labels = {'combined': np.concatenate(list(all_predictions.labels()))} + out_thresh = [] + for dataset in all_predictions: + preds = all_predictions[dataset] + labels = all_labels[dataset] + out_thresh.append(calc_threshold(preds, labels)) + return out_thresh + + +def calc_threshold(p, l): + trials = [(i) * (1. / 100.) for i in range(100)] + best_acc = float('-inf') + best_thresh = 0 + for t in trials: + acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean() + if acc > best_acc: + best_acc = acc + best_thresh = t + return best_thresh + + +def apply_threshold(preds, t): + assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0]))) + prob = preds[:, -1] + thresholded = (prob >= t).astype(int) + preds = np.zeros_like(preds) + preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1 + return preds + + +def threshold_predictions(all_predictions, threshold): + if len(threshold) != len(all_predictions): + threshold = [threshold[-1]] * (len(all_predictions) - len(threshold)) + for i, dataset in enumerate(all_predictions): + thresh = threshold[i] + preds = all_predictions[dataset] + all_predictions[dataset] = apply_threshold(preds, thresh) + return all_predictions + + +def postprocess_predictions(all_predictions, all_labels, args): + for d in all_predictions: + all_predictions[d] = all_predictions[d] / len(args.paths) + + if args.calc_threshold: + args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold) + print('threshold', args.threshold) + + if args.threshold is not None: + all_predictions = threshold_predictions(all_predictions, args.threshold) + + return all_predictions, all_labels + + +def write_predictions(all_predictions, all_labels, all_uid, args): + all_correct = 0 + count = 0 + for dataset in all_predictions: + preds = all_predictions[dataset] + preds = np.argmax(preds, -1) + if args.eval: + correct = (preds == all_labels[dataset]).sum() + num = len(all_labels[dataset]) + accuracy = correct / num + count += num + all_correct += correct + accuracy = (preds == all_labels[dataset]).mean() + print(accuracy) + if not os.path.exists(os.path.join(args.outdir, dataset)): + os.makedirs(os.path.join(args.outdir, dataset)) + outpath = os.path.join( + args.outdir, dataset, os.path.splitext( + args.prediction_name)[0] + '.tsv') + with open(outpath, 'w') as f: + f.write('id\tlabel\n') + f.write('\n'.join(str(uid) + '\t' + str(args.labels[p]) + for uid, p in zip(all_uid[dataset], preds.tolist()))) + if args.eval: + print(all_correct / count) + + +def ensemble_predictions(args): + all_predictions, all_labels, all_uid = process_files(args) + all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args) + write_predictions(all_predictions, all_labels, all_uid, args) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--paths', required=True, nargs='+', + help='paths to checkpoint directories used in ensemble') + parser.add_argument('--eval', action='store_true', + help='compute accuracy metrics against labels (dev set)') + parser.add_argument('--outdir', + help='directory to place ensembled predictions in') + parser.add_argument('--prediction-name', default='test_predictions.pt', + help='name of predictions in checkpoint directories') + parser.add_argument('--calc-threshold', action='store_true', + help='calculate threshold classification') + parser.add_argument('--one-threshold', action='store_true', + help='use on threshold for all subdatasets') + parser.add_argument('--threshold', nargs='+', default=None, type=float, + help='user supplied threshold for classification') + parser.add_argument('--labels', nargs='+', default=None, + help='whitespace separated list of label names') + args = parser.parse_args() + ensemble_predictions(args) + + +if __name__ == '__main__': + main() diff --git a/tasks/eval_utils.py b/tasks/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07eb72ec35bf56a8748aced24038b50786247331 --- /dev/null +++ b/tasks/eval_utils.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation utilities.""" + +import os +import time +from functools import partial + +import torch + +from megatron import get_args +from megatron import print_rank_last, is_last_rank +from megatron import mpu +from megatron.schedules import get_forward_backward_func +from tasks.finetune_utils import build_data_loader +from tasks.finetune_utils import process_batch +import json +import numpy as np +from tasks.label_dict import get_label_dict + +def accuracy_func_provider(single_dataset_provider): + """Provide function that calculates accuracies.""" + args = get_args() + + # Build dataloaders. + datapaths = [args.valid_data[0], args.test_data[0]] + dataloaders = [] + for datapath in datapaths: + dataset = single_dataset_provider(datapath) + dataloader = build_data_loader( + dataset, args.micro_batch_size, num_workers=args.num_workers, + drop_last=(mpu.get_data_parallel_world_size() > 1)) + dataloaders.append((dataset.dataset_name, dataloader)) + + def _generate_prediction_json(predictions, step, save_acc): + + probs_list = predictions[0] + # labels_list = predictions[1] + ids_list = predictions[2] + min_id = min(ids_list) + max_id = max(ids_list) + LABELS = get_label_dict(args.task, write2file=True) + output_submit_file = os.path.join(args.res_path[0], args.task+"_prediction_{}_{}.json".format(step, save_acc)) + with open(output_submit_file, "w") as writer: + for i in range(min_id, max_id + 1): + label_index = ids_list.index(i) + pred_prob_list = probs_list[label_index] + label = pred_prob_list.index(max(pred_prob_list)) + json_d = {} + if min_id == 1: + json_d['id'] = i - 1 + else: + json_d['id'] = i + json_d["label"] = LABELS[str(label)] + writer.write(json.dumps(json_d) + '\n') + + def _generate_prediction_prob(predictions, step, save_acc): + + probs_list = predictions[0] + ids_list = predictions[2] + min_id = min(ids_list) + max_id = max(ids_list) + + output_prob_file = os.path.join(args.res_path[0], args.task+"_prob_{}_{}".format(step, save_acc)) + prob_arr = [] + for i in range(min_id, max_id + 1): + label_index = ids_list.index(i) + prob_arr.append(probs_list[label_index]) + prob_arr = np.array(prob_arr) + np.save(output_prob_file, prob_arr) + + def metrics_func(model, step): + print_rank_last('calculating metrics ...') + correct = 0 + total = 0 + + for index, (name, dataloader) in enumerate(dataloaders): + if index == 1: + output_predictions = True + assert mpu.get_data_parallel_world_size() == 1 + named_predictions = [] + names = 'predictions' + else: + output_predictions = False + + output = calculate_correct_answers(name, model, dataloader, + step, output_predictions) + if not output_predictions: + correct_ans, total_count = output + else: + correct_ans, total_count, predictions = output + named_predictions.append((name, predictions)) + names += '_' + name + if not output_predictions: + correct += correct_ans + total += total_count + save_acc = str(round(correct / total, 4) * 10000)[:4] + + if output_predictions: + print_rank_last("generate prediction...") + # import pdb;pdb.set_trace() + _generate_prediction_json(predictions, step, save_acc) + _generate_prediction_prob(predictions, step, save_acc) + print_rank_last("generate done") + # import pdb;pdb.set_trace() + # import pdb;pdb.set_trace() + # if is_last_rank(): + # percent = float(correct) * 100.0 / float(total) + # print(' >> |step: {}| overall: correct / total = {} / {} = ' + # '{:.4f} %'.format(step, correct, total, percent)) + # if output_predictions and is_last_rank(): + # assert args.load is not None + # filename = os.path.join(args.load, names + '.pt') + # torch.save(named_predictions, filename) + + return metrics_func + + +def calculate_correct_answers(name, model, dataloader, + step, output_predictions): + """Calculate correct over total answers and return prediction if the + `output_predictions` is true.""" + args = get_args() + forward_backward_func = get_forward_backward_func() + start_time = time.time() + for m in model: + m.eval() + saved_micro_batch_size = args.micro_batch_size + saved_global_batch_size = args.global_batch_size + + ds = dataloader.dataset + if hasattr(ds, 'sample_multiplier'): + # If our dataset as a sample_multiplier attribute that means + # each "sample" from the dataset actually has multiple samples + # that will collapse into the batch dimension (for example in + # the RACE dataset that has several options), we need to + # account for that when setting the micro batch size. + sample_multiplier = ds.sample_multiplier + else: + sample_multiplier = 1 + micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size + num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel + + def loss_func(output_predictions, labels, output_tensor): + logits = output_tensor + + loss_dict = {} + # Add output predictions. + if output_predictions: + # assert False + loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)( + logits.float()).data.cpu().numpy().tolist() + loss_dict['labels'] = labels.data.cpu().numpy().tolist() + loss_dict['ids'] = batch['uid'].cpu().numpy().tolist() + # Compute the correct answers. + predicted = torch.argmax(logits, dim=-1) + corrects = (predicted == labels) + # Add to the counters. + loss_dict['total'] = labels.size(0) + loss_dict['correct'] = corrects.sum().item() + + return 0, loss_dict + + # defined inside to capture output_predictions + def correct_answers_forward_step(batch, model): + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + tokens, types, labels, attention_mask = process_batch(batch_) + + # Forward model. + args = get_args() + output_tensor = model(tokens, attention_mask, tokentype_ids=types) + + return output_tensor, partial(loss_func, output_predictions, labels) + + with torch.no_grad(): + # For all the batches in the dataset. + total = 0 + correct = 0 + if output_predictions: + # This option is only possible when data parallel size is 1. + assert mpu.get_data_parallel_world_size() == 1 + softmaxes = [] + labels = [] + ids = [] + for _, batch in enumerate(dataloader): + # For evaluation only mode we use drop_last = False to get all the + # samples, which means we might not have a full batch, so we + # adjust batch_size here to actual batch size of data + actual_batch_size = len(batch['label']) + # ... applying sample_multiplier if necessary + args.micro_batch_size = actual_batch_size * sample_multiplier + args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches + + loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, + optimizer=None, timers=None, forward_only=True) + + for loss_dict in loss_dicts: + if output_predictions: + softmaxes.extend(loss_dict['softmaxes']) + labels.extend(loss_dict['labels']) + ids.extend(loss_dict['ids']) + total += loss_dict['total'] + correct += loss_dict['correct'] + + + for m in model: + m.train() + args.micro_batch_size = saved_micro_batch_size + args.global_batch_size = saved_global_batch_size + + # Reduce. + if mpu.is_pipeline_last_stage(): + unreduced = torch.cuda.LongTensor([correct, total]) + torch.distributed.all_reduce(unreduced, + group=mpu.get_data_parallel_group()) + + # Print on screen. + + correct_ans = unreduced[0].item() + total_count = unreduced[1].item() + percent = float(correct_ans) * 100.0 / float(total_count) + elapsed_time = time.time() - start_time + if not output_predictions: + print_rank_last(' > |step: {} | metrics for {}: correct / total ' + '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( + step, name, correct_ans, total_count, + percent, elapsed_time)) + + if output_predictions: + return correct_ans, total_count, (softmaxes, labels, ids) + return correct_ans, total_count + if output_predictions: + return 0, 0, () + return 0, 0 diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e6c5e09112fbfbe7044ebc1128f88812cdcca5 --- /dev/null +++ b/tasks/finetune_utils.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Finetune utilities.""" + +from functools import partial +import sys +import torch + +from megatron import get_args, get_num_microbatches +from megatron import print_rank_0 +from megatron import get_timers +from megatron import mpu +from megatron.checkpointing import load_checkpoint +from megatron.checkpointing import save_checkpoint +from megatron.model import ModelType +from megatron.training import evaluate_and_print_results +from megatron.training import setup_model_and_optimizer +from megatron.training import train_step +from megatron.training import training_log +from megatron.utils import average_losses_across_data_parallel_group +from megatron.utils import calc_params_l2_norm +from megatron.utils import check_adlr_autoresume_termination + + +def process_batch(batch): + """Process batch and produce inputs for the model.""" + args = get_args() + + tokens = batch['text'].long().cuda().contiguous() + types = batch['types'].long().cuda().contiguous() + labels = batch['label'].long().cuda().contiguous() + attention_mask = batch['padding_mask'].float().cuda().contiguous() + if args.fp16: + attention_mask = attention_mask.half() + + return tokens, types, labels, attention_mask + + +def cross_entropy_loss_func(labels, output_tensor): + logits = output_tensor + + # Cross-entropy loss. + loss_func = torch.nn.CrossEntropyLoss() + loss = loss_func(logits.contiguous().float(), labels) + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'training loss': averaged_loss[0]} + + +def _cross_entropy_forward_step(batch, model): + """Simple forward step with cross-entropy loss.""" + timers = get_timers() + + # Get the batch. + timers('batch-generator').start() + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + tokens, types, labels, attention_mask = process_batch(batch_) + timers('batch-generator').stop() + + # Forward model. + output_tensor = model(tokens, attention_mask, tokentype_ids=types) + + return output_tensor, partial(cross_entropy_loss_func, labels) + + +def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, + task_collate_fn=None): + """Data loader. Note that batch-size is the local (per GPU) batch-size.""" + + # Sampler. + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank) + + # Data loader. Note that batch size is the per GPU batch size. + data_loader = torch.utils.data.DataLoader(dataset, + batch_size=micro_batch_size, + sampler=sampler, + shuffle=False, + num_workers=num_workers, + drop_last=drop_last, + pin_memory=True, + collate_fn=task_collate_fn) + + return data_loader + + +def _build_infinite_size_dataloader(dataloader): + """Build a looped dataloader with infinite size.""" + + iterator = dataloader.__iter__() + while True: + try: + yield iterator.__next__() + except StopIteration: + iterator = dataloader.__iter__() + + +def _build_train_valid_dataloaders(train_dataset, valid_dataset, + task_collate_fn=None): + """Traing and validation dataloaders.""" + args = get_args() + + print_rank_0('building train and validation dataloaders ...') + # Training dataset. + train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, + args.num_workers, not args.keep_last, + task_collate_fn) + # Set the training iterations. + args.train_iters_per_epoch = len(train_dataloader) + args.train_iters = args.epochs * args.train_iters_per_epoch + # Validation dataset. For this dataset, we do not need to set up + # shuffling so we can just use a simple infinite loop. + valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, + args.num_workers, not args.keep_last, + task_collate_fn) + valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) + + # Now that we've built the data loaders, set batch_size arguments + # to the actual batch size the model will see for this dataset. + # This is necessary so pipeline transfers know what size they are + # and the LR schedule, which is based on samples seen, gets set + # correctly. + args.orig_micro_batch_size = args.micro_batch_size + args.orig_global_batch_size = args.global_batch_size + if hasattr(train_dataset, 'sample_multiplier'): + # If our dataset as a sample_multiplier attribute that means + # each "sample" from the dataset actually has multiple samples + # that will collapse into the batch dimension (for example in + # the RACE dataset that has several options), we need to + # account for that when setting the micro batch size. + args.micro_batch_size *= train_dataset.sample_multiplier + args.global_batch_size *= train_dataset.sample_multiplier + + return train_dataloader, valid_dataloader + + +def _train(model, optimizer, opt_param_scheduler, forward_step, + train_dataloader, valid_dataloader, end_of_epoch_callback): + """Train the model.""" + args = get_args() + timers = get_timers() + + assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work" + + # Turn on training mode which enables dropout. + for m in model: + m.train() + + # Tracking loss. + losses_dict_sum = {} + + # Starting epoch and iteration + start_epoch = args.iteration // args.train_iters_per_epoch + start_iteration = args.iteration % args.train_iters_per_epoch + iteration = args.iteration + + # Memory reporting flag. + report_memory_flag = True + + # For each remaining epoch + timers('interval-time').start() + for epoch in range(start_epoch, args.epochs): + print_rank_0('working on epoch {} ...'.format(epoch + 1)) + + # Set the data loader epoch to shuffle the index iterator. + train_dataloader.sampler.set_epoch(args.seed + epoch) + + # For all the batches in the dataset. + for iteration_, batch in enumerate(train_dataloader): + + # Ignore the iterations before starting value + if iteration_ < start_iteration: + continue + # Set to zero so the next epoch does not skip any batches. + start_iteration = 0 + + # Train for one step. + out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler) + + losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out + iteration += 1 + + # Logging. + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + report_memory_flag = training_log(losses_dict, losses_dict_sum, + optimizer.param_groups[0]['lr'], + iteration, + optimizer.get_loss_scale().item(), + report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad, None) + + # Autoresume + if args.adlr_autoresume and \ + (iteration % args.adlr_autoresume_interval == 0): + check_adlr_autoresume_termination(iteration, model, + optimizer, opt_param_scheduler) + + # Checkpointing + saved_checkpoint = False + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + saved_checkpoint = True + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0: + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results(prefix, forward_step, + valid_dataloader, model, + iteration, None, False) + if end_of_epoch_callback is not None: + end_of_epoch_callback(model, iteration) + print_rank_0('-' * 72 + '\n') + + # Exiting based on iterations + if args.exit_interval and iteration % args.exit_interval == 0: + if not saved_checkpoint: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + torch.distributed.barrier() + print_rank_0('exiting program at iteration {}'.format(iteration)) + sys.exit() + + # Checkpointing at the end of each epoch. + if args.save: + save_checkpoint(iteration, model, optimizer, opt_param_scheduler) + + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results(prefix, forward_step, + valid_dataloader, model, + iteration, None, False) + if end_of_epoch_callback is not None: + end_of_epoch_callback(model, iteration) + print_rank_0('-' * 72 + '\n') + + # Callback at the end of each epoch. + # if end_of_epoch_callback is not None: + # end_of_epoch_callback(model, epoch) + + +def finetune(train_valid_datasets_provider, model_provider, + model_type=ModelType.encoder_or_decoder, + forward_step=_cross_entropy_forward_step, + end_of_epoch_callback_provider=None, + task_collate_fn=None): + """Main finetune function used across all tasks.""" + args = get_args() + timers = get_timers() + + assert args.rampup_batch_size is None, \ + 'batch size scaling is not supported for finetuning' + + # Train and validation data loaders. + timers('train/valid/test dataset/dataloder').start() + if args.epochs > 0: + train_dataset, valid_dataset = train_valid_datasets_provider() + train_dataloader, valid_dataloader = _build_train_valid_dataloaders( + train_dataset, valid_dataset, task_collate_fn) + else: + args.train_iters = 0 + timers('train/valid/test dataset/dataloder').stop() + + # Build calback function. + timers('callback function').start() + end_of_epoch_callback = None + if end_of_epoch_callback_provider is not None: + end_of_epoch_callback = end_of_epoch_callback_provider() + timers('callback function').stop() + + # Build model, optimizer and learning rate scheduler. + timers('model and optimizer').start() + model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type) + timers('model and optimizer').stop() + + # If pretrained checkpoint is provided and we have not trained for + # any iteration (i.e., iteration is zero), then load the pretrained + # checkpoint. + timers('pretrained checkpoint').start() + if args.iteration == 0 and args.pretrained_checkpoint is not None: + original_load = args.load + args.load = args.pretrained_checkpoint + original_rng = args.no_load_rng + args.no_load_rng = True + _ = load_checkpoint(model, None, None) + args.load = original_load + args.no_load_rng = original_rng + # This is critical when only model is loaded. We should make sure + # main parameters are also updated. + optimizer.reload_model_params() + timers('pretrained checkpoint').stop() + + # Print setup timing. + print_rank_0('done with setups ...') + timers.log(['train/valid/test dataset/dataloder', 'callback function', + 'model and optimizer', 'pretrained checkpoint']) + print_rank_0('training ...') + + # Finetune the model. + if args.epochs > 0: + _train(model, optimizer, opt_param_scheduler, forward_step, + train_dataloader, valid_dataloader, end_of_epoch_callback) + # Or just evaluate. + else: + print_rank_0("Not Imp") + import pdb;pdb.set_trace() + # if end_of_epoch_callback is not None: + # print_rank_0('evaluation only mode, setting epoch to -1') + # end_of_epoch_callback(model, epoch=-1, output_predictions=True) + print_rank_0('done :-)') diff --git a/tasks/glue/data.py b/tasks/glue/data.py new file mode 100644 index 0000000000000000000000000000000000000000..357ad130c3ac353bd06163822c5a9443b33d1510 --- /dev/null +++ b/tasks/glue/data.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GLUE dataset.""" + +from abc import ABC +from abc import abstractmethod + +from torch.utils.data import Dataset + +from megatron import print_rank_0 +from tasks.data_utils import build_sample +from tasks.data_utils import build_tokens_types_paddings_from_text + + +class GLUEAbstractDataset(ABC, Dataset): + """GLUE base dataset class.""" + + def __init__(self, task_name, dataset_name, datapaths, + tokenizer, max_seq_length): + # Store inputs. + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + # Process the files. + string = ' > paths:' + for path in datapaths: + string += ' ' + path + print_rank_0(string) + self.samples = [] + for datapath in datapaths: + self.samples.extend(self.process_samples_from_single_path(datapath)) + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + ids, types, paddings = build_tokens_types_paddings_from_text( + raw_sample['text_a'], raw_sample['text_b'], + self.tokenizer, self.max_seq_length) + sample = build_sample(ids, types, paddings, + raw_sample['label'], raw_sample['uid']) + return sample + + @abstractmethod + def process_samples_from_single_path(self, datapath): + """Abstract method that takes a single path / filename and + returns a list of dataset samples, each sample being a dict of + {'text_a': string, 'text_b': string, 'label': int, 'uid': int} + """ + pass diff --git a/tasks/glue/finetune.py b/tasks/glue/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1938b0c3fd087a79c5ac3dd76e45d97ce38106 --- /dev/null +++ b/tasks/glue/finetune.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GLUE finetuning/evaluation.""" + +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_tokenizer +from megatron import mpu +from megatron.model.classification import Classification +from tasks.eval_utils import accuracy_func_provider +from tasks.finetune_utils import finetune + + +def glue_classification(num_classes, Dataset, + name_from_datapath_func): + + def train_valid_datasets_provider(): + """Build train and validation dataset.""" + args = get_args() + tokenizer = get_tokenizer() + + train_dataset = Dataset('training', args.train_data, + tokenizer, args.seq_length) + valid_dataset = Dataset('validation', args.valid_data, + tokenizer, args.seq_length) + + return train_dataset, valid_dataset + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + + print_rank_0('building classification model for {} ...'.format( + args.task)) + model = Classification(num_classes=num_classes, num_tokentypes=2, + pre_process=pre_process, post_process=post_process) + + return model + + def metrics_func_provider(): + """Privde metrics callback function.""" + def single_dataset_provider(datapath): + args = get_args() + tokenizer = get_tokenizer() + + name = name_from_datapath_func(datapath) + return Dataset(name, [datapath], tokenizer, args.seq_length) + return accuracy_func_provider(single_dataset_provider) + + """Finetune/evaluate.""" + finetune(train_valid_datasets_provider, model_provider, + end_of_epoch_callback_provider=metrics_func_provider) + + +def main(): + args = get_args() + + if args.task == 'MNLI': + + num_classes = 3 + from tasks.glue.mnli import MNLIDataset as Dataset + + def name_from_datapath(datapath): + return datapath.split('MNLI')[-1].strip( + '.tsv').strip('/').replace('_', '-') + + elif args.task == 'QQP': + + num_classes = 2 + from tasks.glue.qqp import QQPDataset as Dataset + + def name_from_datapath(datapath): + return datapath.split('QQP')[-1].strip( + '.tsv').strip('/').replace('_', '-') + + else: + raise NotImplementedError('GLUE task {} is not implemented.'.format( + args.task)) + + glue_classification(num_classes, Dataset, name_from_datapath) diff --git a/tasks/glue/mnli.py b/tasks/glue/mnli.py new file mode 100644 index 0000000000000000000000000000000000000000..547a2a0052e92d184d155f13b6576c43eee4546d --- /dev/null +++ b/tasks/glue/mnli.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MNLI dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset + + +LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2} + + +class MNLIDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label='contradiction'): + self.test_label = test_label + super().__init__('MNLI', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + for line in f: + row = line.strip().split('\t') + if first: + first = False + if len(row) == 10: + is_test = True + print_rank_0( + ' reading {}, {} and {} columns and setting ' + 'labels to {}'.format( + row[0].strip(), row[8].strip(), + row[9].strip(), self.test_label)) + else: + print_rank_0(' reading {} , {}, {}, and {} columns ' + '...'.format( + row[0].strip(), row[8].strip(), + row[9].strip(), row[-1].strip())) + continue + + text_a = clean_text(row[8].strip()) + text_b = clean_text(row[9].strip()) + unique_id = int(row[0].strip()) + label = row[-1].strip() + if is_test: + label = self.test_label + + assert len(text_a) > 0 + assert len(text_b) > 0 + assert label in LABELS + assert unique_id >= 0 + + sample = {'text_a': text_a, + 'text_b': text_b, + 'label': LABELS[label], + 'uid': unique_id} + total += 1 + samples.append(sample) + + if total % 50000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + return samples diff --git a/tasks/glue/qqp.py b/tasks/glue/qqp.py new file mode 100644 index 0000000000000000000000000000000000000000..a6adbd096c0fca59a49f55b7a81ebd680f893568 --- /dev/null +++ b/tasks/glue/qqp.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QQP dataset.""" + +from megatron import print_rank_0 +from tasks.data_utils import clean_text +from .data import GLUEAbstractDataset + + +LABELS = [0, 1] + + +class QQPDataset(GLUEAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, + test_label=0): + self.test_label = test_label + super().__init__('QQP', name, datapaths, + tokenizer, max_seq_length) + + def process_samples_from_single_path(self, filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + + samples = [] + total = 0 + first = True + is_test = False + with open(filename, 'r') as f: + for line in f: + row = line.strip().split('\t') + if first: + first = False + if len(row) == 3: + is_test = True + print_rank_0(' reading {}, {}, and {} columns and ' + 'setting labels to {}'.format( + row[0].strip(), row[1].strip(), + row[2].strip(), self.test_label)) + else: + assert len(row) == 6 + print_rank_0(' reading {}, {}, {}, and {} columns' + ' ...'.format( + row[0].strip(), row[3].strip(), + row[4].strip(), row[5].strip())) + continue + + if is_test: + assert len(row) == 3, 'expected length 3: {}'.format(row) + uid = int(row[0].strip()) + text_a = clean_text(row[1].strip()) + text_b = clean_text(row[2].strip()) + label = self.test_label + assert len(text_a) > 0 + assert len(text_b) > 0 + else: + if len(row) == 6: + uid = int(row[0].strip()) + text_a = clean_text(row[3].strip()) + text_b = clean_text(row[4].strip()) + label = int(row[5].strip()) + else: + print_rank_0('***WARNING*** index error, ' + 'skipping: {}'.format(row)) + continue + if len(text_a) == 0: + print_rank_0('***WARNING*** zero length a, ' + 'skipping: {}'.format(row)) + continue + if len(text_b) == 0: + print_rank_0('***WARNING*** zero length b, ' + 'skipping: {}'.format(row)) + continue + assert label in LABELS + assert uid >= 0 + + sample = {'uid': uid, + 'text_a': text_a, + 'text_b': text_b, + 'label': label} + total += 1 + samples.append(sample) + + if total % 50000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + return samples diff --git a/tasks/label_dict.py b/tasks/label_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..c439ec18efa943b380038a42fe1264fb13bf0fac --- /dev/null +++ b/tasks/label_dict.py @@ -0,0 +1,73 @@ + +AFQMC_LABELS = { + '0': '0', + '1': '1', +} + +CSL_LABELS = { + '0': '0', + '1': '1', + '2': '2', +} + +IFLYTEK_LABELS = {} +for i in range(119): + IFLYTEK_LABELS[str(i)] = str(i) + +OCNLI_LABELS = { + 'contradiction': '0', + 'entailment': '1', + 'neutral': '2' +} + +CMNLI_LABELS = { + 'contradiction': '0', + 'entailment': '1', + 'neutral': '2' +} + +TNEWS_LABELS = {} +tnews_list = [] +for i in range(17): + if i == 5 or i == 11: + continue + tnews_list.append(i) +for i in range(len(tnews_list)): + TNEWS_LABELS[str(100 + tnews_list[i])] = str(i) + +WSC_LABELS = { + 'true': '0', + 'false': '1', +} + +ZC_LABELS = { + 'negative': '0', + 'positive': '1', +} + +def get_label_dict(task_name, write2file=False): + + if task_name == "AFQMC": + label_dict = AFQMC_LABELS + elif task_name == "CSL": + label_dict = CSL_LABELS + elif task_name == "IFLYTEK": + label_dict = IFLYTEK_LABELS + elif task_name == "OCNLI": + label_dict = OCNLI_LABELS + elif task_name == "TNEWS": + label_dict = TNEWS_LABELS + elif task_name == "WSC": + label_dict = WSC_LABELS + elif task_name == "CMNLI": + label_dict = CMNLI_LABELS + elif task_name == "ZC": + label_dict = ZC_LABELS + else: + print("Not Imp") + import pdb;pdb.set_trace() + + if write2file: + label_dict = {v:k for k,v in label_dict.items()} + + return label_dict \ No newline at end of file diff --git a/tasks/main.py b/tasks/main.py new file mode 100644 index 0000000000000000000000000000000000000000..27bf89b7b94ad36dcdeb60a77040cec14a2bbe4d --- /dev/null +++ b/tasks/main.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main tasks functionality.""" + +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.pardir))) + +from megatron import get_args +from megatron.initialize import initialize_megatron + + +def get_tasks_args(parser): + """Provide extra arguments required for tasks.""" + group = parser.add_argument_group(title='tasks') + + group.add_argument('--task', type=str, required=True, + help='Task name.') + group.add_argument('--epochs', type=int, default=None, + help='Number of finetunning epochs. Zero results in ' + 'evaluation only.') + group.add_argument('--pretrained-checkpoint', type=str, default=None, + help='Pretrained checkpoint used for finetunning.') + group.add_argument('--keep-last', action='store_true', + help='Keep the last batch (maybe incomplete) in' + 'the data loader') + group.add_argument('--train-data', nargs='+', default=None, + help='Whitespace separated paths or corpora names ' + 'for training.') + group.add_argument('--valid-data', nargs='*', default=None, + help='path(s) to the validation data.') + group.add_argument('--test-data', nargs='*', default=None, + help='path(s) to the test data.') + group.add_argument('--res-path', nargs='*', default=None, + help='path(s) to the test result.') + group.add_argument('--overlapping-eval', type=int, default=32, + help='Sliding window for overlapping evaluation.') + group.add_argument('--strict-lambada', action='store_true', + help='Use more difficult formulation of lambada.') + # Retriever args + group.add_argument('--qa-data-dev', type=str, default=None, + help='Path to the QA dataset dev file.') + group.add_argument('--qa-data-test', type=str, default=None, + help='Path to the QA dataset test file.') + + # Faiss arguments for retriever + group.add_argument('--faiss-use-gpu', action='store_true', + help='Whether create the FaissMIPSIndex on GPU') + group.add_argument('--faiss-match', type=str, default='string', \ + choices=['regex', 'string'], help="Answer matching '\ + 'logic type") + group.add_argument('--faiss-topk-retrievals', type=int, default=100, + help='Number of blocks to use as top-k during retrieval') + + # finetune for retriever + group.add_argument('--eval-micro-batch-size', type=int, default=None, + help='Eval Batch size per model instance (local batch ' + 'size). Global batch size is local batch size ' + 'times data parallel size.') + group.add_argument('--train-with-neg', action='store_true', + help='Whether to use negative examples during model ' + 'training') + group.add_argument('--train-hard-neg', type=int, default=0, + help='Number of hard negative exmaples to use during ' + 'training') + + + # parameters for Av.rank validation method + # Following options/arguments have been taken directly from DPR codebase + group.add_argument('--val-av-rank-hard-neg', type=int, default=30, + help='Av.rank validation: how many hard negatives to' + ' take from each question pool') + group.add_argument('--val-av-rank-other-neg', type=int, default=30, + help='Av.rank validation: how many other negatives to' + ' take from each question pool') + + + return parser + + +if __name__ == '__main__': + + initialize_megatron(extra_args_provider=get_tasks_args) + + args = get_args() + + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for downstream tasks.") + exit() + + if args.task == 'RACE': + from race.finetune import main + elif args.task in ['MNLI', 'QQP']: + from glue.finetune import main + elif args.task in ['AFQMC', 'CSL', 'IFLYTEK','OCNLI', 'TNEWS', 'WSC', 'CMNLI', "ZC"]: + from clue.finetune import main + elif args.task in ['LAMBADA', 'WIKITEXT103']: + from zeroshot_gpt.evaluate import main + elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']: + from orqa.evaluate_orqa import main + elif args.task in ['RET-FINETUNE-NQ']: + from orqa.supervised.finetune import main + else: + raise NotImplementedError('Task {} is not implemented.'.format( + args.task)) + + main() diff --git a/tasks/msdp/README.md b/tasks/msdp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..27c8728eca146aea44c627a99d5f80184b6fbf84 --- /dev/null +++ b/tasks/msdp/README.md @@ -0,0 +1,19 @@ + +# Multi-Stage Prompting for Knowledgeable Dialogue Generation + +Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework. + +## Multi-Stage Dialogue Prompting + +### Data Preparation +1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/) +2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets. + +### Stage-1: Prompting for Knowledge Generation +1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation. +2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation. + +### Stage-2: Prompting for Response Generation +1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file). +2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation. +3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation. diff --git a/tasks/msdp/evaluate.py b/tasks/msdp/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..18e2b1e08557b8834d3ca7ac5f1cb979b468301d --- /dev/null +++ b/tasks/msdp/evaluate.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model evaluation""" + +from megatron import get_args +from megatron import print_rank_0 +from tasks.msdp.metrics import F1Metric +from tqdm import tqdm + + +def evaluate_f1(guess_file, answer_file): + """Evaluating F1 Score""" + + guess_list = [] + print_rank_0('reading %s' % guess_file) + with open(guess_file, "r") as f: + for i, line in enumerate(tqdm(f)): + line = line.strip() + if "<|endoftext|>" in line: + line = line.replace("<|endoftext|>", "") + guess_list.append(line) + + answer_list = [] + print_rank_0('reading %s' % answer_file) + with open(answer_file, "r") as f: + for i, line in enumerate(tqdm(f)): + line = line.strip() + if line == "no_passages_used": + line = "" + answer_list.append(line) + + assert len(guess_list) == len(answer_list), \ + "lengths of guess and answer are different!" + + precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list) + print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1)) + + print_rank_0('done :-)') + + +def main(): + args = get_args() + + evaluate_f1(args.guess_file, args.answer_file) + diff --git a/tasks/msdp/main.py b/tasks/msdp/main.py new file mode 100644 index 0000000000000000000000000000000000000000..4966913fc03921a6a784e7daf68bcfd8692dcf7e --- /dev/null +++ b/tasks/msdp/main.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run multi-stage dialogue prompting (MSDP).""" + +import os +import sys +sys.path.append(os.path.abspath(os.path.join( + os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir))) +from megatron import get_args +from megatron.initialize import initialize_megatron + + +def get_tasks_args(parser): + """Provide extra arguments required for tasks.""" + group = parser.add_argument_group(title='tasks') + + # parameters for the knowledgeable dialogue generation + group.add_argument('--task', type=str, required=True, + help='Task name.') + group.add_argument("--sample-input-file", type=str, default=None, + help='Get input from file instead of interactive mode, ' + 'each line is an input.') + group.add_argument("--sample-output-file", type=str, default=None, + help='Output file got from --sample-input-file') + group.add_argument('--prompt-file', type=str, default=None, + help='prompting file') + group.add_argument('--prompt-type', type=str, default=None, + choices=['knowledge', 'response'], + help='prompt type (knowledge or response)') + group.add_argument('--num-prompt-examples', type=int, default=10, + help='number of prompt examples') + group.add_argument('--guess-file', type=str, default=None, + help='datapath for generated sentences') + group.add_argument('--answer-file', type=str, default=None, + help='datapath for golden sentences') + group.add_argument('--out-seq-length', type=int, default=100, + help='output sequence length') + group.add_argument('--api-prompt', default=False, action="store_true", + help='setup model api for prompting') + group.add_argument('--megatron-api-url', type=str, default=None, + help='url of the megatron api') + + return parser + + +if __name__ == '__main__': + + initialize_megatron(extra_args_provider=get_tasks_args) + + args = get_args() + + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for downstream tasks.") + exit() + + if args.task == 'MSDP-PROMPT': + from tasks.msdp.prompt import main + + elif args.task == 'MSDP-EVAL-F1': + from tasks.msdp.evaluate import main + + else: + raise NotImplementedError('Task {} is not implemented.'.format( + args.task)) + + main() diff --git a/tasks/msdp/metrics.py b/tasks/msdp/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7ce5dd5ad986f4d985eb962629627ac2f834a3 --- /dev/null +++ b/tasks/msdp/metrics.py @@ -0,0 +1,77 @@ + +# The following code is adapted from +# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py, +# which is licensed under the MIT license. More details on the license can be +# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE. + +"""Provides standard metric evaluations for dialog.""" + +from collections import Counter +from typing import List +import numpy as np +import re + +re_art = re.compile(r'\b(a|an|the)\b') +re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') + + +def normalize_answer(s): + """ + Lower text and remove punctuation, articles and extra whitespace. + """ + s = s.lower() + s = re_punc.sub(' ', s) + s = re_art.sub(' ', s) + s = ' '.join(s.split()) + return s + + +class F1Metric: + """ + Helper class which computes token-level F1. + """ + + @staticmethod + def _prec_recall_f1_score(pred_items, gold_items): + """ + Compute precision, recall and f1 given a set of gold and prediction items. + :param pred_items: iterable of predicted values + :param gold_items: iterable of gold values + :return: tuple (p, r, f1) for precision, recall, f1 + """ + common = Counter(gold_items) & Counter(pred_items) + num_same = sum(common.values()) + if num_same == 0: + return 0, 0, 0 + precision = 1.0 * num_same / len(pred_items) + recall = 1.0 * num_same / len(gold_items) + f1 = (2 * precision * recall) / (precision + recall) + return precision, recall, f1 + + @staticmethod + def compute_each_pair(guess: str, answer: str): + if answer == "": + return None, None, None + if guess == "": + return 0, 0, 0 + g_tokens = normalize_answer(guess).split() + a_tokens = normalize_answer(answer).split() + + precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens) + return precision, recall, f1 + + @staticmethod + def compute_all_pairs(guesses: List[str], answers: List[str]): + # additional augment: + assert len(guesses) == len(answers) + + precision_list, recall_list, f1_list = [], [], [] + for guess, answer in zip(guesses, answers): + precision, recall, f1 = F1Metric.compute_each_pair(guess, answer) + if precision is None or recall is None or f1 is None: + continue + precision_list.append(precision) + recall_list.append(recall) + f1_list.append(f1) + + return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list) diff --git a/tasks/msdp/preprocessing.py b/tasks/msdp/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..8468a4e5c789889e095fb02bd5afe5a99b582ed0 --- /dev/null +++ b/tasks/msdp/preprocessing.py @@ -0,0 +1,595 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets""" + +import torch +import argparse +from nltk import word_tokenize +from tqdm import tqdm +import numpy as np +import json + +def get_args(): + parser = argparse.ArgumentParser(description="Preprocessing") + + parser.add_argument("--func", type=str, default=None, + help="choose to run which function") + parser.add_argument("--raw_file", type=str, default=None, + help="path of the input file") + parser.add_argument("--processed_file", type=str, default=None, + help="path of the output file") + parser.add_argument("--knwl_ref_file", type=str, default=None, + help="path of the knowledge reference file") + parser.add_argument("--resp_ref_file", type=str, default=None, + help="path of the knowledge reference file") + parser.add_argument("--knwl_gen_file", type=str, default=None, + help="path of the generated knowledge file") + parser.add_argument("--test_file", type=str, default=None, + help="path of the test file") + parser.add_argument("--train_file", type=str, default=None, + help="path of the train file") + parser.add_argument("--model_file", type=str, default=None, + help="path of the model file") + parser.add_argument("--data_type", type=str, default=None, + help="data types, choose one out of three types: \ + wow_seen, wow_unseen, and woi") + parser.add_argument("--seed", type=int, default=1234, + help="random seed") + + args = parser.parse_args() + return args + + +def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file): + """ + This is a function used for processing the wizard of wikipedia (wow) dataset + Expected processed format: + topic \t dialogue context \t golden knowledge \t golden response + """ + + # loading the raw data + print("> Loading data from %s" % raw_file) + with open(raw_file, "r") as fr: + dialog_data = json.load(fr) + + print("> Processing data ...") + fproc = open(processed_file, "w") + fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None + fresp = open(resp_ref_file, "w") if resp_ref_file else None + + for i, sample in enumerate(tqdm(dialog_data)): + # get all the dialog data for a single dialog sample + dialog = sample["dialog"] + + turn_list = [] # collect the dialog history + # processing for each single dialog sample + for j, turn in enumerate(dialog): + # text of each turn + text = turn["text"] + if not (text.endswith("?") or text.endswith(".") or text.endswith("!")): + text = text + "." + + if j == 0: + # first turn + turn_list.append(text) + continue + + speaker = turn["speaker"].lower() + if "wizard" in speaker: + checked_sentence = list(turn["checked_sentence"].values()) # knowledge + checked_passage = list(turn["checked_passage"].values()) # topic + + assert len(checked_sentence) <= 1 + + # get the ground truth knowledge + if len(checked_sentence) > 0: + checked_sentence = checked_sentence[0] + else: + checked_sentence = "no_passages_used" + + if len(checked_passage) == 1: + checked_passage = checked_passage[0] + else: + checked_passage = "no_passages_used" + + # get the topic + if checked_passage != "no_passages_used": + topic = checked_passage + else: + topic = sample["chosen_topic"] + + dialog_context = " [SEP] ".join(turn_list) + knowledge = checked_sentence + response = text + # add the response into the dialog history + turn_list.append(response) + + # write to the output files + fproc.write(topic + "\t" + dialog_context + "\t" + \ + knowledge + "\t" + response + "\n") + + if fknwl: + fknwl.write(knowledge + "\n") + if fresp: + # tokenize for evaluation + response = " ".join(word_tokenize(response)) + fresp.write(response + "\n") + + else: + assert "apprentice" in speaker + turn_list.append(text) + + fproc.close() + if fknwl: + fknwl.close() + if fresp: + fresp.close() + + +def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file): + """ + This is a function used for processing the wizard of internet (woi) dataset + Expected processed format: + topic \t dialogue context \t golden knowledge \t golden response + """ + + print("> Processing %s" % raw_file) + fproc = open(processed_file, "w") + fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None + fresp = open(resp_ref_file, "w") if resp_ref_file else None + + with open(raw_file, "r") as fr: + for i, line in tqdm(enumerate(fr)): + # read line by line, each line uses json format + line = line.strip() + item_dict = json.loads(line) + + # item_dict is a dictionary + # its key is the data id, and its value contains all the data content + item_dict = item_dict.values() + item_dict = list(item_dict)[0] # len(item_dict) == 1 + + # get the whole dialog data for a single dialog sample + dialog_data = item_dict['dialog_history'] + length = len(dialog_data) + + turn_list = [] # collect the dialog history + search_text = "" + for i in range(length): + item = dialog_data[i] + action = item['action'] + + if action == "Wizard => SearchAgent": + search_text = item['text'] + + elif action == "Wizard => Apprentice": + if len(turn_list) == 0: + # first turn + turn = item['text'] + turn_list.append(turn) + continue + + # get the relevant content + contents = item["context"]["contents"] + selects = item["context"]["selected_contents"] + flag = selects[0][0] + selects = selects[1:] + assert len(selects) == len(contents) + + # get the topic + if flag: + # no knowledge sentence is used for the response + topic = "no_topic" + knwl_sent = "no_passages_used" + else: + # we consider the search text as the topic + topic = search_text + # get the knowledge sentence + knwl_sent = "" + for content, select in zip(contents, selects): + content = content['content'] + assert len(content) == len(select) + for c, s in zip(content, select): + if s: + knwl_sent = c + break + + if knwl_sent == "": + # no knowledge is used for the response + topic = "no_topic" + knwl_sent = "no_passages_used" + + # get dialogue context, knowledge, and response + dialog_context = " [SEP] ".join(turn_list) + response = item['text'] + + # processing + topic = topic.replace("\n", "").replace("\r", \ + "").replace("\t", "") + dialog_context = dialog_context.replace("\n", "").replace("\r", \ + "").replace("\t", "") + knwl_sent = knwl_sent.replace("\n", "").replace("\r", \ + "").replace("\t", "") + response = response.replace("\n", "").replace("\r", \ + "").replace("\t", "") + + if topic != "no_topic": + # write to the ouput files + fproc.write(topic + "\t" + dialog_context + "\t" + \ + knwl_sent + "\t" + response + "\n") + if fknwl: + fknwl.write(knwl_sent + "\n") + if fresp: + # tokenize for evaluation + response = " ".join(word_tokenize(response)) + fresp.write(response + "\n") + + turn_list.append(response) + + elif action == "Apprentice => Wizard": + turn = item['text'] + turn_list.append(turn) + + else: + assert action == "SearchAgent => Wizard", \ + "Please check whether you have used the correct data!" + + fproc.close() + if fknwl: + fknwl.close() + if fresp: + fresp.close() + + +def get_database(test_datapath, train_datapath, data_type): + """Get the database by topics""" + + assert data_type in ["wow_seen", "wow_unseen", "woi"], \ + "Please input a correct data type!!" + + # get test data topic dictionary + print("> reading test data from %s" % test_datapath) + test_topics = {} + with open(test_datapath, "r") as f: + for i, line in enumerate(f): + line = line.strip() + splits = line.split("\t") + topic = splits[0] + test_topics[topic] = True + + print("> reading data from %s" % train_datapath) + train_data_by_topic = {} + dialog_data_by_topic = {} + dialog_examples = [] + with open(train_datapath, "r") as f: + for i, line in enumerate(f): + line = line.strip() + splits = line.split("\t") + topic = splits[0] + turns = splits[1].split(" [SEP] ")[-3:] + knowledge = splits[2] + response = splits[3] + # filtering data samples + if knowledge == "no_passages_used": + # when no knowledge is used + continue + if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge): + # when bracket exists in the knowledge + continue + if data_type != "wow_seen" and topic not in knowledge: + # when topic does not exist in the knowledge + continue + + # get the instance + last_turn = turns[-1] + instance = "( " + last_turn + " ) " + topic + " => " + knowledge + + # construct dialog example + dialog_example = "" + if data_type != "wow_seen": + dialog_example += "( " + topic + " ) " + for i, turn in enumerate(turns): + if i != 0: + dialog_example += " " + dialog_example += turn + + # check overlaps + if topic in test_topics: + if topic not in train_data_by_topic: + train_data_by_topic[topic] = [instance] + else: + train_data_by_topic[topic].append(instance) + + if topic not in dialog_data_by_topic: + dialog_data_by_topic[topic] = [dialog_example] + else: + dialog_data_by_topic[topic].append(dialog_example) + + else: + # filtering data samples + if len(knowledge.split()) > 20: + # knowledge is too long + continue + if knowledge.startswith("It") or knowledge.startswith("it") or \ + knowledge.startswith("This") or knowledge.startswith("this"): + continue + + # append all the data into dialogue examples list + dialog_examples.append((topic, dialog_example, instance)) + + return train_data_by_topic, dialog_data_by_topic, dialog_examples + + +emb_dict = {} +def select_prompts_based_on_similarity( + query, dialog_list, prompt_list, topic, tokenizer, encoder, topk): + """Select samples based on the similarity""" + + with torch.no_grad(): + # get the query embeddings + query_ids = tokenizer.encode(query) + query_ids = torch.LongTensor([query_ids]).cuda() + query_emb = encoder(input_ids=query_ids).pooler_output + query_emb = query_emb[0] + + # calculate embeddings for the samples in the database + if topic in emb_dict: + example_embeddings = emb_dict[topic] + example_embeddings = example_embeddings.cuda() + else: + for idx, example in enumerate(dialog_list): + example_ids = tokenizer.encode(example) + example_ids = torch.LongTensor([example_ids]).cuda() + example_emb = encoder(input_ids=example_ids).pooler_output + if idx == 0: + example_embeddings = example_emb + else: + example_embeddings = torch.cat( + (example_embeddings, example_emb), dim=0) + emb_dict[topic] = example_embeddings.cpu() + + # compare the similarity and select the topk samples + similarity_list = example_embeddings.matmul(query_emb) + _, indices = torch.topk(similarity_list, k=topk) + + indices = indices.tolist() + indices = indices[::-1] # reverse the order + selected_prompts = [] + for index in indices: + # index = index.item() + selected_prompts.append(prompt_list[index]) + + return selected_prompts + + +def prompt_selection_for_knowledge_generation( + test_datapath, train_datapath, model_path, output_prompt_path, data_type): + """Selecting prompts for the knowledge generation""" + + print("> Selecting prompts for the knowledge generation") + + train_data_by_topic, dialog_data_by_topic, dialog_examples = \ + get_database(test_datapath, train_datapath, data_type) + + from transformers import DPRQuestionEncoderTokenizer + print("> loading tokenizer and encoder") + tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( + 'facebook/dpr-question_encoder-single-nq-base') + encoder = torch.load(model_path).cuda() + + print("> getting dialog embeddings") + with torch.no_grad(): + for idx, example in tqdm(enumerate(dialog_examples)): + dialog = example[1] + dialog_ids = tokenizer.encode(dialog) + dialog_ids = torch.LongTensor([dialog_ids]).cuda() + dialog_emb = encoder(input_ids=dialog_ids).pooler_output + + if idx == 0: + dialog_embeddings = dialog_emb + else: + dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0) + + print("> reading test data from %s" % test_datapath) + prompt_list_for_each_sample = [] + with open(test_datapath, "r") as f: + for i, line in tqdm(enumerate(f)): + line = line.strip() + + splits = line.split("\t") + topic = splits[0] + turns = splits[1].split(" [SEP] ")[-3:] + + # get the query sentence + query_sent = "" + if data_type != "seen": + query_sent += "( " + topic + " ) " + for i, turn in enumerate(turns): + if i != 0: + query_sent += " " + query_sent += turn + + if topic not in train_data_by_topic: + # get the query embedding + query_ids = tokenizer.encode(query_sent) + query_ids = torch.LongTensor([query_ids]).cuda() + query_emb = encoder(input_ids=query_ids).pooler_output + query_emb = query_emb[0] + + # calculate the similarity + similarity_list = dialog_embeddings.matmul(query_emb) + _, indices = torch.sort(similarity_list) + indices = indices.tolist() + selected_topics = {} + selected_prompts = [] + num_prompt = 0 + for index in indices: + example = dialog_examples[index] + topic_temp = example[0] + if topic_temp not in selected_topics: + selected_topics[topic_temp] = True + selected_prompts.append(example[2]) + num_prompt += 1 + if num_prompt == 10: + break + + # get the selected samples + example_list = selected_prompts[::-1] + key = topic + " " + turns[-1] + prompt_list_for_each_sample.append({key: example_list}) + + else: + num_data_sample = min(len(train_data_by_topic[topic]), 10) + total_example_list = train_data_by_topic[topic] + + dialog_list = dialog_data_by_topic[topic] + assert len(dialog_list) == len(train_data_by_topic[topic]) + + # calculate the similarity + example_list = select_prompts_based_on_similarity( + query_sent, dialog_list, total_example_list, + topic, tokenizer, encoder, topk=num_data_sample) + + key = topic + " " + turns[-1] + prompt_list_for_each_sample.append({key: example_list}) + + print("writing to %s" % output_prompt_path) + with open(output_prompt_path, "w") as f: + for instance in tqdm(prompt_list_for_each_sample): + json.dump(instance, f) + f.write("\n") + + +def prompt_selection_for_response_generation(input_path, output_path, seed): + """Selecting prompts for the response generation""" + + print("> Selecting prompts for the response generation") + print("> set random seed") + np.random.seed(seed) + + prompt_example_list = [] + print("> reading data from %s" % input_path) + with open(input_path, "r") as f: + for i, line in tqdm(enumerate(f)): + line = line.strip() + splits = line.split("\t") + + # get the topic, context, knowledge and response + topic = splits[0] + dialog_context = splits[1] + knowledge = splits[2] + response = splits[3] + turns = dialog_context.split(" [SEP] ")[-3:] + if knowledge == "no_passages_used": + continue + + # calculate the overlap ratio + from nltk import word_tokenize + knowledge_sent_token_list = word_tokenize(knowledge) + knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list} + knowledge_len = len(knowledge_sent_token_list) + response_token_list = word_tokenize(response) + response_len = len(response_token_list) + num_overlap_token = 0 + accumulator = 0 + for token in response_token_list: + if token in knowledge_sent_token_dict: + accumulator += 1 + else: + if accumulator >= 10: + num_overlap_token += accumulator + accumulator = 0 + if accumulator >= 10: + num_overlap_token += accumulator + + # filtering the data based on the ratio + if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6: + continue + if num_overlap_token < knowledge_len * 0.8: + continue + + last_turn = " ".join(word_tokenize(turns[-1])) + knowledge = " ".join(word_tokenize(knowledge)) + response = " ".join(word_tokenize(response)) + prompt_example = "" + # add dialog context + prompt_example += "Topic: " + topic + ". " + prompt_example += "User says: " + last_turn + " " + prompt_example += "We know that: " + knowledge + " " + prompt_example += "System replies: " + response + + prompt_example_list.append(prompt_example) + + # shuffle the prompt examples + np.random.shuffle(prompt_example_list) + + print("> writing to %s" % output_path) + with open(output_path, "w") as f: + # f.write("Generate the System's response based on the knowledge sentence:\n") + for i in tqdm(range(20)): + example = prompt_example_list[i] + f.write(example + "\n") + + +def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file): + """Preparing inputs for the response generation""" + + print("> Reading knowledge file from %s" % knwl_gen_file) + # get the knowledge list + with open(knwl_gen_file, "r") as f: + knowledge_list = f.readlines() + + print("> Processing ...") + with open(test_file, "r") as fr: + with open(processed_file, "w") as fw: + for line_num, line in enumerate(tqdm(fr)): + line = line.strip() + splits = line.split("\t") + # prepare topic, context, knowledge and response + topic = splits[0] + dialog_context = splits[1] + response = splits[3] + knowledge = knowledge_list[line_num] + knowledge = knowledge.strip() + if "<|endoftext|>" in knowledge: + knowledge = knowledge.replace("<|endoftext|>", "") + + # write to the output file + fw.write(topic + "\t" + dialog_context + "\t" \ + + knowledge + "\t" + response + "\n") + + +if __name__ == "__main__": + + args = get_args() + if args.func == "process_wow_dataset": + process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file) + + elif args.func == "process_woi_dataset": + process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file) + + elif args.func == "get_knwl_gen_prompts": + prompt_selection_for_knowledge_generation( + args.test_file, args.train_file, args.model_file, + args.processed_file, args.data_type) + + elif args.func == "get_resp_gen_prompts": + prompt_selection_for_response_generation( + args.train_file, args.processed_file, args.seed) + + elif args.func == "prepare_input": + prepare_input_for_response_generation( + args.test_file, args.knwl_gen_file, args.processed_file) diff --git a/tasks/msdp/prompt.py b/tasks/msdp/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..2a3576a236280dfacba8d899fde832fd67fa81fe --- /dev/null +++ b/tasks/msdp/prompt.py @@ -0,0 +1,322 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prompting the pretrained language model to generate knowledge/response""" + +import json +import torch +import requests +from nltk import word_tokenize +from megatron import mpu +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_tokenizer +from megatron.model import GPTModel +from megatron.training import get_model +from megatron.checkpointing import load_checkpoint +from megatron.initialize import initialize_megatron +from megatron.text_generation import generate_and_post_process + + +def call_model_api(inputs, tokens_to_generate): + """Calling the model api to get the output generations""" + + args = get_args() + + # The following is an example of using the Megatron API + # You can also implement your own API function to place this part + headers = {'Content-Type': 'application/json; charset=UTF-8'} + data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1} + data_json = json.dumps(data) + outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0] + + input_len = len(inputs) + outputs = outputs[input_len:] + outputs = outputs.split("\n")[0].strip() + + return outputs + + +def read_prompts(prompt_path, prompt_type, n_example): + """Read prompt data""" + + if prompt_type == "knowledge": + # prompts for the knowledge generation + prompt_examples_dict = {} + # read prompt_path + with open(prompt_path, "r") as f: + for i, line in enumerate(f): + line = line.strip() + line_dict = json.loads(line) + key = list(line_dict.keys())[0] + + if key not in prompt_examples_dict: + prompt_examples = line_dict[key] + prompt = "" + for instance in prompt_examples: + instance = instance.strip() + prompt += instance + " \n" + prompt_examples_dict[key] = prompt + + return prompt_examples_dict + + else: + # prompts for the response generation + # read prompt_path + prompt = "" + with open(prompt_path, "r") as f: + prompt_examples = f.readlines() + prompt_examples = prompt_examples[:n_example] + for instance in prompt_examples: + instance = instance.strip() + prompt += instance + " \n" + + return prompt + + +def generate_samples_by_calling_api(): + """ Generate outputs by calling""" + args = get_args() + assert args.prompt_type in ["knowledge", "response"], \ + "Please input a correct prompt type!" + + if args.prompt_type == "knowledge": + # read knowledge generation prompts + knwl_gen_prompt_dict = read_prompts( + args.prompt_file, args.prompt_type, args.num_prompt_examples) + + else: + resp_gen_prompt = read_prompts( + args.prompt_file, args.prompt_type, args.num_prompt_examples) + + # read the test data + fname = open(args.sample_input_file, "r") + test_sample_list = fname.readlines() + # create output file + fname_out = open(args.sample_output_file, "w") + + # call the api to get the output generations + for test_sample in test_sample_list: + test_sample = test_sample.strip() + splits = test_sample.split("\t") + topic = splits[0] + + # prepare the inputs for the api + if args.prompt_type == "knowledge": + ## inputs = prompt + current test + # get the prompt + turns = splits[1].split(" [SEP] ") + last_turn = turns[-1] + key = topic + " " + last_turn + inputs = knwl_gen_prompt_dict[key] + + # add current test + inputs += "( " + last_turn + " ) " + topic + " =>" + + else: + # inputs = prompt + current test + # get the prompt + inputs = resp_gen_prompt + + # add current test + turns = splits[1].split(" [SEP] ") + knowledge = splits[2] + last_turn = turns[-1] + last_turn = " ".join(word_tokenize(last_turn)) + knowledge = " ".join(word_tokenize(knowledge)) + knowledge = knowledge.strip() + last_turn = last_turn.strip() + inputs += "Topic: " + topic + ". " + inputs += "User says: " + last_turn + " " + inputs += "We know that: " + knowledge + " " + inputs += "System replies:" + + # get the output generations from the api, + # and write to the output file + generations = call_model_api(inputs, args.out_seq_length) + fname_out.write(generations) + fname_out.write("\n") + + fname.close() + fname_out.close() + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + + +def generate_samples_by_prompting_input_from_file(model): + """Prompt a pretrained language model to generate knowledge/response""" + + # get tokenizer + args = get_args() + tokenizer = get_tokenizer() + + # Read the sample file and open the output file. + assert args.sample_input_file is not None, \ + 'sample input file is not provided.' + if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + fname = open(args.sample_input_file, "r") + all_raw_text = fname.readlines() + input_count = len(all_raw_text) + if args.sample_output_file is None: + sample_output_file = args.sample_input_file + ".out" + print('`sample-output-file` not specified, setting ' + 'it to {}'.format(sample_output_file)) + else: + sample_output_file = args.sample_output_file + + fname_out = open(sample_output_file, "w") + + # only two prompt types (i.e., knowledge and response) are allowed + assert args.prompt_type in ["knowledge", "response"], \ + "Please input a correct prompt type!" + + # Read the prompt file + if args.prompt_type == "knowledge": + # read the prompts for the knowledge generation + prompt_examples_dict = {} + with open(args.prompt_file, "r") as f: + for i, line in enumerate(f): + line = line.strip() + line_dict = json.loads(line) + key = list(line_dict.keys())[0] + + # get the prompt examples based on the key + if key not in prompt_examples_dict: + prompt_examples = line_dict[key] + prompt = "" + for instance in prompt_examples: + instance = instance.strip() + prompt += instance + " \n" + prompt_examples_dict[key] = prompt + + else: + # read the prompts for the response generation + # prompts are fixed for all test samples + with open(args.prompt_file, "r") as f: + prompt_examples = f.readlines() + prompt_examples = prompt_examples[:args.num_prompt_examples] + + prompt = "" + for instance in prompt_examples: + instance = instance.strip() + prompt += instance + " \n" + + input_pos = 0 + model.eval() + # perform prompting + with torch.no_grad(): + while True: + raw_text_len = 0 + if mpu.is_pipeline_first_stage() \ + and mpu.get_tensor_model_parallel_rank() == 0: + input_str = all_raw_text[input_pos] + input_str = input_str.strip() + splits = input_str.split("\t") + topic = splits[0] + + if args.prompt_type == "knowledge": + # first add the prompt into the raw_text + turns = splits[1].split(" [SEP] ") + last_turn = turns[-1] + key = topic + " " + last_turn + raw_text = prompt_examples_dict[key] + + # construct inputs for knowledge generation + # then add the constructed inputs into the raw_text + raw_text += "( " + last_turn + " ) " + topic + " =>" + + else: + # first add the prompt into the raw_text + raw_text = prompt + + # construct inputs for response generation + # then add the constructed inputs into the raw_text + turns = splits[1].split(" [SEP] ") + knowledge = splits[2] + last_turn = turns[-1] + last_turn = " ".join(word_tokenize(last_turn)) + knowledge = " ".join(word_tokenize(knowledge)) + knowledge = knowledge.strip() + last_turn = last_turn.strip() + raw_text += "Topic: " + topic + ". " + raw_text += "User says: " + last_turn + " " + raw_text += "We know that: " + knowledge + " " + raw_text += "System replies:" + + input_pos += 1 + raw_text_len = len(raw_text) + + else: + raw_text = "EMPTY TEXT" + + if input_pos % 100 == 0: + print_rank_0("input_pos: %d" % input_pos) + + outputs = generate_and_post_process( + model=model, + prompts=[raw_text], + tokens_to_generate=args.out_seq_length, + top_k_sampling=1) + prompts_plus_generations = outputs[0] + prompts_plus_generations = prompts_plus_generations[0] + + # write the generated output to the output file + if mpu.get_tensor_model_parallel_rank() == 0: + if mpu.is_pipeline_first_stage(): + + generations = prompts_plus_generations[raw_text_len:] + generations = generations.split("\n")[0] + generations = generations.strip() + fname_out.write(generations) + fname_out.write("\n") + + raw_text = None + if input_pos == input_count: + return + + +def main(): + + args = get_args() + if args.api_prompt: + # obtain the generations by calling the api + generate_samples_by_calling_api() + return + + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for text generation.") + exit() + + # Set up model and load checkpoint. + model = get_model(model_provider, wrap_with_ddp=False) + if args.load is not None: + _ = load_checkpoint(model, None, None) + + assert len(model) == 1, "Above condition should have caught this" + model = model[0] + + # perform the prompting + generate_samples_by_prompting_input_from_file(model) diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a8e8f8e6fabcca14aacc3776a062f753b1253d27 --- /dev/null +++ b/tasks/orqa/README.md @@ -0,0 +1,36 @@ +## End-to-End Training of Neural Retrievers for Open-Domain Question Answering + +Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408). + +## Retriever Training + +#### Unsupervised pretraining +1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body. + +
+python tools/preprocess_data.py \ + --input /path/to/corpus.json \ + --json-keys text title \ + --split-sentences \ + --tokenizer-type BertWordPieceLowerCase \ + --vocab-file /path/to/vocab.txt \ + --output-prefix corpus_indexed \ + --workers 10 ++ +2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training. + +3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf). + +#### Supervised finetuning + +1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906). + +2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. + +More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408). + +## Reader Training + +The reader component will be available soon. + diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py new file mode 100644 index 0000000000000000000000000000000000000000..87c59ea30e2452851f2ea80edf3fa9ba436bf600 --- /dev/null +++ b/tasks/orqa/evaluate_orqa.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main tasks functionality.""" + +from megatron import get_args, print_rank_0 +from megatron.indexer import IndexBuilder +from tasks.orqa.evaluate_utils import ORQAEvaluator + +def main(): + """ + Main program + """ + + args = get_args() + + """ + Create a BlockData data structure by running an IndexBuilder over an + ICT Dataset and then evaluate on NQ task + """ + + print_rank_0("Starting index builder!") + + index_builder = IndexBuilder() + index_builder.build_and_save_index() + print_rank_0("Build and save indices: done!") + + + print_rank_0("Starting evaluations!") + + # Set up the model and evaluator + evaluator = ORQAEvaluator() + + # Run evaluation + if args.qa_data_dev is not None: + evaluator.evaluate(args.qa_data_dev, "DEV") + + if args.qa_data_test is not None: + evaluator.evaluate(args.qa_data_test, "TEST") + diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..08b1e929b3e72179484bcfa22900661daf7ae267 --- /dev/null +++ b/tasks/orqa/evaluate_utils.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron import get_args, print_rank_0 +from megatron.checkpointing import load_biencoder_checkpoint +from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset +from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex +from megatron.model.biencoder_model import get_model_provider +from megatron.training import get_model +from tasks.orqa.unsupervised.nq import get_nq_dataset +from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader +from tasks.orqa.unsupervised.nq import process_nq_batch +from tasks.orqa.unsupervised.qa_utils import calculate_matches + + +class ORQAEvaluator(object): + def __init__(self): + args = get_args() + self.embedding_size = args.hidden_size + self.faiss_use_gpu = args.faiss_use_gpu + self.evidence_embedder_obj = None + self.evidence_dataset = None + self.mips_index = None + self.eval_dataset = None + + # Get Evidence (Wikipedia) dataset + self.get_evidence_dataset() + + # Load query encoder checkpoint + only_query_model = True + if args.biencoder_shared_query_context_model: + only_query_model = False + + model = get_model(get_model_provider(only_query_model=only_query_model, + biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) + + self.model = load_biencoder_checkpoint(model, + only_query_model=only_query_model) + + assert len(self.model) == 1 + self.model[0].eval() + + # Load faiss indexer + self.faiss_wrapper() + + def get_evidence_embedding(self): + # This will load the embedding from the embedding path + self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True) + + def get_evidence_dataset(self): + self.evidence_dataset = get_open_retrieval_wiki_dataset() + + def faiss_wrapper(self): + # Initialize FAISS wrapper on local rank = 0 as the evidence embeddings + # is distributed over all the GPUs in a node and FAISS is not + # thread-safe + args = get_args() + if args.local_rank == 0: + # Get evidence embeddings computed using context encoder + self.get_evidence_embedding() + + assert self.evidence_embedder_obj is not None + self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size, + embed_data=self.evidence_embedder_obj, + use_gpu=self.faiss_use_gpu) + + # Wait for the FAISS index to be initialized in all the nodes + torch.distributed.barrier() + + def generate_query_vectors(self, qa_data, split): + + self.eval_dataset = get_nq_dataset(qa_data, split) + dataloader = get_one_epoch_nq_dataloader(self.eval_dataset) + + query_vectors = [] + reference_list = [] + + for batch in dataloader: + # batch also has query_tokens and query_pad_data + query_tokens, query_mask, query_types, \ + query_len, reference = process_nq_batch(batch) + + assert len(self.model) == 1 + unwrapped_model = self.model[0] + while not hasattr(unwrapped_model, 'embed_text'): + unwrapped_model = unwrapped_model.module + + with torch.no_grad(): + query_logits = unwrapped_model.embed_text( + unwrapped_model.query_model, query_tokens, + query_mask, query_types) + + reference_list.extend(reference) + query_vectors.extend(query_logits.split(1, dim=0)) + if len(query_vectors) % 100 == 0: + print_rank_0('Encoded queries {}'.format(len(query_vectors))) + + query_tensor = torch.cat(query_vectors, dim=0) + print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size())) + + assert query_tensor.size(0) == len(self.eval_dataset) + return query_tensor, reference_list + + def evaluate(self, qa_data, split): + args = get_args() + query_tensor, reference_list = self.generate_query_vectors(qa_data, \ + split) + local_rank = args.local_rank + rank = torch.distributed.get_rank() + device_count = torch.cuda.device_count() + num_nodes = torch.distributed.get_world_size() // device_count + node_id = rank // device_count + + for node in range(num_nodes): + start_rank = node * device_count + end_rank = (node + 1) * device_count + ranks_list = list(range(start_rank, end_rank)) + node_group = torch.distributed.new_group(ranks=ranks_list) + + if node_id == node: + device_start_rank = start_rank + group = node_group + + input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(device_count)] + torch.distributed.all_gather(tensor_list, query_tensor, group=group) + + if local_rank == 0 and self.mips_index is not None: + all_query_tensor = torch.cat(tensor_list, dim=0).contiguous() + + distance, topkindex = self.mips_index.search_mips_index( + all_query_tensor, top_k=args.faiss_topk_retrievals, + reconstruct=False) + distance = torch.from_numpy(distance).cuda() + topkindex = torch.LongTensor(topkindex).cuda() + + if local_rank != 0: + distance = torch.empty(device_count * len(query_tensor), \ + args.faiss_topk_retrievals, dtype=torch.float32).cuda() + topkindex = torch.empty(device_count * len(query_tensor), \ + args.faiss_topk_retrievals, dtype=torch.int64).cuda() + + torch.distributed.broadcast(distance, src=device_start_rank, \ + group=group) + torch.distributed.broadcast(topkindex, src=device_start_rank, \ + group=group) + + distance = torch.split(distance, len(query_tensor), dim=0)\ + [local_rank] + topkindex = torch.split(topkindex, len(query_tensor), dim=0)\ + [local_rank] + + top_ids_and_scores = [] + for darray, topkarray in zip(distance, topkindex): + top_ids_and_scores.append((topkarray.tolist(), darray.tolist())) + + passages = self.evidence_dataset.id2text + match_stats = calculate_matches(passages, + reference_list, + top_ids_and_scores, + workers_num=args.num_workers, + match_type=args.faiss_match) + top_k_hits = match_stats.top_k_hits + + print_rank_0("{} SET RESULTS".format(split)) + print_rank_0("topk-{} documents hits {}".format( + args.faiss_topk_retrievals, top_k_hits)) + top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits] + print_rank_0("top-k documents hits accuracy {}".format(top_k_hits)) + + for i in args.retriever_report_topk_accuracies: + print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100)) + + return diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py new file mode 100644 index 0000000000000000000000000000000000000000..b45a842b61c40f18f2d742f443d43ecf9040c5ce --- /dev/null +++ b/tasks/orqa/supervised/data.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ORQA dataset.""" + +import json +import random +from abc import ABC +from abc import abstractmethod + +import numpy as np +from torch.utils.data import Dataset + +from megatron import print_rank_0, get_args +from megatron.data.biencoder_dataset_utils import make_attention_mask + +def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length): + ctx_id_list, ctx_types_list = [], [] + for context in ctx_list: + title_ids = tokenizer.tokenize(context['title']) + ctx_ids = tokenizer.tokenize(context['text']) + ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids + + ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids, + max_seq_length, tokenizer.cls, + tokenizer.sep, tokenizer.pad) + ctx_id_list.append(ctx_ids) + ctx_types_list.append(ctx_types) + + return ctx_id_list, ctx_types_list + + +def build_tokens_types_paddings_from_text(query, context, + tokenizer, max_seq_length): + """Build token types and paddings, trim if needed, and pad if needed.""" + + query_ids = tokenizer.tokenize(query) + query_ids, query_types, query_pad_mask = \ + build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \ + tokenizer.cls, tokenizer.sep, tokenizer.pad) + + # Appending the title of the context at front + extended_ctx_ids = None + if context is not None: + title_ids = tokenizer.tokenize(context['title']) + ctx_ids = tokenizer.tokenize(context['text']) + extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids + + ctx_ids, ctx_types, ctx_pad_mask = \ + build_tokens_types_paddings_from_ids(extended_ctx_ids, + max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) + + return query_ids, query_types, query_pad_mask, \ + ctx_ids, ctx_types, ctx_pad_mask + + +# Similar code tasks/data_utils with some changes +def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, + cls_id, sep_id, pad_id): + """Build token types and paddings, trim if needed, and pad if needed.""" + enc_ids = [] + tokentypes_enc = [] + + # [CLS]. + enc_ids.append(cls_id) + tokentypes_enc.append(0) + + # A. + len_src = len(text_ids) + enc_ids.extend(text_ids) + tokentypes_enc.extend([0] * len_src) + + # Cap the size. + if len(enc_ids) > max_seq_length - 1: + enc_ids = enc_ids[0: max_seq_length - 1] + tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] + + # [SEP]. + enc_ids.append(sep_id) + tokentypes_enc.append(0) + + num_tokens_enc = len(enc_ids) + # Padding. + padding_length = max_seq_length - len(enc_ids) + if padding_length > 0: + enc_ids.extend([pad_id] * padding_length) + tokentypes_enc.extend([pad_id] * padding_length) + + pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) + pad_mask = np.array(pad_mask, dtype=np.int64) + + return enc_ids, tokentypes_enc, pad_mask + + +def build_sample(query_ids, query_types, query_pad_mask, + ctx_ids, ctx_types, ctx_pad_mask, answers, + neg_ctx_id_list=None, neg_ctx_types_list=None, + include_neg=False): + """Convert to numpy and return a sample consumed by the batch producer.""" + + query_ids = np.array(query_ids, dtype=np.int64) + query_types = np.array(query_types, dtype=np.int64) + query_mask = make_attention_mask(query_ids, query_ids) + + ctx_ids = np.array(ctx_ids, dtype=np.int64) + ctx_types = np.array(ctx_types, dtype=np.int64) + ctx_mask = make_attention_mask(ctx_ids, ctx_ids) + + sample = ({ + 'query': query_ids, + 'query_mask': query_mask, + 'query_types': query_types, + 'query_pad_mask': query_pad_mask, + 'context': ctx_ids, + 'context_mask': ctx_mask, + 'context_types': ctx_types, + 'context_pad_mask': ctx_pad_mask, + 'reference': answers + }) + + if include_neg: + neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64) + neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64) + neg_ctx_mask = np.array([make_attention_mask(ids, ids) \ + for ids in neg_ctx_ids], dtype=np.int64) + + sample['neg_context'] = neg_ctx_ids + sample['neg_context_types'] = neg_ctx_id_types + sample['neg_context_mask'] = neg_ctx_mask + + return sample + + +class OpenRetrievalAbstractDataset(ABC, Dataset): + """Open Retrieval base dataset class.""" + + def __init__(self, task_name, dataset_name, datapaths, tokenizer, \ + max_seq_length, evaluate=False): + # Store inputs. + args = get_args() + self.evaluate = evaluate + self.val_av_rank_hard_neg = args.val_av_rank_hard_neg + self.val_av_rank_other_neg = args.val_av_rank_other_neg + self.train_with_neg = args.train_with_neg + self.train_hard_neg = args.train_hard_neg + + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + # Process the files. + string = ' > paths:' + for path in datapaths: + string += ' ' + path + print_rank_0(string) + self.samples = [] + for datapath in datapaths: + self.samples.extend(self.process_samples_from_single_path(datapath)) + + args = get_args() + if args.sample_rate < 1: # subsample + k = int(len(self.samples) * args.sample_rate) + self.samples = random.sample(self.samples, k) + + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + + query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \ + ctx_pad_mask = build_tokens_types_paddings_from_text( \ + raw_sample['question'], raw_sample['pos_context'], \ + self.tokenizer, self.max_seq_length) + + if self.evaluate: + neg_ctx_list = \ + raw_sample['negative_context'][:self.val_av_rank_other_neg] + \ + raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg] + neg_ctx_id_list, neg_ctx_types_list = \ + build_token_types_from_context_list(neg_ctx_list, \ + self.tokenizer, self.max_seq_length) + + elif self.train_with_neg: + hard_negative_ctx = raw_sample['hard_negative_context'] + negative_ctx = raw_sample['negative_context'] + if True: # TODO: fix this or remove this condition + random.shuffle(hard_negative_ctx) + random.shuffle(negative_ctx) + + neg_ctx_list = hard_negative_ctx[:self.train_hard_neg] + # In the Google NQ dataset by DPR paper, there are around more than + # 50 missing hard negatives in training data. + # In those cases, substitute hard negatives by simple negatives. + if len(neg_ctx_list) < self.train_hard_neg: + neg_ctx_list += negative_ctx[:self.train_hard_neg - \ + len(neg_ctx_list)] + + neg_ctx_id_list, neg_ctx_types_list = \ + build_token_types_from_context_list(neg_ctx_list, + self.tokenizer, self.max_seq_length) + else: + neg_ctx_id_list = None + neg_ctx_types_list = None + + sample = build_sample(query_ids, query_types, query_pad_mask, + ctx_ids, ctx_types, ctx_pad_mask, + raw_sample['answers'], + neg_ctx_id_list, neg_ctx_types_list, + include_neg=self.evaluate or self.train_with_neg) + + return sample + + @staticmethod + @abstractmethod + def process_samples_from_single_path(filename): + """Abstract method that takes a filename and + returns a list of dataset samples, each sample being a dict of + {'text': string, 'text': string} + """ + pass + + + +def normalize_question(question): + if question[-1] == '?': + question = question[:-1] + return question + +# The following class reads the datasets for training retriever as +# prepared by the DPR codebase (https://github.com/facebookresearch/DPR) + +class NQSupervisedDataset(OpenRetrievalAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, \ + evaluate=False): + super().__init__('natural_questions_ret', + name, + datapaths, + tokenizer, + max_seq_length, + evaluate=evaluate) + + @staticmethod + def process_samples_from_single_path(filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + samples = [] + total = 0 + + with open(filename, 'r', encoding="utf-8") as f: + data = json.load(f) + for row in data: + question = normalize_question(row['question']) + pos_context = row['positive_ctxs'][0] + + # Hard Negative Contexts + if len(row['hard_negative_ctxs']) > 0: + hard_neg_context = row['hard_negative_ctxs'] + else: + hard_neg_context = [] + + # Negative Contexts + if len(row['negative_ctxs']) > 0: + neg_context = row['negative_ctxs'] + else: + neg_context = [] + + answers = row['answers'] + sample = {'question': question, + 'pos_context': pos_context, + 'hard_negative_context': hard_neg_context, + 'negative_context': neg_context, + 'answers': answers} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + return samples + diff --git a/tasks/orqa/supervised/eval_utils.py b/tasks/orqa/supervised/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..67dca512b0d1d30d79b7489891a31232fe49e0d5 --- /dev/null +++ b/tasks/orqa/supervised/eval_utils.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation utilities.""" +from collections import OrderedDict +import math +import numpy as np +import time +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from megatron import get_args, print_rank_0 +from megatron import mpu +from megatron.utils import average_losses_across_data_parallel_group +from tasks.finetune_utils import build_data_loader + +def task_collate_fn(batch_data): + # generate batch + batch_size = len(batch_data) + tensorized = OrderedDict() + for d in batch_data: + for k, v in d.items(): + tensorized.setdefault(k, []).append(v) + + tensorized['query'] = torch.LongTensor(tensorized['query']) + tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask']) + tensorized['query_types'] = torch.LongTensor(tensorized['query_types']) + tensorized['query_pad_mask'] = \ + torch.LongTensor(tensorized['query_pad_mask']) + + tensorized['context'] = torch.LongTensor(tensorized['context']) + tensorized['context_mask'] = \ + torch.LongTensor(tensorized['context_mask']) + tensorized['context_types'] = \ + torch.LongTensor(tensorized['context_types']) + tensorized['context_pad_mask'] = \ + torch.LongTensor(tensorized['context_pad_mask']) + + if 'neg_context' in tensorized: + tensorized['neg_context'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context'])) + tensorized['neg_context_mask'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context_mask'])) + tensorized['neg_context_types'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context_types'])) + + return tensorized + + + +def process_batch(batch): + """Process batch and produce inputs for the model.""" + query_tokens = batch['query'].long().cuda() + query_mask = (batch['query_mask'] < 0.5).cuda() + query_types = batch['query_types'].long().cuda() + query_pad_mask = batch['query_pad_mask'].long().cuda() + + context_tokens = batch['context'].long().cuda() + context_mask = (batch['context_mask'] < 0.5).cuda() + context_types = batch['context_types'].long().cuda() + context_pad_mask = batch['context_pad_mask'].long().cuda() + + if 'neg_context' in batch: + neg_context_tokens = batch['neg_context'].long().cuda() + neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda() + neg_context_types = batch['neg_context_types'].long().cuda() + else: + neg_context_tokens = None + neg_context_mask = None + neg_context_types = None + + reference = batch['reference'] + + return query_tokens, query_mask, query_types, query_pad_mask, \ + context_tokens, context_mask, context_types, context_pad_mask, \ + neg_context_tokens, neg_context_mask, neg_context_types, reference + +def accuracy_func_provider(single_dataset_provider, rank0sampler=False): + """Provide function that calculates accuracies.""" + args = get_args() + + print_rank_0("accuracy_func_provider is CALLED") + + # Build dataloaders + datapath = args.valid_data + dataset = single_dataset_provider(datapath) + + drop_last = False + if mpu.get_data_parallel_world_size() > 1 and not rank0sampler: + drop_last = True + + print_rank_0(datapath) + print_rank_0(rank0sampler) + + dataloader = build_data_loader(dataset, + args.eval_micro_batch_size, + num_workers=args.num_workers, + drop_last=drop_last, + task_collate_fn=task_collate_fn) + dataloaders = (dataset.dataset_name, dataloader) + + def metrics_func(model, epoch, output_predictions=False): + print_rank_0('calculating metrics by accuracy func in ORQA...') + + if output_predictions: + assert rank0sampler + names = 'predictions' + name, dataloader = dataloaders + if args.task == "RET-FINETUNE-NQ": + start_time = time.time() + output = retrieval_loss(model, dataloader) + stats_dict, total = output + format_string = "" + for k, v in stats_dict.items(): + format_string += "|{} = {:.2f}".format(k, v / total) + print_rank_0("epoch:{}{}".format(epoch, format_string)) + print_rank_0("taken time to calcuate metrics {:.3f}".format(\ + time.time() - start_time)) + else: + raise AssertionError("{} Task not supported".format(args.task)) + + return metrics_func + + +def retrieval_loss(model, dataloader): + args = get_args() + total = 0 + topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \ + args.retriever_report_topk_accuracies} + stats_dict = dict(rank=0, **topk_stats_dict) + + assert len(model) == 1 + unwrapped_model = model[0] + unwrapped_model.eval() + + with torch.no_grad(): + # For all the batches in the dataset. + for batch in dataloader: + # Run the model forward. + query_tokens, query_mask, query_types, _, \ + context_tokens, context_mask, context_types, _, \ + neg_context_tokens, neg_context_mask, neg_context_types, \ + reference = process_batch(batch) + + query_logits, context_logits = unwrapped_model(query_tokens, + query_mask, query_types, + torch.cat([context_tokens, neg_context_tokens]), + torch.cat([context_mask, neg_context_mask]), + torch.cat([context_types, neg_context_types])) + + retrieval_scores = torch.matmul(query_logits, + torch.transpose(context_logits, 0, 1)) + + if args.retriever_score_scaling: + retrieval_scores = retrieval_scores / \ + math.sqrt(args.hidden_size) + + local_batch_size = query_logits.shape[0] + labels = torch.arange(local_batch_size).long().cuda() + + softmax_scores = F.softmax(retrieval_scores, dim=1) + sorted_vals, sorted_indices = torch.topk(softmax_scores, + k=softmax_scores.shape[1], + sorted=True) + + def topk_accuracy(k): + return torch.cuda.FloatTensor( + [sum([int(labels[i] in sorted_indices[i, :k]) for i in \ + range(local_batch_size)])]) + + def get_rank(): + return torch.cuda.FloatTensor( + [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \ + for i in range(local_batch_size)])]) + + topk_accs = [topk_accuracy(k) for k in \ + args.retriever_report_topk_accuracies] + rank = get_rank() + losses = average_losses_across_data_parallel_group([rank, \ + *topk_accs]) + + # create stats_dict with retrieval loss and all specified + # top-k accuracies + topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ + zip(args.retriever_report_topk_accuracies, losses[1:])} + temp_stats_dict = dict(rank=losses[0], **topk_acc_dict) + for k in stats_dict.keys(): + stats_dict[k] += temp_stats_dict[k] + total += local_batch_size + + unwrapped_model.train() + + return stats_dict, total diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..aed65ac979199a1d469a51d4c469ea9bd935e460 --- /dev/null +++ b/tasks/orqa/supervised/finetune.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ORQA finetuning/evaluation.""" + +from functools import partial +import sys + +import math +import torch +import torch.nn.functional as F + +from megatron import get_args, get_timers, get_tokenizer +from megatron import mpu, print_rank_0 +from megatron.indexer import IndexBuilder +from megatron.model.biencoder_model import biencoder_model_provider +from megatron.utils import average_losses_across_data_parallel_group +from pretrain_ict import get_group_world_size_rank +from tasks.finetune_utils import finetune +from tasks.orqa.supervised.eval_utils import accuracy_func_provider +from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn +from tasks.orqa.evaluate_utils import ORQAEvaluator + +# input_ is a 2D tensor +def check_and_append_tensor_for_gather(group, rank, world_size, input_): + + # gather the size of the first dimension of the tensor from all ranks + current_length = input_.size()[0] + first_dim = torch.tensor([[current_length]], + device=torch.cuda.current_device()) + input_list = [torch.empty_like(first_dim) for _ in range(world_size)] + input_list[rank].copy_(first_dim) + torch.distributed.all_gather(input_list, first_dim, group=group) + all_input_list = torch.cat(input_list, dim=0).contiguous() + max_length = torch.max(all_input_list) + + # if the size are different than the max, extend the tensor + # accordingly + if max_length > current_length: + padding=tuple([0] * (input_.dim() * 2 - 1)) + \ + tuple([max_length - current_length]) + input_ = F.pad(input=input_, pad=padding) + + return input_ + +def orqa(Dataset): + + def cross_entropy_forward_step(batch, model): + """Simple forward step with cross-entropy loss.""" + timers = get_timers() + tokenizer = get_tokenizer() + + # Get the batch. + timers('batch generator').start() + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + + group, rank, world_size = get_group_world_size_rank() + + query_tokens, query_mask, query_types, query_pad_mask, \ + context_tokens, context_mask, context_types, context_pad_mask, \ + neg_context_tokens, neg_context_mask, neg_context_types, \ + reference = process_batch(batch_) + + timers('batch generator').stop() + local_batch_size = query_tokens.shape[0] + + # Text representation of query and context + query_list, context_list = [], [] + for i in range(local_batch_size): + query_list.append(tokenizer.decode(query_tokens[i].tolist())) + context_list.append(tokenizer.decode(context_tokens[i].tolist())) + + if neg_context_tokens is not None: + neg_context_tokens = check_and_append_tensor_for_gather(group, + rank, world_size, neg_context_tokens) + neg_context_mask = check_and_append_tensor_for_gather(group, + rank, world_size, neg_context_mask) + neg_context_types = check_and_append_tensor_for_gather(group, + rank, world_size, neg_context_types) + + if neg_context_tokens is not None: + context_tokens = torch.cat([context_tokens, neg_context_tokens]) + context_mask = torch.cat([context_mask, neg_context_mask]) + context_types = torch.cat([context_types, neg_context_types]) + + # Forward model. + output_tensor = model(query_tokens, query_mask, + query_types, context_tokens, + context_mask, context_types) + return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens) + + + def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor): + args = get_args() + + local_batch_size = query_tokens.shape[0] + group, rank, world_size = get_group_world_size_rank() + # recall we assert that model_parallel_size == 1 + global_batch_size = world_size * local_batch_size + + query_logits, context_logits = output_tensor + + if world_size > 1: + input_ = torch.empty_like(context_logits).copy_(\ + context_logits).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank].copy_(input_) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Check if all-gather happens in order + assert tensor_list[rank].sum().item() == \ + context_logits.sum().item() + + # Preserves the gradient + tensor_list[rank] = context_logits + all_context_logits = torch.cat(tensor_list, dim=0).contiguous() + + # Query tensors + input_ = torch.empty_like(query_logits).copy_(\ + query_logits).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank].copy_(input_) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Check if all-gather happens in order + assert tensor_list[rank].sum().item() == query_logits.sum().item() + + # Preserves the gradient + tensor_list[rank] = query_logits + all_query_logits = torch.cat(tensor_list, dim=0).contiguous() + else: + all_query_logits = query_logits + all_context_logits = context_logits + + retrieval_scores = torch.matmul(all_query_logits, + torch.transpose(all_context_logits, 0, 1)) + # Scaling the retrieval scores + if args.retriever_score_scaling: + retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) + + if args.train_with_neg: + # if the world size is 3, local batch size is 4, and + # local context size is 8, what we want is + # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] + labels = [] + local_context_size = context_tokens.shape[0] + for i in range(world_size): + j = i * local_context_size + labels.extend(list(range(j, j + local_batch_size))) + labels = torch.LongTensor(labels).cuda() + assert len(labels) == global_batch_size + else: + labels = torch.arange(global_batch_size).long().cuda() + + # Cross-entropy loss. + softmax_scores = F.log_softmax(retrieval_scores, dim=1) + + loss = F.nll_loss(softmax_scores, labels, reduction='mean') + + max_score, max_idxs = torch.max(softmax_scores, 1) + correct_predictions_count = (max_idxs == labels).sum().float() + + # Reduce loss for logging. + reduced_loss = average_losses_across_data_parallel_group([loss, \ + correct_predictions_count]) + + # Loss scaling for correct losses in Supervised Retrieval + loss = loss * mpu.get_data_parallel_world_size() + + return loss, {'lm loss': reduced_loss[0], + 'correct_prediction_count': reduced_loss[1]} + + + def train_valid_datasets_provider(): + """Build train and validation dataset.""" + args = get_args() + tokenizer = get_tokenizer() + + train_dataset = Dataset('training', + args.train_data, + tokenizer, + args.retriever_seq_length, + evaluate=False) + valid_dataset = Dataset('validation', + args.valid_data, + tokenizer, + args.retriever_seq_length, + evaluate=True) + return train_dataset, valid_dataset + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + print_rank_0('building retriever model for {} ...'.format(args.task)) + + model = biencoder_model_provider(only_context_model=False, + only_query_model=False, + biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model, + pre_process=pre_process, post_process=post_process) + + return model + + def single_dataset_provider(datapath): + args = get_args() + tokenizer = get_tokenizer() + + name = datapath[0].split('/')[-1].split('.')[0] + return Dataset(name, + datapath, + tokenizer, + args.retriever_seq_length, + evaluate=True) + + def metrics_func_provider(): + """Provide metrics callback function.""" + return accuracy_func_provider(single_dataset_provider) + + """Finetune/evaluate.""" + finetune(train_valid_datasets_provider, + model_provider, + forward_step=cross_entropy_forward_step, + end_of_epoch_callback_provider=metrics_func_provider, + task_collate_fn=task_collate_fn) + +def main(): + args = get_args() + + if args.task == 'RET-FINETUNE-NQ': + from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset + else: + raise NotImplementedError('ORQA task {} is not implemented.'.format( + args.task)) + + orqa(Dataset) + diff --git a/tasks/orqa/unsupervised/nq.py b/tasks/orqa/unsupervised/nq.py new file mode 100644 index 0000000000000000000000000000000000000000..ca07fe4165cb780f53e50943612cc375c2e844e0 --- /dev/null +++ b/tasks/orqa/unsupervised/nq.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Data Loader for Google NQ dataset +""" + +from abc import ABC +import csv +from collections import OrderedDict +import numpy as np + +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset, BatchSampler + +from megatron import print_rank_0, get_args, get_tokenizer, mpu +from megatron.data.biencoder_dataset_utils import make_attention_mask + +def get_nq_dataset(qa_data, split): + args = get_args() + tokenizer = get_tokenizer() + + dataset = NQDataset('Google NQ {} Split'.format(split), + 'Google Natural Questions', + qa_data, + tokenizer, + args.retriever_seq_length) + return dataset + + +def process_nq_batch(batch): + query_tokens = batch['token_ids'].long().cuda() + query_mask = (batch['token_mask'] < 0.5).cuda() + query_types = batch['token_types'].long().cuda() + query_len = batch['seq_len'].long().cuda() + reference = batch['reference'] + + return query_tokens, query_mask, query_types, query_len, reference + + +class CustomDataLoader(DataLoader): + def __init__(self, dataset, eval=False, **kwargs): + if kwargs.get('collate_fn', None) is None: + kwargs['collate_fn'] = self._collate_fn + self.eval = eval + super().__init__(dataset, **kwargs) + + def _collate_fn(self, batch_data): + # generate batch + batch_size = len(batch_data) + tensorized = OrderedDict() + for d in batch_data: + for k, v in d.items(): + tensorized.setdefault(k, []).append(v) + assert len(tensorized) == 5 + + tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids']) + tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask']) + tensorized['token_types'] = torch.LongTensor(tensorized['token_types']) + tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len']) + return tensorized + + +def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None): + """Data loader. Note that batch-size is the local (per GPU) batch-size. + NOTE: This dataloader is not distributed !!! + """ + + args = get_args() + if micro_batch_size is None: + micro_batch_size = args.micro_batch_size + num_workers = args.num_workers + + sampler = torch.utils.data.SequentialSampler(dataset) + # importantly, drop_last must be False to get all the data. + batch_sampler = BatchSampler(sampler, + batch_size=micro_batch_size, + drop_last=False) + + # Data loader. Note that batch size is the per GPU batch size. + data_loader = CustomDataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=True) + return data_loader + + +def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length): + """Build token types and paddings, trim if needed, and pad if needed.""" + + src_text_ids = tokenizer.tokenize(src_text) + + return build_tokens_types_paddings_from_ids(src_text_ids, + max_seq_length, + tokenizer.cls, + tokenizer.sep, + tokenizer.pad) + + +def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \ + sep_id, pad_id): + """ + Build token types and paddings, trim if needed, and pad if needed. + + TODO: Design modular interface to reuse this function. This is getting + repeated multiple times in different tasks + """ + + enc_ids = [] + tokentypes_enc = [] + + # [CLS]. + enc_ids.append(cls_id) + tokentypes_enc.append(0) + + # A. + len_src = len(src_ids) + enc_ids.extend(src_ids) + tokentypes_enc.extend([0] * len_src) + + # Cap the size. + if len(enc_ids) > max_seq_length - 1: + enc_ids = enc_ids[0: max_seq_length - 1] + tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] + + # [SEP]. + enc_ids.append(sep_id) + tokentypes_enc.append(0) + + num_tokens_enc = len(enc_ids) + # Padding. + padding_length = max_seq_length - len(enc_ids) + if padding_length > 0: + enc_ids.extend([pad_id] * padding_length) + tokentypes_enc.extend([pad_id] * padding_length) + + return enc_ids, tokentypes_enc, num_tokens_enc + + +def build_sample(token_ids, token_types, num_tokens, reference): + """ + Convert to numpy and return a sample consumed by the + batch producer. + """ + + token_ids = np.array(token_ids, dtype=np.int64) + token_types = np.array(token_types, dtype=np.int64) + token_mask = make_attention_mask(token_ids, token_ids) + + sample = ({ + 'token_ids': token_ids, + 'token_mask': token_mask, + 'token_types': token_types, + 'seq_len': num_tokens, + 'reference': reference + }) + return sample + + +class NQDataset(ABC, Dataset): + """ + Open Retrieval Question Answering evaluation using Google NQ dataset. + """ + + def __init__(self, task_name, dataset_name, datapath, + tokenizer, max_seq_length): + # Store inputs. + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + print_rank_0(datapath) + self.samples = self.process_samples_from_single_path(datapath) + print_rank_0(' >> total number of samples: {}'.format(\ + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + + ques_tokens, tokentypes_enc, num_tokens_ques = \ + build_tokens_types_paddings_from_text(raw_sample['question'], + self.tokenizer, self.max_seq_length) + + sample = build_sample(ques_tokens, + tokentypes_enc, + num_tokens_ques, + raw_sample['answers']) + return sample + + @staticmethod + def process_samples_from_single_path(filename): + print_rank_0(' > Processing {} ...'.format(filename)) + samples = [] + total = 0 + + with open(filename, 'r') as ifile: + reader = csv.reader(ifile, delimiter='\t') + for row in reader: + question = row[0] + answers = eval(row[1]) + + sample = {'question': question, 'answers': answers} + total += 1 + samples.append(sample) + + if total % 1000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + return samples diff --git a/tasks/orqa/unsupervised/qa_utils.py b/tasks/orqa/unsupervised/qa_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..811a05834a47ce1e9f9cca9bae9e0f77f937b588 --- /dev/null +++ b/tasks/orqa/unsupervised/qa_utils.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# + +# The following code has been taken from +# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0 +# licensed as of now. More details on the license can be found +# at https://github.com/facebookresearch/DPR/blob/master/LICENSE + +""" + Set of utilities for Q&A results validation tasks - Retriver passage + validation and Reader predicted answer validation +""" + +import collections +import logging +import string +import unicodedata +from functools import partial +from multiprocessing import Pool as ProcessPool +from typing import Tuple, List, Dict + +import regex as re +from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer + +logger = logging.getLogger(__name__) + +QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\ + 'questions_doc_hits']) + +def calculate_matches(all_docs: Dict[object, Tuple[str, str]], + answers: List[List[str]], closest_docs: List[Tuple[List[object], + List[float]]], workers_num: int, match_type: str) -> QAMatchStats: + """ + Evaluates answers presence in the set of documents. This function is + supposed to be used with a large collection of documents and results. + It internally forks multiple sub-processes for evaluation and then + merges results + :param all_docs: dictionary of the entire documents database. + doc_id -> (doc_text, title) + :param answers: list of answers's list. One list per question + :param closest_docs: document ids of the top results along with their + scores + :param workers_num: amount of parallel threads to process data + :param match_type: type of answer matching. Refer to has_answer code for + available options + :return: matching information tuple. + top_k_hits - a list where the index is the amount of top documents retrieved + and the value is the total amount of valid matches across an entire + dataset. + questions_doc_hits - more detailed info with answer matches for every + question and every retrieved document + """ + global dpr_all_documents + dpr_all_documents = all_docs + + tok_opts = {} + tokenizer = SimpleTokenizer(**tok_opts) + + processes = ProcessPool( + processes=workers_num, + ) + + logger.info('Matching answers in top docs...') + + get_score_partial = partial(check_answer, match_type=match_type, + tokenizer=tokenizer) + + questions_answers_docs = zip(answers, closest_docs) + + scores = processes.map(get_score_partial, questions_answers_docs) + + logger.info('Per question validation results len=%d', len(scores)) + + n_docs = len(closest_docs[0][0]) + top_k_hits = [0] * n_docs + for question_hits in scores: + best_hit = next((i for i, x in enumerate(question_hits) if x), None) + if best_hit is not None: + top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] + + return QAMatchStats(top_k_hits, scores) + + +def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: + """ + Search through all the top docs to see if they have any of the answers. + """ + answers, (doc_ids, doc_scores) = questions_answers_docs + + global dpr_all_documents + hits = [] + + for i, doc_id in enumerate(doc_ids): + doc = dpr_all_documents[doc_id] + text = doc[0] + + answer_found = False + if text is None: # cannot find the document for some reason + logger.warning("no doc in db") + hits.append(False) + continue + + if has_answer(answers, text, tokenizer, match_type): + answer_found = True + hits.append(answer_found) + return hits + + +def has_answer(answers, text, tokenizer, match_type) -> bool: + """ + Check if a document contains an answer string. + If `match_type` is string, token matching is done between the text + and answer. + If `match_type` is regex, we search the whole text with the regex. + """ + text = _normalize(text) + + if match_type == 'string': + # Answer is a list of possible strings + text = tokenizer.tokenize(text).words(uncased=True) + + for single_answer in answers: + single_answer = _normalize(single_answer) + single_answer = tokenizer.tokenize(single_answer) + single_answer = single_answer.words(uncased=True) + + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i: i + len(single_answer)]: + return True + + elif match_type == 'regex': + # Answer is a regex + for single_answer in answers: + single_answer = _normalize(single_answer) + if regex_match(text, single_answer): + return True + return False + + +def regex_match(text, pattern): + """Test if a regex pattern is contained within a text.""" + try: + pattern = re.compile( + pattern, + flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, + ) + except BaseException: + return False + return pattern.search(text) is not None + + +# function for the reader model answer validation +def exact_match_score(prediction, ground_truth): + return _normalize_answer(prediction) == _normalize_answer(ground_truth) + + +def _normalize_answer(s): + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def _normalize(text): + return unicodedata.normalize('NFD', text) diff --git a/tasks/orqa/unsupervised/tokenizers.py b/tasks/orqa/unsupervised/tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..fb23887ebdd43ca83b2a6746ddc77b2a69fc1dd8 --- /dev/null +++ b/tasks/orqa/unsupervised/tokenizers.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# + +# The following code has been taken from +# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0 +# licensed as of now. More details on the license can be found +# at https://github.com/facebookresearch/DPR/blob/master/LICENSE + +""" +Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency +""" + +import copy +import logging + +import regex +import spacy + +logger = logging.getLogger(__name__) + + +class Tokens(object): + """A class to represent a list of tokenized text.""" + TEXT = 0 + TEXT_WS = 1 + SPAN = 2 + POS = 3 + LEMMA = 4 + NER = 5 + + def __init__(self, data, annotators, opts=None): + self.data = data + self.annotators = annotators + self.opts = opts or {} + + def __len__(self): + """The number of tokens.""" + return len(self.data) + + def slice(self, i=None, j=None): + """Return a view of the list of tokens from [i, j).""" + new_tokens = copy.copy(self) + new_tokens.data = self.data[i: j] + return new_tokens + + def untokenize(self): + """Returns the original text (with whitespace reinserted).""" + return ''.join([t[self.TEXT_WS] for t in self.data]).strip() + + def words(self, uncased=False): + """Returns a list of the text of each token + + Args: + uncased: lower cases text + """ + if uncased: + return [t[self.TEXT].lower() for t in self.data] + else: + return [t[self.TEXT] for t in self.data] + + def offsets(self): + """Returns a list of [start, end) character offsets of each token.""" + return [t[self.SPAN] for t in self.data] + + def pos(self): + """Returns a list of part-of-speech tags of each token. + Returns None if this annotation was not included. + """ + if 'pos' not in self.annotators: + return None + return [t[self.POS] for t in self.data] + + def lemmas(self): + """Returns a list of the lemmatized text of each token. + Returns None if this annotation was not included. + """ + if 'lemma' not in self.annotators: + return None + return [t[self.LEMMA] for t in self.data] + + def entities(self): + """Returns a list of named-entity-recognition tags of each token. + Returns None if this annotation was not included. + """ + if 'ner' not in self.annotators: + return None + return [t[self.NER] for t in self.data] + + def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): + """Returns a list of all ngrams from length 1 to n. + + Args: + n: upper limit of ngram length + uncased: lower cases text + filter_fn: user function that takes in an ngram list and returns + True or False to keep or not keep the ngram + as_string: return the ngram as a string vs list + """ + + def _skip(gram): + if not filter_fn: + return False + return filter_fn(gram) + + words = self.words(uncased) + ngrams = [(s, e + 1) + for s in range(len(words)) + for e in range(s, min(s + n, len(words))) + if not _skip(words[s:e + 1])] + + # Concatenate into strings + if as_strings: + ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] + + return ngrams + + def entity_groups(self): + """Group consecutive entity tokens with the same NER tag.""" + entities = self.entities() + if not entities: + return None + non_ent = self.opts.get('non_ent', 'O') + groups = [] + idx = 0 + while idx < len(entities): + ner_tag = entities[idx] + # Check for entity tag + if ner_tag != non_ent: + # Chomp the sequence + start = idx + while (idx < len(entities) and entities[idx] == ner_tag): + idx += 1 + groups.append((self.slice(start, idx).untokenize(), ner_tag)) + else: + idx += 1 + return groups + + +class Tokenizer(object): + """Base tokenizer class. + Tokenizers implement tokenize, which should return a Tokens class. + """ + + def tokenize(self, text): + raise NotImplementedError + + def shutdown(self): + pass + + def __del__(self): + self.shutdown() + + +class SimpleTokenizer(Tokenizer): + ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' + NON_WS = r'[^\p{Z}\p{C}]' + + def __init__(self, **kwargs): + """ + Args: + annotators: None or empty set (only tokenizes). + """ + self._regexp = regex.compile( + '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), + flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE + ) + if len(kwargs.get('annotators', {})) > 0: + logger.warning('%s only tokenizes! Skipping annotators: %s' % + (type(self).__name__, kwargs.get('annotators'))) + self.annotators = set() + + def tokenize(self, text): + data = [] + matches = [m for m in self._regexp.finditer(text)] + for i in range(len(matches)): + # Get text + token = matches[i].group() + + # Get whitespace + span = matches[i].span() + start_ws = span[0] + if i + 1 < len(matches): + end_ws = matches[i + 1].span()[0] + else: + end_ws = span[1] + + # Format data + data.append(( + token, + text[start_ws: end_ws], + span, + )) + return Tokens(data, self.annotators) + + +class SpacyTokenizer(Tokenizer): + + def __init__(self, **kwargs): + """ + Args: + annotators: set that can include pos, lemma, and ner. + model: spaCy model to use (either path, or keyword like 'en'). + """ + model = kwargs.get('model', 'en') + self.annotators = copy.deepcopy(kwargs.get('annotators', set())) + nlp_kwargs = {'parser': False} + if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + nlp_kwargs['tagger'] = False + if 'ner' not in self.annotators: + nlp_kwargs['entity'] = False + self.nlp = spacy.load(model, **nlp_kwargs) + + def tokenize(self, text): + # We don't treat new lines as tokens. + clean_text = text.replace('\n', ' ') + tokens = self.nlp.tokenizer(clean_text) + if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + self.nlp.tagger(tokens) + if 'ner' in self.annotators: + self.nlp.entity(tokens) + + data = [] + for i in range(len(tokens)): + # Get whitespace + start_ws = tokens[i].idx + if i + 1 < len(tokens): + end_ws = tokens[i + 1].idx + else: + end_ws = tokens[i].idx + len(tokens[i].text) + + data.append(( + tokens[i].text, + text[start_ws: end_ws], + (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), + tokens[i].tag_, + tokens[i].lemma_, + tokens[i].ent_type_, + )) + + # Set special option for non-entity tag: '' vs 'O' in spaCy + return Tokens(data, self.annotators, opts={'non_ent': ''}) diff --git a/tasks/race/data.py b/tasks/race/data.py new file mode 100644 index 0000000000000000000000000000000000000000..c4967a0842fc35b6cbfa20dff49a3dc93342f073 --- /dev/null +++ b/tasks/race/data.py @@ -0,0 +1,135 @@ + +import glob +import json +import os +import time + +from torch.utils.data import Dataset + +from megatron import print_rank_0 +from tasks.data_utils import build_sample +from tasks.data_utils import build_tokens_types_paddings_from_ids +from tasks.data_utils import clean_text + + +NUM_CHOICES = 4 +MAX_QA_LENGTH = 128 + + +class RaceDataset(Dataset): + + def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length, + max_qa_length=MAX_QA_LENGTH): + + self.dataset_name = dataset_name + print_rank_0(' > building RACE dataset for {}:'.format( + self.dataset_name)) + + string = ' > paths:' + for path in datapaths: + string += ' ' + path + print_rank_0(string) + + self.samples = [] + for datapath in datapaths: + self.samples.extend(process_single_datapath(datapath, tokenizer, + max_qa_length, + max_seq_length)) + + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + # This indicates that each "sample" has multiple samples that + # will collapse into batch dimension + self.sample_multiplier = NUM_CHOICES + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length): + """Read in RACE files, combine, clean-up, tokenize, and convert to + samples.""" + + print_rank_0(' > working on {}'.format(datapath)) + start_time = time.time() + + # Get list of files. + filenames = glob.glob(os.path.join(datapath, '*.txt')) + + samples = [] + num_docs = 0 + num_questions = 0 + num_samples = 0 + # Load all the files + for filename in filenames: + with open(filename, 'r') as f: + for line in f: + data = json.loads(line) + num_docs += 1 + + context = data["article"] + questions = data["questions"] + choices = data["options"] + answers = data["answers"] + # Check the length. + assert len(questions) == len(answers) + assert len(questions) == len(choices) + + # Context: clean up and convert to ids. + context = clean_text(context) + context_ids = tokenizer.tokenize(context) + + # Loop over questions. + for qi, question in enumerate(questions): + num_questions += 1 + # Label. + label = ord(answers[qi]) - ord("A") + assert label >= 0 + assert label < NUM_CHOICES + assert len(choices[qi]) == NUM_CHOICES + + # For each question, build num-choices samples. + ids_list = [] + types_list = [] + paddings_list = [] + for ci in range(NUM_CHOICES): + choice = choices[qi][ci] + # Merge with choice. + if "_" in question: + qa = question.replace("_", choice) + else: + qa = " ".join([question, choice]) + # Clean QA. + qa = clean_text(qa) + # Tokenize. + qa_ids = tokenizer.tokenize(qa) + # Trim if needed. + if len(qa_ids) > max_qa_length: + qa_ids = qa_ids[0:max_qa_length] + + # Build the sample. + ids, types, paddings \ + = build_tokens_types_paddings_from_ids( + qa_ids, context_ids, max_seq_length, + tokenizer.cls, tokenizer.sep, tokenizer.pad) + + ids_list.append(ids) + types_list.append(types) + paddings_list.append(paddings) + + # Convert to numpy and add to samples + samples.append(build_sample(ids_list, types_list, + paddings_list, label, + num_samples)) + num_samples += 1 + + elapsed_time = time.time() - start_time + print_rank_0(' > processed {} document, {} questions, and {} samples' + ' in {:.2f} seconds'.format(num_docs, num_questions, + num_samples, elapsed_time)) + + return samples diff --git a/tasks/race/finetune.py b/tasks/race/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..e03f927ceb00dcca4da7e5fedd740108f32574fd --- /dev/null +++ b/tasks/race/finetune.py @@ -0,0 +1,67 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Race.""" + +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_tokenizer +from megatron import mpu +from megatron.model.multiple_choice import MultipleChoice +from tasks.eval_utils import accuracy_func_provider +from tasks.finetune_utils import finetune +from tasks.race.data import RaceDataset + + +def train_valid_datasets_provider(): + """Provide train and validation datasets.""" + args = get_args() + tokenizer = get_tokenizer() + + train_dataset = RaceDataset('training', args.train_data, + tokenizer, args.seq_length) + valid_dataset = RaceDataset('validation', args.valid_data, + tokenizer, args.seq_length) + + return train_dataset, valid_dataset + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building multichoice model for RACE ...') + model = MultipleChoice(num_tokentypes=2, + pre_process=pre_process, + post_process=post_process) + + return model + + +def metrics_func_provider(): + """Privde metrics callback function.""" + args = get_args() + tokenizer = get_tokenizer() + + def single_dataset_provider(datapath): + name = datapath.split('RACE')[-1].strip('/').replace('/', '-') + return RaceDataset(name, [datapath], tokenizer, args.seq_length) + + return accuracy_func_provider(single_dataset_provider) + + +def main(): + + finetune(train_valid_datasets_provider, model_provider, + end_of_epoch_callback_provider=metrics_func_provider) diff --git a/tasks/vision/classification/classification.py b/tasks/vision/classification/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..be31da9bda10864e3a627920d7d45057b35fb9f8 --- /dev/null +++ b/tasks/vision/classification/classification.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Vision-classification finetuning/evaluation.""" + +import torch.nn.functional as F +from functools import partial +from megatron import get_args, get_timers +from megatron import print_rank_0 +from megatron.model.vision.classification import VitClassificationModel +from megatron.data.vit_dataset import build_train_valid_datasets +from tasks.vision.classification.eval_utils import accuracy_func_provider +from tasks.vision.finetune_utils import finetune +from megatron.utils import average_losses_across_data_parallel_group + + +def classification(): + def train_valid_datasets_provider(): + """Build train and validation dataset.""" + args = get_args() + + train_ds, valid_ds = build_train_valid_datasets( + data_path=args.data_path, + image_size=(args.img_h, args.img_w), + ) + return train_ds, valid_ds + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + + print_rank_0("building classification model for ImageNet ...") + + return VitClassificationModel(num_classes=args.num_classes, finetune=True, + pre_process=pre_process, post_process=post_process) + + def process_batch(batch): + """Process batch and produce inputs for the model.""" + images = batch[0].cuda().contiguous() + labels = batch[1].cuda().contiguous() + return images, labels + + def cross_entropy_loss_func(labels, output_tensor): + logits = output_tensor + + # Cross-entropy loss. + loss = F.cross_entropy(logits.contiguous().float(), labels) + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + def _cross_entropy_forward_step(batch, model): + """Simple forward step with cross-entropy loss.""" + timers = get_timers() + + # Get the batch. + timers("batch generator").start() + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + images, labels = process_batch(batch_) + timers("batch generator").stop() + + # Forward model. + output_tensor = model(images) + + return output_tensor, partial(cross_entropy_loss_func, labels) + + """Finetune/evaluate.""" + finetune( + train_valid_datasets_provider, + model_provider, + forward_step=_cross_entropy_forward_step, + end_of_epoch_callback_provider=accuracy_func_provider, + ) + +def main(): + classification() + diff --git a/tasks/vision/classification/eval_utils.py b/tasks/vision/classification/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db14c3dc77d7380523d4eaf12f865f66be6f2d69 --- /dev/null +++ b/tasks/vision/classification/eval_utils.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation utilities.""" + +import os +from functools import partial + +import torch + +from megatron import get_args +from megatron import print_rank_0, print_rank_last +from megatron import mpu +from megatron.schedules import get_forward_backward_func +from tasks.vision.finetune_utils import build_data_loader +from tasks.vision.finetune_utils import process_batch +from torchvision import datasets, transforms + + +def accuracy_func_provider(): + """Provide function that calculates accuracies.""" + args = get_args() + data_path = args.data_path + crop_size = (args.img_h, args.img_w) + + # Build dataloaders. + val_data_path = data_path[1] + normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + transform_val = transforms.Compose( + [ + transforms.Resize(crop_size), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + normalize, + ] + ) + dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val) + + dataloader = build_data_loader( + dataset, + args.micro_batch_size, + num_workers=args.num_workers, + drop_last=(mpu.get_data_parallel_world_size() > 1), + shuffle=False + ) + + def metrics_func(model, epoch): + print_rank_0("calculating metrics ...") + correct, total = calculate_correct_answers(model, dataloader, epoch) + percent = float(correct) * 100.0 / float(total) + print_rank_last( + " >> |epoch: {}| overall: correct / total = {} / {} = " + "{:.4f} %".format(epoch, correct, total, percent) + ) + + return metrics_func + + +def calculate_correct_answers(model, dataloader, epoch): + """Calculate correct over total answers""" + + forward_backward_func = get_forward_backward_func() + for m in model: + m.eval() + + def loss_func(labels, output_tensor): + logits = output_tensor + + loss_dict = {} + # Compute the correct answers. + predicted = torch.argmax(logits, dim=-1) + corrects = (predicted == labels).float() + # Add to the counters. + loss_dict['total'] = labels.size(0) + loss_dict['correct'] = corrects.sum().item() + + return 0, loss_dict + + #defined inside to capture output_predictions + def correct_answers_forward_step(batch, model): + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + images, labels = process_batch(batch_) + + # Forward model. + output_tensor = model(images) + + return output_tensor, partial(loss_func, labels) + + with torch.no_grad(): + # For all the batches in the dataset. + total = 0 + correct = 0 + for _, batch in enumerate(dataloader): + + loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, + optimizer=None, timers=None, forward_only=True) + + for loss_dict in loss_dicts: + total += loss_dict['total'] + correct += loss_dict['correct'] + + for m in model: + m.train() + + # Reduce. + if mpu.is_pipeline_last_stage(): + unreduced = torch.cuda.LongTensor([correct, total]) + torch.distributed.all_reduce(unreduced, + group=mpu.get_data_parallel_group()) + + # Print on screen. + correct_ans = unreduced[0].item() + total_count = unreduced[1].item() + return correct_ans, total_count diff --git a/tasks/vision/finetune_utils.py b/tasks/vision/finetune_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f95da5a0c4cbd5f870363d12fcedc574cd71475 --- /dev/null +++ b/tasks/vision/finetune_utils.py @@ -0,0 +1,312 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Finetune utilities.""" + +import torch +import torch.nn.functional as F +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import mpu, utils +from megatron.checkpointing import load_checkpoint +from megatron.checkpointing import save_checkpoint +from megatron.training import evaluate_and_print_results +from megatron.training import setup_model_and_optimizer +from megatron.training import train_step +from megatron.training import training_log +from megatron.utils import check_adlr_autoresume_termination +from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.model import Float16Module, ModelType + + +def process_batch(batch): + """Process batch and produce inputs for the model.""" + images = batch[0].cuda().contiguous() + labels = batch[1].cuda().contiguous() + return images, labels + + +def build_data_loader(dataset, micro_batch_size, + num_workers, drop_last, shuffle): + """Data loader. Note that batch-size is the local (per GPU) batch-size.""" + + # Sampler. + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, + drop_last=drop_last, shuffle=shuffle + ) + + # Data loader. Note that batch size is the per GPU batch size. + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=micro_batch_size, + sampler=sampler, + shuffle=False, + num_workers=num_workers, + drop_last=drop_last, + pin_memory=True, + ) + + return data_loader + + +def _build_infinite_size_dataloader(dataloader): + """Build a looped dataloader with infinite size.""" + + iterator = dataloader.__iter__() + while True: + try: + yield iterator.__next__() + except StopIteration: + iterator = dataloader.__iter__() + + +def _build_train_valid_dataloaders(train_dataset, valid_dataset): + """Traing and validation dataloaders.""" + args = get_args() + + print_rank_0('building train and validation dataloaders ...') + # Training dataset. + train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, + args.num_workers, False, True) + # Set the training iterations. + args.train_iters_per_epoch = len(train_dataloader) + args.train_iters = args.epochs * args.train_iters_per_epoch + # Validation dataset. For this dataset, we do not need to set up + # shuffling so we can just use a simple infinite loop. + valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, + args.num_workers, True, False) + valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) + + # Now that we've built the data loaders, set batch_size arguments + # to the actual batch size the model will see for this dataset. + # This is necessary so pipeline transfers know what size they are + # and the LR schedule, which is based on samples seen, gets set + # correctly. + args.orig_micro_batch_size = args.micro_batch_size + args.orig_global_batch_size = args.global_batch_size + + return train_dataloader, valid_dataloader + + +def _train( + model, + optimizer, + opt_param_scheduler, + forward_step, + train_dataloader, + valid_dataloader, + end_of_epoch_callback, + process_non_loss_data_func=None +): + """Train the model.""" + args = get_args() + timers = get_timers() + + # Turn on training mode which enables dropout. + for m in model: + m.train() + + # Tracking loss. + losses_dict_sum = {} + + # Starting epoch and iteration + start_epoch = args.iteration // args.train_iters_per_epoch + start_iteration = args.iteration % args.train_iters_per_epoch + iteration = args.iteration + + # Memory reporting flag. + report_memory_flag = True + + # For each remaining epoch + timers("interval-time").start() + for epoch in range(start_epoch, args.epochs): + print_rank_0("working on epoch {} ...".format(epoch + 1)) + + # Set the data loader epoch to shuffle the index iterator. + train_dataloader.sampler.set_epoch(args.seed + epoch) + train_dataloader.dataset.set_epoch(epoch) + + # For all the batches in the dataset. + for iteration_, batch in enumerate(train_dataloader): + + # Ignore the iterations before starting value + if iteration_ < start_iteration: + continue + # Set to zero so the next epoch does not skip any batches. + start_iteration = 0 + + # Train for one step. + losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( + forward_step, batch, model, optimizer, opt_param_scheduler + ) + iteration += 1 + + # Logging. + params_norm = None + + report_memory_flag = training_log( + losses_dict, + losses_dict_sum, + optimizer.param_groups[0]["lr"], + iteration, + optimizer.get_loss_scale().item(), + report_memory_flag, + skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad + ) + + # Autoresume + if args.adlr_autoresume and \ + iteration % args.adlr_autoresume_interval == 0: + check_adlr_autoresume_termination(iteration, model, optimizer, + opt_param_scheduler) + + # Checkpointing + if args.save and args.save_interval and \ + iteration % args.save_interval == 0: + save_checkpoint(iteration, model, optimizer, + opt_param_scheduler) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0: + prefix = "iteration {}".format(iteration) + evaluate_and_print_results( + prefix, + forward_step, + valid_dataloader, + model, + iteration, + process_non_loss_data_func, + False, + ) + + # Callback at the end of each epoch. + if end_of_epoch_callback is not None: + end_of_epoch_callback(model, epoch) + + +def finetune( + train_valid_datasets_provider, + model_provider, + forward_step, + model_type=ModelType.encoder_or_decoder, + process_non_loss_data_func=None, + end_of_epoch_callback_provider=None, +): + """Main finetune function used across all tasks.""" + args = get_args() + timers = get_timers() + + # Train and validation data loaders. + timers("train/valid/test dataset/dataloder").start() + if args.epochs > 0: + train_dataset, valid_dataset = train_valid_datasets_provider() + train_dataloader, valid_dataloader = _build_train_valid_dataloaders( + train_dataset, valid_dataset + ) + timers("train/valid/test dataset/dataloder").stop() + + # Build calback function. + timers("callback function").start() + end_of_epoch_callback = None + if end_of_epoch_callback_provider is not None: + end_of_epoch_callback = end_of_epoch_callback_provider() + timers("callback function").stop() + + # Build model, optimizer and learning rate scheduler. + timers("model and optimizer").start() + model, optimizer, opt_param_scheduler = \ + setup_model_and_optimizer( + model_provider, + model_type, + scale_lr_cond=lambda name, param: ".head." in name, + lr_mult=args.head_lr_mult) + timers("model and optimizer").stop() + + # If pretrained checkpoint is provided and we have not trained for + # any iteration (i.e., iteration is zero), then load the pretrained + # checkpoint. + timers("pretrained checkpoint").start() + if args.iteration == 0 and args.pretrained_checkpoint is not None: + if args.pretrained_checkpoint_type == 'default': + original_load = args.load + args.load = args.pretrained_checkpoint + _ = load_checkpoint(model, None, None, strict=False) + args.load = original_load + elif args.pretrained_checkpoint_type == 'external': + unwrap_model = utils.unwrap_model(model) + state_dict = torch.load(args.pretrained_checkpoint, + map_location="cpu") + unwrap_model[0].module.backbone.load_state_dict(state_dict, + strict=False) + elif args.pretrained_checkpoint_type == 'constrastive': + unwrap_model = utils.unwrap_model(model) + state_dict = torch.load(args.pretrained_checkpoint, + map_location="cpu") + state_dict = state_dict["model"] + state_dict = {k.replace("teacher.backbone.", ""): v + for k, v in state_dict.items() + if k.startswith("teacher.backbone.")} + unwrap_model[0].module.backbone.load_state_dict(state_dict, + strict=False) + else: + raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type)) + + # This is critical when only model is loaded. We should make sure + # master parameters are also updated. + optimizer.reload_model_params() + + timers("pretrained checkpoint").stop() + + # Print setup timing. + print_rank_0("done with setups ...") + timers.log( + [ + "train/valid/test dataset/dataloder", + "callback function", + "model and optimizer", + "pretrained checkpoint", + ] + ) + print_rank_0("training ...") + + # Finetune the model. + if args.epochs > 0: + _train( + model, + optimizer, + opt_param_scheduler, + forward_step, + train_dataloader, + valid_dataloader, + end_of_epoch_callback, + process_non_loss_data_func, + ) + # Or just evaluate. + else: + if end_of_epoch_callback is not None: + print_rank_0("evaluation only mode, setting epoch to -1") + end_of_epoch_callback(model, epoch=-1) + + print_rank_0("done :-)") + diff --git a/tasks/vision/main.py b/tasks/vision/main.py new file mode 100644 index 0000000000000000000000000000000000000000..ac789b20736909ed77bd1dbb6b6caed59d6bb285 --- /dev/null +++ b/tasks/vision/main.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main tasks functionality.""" + +import os +import sys + +sys.path.append( + os.path.abspath( + os.path.join( + os.path.join(os.path.dirname(__file__), os.path.pardir), + os.path.pardir, + ) + ) +) +from megatron import get_args +from megatron.initialize import initialize_megatron + +def get_tasks_args(parser): + """Provide extra arguments required for tasks.""" + group = parser.add_argument_group(title="tasks") + + group.add_argument('--task', type=str, default='segment', + choices=['classify', 'segment_setr', 'segment_segformer'], + help='task name.') + group.add_argument("--epochs", type=int, default=None, + help="Number of finetunning epochs. Zero results in " + "evaluation only.") + group.add_argument('--pretrained-checkpoint-type', type=str, default='default', + choices=['default', 'external', 'constrastive'], + help='Type of pretrained checkpoint') + group.add_argument("--pretrained-checkpoint", type=str, default=None, + help="Pretrained checkpoint used for finetunning.") + group.add_argument('--seg-stride', type=int, default=None, + help='sliding window stride during evaluation') + return parser + + +if __name__ == "__main__": + + initialize_megatron(extra_args_provider=get_tasks_args) + args = get_args() + + if args.task == 'classify': + from tasks.vision.classification.classification import main + main() + elif args.task == 'segment_setr': + from tasks.vision.segmentation.finetune_setr import main + main() + elif args.task == 'segment_segformer': + from tasks.vision.segmentation.finetune_segformer import main + main() + diff --git a/tasks/vision/segmentation/cityscapes.py b/tasks/vision/segmentation/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..1a182288f2631d0dca1e282233bacdeb474be940 --- /dev/null +++ b/tasks/vision/segmentation/cityscapes.py @@ -0,0 +1,207 @@ +# BSD 3-Clause License +# +# Copyright (c) Soumith Chintala 2016, +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# code taken from +# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py +# modified it to change max label index from 255 to 19 (num_classes) + +import torch +import json +import os +from collections import namedtuple +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import numpy as np +from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str +from torchvision.datasets import VisionDataset +from PIL import Image +from megatron import print_rank_0 + + +class Cityscapes(VisionDataset): + """`Cityscapes