import os import io # external libraries import torch import torch.utils.checkpoint import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from diffusers import AutoencoderKL, DDIMScheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from diffusers import UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer # custom imports from model.src.datasets.dresscode import DressCodeDataset from model.src.datasets.vitonhd import VitonHDDataset from model.src.mgd_pipelines.mgd_pipe import MGDPipe from model.src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled from model.src.utils.arg_parser import eval_parse_args from model.src.utils.image_from_pipe import generate_images_from_mgd_pipe from model.src.utils.set_seeds import set_seed from PIL import Image from huggingface_hub import HfApi, HfFolder # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") logger = get_logger(__name__, log_level="INFO") os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["WANDB_START_METHOD"] = "thread" hf_token = os.getenv("HF_TOKEN") api = HfApi() HfFolder.save_token(hf_token) def main(json_from_req: dict) -> None: args = eval_parse_args() accelerator = Accelerator( mixed_precision=args.mixed_precision, ) device = accelerator.device # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Load scheduler, tokenizer and models. val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") val_scheduler.set_timesteps(50, device=device) tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = load_mgd_model(dataset=args.dataset, pretrained=True) #unet = torch.hub.load(dataset=args.dataset, repo_or_dir='aimagelab/multimodal-garment-designer', source='github', #model='mgd', pretrained=True) # Freeze vae and text_encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) # Enable memory efficient attention if requested if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") if args.category: category = [args.category] else: category = ['dresses', 'upper_body', 'lower_body'] if args.dataset == "dresscode": test_dataset = DressCodeDataset( dataroot_path=args.dataset_path, phase='test', order=args.test_order, radius=5, sketch_threshold_range=(20, 20), tokenizer=tokenizer, category=category, size=(512, 384), json_from_req=json_from_req ) elif args.dataset == "vitonhd": test_dataset = VitonHDDataset( dataroot_path=args.dataset_path, phase='test', order=args.test_order, sketch_threshold_range=(20, 20), radius=5, tokenizer=tokenizer, size=(512, 384), json_from_req=json_from_req ) else: raise NotImplementedError test_dataloader = torch.utils.data.DataLoader( test_dataset, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers_test, ) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if args.mixed_precision == 'fp16': weight_dtype = torch.float16 # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.eval() # Select fast classifier free guidance or disentagle classifier free guidance according to the disentagle parameter in args with torch.inference_mode(): if args.disentagle: val_pipe = MGDPipeDisentangled( text_encoder=text_encoder, vae=vae, unet=unet.to(vae.dtype), tokenizer=tokenizer, scheduler=val_scheduler, ).to(device) else: val_pipe = MGDPipe( text_encoder=text_encoder, vae=vae, unet=unet.to(vae.dtype), tokenizer=tokenizer, scheduler=val_scheduler, ).to(device) val_pipe.enable_attention_slicing() test_dataloader = accelerator.prepare(test_dataloader) final_image = generate_images_from_mgd_pipe( test_order=args.test_order, pipe=val_pipe, test_dataloader=test_dataloader, save_name=args.save_name, dataset=args.dataset, output_dir=args.output_dir, guidance_scale=args.guidance_scale, guidance_scale_pose=args.guidance_scale_pose, guidance_scale_sketch=args.guidance_scale_sketch, sketch_cond_rate=args.sketch_cond_rate, start_cond_rate=args.start_cond_rate, no_pose=False, disentagle=False, seed=args.seed, ) return final_image # Now returning the generated image def load_mgd_model(dataset: str, pretrained: bool = True) -> UNet2DConditionModel: """ MGD model pretrained (bool): load pretrained weights into the model """ config = UNet2DConditionModel.load_config("benjamin-paine/stable-diffusion-v1-5-inpainting", subfolder="unet") config['in_channels'] = 28 unet = UNet2DConditionModel.from_config(config) if pretrained: checkpoint = f"https://github.com/aimagelab/multimodal-garment-designer/releases/download/weights/{dataset}.pth" unet.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) return unet if __name__ == "__main__": main()