"""Wikipedia dataset from DPR code for ORQA.""" |
from abc import ABC |
import csv |
import numpy as np |
import random |
import torch |
from torch.utils.data import Dataset |
from megatron import print_rank_0, get_args, get_tokenizer, mpu |
from megatron.data.biencoder_dataset_utils import make_attention_mask |
def get_open_retrieval_wiki_dataset(): |
args = get_args() |
tokenizer = get_tokenizer() |
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase', |
'evidence', |
args.evidence_data_path, |
tokenizer, |
args.retriever_seq_length) |
return dataset |
def get_open_retrieval_batch(data_iterator): |
keys = ['row_id', 'context', 'context_mask', 'context_types', |
'context_pad_mask'] |
datatype = torch.int64 |
data = None if data_iterator is None else next(data_iterator) |
data_b = mpu.broadcast_data(keys, data, datatype) |
row_id = data_b['row_id'].long() |
context = data_b['context'].long() |
context_mask = (data_b['context_mask'] < 0.5) |
context_types = data_b['context_types'].long() |
context_pad_mask = data_b['context_pad_mask'].long() |
return row_id, context, context_mask, context_types, context_pad_mask |
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length): |
"""Build token types and paddings, trim if needed, and pad if needed.""" |
title_ids = tokenizer.tokenize(row['title']) |
context_ids = tokenizer.tokenize(row['text']) |
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids |
context_ids, context_types, context_pad_mask = \ |
build_tokens_types_paddings_from_ids(extended_context_ids, |
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) |
return context_ids, context_types, context_pad_mask |
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, |
cls_id, sep_id, pad_id): |
"""Build token types and paddings, trim if needed, and pad if needed.""" |
enc_ids = [] |
tokentypes_enc = [] |
enc_ids.append(cls_id) |
tokentypes_enc.append(0) |
len_src = len(text_ids) |
enc_ids.extend(text_ids) |
tokentypes_enc.extend([0] * len_src) |
if len(enc_ids) > max_seq_length - 1: |
enc_ids = enc_ids[0: max_seq_length - 1] |
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] |
enc_ids.append(sep_id) |
tokentypes_enc.append(0) |
num_tokens_enc = len(enc_ids) |
padding_length = max_seq_length - len(enc_ids) |
if padding_length > 0: |
enc_ids.extend([pad_id] * padding_length) |
tokentypes_enc.extend([pad_id] * padding_length) |
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) |
pad_mask = np.array(pad_mask, dtype=np.int64) |
return enc_ids, tokentypes_enc, pad_mask |
def build_sample(row_id, context_ids, context_types, context_pad_mask): |
"""Convert to numpy and return a sample consumed by the batch producer.""" |
context_ids = np.array(context_ids, dtype=np.int64) |
context_types = np.array(context_types, dtype=np.int64) |
context_mask = make_attention_mask(context_ids, context_ids) |
sample = ({ |
'row_id': row_id, |
'context': context_ids, |
'context_mask': context_mask, |
'context_types': context_types, |
'context_pad_mask': context_pad_mask |
}) |
return sample |
class OpenRetrievalEvidenceDataset(ABC, Dataset): |
"""Open Retrieval Evidence dataset class.""" |
def __init__(self, task_name, dataset_name, datapath, tokenizer, |
max_seq_length): |
self.task_name = task_name |
self.dataset_name = dataset_name |
self.tokenizer = tokenizer |
self.max_seq_length = max_seq_length |
print_rank_0(' > building {} dataset for {}:'.format(self.task_name, |
self.dataset_name)) |
print_rank_0(datapath) |
self.samples, self.id2text = self.process_samples_from_single_path( |
datapath) |
args = get_args() |
if args.sample_rate < 1: |
k = int(len(self.samples) * args.sample_rate) |
self.samples = random.sample(self.samples, k) |
print_rank_0(' >> total number of samples: {}'.format( |
len(self.samples))) |
def __len__(self): |
return len(self.samples) |
def __getitem__(self, idx): |
row = self.samples[idx] |
context_ids, context_types, context_pad_mask = \ |
build_tokens_types_paddings_from_text(row, self.tokenizer, |
self.max_seq_length) |
sample = build_sample(row['doc_id'], |
context_ids, |
context_types, |
context_pad_mask) |
return sample |
@staticmethod |
def process_samples_from_single_path(filename): |
print_rank_0(' > Processing {} ...'.format(filename)) |
total = 0 |
rows = [] |
id2text = {} |
with open(filename) as tsvfile: |
reader = csv.reader(tsvfile, delimiter='\t') |
next(reader, None) |
for row in reader: |
doc_id = int(row[0]) |
text = row[1] |
title = row[2] |
rows.append({'doc_id': doc_id, |
'text': text, |
'title': title}) |
assert doc_id not in id2text |
id2text[doc_id] = (text, title) |
total += 1 |
if total % 100000 == 0: |
print_rank_0(' > processed {} rows so far ...'.format( |
total)) |
print_rank_0(' >> processed {} samples.'.format(len(rows))) |
return rows, id2text |