herrius's picture
Upload 259 files
32b542e
import os
import copy
import pickle
import random
import json
import glob
import numpy as np
from uniperceiver.config import configurable
from uniperceiver.functional import dict_as_tensor
from uniperceiver.tokenization import ClipTokenizer
from ..build import DATASETS_REGISTRY
import pyarrow as pa
__all__ = ["GLUEDataset"]
@DATASETS_REGISTRY.register()
class GLUEDataset:
@configurable
def __init__(
self,
cfg: dict,
stage: str,
anno_file: str,
max_seq_len: int,
tokenizer,
tokenizer_name,
input_columns,
label_column,
input_count,
task_name,
data_percentage,
data_k_sample,
):
self.cfg = cfg
self.stage = stage
self.anno_file = anno_file
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.max_seq_len = max_seq_len
self.input_columns = input_columns
self.label_column = label_column
self.input_count = input_count
self.task_name = task_name
self.data_percentage = data_percentage
self.data_k_sample = data_k_sample
self.task_info = {
'task_type' : self.cfg.DATASETS.TASK_TYPE,
'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT,
}
self.target_set = cfg.DATASETS.TARGET_SET
self.load_data(cfg)
@classmethod
def from_config(cls, cfg, stage: str = "train"):
task_name = cfg.DATASETS.DATASET_NAME
namesmapping = {
"train": "train",
"val": "dev",
"test": "test",
}
data_dir = cfg.DATALOADER.ANNO_FOLDER
if task_name in ['MNLI', 'QNLI', 'QQP', 'RTE', 'SST-2', 'MRPC', 'CoLA', 'STS-B']:
anno_file = os.path.join(data_dir, task_name, 'processed/{name}.tsv'.format(name=namesmapping[stage]))
elif task_name == 'MNLI_Match':
namesmapping = {
"train": "train",
"val": "dev_matched",
"test": "test_matched",
}
anno_file = os.path.join(data_dir, 'MNLI', 'processed/{name}.tsv'.format(name=namesmapping[stage]))
elif task_name == 'MNLI_Mismatch':
namesmapping = {
"train": "train",
"val": "dev_mismatched",
"test": "test_mismatched",
}
anno_file = os.path.join(data_dir, 'MNLI', 'processed/{name}.tsv'.format(name=namesmapping[stage]))
input_count = 2
if task_name == "QQP":
input_columns = [4, 5]
if stage == 'test':
input_columns = [2, 3]
label_column = 6
elif task_name in ["MNLI_Match", "MNLI_Mismatch"]: # "MNLI" :
input_columns = [9, 10]
if stage == 'test':
input_columns = [9, 10]
label_column = 12
if stage == 'val':
label_column = 16
elif task_name == "QNLI":
input_columns = [2, 3]
if stage == 'test':
input_columns = [2, 3]
label_column = 4
elif task_name == "MRPC":
input_columns = [4, 5]
if stage == 'test':
input_columns = [4, 5]
label_column = 1
elif task_name == "RTE":
input_columns = [2, 3]
if stage == 'test':
input_columns = [2, 3]
label_column = 4
elif task_name == "STS-B":
input_columns = [8, 9]
if stage == 'test':
input_columns = [8, 9]
label_column = 10
# Following are single sentence tasks.
elif task_name == "SST-2":
input_columns = [1]
if stage == 'test':
input_columns = [2]
label_column = 2
input_count = 1
elif task_name == "CoLA":
input_columns = [4]
if stage == 'test':
input_columns = [2]
label_column = 2
input_count = 1
else:
raise NotImplementedError
ret = {
"cfg": cfg,
"stage": stage,
"anno_file": anno_file,
"max_seq_len": cfg.MODEL.MAX_SEQ_LEN,
"input_columns": input_columns,
"label_column": label_column,
"input_count": input_count,
"task_name": task_name,
"data_percentage": getattr(cfg.DATALOADER, "DATA_PERCENTAGE", 1.0),
"data_k_sample": getattr(cfg.DATALOADER, "DATA_K_SAMPLE", -1),
"tokenizer": ClipTokenizer(),
"tokenizer_name": "clip"
}
return ret
def load_data(self, cfg):
cache_path = os.path.join(os.path.dirname(self.anno_file), "cache_GLUE_raw_%s_%s_%s.pkl" % (self.task_name, self.tokenizer_name, self.stage))
if not os.path.exists(cache_path):
datalist = self.load_raw_data(cfg)
pickle.dump(datalist, open(cache_path, "wb"))
datalist = pickle.load(open(cache_path, "rb"))
# for few shot exp
if self.data_percentage < 1.0 and self.stage == "train":
print("will sample {} data for trianing-->".format(self.data_percentage))
labels2l = dict()
for data in datalist:
label = data['label']
if label not in labels2l:
labels2l[label] = list()
labels2l[label].append(data)
# samplers_label = len(datalist) * self.data_percentage // len(labels2l.keys())
datalist = []
for v in labels2l.values():
datalist.extend(random.sample(v, k=int(self.data_percentage * len(v) + 1)))
# datalist.extend(random.sample(v, k=int(samplers_label+1)))
elif self.data_k_sample > 0 and self.stage == "train":
print("will sample {} data for each class when training -->".format(self.data_k_sample))
labels2l = dict()
for data in datalist:
label = data['label']
if label not in labels2l:
labels2l[label] = list()
labels2l[label].append(data)
datalist = []
for v in labels2l.values():
datalist.extend(random.sample(v, k=int(self.data_k_sample)))
while len(datalist) < 200:
datalist = datalist + datalist
self.datalist = datalist
def load_raw_data(self, cfg):
datalist = []
if self.task_name.startswith("MNLI"):
labelmapping = {
"contradiction": 0,
"neutral": 1,
"entailment": 2,
}
fin = open(self.anno_file, 'r').readlines()
for _, line in enumerate(fin):
sensinfo = line.strip().split('\t')
if self.task_name == "RTE":
label = 1.0 if sensinfo[self.label_column - 1] == "entailment" else 0.0
elif self.task_name.startswith("MNLI"):
label = labelmapping[sensinfo[self.label_column - 1]]
elif self.task_name == "QNLI":
label = 1.0 if sensinfo[self.label_column - 1] == "entailment" else 0.0
elif self.task_name == "STS-B":
label = float(sensinfo[self.label_column - 1]) / 5.0
else:
label = float(sensinfo[self.label_column - 1])
datalist.append({
# start index from 1 to 0
"sentences": [sensinfo[i - 1] for i in self.input_columns],
"label": label
})
return datalist
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
dataset_dict = copy.deepcopy(self.datalist[index])
sentences = dataset_dict['sentences']
# input1: SEN1, this sentence is (spe) input2: word choice: postive and negative
if self.input_count == 1:
if self.task_name == "SST-2":
tokens = self.tokenizer.encode(sentences[0] + " <|endoftext|> It is <|spe|>. <|endoftext|>")
elif self.task_name == "CoLA":
tokens = self.tokenizer.encode(sentences[0] + " This is <|spe|>. <|endoftext|>")
else:
raise NotImplementedError
index = len(tokens) - 3
assert index < self.max_seq_len
if len(tokens) > self.max_seq_len:
tokens = tokens[:self.max_seq_len - 4] + tokens[-4:]
else:
if self.task_name in ["RTE"]:
tokens1 = self.tokenizer.encode(sentences[0])
if tokens1[-1] == 269:
tokens1 = tokens1[:-1]
tokens1 = tokens1 + self.tokenizer.encode(" ? <|endoftext|> it is ")
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
if len(tokens2) > self.max_seq_len // 2:
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
max_len = self.max_seq_len - len(tokens2)
elif self.task_name in ["MRPC"]:
tokens1 = self.tokenizer.encode(sentences[0])
if tokens1[-1] == 269:
tokens1 = tokens1[:-1]
tokens1 = tokens1 + self.tokenizer.encode(" . ")
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
if len(tokens2) > self.max_seq_len // 2:
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
max_len = self.max_seq_len - len(tokens2)
elif self.task_name in ["QQP"]:
tokens1 = self.tokenizer.encode(sentences[0])
if tokens1[-1] == 269:
tokens1 = tokens1[:-1]
tokens1 = tokens1 + self.tokenizer.encode(" <|endoftext|> ")
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
if len(tokens2) > self.max_seq_len // 2:
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
max_len = self.max_seq_len - len(tokens2)
elif self.task_name in ["QNLI"]:
tokens1 = self.tokenizer.encode(sentences[0])
if tokens1[-1] == 269:
tokens1 = tokens1[:-1]
tokens1 = tokens1 + self.tokenizer.encode(" <|endoftext|> it is ")
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
if len(tokens2) > self.max_seq_len - len(tokens1):
tokens2 = tokens2[:self.max_seq_len - len(tokens1) - 1] + [tokens2[-1]]
max_len = self.max_seq_len - len(tokens2)
elif self.task_name in ["MNLI", "MNLI_Match"]:
# sentence0 = sentences[0].replace(")", "").replace("(", "")
tokens1 = self.tokenizer.encode(sentences[0])
# if tokens1[-1] == 269:
# tokens1 = tokens1[:-1]
tokens1 = tokens1 # + self.tokenizer.encode(" ? ")
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
if len(tokens2) > self.max_seq_len // 2:
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
max_len = self.max_seq_len - len(tokens2)
elif self.task_name in ["RTE", "QNLI", "MNLI", "MNLI_Match"]:
tokens1 = self.tokenizer.encode(sentences[0] + "? <|endoftext|>")
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
if tokens1[-1] == 269:
tokens1 = tokens1[:-1]
tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
if len(tokens2) > self.max_seq_len // 2:
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
max_len = self.max_seq_len - len(tokens2)
elif self.task_name in ["MRPC", "QQP"]:
tokens1 = self.tokenizer.encode(sentences[0] + " <|endoftext|>")
tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
tokens2 = self.tokenizer.encode(" <|spe|>, ") + tokens2
if len(tokens2) > self.max_seq_len // 2:
tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
max_len = self.max_seq_len - len(tokens2)
else:
NotImplementedError
# tokens = self.tokenizer.add_special_tokens_sentences_pair(tokens1, tokens2, start_type='SPE')
if len(tokens1) > max_len:
tokens1 = tokens1[:max_len - 1] + [tokens1[-1]]
tokens = tokens1 + tokens2
index = len(tokens1)
assert index < self.max_seq_len
sentences = np.array(tokens, dtype=np.int64)
if self.task_name in ["SST-2", "CoLA", "MRPC", "RTE", "QNLI", "MNLI", "QQP", "MNLI_Match"]:
label = int(dataset_dict['label'])
else:
raise NotImplementedError()
ret = {
'input_sample': [{
'data': [sentences],
'modality': 'text',
'data_type': 'input',
'invalid_mask': None,
'sample_info' : {
'spe_index': index
}
}],
'target_sample': [],
'target_idx' : [label],
'target_set' : copy.deepcopy(self.target_set),
'task_info' : copy.deepcopy(self.task_info)
}
dict_as_tensor(ret)
return ret