kothariyashhh's picture
Upload 72 files
c09bcc2 verified
from convert_images import convert_heic_to_jpeg_and_remove, convert_png_to_jpeg
from generate_class_images import generate_class_images
from model_loader import load_models
from train import parse_args, training_function
from inference import generate_images
def main():
args = parse_args()
text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path)
# Convert images
convert_heic_to_jpeg_and_remove(args.instance_data_dir)
convert_png_to_jpeg(args.instance_data_dir, args.instance_data_dir)
# Generate class images if needed
if args.with_prior_preservation:
generate_class_images(
pipeline=StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path),
class_prompt=args.class_prompt,
num_class_images=100, # Adjust as needed
class_images_dir=args.class_data_dir,
)
# Train the model
training_function(args, text_encoder, vae, unet, tokenizer)
pipeline.save_pretrained("./output/")
if __name__ == "__main__":
main()