|
from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats |
|
import random |
|
import numpy as np |
|
import copy |
|
from torch.utils.data import Dataset |
|
from uniperceiver.tokenization import ClipTokenizer |
|
import logging |
|
import os |
|
from ..build import DATASETS_REGISTRY |
|
from uniperceiver.config import configurable |
|
import pickle |
|
from uniperceiver.utils import comm |
|
|
|
__all__ = ["GeneralCorpusDataset"] |
|
|
|
|
|
@DATASETS_REGISTRY.register() |
|
class GeneralCorpusDataset(Dataset): |
|
@configurable |
|
def __init__(self, ann_file, stage, |
|
tokenizer, tokenizer_name, |
|
seq_len=64, min_seq_len=64, |
|
encoding="utf-8", |
|
cache_mode=True, cache_local_rank=0, cache_local_size=1, |
|
append_eos=False, |
|
one_stream=False, |
|
random_mask=False, |
|
task_type=None, |
|
text_type_id=0, |
|
mask_bpe_word='spe', |
|
version='v1', |
|
task_info=None, |
|
target_set=None, |
|
**kwargs): |
|
assert cache_mode, print("only support cache mode!") |
|
assert len(task_type) > 0 |
|
self.version = version |
|
self.stage = stage |
|
self.tokenizer = tokenizer |
|
self.tokenizer_name = tokenizer_name |
|
self.use_clip_tokenizer = tokenizer_name == 'clip' |
|
self.task_type = task_type |
|
self.append_eos = append_eos |
|
self.task_info = task_info |
|
self.target_set = target_set |
|
|
|
self.seq_len = seq_len |
|
self.min_seq_len = min_seq_len |
|
self.cache_mode = cache_mode |
|
self.cache_local_size = cache_local_size |
|
self.cache_local_rank = cache_local_rank |
|
|
|
self.ann_file = ann_file |
|
self.encoding = encoding |
|
self.test_mode = False |
|
self.random_mask = random_mask |
|
|
|
self.one_stream = one_stream |
|
|
|
self.text_type_id = text_type_id |
|
|
|
self.mask_bpe_word = "<|spe|>" if mask_bpe_word == 'spe' else '<|startoftext|>' |
|
|
|
|
|
if cache_mode: |
|
print('dataset cache mode is ON: local size: {}; local rank: {}'.format(cache_local_size, |
|
cache_local_rank)) |
|
self.corpus, self.cursor = self.load_corpus() |
|
|
|
@classmethod |
|
def from_config(cls, cfg, stage: str = "train"): |
|
version = getattr(cfg.DATASETS, 'VERSION', 'v1') |
|
if 'SLURM_PROCID' not in os.environ: |
|
version = 'v1' |
|
if version == 'v2': |
|
ann_files = { |
|
"train": |
|
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2.txt") |
|
if comm.get_world_size() > 1 else os.path.join( |
|
cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc"), |
|
"val": |
|
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc") |
|
} |
|
elif version == 'v3': |
|
ann_files = { |
|
"train": |
|
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswikiopen.txt") |
|
if comm.get_world_size() > 1 else os.path.join( |
|
cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc"), |
|
"val": |
|
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc") |
|
} |
|
else: |
|
ann_files = { |
|
|
|
"train": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki.doc") if comm.get_world_size() > 1 else |
|
os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki-1000.doc"), |
|
"val": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki-1000.doc") |
|
} |
|
|
|
task_info = { |
|
'task_type' : cfg.DATASETS.TASK_TYPE, |
|
'dataset_name' : cfg.DATASETS.DATASET_NAME, |
|
'batch_size' : cfg.DATALOADER.TRAIN_BATCH_SIZE if stage == 'train' else cfg.DATALOADER.TEST_BATCH_SIZE, |
|
'sampling_weight': cfg.DATALOADER.SAMPLING_WEIGHT |
|
} |
|
|
|
ret = { |
|
"version" : version, |
|
"stage" : stage, |
|
"ann_file" : ann_files[stage], |
|
"seq_len" : cfg.MODEL.MAX_SEQ_LEN, |
|
"min_seq_len" : cfg.MODEL.MAX_SEQ_LEN, |
|
"cache_mode" : cfg.DATALOADER.CACHE_MODE, |
|
"append_eos" : cfg.DATALOADER.APPEND_EOS, |
|
"cache_local_rank": comm.get_local_rank(), |
|
"cache_local_size": comm.get_local_size(), |
|
"one_stream" : cfg.DATALOADER.ONE_STREAM, |
|
"task_type" : cfg.DATASETS.TASK_TYPE, |
|
"random_mask" : getattr(cfg.DATALOADER, 'RANDOM_MASK', False), |
|
"text_type_id" : getattr(cfg.DATALOADER, 'TYPE_EMBEDDING_ID', 0), |
|
"mask_bpe_word" : getattr(cfg.DATALOADER, 'MASK_BPE_WORD', 'spe'), |
|
"task_info" : task_info, |
|
"target_set" : cfg.DATASETS.TARGET_SET |
|
} |
|
|
|
|
|
ret['tokenizer'] = ClipTokenizer() |
|
ret['tokenizer_name'] = "clip" |
|
|
|
|
|
return ret |
|
|
|
@classmethod |
|
def add_config(cls, cfg): |
|
cfg.DATALOADER.SAMPLER = "NodeDistributed" |
|
cfg.DATALOADER.CACHE_MODE = True |
|
cfg.DATALOADER.SEQ_PER_SAMPLE = 256 |
|
cfg.DATALOADER.MIN_SEQ_PER_SAMPLE = 256 |
|
cfg.DATALOADER.APPEND_EOS = True |
|
|
|
|
|
def load_corpus(self): |
|
if 'SLURM_PROCID' in os.environ: |
|
self.cache_local_size = 8 |
|
cache_path = os.path.dirname(self.ann_file) |
|
if self.version == 'v2': |
|
cache_filename = 'cache/cache_block' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl' |
|
elif self.version == 'v3': |
|
cache_filename = 'cache_v3/cache_block_books_wiki_openweb' + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl' |
|
else: |
|
cache_filename = 'cache_block' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl' |
|
cache_file = os.path.join(cache_path, cache_filename) |
|
if not os.path.exists(cache_file): |
|
if self.version == 'v3': |
|
raise NotImplementedError |
|
|
|
if 'SLURM_PROCID' in os.environ: |
|
if self.version == 'v2': |
|
self.file_len = 244208263 |
|
block_size = (self.file_len + self.cache_local_size - 1)// self.cache_local_size |
|
block_start = block_size * self.cache_local_rank |
|
block_end = (block_size) * ( |
|
1 + self.cache_local_rank |
|
) if self.cache_local_rank + 1 < self.cache_local_size else self.file_len |
|
else: |
|
block_start = self.cache_local_rank * 13000000 |
|
block_end = ( self.cache_local_rank + 1 ) * 13000000 |
|
else: |
|
block_start = 0 |
|
block_end = 1000 |
|
count = 0 |
|
corpus = bytearray() |
|
cursor = [] |
|
c_ = 0 |
|
i_ = 0 |
|
for ann_file in self.ann_file.split('+'): |
|
with open(ann_file, 'r', encoding=self.encoding) as f: |
|
for l in f: |
|
l = l.strip() |
|
if l != '': |
|
|
|
if i_< block_start or i_ >= block_end: |
|
|
|
i_ += 1 |
|
continue |
|
l = l.encode() |
|
corpus += l |
|
cursor.append(c_) |
|
c_ += len(l) |
|
i_ += 1 |
|
count += 1 |
|
cursor.append(len(corpus)) |
|
cursor = np.array(cursor).astype(np.int, copy=False) |
|
pickle.dump({ |
|
"corpus": corpus, |
|
"cursor": cursor, |
|
"count": count, |
|
}, open(cache_file, "wb"), protocol=4) |
|
|
|
else: |
|
cachedata = pickle.load(open(cache_file, "rb")) |
|
corpus, cursor, count = cachedata['corpus'], cachedata['cursor'], cachedata['count'] |
|
|
|
print("BooksWiki info: rank {} has {} sentences".format(self.cache_local_rank, count)) |
|
|
|
return corpus, cursor |
|
|
|
|
|
def get_line(self, index): |
|
return self.corpus[self.cursor[index]:self.cursor[index+1]].decode() |
|
|
|
@property |
|
def data_names(self): |
|
return ['text', 'mlm_labels'] |
|
|
|
def __len__(self): |
|
return len(self.cursor) - 1 |
|
|
|
def __getitem__(self, item): |
|
|
|
raw = self.get_line(item) |
|
|
|
|
|
if self.use_clip_tokenizer: |
|
tokens = self.tokenizer.basic_tokenize(raw) |
|
if len(tokens) > 0 and self.append_eos: |
|
tokens.append('<|endoftext|>') |
|
else: |
|
tokens = self.tokenizer.basic_tokenizer.tokenize(raw) |
|
|
|
|
|
_cur = (item + 1) % (len(self.cursor) - 1) |
|
while len(tokens) < self.min_seq_len: |
|
if self.use_clip_tokenizer: |
|
_cur_tokens = self.tokenizer.basic_tokenize(self.get_line(_cur)) |
|
if len(_cur_tokens) > 0 and self.append_eos: |
|
_cur_tokens.append('<|endoftext|>') |
|
else: |
|
_cur_tokens = self.tokenizer.basic_tokenizer.tokenize(self.get_line(_cur)) |
|
tokens.extend(_cur_tokens) |
|
_cur = (_cur + 1) % (len(self.cursor) - 1) |
|
|
|
if self.task_type == 'text_mlm': |
|
tokens, mlm_labels = self.random_word_wwm(tokens) |
|
|
|
elif self.task_type == 'caption': |
|
tokens_tmp = [] |
|
for token in tokens: |
|
tokens_tmp.extend(self.tokenizer.encode_basic_tokenized_token(token)) |
|
tokens = tokens_tmp |
|
mlm_labels = self.tokenizer.encode( |
|
self.mask_bpe_word) * len(tokens) |
|
|
|
if self.use_clip_tokenizer: |
|
ids = tokens |
|
else: |
|
|
|
tokens = tokens + ['[SEP]'] |
|
mlm_labels = mlm_labels + [-1] |
|
|
|
|
|
ids = self.tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
if len(ids) > self.seq_len: |
|
ids = ids[:(self.seq_len-1)] + [ids[-1]] |
|
mlm_labels = mlm_labels[:(self.seq_len-1)] + [mlm_labels[-1]] |
|
elif len(ids) < self.seq_len: |
|
ids = ids + [0 for _ in range(self.seq_len - len(ids))] |
|
mlm_labels = mlm_labels + [-1 for _ in range(self.seq_len - len(ids))] |
|
|
|
|
|
|
|
|
|
if self.task_type == 'text_mlm': |
|
ret = { |
|
'input_sample': [{ |
|
'data' : [np.array(ids, dtype=np.int64)], |
|
'invalid_mask': None, |
|
'modality' : 'text', |
|
'data_type': 'input', |
|
'sample_info' : { |
|
'seq_length': len(ids) |
|
} |
|
}], |
|
'target_sample': [], |
|
'target_idx' : [np.array(mlm_labels, dtype=np.int64)], |
|
'target_set' : copy.deepcopy(self.target_set), |
|
'task_info' : copy.deepcopy(self.task_info) |
|
} |
|
elif self.task_type == 'caption': |
|
source = np.array(ids, dtype=np.int64) |
|
source2 = np.array(mlm_labels, dtype=np.int64) |
|
|
|
ret = { |
|
'input_sample': [{ |
|
'data': [source, source2], |
|
'invalid_mask': None, |
|
'modality': 'text', |
|
'data_type': 'input', |
|
'sample_info': {} |
|
}], |
|
'target_sample': [], |
|
'target_idx': [np.array(ids, dtype=np.int64)], |
|
'target_set' : copy.deepcopy(self.target_set), |
|
'task_info' : copy.deepcopy(self.task_info) |
|
} |
|
|
|
dict_as_tensor(ret) |
|
return ret |
|
|
|
def random_word_wwm(self, tokens): |
|
output_tokens = [] |
|
output_label = [] |
|
|
|
for i, token in enumerate(tokens): |
|
if self.use_clip_tokenizer: |
|
sub_tokens = self.tokenizer.encode_basic_tokenized_token(token) |
|
else: |
|
sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token) |
|
prob = random.random() |
|
|
|
if prob < 0.15: |
|
prob /= 0.15 |
|
|
|
|
|
if prob < 0.8: |
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_tokens.append( |
|
self.tokenizer.encoder[self.mask_bpe_word]) |
|
else: |
|
output_tokens.append("[MASK]") |
|
|
|
elif prob < 0.9: |
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder))))) |
|
else: |
|
output_tokens.append(random.choice(list(self.tokenizer.vocab.keys()))) |
|
|
|
else: |
|
for sub_token in sub_tokens: |
|
output_tokens.append(sub_token) |
|
|
|
|
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_label.append(sub_token) |
|
else: |
|
try: |
|
output_label.append(self.tokenizer.vocab[sub_token]) |
|
except KeyError: |
|
|
|
output_label.append(self.tokenizer.vocab["[UNK]"]) |
|
logging.warning("Cannot find sub_token '{}' in vocab. Using [UNK] insetad".format(sub_token)) |
|
else: |
|
for sub_token in sub_tokens: |
|
|
|
output_tokens.append(sub_token) |
|
output_label.append(-1) |
|
|
|
return output_tokens, output_label |