import sys import traceback from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger from finetrainers.config import _get_model_specifiction_cls from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig logger = get_logger() def main(): try: import multiprocessing multiprocessing.set_start_method("fork") except Exception as e: logger.error( f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. ' f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n" f"Error: {e}" ) try: args = BaseArgs() argv = [y.strip() for x in sys.argv for y in x.split()] training_type_index = argv.index("--training_type") if training_type_index == -1: raise ValueError("Training type not provided in command line arguments.") training_type = argv[training_type_index + 1] training_cls = None if training_type == TrainingType.LORA: training_cls = SFTLowRankConfig elif training_type == TrainingType.FULL_FINETUNE: training_cls = SFTFullRankConfig else: raise ValueError(f"Training type {training_type} not supported.") training_config = training_cls() args.extend_args(training_config.add_args, training_config.map_args, training_config.validate_args) args = args.parse_args() model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type) model_specification = model_specification_cls( pretrained_model_name_or_path=args.pretrained_model_name_or_path, tokenizer_id=args.tokenizer_id, tokenizer_2_id=args.tokenizer_2_id, tokenizer_3_id=args.tokenizer_3_id, text_encoder_id=args.text_encoder_id, text_encoder_2_id=args.text_encoder_2_id, text_encoder_3_id=args.text_encoder_3_id, transformer_id=args.transformer_id, vae_id=args.vae_id, text_encoder_dtype=args.text_encoder_dtype, text_encoder_2_dtype=args.text_encoder_2_dtype, text_encoder_3_dtype=args.text_encoder_3_dtype, transformer_dtype=args.transformer_dtype, vae_dtype=args.vae_dtype, revision=args.revision, cache_dir=args.cache_dir, ) if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]: trainer = SFTTrainer(args, model_specification) else: raise ValueError(f"Training type {args.training_type} not supported.") trainer.run() except KeyboardInterrupt: logger.info("Received keyboard interrupt. Exiting...") except Exception as e: logger.error(f"An error occurred during training: {e}") logger.error(traceback.format_exc()) if __name__ == "__main__": main()