import logging import os from dataclasses import dataclass from typing import Dict, List, Sequence, Union import datasets import torch import transformers from datasets import concatenate_datasets, load_dataset from .data_collator import DataCollatorForSupervisedDataset, collate_fn_deepspeed from .dataset_cosyvoice2 import CosyVoice2Dataset from .dataset_deepseek import DeepSeekDataset from .dataset_hunyuan import HunyuanDataset from .dataset_llama3 import Llama3Dataset from .dataset_mistral import MistralDataset from .dataset_qwen2 import Qwen2Dataset logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def build_supervised_dataset_deepspeed( model_config, model_args, data_args, training_args, tokenizer, create_position_ids=True, create_loss_mask=False, shift_token=False, ): logging.info("building dataset...") cfg_path = data_args.dataset_name max_padding_length = model_args.model_max_length output_dir = training_args.output_dir # prompt_format = model_args.prompt_format create_attention_mask = data_args.create_attention_mask create_attention_mask_2d = data_args.create_attention_mask_2d image_size = model_args.image_size image_token_length = model_args.image_token_length max_num_frame = model_args.max_num_frame max_fps = model_args.max_fps reset_position_ids = data_args.reset_position_ids reset_attention_mask = data_args.reset_attention_mask variable_length = data_args.variable_length min_patch_grid = model_args.min_patch_grid max_patch_grid = model_args.max_patch_grid process_type = model_args.vision_process_type normalize_type = model_args.vision_normalize_type audio_tokenizer_path = model_args.audio_tokenizer_path audio_tokenizer_type = model_args.audio_tokenizer_type text_audio_interval_ratio = model_args.text_audio_interval_ratio seed = training_args.seed cross_dataset_joint = data_args.cross_dataset_joint dataset_joint = data_args.dataset_joint if "long_vita" in getattr(model_config, "model_type", None): TrainDataset = Qwen2Dataset elif "cosyvoice2" in getattr(model_config, "model_type", None): TrainDataset = CosyVoice2Dataset elif "qwen2" in getattr(model_config, "model_type", None): TrainDataset = Qwen2Dataset elif getattr(model_config, "model_type", None) == "hunyuan": TrainDataset = HunyuanDataset elif getattr(model_config, "model_type", None) == "mixtral": TrainDataset = Llama2Dataset elif "llama" in getattr(model_config, "model_type", None): TrainDataset = Llama3Dataset elif "deepseek" in getattr(model_config, "model_type", None): TrainDataset = DeepSeekDataset else: raise NotImplementedError train_dataset = TrainDataset( cfg_path, tokenizer, image_size=image_size, image_token_length=image_token_length, max_padding_length=max_padding_length, variable_length=variable_length, output_dir=output_dir, training_args=None, shift_token=shift_token, create_position_ids=create_position_ids, create_attention_mask=create_attention_mask, create_attention_mask_2d=create_attention_mask_2d, create_loss_mask=create_loss_mask, max_num_frame=max_num_frame, max_fps=max_fps, reset_position_ids=reset_position_ids, reset_attention_mask=reset_attention_mask, min_patch_grid=min_patch_grid, max_patch_grid=max_patch_grid, process_type=process_type, normalize_type=normalize_type, seed=seed, cross_dataset_joint=cross_dataset_joint, dataset_joint=dataset_joint, audio_tokenizer_type=audio_tokenizer_type, audio_tokenizer_path=audio_tokenizer_path, text_audio_interval_ratio=text_audio_interval_ratio, use_megatron=False, ) eval_dataset = None # data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) data_collator = collate_fn_deepspeed return dict(train=train_dataset, validation=eval_dataset, data_collator=data_collator) def build_supervised_dataset_megatron( args, tokenizer, create_position_ids=True, create_loss_mask=False, shift_token=False, ): logging.info("building dataset...") assert len(args.data_path) == 1 cfg_path = args.data_path[0] max_padding_length = args.max_padding_length output_dir = args.save prompt_format = args.prompt_format create_attention_mask = args.create_attention_mask_in_dataloader create_attention_mask_2d = args.create_attention_mask_in_dataloader # create_attention_mask=False # create_attention_mask_2d=True image_size = args.image_size image_token_length = args.image_token_length max_num_frame = args.max_num_frame max_fps = args.max_fps reset_position_ids = args.reset_position_ids reset_attention_mask = args.reset_attention_mask # reset_position_ids=True # reset_attention_mask=True min_patch_grid = args.min_patch_grid max_patch_grid = args.max_patch_grid process_type = args.vision_process_type normalize_type = args.vision_normalize_type seed = args.seed cross_dataset_joint = args.cross_dataset_joint dataset_joint = args.dataset_joint if "qwen2" in prompt_format: TrainDataset = Qwen2Dataset elif prompt_format == "mistral": raise NotImplementedError TrainDataset = MistralDataset elif prompt_format == "llama3": TrainDataset = Llama3Dataset if "deepseek" in prompt_format: TrainDataset = DeepSeekDataset else: raise NotImplementedError train_dataset = TrainDataset( cfg_path, tokenizer, image_size=image_size, image_token_length=image_token_length, max_padding_length=max_padding_length, variable_length=False, output_dir=output_dir, training_args=None, shift_token=shift_token, create_position_ids=create_position_ids, create_attention_mask=create_attention_mask, create_attention_mask_2d=create_attention_mask_2d, create_loss_mask=create_loss_mask, max_num_frame=max_num_frame, max_fps=max_fps, reset_position_ids=reset_position_ids, reset_attention_mask=reset_attention_mask, min_patch_grid=min_patch_grid, max_patch_grid=max_patch_grid, process_type=process_type, normalize_type=normalize_type, seed=seed, cross_dataset_joint=cross_dataset_joint, dataset_joint=dataset_joint, use_megatron=True, ) eval_dataset = None return train_dataset, None, None