import os # External libraries import torch from accelerate import Accelerator from accelerate.logging import get_logger from diffusers import AutoencoderKL, DDIMScheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer # Custom imports from src.datasets.dresscode import DressCodeDataset from src.datasets.vitonhd import VitonHDDataset from src.mgd_pipelines.mgd_pipe import MGDPipe from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled from src.utils.image_from_pipe import generate_images_from_mgd_pipe from src.utils.set_seeds import set_seed # Ensure the minimum version of diffusers is installed check_min_version("0.10.0.dev0") logger = get_logger(__name__, log_level="INFO") os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["WANDB_START_METHOD"] = "thread" def main(args): # Initialize Accelerator accelerator = Accelerator(mixed_precision=args.get("mixed_precision", "fp16")) device = accelerator.device # Set the training seed if args.get("seed") is not None: set_seed(args["seed"]) # Load scheduler, tokenizer, and models val_scheduler = DDIMScheduler.from_pretrained(args["pretrained_model_name_or_path"], subfolder="scheduler") val_scheduler.set_timesteps(50, device=device) tokenizer = CLIPTokenizer.from_pretrained( args["pretrained_model_name_or_path"], subfolder="tokenizer", revision=args.get("revision", None) ) text_encoder = CLIPTextModel.from_pretrained( args["pretrained_model_name_or_path"], subfolder="text_encoder", revision=args.get("revision", None) ) vae = AutoencoderKL.from_pretrained(args["pretrained_model_name_or_path"], subfolder="vae", revision=args.get("revision", None)) # Load UNet unet = torch.hub.load( repo_or_dir="aimagelab/multimodal-garment-designer", source="github", model="mgd", pretrained=True, ) # Freeze models vae.requires_grad_(False) text_encoder.requires_grad_(False) # Enable memory efficient attention if requested if args.get("enable_xformers_memory_efficient_attention", False): if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Install it to enable memory-efficient attention.") # Set dataset category category = [args.get("category", "dresses")] # Load dataset if args["dataset"] == "dresscode": test_dataset = DressCodeDataset( dataroot_path=args["dataset_path"], phase="test", order=args.get("test_order", 0), radius=5, sketch_threshold_range=(20, 20), tokenizer=tokenizer, category=category, size=(512, 384), ) elif args["dataset"] == "vitonhd": test_dataset = VitonHDDataset( dataroot_path=args["dataset_path"], phase="test", order=args.get("test_order", 0), sketch_threshold_range=(20, 20), radius=5, tokenizer=tokenizer, size=(512, 384), ) else: raise NotImplementedError(f"Dataset {args['dataset']} is not supported.") # Prepare dataloader test_dataloader = torch.utils.data.DataLoader( test_dataset, shuffle=False, batch_size=args.get("batch_size", 1), num_workers=args.get("num_workers_test", 4), ) # Cast models to appropriate precision weight_dtype = torch.float32 if args.get("mixed_precision") != "fp16" else torch.float16 text_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.eval() # Select pipeline with torch.inference_mode(): pipeline_class = MGDPipeDisentangled if args.get("disentagle", False) else MGDPipe val_pipe = pipeline_class( text_encoder=text_encoder, vae=vae, unet=unet.to(vae.dtype), tokenizer=tokenizer, scheduler=val_scheduler, ).to(device) val_pipe.enable_attention_slicing() # Prepare dataloader with accelerator test_dataloader = accelerator.prepare(test_dataloader) # Generate images output_path = os.path.join(args["output_dir"], args.get("save_name", "generated_image.png")) generate_images_from_mgd_pipe( test_order=args.get("test_order", 0), pipe=val_pipe, test_dataloader=test_dataloader, save_name=args.get("save_name", "generated_image"), dataset=args["dataset"], output_dir=args["output_dir"], guidance_scale=args.get("guidance_scale", 7.5), guidance_scale_pose=args.get("guidance_scale_pose", 0.5), guidance_scale_sketch=args.get("guidance_scale_sketch", 7.5), sketch_cond_rate=args.get("sketch_cond_rate", 1.0), start_cond_rate=args.get("start_cond_rate", 0.0), no_pose=False, disentagle=args.get("disentagle", False), seed=args.get("seed", None), ) # Return the output image path for verification return output_path if __name__ == "__main__": # Example usage for debugging example_args = { "pretrained_model_name_or_path": "./models", "dataset": "dresscode", "dataset_path": "./datasets/dresscode", "output_dir": "./outputs", "guidance_scale": 7.5, "guidance_scale_sketch": 7.5, "mixed_precision": "fp16", "batch_size": 1, "seed": 42, } output_image = main(example_args) print(f"Image generated at: {output_image}")