Spaces:
Runtime error
Runtime error
# coding=utf8 | |
import argparse | |
import sys | |
import os | |
from concurrent.futures import ProcessPoolExecutor | |
def _generate_cache_arrow(index, ds, path): | |
print('saving dataset shard {}'.format(index)) | |
ds.save_to_disk(os.path.join(path, 'part_{}'.format(index))) | |
return 'saving dataset shard {} done'.format(index) | |
def generate_arrow_cache(ds, args) -> None: | |
''' | |
读取wudao_180g等原数据或者tokenized之后的数据,并进行train test split | |
同时利用seed 42做shuffle 缓存下来 | |
''' | |
ds = ds.train_test_split(train_size=args.train_split_size, seed=42) | |
print(ds) | |
p = ProcessPoolExecutor(max_workers=args.preprocessing_num_workers) | |
res = [] | |
train_shard_part = args.saved_data_shards | |
for i in range(0, train_shard_part): | |
res.append(p.submit(_generate_cache_arrow, i, | |
ds['train'].shard(train_shard_part, i), args.saved_train_data_path)) | |
p.shutdown(wait=True) | |
for future in res: | |
print(future.result(), flush=True) | |
ds['test'].save_to_disk(args.saved_test_data_path) | |
print('done') | |
if __name__ == '__main__': | |
total_parser = argparse.ArgumentParser("Save data Task") | |
total_parser.add_argument( | |
'--new_vocab_path', default='/cognitive_comp/ganruyi/hf_models/t5_cn_small/sentencepiece_cn.model', type=str) | |
total_parser.add_argument('--preprocessing_num_workers', default=30, type=int) | |
total_parser.add_argument( | |
'--train_data_path', default='/cognitive_comp/common_data/test_wudao_180g_mt5_tokenized/', type=str) | |
total_parser.add_argument('--saved_data_shards', default=800, type=int) | |
total_parser.add_argument('--saved_train_data_path', default=None, type=str) | |
total_parser.add_argument('--saved_test_data_path', default=None, type=str) | |
total_parser.add_argument('--max_seq_length', default=512, type=int) | |
total_parser.add_argument('--train_split_size', default=0.999, type=float) | |
total_parser.add_argument('--pretrained_model_path', default=None, type=str) | |
total_parser.add_argument('--tokenizer_type', default='t5_tokenizer', choices=['t5_tokenizer', 'bert_tokenizer']) | |
total_parser.add_argument('--text_column_name', default='text') | |
total_parser.add_argument('--remove_columns', nargs='+', default=[]) | |
# * Args for data preprocessing | |
args = total_parser.parse_args() | |
sys.path.append('../../../') | |
from fengshen.data.t5_dataloader.t5_datasets import UnsuperviseT5Dataset | |
ds = UnsuperviseT5Dataset(args.train_data_path, args) | |
print(ds) | |
generate_arrow_cache(ds.data, args=args) | |
# ds = UnsuperviseT5Dataset(args.train_data_path, args, load_data_type=0) | |
for i in range(0, 2): | |
print(ds.data[i]) | |
print(ds.tokenizer.decode(ds.data[i]['input_ids'])) | |
print(ds.data) | |