File size: 1,097 Bytes
c09bcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from convert_images import convert_heic_to_jpeg_and_remove, convert_png_to_jpeg
from generate_class_images import generate_class_images
from model_loader import load_models
from train import parse_args, training_function
from inference import generate_images

def main():
    args = parse_args()
    text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path)

    # Convert images
    convert_heic_to_jpeg_and_remove(args.instance_data_dir)
    convert_png_to_jpeg(args.instance_data_dir, args.instance_data_dir)

    # Generate class images if needed
    if args.with_prior_preservation:
        generate_class_images(
            pipeline=StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path),
            class_prompt=args.class_prompt,
            num_class_images=100,  # Adjust as needed
            class_images_dir=args.class_data_dir,
        )

    # Train the model
    training_function(args, text_encoder, vae, unet, tokenizer)
    pipeline.save_pretrained("./output/")
if __name__ == "__main__":
    main()