herrius's picture
Upload 259 files
32b542e
from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats
import random
import numpy as np
import copy
from torch.utils.data import Dataset
from uniperceiver.tokenization import ClipTokenizer
import logging
import os
from ..build import DATASETS_REGISTRY
from uniperceiver.config import configurable
import pickle
from uniperceiver.utils import comm
__all__ = ["GeneralCorpusDataset"]
@DATASETS_REGISTRY.register()
class GeneralCorpusDataset(Dataset):
@configurable
def __init__(self, ann_file, stage,
tokenizer, tokenizer_name,
seq_len=64, min_seq_len=64,
encoding="utf-8",
cache_mode=True, cache_local_rank=0, cache_local_size=1,
append_eos=False,
one_stream=False,
random_mask=False,
task_type=None,
text_type_id=0,
mask_bpe_word='spe',
version='v1',
task_info=None,
target_set=None,
**kwargs):
assert cache_mode, print("only support cache mode!")
assert len(task_type) > 0
self.version = version
self.stage = stage
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.use_clip_tokenizer = tokenizer_name == 'clip'
self.task_type = task_type
self.append_eos = append_eos
self.task_info = task_info
self.target_set = target_set
self.seq_len = seq_len
self.min_seq_len = min_seq_len
self.cache_mode = cache_mode
self.cache_local_size = cache_local_size
self.cache_local_rank = cache_local_rank
self.ann_file = ann_file
self.encoding = encoding
self.test_mode = False
self.random_mask = random_mask
self.one_stream = one_stream
self.text_type_id = text_type_id
self.mask_bpe_word = "<|spe|>" if mask_bpe_word == 'spe' else '<|startoftext|>'
# load samples into memory
if cache_mode:
print('dataset cache mode is ON: local size: {}; local rank: {}'.format(cache_local_size,
cache_local_rank))
self.corpus, self.cursor = self.load_corpus()
@classmethod
def from_config(cls, cfg, stage: str = "train"):
version = getattr(cfg.DATASETS, 'VERSION', 'v1')
if 'SLURM_PROCID' not in os.environ:
version = 'v1'
if version == 'v2':
ann_files = {
"train":
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2.txt")
if comm.get_world_size() > 1 else os.path.join(
cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc"),
"val":
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc")
}
elif version == 'v3':
ann_files = {
"train":
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswikiopen.txt")
if comm.get_world_size() > 1 else os.path.join(
cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc"),
"val":
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc")
}
else:
ann_files = {
"train": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki.doc") if comm.get_world_size() > 1 else
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki-1000.doc"),
"val": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki-1000.doc")
}
task_info = {
'task_type' : cfg.DATASETS.TASK_TYPE,
'dataset_name' : cfg.DATASETS.DATASET_NAME,
'batch_size' : cfg.DATALOADER.TRAIN_BATCH_SIZE if stage == 'train' else cfg.DATALOADER.TEST_BATCH_SIZE,
'sampling_weight': cfg.DATALOADER.SAMPLING_WEIGHT
}
ret = {
"version" : version,
"stage" : stage,
"ann_file" : ann_files[stage],
"seq_len" : cfg.MODEL.MAX_SEQ_LEN,
"min_seq_len" : cfg.MODEL.MAX_SEQ_LEN,
"cache_mode" : cfg.DATALOADER.CACHE_MODE,
"append_eos" : cfg.DATALOADER.APPEND_EOS,
"cache_local_rank": comm.get_local_rank(),
"cache_local_size": comm.get_local_size(),
"one_stream" : cfg.DATALOADER.ONE_STREAM,
"task_type" : cfg.DATASETS.TASK_TYPE,
"random_mask" : getattr(cfg.DATALOADER, 'RANDOM_MASK', False),
"text_type_id" : getattr(cfg.DATALOADER, 'TYPE_EMBEDDING_ID', 0),
"mask_bpe_word" : getattr(cfg.DATALOADER, 'MASK_BPE_WORD', 'spe'),
"task_info" : task_info,
"target_set" : cfg.DATASETS.TARGET_SET
}
ret['tokenizer'] = ClipTokenizer()
ret['tokenizer_name'] = "clip"
return ret
@classmethod
def add_config(cls, cfg):
cfg.DATALOADER.SAMPLER = "NodeDistributed"
cfg.DATALOADER.CACHE_MODE = True
cfg.DATALOADER.SEQ_PER_SAMPLE = 256
cfg.DATALOADER.MIN_SEQ_PER_SAMPLE = 256
cfg.DATALOADER.APPEND_EOS = True
def load_corpus(self):
if 'SLURM_PROCID' in os.environ:
self.cache_local_size = 8 # for convenice
cache_path = os.path.dirname(self.ann_file)
if self.version == 'v2':
cache_filename = 'cache/cache_block' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
elif self.version == 'v3':
cache_filename = 'cache_v3/cache_block_books_wiki_openweb' + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
else:
cache_filename = 'cache_block' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
cache_file = os.path.join(cache_path, cache_filename)
if not os.path.exists(cache_file):
if self.version == 'v3':
raise NotImplementedError
# [HACK] we hard code the corpus length
if 'SLURM_PROCID' in os.environ:
if self.version == 'v2':
self.file_len = 244208263
block_size = (self.file_len + self.cache_local_size - 1)// self.cache_local_size
block_start = block_size * self.cache_local_rank
block_end = (block_size) * (
1 + self.cache_local_rank
) if self.cache_local_rank + 1 < self.cache_local_size else self.file_len
else:
block_start = self.cache_local_rank * 13000000
block_end = ( self.cache_local_rank + 1 ) * 13000000
else:
block_start = 0
block_end = 1000
count = 0
corpus = bytearray()
cursor = []
c_ = 0
i_ = 0
for ann_file in self.ann_file.split('+'):
with open(ann_file, 'r', encoding=self.encoding) as f:
for l in f:
l = l.strip()
if l != '':
# if i_ % self.cache_local_size != self.cache_local_rank:
if i_< block_start or i_ >= block_end:
# cursor.append(c_)
i_ += 1
continue
l = l.encode()
corpus += l
cursor.append(c_)
c_ += len(l)
i_ += 1
count += 1
cursor.append(len(corpus))
cursor = np.array(cursor).astype(np.int, copy=False)
pickle.dump({
"corpus": corpus,
"cursor": cursor,
"count": count,
}, open(cache_file, "wb"), protocol=4)
else:
cachedata = pickle.load(open(cache_file, "rb"))
corpus, cursor, count = cachedata['corpus'], cachedata['cursor'], cachedata['count']
print("BooksWiki info: rank {} has {} sentences".format(self.cache_local_rank, count))
return corpus, cursor
def get_line(self, index):
return self.corpus[self.cursor[index]:self.cursor[index+1]].decode()
@property
def data_names(self):
return ['text', 'mlm_labels']
def __len__(self):
return len(self.cursor) - 1
def __getitem__(self, item):
# def __call__(self, item):
raw = self.get_line(item)
# tokenize
if self.use_clip_tokenizer:
tokens = self.tokenizer.basic_tokenize(raw)
if len(tokens) > 0 and self.append_eos:
tokens.append('<|endoftext|>')
else:
tokens = self.tokenizer.basic_tokenizer.tokenize(raw)
# add more tokens if len(tokens) < min_len
_cur = (item + 1) % (len(self.cursor) - 1)
while len(tokens) < self.min_seq_len:
if self.use_clip_tokenizer:
_cur_tokens = self.tokenizer.basic_tokenize(self.get_line(_cur))
if len(_cur_tokens) > 0 and self.append_eos:
_cur_tokens.append('<|endoftext|>')
else:
_cur_tokens = self.tokenizer.basic_tokenizer.tokenize(self.get_line(_cur))
tokens.extend(_cur_tokens)
_cur = (_cur + 1) % (len(self.cursor) - 1)
if self.task_type == 'text_mlm':
tokens, mlm_labels = self.random_word_wwm(tokens)
elif self.task_type == 'caption':
tokens_tmp = []
for token in tokens:
tokens_tmp.extend(self.tokenizer.encode_basic_tokenized_token(token))
tokens = tokens_tmp
mlm_labels = self.tokenizer.encode(
self.mask_bpe_word) * len(tokens)
if self.use_clip_tokenizer:
ids = tokens
else:
# add [CLS], [SEP]
tokens = tokens + ['[SEP]']
mlm_labels = mlm_labels + [-1]
# convert token to its vocab id
ids = self.tokenizer.convert_tokens_to_ids(tokens)
# truncate
if len(ids) > self.seq_len:
ids = ids[:(self.seq_len-1)] + [ids[-1]]
mlm_labels = mlm_labels[:(self.seq_len-1)] + [mlm_labels[-1]]
elif len(ids) < self.seq_len:
ids = ids + [0 for _ in range(self.seq_len - len(ids))]
mlm_labels = mlm_labels + [-1 for _ in range(self.seq_len - len(ids))]
if self.task_type == 'text_mlm':
ret = {
'input_sample': [{
'data' : [np.array(ids, dtype=np.int64)],
'invalid_mask': None,
'modality' : 'text',
'data_type': 'input',
'sample_info' : {
'seq_length': len(ids)
}
}],
'target_sample': [],
'target_idx' : [np.array(mlm_labels, dtype=np.int64)],
'target_set' : copy.deepcopy(self.target_set),
'task_info' : copy.deepcopy(self.task_info)
}
elif self.task_type == 'caption':
source = np.array(ids, dtype=np.int64)
source2 = np.array(mlm_labels, dtype=np.int64)
ret = {
'input_sample': [{
'data': [source, source2],
'invalid_mask': None,
'modality': 'text',
'data_type': 'input',
'sample_info': {}
}],
'target_sample': [],
'target_idx': [np.array(ids, dtype=np.int64)],
'target_set' : copy.deepcopy(self.target_set),
'task_info' : copy.deepcopy(self.task_info)
}
dict_as_tensor(ret)
return ret
def random_word_wwm(self, tokens):
output_tokens = []
output_label = []
for i, token in enumerate(tokens):
if self.use_clip_tokenizer:
sub_tokens = self.tokenizer.encode_basic_tokenized_token(token)
else:
sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
for sub_token in sub_tokens:
if self.use_clip_tokenizer:
output_tokens.append(
self.tokenizer.encoder[self.mask_bpe_word])
else:
output_tokens.append("[MASK]")
# 10% randomly change token to random token
elif prob < 0.9:
for sub_token in sub_tokens:
if self.use_clip_tokenizer:
output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder)))))
else:
output_tokens.append(random.choice(list(self.tokenizer.vocab.keys())))
# -> rest 10% randomly keep current token
else:
for sub_token in sub_tokens:
output_tokens.append(sub_token)
# append current token to output (we will predict these later)
for sub_token in sub_tokens:
if self.use_clip_tokenizer:
output_label.append(sub_token)
else:
try:
output_label.append(self.tokenizer.vocab[sub_token])
except KeyError:
# For unknown words (should not occur with BPE vocab)
output_label.append(self.tokenizer.vocab["[UNK]"])
logging.warning("Cannot find sub_token '{}' in vocab. Using [UNK] insetad".format(sub_token))
else:
for sub_token in sub_tokens:
# no masking token (will be ignored by loss function later)
output_tokens.append(sub_token)
output_label.append(-1)
return output_tokens, output_label