|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""MNLI dataset.""" |
|
|
|
from megatron import print_rank_0 |
|
from tasks.data_utils import clean_text |
|
from .data import GLUEAbstractDataset |
|
import json |
|
from tasks.label_dict import get_label_dict |
|
|
|
LABELS = get_label_dict("WSC") |
|
|
|
class WSCDataset(GLUEAbstractDataset): |
|
|
|
def __init__(self, name, datapaths, tokenizer, max_seq_length, |
|
test_label="false"): |
|
self.test_label = test_label |
|
super().__init__('WSC', name, datapaths, |
|
tokenizer, max_seq_length) |
|
|
|
def process_samples_from_single_path(self, filename): |
|
""""Implement abstract method.""" |
|
print_rank_0(' > Processing {} ...'.format(filename)) |
|
|
|
samples = [] |
|
total = 0 |
|
first = True |
|
is_test = False |
|
with open(filename, 'r') as f: |
|
reader = f.readlines() |
|
lines = [] |
|
for line in reader: |
|
lines.append(json.loads(line.strip())) |
|
drop_cnt = 0 |
|
for index, row in enumerate(lines): |
|
if "id" not in row: |
|
row["id"] = index |
|
text_a = row['text'] |
|
text_a_list = list(text_a) |
|
target = row['target'] |
|
query = target['span1_text'] |
|
query_idx = target['span1_index'] |
|
pronoun = target['span2_text'] |
|
pronoun_idx = target['span2_index'] |
|
assert text_a[pronoun_idx: (pronoun_idx + len(pronoun))] == pronoun, "pronoun: {}".format(pronoun) |
|
assert text_a[query_idx: (query_idx + len(query))] == query, "query: {}".format(query) |
|
if pronoun_idx > query_idx: |
|
text_a_list.insert(query_idx, "_") |
|
text_a_list.insert(query_idx + len(query) + 1, "_") |
|
text_a_list.insert(pronoun_idx + 2, "[") |
|
text_a_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") |
|
else: |
|
text_a_list.insert(pronoun_idx, "[") |
|
text_a_list.insert(pronoun_idx + len(pronoun) + 1, "]") |
|
text_a_list.insert(query_idx + 2, "_") |
|
text_a_list.insert(query_idx + len(query) + 2 + 1, "_") |
|
text_a = "".join(text_a_list) |
|
|
|
text_b = None |
|
if first: |
|
first = False |
|
if "label" not in row: |
|
is_test = True |
|
print_rank_0( |
|
' reading {}, {} and {} columns and setting ' |
|
'labels to {}'.format( |
|
row["id"], text_a, |
|
text_b, self.test_label)) |
|
else: |
|
is_test = False |
|
print_rank_0(' reading {} , {}, {}, and {} columns ' |
|
'...'.format( |
|
row["id"], text_a, |
|
text_b, row["label"].strip())) |
|
|
|
text_a = text_a |
|
text_b = text_b |
|
|
|
unique_id = int(row["id"]) |
|
|
|
if is_test: |
|
label = self.test_label |
|
else: |
|
label = row["label"].strip() |
|
|
|
assert len(text_a) > 0 |
|
|
|
assert label in LABELS, "found label {} {} {}".format(label, row, type(label)) |
|
assert unique_id >= 0 |
|
|
|
sample = {'text_a': text_a, |
|
'text_b': text_b, |
|
'label': LABELS[label], |
|
'uid': unique_id} |
|
total += 1 |
|
samples.append(sample) |
|
|
|
if total % 5000 == 0: |
|
print_rank_0(' > processed {} so far ...'.format(total)) |
|
|
|
print_rank_0(' >> processed {} samples.'.format(len(samples))) |
|
print_rank_0(' >> drop {} samples.'.format(drop_cnt)) |
|
|
|
return samples |
|
|