from convert_images import convert_heic_to_jpeg_and_remove, convert_png_to_jpeg from generate_class_images import generate_class_images from model_loader import load_models from train import parse_args, training_function from inference import generate_images def main(): args = parse_args() text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path) # Convert images convert_heic_to_jpeg_and_remove(args.instance_data_dir) convert_png_to_jpeg(args.instance_data_dir, args.instance_data_dir) # Generate class images if needed if args.with_prior_preservation: generate_class_images( pipeline=StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path), class_prompt=args.class_prompt, num_class_images=100, # Adjust as needed class_images_dir=args.class_data_dir, ) # Train the model training_function(args, text_encoder, vae, unet, tokenizer) pipeline.save_pretrained("./output/") if __name__ == "__main__": main()