fyp-deploy / src /eval.py
Mairaaa's picture
Update src/eval.py
d308227 verified
raw
history blame
5.8 kB
import os
# External libraries
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer
# Custom imports
from src.datasets.dresscode import DressCodeDataset
from src.datasets.vitonhd import VitonHDDataset
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
from src.utils.image_from_pipe import generate_images_from_mgd_pipe
from src.utils.set_seeds import set_seed
# Ensure the minimum version of diffusers is installed
check_min_version("0.10.0.dev0")
logger = get_logger(__name__, log_level="INFO")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_START_METHOD"] = "thread"
def main(args):
# Initialize Accelerator
accelerator = Accelerator(mixed_precision=args.get("mixed_precision", "fp16"))
device = accelerator.device
# Set the training seed
if args.get("seed") is not None:
set_seed(args["seed"])
# Load scheduler, tokenizer, and models
val_scheduler = DDIMScheduler.from_pretrained(args["pretrained_model_name_or_path"], subfolder="scheduler")
val_scheduler.set_timesteps(50, device=device)
tokenizer = CLIPTokenizer.from_pretrained(
args["pretrained_model_name_or_path"], subfolder="tokenizer", revision=args.get("revision", None)
)
text_encoder = CLIPTextModel.from_pretrained(
args["pretrained_model_name_or_path"], subfolder="text_encoder", revision=args.get("revision", None)
)
vae = AutoencoderKL.from_pretrained(args["pretrained_model_name_or_path"], subfolder="vae", revision=args.get("revision", None))
# Load UNet
unet = torch.hub.load(
repo_or_dir="aimagelab/multimodal-garment-designer",
source="github",
model="mgd",
pretrained=True,
)
# Freeze models
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Enable memory efficient attention if requested
if args.get("enable_xformers_memory_efficient_attention", False):
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Install it to enable memory-efficient attention.")
# Set dataset category
category = [args.get("category", "dresses")]
# Load dataset
if args["dataset"] == "dresscode":
test_dataset = DressCodeDataset(
dataroot_path=args["dataset_path"],
phase="test",
order=args.get("test_order", 0),
radius=5,
sketch_threshold_range=(20, 20),
tokenizer=tokenizer,
category=category,
size=(512, 384),
)
elif args["dataset"] == "vitonhd":
test_dataset = VitonHDDataset(
dataroot_path=args["dataset_path"],
phase="test",
order=args.get("test_order", 0),
sketch_threshold_range=(20, 20),
radius=5,
tokenizer=tokenizer,
size=(512, 384),
)
else:
raise NotImplementedError(f"Dataset {args['dataset']} is not supported.")
# Prepare dataloader
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
shuffle=False,
batch_size=args.get("batch_size", 1),
num_workers=args.get("num_workers_test", 4),
)
# Cast models to appropriate precision
weight_dtype = torch.float32 if args.get("mixed_precision") != "fp16" else torch.float16
text_encoder.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
unet.eval()
# Select pipeline
with torch.inference_mode():
pipeline_class = MGDPipeDisentangled if args.get("disentagle", False) else MGDPipe
val_pipe = pipeline_class(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=val_scheduler,
).to(device)
val_pipe.enable_attention_slicing()
# Prepare dataloader with accelerator
test_dataloader = accelerator.prepare(test_dataloader)
# Generate images
output_path = os.path.join(args["output_dir"], args.get("save_name", "generated_image.png"))
generate_images_from_mgd_pipe(
test_order=args.get("test_order", 0),
pipe=val_pipe,
test_dataloader=test_dataloader,
save_name=args.get("save_name", "generated_image"),
dataset=args["dataset"],
output_dir=args["output_dir"],
guidance_scale=args.get("guidance_scale", 7.5),
guidance_scale_pose=args.get("guidance_scale_pose", 0.5),
guidance_scale_sketch=args.get("guidance_scale_sketch", 7.5),
sketch_cond_rate=args.get("sketch_cond_rate", 1.0),
start_cond_rate=args.get("start_cond_rate", 0.0),
no_pose=False,
disentagle=args.get("disentagle", False),
seed=args.get("seed", None),
)
# Return the output image path for verification
return output_path
if __name__ == "__main__":
# Example usage for debugging
example_args = {
"pretrained_model_name_or_path": "./models",
"dataset": "dresscode",
"dataset_path": "./datasets/dresscode",
"output_dir": "./outputs",
"guidance_scale": 7.5,
"guidance_scale_sketch": 7.5,
"mixed_precision": "fp16",
"batch_size": 1,
"seed": 42,
}
output_image = main(example_args)
print(f"Image generated at: {output_image}")