Spaces:
Sleeping
Sleeping
import os | |
# External libraries | |
import torch | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.utils import check_min_version | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers import CLIPTextModel, CLIPTokenizer | |
# Custom imports | |
from src.datasets.dresscode import DressCodeDataset | |
from src.datasets.vitonhd import VitonHDDataset | |
from src.mgd_pipelines.mgd_pipe import MGDPipe | |
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled | |
from src.utils.image_from_pipe import generate_images_from_mgd_pipe | |
from src.utils.set_seeds import set_seed | |
# Ensure the minimum version of diffusers is installed | |
check_min_version("0.10.0.dev0") | |
logger = get_logger(__name__, log_level="INFO") | |
os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
os.environ["WANDB_START_METHOD"] = "thread" | |
def main(args): | |
# Initialize Accelerator | |
accelerator = Accelerator(mixed_precision=args.get("mixed_precision", "fp16")) | |
device = accelerator.device | |
# Set the training seed | |
if args.get("seed") is not None: | |
set_seed(args["seed"]) | |
# Load scheduler, tokenizer, and models | |
val_scheduler = DDIMScheduler.from_pretrained(args["pretrained_model_name_or_path"], subfolder="scheduler") | |
val_scheduler.set_timesteps(50, device=device) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
args["pretrained_model_name_or_path"], subfolder="tokenizer", revision=args.get("revision", None) | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
args["pretrained_model_name_or_path"], subfolder="text_encoder", revision=args.get("revision", None) | |
) | |
vae = AutoencoderKL.from_pretrained(args["pretrained_model_name_or_path"], subfolder="vae", revision=args.get("revision", None)) | |
# Load UNet | |
unet = torch.hub.load( | |
repo_or_dir="aimagelab/multimodal-garment-designer", | |
source="github", | |
model="mgd", | |
pretrained=True, | |
) | |
# Freeze models | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
# Enable memory efficient attention if requested | |
if args.get("enable_xformers_memory_efficient_attention", False): | |
if is_xformers_available(): | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Install it to enable memory-efficient attention.") | |
# Set dataset category | |
category = [args.get("category", "dresses")] | |
# Load dataset | |
if args["dataset"] == "dresscode": | |
test_dataset = DressCodeDataset( | |
dataroot_path=args["dataset_path"], | |
phase="test", | |
order=args.get("test_order", 0), | |
radius=5, | |
sketch_threshold_range=(20, 20), | |
tokenizer=tokenizer, | |
category=category, | |
size=(512, 384), | |
) | |
elif args["dataset"] == "vitonhd": | |
test_dataset = VitonHDDataset( | |
dataroot_path=args["dataset_path"], | |
phase="test", | |
order=args.get("test_order", 0), | |
sketch_threshold_range=(20, 20), | |
radius=5, | |
tokenizer=tokenizer, | |
size=(512, 384), | |
) | |
else: | |
raise NotImplementedError(f"Dataset {args['dataset']} is not supported.") | |
# Prepare dataloader | |
test_dataloader = torch.utils.data.DataLoader( | |
test_dataset, | |
shuffle=False, | |
batch_size=args.get("batch_size", 1), | |
num_workers=args.get("num_workers_test", 4), | |
) | |
# Cast models to appropriate precision | |
weight_dtype = torch.float32 if args.get("mixed_precision") != "fp16" else torch.float16 | |
text_encoder.to(device, dtype=weight_dtype) | |
vae.to(device, dtype=weight_dtype) | |
unet.eval() | |
# Select pipeline | |
with torch.inference_mode(): | |
pipeline_class = MGDPipeDisentangled if args.get("disentagle", False) else MGDPipe | |
val_pipe = pipeline_class( | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet.to(vae.dtype), | |
tokenizer=tokenizer, | |
scheduler=val_scheduler, | |
).to(device) | |
val_pipe.enable_attention_slicing() | |
# Prepare dataloader with accelerator | |
test_dataloader = accelerator.prepare(test_dataloader) | |
# Generate images | |
output_path = os.path.join(args["output_dir"], args.get("save_name", "generated_image.png")) | |
generate_images_from_mgd_pipe( | |
test_order=args.get("test_order", 0), | |
pipe=val_pipe, | |
test_dataloader=test_dataloader, | |
save_name=args.get("save_name", "generated_image"), | |
dataset=args["dataset"], | |
output_dir=args["output_dir"], | |
guidance_scale=args.get("guidance_scale", 7.5), | |
guidance_scale_pose=args.get("guidance_scale_pose", 0.5), | |
guidance_scale_sketch=args.get("guidance_scale_sketch", 7.5), | |
sketch_cond_rate=args.get("sketch_cond_rate", 1.0), | |
start_cond_rate=args.get("start_cond_rate", 0.0), | |
no_pose=False, | |
disentagle=args.get("disentagle", False), | |
seed=args.get("seed", None), | |
) | |
# Return the output image path for verification | |
return output_path | |
if __name__ == "__main__": | |
# Example usage for debugging | |
example_args = { | |
"pretrained_model_name_or_path": "./models", | |
"dataset": "dresscode", | |
"dataset_path": "./datasets/dresscode", | |
"output_dir": "./outputs", | |
"guidance_scale": 7.5, | |
"guidance_scale_sketch": 7.5, | |
"mixed_precision": "fp16", | |
"batch_size": 1, | |
"seed": 42, | |
} | |
output_image = main(example_args) | |
print(f"Image generated at: {output_image}") | |