|
from convert_images import convert_heic_to_jpeg_and_remove, convert_png_to_jpeg
|
|
from generate_class_images import generate_class_images
|
|
from model_loader import load_models
|
|
from train import parse_args, training_function
|
|
from inference import generate_images
|
|
|
|
def main():
|
|
args = parse_args()
|
|
text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path)
|
|
|
|
|
|
convert_heic_to_jpeg_and_remove(args.instance_data_dir)
|
|
convert_png_to_jpeg(args.instance_data_dir, args.instance_data_dir)
|
|
|
|
|
|
if args.with_prior_preservation:
|
|
generate_class_images(
|
|
pipeline=StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path),
|
|
class_prompt=args.class_prompt,
|
|
num_class_images=100,
|
|
class_images_dir=args.class_data_dir,
|
|
)
|
|
|
|
|
|
training_function(args, text_encoder, vae, unet, tokenizer)
|
|
pipeline.save_pretrained("./output/")
|
|
if __name__ == "__main__":
|
|
main()
|
|
|