|
import os |
|
import copy |
|
import pickle |
|
import random |
|
import json |
|
import glob |
|
import numpy as np |
|
from uniperceiver.config import configurable |
|
from uniperceiver.functional import dict_as_tensor |
|
from uniperceiver.tokenization import ClipTokenizer |
|
from ..build import DATASETS_REGISTRY |
|
import pyarrow as pa |
|
|
|
__all__ = ["GLUEDataset"] |
|
|
|
|
|
@DATASETS_REGISTRY.register() |
|
class GLUEDataset: |
|
@configurable |
|
def __init__( |
|
self, |
|
cfg: dict, |
|
stage: str, |
|
anno_file: str, |
|
max_seq_len: int, |
|
tokenizer, |
|
tokenizer_name, |
|
input_columns, |
|
label_column, |
|
input_count, |
|
task_name, |
|
data_percentage, |
|
data_k_sample, |
|
): |
|
self.cfg = cfg |
|
self.stage = stage |
|
self.anno_file = anno_file |
|
self.tokenizer = tokenizer |
|
self.tokenizer_name = tokenizer_name |
|
self.max_seq_len = max_seq_len |
|
|
|
self.input_columns = input_columns |
|
self.label_column = label_column |
|
self.input_count = input_count |
|
|
|
self.task_name = task_name |
|
|
|
self.data_percentage = data_percentage |
|
self.data_k_sample = data_k_sample |
|
|
|
self.task_info = { |
|
'task_type' : self.cfg.DATASETS.TASK_TYPE, |
|
'dataset_name' : self.cfg.DATASETS.DATASET_NAME, |
|
'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE, |
|
'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT, |
|
} |
|
self.target_set = cfg.DATASETS.TARGET_SET |
|
|
|
self.load_data(cfg) |
|
|
|
@classmethod |
|
def from_config(cls, cfg, stage: str = "train"): |
|
task_name = cfg.DATASETS.DATASET_NAME |
|
namesmapping = { |
|
"train": "train", |
|
"val": "dev", |
|
"test": "test", |
|
} |
|
data_dir = cfg.DATALOADER.ANNO_FOLDER |
|
if task_name in ['MNLI', 'QNLI', 'QQP', 'RTE', 'SST-2', 'MRPC', 'CoLA', 'STS-B']: |
|
anno_file = os.path.join(data_dir, task_name, 'processed/{name}.tsv'.format(name=namesmapping[stage])) |
|
elif task_name == 'MNLI_Match': |
|
namesmapping = { |
|
"train": "train", |
|
"val": "dev_matched", |
|
"test": "test_matched", |
|
} |
|
anno_file = os.path.join(data_dir, 'MNLI', 'processed/{name}.tsv'.format(name=namesmapping[stage])) |
|
elif task_name == 'MNLI_Mismatch': |
|
namesmapping = { |
|
"train": "train", |
|
"val": "dev_mismatched", |
|
"test": "test_mismatched", |
|
} |
|
anno_file = os.path.join(data_dir, 'MNLI', 'processed/{name}.tsv'.format(name=namesmapping[stage])) |
|
|
|
input_count = 2 |
|
if task_name == "QQP": |
|
input_columns = [4, 5] |
|
if stage == 'test': |
|
input_columns = [2, 3] |
|
label_column = 6 |
|
elif task_name in ["MNLI_Match", "MNLI_Mismatch"]: |
|
input_columns = [9, 10] |
|
if stage == 'test': |
|
input_columns = [9, 10] |
|
|
|
label_column = 12 |
|
if stage == 'val': |
|
label_column = 16 |
|
elif task_name == "QNLI": |
|
input_columns = [2, 3] |
|
if stage == 'test': |
|
input_columns = [2, 3] |
|
label_column = 4 |
|
elif task_name == "MRPC": |
|
input_columns = [4, 5] |
|
if stage == 'test': |
|
input_columns = [4, 5] |
|
label_column = 1 |
|
elif task_name == "RTE": |
|
input_columns = [2, 3] |
|
if stage == 'test': |
|
input_columns = [2, 3] |
|
label_column = 4 |
|
elif task_name == "STS-B": |
|
input_columns = [8, 9] |
|
if stage == 'test': |
|
input_columns = [8, 9] |
|
label_column = 10 |
|
|
|
elif task_name == "SST-2": |
|
input_columns = [1] |
|
if stage == 'test': |
|
input_columns = [2] |
|
label_column = 2 |
|
input_count = 1 |
|
elif task_name == "CoLA": |
|
input_columns = [4] |
|
if stage == 'test': |
|
input_columns = [2] |
|
label_column = 2 |
|
input_count = 1 |
|
else: |
|
raise NotImplementedError |
|
|
|
ret = { |
|
"cfg": cfg, |
|
"stage": stage, |
|
"anno_file": anno_file, |
|
"max_seq_len": cfg.MODEL.MAX_SEQ_LEN, |
|
"input_columns": input_columns, |
|
"label_column": label_column, |
|
"input_count": input_count, |
|
"task_name": task_name, |
|
"data_percentage": getattr(cfg.DATALOADER, "DATA_PERCENTAGE", 1.0), |
|
"data_k_sample": getattr(cfg.DATALOADER, "DATA_K_SAMPLE", -1), |
|
"tokenizer": ClipTokenizer(), |
|
"tokenizer_name": "clip" |
|
} |
|
|
|
return ret |
|
|
|
|
|
|
|
def load_data(self, cfg): |
|
cache_path = os.path.join(os.path.dirname(self.anno_file), "cache_GLUE_raw_%s_%s_%s.pkl" % (self.task_name, self.tokenizer_name, self.stage)) |
|
if not os.path.exists(cache_path): |
|
datalist = self.load_raw_data(cfg) |
|
|
|
pickle.dump(datalist, open(cache_path, "wb")) |
|
|
|
datalist = pickle.load(open(cache_path, "rb")) |
|
|
|
|
|
|
|
if self.data_percentage < 1.0 and self.stage == "train": |
|
print("will sample {} data for trianing-->".format(self.data_percentage)) |
|
labels2l = dict() |
|
for data in datalist: |
|
|
|
label = data['label'] |
|
if label not in labels2l: |
|
labels2l[label] = list() |
|
labels2l[label].append(data) |
|
|
|
|
|
datalist = [] |
|
|
|
for v in labels2l.values(): |
|
datalist.extend(random.sample(v, k=int(self.data_percentage * len(v) + 1))) |
|
|
|
|
|
elif self.data_k_sample > 0 and self.stage == "train": |
|
print("will sample {} data for each class when training -->".format(self.data_k_sample)) |
|
labels2l = dict() |
|
for data in datalist: |
|
|
|
label = data['label'] |
|
if label not in labels2l: |
|
labels2l[label] = list() |
|
labels2l[label].append(data) |
|
|
|
datalist = [] |
|
|
|
for v in labels2l.values(): |
|
datalist.extend(random.sample(v, k=int(self.data_k_sample))) |
|
|
|
while len(datalist) < 200: |
|
datalist = datalist + datalist |
|
|
|
self.datalist = datalist |
|
|
|
|
|
def load_raw_data(self, cfg): |
|
datalist = [] |
|
if self.task_name.startswith("MNLI"): |
|
labelmapping = { |
|
"contradiction": 0, |
|
"neutral": 1, |
|
"entailment": 2, |
|
} |
|
fin = open(self.anno_file, 'r').readlines() |
|
for _, line in enumerate(fin): |
|
sensinfo = line.strip().split('\t') |
|
if self.task_name == "RTE": |
|
label = 1.0 if sensinfo[self.label_column - 1] == "entailment" else 0.0 |
|
elif self.task_name.startswith("MNLI"): |
|
label = labelmapping[sensinfo[self.label_column - 1]] |
|
elif self.task_name == "QNLI": |
|
label = 1.0 if sensinfo[self.label_column - 1] == "entailment" else 0.0 |
|
elif self.task_name == "STS-B": |
|
label = float(sensinfo[self.label_column - 1]) / 5.0 |
|
else: |
|
label = float(sensinfo[self.label_column - 1]) |
|
datalist.append({ |
|
|
|
"sentences": [sensinfo[i - 1] for i in self.input_columns], |
|
"label": label |
|
}) |
|
return datalist |
|
|
|
def __len__(self): |
|
return len(self.datalist) |
|
|
|
def __getitem__(self, index): |
|
dataset_dict = copy.deepcopy(self.datalist[index]) |
|
|
|
sentences = dataset_dict['sentences'] |
|
|
|
|
|
|
|
if self.input_count == 1: |
|
|
|
|
|
if self.task_name == "SST-2": |
|
tokens = self.tokenizer.encode(sentences[0] + " <|endoftext|> It is <|spe|>. <|endoftext|>") |
|
elif self.task_name == "CoLA": |
|
tokens = self.tokenizer.encode(sentences[0] + " This is <|spe|>. <|endoftext|>") |
|
else: |
|
raise NotImplementedError |
|
|
|
index = len(tokens) - 3 |
|
assert index < self.max_seq_len |
|
if len(tokens) > self.max_seq_len: |
|
tokens = tokens[:self.max_seq_len - 4] + tokens[-4:] |
|
|
|
|
|
|
|
else: |
|
|
|
if self.task_name in ["RTE"]: |
|
tokens1 = self.tokenizer.encode(sentences[0]) |
|
if tokens1[-1] == 269: |
|
tokens1 = tokens1[:-1] |
|
tokens1 = tokens1 + self.tokenizer.encode(" ? <|endoftext|> it is ") |
|
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ") |
|
|
|
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2 |
|
if len(tokens2) > self.max_seq_len // 2: |
|
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]] |
|
max_len = self.max_seq_len - len(tokens2) |
|
|
|
elif self.task_name in ["MRPC"]: |
|
tokens1 = self.tokenizer.encode(sentences[0]) |
|
if tokens1[-1] == 269: |
|
tokens1 = tokens1[:-1] |
|
tokens1 = tokens1 + self.tokenizer.encode(" . ") |
|
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ") |
|
|
|
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2 |
|
if len(tokens2) > self.max_seq_len // 2: |
|
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]] |
|
max_len = self.max_seq_len - len(tokens2) |
|
|
|
elif self.task_name in ["QQP"]: |
|
tokens1 = self.tokenizer.encode(sentences[0]) |
|
if tokens1[-1] == 269: |
|
tokens1 = tokens1[:-1] |
|
tokens1 = tokens1 + self.tokenizer.encode(" <|endoftext|> ") |
|
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ") |
|
|
|
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2 |
|
if len(tokens2) > self.max_seq_len // 2: |
|
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]] |
|
max_len = self.max_seq_len - len(tokens2) |
|
|
|
elif self.task_name in ["QNLI"]: |
|
tokens1 = self.tokenizer.encode(sentences[0]) |
|
if tokens1[-1] == 269: |
|
tokens1 = tokens1[:-1] |
|
tokens1 = tokens1 + self.tokenizer.encode(" <|endoftext|> it is ") |
|
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ") |
|
|
|
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2 |
|
if len(tokens2) > self.max_seq_len - len(tokens1): |
|
tokens2 = tokens2[:self.max_seq_len - len(tokens1) - 1] + [tokens2[-1]] |
|
max_len = self.max_seq_len - len(tokens2) |
|
|
|
elif self.task_name in ["MNLI", "MNLI_Match"]: |
|
|
|
tokens1 = self.tokenizer.encode(sentences[0]) |
|
|
|
|
|
tokens1 = tokens1 |
|
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ") |
|
|
|
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2 |
|
if len(tokens2) > self.max_seq_len // 2: |
|
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]] |
|
max_len = self.max_seq_len - len(tokens2) |
|
|
|
elif self.task_name in ["RTE", "QNLI", "MNLI", "MNLI_Match"]: |
|
tokens1 = self.tokenizer.encode(sentences[0] + "? <|endoftext|>") |
|
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ") |
|
|
|
if tokens1[-1] == 269: |
|
tokens1 = tokens1[:-1] |
|
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2 |
|
if len(tokens2) > self.max_seq_len // 2: |
|
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]] |
|
max_len = self.max_seq_len - len(tokens2) |
|
elif self.task_name in ["MRPC", "QQP"]: |
|
tokens1 = self.tokenizer.encode(sentences[0] + " <|endoftext|>") |
|
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ") |
|
tokens2 = self.tokenizer.encode(" <|spe|>, ") + tokens2 |
|
if len(tokens2) > self.max_seq_len // 2: |
|
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]] |
|
max_len = self.max_seq_len - len(tokens2) |
|
else: |
|
NotImplementedError |
|
|
|
|
|
if len(tokens1) > max_len: |
|
tokens1 = tokens1[:max_len - 1] + [tokens1[-1]] |
|
|
|
tokens = tokens1 + tokens2 |
|
|
|
index = len(tokens1) |
|
assert index < self.max_seq_len |
|
|
|
|
|
sentences = np.array(tokens, dtype=np.int64) |
|
|
|
|
|
if self.task_name in ["SST-2", "CoLA", "MRPC", "RTE", "QNLI", "MNLI", "QQP", "MNLI_Match"]: |
|
label = int(dataset_dict['label']) |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
ret = { |
|
'input_sample': [{ |
|
'data': [sentences], |
|
'modality': 'text', |
|
'data_type': 'input', |
|
'invalid_mask': None, |
|
'sample_info' : { |
|
'spe_index': index |
|
} |
|
}], |
|
'target_sample': [], |
|
'target_idx' : [label], |
|
'target_set' : copy.deepcopy(self.target_set), |
|
'task_info' : copy.deepcopy(self.task_info) |
|
} |
|
|
|
dict_as_tensor(ret) |
|
return ret |
|
|