import contextlib import json import logging import os import pdb import re import traceback import uuid import numpy as np import torch import yaml from PIL import Image from torchvision import transforms from torchvision.transforms import InterpolationMode from .processor.audio_processor import AudioProcessor from .processor.image_processor import ImageProcessor from .utils import draw_data logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class BaseDataset(torch.utils.data.Dataset): def __init__( self, cfg_path, tokenizer, image_size=448, image_token_length=1024, max_padding_length=32768, variable_length=False, output_dir="", add_task_symbol=True, training_args=None, shift_token=False, create_position_ids=True, create_attention_mask=True, create_attention_mask_2d=False, create_loss_mask=False, max_num_frame=8, max_fps=1, reset_position_ids=False, reset_attention_mask=False, min_patch_grid=1, max_patch_grid=6, process_type="anyres", normalize_type="imagenet", seed=42, cross_dataset_joint=False, dataset_joint=True, audio_tokenizer_type=None, audio_tokenizer_path=None, text_audio_interval_ratio=None, use_megatron=True, ): super(BaseDataset, self).__init__() self.cfg_path = cfg_path with open(self.cfg_path, "r", encoding="utf8") as cfg_file: cfg_data = cfg_file.read() self.cfg = yaml.load(cfg_data, Loader=yaml.CLoader) logger.info(f"cfg {self.cfg}") self.tokenizer = tokenizer self.max_padding_length = max_padding_length self.variable_length = variable_length self.output_dir = output_dir self.training_args = training_args self.shift_token = shift_token self.create_position_ids = create_position_ids self.create_attention_mask = create_attention_mask self.create_attention_mask_2d = create_attention_mask_2d self.create_loss_mask = create_loss_mask self.max_num_frame = max_num_frame self.max_fps = max_fps self.reset_position_ids = reset_position_ids self.reset_attention_mask = reset_attention_mask self.seed = seed self.cross_dataset_joint = cross_dataset_joint self.dataset_joint = dataset_joint self.image_size = image_size self.image_token_length = image_token_length self.do_dataset_format = self.cfg.get("do_dataset_format", False) self.do_dataset_cast = self.cfg.get("do_dataset_cast", False) self.xlsx_sample_num = self.cfg.get("xlsx_sample_num", 5) self.processor = {} self.processor["image"] = ImageProcessor( process_type, image_size=self.image_size, normalize_type=normalize_type, min_patch_grid=min_patch_grid, max_patch_grid=max_patch_grid, ) self.processor["audio"] = AudioProcessor( audio_tokenizer_path=audio_tokenizer_path, audio_tokenizer_type=audio_tokenizer_type, text_audio_interval_ratio=text_audio_interval_ratio ) if use_megatron: self.load_data() else: with main_process_first(local=True, desc="Loading data"): self.load_data() self.processed_samples = 0 self.unjoint_samples = 0 self.joint_samples = 0 self.skip_samples = 0 def load_data(self): from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset raw_data = None sampled_data = {} source_idx = 0 for data_name, data_info in self.cfg["dataset"].items(): data_ratio = data_info.get("ratio", 1) data_num = data_info.get("num", 999999999) if data_ratio == 0: continue if data_num == 0: continue for data_idx, data_path in enumerate(data_info["data_paths"]): if not os.path.isfile(data_path) and not os.path.isdir(data_path): logger.warning(f"Data file no found {data_path}") continue this_data = load_json(data_path, self.output_dir) # this_data = load_data_one(data_path, self.outout_dir) if this_data is None: logger.warning(f"Failed to load {data_path}") continue # print(f"this_data {this_data}") column_names = list(this_data.features) if "id" in column_names: this_data = this_data.remove_columns("id") # sources = [data_path] * len(this_data) sources = [source_idx] * len(this_data) source_idx += 1 # sources = [data_name] * len(this_data) this_data = this_data.add_column("source", sources) if "images" not in column_names: # images = [[]] * len(this_data) images = [None] * len(this_data) this_data = this_data.add_column("images", images) if "videos" not in column_names: # videos = [[]] * len(this_data) videos = [None] * len(this_data) this_data = this_data.add_column("videos", videos) if "audios" not in column_names: # videos = [[]] * len(this_data) audios = [None] * len(this_data) this_data = this_data.add_column("audios", videos) if False: column_names = list(this_data.features) this_data = this_data.map( format_function_general, batched=True, batch_size=2560, num_proc=1, # batch_size=1, # num_proc=1, remove_columns=column_names, keep_in_memory=False, desc="Running format on dataset", ) this_data = this_data.shuffle(seed=self.seed) # this_data = this_data.flatten_indices() this_data = this_data.shuffle(seed=self.seed) # this_data = this_data.flatten_indices() data_ratio = float(data_ratio) total_num = len(this_data) used_num = min(int(total_num * data_ratio), data_num) logger.info(f"total_num {total_num}") logger.info(f"data_ratio {data_ratio}") logger.info(f"data_num {data_num}") logger.info(f"used_num {used_num}") indices = [x % total_num for x in range(used_num)] this_data = this_data.select(indices) if raw_data is None: raw_data = this_data else: if self.do_dataset_cast: this_data = this_data.cast(raw_data.features) raw_data = concatenate_datasets([raw_data, this_data]) sampled_data[data_path] = {} sampled_data[data_path]["data"] = this_data.select( range(min(self.xlsx_sample_num, used_num)) ) sampled_data[data_path]["total_num"] = total_num sampled_data[data_path]["used_num"] = used_num logger.info(f"this_data {this_data}") logger.info(f"raw_data {raw_data}") # logger.info(f"raw_data {raw_data[0]}") # logger.info(f"raw_data {raw_data[-1]}") logger.info(f"Successful load {data_path}") raw_data = raw_data.shuffle(seed=self.seed) # raw_data = raw_data.flatten_indices() raw_data = raw_data.shuffle(seed=self.seed) # raw_data = raw_data.flatten_indices() self.raw_data = raw_data if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: output_xlsx = os.path.basename(self.cfg_path).replace("yaml", "xlsx") output_xlsx = os.path.join(self.output_dir, output_xlsx) logger.info(f"output_xlsx {output_xlsx}") draw_data( sampled_data, output_xlsx, tokenizer=self.tokenizer, image_processor=self.processor["image"], ) logger.info(f"raw_data {raw_data}") if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: logger.info(f"raw_data {raw_data[:10]}") logger.info(f"raw_data {raw_data[-10:]}") def __len__(self): return len(self.raw_data) def format_function_general(examples): messages = [x for x in examples["messages"]] if "images" in examples: images = [x for x in examples["images"]] else: images = [None for _ in messages] if "videos" in examples: videos = [x for x in examples["videos"]] else: videos = [None for _ in messages] if "audios" in examples: audios = [x for x in examples["audios"]] else: audios = [None for _ in messages] return { "messages": messages, "images": images, "videos": videos, "audios": audios, } def load_json_A(data_file): from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset with open(data_file, "r") as f: raw_data = json.load(f) this_data = Dataset.from_list(raw_data) return this_data def load_json_B(data_file): from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset this_data = load_dataset("json", data_files=data_file, keep_in_memory=False) return this_data["train"] def load_json_C(data_file): from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset raw_data = [] with open(data_file, "r") as f: for line in f.readlines(): d = json.loads(line) # raw_data.append({"conversations": d["conversations"], "id": d["id"]}) if "conversations" in d: raw_data.append({"conversations": d["conversations"]}) if "messages" in d: raw_data.append({"messages": d["messages"]}) this_data = Dataset.from_list(raw_data) return this_data def load_json(data_file, output_dir): for func in [load_json_B, load_json_A, load_json_C]: try: this_data = func(data_file) return this_data except Exception as error: with open(os.path.join(output_dir, "data_error.log"), "a") as f: print("-" * 100, file=f) # print(error, file=f) print(traceback.format_exc(), file=f) continue return None def load_data_one(data_file, output_dir): if data_file.endswith("json") or data_file.endswith("jsonl"): return load_json(data_file, output_dir) from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset this_data = load_dataset(data_file, keep_in_memory=False) return this_data["train"] @contextlib.contextmanager def main_process_first(local=True, desc="work"): if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: if local: rank = int(os.environ["LOCAL_RANK"]) else: rank = torch.distributed.get_rank() is_main_process = rank == 0 try: if not is_main_process: torch.distributed.barrier() yield finally: if is_main_process: torch.distributed.barrier() else: yield