|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Wikipedia dataset from DPR code for ORQA.""" |
|
|
|
from abc import ABC |
|
import csv |
|
import numpy as np |
|
import random |
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
from megatron import print_rank_0, get_args, get_tokenizer, mpu |
|
from megatron.data.biencoder_dataset_utils import make_attention_mask |
|
|
|
def get_open_retrieval_wiki_dataset(): |
|
args = get_args() |
|
tokenizer = get_tokenizer() |
|
|
|
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase', |
|
'evidence', |
|
args.evidence_data_path, |
|
tokenizer, |
|
args.retriever_seq_length) |
|
return dataset |
|
|
|
|
|
def get_open_retrieval_batch(data_iterator): |
|
|
|
keys = ['row_id', 'context', 'context_mask', 'context_types', |
|
'context_pad_mask'] |
|
datatype = torch.int64 |
|
|
|
|
|
data = None if data_iterator is None else next(data_iterator) |
|
data_b = mpu.broadcast_data(keys, data, datatype) |
|
|
|
|
|
row_id = data_b['row_id'].long() |
|
context = data_b['context'].long() |
|
|
|
|
|
context_mask = (data_b['context_mask'] < 0.5) |
|
|
|
context_types = data_b['context_types'].long() |
|
context_pad_mask = data_b['context_pad_mask'].long() |
|
|
|
return row_id, context, context_mask, context_types, context_pad_mask |
|
|
|
|
|
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length): |
|
"""Build token types and paddings, trim if needed, and pad if needed.""" |
|
|
|
title_ids = tokenizer.tokenize(row['title']) |
|
context_ids = tokenizer.tokenize(row['text']) |
|
|
|
|
|
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids |
|
|
|
context_ids, context_types, context_pad_mask = \ |
|
build_tokens_types_paddings_from_ids(extended_context_ids, |
|
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) |
|
|
|
return context_ids, context_types, context_pad_mask |
|
|
|
|
|
|
|
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, |
|
cls_id, sep_id, pad_id): |
|
"""Build token types and paddings, trim if needed, and pad if needed.""" |
|
enc_ids = [] |
|
tokentypes_enc = [] |
|
|
|
|
|
enc_ids.append(cls_id) |
|
tokentypes_enc.append(0) |
|
|
|
|
|
len_src = len(text_ids) |
|
enc_ids.extend(text_ids) |
|
tokentypes_enc.extend([0] * len_src) |
|
|
|
|
|
if len(enc_ids) > max_seq_length - 1: |
|
enc_ids = enc_ids[0: max_seq_length - 1] |
|
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] |
|
|
|
|
|
enc_ids.append(sep_id) |
|
tokentypes_enc.append(0) |
|
|
|
num_tokens_enc = len(enc_ids) |
|
|
|
padding_length = max_seq_length - len(enc_ids) |
|
if padding_length > 0: |
|
enc_ids.extend([pad_id] * padding_length) |
|
tokentypes_enc.extend([pad_id] * padding_length) |
|
|
|
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) |
|
pad_mask = np.array(pad_mask, dtype=np.int64) |
|
|
|
return enc_ids, tokentypes_enc, pad_mask |
|
|
|
|
|
def build_sample(row_id, context_ids, context_types, context_pad_mask): |
|
"""Convert to numpy and return a sample consumed by the batch producer.""" |
|
|
|
context_ids = np.array(context_ids, dtype=np.int64) |
|
context_types = np.array(context_types, dtype=np.int64) |
|
context_mask = make_attention_mask(context_ids, context_ids) |
|
|
|
sample = ({ |
|
'row_id': row_id, |
|
'context': context_ids, |
|
'context_mask': context_mask, |
|
'context_types': context_types, |
|
'context_pad_mask': context_pad_mask |
|
}) |
|
return sample |
|
|
|
|
|
class OpenRetrievalEvidenceDataset(ABC, Dataset): |
|
"""Open Retrieval Evidence dataset class.""" |
|
|
|
def __init__(self, task_name, dataset_name, datapath, tokenizer, |
|
max_seq_length): |
|
|
|
self.task_name = task_name |
|
self.dataset_name = dataset_name |
|
self.tokenizer = tokenizer |
|
self.max_seq_length = max_seq_length |
|
print_rank_0(' > building {} dataset for {}:'.format(self.task_name, |
|
self.dataset_name)) |
|
|
|
print_rank_0(datapath) |
|
self.samples, self.id2text = self.process_samples_from_single_path( |
|
datapath) |
|
|
|
args = get_args() |
|
if args.sample_rate < 1: |
|
k = int(len(self.samples) * args.sample_rate) |
|
self.samples = random.sample(self.samples, k) |
|
|
|
print_rank_0(' >> total number of samples: {}'.format( |
|
len(self.samples))) |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
row = self.samples[idx] |
|
|
|
context_ids, context_types, context_pad_mask = \ |
|
build_tokens_types_paddings_from_text(row, self.tokenizer, |
|
self.max_seq_length) |
|
|
|
sample = build_sample(row['doc_id'], |
|
context_ids, |
|
context_types, |
|
context_pad_mask) |
|
return sample |
|
|
|
@staticmethod |
|
def process_samples_from_single_path(filename): |
|
print_rank_0(' > Processing {} ...'.format(filename)) |
|
total = 0 |
|
|
|
rows = [] |
|
id2text = {} |
|
|
|
with open(filename) as tsvfile: |
|
reader = csv.reader(tsvfile, delimiter='\t') |
|
next(reader, None) |
|
for row in reader: |
|
|
|
doc_id = int(row[0]) |
|
text = row[1] |
|
title = row[2] |
|
|
|
rows.append({'doc_id': doc_id, |
|
'text': text, |
|
'title': title}) |
|
|
|
assert doc_id not in id2text |
|
id2text[doc_id] = (text, title) |
|
|
|
total += 1 |
|
if total % 100000 == 0: |
|
print_rank_0(' > processed {} rows so far ...'.format( |
|
total)) |
|
|
|
print_rank_0(' >> processed {} samples.'.format(len(rows))) |
|
return rows, id2text |
|
|